KMS: Add RSASSA_PSS_SHA_256
private key (#6702)
This commit is contained in:
parent
e11ba21b83
commit
0d75cdc38b
@ -3,9 +3,6 @@ import os
|
||||
from collections import defaultdict
|
||||
from copy import copy
|
||||
from datetime import datetime, timedelta
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import padding
|
||||
from typing import Any, Dict, List, Tuple, Optional, Iterable, Set
|
||||
|
||||
from moto.core import BaseBackend, BackendDict, BaseModel, CloudFormationModel
|
||||
@ -82,10 +79,10 @@ class Key(CloudFormationModel):
|
||||
self.key_rotation_status = False
|
||||
self.deletion_date: Optional[datetime] = None
|
||||
self.key_material = generate_master_key()
|
||||
self.private_key = generate_private_key()
|
||||
self.origin = "AWS_KMS"
|
||||
self.key_manager = "CUSTOMER"
|
||||
self.key_spec = key_spec or "SYMMETRIC_DEFAULT"
|
||||
self.private_key = generate_private_key(self.key_spec)
|
||||
self.arn = f"arn:aws:kms:{region}:{account_id}:key/{self.id}"
|
||||
|
||||
self.grants: Dict[str, Grant] = dict()
|
||||
@ -641,14 +638,7 @@ class KmsBackend(BaseBackend):
|
||||
self.__ensure_valid_sign_and_verify_key(key)
|
||||
self.__ensure_valid_signing_algorithm(key, signing_algorithm)
|
||||
|
||||
# TODO: support more than one hardcoded algorithm based on KeySpec
|
||||
signature = key.private_key.sign(
|
||||
message,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
signature = key.private_key.sign(message, signing_algorithm)
|
||||
|
||||
return key.arn, signature, signing_algorithm
|
||||
|
||||
@ -678,31 +668,15 @@ class KmsBackend(BaseBackend):
|
||||
)
|
||||
)
|
||||
|
||||
public_key = key.private_key.public_key()
|
||||
|
||||
try:
|
||||
# TODO: support more than one hardcoded algorithm based on KeySpec
|
||||
public_key.verify(
|
||||
signature,
|
||||
message,
|
||||
padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()),
|
||||
salt_length=padding.PSS.MAX_LENGTH,
|
||||
),
|
||||
hashes.SHA256(),
|
||||
)
|
||||
return key.arn, True, signing_algorithm
|
||||
except InvalidSignature:
|
||||
return key.arn, False, signing_algorithm
|
||||
return (
|
||||
key.arn,
|
||||
key.private_key.verify(message, signature, signing_algorithm),
|
||||
signing_algorithm,
|
||||
)
|
||||
|
||||
def get_public_key(self, key_id: str) -> Tuple[Key, bytes]:
|
||||
key = self.describe_key(key_id)
|
||||
public_key = key.private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
return key, public_key
|
||||
return key, key.private_key.public_key()
|
||||
|
||||
|
||||
kms_backends = BackendDict(KmsBackend, "kms")
|
||||
|
@ -1,3 +1,4 @@
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, Tuple, List
|
||||
from enum import Enum
|
||||
@ -6,9 +7,12 @@ import os
|
||||
import struct
|
||||
from moto.moto_api._internal import mock_random
|
||||
|
||||
from cryptography.exceptions import InvalidSignature
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa, padding
|
||||
|
||||
|
||||
from .exceptions import (
|
||||
InvalidCiphertextException,
|
||||
@ -134,16 +138,92 @@ def generate_master_key() -> bytes:
|
||||
return generate_data_key(MASTER_KEY_LEN)
|
||||
|
||||
|
||||
def generate_private_key() -> rsa.RSAPrivateKey:
|
||||
"""Generate a private key to be used on asymmetric sign/verify.
|
||||
class AbstractPrivateKey(metaclass=ABCMeta):
|
||||
@abstractmethod
|
||||
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
NOTE: KeySpec is not taken into consideration and the key is always RSA_2048
|
||||
this could be improved to support multiple key types
|
||||
"""
|
||||
return rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
@abstractmethod
|
||||
def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def public_key(self) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def validate_signing_algorithm(
|
||||
target_algorithm: str, valid_algorithms: List[str]
|
||||
) -> None:
|
||||
if target_algorithm not in valid_algorithms:
|
||||
raise ValidationException(
|
||||
(
|
||||
"1 validation error detected: Value at 'signing_algorithm' failed"
|
||||
"to satisfy constraint: Member must satisfy enum value set: {valid_signing_algorithms}"
|
||||
).format(valid_signing_algorithms=valid_algorithms)
|
||||
)
|
||||
|
||||
|
||||
class RSAPrivateKey(AbstractPrivateKey):
|
||||
def __init__(self, key_size: int):
|
||||
self.key_size = key_size
|
||||
self.private_key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=self.key_size
|
||||
)
|
||||
|
||||
def sign(self, message: bytes, signing_algorithm: str) -> bytes:
|
||||
validate_signing_algorithm(
|
||||
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
|
||||
)
|
||||
|
||||
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
|
||||
pad = padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||
)
|
||||
algorithm = hashes.SHA256()
|
||||
else:
|
||||
pad = padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||
)
|
||||
algorithm = hashes.SHA256()
|
||||
return self.private_key.sign(message, pad, algorithm)
|
||||
|
||||
def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool:
|
||||
validate_signing_algorithm(
|
||||
signing_algorithm, SigningAlgorithm.rsa_signing_algorithms()
|
||||
)
|
||||
|
||||
if signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_256:
|
||||
pad = padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||
)
|
||||
algorithm = hashes.SHA256()
|
||||
else:
|
||||
pad = padding.PSS(
|
||||
mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH
|
||||
)
|
||||
algorithm = hashes.SHA256()
|
||||
|
||||
public_key = self.private_key.public_key()
|
||||
try:
|
||||
public_key.verify(signature, message, pad, algorithm)
|
||||
return True
|
||||
except InvalidSignature:
|
||||
return False
|
||||
|
||||
def public_key(self) -> bytes:
|
||||
return self.private_key.public_key().public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
|
||||
|
||||
def generate_private_key(key_spec: str) -> AbstractPrivateKey:
|
||||
"""Generate a private key to be used on asymmetric sign/verify."""
|
||||
if key_spec == KeySpec.RSA_2048:
|
||||
return RSAPrivateKey(key_size=2048)
|
||||
else:
|
||||
return RSAPrivateKey(key_size=2048)
|
||||
|
||||
|
||||
def _serialize_ciphertext_blob(ciphertext: Ciphertext) -> bytes:
|
||||
|
@ -1163,7 +1163,7 @@ def test_sign_and_verify_ignoring_grant_tokens():
|
||||
|
||||
|
||||
@mock_kms
|
||||
def test_sign_and_verify_digest_message_type_256():
|
||||
def test_sign_and_verify_digest_message_type_RSASSA_PSS_SHA_256():
|
||||
client = boto3.client("kms", region_name="us-west-2")
|
||||
|
||||
key = client.create_key(
|
||||
|
Loading…
Reference in New Issue
Block a user