KMS : Adding support for multi-region keys and implementing replicate_key API. (#5288)

This commit is contained in:
taras-kobernyk-localstack 2022-07-27 11:30:41 +02:00 committed by GitHub
parent be6e02e5fa
commit 9d26ec7422
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 102 additions and 8 deletions

View File

@ -1,6 +1,7 @@
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from copy import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from cryptography.exceptions import InvalidSignature from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
@ -55,8 +56,10 @@ class Grant(BaseModel):
class Key(CloudFormationModel): class Key(CloudFormationModel):
def __init__(self, policy, key_usage, key_spec, description, region): def __init__(
self.id = generate_key_id() self, policy, key_usage, key_spec, description, region, multi_region=False
):
self.id = generate_key_id(multi_region)
self.creation_date = unix_time() self.creation_date = unix_time()
self.policy = policy or self.generate_default_policy() self.policy = policy or self.generate_default_policy()
self.key_usage = key_usage self.key_usage = key_usage
@ -64,6 +67,7 @@ class Key(CloudFormationModel):
self.description = description or "" self.description = description or ""
self.enabled = True self.enabled = True
self.region = region self.region = region
self.multi_region = multi_region
self.account_id = get_account_id() self.account_id = get_account_id()
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
@ -184,6 +188,7 @@ class Key(CloudFormationModel):
"KeyManager": self.key_manager, "KeyManager": self.key_manager,
"KeyUsage": self.key_usage, "KeyUsage": self.key_usage,
"KeyState": self.key_state, "KeyState": self.key_state,
"MultiRegion": self.multi_region,
"Origin": self.origin, "Origin": self.origin,
"SigningAlgorithms": self.signing_algorithms, "SigningAlgorithms": self.signing_algorithms,
} }
@ -264,13 +269,31 @@ class KmsBackend(BaseBackend):
self.add_alias(key.id, alias_name) self.add_alias(key.id, alias_name)
return key.id return key.id
def create_key(self, policy, key_usage, key_spec, description, tags, region): def create_key(
key = Key(policy, key_usage, key_spec, description, region) self, policy, key_usage, key_spec, description, tags, region, multi_region=False
):
key = Key(policy, key_usage, key_spec, description, region, multi_region)
self.keys[key.id] = key self.keys[key.id] = key
if tags is not None and len(tags) > 0: if tags is not None and len(tags) > 0:
self.tag_resource(key.id, tags) self.tag_resource(key.id, tags)
return key return key
# https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html#mrk-sync-properties
# In AWS replicas of a key only share some properties with the original key. Some of those properties get updated
# in all replicas automatically if those properties change in the original key. Also, such properties can not be
# changed for replicas directly.
#
# In our implementation with just create a copy of all the properties once without any protection from change,
# as the exact implementation is currently infeasible.
def replicate_key(self, key_id, replica_region):
# Using copy() instead of deepcopy(), as the latter results in exception:
# TypeError: cannot pickle '_cffi_backend.FFI' object
# Since we only update top level properties, copy() should suffice.
replica_key = copy(self.keys[key_id])
replica_key.region = replica_region
to_region_backend = kms_backends[replica_region]
to_region_backend.keys[replica_key.id] = replica_key
def update_key_description(self, key_id, description): def update_key_description(self, key_id, description):
key = self.keys[self.get_key_id(key_id)] key = self.keys[self.get_key_id(key_id)]
key.description = description key.description = description

View File

@ -51,8 +51,11 @@ class KmsResponse(BaseResponse):
- key ARN - key ARN
""" """
is_arn = key_id.startswith("arn:") and ":key/" in key_id is_arn = key_id.startswith("arn:") and ":key/" in key_id
# https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html
# "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to
# identify MRKs programmatically."
is_raw_key_id = re.match( is_raw_key_id = re.match(
r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", r"^(mrk-)?[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$",
key_id, key_id,
re.IGNORECASE, re.IGNORECASE,
) )
@ -114,12 +117,19 @@ class KmsResponse(BaseResponse):
) )
description = self.parameters.get("Description") description = self.parameters.get("Description")
tags = self.parameters.get("Tags") tags = self.parameters.get("Tags")
multi_region = self.parameters.get("MultiRegion")
key = self.kms_backend.create_key( key = self.kms_backend.create_key(
policy, key_usage, key_spec, description, tags, self.region policy, key_usage, key_spec, description, tags, self.region, multi_region
) )
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def replicate_key(self):
key_id = self.parameters.get("KeyId")
self._validate_key_id(key_id)
replica_region = self.parameters.get("ReplicaRegion")
self.kms_backend.replicate_key(key_id, replica_region)
def update_key_description(self): def update_key_description(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html"""
key_id = self.parameters.get("KeyId") key_id = self.parameters.get("KeyId")

View File

@ -45,8 +45,15 @@ RESERVED_ALIASES = [
] ]
def generate_key_id(): def generate_key_id(multi_region=False):
return str(uuid.uuid4()) key = str(uuid.uuid4())
# https://docs.aws.amazon.com/kms/latest/developerguide/multi-region-keys-overview.html
# "Notice that multi-Region keys have a distinctive key ID that begins with mrk-. You can use the mrk- prefix to
# identify MRKs programmatically."
if multi_region:
key = "mrk-" + key
return key
def generate_data_key(number_of_bytes): def generate_data_key(number_of_bytes):

View File

@ -120,6 +120,60 @@ def test_create_key():
key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_512"]) key["KeyMetadata"]["SigningAlgorithms"].should.equal(["ECDSA_SHA_512"])
@mock_kms
def test_create_multi_region_key():
conn = boto3.client("kms", region_name="us-east-1")
key = conn.create_key(
Policy="my policy",
Description="my key",
KeyUsage="ENCRYPT_DECRYPT",
MultiRegion=True,
Tags=[{"TagKey": "project", "TagValue": "moto"}],
)
key["KeyMetadata"]["KeyId"].should.match("^mrk-")
key["KeyMetadata"]["MultiRegion"].should.equal(True)
@mock_kms
def test_non_multi_region_keys_should_not_have_multi_region_properties():
conn = boto3.client("kms", region_name="us-east-1")
key = conn.create_key(
Policy="my policy",
Description="my key",
KeyUsage="ENCRYPT_DECRYPT",
MultiRegion=False,
Tags=[{"TagKey": "project", "TagValue": "moto"}],
)
key["KeyMetadata"]["KeyId"].should_not.match("^mrk-")
key["KeyMetadata"]["MultiRegion"].should.equal(False)
@mock_kms
def test_replicate_key():
region_to_replicate_from = "us-east-1"
region_to_replicate_to = "us-west-1"
from_region_client = boto3.client("kms", region_name=region_to_replicate_from)
to_region_client = boto3.client("kms", region_name=region_to_replicate_to)
response = from_region_client.create_key(
Policy="my policy",
Description="my key",
KeyUsage="ENCRYPT_DECRYPT",
MultiRegion=True,
Tags=[{"TagKey": "project", "TagValue": "moto"}],
)
key_id = response["KeyMetadata"]["KeyId"]
with pytest.raises(to_region_client.exceptions.NotFoundException):
to_region_client.describe_key(KeyId=key_id)
from_region_client.replicate_key(KeyId=key_id, ReplicaRegion=region_to_replicate_to)
to_region_client.describe_key(KeyId=key_id)
from_region_client.describe_key(KeyId=key_id)
@mock_kms @mock_kms
def test_create_key_deprecated_master_custom_key_spec(): def test_create_key_deprecated_master_custom_key_spec():
conn = boto3.client("kms", region_name="us-east-1") conn = boto3.client("kms", region_name="us-east-1")