diff --git a/moto/kms/utils.py b/moto/kms/utils.py index 89898d032..8d53051c2 100644 --- a/moto/kms/utils.py +++ b/moto/kms/utils.py @@ -190,12 +190,17 @@ class RSAPrivateKey(AbstractPrivateKey): pad = padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH ) - algorithm = hashes.SHA256() + algorithm = hashes.SHA256() # type: Any + elif signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_384: + pad = padding.PSS( + mgf=padding.MGF1(hashes.SHA384()), salt_length=padding.PSS.MAX_LENGTH + ) + algorithm = hashes.SHA384() else: pad = padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH + mgf=padding.MGF1(hashes.SHA512()), salt_length=padding.PSS.MAX_LENGTH ) - algorithm = hashes.SHA256() + algorithm = hashes.SHA512() return self.private_key.sign(message, pad, algorithm) def verify(self, message: bytes, signature: bytes, signing_algorithm: str) -> bool: @@ -207,12 +212,17 @@ class RSAPrivateKey(AbstractPrivateKey): pad = padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH ) - algorithm = hashes.SHA256() + algorithm = hashes.SHA256() # type: Any + elif signing_algorithm == SigningAlgorithm.RSASSA_PSS_SHA_384: + pad = padding.PSS( + mgf=padding.MGF1(hashes.SHA384()), salt_length=padding.PSS.MAX_LENGTH + ) + algorithm = hashes.SHA384() else: pad = padding.PSS( - mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH + mgf=padding.MGF1(hashes.SHA512()), salt_length=padding.PSS.MAX_LENGTH ) - algorithm = hashes.SHA256() + algorithm = hashes.SHA512() public_key = self.private_key.public_key() try: diff --git a/tests/test_kms/test_kms_boto3.py b/tests/test_kms/test_kms_boto3.py index 3895df617..ca2b75241 100644 --- a/tests/test_kms/test_kms_boto3.py +++ b/tests/test_kms/test_kms_boto3.py @@ -2,11 +2,13 @@ import json from datetime import datetime from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.asymmetric import rsa +import itertools from unittest import mock from dateutil.tz import tzutc import base64 import os + import boto3 import botocore.exceptions from botocore.exceptions import ClientError @@ -1163,8 +1165,16 @@ def test_sign_and_verify_ignoring_grant_tokens(): @mock_kms -@pytest.mark.parametrize("key_spec", ["RSA_2048", "RSA_3072", "RSA_4096"]) -def test_sign_and_verify_digest_message_type_RSASSA_PSS_SHA_256(key_spec): +@pytest.mark.parametrize( + "key_spec, signing_algorithm", + list( + itertools.product( + ["RSA_2048", "RSA_3072", "RSA_4096"], + ["RSASSA_PSS_SHA_256", "RSASSA_PSS_SHA_384", "RSASSA_PSS_SHA_512"], + ) + ), +) +def test_sign_and_verify_digest_message_type_RSA(key_spec, signing_algorithm): client = boto3.client("kms", region_name="us-west-2") key = client.create_key( @@ -1176,7 +1186,6 @@ def test_sign_and_verify_digest_message_type_RSASSA_PSS_SHA_256(key_spec): digest.update(b"this works") digest.update(b"as well") message = digest.finalize() - signing_algorithm = "RSASSA_PSS_SHA_256" sign_response = client.sign( KeyId=key_id,