add proper KMS encrypt, decrypt, and generate_data_key functionality and tests
This commit is contained in:
parent
3fe8afaa60
commit
98581b9196
@ -4,13 +4,12 @@ import os
|
|||||||
import boto.kms
|
import boto.kms
|
||||||
from moto.core import BaseBackend, BaseModel
|
from moto.core import BaseBackend, BaseModel
|
||||||
from moto.core.utils import iso_8601_datetime_without_milliseconds
|
from moto.core.utils import iso_8601_datetime_without_milliseconds
|
||||||
from .utils import generate_key_id, generate_master_key
|
from .utils import decrypt, encrypt, generate_key_id, generate_master_key
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
class Key(BaseModel):
|
class Key(BaseModel):
|
||||||
|
|
||||||
def __init__(self, policy, key_usage, description, tags, region):
|
def __init__(self, policy, key_usage, description, tags, region):
|
||||||
self.id = generate_key_id()
|
self.id = generate_key_id()
|
||||||
self.policy = policy
|
self.policy = policy
|
||||||
@ -46,8 +45,8 @@ class Key(BaseModel):
|
|||||||
"KeyState": self.key_state,
|
"KeyState": self.key_state,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if self.key_state == 'PendingDeletion':
|
if self.key_state == "PendingDeletion":
|
||||||
key_dict['KeyMetadata']['DeletionDate'] = iso_8601_datetime_without_milliseconds(self.deletion_date)
|
key_dict["KeyMetadata"]["DeletionDate"] = iso_8601_datetime_without_milliseconds(self.deletion_date)
|
||||||
return key_dict
|
return key_dict
|
||||||
|
|
||||||
def delete(self, region_name):
|
def delete(self, region_name):
|
||||||
@ -56,28 +55,28 @@ class Key(BaseModel):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name):
|
def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name):
|
||||||
kms_backend = kms_backends[region_name]
|
kms_backend = kms_backends[region_name]
|
||||||
properties = cloudformation_json['Properties']
|
properties = cloudformation_json["Properties"]
|
||||||
|
|
||||||
key = kms_backend.create_key(
|
key = kms_backend.create_key(
|
||||||
policy=properties['KeyPolicy'],
|
policy=properties["KeyPolicy"],
|
||||||
key_usage='ENCRYPT_DECRYPT',
|
key_usage="ENCRYPT_DECRYPT",
|
||||||
description=properties['Description'],
|
description=properties["Description"],
|
||||||
tags=properties.get('Tags'),
|
tags=properties.get("Tags"),
|
||||||
region=region_name,
|
region=region_name,
|
||||||
)
|
)
|
||||||
key.key_rotation_status = properties['EnableKeyRotation']
|
key.key_rotation_status = properties["EnableKeyRotation"]
|
||||||
key.enabled = properties['Enabled']
|
key.enabled = properties["Enabled"]
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def get_cfn_attribute(self, attribute_name):
|
def get_cfn_attribute(self, attribute_name):
|
||||||
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
|
||||||
if attribute_name == 'Arn':
|
|
||||||
|
if attribute_name == "Arn":
|
||||||
return self.arn
|
return self.arn
|
||||||
raise UnformattedGetAttTemplateException()
|
raise UnformattedGetAttTemplateException()
|
||||||
|
|
||||||
|
|
||||||
class KmsBackend(BaseBackend):
|
class KmsBackend(BaseBackend):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
self.key_to_aliases = defaultdict(set)
|
self.key_to_aliases = defaultdict(set)
|
||||||
@ -110,8 +109,8 @@ class KmsBackend(BaseBackend):
|
|||||||
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to
|
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to
|
||||||
# describe key not just KeyId
|
# describe key not just KeyId
|
||||||
key_id = self.get_key_id(key_id)
|
key_id = self.get_key_id(key_id)
|
||||||
if r'alias/' in str(key_id).lower():
|
if r"alias/" in str(key_id).lower():
|
||||||
key_id = self.get_key_id_from_alias(key_id.split('alias/')[1])
|
key_id = self.get_key_id_from_alias(key_id.split("alias/")[1])
|
||||||
return self.keys[self.get_key_id(key_id)]
|
return self.keys[self.get_key_id(key_id)]
|
||||||
|
|
||||||
def list_keys(self):
|
def list_keys(self):
|
||||||
@ -119,7 +118,26 @@ class KmsBackend(BaseBackend):
|
|||||||
|
|
||||||
def get_key_id(self, key_id):
|
def get_key_id(self, key_id):
|
||||||
# Allow use of ARN as well as pure KeyId
|
# Allow use of ARN as well as pure KeyId
|
||||||
return str(key_id).split(r':key/')[1] if r':key/' in str(key_id).lower() else key_id
|
return str(key_id).split(r":key/")[1] if r":key/" in str(key_id).lower() else key_id
|
||||||
|
|
||||||
|
def get_alias_name(self, alias_name):
|
||||||
|
# Allow use of ARN as well as alias name
|
||||||
|
return str(alias_name).split(r":alias/")[1] if r":alias/" in str(alias_name).lower() else alias_name
|
||||||
|
|
||||||
|
def any_id_to_key_id(self, key_id):
|
||||||
|
"""Go from any valid key ID to the raw key ID.
|
||||||
|
|
||||||
|
Acceptable inputs:
|
||||||
|
- raw key ID
|
||||||
|
- key ARN
|
||||||
|
- alias name
|
||||||
|
- alias ARN
|
||||||
|
"""
|
||||||
|
key_id = self.get_alias_name(key_id)
|
||||||
|
key_id = self.get_key_id(key_id)
|
||||||
|
if key_id.startswith("alias/"):
|
||||||
|
key_id = self.get_key_id_from_alias(key_id)
|
||||||
|
return key_id
|
||||||
|
|
||||||
def alias_exists(self, alias_name):
|
def alias_exists(self, alias_name):
|
||||||
for aliases in self.key_to_aliases.values():
|
for aliases in self.key_to_aliases.values():
|
||||||
@ -163,37 +181,56 @@ class KmsBackend(BaseBackend):
|
|||||||
|
|
||||||
def disable_key(self, key_id):
|
def disable_key(self, key_id):
|
||||||
self.keys[key_id].enabled = False
|
self.keys[key_id].enabled = False
|
||||||
self.keys[key_id].key_state = 'Disabled'
|
self.keys[key_id].key_state = "Disabled"
|
||||||
|
|
||||||
def enable_key(self, key_id):
|
def enable_key(self, key_id):
|
||||||
self.keys[key_id].enabled = True
|
self.keys[key_id].enabled = True
|
||||||
self.keys[key_id].key_state = 'Enabled'
|
self.keys[key_id].key_state = "Enabled"
|
||||||
|
|
||||||
def cancel_key_deletion(self, key_id):
|
def cancel_key_deletion(self, key_id):
|
||||||
self.keys[key_id].key_state = 'Disabled'
|
self.keys[key_id].key_state = "Disabled"
|
||||||
self.keys[key_id].deletion_date = None
|
self.keys[key_id].deletion_date = None
|
||||||
|
|
||||||
def schedule_key_deletion(self, key_id, pending_window_in_days):
|
def schedule_key_deletion(self, key_id, pending_window_in_days):
|
||||||
if 7 <= pending_window_in_days <= 30:
|
if 7 <= pending_window_in_days <= 30:
|
||||||
self.keys[key_id].enabled = False
|
self.keys[key_id].enabled = False
|
||||||
self.keys[key_id].key_state = 'PendingDeletion'
|
self.keys[key_id].key_state = "PendingDeletion"
|
||||||
self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days)
|
self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days)
|
||||||
return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date)
|
return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date)
|
||||||
|
|
||||||
|
def encrypt(self, key_id, plaintext, encryption_context):
|
||||||
|
key_id = self.any_id_to_key_id(key_id)
|
||||||
|
|
||||||
|
ciphertext_blob = encrypt(
|
||||||
|
master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context
|
||||||
|
)
|
||||||
|
arn = self.keys[key_id].arn
|
||||||
|
return ciphertext_blob, arn
|
||||||
|
|
||||||
|
def decrypt(self, ciphertext_blob, encryption_context):
|
||||||
|
plaintext, key_id = decrypt(
|
||||||
|
master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
|
||||||
|
)
|
||||||
|
arn = self.keys[key_id].arn
|
||||||
|
return plaintext, arn
|
||||||
|
|
||||||
def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens):
|
def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens):
|
||||||
key = self.keys[self.get_key_id(key_id)]
|
key_id = self.any_id_to_key_id(key_id)
|
||||||
|
|
||||||
if key_spec:
|
if key_spec:
|
||||||
if key_spec == 'AES_128':
|
# Note: Actual validation of key_spec is done in kms.responses
|
||||||
bytes = 16
|
if key_spec == "AES_128":
|
||||||
|
plaintext_len = 16
|
||||||
else:
|
else:
|
||||||
bytes = 32
|
plaintext_len = 32
|
||||||
else:
|
else:
|
||||||
bytes = number_of_bytes
|
plaintext_len = number_of_bytes
|
||||||
|
|
||||||
plaintext = os.urandom(bytes)
|
plaintext = os.urandom(plaintext_len)
|
||||||
|
|
||||||
return plaintext, key.arn
|
ciphertext_blob, arn = self.encrypt(key_id=key_id, plaintext=plaintext, encryption_context=encryption_context)
|
||||||
|
|
||||||
|
return plaintext, ciphertext_blob, arn
|
||||||
|
|
||||||
|
|
||||||
kms_backends = {}
|
kms_backends = {}
|
||||||
|
@ -8,6 +8,7 @@ import six
|
|||||||
from moto.core.responses import BaseResponse
|
from moto.core.responses import BaseResponse
|
||||||
from .models import kms_backends
|
from .models import kms_backends
|
||||||
from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException
|
from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException
|
||||||
|
from .utils import decrypt, encrypt
|
||||||
|
|
||||||
reserved_aliases = [
|
reserved_aliases = [
|
||||||
'alias/aws/ebs',
|
'alias/aws/ebs',
|
||||||
@ -21,7 +22,13 @@ class KmsResponse(BaseResponse):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def parameters(self):
|
def parameters(self):
|
||||||
return json.loads(self.body)
|
params = json.loads(self.body)
|
||||||
|
|
||||||
|
for key in ("Plaintext", "CiphertextBlob"):
|
||||||
|
if key in params:
|
||||||
|
params[key] = base64.b64decode(params[key].encode("utf-8"))
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def kms_backend(self):
|
def kms_backend(self):
|
||||||
@ -224,24 +231,34 @@ class KmsResponse(BaseResponse):
|
|||||||
return json.dumps({'Truncated': False, 'PolicyNames': ['default']})
|
return json.dumps({'Truncated': False, 'PolicyNames': ['default']})
|
||||||
|
|
||||||
def encrypt(self):
|
def encrypt(self):
|
||||||
"""
|
key_id = self.parameters.get("KeyId")
|
||||||
We perform no encryption, we just encode the value as base64 and then
|
encryption_context = self.parameters.get('EncryptionContext', {})
|
||||||
decode it in decrypt().
|
plaintext = self.parameters.get("Plaintext")
|
||||||
"""
|
|
||||||
value = self.parameters.get("Plaintext")
|
if isinstance(plaintext, six.text_type):
|
||||||
if isinstance(value, six.text_type):
|
plaintext = plaintext.encode('utf-8')
|
||||||
value = value.encode('utf-8')
|
|
||||||
return json.dumps({"CiphertextBlob": base64.b64encode(value).decode("utf-8"), 'KeyId': 'key_id'})
|
ciphertext_blob, arn = self.kms_backend.encrypt(
|
||||||
|
key_id=key_id,
|
||||||
|
plaintext=plaintext,
|
||||||
|
encryption_context=encryption_context,
|
||||||
|
)
|
||||||
|
ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8")
|
||||||
|
|
||||||
|
return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn})
|
||||||
|
|
||||||
def decrypt(self):
|
def decrypt(self):
|
||||||
# TODO refuse decode if EncryptionContext is not the same as when it was encrypted / generated
|
ciphertext_blob = self.parameters.get("CiphertextBlob")
|
||||||
|
encryption_context = self.parameters.get('EncryptionContext', {})
|
||||||
|
|
||||||
value = self.parameters.get("CiphertextBlob")
|
plaintext, arn = self.kms_backend.decrypt(
|
||||||
try:
|
ciphertext_blob=ciphertext_blob,
|
||||||
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'})
|
encryption_context=encryption_context,
|
||||||
except UnicodeDecodeError:
|
)
|
||||||
# Generate data key will produce random bytes which when decrypted is still returned as base64
|
|
||||||
return json.dumps({"Plaintext": value})
|
plaintext_response = base64.b64encode(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn})
|
||||||
|
|
||||||
def disable_key(self):
|
def disable_key(self):
|
||||||
key_id = self.parameters.get('KeyId')
|
key_id = self.parameters.get('KeyId')
|
||||||
@ -291,7 +308,7 @@ class KmsResponse(BaseResponse):
|
|||||||
|
|
||||||
def generate_data_key(self):
|
def generate_data_key(self):
|
||||||
key_id = self.parameters.get('KeyId')
|
key_id = self.parameters.get('KeyId')
|
||||||
encryption_context = self.parameters.get('EncryptionContext')
|
encryption_context = self.parameters.get('EncryptionContext', {})
|
||||||
number_of_bytes = self.parameters.get('NumberOfBytes')
|
number_of_bytes = self.parameters.get('NumberOfBytes')
|
||||||
key_spec = self.parameters.get('KeySpec')
|
key_spec = self.parameters.get('KeySpec')
|
||||||
grant_tokens = self.parameters.get('GrantTokens')
|
grant_tokens = self.parameters.get('GrantTokens')
|
||||||
@ -306,27 +323,39 @@ class KmsResponse(BaseResponse):
|
|||||||
raise NotFoundException('Invalid keyId')
|
raise NotFoundException('Invalid keyId')
|
||||||
|
|
||||||
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0):
|
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0):
|
||||||
raise ValidationException("1 validation error detected: Value '2048' at 'numberOfBytes' failed "
|
raise ValidationException((
|
||||||
|
"1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed "
|
||||||
"to satisfy constraint: Member must have value less than or "
|
"to satisfy constraint: Member must have value less than or "
|
||||||
"equal to 1024")
|
"equal to 1024"
|
||||||
|
).format(number_of_bytes=number_of_bytes)
|
||||||
|
)
|
||||||
|
|
||||||
if key_spec and key_spec not in ('AES_256', 'AES_128'):
|
if key_spec and key_spec not in ('AES_256', 'AES_128'):
|
||||||
raise ValidationException("1 validation error detected: Value 'AES_257' at 'keySpec' failed "
|
raise ValidationException((
|
||||||
|
"1 validation error detected: Value '{key_spec}' at 'keySpec' failed "
|
||||||
"to satisfy constraint: Member must satisfy enum value set: "
|
"to satisfy constraint: Member must satisfy enum value set: "
|
||||||
"[AES_256, AES_128]")
|
"[AES_256, AES_128]"
|
||||||
|
).format(key_spec=key_spec)
|
||||||
|
)
|
||||||
if not key_spec and not number_of_bytes:
|
if not key_spec and not number_of_bytes:
|
||||||
raise ValidationException("Please specify either number of bytes or key spec.")
|
raise ValidationException("Please specify either number of bytes or key spec.")
|
||||||
if key_spec and number_of_bytes:
|
if key_spec and number_of_bytes:
|
||||||
raise ValidationException("Please specify either number of bytes or key spec.")
|
raise ValidationException("Please specify either number of bytes or key spec.")
|
||||||
|
|
||||||
plaintext, key_arn = self.kms_backend.generate_data_key(key_id, encryption_context,
|
plaintext, ciphertext_blob, key_arn = self.kms_backend.generate_data_key(
|
||||||
number_of_bytes, key_spec, grant_tokens)
|
key_id=key_id,
|
||||||
|
encryption_context=encryption_context,
|
||||||
|
number_of_bytes=number_of_bytes,
|
||||||
|
key_spec=key_spec,
|
||||||
|
grant_tokens=grant_tokens
|
||||||
|
)
|
||||||
|
|
||||||
plaintext = base64.b64encode(plaintext).decode()
|
plaintext_response = base64.b64encode(plaintext).decode("utf-8")
|
||||||
|
ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8")
|
||||||
|
|
||||||
return json.dumps({
|
return json.dumps({
|
||||||
'CiphertextBlob': plaintext,
|
'CiphertextBlob': ciphertext_blob_response,
|
||||||
'Plaintext': plaintext,
|
'Plaintext': plaintext_response,
|
||||||
'KeyId': key_arn # not alias
|
'KeyId': key_arn # not alias
|
||||||
})
|
})
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,6 @@
|
|||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import sure # noqa
|
||||||
from nose.tools import assert_raises
|
from nose.tools import assert_raises
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
|
|
||||||
@ -104,26 +105,23 @@ def test_encrypt_decrypt_cycle(encryption_context):
|
|||||||
|
|
||||||
|
|
||||||
def test_encrypt_unknown_key_id():
|
def test_encrypt_unknown_key_id():
|
||||||
assert_raises(
|
with assert_raises(NotFoundException):
|
||||||
NotFoundException, encrypt, master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={}
|
encrypt(master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={})
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_decrypt_invalid_ciphertext_format():
|
def test_decrypt_invalid_ciphertext_format():
|
||||||
master_key = Key("nop", "nop", "nop", [], "nop")
|
master_key = Key("nop", "nop", "nop", [], "nop")
|
||||||
master_key_map = {master_key.id: master_key}
|
master_key_map = {master_key.id: master_key}
|
||||||
|
|
||||||
assert_raises(
|
with assert_raises(InvalidCiphertextException):
|
||||||
InvalidCiphertextException, decrypt, master_keys=master_key_map, ciphertext_blob=b"", encryption_context={}
|
decrypt(master_keys=master_key_map, ciphertext_blob=b"", encryption_context={})
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_decrypt_unknwown_key_id():
|
def test_decrypt_unknwown_key_id():
|
||||||
ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext"
|
ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext"
|
||||||
|
|
||||||
assert_raises(
|
with assert_raises(AccessDeniedException):
|
||||||
AccessDeniedException, decrypt, master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={}
|
decrypt(master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={})
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_decrypt_invalid_ciphertext():
|
def test_decrypt_invalid_ciphertext():
|
||||||
@ -131,9 +129,8 @@ def test_decrypt_invalid_ciphertext():
|
|||||||
master_key_map = {master_key.id: master_key}
|
master_key_map = {master_key.id: master_key}
|
||||||
ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext"
|
ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext"
|
||||||
|
|
||||||
assert_raises(
|
with assert_raises(InvalidCiphertextException):
|
||||||
InvalidCiphertextException,
|
decrypt(
|
||||||
decrypt,
|
|
||||||
master_keys=master_key_map,
|
master_keys=master_key_map,
|
||||||
ciphertext_blob=ciphertext_blob,
|
ciphertext_blob=ciphertext_blob,
|
||||||
encryption_context={},
|
encryption_context={},
|
||||||
@ -152,9 +149,8 @@ def test_decrypt_invalid_encryption_context():
|
|||||||
encryption_context={"some": "encryption", "context": "here"},
|
encryption_context={"some": "encryption", "context": "here"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_raises(
|
with assert_raises(InvalidCiphertextException):
|
||||||
InvalidCiphertextException,
|
decrypt(
|
||||||
decrypt,
|
|
||||||
master_keys=master_key_map,
|
master_keys=master_key_map,
|
||||||
ciphertext_blob=ciphertext_blob,
|
ciphertext_blob=ciphertext_blob,
|
||||||
encryption_context={},
|
encryption_context={},
|
||||||
|
Loading…
Reference in New Issue
Block a user