diff --git a/moto/kms/models.py b/moto/kms/models.py index 22f0039b2..32fcd23ae 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -7,13 +7,14 @@ from datetime import datetime, timedelta from boto3 import Session from moto.core import BaseBackend, BaseModel +from moto.core.exceptions import JsonRESTError from moto.core.utils import iso_8601_datetime_without_milliseconds - +from moto.utilities.tagging_service import TaggingService from .utils import decrypt, encrypt, generate_key_id, generate_master_key class Key(BaseModel): - def __init__(self, policy, key_usage, description, tags, region): + def __init__(self, policy, key_usage, description, region): self.id = generate_key_id() self.policy = policy self.key_usage = key_usage @@ -24,7 +25,6 @@ class Key(BaseModel): self.account_id = "012345678912" self.key_rotation_status = False self.deletion_date = None - self.tags = tags or {} self.key_material = generate_master_key() @property @@ -70,11 +70,12 @@ class Key(BaseModel): policy=properties["KeyPolicy"], key_usage="ENCRYPT_DECRYPT", description=properties["Description"], - tags=properties.get("Tags"), region=region_name, ) key.key_rotation_status = properties["EnableKeyRotation"] key.enabled = properties["Enabled"] + kms_backend.tag_resource(key.id, properties.get("Tags")) + return key def get_cfn_attribute(self, attribute_name): @@ -89,24 +90,19 @@ class KmsBackend(BaseBackend): def __init__(self): self.keys = {} self.key_to_aliases = defaultdict(set) + self.tagger = TaggingService(keyName='TagKey', valueName='TagValue') def create_key(self, policy, key_usage, description, tags, region): - key = Key(policy, key_usage, description, tags, region) + key = Key(policy, key_usage, description, region) self.keys[key.id] = key + if tags != None and len(tags) > 0: + self.tag_resource(key.id, tags) return key def update_key_description(self, key_id, description): key = self.keys[self.get_key_id(key_id)] key.description = description - def tag_resource(self, key_id, tags): - key = self.keys[self.get_key_id(key_id)] - key.tags = tags - - def list_resource_tags(self, key_id): - key = self.keys[self.get_key_id(key_id)] - return key.tags - def delete_key(self, key_id): if key_id in self.keys: if key_id in self.key_to_aliases: @@ -282,6 +278,29 @@ class KmsBackend(BaseBackend): return plaintext, ciphertext_blob, arn + def list_resource_tags(self, key_id): + if key_id in self.keys: + return self.tagger.list_tags_for_resource(key_id) + raise JsonRESTError( + "NotFoundException", "The request was rejected because the specified entity or resource could not be found." + ) + + def tag_resource(self, key_id, tags): + if key_id in self.keys: + self.tagger.tag_resource(key_id, tags) + return {} + raise JsonRESTError( + "NotFoundException", "The request was rejected because the specified entity or resource could not be found." + ) + + def untag_resource(self, key_id, tag_names): + if key_id in self.keys: + self.tagger.untag_resource_using_names(key_id, tag_names) + return {} + raise JsonRESTError( + "NotFoundException", "The request was rejected because the specified entity or resource could not be found." + ) + kms_backends = {} for region in Session().get_available_regions("kms"): diff --git a/moto/kms/responses.py b/moto/kms/responses.py index d3a9726e1..3658f0d37 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -143,17 +143,27 @@ class KmsResponse(BaseResponse): self._validate_cmk_id(key_id) - self.kms_backend.tag_resource(key_id, tags) - return json.dumps({}) + result = self.kms_backend.tag_resource(key_id, tags) + return json.dumps(result) + + def untag_resource(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html""" + key_id = self.parameters.get("KeyId") + tag_names = self.parameters.get("TagKeys") + + self._validate_cmk_id(key_id) + + result = self.kms_backend.untag_resource(key_id, tag_names) + return json.dumps(result) def list_resource_tags(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" key_id = self.parameters.get("KeyId") - self._validate_cmk_id(key_id) tags = self.kms_backend.list_resource_tags(key_id) - return json.dumps({"Tags": tags, "NextMarker": None, "Truncated": False}) + tags.update({"NextMarker": None, "Truncated": False}) + return json.dumps(tags) def describe_key(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 70fa68787..6a35ee2c8 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -17,7 +17,8 @@ from boto.kms.exceptions import AlreadyExistsException, NotFoundException from freezegun import freeze_time from nose.tools import assert_raises from parameterized import parameterized - +from moto.core.exceptions import JsonRESTError +from moto.kms.models import KmsBackend from moto.kms.exceptions import NotFoundException as MotoNotFoundException from moto import mock_kms, mock_kms_deprecated @@ -910,36 +911,46 @@ def test_update_key_description(): result = client.update_key_description(KeyId=key_id, Description="new_description") assert "ResponseMetadata" in result +@mock_kms +def test_key_tagging_happy(): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="test-key-tagging") + key_id = key["KeyMetadata"]["KeyId"] + + tags = [{"TagKey": "key1", "TagValue": "value1"}, {"TagKey": "key2", "TagValue": "value2"}] + client.tag_resource(KeyId=key_id, Tags=tags) + + result = client.list_resource_tags(KeyId=key_id) + actual = result.get("Tags", []) + assert tags == actual + + client.untag_resource(KeyId=key_id, TagKeys=["key1"]) + + actual = client.list_resource_tags(KeyId=key_id).get("Tags", []) + expected = [{"TagKey": "key2", "TagValue": "value2"}] + assert expected == actual @mock_kms -def test_tag_resource(): - client = boto3.client("kms", region_name="us-east-1") - key = client.create_key(Description="cancel-key-deletion") - response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) +def test_key_tagging_sad(): + b = KmsBackend() - keyid = response["KeyId"] - response = client.tag_resource( - KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}] - ) + try: + b.tag_resource('unknown', []) + raise 'tag_resource should fail if KeyId is not known' + except JsonRESTError: + pass - # Shouldn't have any data, just header - assert len(response.keys()) == 1 + try: + b.untag_resource('unknown', []) + raise 'untag_resource should fail if KeyId is not known' + except JsonRESTError: + pass - -@mock_kms -def test_list_resource_tags(): - client = boto3.client("kms", region_name="us-east-1") - key = client.create_key(Description="cancel-key-deletion") - response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) - - keyid = response["KeyId"] - response = client.tag_resource( - KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}] - ) - - response = client.list_resource_tags(KeyId=keyid) - assert response["Tags"][0]["TagKey"] == "string" - assert response["Tags"][0]["TagValue"] == "string" + try: + b.list_resource_tags('unknown') + raise 'list_resource_tags should fail if KeyId is not known' + except JsonRESTError: + pass @parameterized( diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py index f5478e0ef..29ea969b5 100644 --- a/tests/test_kms/test_utils.py +++ b/tests/test_kms/test_utils.py @@ -102,7 +102,7 @@ def test_deserialize_ciphertext_blob(raw, serialized): @parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) def test_encrypt_decrypt_cycle(encryption_context): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt( @@ -133,7 +133,7 @@ def test_encrypt_unknown_key_id(): def test_decrypt_invalid_ciphertext_format(): - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} with assert_raises(InvalidCiphertextException): @@ -153,7 +153,7 @@ def test_decrypt_unknwown_key_id(): def test_decrypt_invalid_ciphertext(): - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = ( master_key.id.encode("utf-8") + b"123456789012" @@ -171,7 +171,7 @@ def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_encryption_context(): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt(