updates KMS service to use TaggingService

This commit is contained in:
Bryan Alexander 2020-01-16 12:10:38 -06:00
parent 6cb0428d20
commit 85207b885b
4 changed files with 87 additions and 47 deletions

View File

@ -7,13 +7,14 @@ from datetime import datetime, timedelta
from boto3 import Session from boto3 import Session
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.exceptions import JsonRESTError
from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.utilities.tagging_service import TaggingService
from .utils import decrypt, encrypt, generate_key_id, generate_master_key from .utils import decrypt, encrypt, generate_key_id, generate_master_key
class Key(BaseModel): class Key(BaseModel):
def __init__(self, policy, key_usage, description, tags, region): def __init__(self, policy, key_usage, description, region):
self.id = generate_key_id() self.id = generate_key_id()
self.policy = policy self.policy = policy
self.key_usage = key_usage self.key_usage = key_usage
@ -24,7 +25,6 @@ class Key(BaseModel):
self.account_id = "012345678912" self.account_id = "012345678912"
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.tags = tags or {}
self.key_material = generate_master_key() self.key_material = generate_master_key()
@property @property
@ -70,11 +70,12 @@ class Key(BaseModel):
policy=properties["KeyPolicy"], policy=properties["KeyPolicy"],
key_usage="ENCRYPT_DECRYPT", key_usage="ENCRYPT_DECRYPT",
description=properties["Description"], description=properties["Description"],
tags=properties.get("Tags"),
region=region_name, region=region_name,
) )
key.key_rotation_status = properties["EnableKeyRotation"] key.key_rotation_status = properties["EnableKeyRotation"]
key.enabled = properties["Enabled"] key.enabled = properties["Enabled"]
kms_backend.tag_resource(key.id, properties.get("Tags"))
return key return key
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
@ -89,24 +90,19 @@ class KmsBackend(BaseBackend):
def __init__(self): def __init__(self):
self.keys = {} self.keys = {}
self.key_to_aliases = defaultdict(set) self.key_to_aliases = defaultdict(set)
self.tagger = TaggingService(keyName='TagKey', valueName='TagValue')
def create_key(self, policy, key_usage, description, tags, region): def create_key(self, policy, key_usage, description, tags, region):
key = Key(policy, key_usage, description, tags, region) key = Key(policy, key_usage, description, region)
self.keys[key.id] = key self.keys[key.id] = key
if tags != None and len(tags) > 0:
self.tag_resource(key.id, tags)
return key return key
def update_key_description(self, key_id, description): def update_key_description(self, key_id, description):
key = self.keys[self.get_key_id(key_id)] key = self.keys[self.get_key_id(key_id)]
key.description = description key.description = description
def tag_resource(self, key_id, tags):
key = self.keys[self.get_key_id(key_id)]
key.tags = tags
def list_resource_tags(self, key_id):
key = self.keys[self.get_key_id(key_id)]
return key.tags
def delete_key(self, key_id): def delete_key(self, key_id):
if key_id in self.keys: if key_id in self.keys:
if key_id in self.key_to_aliases: if key_id in self.key_to_aliases:
@ -282,6 +278,29 @@ class KmsBackend(BaseBackend):
return plaintext, ciphertext_blob, arn return plaintext, ciphertext_blob, arn
def list_resource_tags(self, key_id):
if key_id in self.keys:
return self.tagger.list_tags_for_resource(key_id)
raise JsonRESTError(
"NotFoundException", "The request was rejected because the specified entity or resource could not be found."
)
def tag_resource(self, key_id, tags):
if key_id in self.keys:
self.tagger.tag_resource(key_id, tags)
return {}
raise JsonRESTError(
"NotFoundException", "The request was rejected because the specified entity or resource could not be found."
)
def untag_resource(self, key_id, tag_names):
if key_id in self.keys:
self.tagger.untag_resource_using_names(key_id, tag_names)
return {}
raise JsonRESTError(
"NotFoundException", "The request was rejected because the specified entity or resource could not be found."
)
kms_backends = {} kms_backends = {}
for region in Session().get_available_regions("kms"): for region in Session().get_available_regions("kms"):

View File

@ -143,17 +143,27 @@ class KmsResponse(BaseResponse):
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
self.kms_backend.tag_resource(key_id, tags) result = self.kms_backend.tag_resource(key_id, tags)
return json.dumps({}) return json.dumps(result)
def untag_resource(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html"""
key_id = self.parameters.get("KeyId")
tag_names = self.parameters.get("TagKeys")
self._validate_cmk_id(key_id)
result = self.kms_backend.untag_resource(key_id, tag_names)
return json.dumps(result)
def list_resource_tags(self): def list_resource_tags(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
key_id = self.parameters.get("KeyId") key_id = self.parameters.get("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
tags = self.kms_backend.list_resource_tags(key_id) tags = self.kms_backend.list_resource_tags(key_id)
return json.dumps({"Tags": tags, "NextMarker": None, "Truncated": False}) tags.update({"NextMarker": None, "Truncated": False})
return json.dumps(tags)
def describe_key(self): def describe_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""

View File

@ -17,7 +17,8 @@ from boto.kms.exceptions import AlreadyExistsException, NotFoundException
from freezegun import freeze_time from freezegun import freeze_time
from nose.tools import assert_raises from nose.tools import assert_raises
from parameterized import parameterized from parameterized import parameterized
from moto.core.exceptions import JsonRESTError
from moto.kms.models import KmsBackend
from moto.kms.exceptions import NotFoundException as MotoNotFoundException from moto.kms.exceptions import NotFoundException as MotoNotFoundException
from moto import mock_kms, mock_kms_deprecated from moto import mock_kms, mock_kms_deprecated
@ -910,36 +911,46 @@ def test_update_key_description():
result = client.update_key_description(KeyId=key_id, Description="new_description") result = client.update_key_description(KeyId=key_id, Description="new_description")
assert "ResponseMetadata" in result assert "ResponseMetadata" in result
@mock_kms
def test_key_tagging_happy():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="test-key-tagging")
key_id = key["KeyMetadata"]["KeyId"]
tags = [{"TagKey": "key1", "TagValue": "value1"}, {"TagKey": "key2", "TagValue": "value2"}]
client.tag_resource(KeyId=key_id, Tags=tags)
result = client.list_resource_tags(KeyId=key_id)
actual = result.get("Tags", [])
assert tags == actual
client.untag_resource(KeyId=key_id, TagKeys=["key1"])
actual = client.list_resource_tags(KeyId=key_id).get("Tags", [])
expected = [{"TagKey": "key2", "TagValue": "value2"}]
assert expected == actual
@mock_kms @mock_kms
def test_tag_resource(): def test_key_tagging_sad():
client = boto3.client("kms", region_name="us-east-1") b = KmsBackend()
key = client.create_key(Description="cancel-key-deletion")
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
keyid = response["KeyId"] try:
response = client.tag_resource( b.tag_resource('unknown', [])
KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}] raise 'tag_resource should fail if KeyId is not known'
) except JsonRESTError:
pass
# Shouldn't have any data, just header try:
assert len(response.keys()) == 1 b.untag_resource('unknown', [])
raise 'untag_resource should fail if KeyId is not known'
except JsonRESTError:
pass
try:
@mock_kms b.list_resource_tags('unknown')
def test_list_resource_tags(): raise 'list_resource_tags should fail if KeyId is not known'
client = boto3.client("kms", region_name="us-east-1") except JsonRESTError:
key = client.create_key(Description="cancel-key-deletion") pass
response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"])
keyid = response["KeyId"]
response = client.tag_resource(
KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]
)
response = client.list_resource_tags(KeyId=keyid)
assert response["Tags"][0]["TagKey"] == "string"
assert response["Tags"][0]["TagValue"] == "string"
@parameterized( @parameterized(

View File

@ -102,7 +102,7 @@ def test_deserialize_ciphertext_blob(raw, serialized):
@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) @parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS))
def test_encrypt_decrypt_cycle(encryption_context): def test_encrypt_decrypt_cycle(encryption_context):
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(
@ -133,7 +133,7 @@ def test_encrypt_unknown_key_id():
def test_decrypt_invalid_ciphertext_format(): def test_decrypt_invalid_ciphertext_format():
master_key = Key("nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
with assert_raises(InvalidCiphertextException): with assert_raises(InvalidCiphertextException):
@ -153,7 +153,7 @@ def test_decrypt_unknwown_key_id():
def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_ciphertext():
master_key = Key("nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = ( ciphertext_blob = (
master_key.id.encode("utf-8") + b"123456789012" master_key.id.encode("utf-8") + b"123456789012"
@ -171,7 +171,7 @@ def test_decrypt_invalid_ciphertext():
def test_decrypt_invalid_encryption_context(): def test_decrypt_invalid_encryption_context():
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(