KMS: Add RSASSA_PSS_SHA_256
private key (#6702)
This commit is contained in:
parent
e11ba21b83
commit
0d75cdc38b
@ -3,9 +3,6 @@ import os
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from cryptography.exceptions import InvalidSignature
|
|
||||||
from cryptography.hazmat.primitives import hashes, serialization
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import padding
|
|
||||||
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
|
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
|
||||||
|
|
||||||
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
||||||
@ -82,10 +79,10 @@ class Key(CloudFormationModel):
|
|||||||
self.key_rotation_status = False
|
self.key_rotation_status = False
|
||||||
self.deletion_date: Optional[datetime] = None
|
self.deletion_date: Optional[datetime] = None
|
||||||
self.key_material = generate_master_key()
|
self.key_material = generate_master_key()
|
||||||
self.private_key = generate_private_key()
|
|
||||||
self.origin = "AWS_KMS"
|
self.origin = "AWS_KMS"
|
||||||
self.key_manager = "CUSTOMER"
|
self.key_manager = "CUSTOMER"
|
||||||
self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
|
self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
|
||||||
|
self.private_key = generate_private_key(self.key_spec)
|
||||||
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
|
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
|
||||||
|
|
||||||
self.grants: Dict[str, Grant] = dict()
|
self.grants: Dict[str, Grant] = dict()
|
||||||
@ -641,14 +638,7 @@ class KmsBackend(BaseBackend):
|
|||||||
self.__ensure_valid_sign_and_verify_key(key)
|
self.__ensure_valid_sign_and_verify_key(key)
|
||||||
self.__ensure_valid_signing_algorithm(key, signing_algorithm)
|
self.__ensure_valid_signing_algorithm(key, signing_algorithm)
|
||||||
|
|
||||||
# TODO: support more than one hardcoded algorithm based on KeySpec
|
signature = key.private_key.sign(message, signing_algorithm)
|
||||||
signature = key.private_key.sign(
|
|
||||||
message,
|
|
||||||
padding.PSS(
|
|
||||||
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
|
||||||
),
|
|
||||||
hashes.SHA256(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return key.arn, signature, signing_algorithm
|
return key.arn, signature, signing_algorithm
|
||||||
|
|
||||||
@ -678,31 +668,15 @@ class KmsBackend(BaseBackend):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
public_key = key.private_key.public_key()
|
return (
|
||||||
|
key.arn,
|
||||||
try:
|
key.private_key.verify(message, signature, signing_algorithm),
|
||||||
# TODO: support more than one hardcoded algorithm based on KeySpec
|
signing_algorithm,
|
||||||
public_key.verify(
|
)
|
||||||
signature,
|
|
||||||
message,
|
|
||||||
padding.PSS(
|
|
||||||
mgf=padding.MGF1(hashes.SHA256()),
|
|
||||||
salt_length=padding.PSS.MAX_LENGTH,
|
|
||||||
),
|
|
||||||
hashes.SHA256(),
|
|
||||||
)
|
|
||||||
return key.arn, True, signing_algorithm
|
|
||||||
except InvalidSignature:
|
|
||||||
return key.arn, False, signing_algorithm
|
|
||||||
|
|
||||||
def get_public_key(self, key_id: str) -> Tuple[Key, bytes]:
|
def get_public_key(self, key_id: str) -> Tuple[Key, bytes]:
|
||||||
key = self.describe_key(key_id)
|
key = self.describe_key(key_id)
|
||||||
public_key = key.private_key.public_key().public_bytes(
|
return key, key.private_key.public_key()
|
||||||
encoding=serialization.Encoding.DER,
|
|
||||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
|
||||||
)
|
|
||||||
|
|
||||||
return key, public_key
|
|
||||||
|
|
||||||
|
|
||||||
kms_backends = BackendDict(KmsBackend, "kms")
|
kms_backends = BackendDict(KmsBackend, "kms")
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
from abc import abstractmethod, ABCMeta
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Dict, Tuple, List
|
from typing import Any, Dict, Tuple, List
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@ -6,9 +7,12 @@ import os
|
|||||||
import struct
|
import struct
|
||||||
from moto.moto_api._internal import mock_random
|
from moto.moto_api._internal import mock_random
|
||||||
|
|
||||||
|
from cryptography.exceptions import InvalidSignature
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
from cryptography.hazmat.primitives import hashes, serialization
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||||
|
|
||||||
|
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
InvalidCiphertextException,
|
InvalidCiphertextException,
|
||||||
@ -134,16 +138,92 @@ def generate_master_key() -> bytes:
|
|||||||
return generate_data_key(MASTER_KEY_LEN)
|
return generate_data_key(MASTER_KEY_LEN)
|
||||||
|
|
||||||
|
|
||||||
def generate_private_key() -> rsa.RSAPrivateKey:
|
class AbstractPrivateKey(metaclass=ABCMeta):
|
||||||
"""Generate a private key to be used on asymmetric sign/verify.
|
@abstractmethod
|
||||||
|
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
NOTE: KeySpec is not taken into consideration and the key is always RSA_2048
|
@abstractmethod
|
||||||
this could be improved to support multiple key types
|
def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
|
||||||
"""
|
raise NotImplementedError
|
||||||
return rsa.generate_private_key(
|
|
||||||
public_exponent=65537,
|
@abstractmethod
|
||||||
key_size=2048,
|
def public_key(self) -> bytes:
|
||||||
)
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def validate_signing_algorithm(
|
||||||
|
target_algorithm: str, valid_algorithms: List[str]
|
||||||
|
) -> None:
|
||||||
|
if target_algorithm not in valid_algorithms:
|
||||||
|
raise ValidationException(
|
||||||
|
(
|
||||||
|
"1 validation error detected: Value at 'signing_algorithm' failed"
|
||||||
|
"to satisfy constraint: Member must satisfy enum value set: {valid_signing_algorithms}"
|
||||||
|
).format(valid_signing_algorithms=valid_algorithms)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RSAPrivateKey(AbstractPrivateKey):
|
||||||
|
def __init__(self, key_size: int):
|
||||||
|
self.key_size = key_size
|
||||||
|
self.private_key = rsa.generate_private_key(
|
||||||
|
public_exponent=65537, key_size=self.key_size
|
||||||
|
)
|
||||||
|
|
||||||
|
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
|
||||||
|
validate_signing_algorithm(
|
||||||
|
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
|
||||||
|
)
|
||||||
|
|
||||||
|
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
|
||||||
|
pad = padding.PSS(
|
||||||
|
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||||
|
)
|
||||||
|
algorithm = hashes.SHA256()
|
||||||
|
else:
|
||||||
|
pad = padding.PSS(
|
||||||
|
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||||
|
)
|
||||||
|
algorithm = hashes.SHA256()
|
||||||
|
return self.private_key.sign(message, pad, algorithm)
|
||||||
|
|
||||||
|
def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
|
||||||
|
validate_signing_algorithm(
|
||||||
|
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
|
||||||
|
)
|
||||||
|
|
||||||
|
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
|
||||||
|
pad = padding.PSS(
|
||||||
|
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||||
|
)
|
||||||
|
algorithm = hashes.SHA256()
|
||||||
|
else:
|
||||||
|
pad = padding.PSS(
|
||||||
|
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||||
|
)
|
||||||
|
algorithm = hashes.SHA256()
|
||||||
|
|
||||||
|
public_key = self.private_key.public_key()
|
||||||
|
try:
|
||||||
|
public_key.verify(signature, message, pad, algorithm)
|
||||||
|
return True
|
||||||
|
except InvalidSignature:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def public_key(self) -> bytes:
|
||||||
|
return self.private_key.public_key().public_bytes(
|
||||||
|
encoding=serialization.Encoding.DER,
|
||||||
|
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_private_key(key_spec: str) -> AbstractPrivateKey:
|
||||||
|
"""Generate a private key to be used on asymmetric sign/verify."""
|
||||||
|
if key_spec == KeySpec.RSA_2048:
|
||||||
|
return RSAPrivateKey(key_size=2048)
|
||||||
|
else:
|
||||||
|
return RSAPrivateKey(key_size=2048)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:
|
def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:
|
||||||
|
@ -1163,7 +1163,7 @@ def test_sign_and_verify_ignoring_grant_tokens():
|
|||||||
|
|
||||||
|
|
||||||
@mock_kms
|
@mock_kms
|
||||||
def test_sign_and_verify_digest_message_type_256():
|
def test_sign_and_verify_digest_message_type_RSASSA_PSS_SHA_256():
|
||||||
client = boto3.client("kms", region_name="us-west-2")
|
client = boto3.client("kms", region_name="us-west-2")
|
||||||
|
|
||||||
key = client.create_key(
|
key = client.create_key(
|
||||||
|
Loading…
Reference in New Issue
Block a user