diff --git a/moto/kms/models.py b/moto/kms/models.py index cceb96342..1015aa72a 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -15,7 +15,7 @@ from .utils import decrypt, encrypt, generate_key_id, generate_master_key class Key(BaseModel): - def __init__(self, policy, key_usage, description, tags, region): + def __init__(self, policy, key_usage, customer_master_key_spec, description, tags, region): self.id = generate_key_id() self.policy = policy self.key_usage = key_usage @@ -30,9 +30,7 @@ class Key(BaseModel): self.key_material = generate_master_key() self.origin = "AWS_KMS" self.key_manager = "CUSTOMER" - self.customer_master_key_spec = "SYMMETRIC_DEFAULT" - self.encryption_algorithms = ["SYMMETRIC_DEFAULT"] - self.signing_algorithms = None + self.customer_master_key_spec = customer_master_key_spec or "SYMMETRIC_DEFAULT" @property def physical_resource_id(self): @@ -44,6 +42,38 @@ class Key(BaseModel): self.region, self.account_id, self.id ) + @property + def encryption_algorithms(self): + if self.key_usage == "SIGN_VERIFY": + return None + elif self.customer_master_key_spec == "SYMMETRIC_DEFAULT": + return ["SYMMETRIC_DEFAULT"] + else: + return [ + "RSAES_OAEP_SHA_1", + "RSAES_OAEP_SHA_256" + ] + + @property + def signing_algorithms(self): + if self.key_usage == "ENCRYPT_DECRYPT": + return None + elif self.customer_master_key_spec in ["ECC_NIST_P256", "ECC_SECG_P256K1"]: + return ["ECDSA_SHA_256"] + elif self.customer_master_key_spec == "ECC_NIST_P384": + return ["ECDSA_SHA_384"] + elif self.customer_master_key_spec == "ECC_NIST_P521": + return ["ECDSA_SHA_512"] + else: + return [ + "RSASSA_PKCS1_V1_5_SHA_256", + "RSASSA_PKCS1_V1_5_SHA_384", + "RSASSA_PKCS1_V1_5_SHA_512", + "RSASSA_PSS_SHA_256", + "RSASSA_PSS_SHA_384", + "RSASSA_PSS_SHA_512" + ] + def to_dict(self): key_dict = { "KeyMetadata": { @@ -81,6 +111,7 @@ class Key(BaseModel): key = kms_backend.create_key( policy=properties["KeyPolicy"], key_usage="ENCRYPT_DECRYPT", + customer_master_key_spec="SYMMETRIC_DEFAULT", description=properties["Description"], tags=properties.get("Tags"), region=region_name, @@ -102,8 +133,8 @@ class KmsBackend(BaseBackend): self.keys = {} self.key_to_aliases = defaultdict(set) - def create_key(self, policy, key_usage, description, tags, region): - key = Key(policy, key_usage, description, tags, region) + def create_key(self, policy, key_usage, customer_master_key_spec, description, tags, region): + key = Key(policy, key_usage, customer_master_key_spec, description, tags, region) self.keys[key.id] = key return key diff --git a/moto/kms/responses.py b/moto/kms/responses.py index d3a9726e1..15b990bbb 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -118,11 +118,12 @@ class KmsResponse(BaseResponse): """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" policy = self.parameters.get("Policy") key_usage = self.parameters.get("KeyUsage") + customer_master_key_spec = self.parameters.get("CustomerMasterKeySpec") description = self.parameters.get("Description") tags = self.parameters.get("Tags") key = self.kms_backend.create_key( - policy, key_usage, description, tags, self.region + policy, key_usage, customer_master_key_spec, description, tags, self.region ) return json.dumps(key.to_dict()) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 8c2843ee4..c5a49b974 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -64,6 +64,53 @@ def test_create_key(): key["KeyMetadata"]["Origin"].should.equal("AWS_KMS") key["KeyMetadata"].should_not.have.key("SigningAlgorithms") + key = conn.create_key( + KeyUsage = "ENCRYPT_DECRYPT", + CustomerMasterKeySpec = 'RSA_2048', + ) + + sorted(key["KeyMetadata"]["EncryptionAlgorithms"]).should.equal(["RSAES_OAEP_SHA_1", "RSAES_OAEP_SHA_256"]) + key["KeyMetadata"].should_not.have.key("SigningAlgorithms") + + key = conn.create_key( + KeyUsage = "SIGN_VERIFY", + CustomerMasterKeySpec = 'RSA_2048', + ) + + key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms") + sorted(key["KeyMetadata"]["SigningAlgorithms"]).should.equal([ + "RSASSA_PKCS1_V1_5_SHA_256", + "RSASSA_PKCS1_V1_5_SHA_384", + "RSASSA_PKCS1_V1_5_SHA_512", + "RSASSA_PSS_SHA_256", + "RSASSA_PSS_SHA_384", + "RSASSA_PSS_SHA_512" + ]) + + key = conn.create_key( + KeyUsage = "SIGN_VERIFY", + CustomerMasterKeySpec = 'ECC_SECG_P256K1', + ) + + key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms") + key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_256"]) + + key = conn.create_key( + KeyUsage = "SIGN_VERIFY", + CustomerMasterKeySpec = 'ECC_NIST_P384', + ) + + key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms") + key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_384"]) + + key = conn.create_key( + KeyUsage = "SIGN_VERIFY", + CustomerMasterKeySpec = 'ECC_NIST_P521', + ) + + key["KeyMetadata"].should_not.have.key("EncryptionAlgorithms") + key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_512"]) + @mock_kms_deprecated def test_describe_key(): diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py index f5478e0ef..4c84ed127 100644 --- a/tests/test_kms/test_utils.py +++ b/tests/test_kms/test_utils.py @@ -102,7 +102,7 @@ def test_deserialize_ciphertext_blob(raw, serialized): @parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) def test_encrypt_decrypt_cycle(encryption_context): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt( @@ -133,7 +133,7 @@ def test_encrypt_unknown_key_id(): def test_decrypt_invalid_ciphertext_format(): - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key_map = {master_key.id: master_key} with assert_raises(InvalidCiphertextException): @@ -153,7 +153,7 @@ def test_decrypt_unknwown_key_id(): def test_decrypt_invalid_ciphertext(): - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = ( master_key.id.encode("utf-8") + b"123456789012" @@ -171,7 +171,7 @@ def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_encryption_context(): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt(