diff --git a/moto/kms/models.py b/moto/kms/models.py index ff5d0a356..36f72e6de 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -8,7 +8,8 @@ from boto3 import Session from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time - +from moto.utilities.tagging_service import TaggingService +from moto.core.exceptions import JsonRESTError from moto.iam.models import ACCOUNT_ID from .utils import decrypt, encrypt, generate_key_id, generate_master_key @@ -16,7 +17,7 @@ from .utils import decrypt, encrypt, generate_key_id, generate_master_key class Key(BaseModel): def __init__( - self, policy, key_usage, customer_master_key_spec, description, tags, region + self, policy, key_usage, customer_master_key_spec, description, region ): self.id = generate_key_id() self.creation_date = unix_time() @@ -29,7 +30,6 @@ class Key(BaseModel): self.account_id = ACCOUNT_ID self.key_rotation_status = False self.deletion_date = None - self.tags = tags or {} self.key_material = generate_master_key() self.origin = "AWS_KMS" self.key_manager = "CUSTOMER" @@ -111,11 +111,12 @@ class Key(BaseModel): key_usage="ENCRYPT_DECRYPT", customer_master_key_spec="SYMMETRIC_DEFAULT", description=properties["Description"], - tags=properties.get("Tags"), + tags=properties.get("Tags", []), region=region_name, ) key.key_rotation_status = properties["EnableKeyRotation"] key.enabled = properties["Enabled"] + return key def get_cfn_attribute(self, attribute_name): @@ -130,32 +131,26 @@ class KmsBackend(BaseBackend): def __init__(self): self.keys = {} self.key_to_aliases = defaultdict(set) + self.tagger = TaggingService(keyName="TagKey", valueName="TagValue") def create_key( self, policy, key_usage, customer_master_key_spec, description, tags, region ): - key = Key( - policy, key_usage, customer_master_key_spec, description, tags, region - ) + key = Key(policy, key_usage, customer_master_key_spec, description, region) self.keys[key.id] = key + if tags is not None and len(tags) > 0: + self.tag_resource(key.id, tags) return key def update_key_description(self, key_id, description): key = self.keys[self.get_key_id(key_id)] key.description = description - def tag_resource(self, key_id, tags): - key = self.keys[self.get_key_id(key_id)] - key.tags = tags - - def list_resource_tags(self, key_id): - key = self.keys[self.get_key_id(key_id)] - return key.tags - def delete_key(self, key_id): if key_id in self.keys: if key_id in self.key_to_aliases: self.key_to_aliases.pop(key_id) + self.tagger.delete_all_tags_for_resource(key_id) return self.keys.pop(key_id) @@ -325,6 +320,32 @@ class KmsBackend(BaseBackend): return plaintext, ciphertext_blob, arn + def list_resource_tags(self, key_id): + if key_id in self.keys: + return self.tagger.list_tags_for_resource(key_id) + raise JsonRESTError( + "NotFoundException", + "The request was rejected because the specified entity or resource could not be found.", + ) + + def tag_resource(self, key_id, tags): + if key_id in self.keys: + self.tagger.tag_resource(key_id, tags) + return {} + raise JsonRESTError( + "NotFoundException", + "The request was rejected because the specified entity or resource could not be found.", + ) + + def untag_resource(self, key_id, tag_names): + if key_id in self.keys: + self.tagger.untag_resource_using_names(key_id, tag_names) + return {} + raise JsonRESTError( + "NotFoundException", + "The request was rejected because the specified entity or resource could not be found.", + ) + kms_backends = {} for region in Session().get_available_regions("kms"): diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 15b990bbb..995c097e0 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -144,17 +144,27 @@ class KmsResponse(BaseResponse): self._validate_cmk_id(key_id) - self.kms_backend.tag_resource(key_id, tags) - return json.dumps({}) + result = self.kms_backend.tag_resource(key_id, tags) + return json.dumps(result) + + def untag_resource(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html""" + key_id = self.parameters.get("KeyId") + tag_names = self.parameters.get("TagKeys") + + self._validate_cmk_id(key_id) + + result = self.kms_backend.untag_resource(key_id, tag_names) + return json.dumps(result) def list_resource_tags(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" key_id = self.parameters.get("KeyId") - self._validate_cmk_id(key_id) tags = self.kms_backend.list_resource_tags(key_id) - return json.dumps({"Tags": tags, "NextMarker": None, "Truncated": False}) + tags.update({"NextMarker": None, "Truncated": False}) + return json.dumps(tags) def describe_key(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" diff --git a/moto/resourcegroupstaggingapi/models.py b/moto/resourcegroupstaggingapi/models.py index 850ab5c04..d05a53f81 100644 --- a/moto/resourcegroupstaggingapi/models.py +++ b/moto/resourcegroupstaggingapi/models.py @@ -318,7 +318,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # KMS def get_kms_tags(kms_key_id): result = [] - for tag in self.kms_backend.list_resource_tags(kms_key_id): + for tag in self.kms_backend.list_resource_tags(kms_key_id).get("Tags", []): result.append({"Key": tag["TagKey"], "Value": tag["TagValue"]}) return result diff --git a/tests/test_events/test_events.py b/tests/test_events/test_events.py index 4ecb2d882..80fadb449 100644 --- a/tests/test_events/test_events.py +++ b/tests/test_events/test_events.py @@ -1,15 +1,15 @@ +from moto.events.models import EventsBackend +from moto.events import mock_events import json import random import unittest import boto3 from botocore.exceptions import ClientError +from moto.core.exceptions import JsonRESTError from nose.tools import assert_raises from moto.core import ACCOUNT_ID -from moto.core.exceptions import JsonRESTError -from moto.events import mock_events -from moto.events.models import EventsBackend RULES = [ {"Name": "test1", "ScheduleExpression": "rate(5 minutes)"}, diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 9ce324373..a04a24a82 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -4,15 +4,17 @@ import base64 import re import boto.kms +import boto3 import six import sure # noqa from boto.exception import JSONResponseError from boto.kms.exceptions import AlreadyExistsException, NotFoundException from nose.tools import assert_raises from parameterized import parameterized - +from moto.core.exceptions import JsonRESTError +from moto.kms.models import KmsBackend from moto.kms.exceptions import NotFoundException as MotoNotFoundException -from moto import mock_kms_deprecated +from moto import mock_kms_deprecated, mock_kms PLAINTEXT_VECTORS = ( (b"some encodeable plaintext",), @@ -679,3 +681,77 @@ def test__assert_default_policy(): _assert_default_policy.when.called_with("default").should_not.throw( MotoNotFoundException ) + + +if six.PY2: + sort = sorted +else: + sort = lambda l: sorted(l, key=lambda d: d.keys()) + + +@mock_kms +def test_key_tag_on_create_key_happy(): + client = boto3.client("kms", region_name="us-east-1") + + tags = [ + {"TagKey": "key1", "TagValue": "value1"}, + {"TagKey": "key2", "TagValue": "value2"}, + ] + key = client.create_key(Description="test-key-tagging", Tags=tags) + key_id = key["KeyMetadata"]["KeyId"] + + result = client.list_resource_tags(KeyId=key_id) + actual = result.get("Tags", []) + assert sort(tags) == sort(actual) + + client.untag_resource(KeyId=key_id, TagKeys=["key1"]) + + actual = client.list_resource_tags(KeyId=key_id).get("Tags", []) + expected = [{"TagKey": "key2", "TagValue": "value2"}] + assert sort(expected) == sort(actual) + + +@mock_kms +def test_key_tag_added_happy(): + client = boto3.client("kms", region_name="us-east-1") + + key = client.create_key(Description="test-key-tagging") + key_id = key["KeyMetadata"]["KeyId"] + tags = [ + {"TagKey": "key1", "TagValue": "value1"}, + {"TagKey": "key2", "TagValue": "value2"}, + ] + client.tag_resource(KeyId=key_id, Tags=tags) + + result = client.list_resource_tags(KeyId=key_id) + actual = result.get("Tags", []) + assert sort(tags) == sort(actual) + + client.untag_resource(KeyId=key_id, TagKeys=["key1"]) + + actual = client.list_resource_tags(KeyId=key_id).get("Tags", []) + expected = [{"TagKey": "key2", "TagValue": "value2"}] + assert sort(expected) == sort(actual) + + +@mock_kms_deprecated +def test_key_tagging_sad(): + b = KmsBackend() + + try: + b.tag_resource("unknown", []) + raise "tag_resource should fail if KeyId is not known" + except JsonRESTError: + pass + + try: + b.untag_resource("unknown", []) + raise "untag_resource should fail if KeyId is not known" + except JsonRESTError: + pass + + try: + b.list_resource_tags("unknown") + raise "list_resource_tags should fail if KeyId is not known" + except JsonRESTError: + pass diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py index 4c84ed127..4446635f3 100644 --- a/tests/test_kms/test_utils.py +++ b/tests/test_kms/test_utils.py @@ -102,7 +102,7 @@ def test_deserialize_ciphertext_blob(raw, serialized): @parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) def test_encrypt_decrypt_cycle(encryption_context): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt( @@ -133,7 +133,7 @@ def test_encrypt_unknown_key_id(): def test_decrypt_invalid_ciphertext_format(): - master_key = Key("nop", "nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} with assert_raises(InvalidCiphertextException): @@ -153,7 +153,7 @@ def test_decrypt_unknwown_key_id(): def test_decrypt_invalid_ciphertext(): - master_key = Key("nop", "nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = ( master_key.id.encode("utf-8") + b"123456789012" @@ -171,7 +171,7 @@ def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_encryption_context(): plaintext = b"some secret plaintext" - master_key = Key("nop", "nop", "nop", "nop", [], "nop") + master_key = Key("nop", "nop", "nop", "nop", "nop") master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt(