KMS: Add RSASSA_PSS_SHA_256 private key (#6702)

This commit is contained in:
Akira Noda 2023-08-20 23:32:26 +09:00 committed by GitHub
parent e11ba21b83
commit 0d75cdc38b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 45 deletions

View File

@ -3,9 +3,6 @@ import os
from collections import defaultdict from collections import defaultdict
from copy import copy from copy import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
@ -82,10 +79,10 @@ class Key(CloudFormationModel):
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date: Optional[datetime] = None self.deletion_date: Optional[datetime] = None
self.key_material = generate_master_key() self.key_material = generate_master_key()
self.private_key = generate_private_key()
self.origin = "AWS_KMS" self.origin = "AWS_KMS"
self.key_manager = "CUSTOMER" self.key_manager = "CUSTOMER"
self.key_spec = key_spec or "SYMMETRIC_DEFAULT" self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
self.private_key = generate_private_key(self.key_spec)
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}" self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
self.grants: Dict[str, Grant] = dict() self.grants: Dict[str, Grant] = dict()
@ -641,14 +638,7 @@ class KmsBackend(BaseBackend):
self.__ensure_valid_sign_and_verify_key(key) self.__ensure_valid_sign_and_verify_key(key)
self.__ensure_valid_signing_algorithm(key, signing_algorithm) self.__ensure_valid_signing_algorithm(key, signing_algorithm)
# TODO: support more than one hardcoded algorithm based on KeySpec signature = key.private_key.sign(message, signing_algorithm)
signature = key.private_key.sign(
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256(),
)
return key.arn, signature, signing_algorithm return key.arn, signature, signing_algorithm
@ -678,31 +668,15 @@ class KmsBackend(BaseBackend):
) )
) )
public_key = key.private_key.public_key() return (
key.arn,
try: key.private_key.verify(message, signature, signing_algorithm),
# TODO: support more than one hardcoded algorithm based on KeySpec signing_algorithm,
public_key.verify( )
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
return key.arn, True, signing_algorithm
except InvalidSignature:
return key.arn, False, signing_algorithm
def get_public_key(self, key_id: str) -> Tuple[Key, bytes]: def get_public_key(self, key_id: str) -> Tuple[Key, bytes]:
key = self.describe_key(key_id) key = self.describe_key(key_id)
public_key = key.private_key.public_key().public_bytes( return key, key.private_key.public_key()
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return key, public_key
kms_backends = BackendDict(KmsBackend, "kms") kms_backends = BackendDict(KmsBackend, "kms")

View File

@ -1,3 +1,4 @@
from abc import abstractmethod, ABCMeta
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict, Tuple, List from typing import Any, Dict, Tuple, List
from enum import Enum from enum import Enum
@ -6,9 +7,12 @@ import os
import struct import struct
from moto.moto_api._internal import mock_random from moto.moto_api._internal import mock_random
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from .exceptions import ( from .exceptions import (
InvalidCiphertextException, InvalidCiphertextException,
@ -134,16 +138,92 @@ def generate_master_key() -> bytes:
return generate_data_key(MASTER_KEY_LEN) return generate_data_key(MASTER_KEY_LEN)
def generate_private_key() -> rsa.RSAPrivateKey: class AbstractPrivateKey(metaclass=ABCMeta):
"""Generate a private key to be used on asymmetric sign/verify. @abstractmethod
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
raise NotImplementedError
NOTE: KeySpec is not taken into consideration and the key is always RSA_2048 @abstractmethod
this could be improved to support multiple key types def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
""" raise NotImplementedError
return rsa.generate_private_key(
public_exponent=65537, @abstractmethod
key_size=2048, def public_key(self) -> bytes:
) raise NotImplementedError
def validate_signing_algorithm(
target_algorithm: str, valid_algorithms: List[str]
) -> None:
if target_algorithm not in valid_algorithms:
raise ValidationException(
(
"1 validation error detected: Value at 'signing_algorithm' failed"
"to satisfy constraint: Member must satisfy enum value set: {valid_signing_algorithms}"
).format(valid_signing_algorithms=valid_algorithms)
)
class RSAPrivateKey(AbstractPrivateKey):
def __init__(self, key_size: int):
self.key_size = key_size
self.private_key = rsa.generate_private_key(
public_exponent=65537, key_size=self.key_size
)
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
validate_signing_algorithm(
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
)
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
else:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
return self.private_key.sign(message, pad, algorithm)
def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
validate_signing_algorithm(
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
)
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
else:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
public_key = self.private_key.public_key()
try:
public_key.verify(signature, message, pad, algorithm)
return True
except InvalidSignature:
return False
def public_key(self) -> bytes:
return self.private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def generate_private_key(key_spec: str) -> AbstractPrivateKey:
"""Generate a private key to be used on asymmetric sign/verify."""
if key_spec == KeySpec.RSA_2048:
return RSAPrivateKey(key_size=2048)
else:
return RSAPrivateKey(key_size=2048)
def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes: def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:

View File

@ -1163,7 +1163,7 @@ def test_sign_and_verify_ignoring_grant_tokens():
@mock_kms @mock_kms
def test_sign_and_verify_digest_message_type_256(): def test_sign_and_verify_digest_message_type_RSASSA_PSS_SHA_256():
client = boto3.client("kms", region_name="us-west-2") client = boto3.client("kms", region_name="us-west-2")
key = client.create_key( key = client.create_key(