Merge pull request #2758 from brady-gsa/kms-tagging

Kms tagging and untag support
This commit is contained in:
Bert Blommers 2020-02-22 10:25:43 +00:00 committed by GitHub
commit dc9129955b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 136 additions and 29 deletions

View File

@ -8,7 +8,8 @@ from boto3 import Session
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time from moto.core.utils import unix_time
from moto.utilities.tagging_service import TaggingService
from moto.core.exceptions import JsonRESTError
from moto.iam.models import ACCOUNT_ID from moto.iam.models import ACCOUNT_ID
from .utils import decrypt, encrypt, generate_key_id, generate_master_key from .utils import decrypt, encrypt, generate_key_id, generate_master_key
@ -16,7 +17,7 @@ from .utils import decrypt, encrypt, generate_key_id, generate_master_key
class Key(BaseModel): class Key(BaseModel):
def __init__( def __init__(
self, policy, key_usage, customer_master_key_spec, description, tags, region self, policy, key_usage, customer_master_key_spec, description, region
): ):
self.id = generate_key_id() self.id = generate_key_id()
self.creation_date = unix_time() self.creation_date = unix_time()
@ -29,7 +30,6 @@ class Key(BaseModel):
self.account_id = ACCOUNT_ID self.account_id = ACCOUNT_ID
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.tags = tags or {}
self.key_material = generate_master_key() self.key_material = generate_master_key()
self.origin = "AWS_KMS" self.origin = "AWS_KMS"
self.key_manager = "CUSTOMER" self.key_manager = "CUSTOMER"
@ -111,11 +111,12 @@ class Key(BaseModel):
key_usage="ENCRYPT_DECRYPT", key_usage="ENCRYPT_DECRYPT",
customer_master_key_spec="SYMMETRIC_DEFAULT", customer_master_key_spec="SYMMETRIC_DEFAULT",
description=properties["Description"], description=properties["Description"],
tags=properties.get("Tags"), tags=properties.get("Tags", []),
region=region_name, region=region_name,
) )
key.key_rotation_status = properties["EnableKeyRotation"] key.key_rotation_status = properties["EnableKeyRotation"]
key.enabled = properties["Enabled"] key.enabled = properties["Enabled"]
return key return key
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
@ -130,32 +131,26 @@ class KmsBackend(BaseBackend):
def __init__(self): def __init__(self):
self.keys = {} self.keys = {}
self.key_to_aliases = defaultdict(set) self.key_to_aliases = defaultdict(set)
self.tagger = TaggingService(keyName="TagKey", valueName="TagValue")
def create_key( def create_key(
self, policy, key_usage, customer_master_key_spec, description, tags, region self, policy, key_usage, customer_master_key_spec, description, tags, region
): ):
key = Key( key = Key(policy, key_usage, customer_master_key_spec, description, region)
policy, key_usage, customer_master_key_spec, description, tags, region
)
self.keys[key.id] = key self.keys[key.id] = key
if tags is not None and len(tags) > 0:
self.tag_resource(key.id, tags)
return key return key
def update_key_description(self, key_id, description): def update_key_description(self, key_id, description):
key = self.keys[self.get_key_id(key_id)] key = self.keys[self.get_key_id(key_id)]
key.description = description key.description = description
def tag_resource(self, key_id, tags):
key = self.keys[self.get_key_id(key_id)]
key.tags = tags
def list_resource_tags(self, key_id):
key = self.keys[self.get_key_id(key_id)]
return key.tags
def delete_key(self, key_id): def delete_key(self, key_id):
if key_id in self.keys: if key_id in self.keys:
if key_id in self.key_to_aliases: if key_id in self.key_to_aliases:
self.key_to_aliases.pop(key_id) self.key_to_aliases.pop(key_id)
self.tagger.delete_all_tags_for_resource(key_id)
return self.keys.pop(key_id) return self.keys.pop(key_id)
@ -325,6 +320,32 @@ class KmsBackend(BaseBackend):
return plaintext, ciphertext_blob, arn return plaintext, ciphertext_blob, arn
def list_resource_tags(self, key_id):
if key_id in self.keys:
return self.tagger.list_tags_for_resource(key_id)
raise JsonRESTError(
"NotFoundException",
"The request was rejected because the specified entity or resource could not be found.",
)
def tag_resource(self, key_id, tags):
if key_id in self.keys:
self.tagger.tag_resource(key_id, tags)
return {}
raise JsonRESTError(
"NotFoundException",
"The request was rejected because the specified entity or resource could not be found.",
)
def untag_resource(self, key_id, tag_names):
if key_id in self.keys:
self.tagger.untag_resource_using_names(key_id, tag_names)
return {}
raise JsonRESTError(
"NotFoundException",
"The request was rejected because the specified entity or resource could not be found.",
)
kms_backends = {} kms_backends = {}
for region in Session().get_available_regions("kms"): for region in Session().get_available_regions("kms"):

View File

@ -144,17 +144,27 @@ class KmsResponse(BaseResponse):
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
self.kms_backend.tag_resource(key_id, tags) result = self.kms_backend.tag_resource(key_id, tags)
return json.dumps({}) return json.dumps(result)
def untag_resource(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UntagResource.html"""
key_id = self.parameters.get("KeyId")
tag_names = self.parameters.get("TagKeys")
self._validate_cmk_id(key_id)
result = self.kms_backend.untag_resource(key_id, tag_names)
return json.dumps(result)
def list_resource_tags(self): def list_resource_tags(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
key_id = self.parameters.get("KeyId") key_id = self.parameters.get("KeyId")
self._validate_cmk_id(key_id) self._validate_cmk_id(key_id)
tags = self.kms_backend.list_resource_tags(key_id) tags = self.kms_backend.list_resource_tags(key_id)
return json.dumps({"Tags": tags, "NextMarker": None, "Truncated": False}) tags.update({"NextMarker": None, "Truncated": False})
return json.dumps(tags)
def describe_key(self): def describe_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""

View File

@ -318,7 +318,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend):
# KMS # KMS
def get_kms_tags(kms_key_id): def get_kms_tags(kms_key_id):
result = [] result = []
for tag in self.kms_backend.list_resource_tags(kms_key_id): for tag in self.kms_backend.list_resource_tags(kms_key_id).get("Tags", []):
result.append({"Key": tag["TagKey"], "Value": tag["TagValue"]}) result.append({"Key": tag["TagKey"], "Value": tag["TagValue"]})
return result return result

View File

@ -1,15 +1,15 @@
from moto.events.models import EventsBackend
from moto.events import mock_events
import json import json
import random import random
import unittest import unittest
import boto3 import boto3
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto.core.exceptions import JsonRESTError
from nose.tools import assert_raises from nose.tools import assert_raises
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from moto.core.exceptions import JsonRESTError
from moto.events import mock_events
from moto.events.models import EventsBackend
RULES = [ RULES = [
{"Name": "test1", "ScheduleExpression": "rate(5 minutes)"}, {"Name": "test1", "ScheduleExpression": "rate(5 minutes)"},

View File

@ -4,15 +4,17 @@ import base64
import re import re
import boto.kms import boto.kms
import boto3
import six import six
import sure # noqa import sure # noqa
from boto.exception import JSONResponseError from boto.exception import JSONResponseError
from boto.kms.exceptions import AlreadyExistsException, NotFoundException from boto.kms.exceptions import AlreadyExistsException, NotFoundException
from nose.tools import assert_raises from nose.tools import assert_raises
from parameterized import parameterized from parameterized import parameterized
from moto.core.exceptions import JsonRESTError
from moto.kms.models import KmsBackend
from moto.kms.exceptions import NotFoundException as MotoNotFoundException from moto.kms.exceptions import NotFoundException as MotoNotFoundException
from moto import mock_kms_deprecated from moto import mock_kms_deprecated, mock_kms
PLAINTEXT_VECTORS = ( PLAINTEXT_VECTORS = (
(b"some encodeable plaintext",), (b"some encodeable plaintext",),
@ -679,3 +681,77 @@ def test__assert_default_policy():
_assert_default_policy.when.called_with("default").should_not.throw( _assert_default_policy.when.called_with("default").should_not.throw(
MotoNotFoundException MotoNotFoundException
) )
if six.PY2:
sort = sorted
else:
sort = lambda l: sorted(l, key=lambda d: d.keys())
@mock_kms
def test_key_tag_on_create_key_happy():
client = boto3.client("kms", region_name="us-east-1")
tags = [
{"TagKey": "key1", "TagValue": "value1"},
{"TagKey": "key2", "TagValue": "value2"},
]
key = client.create_key(Description="test-key-tagging", Tags=tags)
key_id = key["KeyMetadata"]["KeyId"]
result = client.list_resource_tags(KeyId=key_id)
actual = result.get("Tags", [])
assert sort(tags) == sort(actual)
client.untag_resource(KeyId=key_id, TagKeys=["key1"])
actual = client.list_resource_tags(KeyId=key_id).get("Tags", [])
expected = [{"TagKey": "key2", "TagValue": "value2"}]
assert sort(expected) == sort(actual)
@mock_kms
def test_key_tag_added_happy():
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="test-key-tagging")
key_id = key["KeyMetadata"]["KeyId"]
tags = [
{"TagKey": "key1", "TagValue": "value1"},
{"TagKey": "key2", "TagValue": "value2"},
]
client.tag_resource(KeyId=key_id, Tags=tags)
result = client.list_resource_tags(KeyId=key_id)
actual = result.get("Tags", [])
assert sort(tags) == sort(actual)
client.untag_resource(KeyId=key_id, TagKeys=["key1"])
actual = client.list_resource_tags(KeyId=key_id).get("Tags", [])
expected = [{"TagKey": "key2", "TagValue": "value2"}]
assert sort(expected) == sort(actual)
@mock_kms_deprecated
def test_key_tagging_sad():
b = KmsBackend()
try:
b.tag_resource("unknown", [])
raise "tag_resource should fail if KeyId is not known"
except JsonRESTError:
pass
try:
b.untag_resource("unknown", [])
raise "untag_resource should fail if KeyId is not known"
except JsonRESTError:
pass
try:
b.list_resource_tags("unknown")
raise "list_resource_tags should fail if KeyId is not known"
except JsonRESTError:
pass

View File

@ -102,7 +102,7 @@ def test_deserialize_ciphertext_blob(raw, serialized):
@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) @parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS))
def test_encrypt_decrypt_cycle(encryption_context): def test_encrypt_decrypt_cycle(encryption_context):
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(
@ -133,7 +133,7 @@ def test_encrypt_unknown_key_id():
def test_decrypt_invalid_ciphertext_format(): def test_decrypt_invalid_ciphertext_format():
master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
with assert_raises(InvalidCiphertextException): with assert_raises(InvalidCiphertextException):
@ -153,7 +153,7 @@ def test_decrypt_unknwown_key_id():
def test_decrypt_invalid_ciphertext(): def test_decrypt_invalid_ciphertext():
master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = ( ciphertext_blob = (
master_key.id.encode("utf-8") + b"123456789012" master_key.id.encode("utf-8") + b"123456789012"
@ -171,7 +171,7 @@ def test_decrypt_invalid_ciphertext():
def test_decrypt_invalid_encryption_context(): def test_decrypt_invalid_encryption_context():
plaintext = b"some secret plaintext" plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", "nop", [], "nop") master_key = Key("nop", "nop", "nop", "nop", "nop")
master_key_map = {master_key.id: master_key} master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt( ciphertext_blob = encrypt(