Merge pull request #2420 from mattsb42-aws/key_id_validation

Normalize KMS key ID validation
This commit is contained in:
Steve Pulec 2019-09-23 21:33:03 -05:00 committed by GitHub
commit 8f4a05581c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 304 additions and 134 deletions

View File

@ -1,13 +1,16 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os import os
import boto.kms
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import decrypt, encrypt, generate_key_id, generate_master_key
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import boto.kms
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import decrypt, encrypt, generate_key_id, generate_master_key
class Key(BaseModel): class Key(BaseModel):
def __init__(self, policy, key_usage, description, tags, region): def __init__(self, policy, key_usage, description, tags, region):
@ -18,7 +21,7 @@ class Key(BaseModel):
self.description = description self.description = description
self.enabled = True self.enabled = True
self.region = region self.region = region
self.account_id = "0123456789012" self.account_id = "012345678912"
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.tags = tags or {} self.tags = tags or {}
@ -116,13 +119,21 @@ class KmsBackend(BaseBackend):
def list_keys(self): def list_keys(self):
return self.keys.values() return self.keys.values()
def get_key_id(self, key_id): @staticmethod
def get_key_id(key_id):
# Allow use of ARN as well as pure KeyId # Allow use of ARN as well as pure KeyId
return str(key_id).split(r":key/")[1] if r":key/" in str(key_id).lower() else key_id if key_id.startswith("arn:") and ":key/" in key_id:
return key_id.split(":key/")[1]
def get_alias_name(self, alias_name): return key_id
@staticmethod
def get_alias_name(alias_name):
# Allow use of ARN as well as alias name # Allow use of ARN as well as alias name
return str(alias_name).split(r":alias/")[1] if r":alias/" in str(alias_name).lower() else alias_name if alias_name.startswith("arn:") and ":alias/" in alias_name:
return alias_name.split(":alias/")[1]
return alias_name
def any_id_to_key_id(self, key_id): def any_id_to_key_id(self, key_id):
"""Go from any valid key ID to the raw key ID. """Go from any valid key ID to the raw key ID.

View File

@ -11,6 +11,7 @@ from moto.core.responses import BaseResponse
from .models import kms_backends from .models import kms_backends
from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException
ACCOUNT_ID = "012345678912"
reserved_aliases = [ reserved_aliases = [
'alias/aws/ebs', 'alias/aws/ebs',
'alias/aws/s3', 'alias/aws/s3',
@ -35,7 +36,74 @@ class KmsResponse(BaseResponse):
def kms_backend(self): def kms_backend(self):
return kms_backends[self.region] return kms_backends[self.region]
def _display_arn(self, key_id):
if key_id.startswith("arn:"):
return key_id
if key_id.startswith("alias/"):
id_type = ""
else:
id_type = "key/"
return "arn:aws:kms:{region}:{account}:{id_type}{key_id}".format(
region=self.region, account=ACCOUNT_ID, id_type=id_type, key_id=key_id
)
def _validate_cmk_id(self, key_id):
"""Determine whether a CMK ID exists.
- raw key ID
- key ARN
"""
is_arn = key_id.startswith("arn:") and ":key/" in key_id
is_raw_key_id = re.match(r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", key_id, re.IGNORECASE)
if not is_arn and not is_raw_key_id:
raise NotFoundException("Invalid keyId {key_id}".format(key_id=key_id))
cmk_id = self.kms_backend.get_key_id(key_id)
if cmk_id not in self.kms_backend.keys:
raise NotFoundException("Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id)))
def _validate_alias(self, key_id):
"""Determine whether an alias exists.
- alias name
- alias ARN
"""
error = NotFoundException("Alias {key_id} is not found.".format(key_id=self._display_arn(key_id)))
is_arn = key_id.startswith("arn:") and ":alias/" in key_id
is_name = key_id.startswith("alias/")
if not is_arn and not is_name:
raise error
alias_name = self.kms_backend.get_alias_name(key_id)
cmk_id = self.kms_backend.get_key_id_from_alias(alias_name)
if cmk_id is None:
raise error
def _validate_key_id(self, key_id):
"""Determine whether or not a key ID exists.
- raw key ID
- key ARN
- alias name
- alias ARN
"""
is_alias_arn = key_id.startswith("arn:") and ":alias/" in key_id
is_alias_name = key_id.startswith("alias/")
if is_alias_arn or is_alias_name:
self._validate_alias(key_id)
return
self._validate_cmk_id(key_id)
def create_key(self): def create_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
policy = self.parameters.get('Policy') policy = self.parameters.get('Policy')
key_usage = self.parameters.get('KeyUsage') key_usage = self.parameters.get('KeyUsage')
description = self.parameters.get('Description') description = self.parameters.get('Description')
@ -46,20 +114,31 @@ class KmsResponse(BaseResponse):
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def update_key_description(self): def update_key_description(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
description = self.parameters.get('Description') description = self.parameters.get('Description')
self._validate_cmk_id(key_id)
self.kms_backend.update_key_description(key_id, description) self.kms_backend.update_key_description(key_id, description)
return json.dumps(None) return json.dumps(None)
def tag_resource(self): def tag_resource(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
tags = self.parameters.get('Tags') tags = self.parameters.get('Tags')
self._validate_cmk_id(key_id)
self.kms_backend.tag_resource(key_id, tags) self.kms_backend.tag_resource(key_id, tags)
return json.dumps({}) return json.dumps({})
def list_resource_tags(self): def list_resource_tags(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
self._validate_cmk_id(key_id)
tags = self.kms_backend.list_resource_tags(key_id) tags = self.kms_backend.list_resource_tags(key_id)
return json.dumps({ return json.dumps({
"Tags": tags, "Tags": tags,
@ -68,17 +147,19 @@ class KmsResponse(BaseResponse):
}) })
def describe_key(self): def describe_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
try:
self._validate_key_id(key_id)
key = self.kms_backend.describe_key( key = self.kms_backend.describe_key(
self.kms_backend.get_key_id(key_id)) self.kms_backend.get_key_id(key_id)
except KeyError: )
headers = dict(self.headers)
headers['status'] = 404
return "{}", headers
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def list_keys(self): def list_keys(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html"""
keys = self.kms_backend.list_keys() keys = self.kms_backend.list_keys()
return json.dumps({ return json.dumps({
@ -93,6 +174,7 @@ class KmsResponse(BaseResponse):
}) })
def create_alias(self): def create_alias(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html"""
alias_name = self.parameters['AliasName'] alias_name = self.parameters['AliasName']
target_key_id = self.parameters['TargetKeyId'] target_key_id = self.parameters['TargetKeyId']
@ -118,27 +200,31 @@ class KmsResponse(BaseResponse):
raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} ' raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} '
'already exists'.format(region=self.region, alias_name=alias_name)) 'already exists'.format(region=self.region, alias_name=alias_name))
self._validate_cmk_id(target_key_id)
self.kms_backend.add_alias(target_key_id, alias_name) self.kms_backend.add_alias(target_key_id, alias_name)
return json.dumps(None) return json.dumps(None)
def delete_alias(self): def delete_alias(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html"""
alias_name = self.parameters['AliasName'] alias_name = self.parameters['AliasName']
if not alias_name.startswith('alias/'): if not alias_name.startswith('alias/'):
raise ValidationException('Invalid identifier') raise ValidationException('Invalid identifier')
if not self.kms_backend.alias_exists(alias_name): self._validate_alias(alias_name)
raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:'
'{alias_name} is not found.'.format(region=self.region, alias_name=alias_name))
self.kms_backend.delete_alias(alias_name) self.kms_backend.delete_alias(alias_name)
return json.dumps(None) return json.dumps(None)
def list_aliases(self): def list_aliases(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html"""
region = self.region region = self.region
# TODO: The actual API can filter on KeyId.
response_aliases = [ response_aliases = [
{ {
'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region, 'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region,
@ -163,79 +249,76 @@ class KmsResponse(BaseResponse):
}) })
def enable_key_rotation(self): def enable_key_rotation(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.enable_key_rotation(key_id) self.kms_backend.enable_key_rotation(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def disable_key_rotation(self): def disable_key_rotation(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.disable_key_rotation(key_id) self.kms_backend.disable_key_rotation(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def get_key_rotation_status(self): def get_key_rotation_status(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) rotation_enabled = self.kms_backend.get_key_rotation_status(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'KeyRotationEnabled': rotation_enabled}) return json.dumps({'KeyRotationEnabled': rotation_enabled})
def put_key_policy(self): def put_key_policy(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
policy_name = self.parameters.get('PolicyName') policy_name = self.parameters.get('PolicyName')
policy = self.parameters.get('Policy') policy = self.parameters.get('Policy')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
_assert_default_policy(policy_name) _assert_default_policy(policy_name)
try: self._validate_cmk_id(key_id)
self.kms_backend.put_key_policy(key_id, policy) self.kms_backend.put_key_policy(key_id, policy)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def get_key_policy(self): def get_key_policy(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
policy_name = self.parameters.get('PolicyName') policy_name = self.parameters.get('PolicyName')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
_assert_default_policy(policy_name) _assert_default_policy(policy_name)
try: self._validate_cmk_id(key_id)
return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)})
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
def list_key_policies(self): def list_key_policies(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.describe_key(key_id) self.kms_backend.describe_key(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) return json.dumps({'Truncated': False, 'PolicyNames': ['default']})
def encrypt(self): def encrypt(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html"""
key_id = self.parameters.get("KeyId") key_id = self.parameters.get("KeyId")
encryption_context = self.parameters.get('EncryptionContext', {}) encryption_context = self.parameters.get('EncryptionContext', {})
plaintext = self.parameters.get("Plaintext") plaintext = self.parameters.get("Plaintext")
self._validate_key_id(key_id)
if isinstance(plaintext, six.text_type): if isinstance(plaintext, six.text_type):
plaintext = plaintext.encode('utf-8') plaintext = plaintext.encode('utf-8')
@ -249,6 +332,7 @@ class KmsResponse(BaseResponse):
return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn}) return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn})
def decrypt(self): def decrypt(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob") ciphertext_blob = self.parameters.get("CiphertextBlob")
encryption_context = self.parameters.get('EncryptionContext', {}) encryption_context = self.parameters.get('EncryptionContext', {})
@ -262,11 +346,14 @@ class KmsResponse(BaseResponse):
return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn}) return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn})
def re_encrypt(self): def re_encrypt(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob") ciphertext_blob = self.parameters.get("CiphertextBlob")
source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) source_encryption_context = self.parameters.get("SourceEncryptionContext", {})
destination_key_id = self.parameters.get("DestinationKeyId") destination_key_id = self.parameters.get("DestinationKeyId")
destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {}) destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {})
self._validate_cmk_id(destination_key_id)
new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt( new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt(
ciphertext_blob=ciphertext_blob, ciphertext_blob=ciphertext_blob,
source_encryption_context=source_encryption_context, source_encryption_context=source_encryption_context,
@ -281,52 +368,52 @@ class KmsResponse(BaseResponse):
) )
def disable_key(self): def disable_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.disable_key(key_id) self.kms_backend.disable_key(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def enable_key(self): def enable_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.enable_key(key_id) self.kms_backend.enable_key(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def cancel_key_deletion(self): def cancel_key_deletion(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.cancel_key_deletion(key_id) self.kms_backend.cancel_key_deletion(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'KeyId': key_id}) return json.dumps({'KeyId': key_id})
def schedule_key_deletion(self): def schedule_key_deletion(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
if self.parameters.get('PendingWindowInDays') is None: if self.parameters.get('PendingWindowInDays') is None:
pending_window_in_days = 30 pending_window_in_days = 30
else: else:
pending_window_in_days = self.parameters.get('PendingWindowInDays') pending_window_in_days = self.parameters.get('PendingWindowInDays')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
return json.dumps({ return json.dumps({
'KeyId': key_id, 'KeyId': key_id,
'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days)
}) })
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
def generate_data_key(self): def generate_data_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
encryption_context = self.parameters.get('EncryptionContext', {}) encryption_context = self.parameters.get('EncryptionContext', {})
number_of_bytes = self.parameters.get('NumberOfBytes') number_of_bytes = self.parameters.get('NumberOfBytes')
@ -334,15 +421,9 @@ class KmsResponse(BaseResponse):
grant_tokens = self.parameters.get('GrantTokens') grant_tokens = self.parameters.get('GrantTokens')
# Param validation # Param validation
if key_id.startswith('alias'): self._validate_key_id(key_id)
if self.kms_backend.get_key_id_from_alias(key_id) is None:
raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:{alias_name} is not found.'.format(
region=self.region, alias_name=key_id))
else:
if self.kms_backend.get_key_id(key_id) not in self.kms_backend.keys:
raise NotFoundException('Invalid keyId')
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0): if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
raise ValidationException(( raise ValidationException((
"1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed "
"to satisfy constraint: Member must have value less than or " "to satisfy constraint: Member must have value less than or "
@ -357,6 +438,7 @@ class KmsResponse(BaseResponse):
).format(key_spec=key_spec)) ).format(key_spec=key_spec))
if not key_spec and not number_of_bytes: if not key_spec and not number_of_bytes:
raise ValidationException("Please specify either number of bytes or key spec.") raise ValidationException("Please specify either number of bytes or key spec.")
if key_spec and number_of_bytes: if key_spec and number_of_bytes:
raise ValidationException("Please specify either number of bytes or key spec.") raise ValidationException("Please specify either number of bytes or key spec.")
@ -378,14 +460,23 @@ class KmsResponse(BaseResponse):
}) })
def generate_data_key_without_plaintext(self): def generate_data_key_without_plaintext(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html"""
result = json.loads(self.generate_data_key()) result = json.loads(self.generate_data_key())
del result['Plaintext'] del result['Plaintext']
return json.dumps(result) return json.dumps(result)
def generate_random(self): def generate_random(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html"""
number_of_bytes = self.parameters.get("NumberOfBytes") number_of_bytes = self.parameters.get("NumberOfBytes")
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
raise ValidationException((
"1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed "
"to satisfy constraint: Member must have value less than or "
"equal to 1024"
).format(number_of_bytes=number_of_bytes))
entropy = os.urandom(number_of_bytes) entropy = os.urandom(number_of_bytes)
response_entropy = base64.b64encode(entropy).decode("utf-8") response_entropy = base64.b64encode(entropy).decode("utf-8")
@ -393,11 +484,6 @@ class KmsResponse(BaseResponse):
return json.dumps({"Plaintext": response_entropy}) return json.dumps({"Plaintext": response_entropy})
def _assert_valid_key_id(key_id):
if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE):
raise NotFoundException('Invalid keyId')
def _assert_default_policy(policy_name): def _assert_default_policy(policy_name):
if policy_name != 'default': if policy_name != 'default':
raise NotFoundException("No such policy exists") raise NotFoundException("No such policy exists")

View File

@ -1,15 +1,16 @@
# -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
from datetime import date from datetime import date
from datetime import datetime from datetime import datetime
from dateutil.tz import tzutc from dateutil.tz import tzutc
import base64 import base64
import binascii
import os import os
import re import re
import boto3 import boto3
import boto.kms import boto.kms
import botocore.exceptions import botocore.exceptions
import six
import sure # noqa import sure # noqa
from boto.exception import JSONResponseError from boto.exception import JSONResponseError
from boto.kms.exceptions import AlreadyExistsException, NotFoundException from boto.kms.exceptions import AlreadyExistsException, NotFoundException
@ -23,9 +24,17 @@ from moto import mock_kms, mock_kms_deprecated
PLAINTEXT_VECTORS = ( PLAINTEXT_VECTORS = (
(b"some encodeable plaintext",), (b"some encodeable plaintext",),
(b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",), (b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",),
(u"some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",),
) )
def _get_encoded_value(plaintext):
if isinstance(plaintext, six.binary_type):
return plaintext
return plaintext.encode("utf-8")
@mock_kms @mock_kms
def test_create_key(): def test_create_key():
conn = boto3.client("kms", region_name="us-east-1") conn = boto3.client("kms", region_name="us-east-1")
@ -72,7 +81,21 @@ def test_describe_key_via_alias_not_found():
key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT")
conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"])
conn.describe_key.when.called_with("alias/not-found-alias").should.throw(JSONResponseError) conn.describe_key.when.called_with("alias/not-found-alias").should.throw(NotFoundException)
@parameterized((
("alias/does-not-exist",),
("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",),
("invalid",),
))
@mock_kms
def test_describe_key_via_alias_invalid_alias(key_id):
client = boto3.client("kms", region_name="us-east-1")
client.create_key(Description="key")
with assert_raises(client.exceptions.NotFoundException):
client.describe_key(KeyId=key_id)
@mock_kms_deprecated @mock_kms_deprecated
@ -90,7 +113,7 @@ def test_describe_key_via_arn():
@mock_kms_deprecated @mock_kms_deprecated
def test_describe_missing_key(): def test_describe_missing_key():
conn = boto.kms.connect_to_region("us-west-2") conn = boto.kms.connect_to_region("us-west-2")
conn.describe_key.when.called_with("not-a-key").should.throw(JSONResponseError) conn.describe_key.when.called_with("not-a-key").should.throw(NotFoundException)
@mock_kms_deprecated @mock_kms_deprecated
@ -201,15 +224,15 @@ def test_boto3_generate_data_key():
@parameterized(PLAINTEXT_VECTORS) @parameterized(PLAINTEXT_VECTORS)
@mock_kms_deprecated @mock_kms
def test_encrypt(plaintext): def test_encrypt(plaintext):
conn = boto.kms.connect_to_region("us-west-2") client = boto3.client("kms", region_name="us-west-2")
key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"] key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"] key_arn = key["KeyMetadata"]["Arn"]
response = conn.encrypt(key_id, plaintext) response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
response["CiphertextBlob"].should_not.equal(plaintext) response["CiphertextBlob"].should_not.equal(plaintext)
# CiphertextBlob must NOT be base64-encoded # CiphertextBlob must NOT be base64-encoded
@ -220,27 +243,28 @@ def test_encrypt(plaintext):
@parameterized(PLAINTEXT_VECTORS) @parameterized(PLAINTEXT_VECTORS)
@mock_kms_deprecated @mock_kms
def test_decrypt(plaintext): def test_decrypt(plaintext):
conn = boto.kms.connect_to_region("us-west-2") client = boto3.client("kms", region_name="us-west-2")
key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") key = client.create_key(Description="key")
key_id = key["KeyMetadata"]["KeyId"] key_id = key["KeyMetadata"]["KeyId"]
key_arn = key["KeyMetadata"]["Arn"] key_arn = key["KeyMetadata"]["Arn"]
encrypt_response = conn.encrypt(key_id, plaintext) encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext)
client.create_key(Description="key")
# CiphertextBlob must NOT be base64-encoded # CiphertextBlob must NOT be base64-encoded
with assert_raises(Exception): with assert_raises(Exception):
base64.b64decode(encrypt_response["CiphertextBlob"], validate=True) base64.b64decode(encrypt_response["CiphertextBlob"], validate=True)
decrypt_response = conn.decrypt(encrypt_response["CiphertextBlob"]) decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"])
# Plaintext must NOT be base64-encoded # Plaintext must NOT be base64-encoded
with assert_raises(Exception): with assert_raises(Exception):
base64.b64decode(decrypt_response["Plaintext"], validate=True) base64.b64decode(decrypt_response["Plaintext"], validate=True)
decrypt_response["Plaintext"].should.equal(plaintext) decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response["KeyId"].should.equal(key_arn) decrypt_response["KeyId"].should.equal(key_arn)
@ -493,15 +517,16 @@ def test__create_alias__raises_if_alias_has_colon_character():
ex.status.should.equal(400) ex.status.should.equal(400)
@parameterized((
("alias/my-alias_/",),
("alias/my_alias-/",),
))
@mock_kms_deprecated @mock_kms_deprecated
def test__create_alias__accepted_characters(): def test__create_alias__accepted_characters(alias_name):
kms = boto.connect_kms() kms = boto.connect_kms()
create_resp = kms.create_key() create_resp = kms.create_key()
key_id = create_resp["KeyMetadata"]["KeyId"] key_id = create_resp["KeyMetadata"]["KeyId"]
alias_names_with_accepted_characters = ["alias/my-alias_/", "alias/my_alias-/"]
for alias_name in alias_names_with_accepted_characters:
kms.create_alias(alias_name, key_id) kms.create_alias(alias_name, key_id)
@ -575,14 +600,16 @@ def test__delete_alias__raises_if_alias_is_not_found():
with assert_raises(NotFoundException) as err: with assert_raises(NotFoundException) as err:
kms.delete_alias(alias_name) kms.delete_alias(alias_name)
expected_message_match = r"Alias arn:aws:kms:{region}:[0-9]{{12}}:{alias_name} is not found.".format(
region=region,
alias_name=alias_name
)
ex = err.exception ex = err.exception
ex.body["__type"].should.equal("NotFoundException") ex.body["__type"].should.equal("NotFoundException")
ex.body["message"].should.match( ex.body["message"].should.match(expected_message_match)
r"Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.".format(**locals())
)
ex.box_usage.should.be.none ex.box_usage.should.be.none
ex.error_code.should.be.none ex.error_code.should.be.none
ex.message.should.match(r"Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.".format(**locals())) ex.message.should.match(expected_message_match)
ex.reason.should.equal("Bad Request") ex.reason.should.equal("Bad Request")
ex.request_id.should.be.none ex.request_id.should.be.none
ex.status.should.equal(400) ex.status.should.equal(400)
@ -635,13 +662,19 @@ def test__list_aliases():
len(aliases).should.equal(7) len(aliases).should.equal(7)
@mock_kms_deprecated @parameterized((
def test__assert_valid_key_id(): ("not-a-uuid",),
from moto.kms.responses import _assert_valid_key_id ("alias/DoesNotExist",),
import uuid ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",),
))
@mock_kms
def test_invalid_key_ids(key_id):
client = boto3.client("kms", region_name="us-east-1")
_assert_valid_key_id.when.called_with("not-a-key").should.throw(MotoNotFoundException) with assert_raises(client.exceptions.NotFoundException):
_assert_valid_key_id.when.called_with(str(uuid.uuid4())).should_not.throw(MotoNotFoundException) client.generate_data_key(KeyId=key_id, NumberOfBytes=5)
@mock_kms_deprecated @mock_kms_deprecated
@ -660,7 +693,7 @@ def test_kms_encrypt_boto3(plaintext):
response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext) response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext)
response = client.decrypt(CiphertextBlob=response["CiphertextBlob"]) response = client.decrypt(CiphertextBlob=response["CiphertextBlob"])
response["Plaintext"].should.equal(plaintext) response["Plaintext"].should.equal(_get_encoded_value(plaintext))
@mock_kms @mock_kms
@ -781,6 +814,8 @@ def test_list_resource_tags():
(dict(KeySpec="AES_256"), 32), (dict(KeySpec="AES_256"), 32),
(dict(KeySpec="AES_128"), 16), (dict(KeySpec="AES_128"), 16),
(dict(NumberOfBytes=64), 64), (dict(NumberOfBytes=64), 64),
(dict(NumberOfBytes=1), 1),
(dict(NumberOfBytes=1024), 1024),
)) ))
@mock_kms @mock_kms
def test_generate_data_key_sizes(kwargs, expected_key_length): def test_generate_data_key_sizes(kwargs, expected_key_length):
@ -807,6 +842,7 @@ def test_generate_data_key_decrypt():
(dict(KeySpec="AES_257"),), (dict(KeySpec="AES_257"),),
(dict(KeySpec="AES_128", NumberOfBytes=16),), (dict(KeySpec="AES_128", NumberOfBytes=16),),
(dict(NumberOfBytes=2048),), (dict(NumberOfBytes=2048),),
(dict(NumberOfBytes=0),),
(dict(),), (dict(),),
)) ))
@mock_kms @mock_kms
@ -814,20 +850,42 @@ def test_generate_data_key_invalid_size_params(kwargs):
client = boto3.client("kms", region_name="us-east-1") client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size") key = client.create_key(Description="generate-data-key-size")
with assert_raises(botocore.exceptions.ClientError) as err: with assert_raises((botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)) as err:
client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs)
@parameterized((
("alias/DoesNotExist",),
("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",),
("d25652e4-d2d2-49f7-929a-671ccda580c6",),
("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",),
))
@mock_kms @mock_kms
def test_generate_data_key_invalid_key(): def test_generate_data_key_invalid_key(key_id):
client = boto3.client("kms", region_name="us-east-1") client = boto3.client("kms", region_name="us-east-1")
key = client.create_key(Description="generate-data-key-size")
with assert_raises(client.exceptions.NotFoundException): with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId="alias/randomnonexistantkey", KeySpec="AES_256") client.generate_data_key(KeyId=key_id, KeySpec="AES_256")
with assert_raises(client.exceptions.NotFoundException):
client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"] + "4", KeySpec="AES_256") @parameterized((
("alias/DoesExist", False),
("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False),
("", True),
("arn:aws:kms:us-east-1:012345678912:key/", True),
))
@mock_kms
def test_generate_data_key_all_valid_key_ids(prefix, append_key_id):
client = boto3.client("kms", region_name="us-east-1")
key = client.create_key()
key_id = key["KeyMetadata"]["KeyId"]
client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id)
target_id = prefix
if append_key_id:
target_id += key_id
client.generate_data_key(KeyId=key_id, NumberOfBytes=32)
@mock_kms @mock_kms
@ -876,14 +934,14 @@ def test_re_encrypt_decrypt(plaintext):
CiphertextBlob=encrypt_response["CiphertextBlob"], CiphertextBlob=encrypt_response["CiphertextBlob"],
EncryptionContext={"encryption": "context"}, EncryptionContext={"encryption": "context"},
) )
decrypt_response_1["Plaintext"].should.equal(plaintext) decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_1["KeyId"].should.equal(key_1_arn) decrypt_response_1["KeyId"].should.equal(key_1_arn)
decrypt_response_2 = client.decrypt( decrypt_response_2 = client.decrypt(
CiphertextBlob=re_encrypt_response["CiphertextBlob"], CiphertextBlob=re_encrypt_response["CiphertextBlob"],
EncryptionContext={"another": "context"}, EncryptionContext={"another": "context"},
) )
decrypt_response_2["Plaintext"].should.equal(plaintext) decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext))
decrypt_response_2["KeyId"].should.equal(key_2_arn) decrypt_response_2["KeyId"].should.equal(key_2_arn)
decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"]) decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"])
@ -904,11 +962,11 @@ def test_re_encrypt_to_invalid_destination():
with assert_raises(client.exceptions.NotFoundException): with assert_raises(client.exceptions.NotFoundException):
client.re_encrypt( client.re_encrypt(
CiphertextBlob=encrypt_response["CiphertextBlob"], CiphertextBlob=encrypt_response["CiphertextBlob"],
DestinationKeyId="8327948729348", DestinationKeyId="alias/DoesNotExist",
) )
@parameterized(((12,), (44,), (91,))) @parameterized(((12,), (44,), (91,), (1,), (1024,)))
@mock_kms @mock_kms
def test_generate_random(number_of_bytes): def test_generate_random(number_of_bytes):
client = boto3.client("kms", region_name="us-west-2") client = boto3.client("kms", region_name="us-west-2")
@ -923,6 +981,21 @@ def test_generate_random(number_of_bytes):
len(response["Plaintext"]).should.equal(number_of_bytes) len(response["Plaintext"]).should.equal(number_of_bytes)
@parameterized((
(2048, botocore.exceptions.ClientError),
(1025, botocore.exceptions.ClientError),
(0, botocore.exceptions.ParamValidationError),
(-1, botocore.exceptions.ParamValidationError),
(-1024, botocore.exceptions.ParamValidationError)
))
@mock_kms
def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type):
client = boto3.client("kms", region_name="us-west-2")
with assert_raises(error_type):
client.generate_random(NumberOfBytes=number_of_bytes)
@mock_kms @mock_kms
def test_enable_key_rotation_key_not_found(): def test_enable_key_rotation_key_not_found():
client = boto3.client("kms", region_name="us-east-1") client = boto3.client("kms", region_name="us-east-1")