diff --git a/IMPLEMENTATION_COVERAGE.md b/IMPLEMENTATION_COVERAGE.md index d149b0dd8..1ad96aeb4 100644 --- a/IMPLEMENTATION_COVERAGE.md +++ b/IMPLEMENTATION_COVERAGE.md @@ -1237,7 +1237,7 @@ - [ ] delete_identities - [ ] delete_identity_pool - [ ] describe_identity -- [ ] describe_identity_pool +- [X] describe_identity_pool - [X] get_credentials_for_identity - [X] get_id - [ ] get_identity_pool_roles @@ -3801,14 +3801,14 @@ - [ ] update_stream ## kms -41% implemented +54% implemented - [X] cancel_key_deletion - [ ] connect_custom_key_store - [ ] create_alias - [ ] create_custom_key_store - [ ] create_grant - [X] create_key -- [ ] decrypt +- [X] decrypt - [X] delete_alias - [ ] delete_custom_key_store - [ ] delete_imported_key_material @@ -3819,10 +3819,10 @@ - [ ] disconnect_custom_key_store - [X] enable_key - [X] enable_key_rotation -- [ ] encrypt +- [X] encrypt - [X] generate_data_key -- [ ] generate_data_key_without_plaintext -- [ ] generate_random +- [X] generate_data_key_without_plaintext +- [X] generate_random - [X] get_key_policy - [X] get_key_rotation_status - [ ] get_parameters_for_import @@ -3834,7 +3834,7 @@ - [X] list_resource_tags - [ ] list_retirable_grants - [X] put_key_policy -- [ ] re_encrypt +- [X] re_encrypt - [ ] retire_grant - [ ] revoke_grant - [X] schedule_key_deletion @@ -6050,24 +6050,24 @@ ## stepfunctions 0% implemented - [ ] create_activity -- [ ] create_state_machine +- [X] create_state_machine - [ ] delete_activity -- [ ] delete_state_machine +- [X] delete_state_machine - [ ] describe_activity -- [ ] describe_execution -- [ ] describe_state_machine -- [ ] describe_state_machine_for_execution +- [X] describe_execution +- [X] describe_state_machine +- [x] describe_state_machine_for_execution - [ ] get_activity_task - [ ] get_execution_history - [ ] list_activities -- [ ] list_executions -- [ ] list_state_machines -- [ ] list_tags_for_resource +- [X] list_executions +- [X] list_state_machines +- [X] list_tags_for_resource - [ ] send_task_failure - [ ] send_task_heartbeat - [ ] send_task_success -- [ ] start_execution -- [ ] stop_execution +- [X] start_execution +- [X] stop_execution - [ ] tag_resource - [ ] untag_resource - [ ] update_state_machine diff --git a/docs/index.rst b/docs/index.rst index 4811fb797..6311597fe 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -94,6 +94,8 @@ Currently implemented Services: +---------------------------+-----------------------+------------------------------------+ | SES | @mock_ses | all endpoints done | +---------------------------+-----------------------+------------------------------------+ +| SFN | @mock_stepfunctions | basic endpoints done | ++---------------------------+-----------------------+------------------------------------+ | SNS | @mock_sns | all endpoints done | +---------------------------+-----------------------+------------------------------------+ | SQS | @mock_sqs | core endpoints done | diff --git a/moto/__init__.py b/moto/__init__.py index 8594cedd2..f82a411cf 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -42,6 +42,7 @@ from .ses import mock_ses, mock_ses_deprecated # flake8: noqa from .secretsmanager import mock_secretsmanager # flake8: noqa from .sns import mock_sns, mock_sns_deprecated # flake8: noqa from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa +from .stepfunctions import mock_stepfunctions # flake8: noqa from .sts import mock_sts, mock_sts_deprecated # flake8: noqa from .ssm import mock_ssm # flake8: noqa from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 6be062d7f..9d6305ef9 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -298,7 +298,7 @@ class Stage(BaseModel, dict): class ApiKey(BaseModel, dict): def __init__(self, name=None, description=None, enabled=True, - generateDistinctId=False, value=None, stageKeys=None, customerId=None): + generateDistinctId=False, value=None, stageKeys=None, tags=None, customerId=None): super(ApiKey, self).__init__() self['id'] = create_id() self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40)) @@ -308,6 +308,7 @@ class ApiKey(BaseModel, dict): self['enabled'] = enabled self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) self['stageKeys'] = stageKeys + self['tags'] = tags def update_operations(self, patch_operations): for op in patch_operations: @@ -407,10 +408,16 @@ class RestAPI(BaseModel): stage_url_upper = STAGE_URL.format(api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name) - responses.add_callback(responses.GET, stage_url_lower, - callback=self.resource_callback) - responses.add_callback(responses.GET, stage_url_upper, - callback=self.resource_callback) + for url in [stage_url_lower, stage_url_upper]: + responses._default_mock._matches.insert(0, + responses.CallbackResponse( + url=url, + method=responses.GET, + callback=self.resource_callback, + content_type="text/plain", + match_querystring=False, + ) + ) def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): if variables is None: diff --git a/moto/backends.py b/moto/backends.py index 6ea85093d..8a20697c2 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -40,6 +40,7 @@ from moto.secretsmanager import secretsmanager_backends from moto.sns import sns_backends from moto.sqs import sqs_backends from moto.ssm import ssm_backends +from moto.stepfunctions import stepfunction_backends from moto.sts import sts_backends from moto.swf import swf_backends from moto.xray import xray_backends @@ -91,6 +92,7 @@ BACKENDS = { 'sns': sns_backends, 'sqs': sqs_backends, 'ssm': ssm_backends, + 'stepfunctions': stepfunction_backends, 'sts': sts_backends, 'swf': swf_backends, 'route53': route53_backends, diff --git a/moto/cognitoidentity/exceptions.py b/moto/cognitoidentity/exceptions.py new file mode 100644 index 000000000..ec22f3b42 --- /dev/null +++ b/moto/cognitoidentity/exceptions.py @@ -0,0 +1,15 @@ +from __future__ import unicode_literals + +import json + +from werkzeug.exceptions import BadRequest + + +class ResourceNotFoundError(BadRequest): + + def __init__(self, message): + super(ResourceNotFoundError, self).__init__() + self.description = json.dumps({ + "message": message, + '__type': 'ResourceNotFoundException', + }) diff --git a/moto/cognitoidentity/models.py b/moto/cognitoidentity/models.py index c916b7f62..6f752ab69 100644 --- a/moto/cognitoidentity/models.py +++ b/moto/cognitoidentity/models.py @@ -8,7 +8,7 @@ import boto.cognito.identity from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds - +from .exceptions import ResourceNotFoundError from .utils import get_random_identity_id @@ -39,10 +39,29 @@ class CognitoIdentityBackend(BaseBackend): self.__dict__ = {} self.__init__(region) - def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, - supported_login_providers, developer_provider_name, open_id_connect_provider_arns, - cognito_identity_providers, saml_provider_arns): + def describe_identity_pool(self, identity_pool_id): + identity_pool = self.identity_pools.get(identity_pool_id, None) + if not identity_pool: + raise ResourceNotFoundError(identity_pool) + + response = json.dumps({ + 'AllowUnauthenticatedIdentities': identity_pool.allow_unauthenticated_identities, + 'CognitoIdentityProviders': identity_pool.cognito_identity_providers, + 'DeveloperProviderName': identity_pool.developer_provider_name, + 'IdentityPoolId': identity_pool.identity_pool_id, + 'IdentityPoolName': identity_pool.identity_pool_name, + 'IdentityPoolTags': {}, + 'OpenIdConnectProviderARNs': identity_pool.open_id_connect_provider_arns, + 'SamlProviderARNs': identity_pool.saml_provider_arns, + 'SupportedLoginProviders': identity_pool.supported_login_providers + }) + + return response + + def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, + supported_login_providers, developer_provider_name, open_id_connect_provider_arns, + cognito_identity_providers, saml_provider_arns): new_identity = CognitoIdentity(self.region, identity_pool_name, allow_unauthenticated_identities=allow_unauthenticated_identities, supported_login_providers=supported_login_providers, @@ -77,12 +96,12 @@ class CognitoIdentityBackend(BaseBackend): response = json.dumps( { "Credentials": - { - "AccessKeyId": "TESTACCESSKEY12345", - "Expiration": expiration_str, - "SecretKey": "ABCSECRETKEY", - "SessionToken": "ABC12345" - }, + { + "AccessKeyId": "TESTACCESSKEY12345", + "Expiration": expiration_str, + "SecretKey": "ABCSECRETKEY", + "SessionToken": "ABC12345" + }, "IdentityId": identity_id }) return response diff --git a/moto/cognitoidentity/responses.py b/moto/cognitoidentity/responses.py index 33faaa300..709fdb40a 100644 --- a/moto/cognitoidentity/responses.py +++ b/moto/cognitoidentity/responses.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse - from .models import cognitoidentity_backends from .utils import get_random_identity_id @@ -16,6 +15,7 @@ class CognitoIdentityResponse(BaseResponse): open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs') cognito_identity_providers = self._get_param('CognitoIdentityProviders') saml_provider_arns = self._get_param('SamlProviderARNs') + return cognitoidentity_backends[self.region].create_identity_pool( identity_pool_name=identity_pool_name, allow_unauthenticated_identities=allow_unauthenticated_identities, @@ -28,6 +28,9 @@ class CognitoIdentityResponse(BaseResponse): def get_id(self): return cognitoidentity_backends[self.region].get_id() + def describe_identity_pool(self): + return cognitoidentity_backends[self.region].describe_identity_pool(self._get_param('IdentityPoolId')) + def get_credentials_for_identity(self): return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId')) diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index 06cfd8895..a81d89093 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -14,7 +14,7 @@ SINGLE_ERROR_RESPONSE = u""" """ ERROR_RESPONSE = u""" - + {{error_type}} @@ -23,7 +23,7 @@ ERROR_RESPONSE = u""" 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE - + """ ERROR_JSON_RESPONSE = u"""{ diff --git a/moto/core/models.py b/moto/core/models.py index 896f9ac4a..63287608d 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -197,53 +197,9 @@ class CallbackResponse(responses.CallbackResponse): botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') responses_mock = responses._default_mock - - -class ResponsesMockAWS(BaseMockAWS): - def reset(self): - botocore_mock.reset() - responses_mock.reset() - - def enable_patching(self): - if not hasattr(botocore_mock, '_patcher') or not hasattr(botocore_mock._patcher, 'target'): - # Check for unactivated patcher - botocore_mock.start() - - if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'): - responses_mock.start() - - for method in RESPONSES_METHODS: - for backend in self.backends_for_urls.values(): - for key, value in backend.urls.items(): - responses_mock.add( - CallbackResponse( - method=method, - url=re.compile(key), - callback=convert_flask_to_responses_response(value), - stream=True, - match_querystring=False, - ) - ) - botocore_mock.add( - CallbackResponse( - method=method, - url=re.compile(key), - callback=convert_flask_to_responses_response(value), - stream=True, - match_querystring=False, - ) - ) - - def disable_patching(self): - try: - botocore_mock.stop() - except RuntimeError: - pass - - try: - responses_mock.stop() - except RuntimeError: - pass +# Add passthrough to allow any other requests to work +# Since this uses .startswith, it applies to http and https requests. +responses_mock.add_passthru("http") BOTOCORE_HTTP_METHODS = [ @@ -310,6 +266,14 @@ botocore_stubber = BotocoreStubber() BUILTIN_HANDLERS.append(('before-send', botocore_stubber)) +def not_implemented_callback(request): + status = 400 + headers = {} + response = "The method is not implemented" + + return status, headers, response + + class BotocoreEventMockAWS(BaseMockAWS): def reset(self): botocore_stubber.reset() @@ -339,6 +303,24 @@ class BotocoreEventMockAWS(BaseMockAWS): match_querystring=False, ) ) + responses_mock.add( + CallbackResponse( + method=method, + url=re.compile("https?://.+.amazonaws.com/.*"), + callback=not_implemented_callback, + stream=True, + match_querystring=False, + ) + ) + botocore_mock.add( + CallbackResponse( + method=method, + url=re.compile("https?://.+.amazonaws.com/.*"), + callback=not_implemented_callback, + stream=True, + match_querystring=False, + ) + ) def disable_patching(self): botocore_stubber.enabled = False diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index 151a314f1..dbc0bd57d 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -941,8 +941,7 @@ class OpAnd(Op): def expr(self, item): lhs = self.lhs.expr(item) - rhs = self.rhs.expr(item) - return lhs and rhs + return lhs and self.rhs.expr(item) class OpLessThan(Op): diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index e868caaa8..4ef4461cd 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -363,7 +363,7 @@ class StreamRecord(BaseModel): 'dynamodb': { 'StreamViewType': stream_type, 'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(), - 'SequenceNumber': seq, + 'SequenceNumber': str(seq), 'SizeBytes': 1, 'Keys': keys } diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index 3e9fbb553..15c1130f8 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -356,9 +356,18 @@ class DynamoHandler(BaseResponse): if projection_expression and expression_attribute_names: expressions = [x.strip() for x in projection_expression.split(',')] + projection_expression = None for expression in expressions: + if projection_expression is not None: + projection_expression = projection_expression + ", " + else: + projection_expression = "" + if expression in expression_attribute_names: - projection_expression = projection_expression.replace(expression, expression_attribute_names[expression]) + projection_expression = projection_expression + \ + expression_attribute_names[expression] + else: + projection_expression = projection_expression + expression filter_kwargs = {} diff --git a/moto/dynamodbstreams/models.py b/moto/dynamodbstreams/models.py index 41cc6e280..3e20ae13f 100644 --- a/moto/dynamodbstreams/models.py +++ b/moto/dynamodbstreams/models.py @@ -39,7 +39,7 @@ class ShardIterator(BaseModel): def get(self, limit=1000): items = self.stream_shard.get(self.sequence_number, limit) try: - last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items) + last_sequence_number = max(int(i['dynamodb']['SequenceNumber']) for i in items) new_shard_iterator = ShardIterator(self.streams_backend, self.stream_shard, 'AFTER_SEQUENCE_NUMBER', diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index c9c113615..7774f3239 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse from .models import dynamodbstreams_backends +from six import string_types class DynamoDBStreamsHandler(BaseResponse): @@ -23,8 +24,13 @@ class DynamoDBStreamsHandler(BaseResponse): arn = self._get_param('StreamArn') shard_id = self._get_param('ShardId') shard_iterator_type = self._get_param('ShardIteratorType') + sequence_number = self._get_param('SequenceNumber') + # according to documentation sequence_number param should be string + if isinstance(sequence_number, string_types): + sequence_number = int(sequence_number) + return self.backend.get_shard_iterator(arn, shard_id, - shard_iterator_type) + shard_iterator_type, sequence_number) def get_records(self): arn = self._get_param('ShardIterator') diff --git a/moto/ec2/responses/instances.py b/moto/ec2/responses/instances.py index 82c2b1997..28123b995 100644 --- a/moto/ec2/responses/instances.py +++ b/moto/ec2/responses/instances.py @@ -6,6 +6,7 @@ from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores from moto.ec2.utils import filters_from_querystring, \ dict_from_querystring +from moto.elbv2 import elbv2_backends class InstanceResponse(BaseResponse): @@ -68,6 +69,7 @@ class InstanceResponse(BaseResponse): if self.is_not_dryrun('TerminateInstance'): instances = self.ec2_backend.terminate_instances(instance_ids) autoscaling_backends[self.region].notify_terminate_instances(instance_ids) + elbv2_backends[self.region].notify_terminate_instances(instance_ids) template = self.response_template(EC2_TERMINATE_INSTANCES) return template.render(instances=instances) diff --git a/moto/elbv2/exceptions.py b/moto/elbv2/exceptions.py index 11dcbcb21..ccbfd06dd 100644 --- a/moto/elbv2/exceptions.py +++ b/moto/elbv2/exceptions.py @@ -131,7 +131,7 @@ class InvalidActionTypeError(ELBClientError): def __init__(self, invalid_name, index): super(InvalidActionTypeError, self).__init__( "ValidationError", - "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect]" % (invalid_name, index) + "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect, fixed-response]" % (invalid_name, index) ) @@ -190,3 +190,18 @@ class InvalidModifyRuleArgumentsError(ELBClientError): "ValidationError", "Either conditions or actions must be specified" ) + + +class InvalidStatusCodeActionTypeError(ELBClientError): + def __init__(self, msg): + super(InvalidStatusCodeActionTypeError, self).__init__( + "ValidationError", msg + ) + + +class InvalidLoadBalancerActionException(ELBClientError): + + def __init__(self, msg): + super(InvalidLoadBalancerActionException, self).__init__( + "InvalidLoadBalancerAction", msg + ) diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 7e73c7042..636cc56a1 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -3,10 +3,11 @@ from __future__ import unicode_literals import datetime import re from jinja2 import Template +from botocore.exceptions import ParamValidationError from moto.compat import OrderedDict from moto.core.exceptions import RESTError from moto.core import BaseBackend, BaseModel -from moto.core.utils import camelcase_to_underscores +from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase from moto.ec2.models import ec2_backends from moto.acm.models import acm_backends from .utils import make_arn_for_target_group @@ -31,8 +32,8 @@ from .exceptions import ( RuleNotFoundError, DuplicatePriorityError, InvalidTargetGroupNameError, - InvalidModifyRuleArgumentsError -) + InvalidModifyRuleArgumentsError, + InvalidStatusCodeActionTypeError, InvalidLoadBalancerActionException) class FakeHealthStatus(BaseModel): @@ -110,6 +111,11 @@ class FakeTargetGroup(BaseModel): if not t: raise InvalidTargetError() + def deregister_terminated_instances(self, instance_ids): + for target_id in list(self.targets.keys()): + if target_id in instance_ids: + del self.targets[target_id] + def add_tag(self, key, value): if len(self.tags) >= 10 and key not in self.tags: raise TooManyTagsError() @@ -215,9 +221,9 @@ class FakeListener(BaseModel): action_type = action['Type'] if action_type == 'forward': default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) - elif action_type in ['redirect', 'authenticate-cognito']: + elif action_type in ['redirect', 'authenticate-cognito', 'fixed-response']: redirect_action = {'type': action_type} - key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig' + key = underscores_to_camelcase(action_type.capitalize().replace('-', '_')) + 'Config' for redirect_config_key, redirect_config_value in action[key].items(): # need to match the output of _get_list_prefix redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value @@ -253,6 +259,12 @@ class FakeAction(BaseModel): {{ action.data["authenticate_cognito_config._user_pool_client_id"] }} {{ action.data["authenticate_cognito_config._user_pool_domain"] }} + {% elif action.type == "fixed-response" %} + + {{ action.data["fixed_response_config._content_type"] }} + {{ action.data["fixed_response_config._message_body"] }} + {{ action.data["fixed_response_config._status_code"] }} + {% endif %} """) return template.render(action=self) @@ -477,11 +489,30 @@ class ELBv2Backend(BaseBackend): action_target_group_arn = action.data['target_group_arn'] if action_target_group_arn not in target_group_arns: raise ActionTargetGroupNotFoundError(action_target_group_arn) + elif action_type == 'fixed-response': + self._validate_fixed_response_action(action, i, index) elif action_type in ['redirect', 'authenticate-cognito']: pass else: raise InvalidActionTypeError(action_type, index) + def _validate_fixed_response_action(self, action, i, index): + status_code = action.data.get('fixed_response_config._status_code') + if status_code is None: + raise ParamValidationError( + report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' % i) + if not re.match(r'^(2|4|5)\d\d$', status_code): + raise InvalidStatusCodeActionTypeError( + "1 validation error detected: Value '%s' at 'actions.%s.member.fixedResponseConfig.statusCode' failed to satisfy constraint: \ +Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, index) + ) + content_type = action.data['fixed_response_config._content_type'] + if content_type and content_type not in ['text/plain', 'text/css', 'text/html', 'application/javascript', + 'application/json']: + raise InvalidLoadBalancerActionException( + "The ContentType must be one of:'text/html', 'application/json', 'application/javascript', 'text/css', 'text/plain'" + ) + def create_target_group(self, name, **kwargs): if len(name) > 32: raise InvalidTargetGroupNameError( @@ -936,6 +967,10 @@ class ELBv2Backend(BaseBackend): return True return False + def notify_terminate_instances(self, instance_ids): + for target_group in self.target_groups.values(): + target_group.deregister_terminated_instances(instance_ids) + elbv2_backends = {} for region in ec2_backends.keys(): diff --git a/moto/iam/policy_validation.py b/moto/iam/policy_validation.py index 6ee286072..d9a4b0282 100644 --- a/moto/iam/policy_validation.py +++ b/moto/iam/policy_validation.py @@ -152,8 +152,10 @@ class IAMPolicyDocumentValidator: sids = [] for statement in self._statements: if "Sid" in statement: - assert statement["Sid"] not in sids - sids.append(statement["Sid"]) + statementId = statement["Sid"] + if statementId: + assert statementId not in sids + sids.append(statementId) def _validate_statements_syntax(self): assert "Statement" in self._policy_json diff --git a/moto/iotdata/exceptions.py b/moto/iotdata/exceptions.py index ddc6b37fd..f2c209eed 100644 --- a/moto/iotdata/exceptions.py +++ b/moto/iotdata/exceptions.py @@ -21,3 +21,11 @@ class InvalidRequestException(IoTDataPlaneClientError): super(InvalidRequestException, self).__init__( "InvalidRequestException", message ) + + +class ConflictException(IoTDataPlaneClientError): + def __init__(self, message): + self.code = 409 + super(ConflictException, self).__init__( + "ConflictException", message + ) diff --git a/moto/iotdata/models.py b/moto/iotdata/models.py index ad4caa89e..fec066f07 100644 --- a/moto/iotdata/models.py +++ b/moto/iotdata/models.py @@ -6,6 +6,7 @@ import jsondiff from moto.core import BaseBackend, BaseModel from moto.iot import iot_backends from .exceptions import ( + ConflictException, ResourceNotFoundException, InvalidRequestException ) @@ -161,6 +162,8 @@ class IoTDataPlaneBackend(BaseBackend): if any(_ for _ in payload['state'].keys() if _ not in ['desired', 'reported']): raise InvalidRequestException('State contains an invalid node') + if 'version' in payload and thing.thing_shadow.version != payload['version']: + raise ConflictException('Version conflict') new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload) thing.thing_shadow = new_shadow return thing.thing_shadow diff --git a/moto/kms/exceptions.py b/moto/kms/exceptions.py index 70edd3dcd..c9094e8f8 100644 --- a/moto/kms/exceptions.py +++ b/moto/kms/exceptions.py @@ -34,3 +34,23 @@ class NotAuthorizedException(JsonRESTError): "NotAuthorizedException", None) self.description = '{"__type":"NotAuthorizedException"}' + + +class AccessDeniedException(JsonRESTError): + code = 400 + + def __init__(self, message): + super(AccessDeniedException, self).__init__( + "AccessDeniedException", message) + + self.description = '{"__type":"AccessDeniedException"}' + + +class InvalidCiphertextException(JsonRESTError): + code = 400 + + def __init__(self): + super(InvalidCiphertextException, self).__init__( + "InvalidCiphertextException", None) + + self.description = '{"__type":"InvalidCiphertextException"}' diff --git a/moto/kms/models.py b/moto/kms/models.py index 577840b06..9e1b08bf9 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -1,16 +1,18 @@ from __future__ import unicode_literals import os -import boto.kms -from moto.core import BaseBackend, BaseModel -from moto.core.utils import iso_8601_datetime_without_milliseconds -from .utils import generate_key_id from collections import defaultdict from datetime import datetime, timedelta +import boto.kms + +from moto.core import BaseBackend, BaseModel +from moto.core.utils import iso_8601_datetime_without_milliseconds + +from .utils import decrypt, encrypt, generate_key_id, generate_master_key + class Key(BaseModel): - def __init__(self, policy, key_usage, description, tags, region): self.id = generate_key_id() self.policy = policy @@ -19,10 +21,11 @@ class Key(BaseModel): self.description = description self.enabled = True self.region = region - self.account_id = "0123456789012" + self.account_id = "012345678912" self.key_rotation_status = False self.deletion_date = None self.tags = tags or {} + self.key_material = generate_master_key() @property def physical_resource_id(self): @@ -45,8 +48,8 @@ class Key(BaseModel): "KeyState": self.key_state, } } - if self.key_state == 'PendingDeletion': - key_dict['KeyMetadata']['DeletionDate'] = iso_8601_datetime_without_milliseconds(self.deletion_date) + if self.key_state == "PendingDeletion": + key_dict["KeyMetadata"]["DeletionDate"] = iso_8601_datetime_without_milliseconds(self.deletion_date) return key_dict def delete(self, region_name): @@ -55,28 +58,28 @@ class Key(BaseModel): @classmethod def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name): kms_backend = kms_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] key = kms_backend.create_key( - policy=properties['KeyPolicy'], - key_usage='ENCRYPT_DECRYPT', - description=properties['Description'], - tags=properties.get('Tags'), + policy=properties["KeyPolicy"], + key_usage="ENCRYPT_DECRYPT", + description=properties["Description"], + tags=properties.get("Tags"), region=region_name, ) - key.key_rotation_status = properties['EnableKeyRotation'] - key.enabled = properties['Enabled'] + key.key_rotation_status = properties["EnableKeyRotation"] + key.enabled = properties["Enabled"] return key def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() class KmsBackend(BaseBackend): - def __init__(self): self.keys = {} self.key_to_aliases = defaultdict(set) @@ -109,16 +112,43 @@ class KmsBackend(BaseBackend): # allow the different methods (alias, ARN :key/, keyId, ARN alias) to # describe key not just KeyId key_id = self.get_key_id(key_id) - if r'alias/' in str(key_id).lower(): - key_id = self.get_key_id_from_alias(key_id.split('alias/')[1]) + if r"alias/" in str(key_id).lower(): + key_id = self.get_key_id_from_alias(key_id.split("alias/")[1]) return self.keys[self.get_key_id(key_id)] def list_keys(self): return self.keys.values() - def get_key_id(self, key_id): + @staticmethod + def get_key_id(key_id): # Allow use of ARN as well as pure KeyId - return str(key_id).split(r':key/')[1] if r':key/' in str(key_id).lower() else key_id + if key_id.startswith("arn:") and ":key/" in key_id: + return key_id.split(":key/")[1] + + return key_id + + @staticmethod + def get_alias_name(alias_name): + # Allow use of ARN as well as alias name + if alias_name.startswith("arn:") and ":alias/" in alias_name: + return alias_name.split(":alias/")[1] + + return alias_name + + def any_id_to_key_id(self, key_id): + """Go from any valid key ID to the raw key ID. + + Acceptable inputs: + - raw key ID + - key ARN + - alias name + - alias ARN + """ + key_id = self.get_alias_name(key_id) + key_id = self.get_key_id(key_id) + if key_id.startswith("alias/"): + key_id = self.get_key_id_from_alias(key_id) + return key_id def alias_exists(self, alias_name): for aliases in self.key_to_aliases.values(): @@ -162,37 +192,69 @@ class KmsBackend(BaseBackend): def disable_key(self, key_id): self.keys[key_id].enabled = False - self.keys[key_id].key_state = 'Disabled' + self.keys[key_id].key_state = "Disabled" def enable_key(self, key_id): self.keys[key_id].enabled = True - self.keys[key_id].key_state = 'Enabled' + self.keys[key_id].key_state = "Enabled" def cancel_key_deletion(self, key_id): - self.keys[key_id].key_state = 'Disabled' + self.keys[key_id].key_state = "Disabled" self.keys[key_id].deletion_date = None def schedule_key_deletion(self, key_id, pending_window_in_days): if 7 <= pending_window_in_days <= 30: self.keys[key_id].enabled = False - self.keys[key_id].key_state = 'PendingDeletion' + self.keys[key_id].key_state = "PendingDeletion" self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days) return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date) + def encrypt(self, key_id, plaintext, encryption_context): + key_id = self.any_id_to_key_id(key_id) + + ciphertext_blob = encrypt( + master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context + ) + arn = self.keys[key_id].arn + return ciphertext_blob, arn + + def decrypt(self, ciphertext_blob, encryption_context): + plaintext, key_id = decrypt( + master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context + ) + arn = self.keys[key_id].arn + return plaintext, arn + + def re_encrypt( + self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context + ): + destination_key_id = self.any_id_to_key_id(destination_key_id) + + plaintext, decrypting_arn = self.decrypt( + ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context + ) + new_ciphertext_blob, encrypting_arn = self.encrypt( + key_id=destination_key_id, plaintext=plaintext, encryption_context=destination_encryption_context + ) + return new_ciphertext_blob, decrypting_arn, encrypting_arn + def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens): - key = self.keys[self.get_key_id(key_id)] + key_id = self.any_id_to_key_id(key_id) if key_spec: - if key_spec == 'AES_128': - bytes = 16 + # Note: Actual validation of key_spec is done in kms.responses + if key_spec == "AES_128": + plaintext_len = 16 else: - bytes = 32 + plaintext_len = 32 else: - bytes = number_of_bytes + plaintext_len = number_of_bytes - plaintext = os.urandom(bytes) + plaintext = os.urandom(plaintext_len) - return plaintext, key.arn + ciphertext_blob, arn = self.encrypt(key_id=key_id, plaintext=plaintext, encryption_context=encryption_context) + + return plaintext, ciphertext_blob, arn kms_backends = {} diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 53012b7f8..998d5cc4b 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -2,13 +2,16 @@ from __future__ import unicode_literals import base64 import json +import os import re + import six from moto.core.responses import BaseResponse from .models import kms_backends from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException +ACCOUNT_ID = "012345678912" reserved_aliases = [ 'alias/aws/ebs', 'alias/aws/s3', @@ -21,13 +24,86 @@ class KmsResponse(BaseResponse): @property def parameters(self): - return json.loads(self.body) + params = json.loads(self.body) + + for key in ("Plaintext", "CiphertextBlob"): + if key in params: + params[key] = base64.b64decode(params[key].encode("utf-8")) + + return params @property def kms_backend(self): return kms_backends[self.region] + def _display_arn(self, key_id): + if key_id.startswith("arn:"): + return key_id + + if key_id.startswith("alias/"): + id_type = "" + else: + id_type = "key/" + + return "arn:aws:kms:{region}:{account}:{id_type}{key_id}".format( + region=self.region, account=ACCOUNT_ID, id_type=id_type, key_id=key_id + ) + + def _validate_cmk_id(self, key_id): + """Determine whether a CMK ID exists. + + - raw key ID + - key ARN + """ + is_arn = key_id.startswith("arn:") and ":key/" in key_id + is_raw_key_id = re.match(r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", key_id, re.IGNORECASE) + + if not is_arn and not is_raw_key_id: + raise NotFoundException("Invalid keyId {key_id}".format(key_id=key_id)) + + cmk_id = self.kms_backend.get_key_id(key_id) + + if cmk_id not in self.kms_backend.keys: + raise NotFoundException("Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id))) + + def _validate_alias(self, key_id): + """Determine whether an alias exists. + + - alias name + - alias ARN + """ + error = NotFoundException("Alias {key_id} is not found.".format(key_id=self._display_arn(key_id))) + + is_arn = key_id.startswith("arn:") and ":alias/" in key_id + is_name = key_id.startswith("alias/") + + if not is_arn and not is_name: + raise error + + alias_name = self.kms_backend.get_alias_name(key_id) + cmk_id = self.kms_backend.get_key_id_from_alias(alias_name) + if cmk_id is None: + raise error + + def _validate_key_id(self, key_id): + """Determine whether or not a key ID exists. + + - raw key ID + - key ARN + - alias name + - alias ARN + """ + is_alias_arn = key_id.startswith("arn:") and ":alias/" in key_id + is_alias_name = key_id.startswith("alias/") + + if is_alias_arn or is_alias_name: + self._validate_alias(key_id) + return + + self._validate_cmk_id(key_id) + def create_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" policy = self.parameters.get('Policy') key_usage = self.parameters.get('KeyUsage') description = self.parameters.get('Description') @@ -38,20 +114,31 @@ class KmsResponse(BaseResponse): return json.dumps(key.to_dict()) def update_key_description(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html""" key_id = self.parameters.get('KeyId') description = self.parameters.get('Description') + self._validate_cmk_id(key_id) + self.kms_backend.update_key_description(key_id, description) return json.dumps(None) def tag_resource(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html""" key_id = self.parameters.get('KeyId') tags = self.parameters.get('Tags') + + self._validate_cmk_id(key_id) + self.kms_backend.tag_resource(key_id, tags) return json.dumps({}) def list_resource_tags(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" key_id = self.parameters.get('KeyId') + + self._validate_cmk_id(key_id) + tags = self.kms_backend.list_resource_tags(key_id) return json.dumps({ "Tags": tags, @@ -60,17 +147,19 @@ class KmsResponse(BaseResponse): }) def describe_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" key_id = self.parameters.get('KeyId') - try: - key = self.kms_backend.describe_key( - self.kms_backend.get_key_id(key_id)) - except KeyError: - headers = dict(self.headers) - headers['status'] = 404 - return "{}", headers + + self._validate_key_id(key_id) + + key = self.kms_backend.describe_key( + self.kms_backend.get_key_id(key_id) + ) + return json.dumps(key.to_dict()) def list_keys(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html""" keys = self.kms_backend.list_keys() return json.dumps({ @@ -85,6 +174,7 @@ class KmsResponse(BaseResponse): }) def create_alias(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html""" alias_name = self.parameters['AliasName'] target_key_id = self.parameters['TargetKeyId'] @@ -110,27 +200,31 @@ class KmsResponse(BaseResponse): raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} ' 'already exists'.format(region=self.region, alias_name=alias_name)) + self._validate_cmk_id(target_key_id) + self.kms_backend.add_alias(target_key_id, alias_name) return json.dumps(None) def delete_alias(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html""" alias_name = self.parameters['AliasName'] if not alias_name.startswith('alias/'): raise ValidationException('Invalid identifier') - if not self.kms_backend.alias_exists(alias_name): - raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:' - '{alias_name} is not found.'.format(region=self.region, alias_name=alias_name)) + self._validate_alias(alias_name) self.kms_backend.delete_alias(alias_name) return json.dumps(None) def list_aliases(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html""" region = self.region + # TODO: The actual API can filter on KeyId. + response_aliases = [ { 'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region, @@ -155,191 +249,239 @@ class KmsResponse(BaseResponse): }) def enable_key_rotation(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.enable_key_rotation(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.enable_key_rotation(key_id) return json.dumps(None) def disable_key_rotation(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.disable_key_rotation(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.disable_key_rotation(key_id) + return json.dumps(None) def get_key_rotation_status(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) + return json.dumps({'KeyRotationEnabled': rotation_enabled}) def put_key_policy(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html""" key_id = self.parameters.get('KeyId') policy_name = self.parameters.get('PolicyName') policy = self.parameters.get('Policy') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) _assert_default_policy(policy_name) - try: - self.kms_backend.put_key_policy(key_id, policy) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + self._validate_cmk_id(key_id) + + self.kms_backend.put_key_policy(key_id, policy) return json.dumps(None) def get_key_policy(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html""" key_id = self.parameters.get('KeyId') policy_name = self.parameters.get('PolicyName') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) _assert_default_policy(policy_name) - try: - return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + self._validate_cmk_id(key_id) + + return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) def list_key_policies(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.describe_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.describe_key(key_id) return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) def encrypt(self): - """ - We perform no encryption, we just encode the value as base64 and then - decode it in decrypt(). - """ - value = self.parameters.get("Plaintext") - if isinstance(value, six.text_type): - value = value.encode('utf-8') - return json.dumps({"CiphertextBlob": base64.b64encode(value).decode("utf-8"), 'KeyId': 'key_id'}) + """https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html""" + key_id = self.parameters.get("KeyId") + encryption_context = self.parameters.get('EncryptionContext', {}) + plaintext = self.parameters.get("Plaintext") + + self._validate_key_id(key_id) + + if isinstance(plaintext, six.text_type): + plaintext = plaintext.encode('utf-8') + + ciphertext_blob, arn = self.kms_backend.encrypt( + key_id=key_id, + plaintext=plaintext, + encryption_context=encryption_context, + ) + ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") + + return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn}) def decrypt(self): - # TODO refuse decode if EncryptionContext is not the same as when it was encrypted / generated + """https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html""" + ciphertext_blob = self.parameters.get("CiphertextBlob") + encryption_context = self.parameters.get('EncryptionContext', {}) - value = self.parameters.get("CiphertextBlob") - try: - return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'}) - except UnicodeDecodeError: - # Generate data key will produce random bytes which when decrypted is still returned as base64 - return json.dumps({"Plaintext": value}) + plaintext, arn = self.kms_backend.decrypt( + ciphertext_blob=ciphertext_blob, + encryption_context=encryption_context, + ) + + plaintext_response = base64.b64encode(plaintext).decode("utf-8") + + return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn}) + + def re_encrypt(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html""" + ciphertext_blob = self.parameters.get("CiphertextBlob") + source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) + destination_key_id = self.parameters.get("DestinationKeyId") + destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {}) + + self._validate_cmk_id(destination_key_id) + + new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt( + ciphertext_blob=ciphertext_blob, + source_encryption_context=source_encryption_context, + destination_key_id=destination_key_id, + destination_encryption_context=destination_encryption_context, + ) + + response_ciphertext_blob = base64.b64encode(new_ciphertext_blob).decode("utf-8") + + return json.dumps( + {"CiphertextBlob": response_ciphertext_blob, "KeyId": encrypting_arn, "SourceKeyId": decrypting_arn} + ) def disable_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.disable_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.disable_key(key_id) + return json.dumps(None) def enable_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.enable_key(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.enable_key(key_id) + return json.dumps(None) def cancel_key_deletion(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html""" key_id = self.parameters.get('KeyId') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - self.kms_backend.cancel_key_deletion(key_id) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + self.kms_backend.cancel_key_deletion(key_id) + return json.dumps({'KeyId': key_id}) def schedule_key_deletion(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html""" key_id = self.parameters.get('KeyId') if self.parameters.get('PendingWindowInDays') is None: pending_window_in_days = 30 else: pending_window_in_days = self.parameters.get('PendingWindowInDays') - _assert_valid_key_id(self.kms_backend.get_key_id(key_id)) - try: - return json.dumps({ - 'KeyId': key_id, - 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) - }) - except KeyError: - raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/" - "{key_id}' does not exist".format(region=self.region, key_id=key_id)) + + self._validate_cmk_id(key_id) + + return json.dumps({ + 'KeyId': key_id, + 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) + }) def generate_data_key(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html""" key_id = self.parameters.get('KeyId') - encryption_context = self.parameters.get('EncryptionContext') + encryption_context = self.parameters.get('EncryptionContext', {}) number_of_bytes = self.parameters.get('NumberOfBytes') key_spec = self.parameters.get('KeySpec') grant_tokens = self.parameters.get('GrantTokens') # Param validation - if key_id.startswith('alias'): - if self.kms_backend.get_key_id_from_alias(key_id) is None: - raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:{alias_name} is not found.'.format( - region=self.region, alias_name=key_id)) - else: - if self.kms_backend.get_key_id(key_id) not in self.kms_backend.keys: - raise NotFoundException('Invalid keyId') + self._validate_key_id(key_id) - if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0): - raise ValidationException("1 validation error detected: Value '2048' at 'numberOfBytes' failed " - "to satisfy constraint: Member must have value less than or " - "equal to 1024") + if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): + raise ValidationException(( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes)) if key_spec and key_spec not in ('AES_256', 'AES_128'): - raise ValidationException("1 validation error detected: Value 'AES_257' at 'keySpec' failed " - "to satisfy constraint: Member must satisfy enum value set: " - "[AES_256, AES_128]") + raise ValidationException(( + "1 validation error detected: Value '{key_spec}' at 'keySpec' failed " + "to satisfy constraint: Member must satisfy enum value set: " + "[AES_256, AES_128]" + ).format(key_spec=key_spec)) if not key_spec and not number_of_bytes: raise ValidationException("Please specify either number of bytes or key spec.") + if key_spec and number_of_bytes: raise ValidationException("Please specify either number of bytes or key spec.") - plaintext, key_arn = self.kms_backend.generate_data_key(key_id, encryption_context, - number_of_bytes, key_spec, grant_tokens) + plaintext, ciphertext_blob, key_arn = self.kms_backend.generate_data_key( + key_id=key_id, + encryption_context=encryption_context, + number_of_bytes=number_of_bytes, + key_spec=key_spec, + grant_tokens=grant_tokens + ) - plaintext = base64.b64encode(plaintext).decode() + plaintext_response = base64.b64encode(plaintext).decode("utf-8") + ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") return json.dumps({ - 'CiphertextBlob': plaintext, - 'Plaintext': plaintext, + 'CiphertextBlob': ciphertext_blob_response, + 'Plaintext': plaintext_response, 'KeyId': key_arn # not alias }) def generate_data_key_without_plaintext(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html""" result = json.loads(self.generate_data_key()) del result['Plaintext'] return json.dumps(result) + def generate_random(self): + """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html""" + number_of_bytes = self.parameters.get("NumberOfBytes") -def _assert_valid_key_id(key_id): - if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE): - raise NotFoundException('Invalid keyId') + if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): + raise ValidationException(( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes)) + + entropy = os.urandom(number_of_bytes) + + response_entropy = base64.b64encode(entropy).decode("utf-8") + + return json.dumps({"Plaintext": response_entropy}) def _assert_default_policy(policy_name): diff --git a/moto/kms/utils.py b/moto/kms/utils.py index fad38150f..96d3f25cc 100644 --- a/moto/kms/utils.py +++ b/moto/kms/utils.py @@ -1,7 +1,142 @@ from __future__ import unicode_literals +from collections import namedtuple +import io +import os +import struct import uuid +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes + +from .exceptions import InvalidCiphertextException, AccessDeniedException, NotFoundException + + +MASTER_KEY_LEN = 32 +KEY_ID_LEN = 36 +IV_LEN = 12 +TAG_LEN = 16 +HEADER_LEN = KEY_ID_LEN + IV_LEN + TAG_LEN +# NOTE: This is just a simple binary format. It is not what KMS actually does. +CIPHERTEXT_HEADER_FORMAT = ">{key_id_len}s{iv_len}s{tag_len}s".format( + key_id_len=KEY_ID_LEN, iv_len=IV_LEN, tag_len=TAG_LEN +) +Ciphertext = namedtuple("Ciphertext", ("key_id", "iv", "ciphertext", "tag")) + def generate_key_id(): return str(uuid.uuid4()) + + +def generate_data_key(number_of_bytes): + """Generate a data key.""" + return os.urandom(number_of_bytes) + + +def generate_master_key(): + """Generate a master key.""" + return generate_data_key(MASTER_KEY_LEN) + + +def _serialize_ciphertext_blob(ciphertext): + """Serialize Ciphertext object into a ciphertext blob. + + NOTE: This is just a simple binary format. It is not what KMS actually does. + """ + header = struct.pack(CIPHERTEXT_HEADER_FORMAT, ciphertext.key_id.encode("utf-8"), ciphertext.iv, ciphertext.tag) + return header + ciphertext.ciphertext + + +def _deserialize_ciphertext_blob(ciphertext_blob): + """Deserialize ciphertext blob into a Ciphertext object. + + NOTE: This is just a simple binary format. It is not what KMS actually does. + """ + header = ciphertext_blob[:HEADER_LEN] + ciphertext = ciphertext_blob[HEADER_LEN:] + key_id, iv, tag = struct.unpack(CIPHERTEXT_HEADER_FORMAT, header) + return Ciphertext(key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag) + + +def _serialize_encryption_context(encryption_context): + """Serialize encryption context for use a AAD. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + """ + aad = io.BytesIO() + for key, value in sorted(encryption_context.items(), key=lambda x: x[0]): + aad.write(key.encode("utf-8")) + aad.write(value.encode("utf-8")) + return aad.getvalue() + + +def encrypt(master_keys, key_id, plaintext, encryption_context): + """Encrypt data using a master key material. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + + NOTE: This function is NOT compatible with KMS APIs. + :param dict master_keys: Mapping of a KmsBackend's known master keys + :param str key_id: Key ID of moto master key + :param bytes plaintext: Plaintext data to encrypt + :param dict[str, str] encryption_context: KMS-style encryption context + :returns: Moto-structured ciphertext blob encrypted under a moto master key in master_keys + :rtype: bytes + """ + try: + key = master_keys[key_id] + except KeyError: + is_alias = key_id.startswith("alias/") or ":alias/" in key_id + raise NotFoundException( + "{id_type} {key_id} is not found.".format(id_type="Alias" if is_alias else "keyId", key_id=key_id) + ) + + iv = os.urandom(IV_LEN) + aad = _serialize_encryption_context(encryption_context=encryption_context) + + encryptor = Cipher(algorithms.AES(key.key_material), modes.GCM(iv), backend=default_backend()).encryptor() + encryptor.authenticate_additional_data(aad) + ciphertext = encryptor.update(plaintext) + encryptor.finalize() + return _serialize_ciphertext_blob( + ciphertext=Ciphertext(key_id=key_id, iv=iv, ciphertext=ciphertext, tag=encryptor.tag) + ) + + +def decrypt(master_keys, ciphertext_blob, encryption_context): + """Decrypt a ciphertext blob using a master key material. + + NOTE: This is not necessarily what KMS does, but it retains the same properties. + + NOTE: This function is NOT compatible with KMS APIs. + + :param dict master_keys: Mapping of a KmsBackend's known master keys + :param bytes ciphertext_blob: moto-structured ciphertext blob encrypted under a moto master key in master_keys + :param dict[str, str] encryption_context: KMS-style encryption context + :returns: plaintext bytes and moto key ID + :rtype: bytes and str + """ + try: + ciphertext = _deserialize_ciphertext_blob(ciphertext_blob=ciphertext_blob) + except Exception: + raise InvalidCiphertextException() + + aad = _serialize_encryption_context(encryption_context=encryption_context) + + try: + key = master_keys[ciphertext.key_id] + except KeyError: + raise AccessDeniedException( + "The ciphertext refers to a customer master key that does not exist, " + "does not exist in this region, or you are not allowed to access." + ) + + try: + decryptor = Cipher( + algorithms.AES(key.key_material), modes.GCM(ciphertext.iv, ciphertext.tag), backend=default_backend() + ).decryptor() + decryptor.authenticate_additional_data(aad) + plaintext = decryptor.update(ciphertext.ciphertext) + decryptor.finalize() + except Exception: + raise InvalidCiphertextException() + + return plaintext, ciphertext.key_id diff --git a/moto/logs/models.py b/moto/logs/models.py index 2b8dcfeb4..3c5360371 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -41,7 +41,7 @@ class LogStream: self.region = region self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format( region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name) - self.creationTime = unix_time_millis() + self.creationTime = int(unix_time_millis()) self.firstEventTimestamp = None self.lastEventTimestamp = None self.lastIngestionTime = None @@ -80,7 +80,7 @@ class LogStream: def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): # TODO: ensure sequence_token # TODO: to be thread safe this would need a lock - self.lastIngestionTime = unix_time_millis() + self.lastIngestionTime = int(unix_time_millis()) # TODO: make this match AWS if possible self.storedBytes += sum([len(log_event["message"]) for log_event in log_events]) self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events] @@ -115,6 +115,8 @@ class LogStream: events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]] if next_index + limit < len(self.events): next_index += limit + else: + next_index = len(self.events) back_index -= limit if back_index <= 0: @@ -146,7 +148,7 @@ class LogGroup: self.region = region self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format( region=region, log_group=name) - self.creationTime = unix_time_millis() + self.creationTime = int(unix_time_millis()) self.tags = tags self.streams = dict() # {name: LogStream} self.retentionInDays = None # AWS defaults to Never Expire for log group retention diff --git a/moto/redshift/models.py b/moto/redshift/models.py index c0b783bde..8a2b7e6b6 100644 --- a/moto/redshift/models.py +++ b/moto/redshift/models.py @@ -74,7 +74,7 @@ class Cluster(TaggableResourceMixin, BaseModel): automated_snapshot_retention_period, port, cluster_version, allow_version_upgrade, number_of_nodes, publicly_accessible, encrypted, region_name, tags=None, iam_roles_arn=None, - restored_from_snapshot=False): + enhanced_vpc_routing=None, restored_from_snapshot=False): super(Cluster, self).__init__(region_name, tags) self.redshift_backend = redshift_backend self.cluster_identifier = cluster_identifier @@ -85,6 +85,7 @@ class Cluster(TaggableResourceMixin, BaseModel): self.master_user_password = master_user_password self.db_name = db_name if db_name else "dev" self.vpc_security_group_ids = vpc_security_group_ids + self.enhanced_vpc_routing = enhanced_vpc_routing if enhanced_vpc_routing is not None else False self.cluster_subnet_group_name = cluster_subnet_group_name self.publicly_accessible = publicly_accessible self.encrypted = encrypted @@ -154,6 +155,7 @@ class Cluster(TaggableResourceMixin, BaseModel): port=properties.get('Port'), cluster_version=properties.get('ClusterVersion'), allow_version_upgrade=properties.get('AllowVersionUpgrade'), + enhanced_vpc_routing=properties.get('EnhancedVpcRouting'), number_of_nodes=properties.get('NumberOfNodes'), publicly_accessible=properties.get("PubliclyAccessible"), encrypted=properties.get("Encrypted"), @@ -241,6 +243,7 @@ class Cluster(TaggableResourceMixin, BaseModel): 'ClusterCreateTime': self.create_time, "PendingModifiedValues": [], "Tags": self.tags, + "EnhancedVpcRouting": self.enhanced_vpc_routing, "IamRoles": [{ "ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn @@ -427,6 +430,7 @@ class Snapshot(TaggableResourceMixin, BaseModel): 'NumberOfNodes': self.cluster.number_of_nodes, 'DBName': self.cluster.db_name, 'Tags': self.tags, + 'EnhancedVpcRouting': self.cluster.enhanced_vpc_routing, "IamRoles": [{ "ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn @@ -678,7 +682,8 @@ class RedshiftBackend(BaseBackend): "number_of_nodes": snapshot.cluster.number_of_nodes, "encrypted": snapshot.cluster.encrypted, "tags": snapshot.cluster.tags, - "restored_from_snapshot": True + "restored_from_snapshot": True, + "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing } create_kwargs.update(kwargs) return self.create_cluster(**create_kwargs) diff --git a/moto/redshift/responses.py b/moto/redshift/responses.py index a7758febb..7ac73d470 100644 --- a/moto/redshift/responses.py +++ b/moto/redshift/responses.py @@ -135,6 +135,7 @@ class RedshiftResponse(BaseResponse): "region_name": self.region, "tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')), "iam_roles_arn": self._get_iam_roles(), + "enhanced_vpc_routing": self._get_param('EnhancedVpcRouting'), } cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json() cluster['ClusterStatus'] = 'creating' @@ -150,6 +151,7 @@ class RedshiftResponse(BaseResponse): }) def restore_from_cluster_snapshot(self): + enhanced_vpc_routing = self._get_bool_param('EnhancedVpcRouting') restore_kwargs = { "snapshot_identifier": self._get_param('SnapshotIdentifier'), "cluster_identifier": self._get_param('ClusterIdentifier'), @@ -171,6 +173,8 @@ class RedshiftResponse(BaseResponse): "region_name": self.region, "iam_roles_arn": self._get_iam_roles(), } + if enhanced_vpc_routing is not None: + restore_kwargs['enhanced_vpc_routing'] = enhanced_vpc_routing cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json() cluster['ClusterStatus'] = 'creating' return self.get_response({ @@ -218,6 +222,7 @@ class RedshiftResponse(BaseResponse): "publicly_accessible": self._get_param("PubliclyAccessible"), "encrypted": self._get_param("Encrypted"), "iam_roles_arn": self._get_iam_roles(), + "enhanced_vpc_routing": self._get_param("EnhancedVpcRouting") } cluster_kwargs = {} # We only want parameters that were actually passed in, otherwise diff --git a/moto/route53/models.py b/moto/route53/models.py index 61a6609aa..77a0e59e6 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -305,6 +305,7 @@ class Route53Backend(BaseBackend): def list_tags_for_resource(self, resource_id): if resource_id in self.resource_tags: return self.resource_tags[resource_id] + return {} def get_all_hosted_zones(self): return self.zones.values() diff --git a/moto/s3/responses.py b/moto/s3/responses.py index ee047a14f..61ebff9d0 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -20,7 +20,7 @@ from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, Missi MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \ FakeTag -from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url +from .utils import bucket_name_from_url, clean_key_name, undo_clean_key_name, metadata_from_headers, parse_region_from_url from xml.dom import minidom @@ -451,17 +451,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): continuation_token = querystring.get('continuation-token', [None])[0] start_after = querystring.get('start-after', [None])[0] + # sort the combination of folders and keys into lexicographical order + all_keys = result_keys + result_folders + all_keys.sort(key=self._get_name) + if continuation_token or start_after: limit = continuation_token or start_after - if not delimiter: - result_keys = self._get_results_from_token(result_keys, limit) - else: - result_folders = self._get_results_from_token(result_folders, limit) + all_keys = self._get_results_from_token(all_keys, limit) - if not delimiter: - result_keys, is_truncated, next_continuation_token = self._truncate_result(result_keys, max_keys) - else: - result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys) + truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) + result_keys, result_folders = self._split_truncated_keys(truncated_keys) key_count = len(result_keys) + len(result_folders) @@ -479,6 +478,24 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): start_after=None if continuation_token else start_after ) + @staticmethod + def _get_name(key): + if isinstance(key, FakeKey): + return key.name + else: + return key + + @staticmethod + def _split_truncated_keys(truncated_keys): + result_keys = [] + result_folders = [] + for key in truncated_keys: + if isinstance(key, FakeKey): + result_keys.append(key) + else: + result_folders.append(key) + return result_keys, result_folders + def _get_results_from_token(self, result_keys, token): continuation_index = 0 for key in result_keys: @@ -694,7 +711,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): for k in keys: key_name = k.firstChild.nodeValue - success = self.backend.delete_key(bucket_name, key_name) + success = self.backend.delete_key(bucket_name, undo_clean_key_name(key_name)) if success: deleted_names.append(key_name) else: diff --git a/moto/s3/urls.py b/moto/s3/urls.py index fa81568a4..1388c81e5 100644 --- a/moto/s3/urls.py +++ b/moto/s3/urls.py @@ -15,4 +15,6 @@ url_paths = { '{0}/(?P[^/]+)/?$': S3ResponseInstance.ambiguous_response, # path-based bucket + key '{0}/(?P[^/]+)/(?P.+)': S3ResponseInstance.key_response, + # subdomain bucket + key with empty first part of path + '{0}//(?P.*)$': S3ResponseInstance.key_response, } diff --git a/moto/s3/utils.py b/moto/s3/utils.py index 85a812aad..3bdd24cc4 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -5,7 +5,7 @@ import os from boto.s3.key import Key import re import six -from six.moves.urllib.parse import urlparse, unquote +from six.moves.urllib.parse import urlparse, unquote, quote import sys @@ -71,10 +71,15 @@ def metadata_from_headers(headers): def clean_key_name(key_name): if six.PY2: return unquote(key_name.encode('utf-8')).decode('utf-8') - return unquote(key_name) +def undo_clean_key_name(key_name): + if six.PY2: + return quote(key_name.encode('utf-8')).decode('utf-8') + return quote(key_name) + + class _VersionedKeyStore(dict): """ A simplified/modified version of Django's `MultiValueDict` taken from: diff --git a/moto/secretsmanager/models.py b/moto/secretsmanager/models.py index 3e0424b6b..63d847c49 100644 --- a/moto/secretsmanager/models.py +++ b/moto/secretsmanager/models.py @@ -154,9 +154,9 @@ class SecretsManagerBackend(BaseBackend): return version_id - def put_secret_value(self, secret_id, secret_string, version_stages): + def put_secret_value(self, secret_id, secret_string, secret_binary, version_stages): - version_id = self._add_secret(secret_id, secret_string, version_stages=version_stages) + version_id = self._add_secret(secret_id, secret_string, secret_binary, version_stages=version_stages) response = json.dumps({ 'ARN': secret_arn(self.region, secret_id), diff --git a/moto/secretsmanager/responses.py b/moto/secretsmanager/responses.py index 090688351..4995c4bc7 100644 --- a/moto/secretsmanager/responses.py +++ b/moto/secretsmanager/responses.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse +from moto.secretsmanager.exceptions import InvalidRequestException from .models import secretsmanager_backends @@ -71,10 +72,14 @@ class SecretsManagerResponse(BaseResponse): def put_secret_value(self): secret_id = self._get_param('SecretId', if_none='') - secret_string = self._get_param('SecretString', if_none='') + secret_string = self._get_param('SecretString') + secret_binary = self._get_param('SecretBinary') + if not secret_binary and not secret_string: + raise InvalidRequestException('You must provide either SecretString or SecretBinary.') version_stages = self._get_param('VersionStages', if_none=['AWSCURRENT']) return secretsmanager_backends[self.region].put_secret_value( secret_id=secret_id, + secret_binary=secret_binary, secret_string=secret_string, version_stages=version_stages, ) diff --git a/moto/server.py b/moto/server.py index 89be47093..b245f6e6f 100644 --- a/moto/server.py +++ b/moto/server.py @@ -174,10 +174,11 @@ def create_backend_app(service): backend_app.url_map.converters['regex'] = RegexConverter backend = list(BACKENDS[service].values())[0] for url_path, handler in backend.flask_paths.items(): + view_func = convert_flask_to_httpretty_response(handler) if handler.__name__ == 'dispatch': endpoint = '{0}.dispatch'.format(handler.__self__.__name__) else: - endpoint = None + endpoint = view_func.__name__ original_endpoint = endpoint index = 2 @@ -191,7 +192,7 @@ def create_backend_app(service): url_path, endpoint=endpoint, methods=HTTP_METHODS, - view_func=convert_flask_to_httpretty_response(handler), + view_func=view_func, strict_slashes=False, ) diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 0e7a0bdcf..706b3b5cc 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -40,3 +40,11 @@ class InvalidParameterValue(RESTError): def __init__(self, message): super(InvalidParameterValue, self).__init__( "InvalidParameterValue", message) + + +class InternalError(RESTError): + code = 500 + + def __init__(self, message): + super(InternalError, self).__init__( + "InternalFailure", message) diff --git a/moto/sns/models.py b/moto/sns/models.py index f1293eb0f..92e6c61de 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -18,7 +18,7 @@ from moto.awslambda import lambda_backends from .exceptions import ( SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter, - InvalidParameterValue + InvalidParameterValue, InternalError ) from .utils import make_arn_for_topic, make_arn_for_subscription @@ -131,13 +131,47 @@ class Subscription(BaseModel): message_attributes = {} def _field_match(field, rules, message_attributes): - if field not in message_attributes: - return False for rule in rules: + # TODO: boolean value matching is not supported, SNS behavior unknown if isinstance(rule, six.string_types): - # only string value matching is supported + if field not in message_attributes: + return False if message_attributes[field]['Value'] == rule: return True + try: + json_data = json.loads(message_attributes[field]['Value']) + if rule in json_data: + return True + except (ValueError, TypeError): + pass + if isinstance(rule, (six.integer_types, float)): + if field not in message_attributes: + return False + if message_attributes[field]['Type'] == 'Number': + attribute_values = [message_attributes[field]['Value']] + elif message_attributes[field]['Type'] == 'String.Array': + try: + attribute_values = json.loads(message_attributes[field]['Value']) + if not isinstance(attribute_values, list): + attribute_values = [attribute_values] + except (ValueError, TypeError): + return False + else: + return False + + for attribute_values in attribute_values: + # Even the offical documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6 + # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + if int(attribute_values * 1000000) == int(rule * 1000000): + return True + if isinstance(rule, dict): + keyword = list(rule.keys())[0] + attributes = list(rule.values())[0] + if keyword == 'exists': + if attributes and field in message_attributes: + return True + elif not attributes and field not in message_attributes: + return True return False return all(_field_match(field, rules, message_attributes) @@ -421,7 +455,49 @@ class SNSBackend(BaseBackend): subscription.attributes[name] = value if name == 'FilterPolicy': - subscription._filter_policy = json.loads(value) + filter_policy = json.loads(value) + self._validate_filter_policy(filter_policy) + subscription._filter_policy = filter_policy + + def _validate_filter_policy(self, value): + # TODO: extend validation checks + combinations = 1 + for rules in six.itervalues(value): + combinations *= len(rules) + # Even the offical documentation states the total combination of values must not exceed 100, in reality it is 150 + # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints + if combinations > 150: + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Filter policy is too complex") + + for field, rules in six.iteritems(value): + for rule in rules: + if rule is None: + continue + if isinstance(rule, six.string_types): + continue + if isinstance(rule, bool): + continue + if isinstance(rule, (six.integer_types, float)): + if rule <= -1000000000 or rule >= 1000000000: + raise InternalError("Unknown") + continue + if isinstance(rule, dict): + keyword = list(rule.keys())[0] + attributes = list(rule.values())[0] + if keyword == 'anything-but': + continue + elif keyword == 'exists': + if not isinstance(attributes, bool): + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: exists match pattern must be either true or false.") + continue + elif keyword == 'numeric': + continue + elif keyword == 'prefix': + continue + else: + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Unrecognized match type {type}".format(type=keyword)) + + raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null") sns_backends = {} diff --git a/moto/sns/responses.py b/moto/sns/responses.py index 440115429..578c5ea65 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -57,7 +57,16 @@ class SNSResponse(BaseResponse): transform_value = None if 'StringValue' in value: - transform_value = value['StringValue'] + if data_type == 'Number': + try: + transform_value = float(value['StringValue']) + except ValueError: + raise InvalidParameterValue( + "An error occurred (ParameterValueInvalid) " + "when calling the Publish operation: " + "Could not cast message attribute '{0}' value to number.".format(name)) + else: + transform_value = value['StringValue'] elif 'BinaryValue' in value: transform_value = value['BinaryValue'] if not transform_value: diff --git a/moto/sqs/models.py b/moto/sqs/models.py index e774e261c..188e25e9e 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -265,6 +265,9 @@ class Queue(BaseModel): if 'maxReceiveCount' not in self.redrive_policy: raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount') + # 'maxReceiveCount' is stored as int + self.redrive_policy['maxReceiveCount'] = int(self.redrive_policy['maxReceiveCount']) + for queue in sqs_backends[self.region].queues.values(): if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']: self.dead_letter_queue = queue @@ -424,13 +427,26 @@ class SQSBackend(BaseBackend): queue_attributes = queue.attributes new_queue_attributes = new_queue.attributes + static_attributes = ( + 'DelaySeconds', + 'MaximumMessageSize', + 'MessageRetentionPeriod', + 'Policy', + 'QueueArn', + 'ReceiveMessageWaitTimeSeconds', + 'RedrivePolicy', + 'VisibilityTimeout', + 'KmsMasterKeyId', + 'KmsDataKeyReusePeriodSeconds', + 'FifoQueue', + 'ContentBasedDeduplication', + ) - for key in ['CreatedTimestamp', 'LastModifiedTimestamp']: - queue_attributes.pop(key) - new_queue_attributes.pop(key) - - if queue_attributes != new_queue_attributes: - raise QueueAlreadyExists("The specified queue already exists.") + for key in static_attributes: + if queue_attributes.get(key) != new_queue_attributes.get(key): + raise QueueAlreadyExists( + "The specified queue already exists.", + ) else: try: kwargs.pop('region') diff --git a/moto/stepfunctions/__init__.py b/moto/stepfunctions/__init__.py new file mode 100644 index 000000000..dc2b0ba13 --- /dev/null +++ b/moto/stepfunctions/__init__.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals +from .models import stepfunction_backends +from ..core.models import base_decorator + +stepfunction_backend = stepfunction_backends['us-east-1'] +mock_stepfunctions = base_decorator(stepfunction_backends) diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py new file mode 100644 index 000000000..8af4686c7 --- /dev/null +++ b/moto/stepfunctions/exceptions.py @@ -0,0 +1,35 @@ +from __future__ import unicode_literals +import json + + +class AWSError(Exception): + TYPE = None + STATUS = 400 + + def __init__(self, message, type=None, status=None): + self.message = message + self.type = type if type is not None else self.TYPE + self.status = status if status is not None else self.STATUS + + def response(self): + return json.dumps({'__type': self.type, 'message': self.message}), dict(status=self.status) + + +class ExecutionDoesNotExist(AWSError): + TYPE = 'ExecutionDoesNotExist' + STATUS = 400 + + +class InvalidArn(AWSError): + TYPE = 'InvalidArn' + STATUS = 400 + + +class InvalidName(AWSError): + TYPE = 'InvalidName' + STATUS = 400 + + +class StateMachineDoesNotExist(AWSError): + TYPE = 'StateMachineDoesNotExist' + STATUS = 400 diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py new file mode 100644 index 000000000..7784919b0 --- /dev/null +++ b/moto/stepfunctions/models.py @@ -0,0 +1,162 @@ +import boto +import re +from datetime import datetime +from moto.core import BaseBackend +from moto.core.utils import iso_8601_datetime_without_milliseconds +from moto.sts.models import ACCOUNT_ID +from uuid import uuid4 +from .exceptions import ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist + + +class StateMachine(): + def __init__(self, arn, name, definition, roleArn, tags=None): + self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.arn = arn + self.name = name + self.definition = definition + self.roleArn = roleArn + self.tags = tags + + +class Execution(): + def __init__(self, region_name, account_id, state_machine_name, execution_name, state_machine_arn): + execution_arn = 'arn:aws:states:{}:{}:execution:{}:{}' + execution_arn = execution_arn.format(region_name, account_id, state_machine_name, execution_name) + self.execution_arn = execution_arn + self.name = execution_name + self.start_date = iso_8601_datetime_without_milliseconds(datetime.now()) + self.state_machine_arn = state_machine_arn + self.status = 'RUNNING' + self.stop_date = None + + def stop(self): + self.status = 'SUCCEEDED' + self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now()) + + +class StepFunctionBackend(BaseBackend): + + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.create_state_machine + # A name must not contain: + # whitespace + # brackets < > { } [ ] + # wildcard characters ? * + # special characters " # % \ ^ | ~ ` $ & , ; : / + invalid_chars_for_name = [' ', '{', '}', '[', ']', '<', '>', + '?', '*', + '"', '#', '%', '\\', '^', '|', '~', '`', '$', '&', ',', ';', ':', '/'] + # control characters (U+0000-001F , U+007F-009F ) + invalid_unicodes_for_name = [u'\u0000', u'\u0001', u'\u0002', u'\u0003', u'\u0004', + u'\u0005', u'\u0006', u'\u0007', u'\u0008', u'\u0009', + u'\u000A', u'\u000B', u'\u000C', u'\u000D', u'\u000E', u'\u000F', + u'\u0010', u'\u0011', u'\u0012', u'\u0013', u'\u0014', + u'\u0015', u'\u0016', u'\u0017', u'\u0018', u'\u0019', + u'\u001A', u'\u001B', u'\u001C', u'\u001D', u'\u001E', u'\u001F', + u'\u007F', + u'\u0080', u'\u0081', u'\u0082', u'\u0083', u'\u0084', u'\u0085', + u'\u0086', u'\u0087', u'\u0088', u'\u0089', + u'\u008A', u'\u008B', u'\u008C', u'\u008D', u'\u008E', u'\u008F', + u'\u0090', u'\u0091', u'\u0092', u'\u0093', u'\u0094', u'\u0095', + u'\u0096', u'\u0097', u'\u0098', u'\u0099', + u'\u009A', u'\u009B', u'\u009C', u'\u009D', u'\u009E', u'\u009F'] + accepted_role_arn_format = re.compile('arn:aws:iam:(?P[0-9]{12}):role/.+') + accepted_mchn_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):stateMachine:.+') + accepted_exec_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):execution:.+') + + def __init__(self, region_name): + self.state_machines = [] + self.executions = [] + self.region_name = region_name + self._account_id = None + + def create_state_machine(self, name, definition, roleArn, tags=None): + self._validate_name(name) + self._validate_role_arn(roleArn) + arn = 'arn:aws:states:' + self.region_name + ':' + str(self._get_account_id()) + ':stateMachine:' + name + try: + return self.describe_state_machine(arn) + except StateMachineDoesNotExist: + state_machine = StateMachine(arn, name, definition, roleArn, tags) + self.state_machines.append(state_machine) + return state_machine + + def list_state_machines(self): + return self.state_machines + + def describe_state_machine(self, arn): + self._validate_machine_arn(arn) + sm = next((x for x in self.state_machines if x.arn == arn), None) + if not sm: + raise StateMachineDoesNotExist("State Machine Does Not Exist: '" + arn + "'") + return sm + + def delete_state_machine(self, arn): + self._validate_machine_arn(arn) + sm = next((x for x in self.state_machines if x.arn == arn), None) + if sm: + self.state_machines.remove(sm) + + def start_execution(self, state_machine_arn): + state_machine_name = self.describe_state_machine(state_machine_arn).name + execution = Execution(region_name=self.region_name, + account_id=self._get_account_id(), + state_machine_name=state_machine_name, + execution_name=str(uuid4()), + state_machine_arn=state_machine_arn) + self.executions.append(execution) + return execution + + def stop_execution(self, execution_arn): + execution = next((x for x in self.executions if x.execution_arn == execution_arn), None) + if not execution: + raise ExecutionDoesNotExist("Execution Does Not Exist: '" + execution_arn + "'") + execution.stop() + return execution + + def list_executions(self, state_machine_arn): + return [execution for execution in self.executions if execution.state_machine_arn == state_machine_arn] + + def describe_execution(self, arn): + self._validate_execution_arn(arn) + exctn = next((x for x in self.executions if x.execution_arn == arn), None) + if not exctn: + raise ExecutionDoesNotExist("Execution Does Not Exist: '" + arn + "'") + return exctn + + def reset(self): + region_name = self.region_name + self.__dict__ = {} + self.__init__(region_name) + + def _validate_name(self, name): + if any(invalid_char in name for invalid_char in self.invalid_chars_for_name): + raise InvalidName("Invalid Name: '" + name + "'") + + if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name): + raise InvalidName("Invalid Name: '" + name + "'") + + def _validate_role_arn(self, role_arn): + self._validate_arn(arn=role_arn, + regex=self.accepted_role_arn_format, + invalid_msg="Invalid Role Arn: '" + role_arn + "'") + + def _validate_machine_arn(self, machine_arn): + self._validate_arn(arn=machine_arn, + regex=self.accepted_mchn_arn_format, + invalid_msg="Invalid Role Arn: '" + machine_arn + "'") + + def _validate_execution_arn(self, execution_arn): + self._validate_arn(arn=execution_arn, + regex=self.accepted_exec_arn_format, + invalid_msg="Execution Does Not Exist: '" + execution_arn + "'") + + def _validate_arn(self, arn, regex, invalid_msg): + match = regex.match(arn) + if not arn or not match: + raise InvalidArn(invalid_msg) + + def _get_account_id(self): + return ACCOUNT_ID + + +stepfunction_backends = {_region.name: StepFunctionBackend(_region.name) for _region in boto.awslambda.regions()} diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py new file mode 100644 index 000000000..0a170aa57 --- /dev/null +++ b/moto/stepfunctions/responses.py @@ -0,0 +1,138 @@ +from __future__ import unicode_literals + +import json + +from moto.core.responses import BaseResponse +from moto.core.utils import amzn_request_id +from .exceptions import AWSError +from .models import stepfunction_backends + + +class StepFunctionResponse(BaseResponse): + + @property + def stepfunction_backend(self): + return stepfunction_backends[self.region] + + @amzn_request_id + def create_state_machine(self): + name = self._get_param('name') + definition = self._get_param('definition') + roleArn = self._get_param('roleArn') + tags = self._get_param('tags') + try: + state_machine = self.stepfunction_backend.create_state_machine(name=name, definition=definition, + roleArn=roleArn, + tags=tags) + response = { + 'creationDate': state_machine.creation_date, + 'stateMachineArn': state_machine.arn + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def list_state_machines(self): + list_all = self.stepfunction_backend.list_state_machines() + list_all = sorted([{'creationDate': sm.creation_date, + 'name': sm.name, + 'stateMachineArn': sm.arn} for sm in list_all], + key=lambda x: x['name']) + response = {'stateMachines': list_all} + return 200, {}, json.dumps(response) + + @amzn_request_id + def describe_state_machine(self): + arn = self._get_param('stateMachineArn') + return self._describe_state_machine(arn) + + @amzn_request_id + def _describe_state_machine(self, state_machine_arn): + try: + state_machine = self.stepfunction_backend.describe_state_machine(state_machine_arn) + response = { + 'creationDate': state_machine.creation_date, + 'stateMachineArn': state_machine.arn, + 'definition': state_machine.definition, + 'name': state_machine.name, + 'roleArn': state_machine.roleArn, + 'status': 'ACTIVE' + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def delete_state_machine(self): + arn = self._get_param('stateMachineArn') + try: + self.stepfunction_backend.delete_state_machine(arn) + return 200, {}, json.dumps('{}') + except AWSError as err: + return err.response() + + @amzn_request_id + def list_tags_for_resource(self): + arn = self._get_param('resourceArn') + try: + state_machine = self.stepfunction_backend.describe_state_machine(arn) + tags = state_machine.tags or [] + except AWSError: + tags = [] + response = {'tags': tags} + return 200, {}, json.dumps(response) + + @amzn_request_id + def start_execution(self): + arn = self._get_param('stateMachineArn') + execution = self.stepfunction_backend.start_execution(arn) + response = {'executionArn': execution.execution_arn, + 'startDate': execution.start_date} + return 200, {}, json.dumps(response) + + @amzn_request_id + def list_executions(self): + arn = self._get_param('stateMachineArn') + state_machine = self.stepfunction_backend.describe_state_machine(arn) + executions = self.stepfunction_backend.list_executions(arn) + executions = [{'executionArn': execution.execution_arn, + 'name': execution.name, + 'startDate': execution.start_date, + 'stateMachineArn': state_machine.arn, + 'status': execution.status} for execution in executions] + return 200, {}, json.dumps({'executions': executions}) + + @amzn_request_id + def describe_execution(self): + arn = self._get_param('executionArn') + try: + execution = self.stepfunction_backend.describe_execution(arn) + response = { + 'executionArn': arn, + 'input': '{}', + 'name': execution.name, + 'startDate': execution.start_date, + 'stateMachineArn': execution.state_machine_arn, + 'status': execution.status, + 'stopDate': execution.stop_date + } + return 200, {}, json.dumps(response) + except AWSError as err: + return err.response() + + @amzn_request_id + def describe_state_machine_for_execution(self): + arn = self._get_param('executionArn') + try: + execution = self.stepfunction_backend.describe_execution(arn) + return self._describe_state_machine(execution.state_machine_arn) + except AWSError as err: + return err.response() + + @amzn_request_id + def stop_execution(self): + arn = self._get_param('executionArn') + execution = self.stepfunction_backend.stop_execution(arn) + response = {'stopDate': execution.stop_date} + return 200, {}, json.dumps(response) diff --git a/moto/stepfunctions/urls.py b/moto/stepfunctions/urls.py new file mode 100644 index 000000000..f8d5fb1e8 --- /dev/null +++ b/moto/stepfunctions/urls.py @@ -0,0 +1,10 @@ +from __future__ import unicode_literals +from .responses import StepFunctionResponse + +url_bases = [ + "https?://states.(.+).amazonaws.com", +] + +url_paths = { + '{0}/$': StepFunctionResponse.dispatch, +} diff --git a/requirements-dev.txt b/requirements-dev.txt index f87ab3db6..1dd8ef1f8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -10,6 +10,7 @@ boto>=2.45.0 boto3>=1.4.4 botocore>=1.12.13 six>=1.9 +parameterized>=0.7.0 prompt-toolkit==1.0.14 click==6.7 inflection==0.3.1 diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index 0a33f2f9f..20cc078b8 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -981,11 +981,13 @@ def test_api_keys(): apikey['value'].should.equal(apikey_value) apikey_name = 'TESTKEY2' - payload = {'name': apikey_name } + payload = {'name': apikey_name, 'tags': {'tag1': 'test_tag1', 'tag2': '1'}} response = client.create_api_key(**payload) apikey_id = response['id'] apikey = client.get_api_key(apiKey=apikey_id) apikey['name'].should.equal(apikey_name) + apikey['tags']['tag1'].should.equal('test_tag1') + apikey['tags']['tag2'].should.equal('1') len(apikey['value']).should.equal(40) apikey_name = 'TESTKEY3' diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index 89a8d4d0e..5487cfb91 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -563,6 +563,38 @@ def test_reregister_task_definition(): resp2['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp3 = batch_client.register_job_definition( + jobDefinitionName='sleep10', + type='container', + containerProperties={ + 'image': 'busybox', + 'vcpus': 1, + 'memory': 42, + 'command': ['sleep', '10'] + } + ) + resp3['revision'].should.equal(3) + + resp3['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp3['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn']) + + resp4 = batch_client.register_job_definition( + jobDefinitionName='sleep10', + type='container', + containerProperties={ + 'image': 'busybox', + 'vcpus': 1, + 'memory': 41, + 'command': ['sleep', '10'] + } + ) + resp4['revision'].should.equal(4) + + resp4['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp4['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn']) + resp4['jobDefinitionArn'].should_not.equal(resp3['jobDefinitionArn']) + + @mock_ec2 @mock_ecs diff --git a/tests/test_cognitoidentity/test_cognitoidentity.py b/tests/test_cognitoidentity/test_cognitoidentity.py index ea9ccbc78..67679e896 100644 --- a/tests/test_cognitoidentity/test_cognitoidentity.py +++ b/tests/test_cognitoidentity/test_cognitoidentity.py @@ -1,10 +1,10 @@ from __future__ import unicode_literals import boto3 +from botocore.exceptions import ClientError +from nose.tools import assert_raises from moto import mock_cognitoidentity -import sure # noqa - from moto.cognitoidentity.utils import get_random_identity_id @@ -28,6 +28,47 @@ def test_create_identity_pool(): assert result['IdentityPoolId'] != '' +@mock_cognitoidentity +def test_describe_identity_pool(): + conn = boto3.client('cognito-identity', 'us-west-2') + + res = conn.create_identity_pool(IdentityPoolName='TestPool', + AllowUnauthenticatedIdentities=False, + SupportedLoginProviders={'graph.facebook.com': '123456789012345'}, + DeveloperProviderName='devname', + OpenIdConnectProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'], + CognitoIdentityProviders=[ + { + 'ProviderName': 'testprovider', + 'ClientId': 'CLIENT12345', + 'ServerSideTokenCheck': True + }, + ], + SamlProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db']) + + result = conn.describe_identity_pool(IdentityPoolId=res['IdentityPoolId']) + + assert result['IdentityPoolId'] == res['IdentityPoolId'] + assert result['AllowUnauthenticatedIdentities'] == res['AllowUnauthenticatedIdentities'] + assert result['SupportedLoginProviders'] == res['SupportedLoginProviders'] + assert result['DeveloperProviderName'] == res['DeveloperProviderName'] + assert result['OpenIdConnectProviderARNs'] == res['OpenIdConnectProviderARNs'] + assert result['CognitoIdentityProviders'] == res['CognitoIdentityProviders'] + assert result['SamlProviderARNs'] == res['SamlProviderARNs'] + + +@mock_cognitoidentity +def test_describe_identity_pool_with_invalid_id_raises_error(): + conn = boto3.client('cognito-identity', 'us-west-2') + + with assert_raises(ClientError) as cm: + conn.describe_identity_pool(IdentityPoolId='us-west-2_non-existent') + + cm.exception.operation_name.should.equal('DescribeIdentityPool') + cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') + cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + + # testing a helper function def test_get_random_identity_id(): assert len(get_random_identity_id('us-west-2')) > 0 @@ -44,7 +85,8 @@ def test_get_id(): 'someurl': '12345' }) print(result) - assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 + assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get( + 'HTTPStatusCode') == 200 @mock_cognitoidentity @@ -71,6 +113,7 @@ def test_get_open_id_token_for_developer_identity(): assert len(result['Token']) > 0 assert result['IdentityId'] == '12345' + @mock_cognitoidentity def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id(): conn = boto3.client('cognito-identity', 'us-west-2') @@ -84,6 +127,7 @@ def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id() assert len(result['Token']) > 0 assert len(result['IdentityId']) > 0 + @mock_cognitoidentity def test_get_open_id_token(): conn = boto3.client('cognito-identity', 'us-west-2') diff --git a/tests/test_core/test_request_mocking.py b/tests/test_core/test_request_mocking.py new file mode 100644 index 000000000..ee3ec5f88 --- /dev/null +++ b/tests/test_core/test_request_mocking.py @@ -0,0 +1,22 @@ +import requests +import sure # noqa + +import boto3 +from moto import mock_sqs, settings + + +@mock_sqs +def test_passthrough_requests(): + conn = boto3.client("sqs", region_name='us-west-1') + conn.create_queue(QueueName="queue1") + + res = requests.get("https://httpbin.org/ip") + assert res.status_code == 200 + + +if not settings.TEST_SERVER_MODE: + @mock_sqs + def test_requests_to_amazon_subdomains_dont_work(): + res = requests.get("https://fakeservice.amazonaws.com/foo/bar") + assert res.content == b"The method is not implemented" + assert res.status_code == 400 diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index fb6c0e17d..4e7b2dfeb 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -973,6 +973,53 @@ def test_query_filter(): assert response['Count'] == 2 +@mock_dynamodb2 +def test_query_filter_overlapping_expression_prefixes(): + client = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + + # Create the DynamoDB table. + client.create_table( + TableName='test1', + AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], + KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], + ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + ) + + client.put_item( + TableName='test1', + Item={ + 'client': {'S': 'client1'}, + 'app': {'S': 'app1'}, + 'nested': {'M': { + 'version': {'S': 'version1'}, + 'contents': {'L': [ + {'S': 'value1'}, {'S': 'value2'}, + ]}, + }}, + }) + + table = dynamodb.Table('test1') + response = table.query( + KeyConditionExpression=Key('client').eq('client1') & Key('app').eq('app1'), + ProjectionExpression='#1, #10, nested', + ExpressionAttributeNames={ + '#1': 'client', + '#10': 'app', + } + ) + + assert response['Count'] == 1 + assert response['Items'][0] == { + 'client': 'client1', + 'app': 'app1', + 'nested': { + 'version': 'version1', + 'contents': ['value1', 'value2'] + } + } + + @mock_dynamodb2 def test_scan_filter(): client = boto3.client('dynamodb', region_name='us-east-1') @@ -2034,6 +2081,36 @@ def test_condition_expression__or_order(): ) +@mock_dynamodb2 +def test_condition_expression__and_order(): + client = boto3.client('dynamodb', region_name='us-east-1') + + client.create_table( + TableName='test', + KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], + AttributeDefinitions=[ + {'AttributeName': 'forum_name', 'AttributeType': 'S'}, + ], + ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ) + + # ensure that the RHS of the AND expression is not evaluated if the LHS + # returns true (as it would result an error) + with assert_raises(client.exceptions.ConditionalCheckFailedException): + client.update_item( + TableName='test', + Key={ + 'forum_name': {'S': 'the-key'}, + }, + UpdateExpression='set #ttl=:ttl', + ConditionExpression='attribute_exists(#ttl) AND #ttl <= :old_ttl', + ExpressionAttributeNames={'#ttl': 'ttl'}, + ExpressionAttributeValues={ + ':ttl': {'N': '6'}, + ':old_ttl': {'N': '5'}, + } + ) + @mock_dynamodb2 def test_query_gsi_with_range_key(): dynamodb = boto3.client('dynamodb', region_name='us-east-1') diff --git a/tests/test_dynamodbstreams/test_dynamodbstreams.py b/tests/test_dynamodbstreams/test_dynamodbstreams.py index b60c21053..deb9f9283 100644 --- a/tests/test_dynamodbstreams/test_dynamodbstreams.py +++ b/tests/test_dynamodbstreams/test_dynamodbstreams.py @@ -76,6 +76,34 @@ class TestCore(): ShardIteratorType='TRIM_HORIZON' ) assert 'ShardIterator' in resp + + def test_get_shard_iterator_at_sequence_number(self): + conn = boto3.client('dynamodbstreams', region_name='us-east-1') + + resp = conn.describe_stream(StreamArn=self.stream_arn) + shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] + + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AT_SEQUENCE_NUMBER', + SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber'] + ) + assert 'ShardIterator' in resp + + def test_get_shard_iterator_after_sequence_number(self): + conn = boto3.client('dynamodbstreams', region_name='us-east-1') + + resp = conn.describe_stream(StreamArn=self.stream_arn) + shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] + + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AFTER_SEQUENCE_NUMBER', + SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber'] + ) + assert 'ShardIterator' in resp def test_get_records_empty(self): conn = boto3.client('dynamodbstreams', region_name='us-east-1') @@ -135,11 +163,39 @@ class TestCore(): assert resp['Records'][1]['eventName'] == 'MODIFY' assert resp['Records'][2]['eventName'] == 'DELETE' + sequence_number_modify = resp['Records'][1]['dynamodb']['SequenceNumber'] + # now try fetching from the next shard iterator, it should be # empty resp = conn.get_records(ShardIterator=resp['NextShardIterator']) assert len(resp['Records']) == 0 + # check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AT_SEQUENCE_NUMBER', + SequenceNumber=sequence_number_modify + ) + iterator_id = resp['ShardIterator'] + resp = conn.get_records(ShardIterator=iterator_id) + assert len(resp['Records']) == 2 + assert resp['Records'][0]['eventName'] == 'MODIFY' + assert resp['Records'][1]['eventName'] == 'DELETE' + + # check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event + resp = conn.get_shard_iterator( + StreamArn=self.stream_arn, + ShardId=shard_id, + ShardIteratorType='AFTER_SEQUENCE_NUMBER', + SequenceNumber=sequence_number_modify + ) + iterator_id = resp['ShardIterator'] + resp = conn.get_records(ShardIterator=iterator_id) + assert len(resp['Records']) == 1 + assert resp['Records'][0]['eventName'] == 'DELETE' + + class TestEdges(): mocks = [] diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index 36772c02e..97b876fec 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -4,7 +4,7 @@ import json import os import boto3 import botocore -from botocore.exceptions import ClientError +from botocore.exceptions import ClientError, ParamValidationError from nose.tools import assert_raises import sure # noqa @@ -752,6 +752,83 @@ def test_stopped_instance_target(): }) +@mock_ec2 +@mock_elbv2 +def test_terminated_instance_target(): + target_group_port = 8080 + + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.0/26', + AvailabilityZone='us-east-1b') + + conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + + response = conn.create_target_group( + Name='a-target', + Protocol='HTTP', + Port=target_group_port, + VpcId=vpc.id, + HealthCheckProtocol='HTTP', + HealthCheckPath='/', + HealthCheckIntervalSeconds=5, + HealthCheckTimeoutSeconds=5, + HealthyThresholdCount=5, + UnhealthyThresholdCount=2, + Matcher={'HttpCode': '200'}) + target_group = response.get('TargetGroups')[0] + + # No targets registered yet + response = conn.describe_target_health( + TargetGroupArn=target_group.get('TargetGroupArn')) + response.get('TargetHealthDescriptions').should.have.length_of(0) + + response = ec2.create_instances( + ImageId='ami-1234abcd', MinCount=1, MaxCount=1) + instance = response[0] + + target_dict = { + 'Id': instance.id, + 'Port': 500 + } + + response = conn.register_targets( + TargetGroupArn=target_group.get('TargetGroupArn'), + Targets=[target_dict]) + + response = conn.describe_target_health( + TargetGroupArn=target_group.get('TargetGroupArn')) + response.get('TargetHealthDescriptions').should.have.length_of(1) + target_health_description = response.get('TargetHealthDescriptions')[0] + + target_health_description['Target'].should.equal(target_dict) + target_health_description['HealthCheckPort'].should.equal(str(target_group_port)) + target_health_description['TargetHealth'].should.equal({ + 'State': 'healthy' + }) + + instance.terminate() + + response = conn.describe_target_health( + TargetGroupArn=target_group.get('TargetGroupArn')) + response.get('TargetHealthDescriptions').should.have.length_of(0) + + @mock_ec2 @mock_elbv2 def test_target_group_attributes(): @@ -1940,3 +2017,279 @@ def test_cognito_action_listener_rule_cloudformation(): 'UserPoolDomain': 'testpool', } },]) + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '404', + } + } + response = conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[action]) + + listener = response.get('Listeners')[0] + listener.get('DefaultActions')[0].should.equal(action) + listener_arn = listener.get('ListenerArn') + + describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) + describe_rules_response['Rules'][0]['Actions'][0].should.equal(action) + + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) + describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0] + describe_listener_actions.should.equal(action) + + +@mock_elbv2 +@mock_cloudformation +def test_fixed_response_action_listener_rule_cloudformation(): + cnf_conn = boto3.client('cloudformation', region_name='us-east-1') + elbv2_client = boto3.client('elbv2', region_name='us-east-1') + + template = { + "AWSTemplateFormatVersion": "2010-09-09", + "Description": "ECS Cluster Test CloudFormation", + "Resources": { + "testVPC": { + "Type": "AWS::EC2::VPC", + "Properties": { + "CidrBlock": "10.0.0.0/16", + }, + }, + "subnet1": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.0.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "subnet2": { + "Type": "AWS::EC2::Subnet", + "Properties": { + "CidrBlock": "10.0.1.0/24", + "VpcId": {"Ref": "testVPC"}, + "AvalabilityZone": "us-east-1b", + }, + }, + "testLb": { + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Name": "my-lb", + "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], + "Type": "application", + "SecurityGroups": [], + } + }, + "testListener": { + "Type": "AWS::ElasticLoadBalancingV2::Listener", + "Properties": { + "LoadBalancerArn": {"Ref": "testLb"}, + "Port": 80, + "Protocol": "HTTP", + "DefaultActions": [{ + "Type": "fixed-response", + "FixedResponseConfig": { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '404', + } + }] + } + + } + } + } + template_json = json.dumps(template) + cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) + + describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) + load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] + describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + + describe_listeners_response['Listeners'].should.have.length_of(1) + describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ + 'Type': 'fixed-response', + "FixedResponseConfig": { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '404', + } + },]) + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule_validates_status_code(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + missing_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + } + } + with assert_raises(ParamValidationError): + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[missing_status_code_action]) + + invalid_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '100' + } + } + + @mock_elbv2 + @mock_ec2 + def test_fixed_response_action_listener_rule_validates_status_code(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + missing_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + } + } + with assert_raises(ParamValidationError): + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[missing_status_code_action]) + + invalid_status_code_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'text/plain', + 'MessageBody': 'This page does not exist', + 'StatusCode': '100' + } + } + + with assert_raises(ClientError) as invalid_status_code_exception: + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[invalid_status_code_action]) + + invalid_status_code_exception.exception.response['Error']['Code'].should.equal('ValidationError') + + +@mock_elbv2 +@mock_ec2 +def test_fixed_response_action_listener_rule_validates_content_type(): + conn = boto3.client('elbv2', region_name='us-east-1') + ec2 = boto3.resource('ec2', region_name='us-east-1') + + security_group = ec2.create_security_group( + GroupName='a-security-group', Description='First One') + vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + subnet1 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.192/26', + AvailabilityZone='us-east-1a') + subnet2 = ec2.create_subnet( + VpcId=vpc.id, + CidrBlock='172.28.7.128/26', + AvailabilityZone='us-east-1b') + + response = conn.create_load_balancer( + Name='my-lb', + Subnets=[subnet1.id, subnet2.id], + SecurityGroups=[security_group.id], + Scheme='internal', + Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + + invalid_content_type_action = { + 'Type': 'fixed-response', + 'FixedResponseConfig': { + 'ContentType': 'Fake content type', + 'MessageBody': 'This page does not exist', + 'StatusCode': '200' + } + } + with assert_raises(ClientError) as invalid_content_type_exception: + conn.create_listener(LoadBalancerArn=load_balancer_arn, + Protocol='HTTP', + Port=80, + DefaultActions=[invalid_content_type_action]) + invalid_content_type_exception.exception.response['Error']['Code'].should.equal('InvalidLoadBalancerAction') diff --git a/tests/test_iam/test_iam_policies.py b/tests/test_iam/test_iam_policies.py index e1924a559..adb8bd990 100644 --- a/tests/test_iam/test_iam_policies.py +++ b/tests/test_iam/test_iam_policies.py @@ -1827,6 +1827,23 @@ valid_policy_documents = [ "Resource": ["*"] } ] + }, + { + "Version": "2012-10-17", + "Statement": [ + { + "Sid": "", + "Effect": "Allow", + "Action": "rds:*", + "Resource": ["arn:aws:rds:region:*:*"] + }, + { + "Sid": "", + "Effect": "Allow", + "Action": ["rds:Describe*"], + "Resource": ["*"] + } + ] } ] diff --git a/tests/test_iotdata/test_iotdata.py b/tests/test_iotdata/test_iotdata.py index 09c1ada4c..1cedcaa72 100644 --- a/tests/test_iotdata/test_iotdata.py +++ b/tests/test_iotdata/test_iotdata.py @@ -86,6 +86,12 @@ def test_update(): payload.should.have.key('version').which.should.equal(2) payload.should.have.key('timestamp') + raw_payload = b'{"state": {"desired": {"led": "on"}}, "version": 1}' + with assert_raises(ClientError) as ex: + client.update_thing_shadow(thingName=name, payload=raw_payload) + ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(409) + ex.exception.response['Error']['Message'].should.equal('Version conflict') + @mock_iotdata def test_publish(): diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index f189fbe41..49c0f886e 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -1,333 +1,392 @@ +# -*- coding: utf-8 -*- from __future__ import unicode_literals -import os, re -import boto3 -import boto.kms -import botocore.exceptions -from boto.exception import JSONResponseError -from boto.kms.exceptions import AlreadyExistsException, NotFoundException - -from moto.kms.exceptions import NotFoundException as MotoNotFoundException -import sure # noqa -from moto import mock_kms, mock_kms_deprecated -from nose.tools import assert_raises -from freezegun import freeze_time from datetime import date from datetime import datetime from dateutil.tz import tzutc +import base64 +import os +import re + +import boto3 +import boto.kms +import botocore.exceptions +import six +import sure # noqa +from boto.exception import JSONResponseError +from boto.kms.exceptions import AlreadyExistsException, NotFoundException +from freezegun import freeze_time +from nose.tools import assert_raises +from parameterized import parameterized + +from moto.kms.exceptions import NotFoundException as MotoNotFoundException +from moto import mock_kms, mock_kms_deprecated + +PLAINTEXT_VECTORS = ( + (b"some encodeable plaintext",), + (b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",), + (u"some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",), +) + + +def _get_encoded_value(plaintext): + if isinstance(plaintext, six.binary_type): + return plaintext + + return plaintext.encode("utf-8") @mock_kms def test_create_key(): - conn = boto3.client('kms', region_name='us-east-1') + conn = boto3.client("kms", region_name="us-east-1") with freeze_time("2015-01-01 00:00:00"): - key = conn.create_key(Policy="my policy", - Description="my key", - KeyUsage='ENCRYPT_DECRYPT', - Tags=[ - { - 'TagKey': 'project', - 'TagValue': 'moto', - }, - ]) + key = conn.create_key( + Policy="my policy", + Description="my key", + KeyUsage="ENCRYPT_DECRYPT", + Tags=[{"TagKey": "project", "TagValue": "moto"}], + ) - key['KeyMetadata']['Description'].should.equal("my key") - key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - key['KeyMetadata']['Enabled'].should.equal(True) - key['KeyMetadata']['CreationDate'].should.be.a(date) + key["KeyMetadata"]["Description"].should.equal("my key") + key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + key["KeyMetadata"]["Enabled"].should.equal(True) + key["KeyMetadata"]["CreationDate"].should.be.a(date) @mock_kms_deprecated def test_describe_key(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] key = conn.describe_key(key_id) - key['KeyMetadata']['Description'].should.equal("my key") - key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") + key["KeyMetadata"]["Description"].should.equal("my key") + key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") @mock_kms_deprecated def test_describe_key_via_alias(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - alias_key = conn.describe_key('alias/my-key-alias') - alias_key['KeyMetadata']['Description'].should.equal("my key") - alias_key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - alias_key['KeyMetadata']['Arn'].should.equal(key['KeyMetadata']['Arn']) + alias_key = conn.describe_key("alias/my-key-alias") + alias_key["KeyMetadata"]["Description"].should.equal("my key") + alias_key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + alias_key["KeyMetadata"]["Arn"].should.equal(key["KeyMetadata"]["Arn"]) @mock_kms_deprecated def test_describe_key_via_alias_not_found(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - conn.describe_key.when.called_with( - 'alias/not-found-alias').should.throw(JSONResponseError) + conn.describe_key.when.called_with("alias/not-found-alias").should.throw(NotFoundException) + + +@parameterized(( + ("alias/does-not-exist",), + ("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",), + ("invalid",), +)) +@mock_kms +def test_describe_key_via_alias_invalid_alias(key_id): + client = boto3.client("kms", region_name="us-east-1") + client.create_key(Description="key") + + with assert_raises(client.exceptions.NotFoundException): + client.describe_key(KeyId=key_id) @mock_kms_deprecated def test_describe_key_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - arn = key['KeyMetadata']['Arn'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + arn = key["KeyMetadata"]["Arn"] the_key = conn.describe_key(arn) - the_key['KeyMetadata']['Description'].should.equal("my key") - the_key['KeyMetadata']['KeyUsage'].should.equal("ENCRYPT_DECRYPT") - the_key['KeyMetadata']['KeyId'].should.equal(key['KeyMetadata']['KeyId']) + the_key["KeyMetadata"]["Description"].should.equal("my key") + the_key["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") + the_key["KeyMetadata"]["KeyId"].should.equal(key["KeyMetadata"]["KeyId"]) @mock_kms_deprecated def test_describe_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.describe_key.when.called_with( - "not-a-key").should.throw(JSONResponseError) + conn.describe_key.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_list_keys(): conn = boto.kms.connect_to_region("us-west-2") - conn.create_key(policy="my policy", description="my key1", - key_usage='ENCRYPT_DECRYPT') - conn.create_key(policy="my policy", description="my key2", - key_usage='ENCRYPT_DECRYPT') + conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + conn.create_key(policy="my policy", description="my key2", key_usage="ENCRYPT_DECRYPT") keys = conn.list_keys() - keys['Keys'].should.have.length_of(2) + keys["Keys"].should.have.length_of(2) @mock_kms_deprecated def test_enable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) @mock_kms_deprecated def test_enable_key_rotation_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['Arn'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["Arn"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) @mock_kms_deprecated def test_enable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.enable_key_rotation.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_enable_key_rotation_with_alias_name_should_fail(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - alias_key = conn.describe_key('alias/my-key-alias') - alias_key['KeyMetadata']['Arn'].should.equal(key['KeyMetadata']['Arn']) + alias_key = conn.describe_key("alias/my-key-alias") + alias_key["KeyMetadata"]["Arn"].should.equal(key["KeyMetadata"]["Arn"]) - conn.enable_key_rotation.when.called_with( - 'alias/my-alias').should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("alias/my-alias").should.throw(NotFoundException) @mock_kms_deprecated def test_disable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(True) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(True) conn.disable_key_rotation(key_id) - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated -def test_encrypt(): - """ - test_encrypt - Using base64 encoding to merely test that the endpoint was called - """ +def test_generate_data_key(): conn = boto.kms.connect_to_region("us-west-2") - response = conn.encrypt('key_id', 'encryptme'.encode('utf-8')) - response['CiphertextBlob'].should.equal(b'ZW5jcnlwdG1l') - response['KeyId'].should.equal('key_id') + + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = conn.generate_data_key(key_id=key_id, number_of_bytes=32) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["KeyId"].should.equal(key_arn) -@mock_kms_deprecated -def test_decrypt(): - conn = boto.kms.connect_to_region('us-west-2') - response = conn.decrypt('ZW5jcnlwdG1l'.encode('utf-8')) - response['Plaintext'].should.equal(b'encryptme') - response['KeyId'].should.equal('key_id') +@mock_kms +def test_boto3_generate_data_key(): + kms = boto3.client("kms", region_name="us-west-2") + + key = kms.create_key() + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = kms.generate_data_key(KeyId=key_id, NumberOfBytes=32) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["KeyId"].should.equal(key_arn) + + +@parameterized(PLAINTEXT_VECTORS) +@mock_kms +def test_encrypt(plaintext): + client = boto3.client("kms", region_name="us-west-2") + + key = client.create_key(Description="key") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + response = client.encrypt(KeyId=key_id, Plaintext=plaintext) + response["CiphertextBlob"].should_not.equal(plaintext) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["CiphertextBlob"], validate=True) + + response["KeyId"].should.equal(key_arn) + + +@parameterized(PLAINTEXT_VECTORS) +@mock_kms +def test_decrypt(plaintext): + client = boto3.client("kms", region_name="us-west-2") + + key = client.create_key(Description="key") + key_id = key["KeyMetadata"]["KeyId"] + key_arn = key["KeyMetadata"]["Arn"] + + encrypt_response = client.encrypt(KeyId=key_id, Plaintext=plaintext) + + client.create_key(Description="key") + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(encrypt_response["CiphertextBlob"], validate=True) + + decrypt_response = client.decrypt(CiphertextBlob=encrypt_response["CiphertextBlob"]) + + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(decrypt_response["Plaintext"], validate=True) + + decrypt_response["Plaintext"].should.equal(_get_encoded_value(plaintext)) + decrypt_response["KeyId"].should.equal(key_arn) @mock_kms_deprecated def test_disable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.disable_key_rotation.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.disable_key_rotation.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_get_key_rotation_status_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.get_key_rotation_status.when.called_with( - "not-a-key").should.throw(NotFoundException) + conn.get_key_rotation_status.when.called_with("not-a-key").should.throw(NotFoundException) @mock_kms_deprecated def test_get_key_rotation_status(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated def test_create_key_defaults_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", - description="my key", key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - conn.get_key_rotation_status( - key_id)['KeyRotationEnabled'].should.equal(False) + conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @mock_kms_deprecated def test_get_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('my policy') + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_get_key_policy_via_arn(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - policy = conn.get_key_policy(key['KeyMetadata']['Arn'], 'default') + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + policy = conn.get_key_policy(key["KeyMetadata"]["Arn"], "default") - policy['Policy'].should.equal('my policy') + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_put_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] - conn.put_key_policy(key_id, 'default', 'new policy') - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('new policy') + conn.put_key_policy(key_id, "default", "new policy") + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_put_key_policy_via_arn(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['Arn'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["Arn"] - conn.put_key_policy(key_id, 'default', 'new policy') - policy = conn.get_key_policy(key_id, 'default') - policy['Policy'].should.equal('new policy') + conn.put_key_policy(key_id, "default", "new policy") + policy = conn.get_key_policy(key_id, "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_put_key_policy_via_alias_should_not_update(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - conn.create_alias(alias_name='alias/my-key-alias', - target_key_id=key['KeyMetadata']['KeyId']) + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) - conn.put_key_policy.when.called_with( - 'alias/my-key-alias', 'default', 'new policy').should.throw(NotFoundException) + conn.put_key_policy.when.called_with("alias/my-key-alias", "default", "new policy").should.throw(NotFoundException) - policy = conn.get_key_policy(key['KeyMetadata']['KeyId'], 'default') - policy['Policy'].should.equal('my policy') + policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") + policy["Policy"].should.equal("my policy") @mock_kms_deprecated def test_put_key_policy(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - conn.put_key_policy(key['KeyMetadata']['Arn'], 'default', 'new policy') + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + conn.put_key_policy(key["KeyMetadata"]["Arn"], "default", "new policy") - policy = conn.get_key_policy(key['KeyMetadata']['KeyId'], 'default') - policy['Policy'].should.equal('new policy') + policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") + policy["Policy"].should.equal("new policy") @mock_kms_deprecated def test_list_key_policies(): - conn = boto.kms.connect_to_region('us-west-2') + conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy='my policy', - description='my key1', key_usage='ENCRYPT_DECRYPT') - key_id = key['KeyMetadata']['KeyId'] + key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key_id = key["KeyMetadata"]["KeyId"] policies = conn.list_key_policies(key_id) - policies['PolicyNames'].should.equal(['default']) + policies["PolicyNames"].should.equal(["default"]) @mock_kms_deprecated def test__create_alias__returns_none_if_correct(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - resp = kms.create_alias('alias/my-alias', key_id) + resp = kms.create_alias("alias/my-alias", key_id) resp.should.be.none @@ -336,14 +395,9 @@ def test__create_alias__returns_none_if_correct(): def test__create_alias__raises_if_reserved_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - reserved_aliases = [ - 'alias/aws/ebs', - 'alias/aws/s3', - 'alias/aws/redshift', - 'alias/aws/rds', - ] + reserved_aliases = ["alias/aws/ebs", "alias/aws/s3", "alias/aws/redshift", "alias/aws/rds"] for alias_name in reserved_aliases: with assert_raises(JSONResponseError) as err: @@ -351,9 +405,9 @@ def test__create_alias__raises_if_reserved_alias(): ex = err.exception ex.error_message.should.be.none - ex.error_code.should.equal('NotAuthorizedException') - ex.body.should.equal({'__type': 'NotAuthorizedException'}) - ex.reason.should.equal('Bad Request') + ex.error_code.should.equal("NotAuthorizedException") + ex.body.should.equal({"__type": "NotAuthorizedException"}) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -361,38 +415,37 @@ def test__create_alias__raises_if_reserved_alias(): def test__create_alias__can_create_multiple_aliases_for_same_key_id(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - kms.create_alias('alias/my-alias3', key_id).should.be.none - kms.create_alias('alias/my-alias4', key_id).should.be.none - kms.create_alias('alias/my-alias5', key_id).should.be.none + kms.create_alias("alias/my-alias3", key_id).should.be.none + kms.create_alias("alias/my-alias4", key_id).should.be.none + kms.create_alias("alias/my-alias5", key_id).should.be.none @mock_kms_deprecated def test__create_alias__raises_if_wrong_prefix(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] with assert_raises(JSONResponseError) as err: - kms.create_alias('wrongprefix/my-alias', key_id) + kms.create_alias("wrongprefix/my-alias", key_id) ex = err.exception - ex.error_message.should.equal('Invalid identifier') - ex.error_code.should.equal('ValidationException') - ex.body.should.equal({'message': 'Invalid identifier', - '__type': 'ValidationException'}) - ex.reason.should.equal('Bad Request') + ex.error_message.should.equal("Invalid identifier") + ex.error_code.should.equal("ValidationException") + ex.body.should.equal({"message": "Invalid identifier", "__type": "ValidationException"}) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_kms_deprecated def test__create_alias__raises_if_duplicate(): - region = 'us-west-2' + region = "us-west-2" kms = boto.kms.connect_to_region(region) create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" kms.create_alias(alias, key_id) @@ -400,15 +453,17 @@ def test__create_alias__raises_if_duplicate(): kms.create_alias(alias, key_id) ex = err.exception - ex.error_message.should.match(r'An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists' - .format(**locals())) + ex.error_message.should.match( + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format(**locals()) + ) ex.error_code.should.be.none ex.box_usage.should.be.none ex.request_id.should.be.none - ex.body['message'].should.match(r'An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists' - .format(**locals())) - ex.body['__type'].should.equal('AlreadyExistsException') - ex.reason.should.equal('Bad Request') + ex.body["message"].should.match( + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format(**locals()) + ) + ex.body["__type"].should.equal("AlreadyExistsException") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -416,25 +471,27 @@ def test__create_alias__raises_if_duplicate(): def test__create_alias__raises_if_alias_has_restricted_characters(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_restricted_characters = [ - 'alias/my-alias!', - 'alias/my-alias$', - 'alias/my-alias@', - ] + alias_names_with_restricted_characters = ["alias/my-alias!", "alias/my-alias$", "alias/my-alias@"] for alias_name in alias_names_with_restricted_characters: with assert_raises(JSONResponseError) as err: kms.create_alias(alias_name, key_id) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal( - "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format(**locals())) - ex.error_code.should.equal('ValidationException') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal( + "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format( + **locals() + ) + ) + ex.error_code.should.equal("ValidationException") ex.message.should.equal( - "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format(**locals())) - ex.reason.should.equal('Bad Request') + "1 validation error detected: Value '{alias_name}' at 'aliasName' failed to satisfy constraint: Member must satisfy regular expression pattern: ^[a-zA-Z0-9:/_-]+$".format( + **locals() + ) + ) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -444,47 +501,41 @@ def test__create_alias__raises_if_alias_has_colon_character(): # are accepted by regex ^[a-zA-Z0-9:/_-]+$ kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_restricted_characters = [ - 'alias/my:alias', - ] + alias_names_with_restricted_characters = ["alias/my:alias"] for alias_name in alias_names_with_restricted_characters: with assert_raises(JSONResponseError) as err: kms.create_alias(alias_name, key_id) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal( - "{alias_name} contains invalid characters for an alias".format(**locals())) - ex.error_code.should.equal('ValidationException') - ex.message.should.equal( - "{alias_name} contains invalid characters for an alias".format(**locals())) - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("{alias_name} contains invalid characters for an alias".format(**locals())) + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("{alias_name} contains invalid characters for an alias".format(**locals())) + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) +@parameterized(( + ("alias/my-alias_/",), + ("alias/my_alias-/",), +)) @mock_kms_deprecated -def test__create_alias__accepted_characters(): +def test__create_alias__accepted_characters(alias_name): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] + key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_accepted_characters = [ - 'alias/my-alias_/', - 'alias/my_alias-/', - ] - - for alias_name in alias_names_with_accepted_characters: - kms.create_alias(alias_name, key_id) + kms.create_alias(alias_name, key_id) @mock_kms_deprecated def test__create_alias__raises_if_target_key_id_is_existing_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" kms.create_alias(alias, key_id) @@ -492,11 +543,11 @@ def test__create_alias__raises_if_target_key_id_is_existing_alias(): kms.create_alias(alias, alias) ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal('Aliases must refer to keys. Not aliases') - ex.error_code.should.equal('ValidationException') - ex.message.should.equal('Aliases must refer to keys. Not aliases') - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("Aliases must refer to keys. Not aliases") + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("Aliases must refer to keys. Not aliases") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -504,14 +555,14 @@ def test__create_alias__raises_if_target_key_id_is_existing_alias(): def test__delete_alias(): kms = boto.connect_kms() create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - alias = 'alias/my-alias' + key_id = create_resp["KeyMetadata"]["KeyId"] + alias = "alias/my-alias" # added another alias here to make sure that the deletion of the alias can # be done when there are multiple existing aliases. another_create_resp = kms.create_key() - another_key_id = create_resp['KeyMetadata']['KeyId'] - another_alias = 'alias/another-alias' + another_key_id = create_resp["KeyMetadata"]["KeyId"] + another_alias = "alias/another-alias" kms.create_alias(alias, key_id) kms.create_alias(another_alias, another_key_id) @@ -529,35 +580,37 @@ def test__delete_alias__raises_if_wrong_prefix(): kms = boto.connect_kms() with assert_raises(JSONResponseError) as err: - kms.delete_alias('wrongprefix/my-alias') + kms.delete_alias("wrongprefix/my-alias") ex = err.exception - ex.body['__type'].should.equal('ValidationException') - ex.body['message'].should.equal('Invalid identifier') - ex.error_code.should.equal('ValidationException') - ex.message.should.equal('Invalid identifier') - ex.reason.should.equal('Bad Request') + ex.body["__type"].should.equal("ValidationException") + ex.body["message"].should.equal("Invalid identifier") + ex.error_code.should.equal("ValidationException") + ex.message.should.equal("Invalid identifier") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_kms_deprecated def test__delete_alias__raises_if_alias_is_not_found(): - region = 'us-west-2' + region = "us-west-2" kms = boto.kms.connect_to_region(region) - alias_name = 'alias/unexisting-alias' + alias_name = "alias/unexisting-alias" with assert_raises(NotFoundException) as err: kms.delete_alias(alias_name) + expected_message_match = r"Alias arn:aws:kms:{region}:[0-9]{{12}}:{alias_name} is not found.".format( + region=region, + alias_name=alias_name + ) ex = err.exception - ex.body['__type'].should.equal('NotFoundException') - ex.body['message'].should.match( - r'Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.'.format(**locals())) + ex.body["__type"].should.equal("NotFoundException") + ex.body["message"].should.match(expected_message_match) ex.box_usage.should.be.none ex.error_code.should.be.none - ex.message.should.match( - r'Alias arn:aws:kms:{region}:\d{{12}}:{alias_name} is not found.'.format(**locals())) - ex.reason.should.equal('Bad Request') + ex.message.should.match(expected_message_match) + ex.reason.should.equal("Bad Request") ex.request_id.should.be.none ex.status.should.equal(400) @@ -568,198 +621,176 @@ def test__list_aliases(): kms = boto.kms.connect_to_region(region) create_resp = kms.create_key() - key_id = create_resp['KeyMetadata']['KeyId'] - kms.create_alias('alias/my-alias1', key_id) - kms.create_alias('alias/my-alias2', key_id) - kms.create_alias('alias/my-alias3', key_id) + key_id = create_resp["KeyMetadata"]["KeyId"] + kms.create_alias("alias/my-alias1", key_id) + kms.create_alias("alias/my-alias2", key_id) + kms.create_alias("alias/my-alias3", key_id) resp = kms.list_aliases() - resp['Truncated'].should.be.false + resp["Truncated"].should.be.false - aliases = resp['Aliases'] + aliases = resp["Aliases"] def has_correct_arn(alias_obj): - alias_name = alias_obj['AliasName'] - alias_arn = alias_obj['AliasArn'] - return re.match(r'arn:aws:kms:{region}:\d{{12}}:{alias_name}'.format(region=region, alias_name=alias_name), - alias_arn) + alias_name = alias_obj["AliasName"] + alias_arn = alias_obj["AliasArn"] + return re.match( + r"arn:aws:kms:{region}:\d{{12}}:{alias_name}".format(region=region, alias_name=alias_name), alias_arn + ) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/ebs' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/rds' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/redshift' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/aws/s3' == alias['AliasName']]).should.equal(1) + len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/ebs" == alias["AliasName"]]).should.equal( + 1 + ) + len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/rds" == alias["AliasName"]]).should.equal( + 1 + ) + len( + [alias for alias in aliases if has_correct_arn(alias) and "alias/aws/redshift" == alias["AliasName"]] + ).should.equal(1) + len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/s3" == alias["AliasName"]]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/my-alias1' == alias['AliasName']]).should.equal(1) - len([alias for alias in aliases if - has_correct_arn(alias) and 'alias/my-alias2' == alias['AliasName']]).should.equal(1) + len( + [alias for alias in aliases if has_correct_arn(alias) and "alias/my-alias1" == alias["AliasName"]] + ).should.equal(1) + len( + [alias for alias in aliases if has_correct_arn(alias) and "alias/my-alias2" == alias["AliasName"]] + ).should.equal(1) - len([alias for alias in aliases if 'TargetKeyId' in alias and key_id == - alias['TargetKeyId']]).should.equal(3) + len([alias for alias in aliases if "TargetKeyId" in alias and key_id == alias["TargetKeyId"]]).should.equal(3) len(aliases).should.equal(7) -@mock_kms_deprecated -def test__assert_valid_key_id(): - from moto.kms.responses import _assert_valid_key_id - import uuid +@parameterized(( + ("not-a-uuid",), + ("alias/DoesNotExist",), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), + ("d25652e4-d2d2-49f7-929a-671ccda580c6",), + ("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",), +)) +@mock_kms +def test_invalid_key_ids(key_id): + client = boto3.client("kms", region_name="us-east-1") - _assert_valid_key_id.when.called_with( - "not-a-key").should.throw(MotoNotFoundException) - _assert_valid_key_id.when.called_with( - str(uuid.uuid4())).should_not.throw(MotoNotFoundException) + with assert_raises(client.exceptions.NotFoundException): + client.generate_data_key(KeyId=key_id, NumberOfBytes=5) @mock_kms_deprecated def test__assert_default_policy(): from moto.kms.responses import _assert_default_policy - _assert_default_policy.when.called_with( - "not-default").should.throw(MotoNotFoundException) - _assert_default_policy.when.called_with( - "default").should_not.throw(MotoNotFoundException) + _assert_default_policy.when.called_with("not-default").should.throw(MotoNotFoundException) + _assert_default_policy.when.called_with("default").should_not.throw(MotoNotFoundException) +@parameterized(PLAINTEXT_VECTORS) @mock_kms -def test_kms_encrypt_boto3(): - client = boto3.client('kms', region_name='us-east-1') - response = client.encrypt(KeyId='foo', Plaintext=b'bar') +def test_kms_encrypt_boto3(plaintext): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="key") + response = client.encrypt(KeyId=key["KeyMetadata"]["KeyId"], Plaintext=plaintext) - response = client.decrypt(CiphertextBlob=response['CiphertextBlob']) - response['Plaintext'].should.equal(b'bar') + response = client.decrypt(CiphertextBlob=response["CiphertextBlob"]) + response["Plaintext"].should.equal(_get_encoded_value(plaintext)) @mock_kms def test_disable_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='disable-key') - client.disable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="disable-key") + client.disable_key(KeyId=key["KeyMetadata"]["KeyId"]) - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'Disabled' + assert result["KeyMetadata"]["KeyState"] == "Disabled" @mock_kms def test_enable_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='enable-key') - client.disable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) - client.enable_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="enable-key") + client.disable_key(KeyId=key["KeyMetadata"]["KeyId"]) + client.enable_key(KeyId=key["KeyMetadata"]["KeyId"]) - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == True - assert result["KeyMetadata"]["KeyState"] == 'Enabled' + assert result["KeyMetadata"]["KeyState"] == "Enabled" @mock_kms def test_schedule_key_deletion(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='schedule-key-deletion') - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="schedule-key-deletion") + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": with freeze_time("2015-01-01 12:00:00"): - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] - assert response['DeletionDate'] == datetime(2015, 1, 31, 12, 0, tzinfo=tzutc()) + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] + assert response["DeletionDate"] == datetime(2015, 1, 31, 12, 0, tzinfo=tzutc()) else: # Can't manipulate time in server mode - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion' - assert 'DeletionDate' in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "PendingDeletion" + assert "DeletionDate" in result["KeyMetadata"] @mock_kms def test_schedule_key_deletion_custom(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='schedule-key-deletion') - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="schedule-key-deletion") + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": with freeze_time("2015-01-01 12:00:00"): - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'], - PendingWindowInDays=7 - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] - assert response['DeletionDate'] == datetime(2015, 1, 8, 12, 0, tzinfo=tzutc()) + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] + assert response["DeletionDate"] == datetime(2015, 1, 8, 12, 0, tzinfo=tzutc()) else: # Can't manipulate time in server mode - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'], - PendingWindowInDays=7 - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'PendingDeletion' - assert 'DeletionDate' in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "PendingDeletion" + assert "DeletionDate" in result["KeyMetadata"] @mock_kms def test_cancel_key_deletion(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - response = client.cancel_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) - assert response['KeyId'] == key['KeyMetadata']['KeyId'] + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + response = client.cancel_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) + assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - result = client.describe_key(KeyId=key['KeyMetadata']['KeyId']) + result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) assert result["KeyMetadata"]["Enabled"] == False - assert result["KeyMetadata"]["KeyState"] == 'Disabled' - assert 'DeletionDate' not in result["KeyMetadata"] + assert result["KeyMetadata"]["KeyState"] == "Disabled" + assert "DeletionDate" not in result["KeyMetadata"] @mock_kms def test_update_key_description(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='old_description') - key_id = key['KeyMetadata']['KeyId'] + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="old_description") + key_id = key["KeyMetadata"]["KeyId"] - result = client.update_key_description(KeyId=key_id, Description='new_description') - assert 'ResponseMetadata' in result + result = client.update_key_description(KeyId=key_id, Description="new_description") + assert "ResponseMetadata" in result @mock_kms def test_tag_resource(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) - keyid = response['KeyId'] - response = client.tag_resource( - KeyId=keyid, - Tags=[ - { - 'TagKey': 'string', - 'TagValue': 'string' - }, - ] - ) + keyid = response["KeyId"] + response = client.tag_resource(KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]) # Shouldn't have any data, just header assert len(response.keys()) == 1 @@ -767,226 +798,279 @@ def test_tag_resource(): @mock_kms def test_list_resource_tags(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='cancel-key-deletion') - response = client.schedule_key_deletion( - KeyId=key['KeyMetadata']['KeyId'] - ) + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="cancel-key-deletion") + response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) - keyid = response['KeyId'] - response = client.tag_resource( - KeyId=keyid, - Tags=[ - { - 'TagKey': 'string', - 'TagValue': 'string' - }, - ] - ) + keyid = response["KeyId"] + response = client.tag_resource(KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]) response = client.list_resource_tags(KeyId=keyid) - assert response['Tags'][0]['TagKey'] == 'string' - assert response['Tags'][0]['TagValue'] == 'string' + assert response["Tags"][0]["TagKey"] == "string" + assert response["Tags"][0]["TagValue"] == "string" +@parameterized(( + (dict(KeySpec="AES_256"), 32), + (dict(KeySpec="AES_128"), 16), + (dict(NumberOfBytes=64), 64), + (dict(NumberOfBytes=1), 1), + (dict(NumberOfBytes=1024), 1024), +)) @mock_kms -def test_generate_data_key_sizes(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') +def test_generate_data_key_sizes(kwargs, expected_key_length): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-size") - resp1 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' - ) - resp2 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_128' - ) - resp3 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - NumberOfBytes=64 - ) + response = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) - assert len(resp1['Plaintext']) == 32 - assert len(resp2['Plaintext']) == 16 - assert len(resp3['Plaintext']) == 64 + assert len(response["Plaintext"]) == expected_key_length @mock_kms def test_generate_data_key_decrypt(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-decrypt') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-decrypt") - resp1 = client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' - ) - resp2 = client.decrypt( - CiphertextBlob=resp1['CiphertextBlob'] - ) + resp1 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") + resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"]) - assert resp1['Plaintext'] == resp2['Plaintext'] + assert resp1["Plaintext"] == resp2["Plaintext"] +@parameterized(( + (dict(KeySpec="AES_257"),), + (dict(KeySpec="AES_128", NumberOfBytes=16),), + (dict(NumberOfBytes=2048),), + (dict(NumberOfBytes=0),), + (dict(),), +)) @mock_kms -def test_generate_data_key_invalid_size_params(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') +def test_generate_data_key_invalid_size_params(kwargs): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-size") - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_257' - ) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_128', - NumberOfBytes=16 - ) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'], - NumberOfBytes=2048 - ) - - with assert_raises(botocore.exceptions.ClientError) as err: - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'] - ) + with assert_raises((botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)) as err: + client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) +@parameterized(( + ("alias/DoesNotExist",), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), + ("d25652e4-d2d2-49f7-929a-671ccda580c6",), + ("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",), +)) @mock_kms -def test_generate_data_key_invalid_key(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-size') +def test_generate_data_key_invalid_key(key_id): + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key( - KeyId='alias/randomnonexistantkey', - KeySpec='AES_256' - ) + client.generate_data_key(KeyId=key_id, KeySpec="AES_256") - with assert_raises(client.exceptions.NotFoundException): - client.generate_data_key( - KeyId=key['KeyMetadata']['KeyId'] + '4', - KeySpec='AES_256' - ) + +@parameterized(( + ("alias/DoesExist", False), + ("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False), + ("", True), + ("arn:aws:kms:us-east-1:012345678912:key/", True), +)) +@mock_kms +def test_generate_data_key_all_valid_key_ids(prefix, append_key_id): + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key() + key_id = key["KeyMetadata"]["KeyId"] + client.create_alias(AliasName="alias/DoesExist", TargetKeyId=key_id) + + target_id = prefix + if append_key_id: + target_id += key_id + + client.generate_data_key(KeyId=key_id, NumberOfBytes=32) @mock_kms def test_generate_data_key_without_plaintext_decrypt(): - client = boto3.client('kms', region_name='us-east-1') - key = client.create_key(Description='generate-data-key-decrypt') + client = boto3.client("kms", region_name="us-east-1") + key = client.create_key(Description="generate-data-key-decrypt") - resp1 = client.generate_data_key_without_plaintext( - KeyId=key['KeyMetadata']['KeyId'], - KeySpec='AES_256' + resp1 = client.generate_data_key_without_plaintext(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") + + assert "Plaintext" not in resp1 + + +@parameterized(PLAINTEXT_VECTORS) +@mock_kms +def test_re_encrypt_decrypt(plaintext): + client = boto3.client("kms", region_name="us-west-2") + + key_1 = client.create_key(Description="key 1") + key_1_id = key_1["KeyMetadata"]["KeyId"] + key_1_arn = key_1["KeyMetadata"]["Arn"] + key_2 = client.create_key(Description="key 2") + key_2_id = key_2["KeyMetadata"]["KeyId"] + key_2_arn = key_2["KeyMetadata"]["Arn"] + + encrypt_response = client.encrypt( + KeyId=key_1_id, + Plaintext=plaintext, + EncryptionContext={"encryption": "context"}, ) - assert 'Plaintext' not in resp1 + re_encrypt_response = client.re_encrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + SourceEncryptionContext={"encryption": "context"}, + DestinationKeyId=key_2_id, + DestinationEncryptionContext={"another": "context"}, + ) + + # CiphertextBlob must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(re_encrypt_response["CiphertextBlob"], validate=True) + + re_encrypt_response["SourceKeyId"].should.equal(key_1_arn) + re_encrypt_response["KeyId"].should.equal(key_2_arn) + + decrypt_response_1 = client.decrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + EncryptionContext={"encryption": "context"}, + ) + decrypt_response_1["Plaintext"].should.equal(_get_encoded_value(plaintext)) + decrypt_response_1["KeyId"].should.equal(key_1_arn) + + decrypt_response_2 = client.decrypt( + CiphertextBlob=re_encrypt_response["CiphertextBlob"], + EncryptionContext={"another": "context"}, + ) + decrypt_response_2["Plaintext"].should.equal(_get_encoded_value(plaintext)) + decrypt_response_2["KeyId"].should.equal(key_2_arn) + + decrypt_response_1["Plaintext"].should.equal(decrypt_response_2["Plaintext"]) + + +@mock_kms +def test_re_encrypt_to_invalid_destination(): + client = boto3.client("kms", region_name="us-west-2") + + key = client.create_key(Description="key 1") + key_id = key["KeyMetadata"]["KeyId"] + + encrypt_response = client.encrypt( + KeyId=key_id, + Plaintext=b"some plaintext", + ) + + with assert_raises(client.exceptions.NotFoundException): + client.re_encrypt( + CiphertextBlob=encrypt_response["CiphertextBlob"], + DestinationKeyId="alias/DoesNotExist", + ) + + +@parameterized(((12,), (44,), (91,), (1,), (1024,))) +@mock_kms +def test_generate_random(number_of_bytes): + client = boto3.client("kms", region_name="us-west-2") + + response = client.generate_random(NumberOfBytes=number_of_bytes) + + # Plaintext must NOT be base64-encoded + with assert_raises(Exception): + base64.b64decode(response["Plaintext"], validate=True) + + response["Plaintext"].should.be.a(bytes) + len(response["Plaintext"]).should.equal(number_of_bytes) + + +@parameterized(( + (2048, botocore.exceptions.ClientError), + (1025, botocore.exceptions.ClientError), + (0, botocore.exceptions.ParamValidationError), + (-1, botocore.exceptions.ParamValidationError), + (-1024, botocore.exceptions.ParamValidationError) +)) +@mock_kms +def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type): + client = boto3.client("kms", region_name="us-west-2") + + with assert_raises(error_type): + client.generate_random(NumberOfBytes=number_of_bytes) @mock_kms def test_enable_key_rotation_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.enable_key_rotation( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.enable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_disable_key_rotation_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.disable_key_rotation( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.disable_key_rotation(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_enable_key_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.enable_key( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.enable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_disable_key_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.disable_key( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.disable_key(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_cancel_key_deletion_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.cancel_key_deletion( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.cancel_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_schedule_key_deletion_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.schedule_key_deletion( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.schedule_key_deletion(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_get_key_rotation_status_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.get_key_rotation_status( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.get_key_rotation_status(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_get_key_policy_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.get_key_policy( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02', - PolicyName='default' - ) + client.get_key_policy(KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default") @mock_kms def test_list_key_policies_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.list_key_policies( - KeyId='12366f9b-1230-123d-123e-123e6ae60c02' - ) + client.list_key_policies(KeyId="12366f9b-1230-123d-123e-123e6ae60c02") @mock_kms def test_put_key_policy_key_not_found(): - client = boto3.client('kms', region_name='us-east-1') + client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.put_key_policy( - KeyId='00000000-0000-0000-0000-000000000000', - PolicyName='default', - Policy='new policy' - ) + client.put_key_policy(KeyId="00000000-0000-0000-0000-000000000000", PolicyName="default", Policy="new policy") diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py new file mode 100644 index 000000000..73d7d3580 --- /dev/null +++ b/tests/test_kms/test_utils.py @@ -0,0 +1,157 @@ +from __future__ import unicode_literals + +import sure # noqa +from nose.tools import assert_raises +from parameterized import parameterized + +from moto.kms.exceptions import AccessDeniedException, InvalidCiphertextException, NotFoundException +from moto.kms.models import Key +from moto.kms.utils import ( + _deserialize_ciphertext_blob, + _serialize_ciphertext_blob, + _serialize_encryption_context, + generate_data_key, + generate_master_key, + MASTER_KEY_LEN, + encrypt, + decrypt, + Ciphertext, +) + +ENCRYPTION_CONTEXT_VECTORS = ( + ({"this": "is", "an": "encryption", "context": "example"}, b"an" b"encryption" b"context" b"example" b"this" b"is"), + ({"a_this": "one", "b_is": "actually", "c_in": "order"}, b"a_this" b"one" b"b_is" b"actually" b"c_in" b"order"), +) +CIPHERTEXT_BLOB_VECTORS = ( + ( + Ciphertext( + key_id="d25652e4-d2d2-49f7-929a-671ccda580c6", + iv=b"123456789012", + ciphertext=b"some ciphertext", + tag=b"1234567890123456", + ), + b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext", + ), + ( + Ciphertext( + key_id="d25652e4-d2d2-49f7-929a-671ccda580c6", + iv=b"123456789012", + ciphertext=b"some ciphertext that is much longer now", + tag=b"1234567890123456", + ), + b"d25652e4-d2d2-49f7-929a-671ccda580c6" + b"123456789012" + b"1234567890123456" + b"some ciphertext that is much longer now", + ), +) + + +def test_generate_data_key(): + test = generate_data_key(123) + + test.should.be.a(bytes) + len(test).should.equal(123) + + +def test_generate_master_key(): + test = generate_master_key() + + test.should.be.a(bytes) + len(test).should.equal(MASTER_KEY_LEN) + + +@parameterized(ENCRYPTION_CONTEXT_VECTORS) +def test_serialize_encryption_context(raw, serialized): + test = _serialize_encryption_context(raw) + test.should.equal(serialized) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_cycle_ciphertext_blob(raw, _serialized): + test_serialized = _serialize_ciphertext_blob(raw) + test_deserialized = _deserialize_ciphertext_blob(test_serialized) + test_deserialized.should.equal(raw) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_serialize_ciphertext_blob(raw, serialized): + test = _serialize_ciphertext_blob(raw) + test.should.equal(serialized) + + +@parameterized(CIPHERTEXT_BLOB_VECTORS) +def test_deserialize_ciphertext_blob(raw, serialized): + test = _deserialize_ciphertext_blob(serialized) + test.should.equal(raw) + + +@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS)) +def test_encrypt_decrypt_cycle(encryption_context): + plaintext = b"some secret plaintext" + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + ciphertext_blob = encrypt( + master_keys=master_key_map, key_id=master_key.id, plaintext=plaintext, encryption_context=encryption_context + ) + ciphertext_blob.should_not.equal(plaintext) + + decrypted, decrypting_key_id = decrypt( + master_keys=master_key_map, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context + ) + decrypted.should.equal(plaintext) + decrypting_key_id.should.equal(master_key.id) + + +def test_encrypt_unknown_key_id(): + with assert_raises(NotFoundException): + encrypt(master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={}) + + +def test_decrypt_invalid_ciphertext_format(): + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + with assert_raises(InvalidCiphertextException): + decrypt(master_keys=master_key_map, ciphertext_blob=b"", encryption_context={}) + + +def test_decrypt_unknwown_key_id(): + ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext" + + with assert_raises(AccessDeniedException): + decrypt(master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={}) + + +def test_decrypt_invalid_ciphertext(): + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext" + + with assert_raises(InvalidCiphertextException): + decrypt( + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) + + +def test_decrypt_invalid_encryption_context(): + plaintext = b"some secret plaintext" + master_key = Key("nop", "nop", "nop", [], "nop") + master_key_map = {master_key.id: master_key} + + ciphertext_blob = encrypt( + master_keys=master_key_map, + key_id=master_key.id, + plaintext=plaintext, + encryption_context={"some": "encryption", "context": "here"}, + ) + + with assert_raises(InvalidCiphertextException): + decrypt( + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context={}, + ) diff --git a/tests/test_logs/test_logs.py b/tests/test_logs/test_logs.py index 49e593fdc..0a63308c2 100644 --- a/tests/test_logs/test_logs.py +++ b/tests/test_logs/test_logs.py @@ -190,6 +190,8 @@ def test_get_log_events(): resp['events'].should.have.length_of(10) resp.should.have.key('nextForwardToken') resp.should.have.key('nextBackwardToken') + resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000010') + resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000') for i in range(10): resp['events'][i]['timestamp'].should.equal(i) resp['events'][i]['message'].should.equal(str(i)) @@ -205,7 +207,8 @@ def test_get_log_events(): resp['events'].should.have.length_of(10) resp.should.have.key('nextForwardToken') resp.should.have.key('nextBackwardToken') - resp['nextForwardToken'].should.equal(next_token) + resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000020') + resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000') for i in range(10): resp['events'][i]['timestamp'].should.equal(i+10) resp['events'][i]['message'].should.equal(str(i+10)) diff --git a/tests/test_redshift/test_redshift.py b/tests/test_redshift/test_redshift.py index 2c9b42a1d..79e283e5b 100644 --- a/tests/test_redshift/test_redshift.py +++ b/tests/test_redshift/test_redshift.py @@ -37,6 +37,25 @@ def test_create_cluster_boto3(): create_time = response['Cluster']['ClusterCreateTime'] create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) + response['Cluster']['EnhancedVpcRouting'].should.equal(False) + +@mock_redshift +def test_create_cluster_boto3(): + client = boto3.client('redshift', region_name='us-east-1') + response = client.create_cluster( + DBName='test', + ClusterIdentifier='test', + ClusterType='single-node', + NodeType='ds2.xlarge', + MasterUsername='user', + MasterUserPassword='password', + EnhancedVpcRouting=True + ) + response['Cluster']['NodeType'].should.equal('ds2.xlarge') + create_time = response['Cluster']['ClusterCreateTime'] + create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) + create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) + response['Cluster']['EnhancedVpcRouting'].should.equal(True) @mock_redshift @@ -425,6 +444,58 @@ def test_delete_cluster(): "not-a-cluster").should.throw(ClusterNotFound) +@mock_redshift +def test_modify_cluster_vpc_routing(): + iam_roles_arn = ['arn:aws:iam:::role/my-iam-role', ] + client = boto3.client('redshift', region_name='us-east-1') + cluster_identifier = 'my_cluster' + + client.create_cluster( + ClusterIdentifier=cluster_identifier, + NodeType="single-node", + MasterUsername="username", + MasterUserPassword="password", + IamRoles=iam_roles_arn + ) + + cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster = cluster_response['Clusters'][0] + cluster['EnhancedVpcRouting'].should.equal(False) + + client.create_cluster_security_group(ClusterSecurityGroupName='security_group', + Description='security_group') + + client.create_cluster_parameter_group(ParameterGroupName='my_parameter_group', + ParameterGroupFamily='redshift-1.0', + Description='my_parameter_group') + + client.modify_cluster( + ClusterIdentifier=cluster_identifier, + ClusterType='multi-node', + NodeType="ds2.8xlarge", + NumberOfNodes=3, + ClusterSecurityGroups=["security_group"], + MasterUserPassword="new_password", + ClusterParameterGroupName="my_parameter_group", + AutomatedSnapshotRetentionPeriod=7, + PreferredMaintenanceWindow="Tue:03:00-Tue:11:00", + AllowVersionUpgrade=False, + NewClusterIdentifier=cluster_identifier, + EnhancedVpcRouting=True + ) + + cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + cluster = cluster_response['Clusters'][0] + cluster['ClusterIdentifier'].should.equal(cluster_identifier) + cluster['NodeType'].should.equal("ds2.8xlarge") + cluster['PreferredMaintenanceWindow'].should.equal("Tue:03:00-Tue:11:00") + cluster['AutomatedSnapshotRetentionPeriod'].should.equal(7) + cluster['AllowVersionUpgrade'].should.equal(False) + # This one should remain unmodified. + cluster['NumberOfNodes'].should.equal(3) + cluster['EnhancedVpcRouting'].should.equal(True) + + @mock_redshift_deprecated def test_modify_cluster(): conn = boto.connect_redshift() @@ -446,6 +517,10 @@ def test_modify_cluster(): master_user_password="password", ) + cluster_response = conn.describe_clusters(cluster_identifier) + cluster = cluster_response['DescribeClustersResponse']['DescribeClustersResult']['Clusters'][0] + cluster['EnhancedVpcRouting'].should.equal(False) + conn.modify_cluster( cluster_identifier, cluster_type="multi-node", @@ -456,14 +531,13 @@ def test_modify_cluster(): automated_snapshot_retention_period=7, preferred_maintenance_window="Tue:03:00-Tue:11:00", allow_version_upgrade=False, - new_cluster_identifier="new_identifier", + new_cluster_identifier=cluster_identifier, ) - cluster_response = conn.describe_clusters("new_identifier") + cluster_response = conn.describe_clusters(cluster_identifier) cluster = cluster_response['DescribeClustersResponse'][ 'DescribeClustersResult']['Clusters'][0] - - cluster['ClusterIdentifier'].should.equal("new_identifier") + cluster['ClusterIdentifier'].should.equal(cluster_identifier) cluster['NodeType'].should.equal("dw.hs1.xlarge") cluster['ClusterSecurityGroups'][0][ 'ClusterSecurityGroupName'].should.equal("security_group") @@ -674,6 +748,7 @@ def test_create_cluster_snapshot(): NodeType='ds2.xlarge', MasterUsername='username', MasterUserPassword='password', + EnhancedVpcRouting=True ) cluster_response['Cluster']['NodeType'].should.equal('ds2.xlarge') @@ -823,11 +898,14 @@ def test_create_cluster_from_snapshot(): NodeType='ds2.xlarge', MasterUsername='username', MasterUserPassword='password', + EnhancedVpcRouting=True, ) + client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, ClusterIdentifier=original_cluster_identifier ) + response = client.restore_from_cluster_snapshot( ClusterIdentifier=new_cluster_identifier, SnapshotIdentifier=original_snapshot_identifier, @@ -842,7 +920,7 @@ def test_create_cluster_from_snapshot(): new_cluster['NodeType'].should.equal('ds2.xlarge') new_cluster['MasterUsername'].should.equal('username') new_cluster['Endpoint']['Port'].should.equal(1234) - + new_cluster['EnhancedVpcRouting'].should.equal(True) @mock_redshift def test_create_cluster_from_snapshot_with_waiter(): @@ -857,6 +935,7 @@ def test_create_cluster_from_snapshot_with_waiter(): NodeType='ds2.xlarge', MasterUsername='username', MasterUserPassword='password', + EnhancedVpcRouting=True ) client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, @@ -883,6 +962,7 @@ def test_create_cluster_from_snapshot_with_waiter(): new_cluster = response['Clusters'][0] new_cluster['NodeType'].should.equal('ds2.xlarge') new_cluster['MasterUsername'].should.equal('username') + new_cluster['EnhancedVpcRouting'].should.equal(True) new_cluster['Endpoint']['Port'].should.equal(1234) diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py index de9465d6d..babd54d26 100644 --- a/tests/test_route53/test_route53.py +++ b/tests/test_route53/test_route53.py @@ -404,6 +404,11 @@ def test_list_or_change_tags_for_resource_request(): ) healthcheck_id = health_check['HealthCheck']['Id'] + # confirm this works for resources with zero tags + response = conn.list_tags_for_resource( + ResourceType="healthcheck", ResourceId=healthcheck_id) + response["ResourceTagSet"]["Tags"].should.be.empty + tag1 = {"Key": "Deploy", "Value": "True"} tag2 = {"Key": "Name", "Value": "UnitTest"} diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index cd57fc92b..336639a8c 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import datetime +import os from six.moves.urllib.request import urlopen from six.moves.urllib.error import HTTPError from functools import wraps @@ -20,9 +21,11 @@ from botocore.handlers import disable_signing from boto.s3.connection import S3Connection from boto.s3.key import Key from freezegun import freeze_time +from parameterized import parameterized import six import requests import tests.backport_assert_raises # noqa +from nose import SkipTest from nose.tools import assert_raises import sure # noqa @@ -1390,6 +1393,34 @@ def test_boto3_list_objects_v2_fetch_owner(): assert len(owner.keys()) == 2 +@mock_s3 +def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): + s3 = boto3.client('s3', region_name='us-east-1') + s3.create_bucket(Bucket='mybucket') + s3.put_object(Bucket='mybucket', Key='1/2', Body='') + s3.put_object(Bucket='mybucket', Key='2', Body='') + s3.put_object(Bucket='mybucket', Key='3/4', Body='') + s3.put_object(Bucket='mybucket', Key='4', Body='') + + resp = s3.list_objects_v2(Bucket='mybucket', Prefix='', MaxKeys=2, Delimiter='/') + assert 'Delimiter' in resp + assert resp['IsTruncated'] is True + assert resp['KeyCount'] == 2 + assert len(resp['Contents']) == 1 + assert resp['Contents'][0]['Key'] == '2' + assert len(resp['CommonPrefixes']) == 1 + assert resp['CommonPrefixes'][0]['Prefix'] == '1/' + + last_tail = resp['NextContinuationToken'] + resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Prefix='', Delimiter='/', StartAfter=last_tail) + assert resp['KeyCount'] == 2 + assert resp['IsTruncated'] is False + assert len(resp['Contents']) == 1 + assert resp['Contents'][0]['Key'] == '4' + assert len(resp['CommonPrefixes']) == 1 + assert resp['CommonPrefixes'][0]['Prefix'] == '3/' + + @mock_s3 def test_boto3_bucket_create(): s3 = boto3.resource('s3', region_name='us-east-1') @@ -2991,3 +3022,64 @@ def test_accelerate_configuration_is_not_supported_when_bucket_name_has_dots(): AccelerateConfiguration={'Status': 'Enabled'}, ) exc.exception.response['Error']['Code'].should.equal('InvalidRequest') + +def store_and_read_back_a_key(key): + s3 = boto3.client('s3', region_name='us-east-1') + bucket_name = 'mybucket' + body = b'Some body' + + s3.create_bucket(Bucket=bucket_name) + s3.put_object( + Bucket=bucket_name, + Key=key, + Body=body + ) + + response = s3.get_object(Bucket=bucket_name, Key=key) + response['Body'].read().should.equal(body) + +@mock_s3 +def test_paths_with_leading_slashes_work(): + store_and_read_back_a_key('/a-key') + +@mock_s3 +def test_root_dir_with_empty_name_works(): + if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': + raise SkipTest('Does not work in server mode due to error in Workzeug') + store_and_read_back_a_key('/') + + +@parameterized([ + ('foo/bar/baz',), + ('foo',), + ('foo/run_dt%3D2019-01-01%252012%253A30%253A00',), +]) +@mock_s3 +def test_delete_objects_with_url_encoded_key(key): + s3 = boto3.client('s3', region_name='us-east-1') + bucket_name = 'mybucket' + body = b'Some body' + + s3.create_bucket(Bucket=bucket_name) + + def put_object(): + s3.put_object( + Bucket=bucket_name, + Key=key, + Body=body + ) + + def assert_deleted(): + with assert_raises(ClientError) as e: + s3.get_object(Bucket=bucket_name, Key=key) + + e.exception.response['Error']['Code'].should.equal('NoSuchKey') + + put_object() + s3.delete_object(Bucket=bucket_name, Key=key) + assert_deleted() + + put_object() + s3.delete_objects(Bucket=bucket_name, Delete={'Objects': [{'Key': key}]}) + assert_deleted() + diff --git a/tests/test_s3/test_s3_utils.py b/tests/test_s3/test_s3_utils.py index ce9f54c75..93a4683e6 100644 --- a/tests/test_s3/test_s3_utils.py +++ b/tests/test_s3/test_s3_utils.py @@ -1,7 +1,8 @@ from __future__ import unicode_literals import os from sure import expect -from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url +from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url, clean_key_name, undo_clean_key_name +from parameterized import parameterized def test_base_url(): @@ -78,3 +79,29 @@ def test_parse_region_from_url(): 'https://s3.amazonaws.com/bucket', 'https://bucket.s3.amazonaws.com']: parse_region_from_url(url).should.equal(expected) + + +@parameterized([ + ('foo/bar/baz', + 'foo/bar/baz'), + ('foo', + 'foo'), + ('foo/run_dt%3D2019-01-01%252012%253A30%253A00', + 'foo/run_dt=2019-01-01%2012%3A30%3A00'), +]) +def test_clean_key_name(key, expected): + clean_key_name(key).should.equal(expected) + + +@parameterized([ + ('foo/bar/baz', + 'foo/bar/baz'), + ('foo', + 'foo'), + ('foo/run_dt%3D2019-01-01%252012%253A30%253A00', + 'foo/run_dt%253D2019-01-01%25252012%25253A30%25253A00'), +]) +def test_undo_clean_key_name(key, expected): + undo_clean_key_name(key).should.equal(expected) + + diff --git a/tests/test_secretsmanager/test_secretsmanager.py b/tests/test_secretsmanager/test_secretsmanager.py index 78b95ee6a..62de93bab 100644 --- a/tests/test_secretsmanager/test_secretsmanager.py +++ b/tests/test_secretsmanager/test_secretsmanager.py @@ -5,9 +5,9 @@ import boto3 from moto import mock_secretsmanager from botocore.exceptions import ClientError import string -import unittest import pytz from datetime import datetime +import sure # noqa from nose.tools import assert_raises from six import b @@ -23,6 +23,7 @@ def test_get_secret_value(): result = conn.get_secret_value(SecretId='java-util-test-password') assert result['SecretString'] == 'foosecret' + @mock_secretsmanager def test_get_secret_value_binary(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -32,6 +33,7 @@ def test_get_secret_value_binary(): result = conn.get_secret_value(SecretId='java-util-test-password') assert result['SecretBinary'] == b('foosecret') + @mock_secretsmanager def test_get_secret_that_does_not_exist(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -39,6 +41,7 @@ def test_get_secret_that_does_not_exist(): with assert_raises(ClientError): result = conn.get_secret_value(SecretId='i-dont-exist') + @mock_secretsmanager def test_get_secret_that_does_not_match(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -72,6 +75,7 @@ def test_create_secret(): secret = conn.get_secret_value(SecretId='test-secret') assert secret['SecretString'] == 'foosecret' + @mock_secretsmanager def test_create_secret_with_tags(): conn = boto3.client('secretsmanager', region_name='us-east-1') @@ -216,6 +220,7 @@ def test_get_random_exclude_lowercase(): ExcludeLowercase=True) assert any(c.islower() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_uppercase(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -224,6 +229,7 @@ def test_get_random_exclude_uppercase(): ExcludeUppercase=True) assert any(c.isupper() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_characters_and_symbols(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -232,6 +238,7 @@ def test_get_random_exclude_characters_and_symbols(): ExcludeCharacters='xyzDje@?!.') assert any(c in 'xyzDje@?!.' for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_numbers(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -240,6 +247,7 @@ def test_get_random_exclude_numbers(): ExcludeNumbers=True) assert any(c.isdigit() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_exclude_punctuation(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -249,6 +257,7 @@ def test_get_random_exclude_punctuation(): assert any(c in string.punctuation for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_include_space_false(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -256,6 +265,7 @@ def test_get_random_include_space_false(): random_password = conn.get_random_password(PasswordLength=300) assert any(c.isspace() for c in random_password['RandomPassword']) == False + @mock_secretsmanager def test_get_random_include_space_true(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -264,6 +274,7 @@ def test_get_random_include_space_true(): IncludeSpace=True) assert any(c.isspace() for c in random_password['RandomPassword']) == True + @mock_secretsmanager def test_get_random_require_each_included_type(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -275,6 +286,7 @@ def test_get_random_require_each_included_type(): assert any(c in string.ascii_uppercase for c in random_password['RandomPassword']) == True assert any(c in string.digits for c in random_password['RandomPassword']) == True + @mock_secretsmanager def test_get_random_too_short_password(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -282,6 +294,7 @@ def test_get_random_too_short_password(): with assert_raises(ClientError): random_password = conn.get_random_password(PasswordLength=3) + @mock_secretsmanager def test_get_random_too_long_password(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -289,6 +302,7 @@ def test_get_random_too_long_password(): with assert_raises(Exception): random_password = conn.get_random_password(PasswordLength=5555) + @mock_secretsmanager def test_describe_secret(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -307,6 +321,7 @@ def test_describe_secret(): assert secret_description_2['Name'] == ('test-secret-2') assert secret_description_2['ARN'] != '' # Test arn not empty + @mock_secretsmanager def test_describe_secret_that_does_not_exist(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -314,6 +329,7 @@ def test_describe_secret_that_does_not_exist(): with assert_raises(ClientError): result = conn.get_secret_value(SecretId='i-dont-exist') + @mock_secretsmanager def test_describe_secret_that_does_not_match(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -500,6 +516,7 @@ def test_rotate_secret_rotation_period_zero(): # test_server actually handles this error. assert True + @mock_secretsmanager def test_rotate_secret_rotation_period_too_long(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -511,6 +528,7 @@ def test_rotate_secret_rotation_period_too_long(): result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, RotationRules=rotation_rules) + @mock_secretsmanager def test_put_secret_value_puts_new_secret(): conn = boto3.client('secretsmanager', region_name='us-west-2') @@ -526,6 +544,45 @@ def test_put_secret_value_puts_new_secret(): assert get_secret_value_dict assert get_secret_value_dict['SecretString'] == 'foosecret' + +@mock_secretsmanager +def test_put_secret_binary_value_puts_new_secret(): + conn = boto3.client('secretsmanager', region_name='us-west-2') + put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, + SecretBinary=b('foosecret'), + VersionStages=['AWSCURRENT']) + version_id = put_secret_value_dict['VersionId'] + + get_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME, + VersionId=version_id, + VersionStage='AWSCURRENT') + + assert get_secret_value_dict + assert get_secret_value_dict['SecretBinary'] == b('foosecret') + + +@mock_secretsmanager +def test_create_and_put_secret_binary_value_puts_new_secret(): + conn = boto3.client('secretsmanager', region_name='us-west-2') + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret")) + conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, SecretBinary=b('foosecret_update')) + + latest_secret = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME) + + assert latest_secret + assert latest_secret['SecretBinary'] == b('foosecret_update') + + +@mock_secretsmanager +def test_put_secret_binary_requires_either_string_or_binary(): + conn = boto3.client('secretsmanager', region_name='us-west-2') + with assert_raises(ClientError) as ire: + conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME) + + ire.exception.response['Error']['Code'].should.equal('InvalidRequestException') + ire.exception.response['Error']['Message'].should.equal('You must provide either SecretString or SecretBinary.') + + @mock_secretsmanager def test_put_secret_value_can_get_first_version_if_put_twice(): conn = boto3.client('secretsmanager', region_name='us-west-2') diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index 3d598d406..d7bf32e51 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -109,6 +109,17 @@ def test_publish_to_sqs_bad(): }}) except ClientError as err: err.response['Error']['Code'].should.equal('InvalidParameterValue') + try: + # Test Number DataType, with a non numeric value + conn.publish( + TopicArn=topic_arn, Message=message, + MessageAttributes={'price': { + 'DataType': 'Number', + 'StringValue': 'error' + }}) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response['Error']['Message'].should.equal("An error occurred (ParameterValueInvalid) when calling the Publish operation: Could not cast message attribute 'price' value to number.") @mock_sqs @@ -487,3 +498,380 @@ def test_filtering_exact_string_no_attributes_no_match(): message_attributes = [ json.loads(m.body)['MessageAttributes'] for m in messages] message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_int(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'Number', 'Value': 100}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_float(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100.1]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '100.1'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'Number', 'Value': 100.1}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_float_accuracy(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100.123456789]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '100.1234561'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'Number', 'Value': 100.1234561}}]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='no match', + MessageAttributes={'price': {'DataType': 'Number', + 'StringValue': '101'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_exact_number_with_string_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='no match', + MessageAttributes={'price': {'DataType': 'String', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'customer_interests': ['basketball', 'baseball']}) + + topic.publish( + Message='match', + MessageAttributes={'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}}]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'customer_interests': ['baseball']}) + + topic.publish( + Message='no_match', + MessageAttributes={'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100, 500]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': json.dumps([100, 50])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'String.Array', 'Value': json.dumps([100, 50])}}]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_float_accuracy_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100.123456789, 500]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': json.dumps([100.1234561, 50])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'String.Array', 'Value': json.dumps([100.1234561, 50])}}]) + + +@mock_sqs +@mock_sns +# this is the correct behavior from SNS +def test_filtering_string_array_with_number_no_array_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100, 500]}) + + topic.publish( + Message='match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'price': {'Type': 'String.Array', 'Value': '100'}}]) + + +@mock_sqs +@mock_sns +def test_filtering_string_array_with_number_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [500]}) + + topic.publish( + Message='no_match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': json.dumps([100, 50])}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +# this is the correct behavior from SNS +def test_filtering_string_array_with_string_no_array_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'price': [100]}) + + topic.publish( + Message='no_match', + MessageAttributes={'price': {'DataType': 'String.Array', + 'StringValue': 'one hundread'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_exists_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}]}) + + topic.publish( + Message='match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'store': {'Type': 'String', 'Value': 'example_corp'}}]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_exists_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}]}) + + topic.publish( + Message='no match', + MessageAttributes={'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_not_exists_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': False}]}) + + topic.publish( + Message='match', + MessageAttributes={'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal(['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal( + [{'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) + + +@mock_sqs +@mock_sns +def test_filtering_attribute_key_not_exists_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': False}]}) + + topic.publish( + Message='no match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) + + +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}], + 'event': ['order_cancelled'], + 'customer_interests': ['basketball', 'baseball'], + 'price': [100]}) + + topic.publish( + Message='match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}, + 'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}, + 'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}, + 'price': {'DataType': 'Number', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal( + ['match']) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([{ + 'store': {'Type': 'String', 'Value': 'example_corp'}, + 'event': {'Type': 'String', 'Value': 'order_cancelled'}, + 'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}, + 'price': {'Type': 'Number', 'Value': 100}}]) + + +@mock_sqs +@mock_sns +def test_filtering_all_AND_matching_no_match(): + topic, subscription, queue = _setup_filter_policy_test( + {'store': [{'exists': True}], + 'event': ['order_cancelled'], + 'customer_interests': ['basketball', 'baseball'], + 'price': [100], + "encrypted": [False]}) + + topic.publish( + Message='no match', + MessageAttributes={'store': {'DataType': 'String', + 'StringValue': 'example_corp'}, + 'event': {'DataType': 'String', + 'StringValue': 'order_cancelled'}, + 'customer_interests': {'DataType': 'String.Array', + 'StringValue': json.dumps(['basketball', 'rugby'])}, + 'price': {'DataType': 'Number', + 'StringValue': '100'}}) + + messages = queue.receive_messages(MaxNumberOfMessages=5) + message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies.should.equal([]) + message_attributes = [ + json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes.should.equal([]) diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 2a56c8213..012cd6470 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -201,7 +201,9 @@ def test_creating_subscription_with_attributes(): "store": ["example_corp"], "event": ["order_cancelled"], "encrypted": [False], - "customer_interests": ["basketball", "baseball"] + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None] }) conn.subscribe(TopicArn=topic_arn, @@ -294,7 +296,9 @@ def test_set_subscription_attributes(): "store": ["example_corp"], "event": ["order_cancelled"], "encrypted": [False], - "customer_interests": ["basketball", "baseball"] + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None] }) conn.set_subscription_attributes( SubscriptionArn=subscription_arn, @@ -332,6 +336,77 @@ def test_set_subscription_attributes(): ) +@mock_sns +def test_subscribe_invalid_filter_policy(): + conn = boto3.client('sns', region_name = 'us-east-1') + conn.create_topic(Name = 'some-topic') + response = conn.list_topics() + topic_arn = response['Topics'][0]['TopicArn'] + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [str(i) for i in range(151)] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Filter policy is too complex') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [['example_corp']] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [{'exists': None}] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: exists match pattern must be either true or false.') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [{'error': True}] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InvalidParameter') + err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Unrecognized match type error') + + try: + conn.subscribe(TopicArn = topic_arn, + Protocol = 'http', + Endpoint = 'http://example.com/', + Attributes = { + 'FilterPolicy': json.dumps({ + 'store': [1000000001] + }) + }) + except ClientError as err: + err.response['Error']['Code'].should.equal('InternalFailure') + @mock_sns def test_check_not_opted_out(): conn = boto3.client('sns', region_name='us-east-1') diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index d53ae50f7..56d82ea61 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -1117,6 +1117,28 @@ def test_redrive_policy_set_attributes(): assert copy_policy == redrive_policy +@mock_sqs +def test_redrive_policy_set_attributes_with_string_value(): + sqs = boto3.resource('sqs', region_name='us-east-1') + + queue = sqs.create_queue(QueueName='test-queue') + deadletter_queue = sqs.create_queue(QueueName='test-deadletter') + + queue.set_attributes(Attributes={ + 'RedrivePolicy': json.dumps({ + 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], + 'maxReceiveCount': '1', + })}) + + copy = sqs.get_queue_by_name(QueueName='test-queue') + assert 'RedrivePolicy' in copy.attributes + copy_policy = json.loads(copy.attributes['RedrivePolicy']) + assert copy_policy == { + 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], + 'maxReceiveCount': 1, + } + + @mock_sqs def test_receive_messages_with_message_group_id(): sqs = boto3.resource('sqs', region_name='us-east-1') diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py new file mode 100644 index 000000000..10953ce2d --- /dev/null +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -0,0 +1,378 @@ +from __future__ import unicode_literals + +import boto3 +import sure # noqa +import datetime + +from datetime import datetime +from botocore.exceptions import ClientError +from nose.tools import assert_raises + +from moto import mock_sts, mock_stepfunctions + + +region = 'us-east-1' +simple_definition = '{"Comment": "An example of the Amazon States Language using a choice state.",' \ + '"StartAt": "DefaultState",' \ + '"States": ' \ + '{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}' +account_id = None + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_succeeds(): + client = boto3.client('stepfunctions', region_name=region) + name = 'example_step_function' + # + response = client.create_state_machine(name=name, + definition=str(simple_definition), + roleArn=_get_default_role()) + # + response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response['creationDate'].should.be.a(datetime) + response['stateMachineArn'].should.equal('arn:aws:states:' + region + ':123456789012:stateMachine:' + name) + + +@mock_stepfunctions +def test_state_machine_creation_fails_with_invalid_names(): + client = boto3.client('stepfunctions', region_name=region) + invalid_names = [ + 'with space', + 'withbracket', 'with{bracket', 'with}bracket', 'with[bracket', 'with]bracket', + 'with?wildcard', 'with*wildcard', + 'special"char', 'special#char', 'special%char', 'special\\char', 'special^char', 'special|char', + 'special~char', 'special`char', 'special$char', 'special&char', 'special,char', 'special;char', + 'special:char', 'special/char', + u'uni\u0000code', u'uni\u0001code', u'uni\u0002code', u'uni\u0003code', u'uni\u0004code', + u'uni\u0005code', u'uni\u0006code', u'uni\u0007code', u'uni\u0008code', u'uni\u0009code', + u'uni\u000Acode', u'uni\u000Bcode', u'uni\u000Ccode', + u'uni\u000Dcode', u'uni\u000Ecode', u'uni\u000Fcode', + u'uni\u0010code', u'uni\u0011code', u'uni\u0012code', u'uni\u0013code', u'uni\u0014code', + u'uni\u0015code', u'uni\u0016code', u'uni\u0017code', u'uni\u0018code', u'uni\u0019code', + u'uni\u001Acode', u'uni\u001Bcode', u'uni\u001Ccode', + u'uni\u001Dcode', u'uni\u001Ecode', u'uni\u001Fcode', + u'uni\u007Fcode', + u'uni\u0080code', u'uni\u0081code', u'uni\u0082code', u'uni\u0083code', u'uni\u0084code', + u'uni\u0085code', u'uni\u0086code', u'uni\u0087code', u'uni\u0088code', u'uni\u0089code', + u'uni\u008Acode', u'uni\u008Bcode', u'uni\u008Ccode', + u'uni\u008Dcode', u'uni\u008Ecode', u'uni\u008Fcode', + u'uni\u0090code', u'uni\u0091code', u'uni\u0092code', u'uni\u0093code', u'uni\u0094code', + u'uni\u0095code', u'uni\u0096code', u'uni\u0097code', u'uni\u0098code', u'uni\u0099code', + u'uni\u009Acode', u'uni\u009Bcode', u'uni\u009Ccode', + u'uni\u009Dcode', u'uni\u009Ecode', u'uni\u009Fcode'] + # + + for invalid_name in invalid_names: + with assert_raises(ClientError) as exc: + client.create_state_machine(name=invalid_name, + definition=str(simple_definition), + roleArn=_get_default_role()) + + +@mock_stepfunctions +def test_state_machine_creation_requires_valid_role_arn(): + client = boto3.client('stepfunctions', region_name=region) + name = 'example_step_function' + # + with assert_raises(ClientError) as exc: + client.create_state_machine(name=name, + definition=str(simple_definition), + roleArn='arn:aws:iam:1234:role/unknown_role') + + +@mock_stepfunctions +def test_state_machine_list_returns_empty_list_by_default(): + client = boto3.client('stepfunctions', region_name=region) + # + list = client.list_state_machines() + list['stateMachines'].should.be.empty + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_returns_created_state_machines(): + client = boto3.client('stepfunctions', region_name=region) + # + machine2 = client.create_state_machine(name='name2', + definition=str(simple_definition), + roleArn=_get_default_role()) + machine1 = client.create_state_machine(name='name1', + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=[{'key': 'tag_key', 'value': 'tag_value'}]) + list = client.list_state_machines() + # + list['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + list['stateMachines'].should.have.length_of(2) + list['stateMachines'][0]['creationDate'].should.be.a(datetime) + list['stateMachines'][0]['creationDate'].should.equal(machine1['creationDate']) + list['stateMachines'][0]['name'].should.equal('name1') + list['stateMachines'][0]['stateMachineArn'].should.equal(machine1['stateMachineArn']) + list['stateMachines'][1]['creationDate'].should.be.a(datetime) + list['stateMachines'][1]['creationDate'].should.equal(machine2['creationDate']) + list['stateMachines'][1]['name'].should.equal('name2') + list['stateMachines'][1]['stateMachineArn'].should.equal(machine2['stateMachineArn']) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_is_idempotent_by_name(): + client = boto3.client('stepfunctions', region_name=region) + # + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(1) + # + client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(1) + # + client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=_get_default_role()) + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(2) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_creation_can_be_described(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + desc = client.describe_state_machine(stateMachineArn=sm['stateMachineArn']) + desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + desc['creationDate'].should.equal(sm['creationDate']) + desc['definition'].should.equal(str(simple_definition)) + desc['name'].should.equal('name') + desc['roleArn'].should.equal(_get_default_role()) + desc['stateMachineArn'].should.equal(sm['stateMachineArn']) + desc['status'].should.equal('ACTIVE') + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' + client.describe_state_machine(stateMachineArn=unknown_state_machine) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_machine_in_different_account(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_state_machine = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' + client.describe_state_machine(stateMachineArn=unknown_state_machine) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_be_deleted(): + client = boto3.client('stepfunctions', region_name=region) + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + # + response = client.delete_state_machine(stateMachineArn=sm['stateMachineArn']) + response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + # + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_deleted_nonexisting_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + unknown_state_machine = 'arn:aws:states:' + region + ':123456789012:stateMachine:unknown' + response = client.delete_state_machine(stateMachineArn=unknown_state_machine) + response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + # + sm_list = client.list_state_machines() + sm_list['stateMachines'].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_created_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + machine = client.create_state_machine(name='name1', + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=[{'key': 'tag_key', 'value': 'tag_value'}]) + response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) + tags = response['tags'] + tags.should.have.length_of(1) + tags[0].should.equal({'key': 'tag_key', 'value': 'tag_value'}) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_machine_without_tags(): + client = boto3.client('stepfunctions', region_name=region) + # + machine = client.create_state_machine(name='name1', + definition=str(simple_definition), + roleArn=_get_default_role()) + response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) + tags = response['tags'] + tags.should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_tags_for_nonexisting_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + non_existing_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' + response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) + tags = response['tags'] + tags.should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_start_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + # + execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + expected_exec_name = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:name:[a-zA-Z0-9-]+' + execution['executionArn'].should.match(expected_exec_name) + execution['startDate'].should.be.a(datetime) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_executions(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + execution_arn = execution['executionArn'] + execution_name = execution_arn[execution_arn.rindex(':')+1:] + executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) + # + executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + executions['executions'].should.have.length_of(1) + executions['executions'][0]['executionArn'].should.equal(execution_arn) + executions['executions'][0]['name'].should.equal(execution_name) + executions['executions'][0]['startDate'].should.equal(execution['startDate']) + executions['executions'][0]['stateMachineArn'].should.equal(sm['stateMachineArn']) + executions['executions'][0]['status'].should.equal('RUNNING') + executions['executions'][0].shouldnt.have('stopDate') + + +@mock_stepfunctions +@mock_sts +def test_state_machine_list_executions_when_none_exist(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) + # + executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + executions['executions'].should.have.length_of(0) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_describe_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + description = client.describe_execution(executionArn=execution['executionArn']) + # + description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + description['executionArn'].should.equal(execution['executionArn']) + description['input'].should.equal("{}") + description['name'].shouldnt.be.empty + description['startDate'].should.equal(execution['startDate']) + description['stateMachineArn'].should.equal(sm['stateMachineArn']) + description['status'].should.equal('RUNNING') + description.shouldnt.have('stopDate') + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_machine(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' + client.describe_execution(executionArn=unknown_execution) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_can_be_described_by_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + desc = client.describe_state_machine_for_execution(executionArn=execution['executionArn']) + desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + desc['definition'].should.equal(str(simple_definition)) + desc['name'].should.equal('name') + desc['roleArn'].should.equal(_get_default_role()) + desc['stateMachineArn'].should.equal(sm['stateMachineArn']) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_throws_error_when_describing_unknown_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + with assert_raises(ClientError) as exc: + unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' + client.describe_state_machine_for_execution(executionArn=unknown_execution) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_stop_execution(): + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + start = client.start_execution(stateMachineArn=sm['stateMachineArn']) + stop = client.stop_execution(executionArn=start['executionArn']) + # + stop['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + stop['stopDate'].should.be.a(datetime) + + +@mock_stepfunctions +@mock_sts +def test_state_machine_describe_execution_after_stoppage(): + account_id + client = boto3.client('stepfunctions', region_name=region) + # + sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + client.stop_execution(executionArn=execution['executionArn']) + description = client.describe_execution(executionArn=execution['executionArn']) + # + description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + description['status'].should.equal('SUCCEEDED') + description['stopDate'].should.be.a(datetime) + + +def _get_account_id(): + global account_id + if account_id: + return account_id + sts = boto3.client("sts") + identity = sts.get_caller_identity() + account_id = identity['Account'] + return account_id + + +def _get_default_role(): + return 'arn:aws:iam:' + _get_account_id() + ':role/unknown_sf_role'