KMS: Add RSASSA_PSS_SHA_256 private key (#6702)

This commit is contained in:
Akira Noda 2023-08-20 23:32:26 +09:00 committed by GitHub
parent e11ba21b83
commit 0d75cdc38b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 45 deletions

View File

@ -3,9 +3,6 @@ import os
from collections import defaultdict
from copy import copy
from datetime import datetime, timedelta
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
@ -82,10 +79,10 @@ class Key(CloudFormationModel):
self.key_rotation_status = False
self.deletion_date: Optional[datetime] = None
self.key_material = generate_master_key()
self.private_key = generate_private_key()
self.origin = "AWS_KMS"
self.key_manager = "CUSTOMER"
self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
self.private_key = generate_private_key(self.key_spec)
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
self.grants: Dict[str, Grant] = dict()
@ -641,14 +638,7 @@ class KmsBackend(BaseBackend):
self.__ensure_valid_sign_and_verify_key(key)
self.__ensure_valid_signing_algorithm(key, signing_algorithm)
# TODO: support more than one hardcoded algorithm based on KeySpec
signature = key.private_key.sign(
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
),
hashes.SHA256(),
)
signature = key.private_key.sign(message, signing_algorithm)
return key.arn, signature, signing_algorithm
@ -678,31 +668,15 @@ class KmsBackend(BaseBackend):
)
)
public_key = key.private_key.public_key()
try:
# TODO: support more than one hardcoded algorithm based on KeySpec
public_key.verify(
signature,
message,
padding.PSS(
mgf=padding.MGF1(hashes.SHA256()),
salt_length=padding.PSS.MAX_LENGTH,
),
hashes.SHA256(),
)
return key.arn, True, signing_algorithm
except InvalidSignature:
return key.arn, False, signing_algorithm
return (
key.arn,
key.private_key.verify(message, signature, signing_algorithm),
signing_algorithm,
)
def get_public_key(self, key_id: str) -> Tuple[Key, bytes]:
key = self.describe_key(key_id)
public_key = key.private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return key, public_key
return key, key.private_key.public_key()
kms_backends = BackendDict(KmsBackend, "kms")

View File

@ -1,3 +1,4 @@
from abc import abstractmethod, ABCMeta
from collections import namedtuple
from typing import Any, Dict, Tuple, List
from enum import Enum
@ -6,9 +7,12 @@ import os
import struct
from moto.moto_api._internal import mock_random
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from .exceptions import (
InvalidCiphertextException,
@ -134,16 +138,92 @@ def generate_master_key() -> bytes:
return generate_data_key(MASTER_KEY_LEN)
def generate_private_key() -> rsa.RSAPrivateKey:
"""Generate a private key to be used on asymmetric sign/verify.
class AbstractPrivateKey(metaclass=ABCMeta):
@abstractmethod
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
raise NotImplementedError
NOTE: KeySpec is not taken into consideration and the key is always RSA_2048
this could be improved to support multiple key types
"""
return rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
@abstractmethod
def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
raise NotImplementedError
@abstractmethod
def public_key(self) -> bytes:
raise NotImplementedError
def validate_signing_algorithm(
target_algorithm: str, valid_algorithms: List[str]
) -> None:
if target_algorithm not in valid_algorithms:
raise ValidationException(
(
"1 validation error detected: Value at 'signing_algorithm' failed"
"to satisfy constraint: Member must satisfy enum value set: {valid_signing_algorithms}"
).format(valid_signing_algorithms=valid_algorithms)
)
class RSAPrivateKey(AbstractPrivateKey):
def __init__(self, key_size: int):
self.key_size = key_size
self.private_key = rsa.generate_private_key(
public_exponent=65537, key_size=self.key_size
)
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
validate_signing_algorithm(
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
)
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
else:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
return self.private_key.sign(message, pad, algorithm)
def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
validate_signing_algorithm(
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
)
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
else:
pad = padding.PSS(
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
)
algorithm = hashes.SHA256()
public_key = self.private_key.public_key()
try:
public_key.verify(signature, message, pad, algorithm)
return True
except InvalidSignature:
return False
def public_key(self) -> bytes:
return self.private_key.public_key().public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
def generate_private_key(key_spec: str) -> AbstractPrivateKey:
"""Generate a private key to be used on asymmetric sign/verify."""
if key_spec == KeySpec.RSA_2048:
return RSAPrivateKey(key_size=2048)
else:
return RSAPrivateKey(key_size=2048)
def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:

View File

@ -1163,7 +1163,7 @@ def test_sign_and_verify_ignoring_grant_tokens():
@mock_kms
def test_sign_and_verify_digest_message_type_256():
def test_sign_and_verify_digest_message_type_RSASSA_PSS_SHA_256():
client = boto3.client("kms", region_name="us-west-2")
key = client.create_key(