Merge pull request #2 from spulec/master

Upstream changes
This commit is contained in:
Bert Blommers 2019-10-03 09:18:17 +01:00 committed by GitHub
commit 82ce5ad430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 3740 additions and 790 deletions

View File

@ -1237,7 +1237,7 @@
- [ ] delete_identities - [ ] delete_identities
- [ ] delete_identity_pool - [ ] delete_identity_pool
- [ ] describe_identity - [ ] describe_identity
- [ ] describe_identity_pool - [X] describe_identity_pool
- [X] get_credentials_for_identity - [X] get_credentials_for_identity
- [X] get_id - [X] get_id
- [ ] get_identity_pool_roles - [ ] get_identity_pool_roles
@ -3801,14 +3801,14 @@
- [ ] update_stream - [ ] update_stream
## kms ## kms
41% implemented 54% implemented
- [X] cancel_key_deletion - [X] cancel_key_deletion
- [ ] connect_custom_key_store - [ ] connect_custom_key_store
- [ ] create_alias - [ ] create_alias
- [ ] create_custom_key_store - [ ] create_custom_key_store
- [ ] create_grant - [ ] create_grant
- [X] create_key - [X] create_key
- [ ] decrypt - [X] decrypt
- [X] delete_alias - [X] delete_alias
- [ ] delete_custom_key_store - [ ] delete_custom_key_store
- [ ] delete_imported_key_material - [ ] delete_imported_key_material
@ -3819,10 +3819,10 @@
- [ ] disconnect_custom_key_store - [ ] disconnect_custom_key_store
- [X] enable_key - [X] enable_key
- [X] enable_key_rotation - [X] enable_key_rotation
- [ ] encrypt - [X] encrypt
- [X] generate_data_key - [X] generate_data_key
- [ ] generate_data_key_without_plaintext - [X] generate_data_key_without_plaintext
- [ ] generate_random - [X] generate_random
- [X] get_key_policy - [X] get_key_policy
- [X] get_key_rotation_status - [X] get_key_rotation_status
- [ ] get_parameters_for_import - [ ] get_parameters_for_import
@ -3834,7 +3834,7 @@
- [X] list_resource_tags - [X] list_resource_tags
- [ ] list_retirable_grants - [ ] list_retirable_grants
- [X] put_key_policy - [X] put_key_policy
- [ ] re_encrypt - [X] re_encrypt
- [ ] retire_grant - [ ] retire_grant
- [ ] revoke_grant - [ ] revoke_grant
- [X] schedule_key_deletion - [X] schedule_key_deletion
@ -6050,24 +6050,24 @@
## stepfunctions ## stepfunctions
0% implemented 0% implemented
- [ ] create_activity - [ ] create_activity
- [ ] create_state_machine - [X] create_state_machine
- [ ] delete_activity - [ ] delete_activity
- [ ] delete_state_machine - [X] delete_state_machine
- [ ] describe_activity - [ ] describe_activity
- [ ] describe_execution - [X] describe_execution
- [ ] describe_state_machine - [X] describe_state_machine
- [ ] describe_state_machine_for_execution - [x] describe_state_machine_for_execution
- [ ] get_activity_task - [ ] get_activity_task
- [ ] get_execution_history - [ ] get_execution_history
- [ ] list_activities - [ ] list_activities
- [ ] list_executions - [X] list_executions
- [ ] list_state_machines - [X] list_state_machines
- [ ] list_tags_for_resource - [X] list_tags_for_resource
- [ ] send_task_failure - [ ] send_task_failure
- [ ] send_task_heartbeat - [ ] send_task_heartbeat
- [ ] send_task_success - [ ] send_task_success
- [ ] start_execution - [X] start_execution
- [ ] stop_execution - [X] stop_execution
- [ ] tag_resource - [ ] tag_resource
- [ ] untag_resource - [ ] untag_resource
- [ ] update_state_machine - [ ] update_state_machine

View File

@ -94,6 +94,8 @@ Currently implemented Services:
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | all endpoints done | | SES | @mock_ses | all endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SFN | @mock_stepfunctions | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | all endpoints done | | SNS | @mock_sns | all endpoints done |
+---------------------------+-----------------------+------------------------------------+ +---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done | | SQS | @mock_sqs | core endpoints done |

View File

@ -42,6 +42,7 @@ from .ses import mock_ses, mock_ses_deprecated # flake8: noqa
from .secretsmanager import mock_secretsmanager # flake8: noqa from .secretsmanager import mock_secretsmanager # flake8: noqa
from .sns import mock_sns, mock_sns_deprecated # flake8: noqa from .sns import mock_sns, mock_sns_deprecated # flake8: noqa
from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa
from .stepfunctions import mock_stepfunctions # flake8: noqa
from .sts import mock_sts, mock_sts_deprecated # flake8: noqa from .sts import mock_sts, mock_sts_deprecated # flake8: noqa
from .ssm import mock_ssm # flake8: noqa from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa

View File

@ -298,7 +298,7 @@ class Stage(BaseModel, dict):
class ApiKey(BaseModel, dict): class ApiKey(BaseModel, dict):
def __init__(self, name=None, description=None, enabled=True, def __init__(self, name=None, description=None, enabled=True,
generateDistinctId=False, value=None, stageKeys=None, customerId=None): generateDistinctId=False, value=None, stageKeys=None, tags=None, customerId=None):
super(ApiKey, self).__init__() super(ApiKey, self).__init__()
self['id'] = create_id() self['id'] = create_id()
self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40)) self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40))
@ -308,6 +308,7 @@ class ApiKey(BaseModel, dict):
self['enabled'] = enabled self['enabled'] = enabled
self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) self['createdDate'] = self['lastUpdatedDate'] = int(time.time())
self['stageKeys'] = stageKeys self['stageKeys'] = stageKeys
self['tags'] = tags
def update_operations(self, patch_operations): def update_operations(self, patch_operations):
for op in patch_operations: for op in patch_operations:
@ -407,10 +408,16 @@ class RestAPI(BaseModel):
stage_url_upper = STAGE_URL.format(api_id=self.id.upper(), stage_url_upper = STAGE_URL.format(api_id=self.id.upper(),
region_name=self.region_name, stage_name=stage_name) region_name=self.region_name, stage_name=stage_name)
responses.add_callback(responses.GET, stage_url_lower, for url in [stage_url_lower, stage_url_upper]:
callback=self.resource_callback) responses._default_mock._matches.insert(0,
responses.add_callback(responses.GET, stage_url_upper, responses.CallbackResponse(
callback=self.resource_callback) url=url,
method=responses.GET,
callback=self.resource_callback,
content_type="text/plain",
match_querystring=False,
)
)
def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None):
if variables is None: if variables is None:

View File

@ -40,6 +40,7 @@ from moto.secretsmanager import secretsmanager_backends
from moto.sns import sns_backends from moto.sns import sns_backends
from moto.sqs import sqs_backends from moto.sqs import sqs_backends
from moto.ssm import ssm_backends from moto.ssm import ssm_backends
from moto.stepfunctions import stepfunction_backends
from moto.sts import sts_backends from moto.sts import sts_backends
from moto.swf import swf_backends from moto.swf import swf_backends
from moto.xray import xray_backends from moto.xray import xray_backends
@ -91,6 +92,7 @@ BACKENDS = {
'sns': sns_backends, 'sns': sns_backends,
'sqs': sqs_backends, 'sqs': sqs_backends,
'ssm': ssm_backends, 'ssm': ssm_backends,
'stepfunctions': stepfunction_backends,
'sts': sts_backends, 'sts': sts_backends,
'swf': swf_backends, 'swf': swf_backends,
'route53': route53_backends, 'route53': route53_backends,

View File

@ -0,0 +1,15 @@
from __future__ import unicode_literals
import json
from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest):
def __init__(self, message):
super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'ResourceNotFoundException',
})

View File

@ -8,7 +8,7 @@ import boto.cognito.identity
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from .exceptions import ResourceNotFoundError
from .utils import get_random_identity_id from .utils import get_random_identity_id
@ -39,10 +39,29 @@ class CognitoIdentityBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(region) self.__init__(region)
def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, def describe_identity_pool(self, identity_pool_id):
supported_login_providers, developer_provider_name, open_id_connect_provider_arns, identity_pool = self.identity_pools.get(identity_pool_id, None)
cognito_identity_providers, saml_provider_arns):
if not identity_pool:
raise ResourceNotFoundError(identity_pool)
response = json.dumps({
'AllowUnauthenticatedIdentities': identity_pool.allow_unauthenticated_identities,
'CognitoIdentityProviders': identity_pool.cognito_identity_providers,
'DeveloperProviderName': identity_pool.developer_provider_name,
'IdentityPoolId': identity_pool.identity_pool_id,
'IdentityPoolName': identity_pool.identity_pool_name,
'IdentityPoolTags': {},
'OpenIdConnectProviderARNs': identity_pool.open_id_connect_provider_arns,
'SamlProviderARNs': identity_pool.saml_provider_arns,
'SupportedLoginProviders': identity_pool.supported_login_providers
})
return response
def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities,
supported_login_providers, developer_provider_name, open_id_connect_provider_arns,
cognito_identity_providers, saml_provider_arns):
new_identity = CognitoIdentity(self.region, identity_pool_name, new_identity = CognitoIdentity(self.region, identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities, allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers, supported_login_providers=supported_login_providers,
@ -77,12 +96,12 @@ class CognitoIdentityBackend(BaseBackend):
response = json.dumps( response = json.dumps(
{ {
"Credentials": "Credentials":
{ {
"AccessKeyId": "TESTACCESSKEY12345", "AccessKeyId": "TESTACCESSKEY12345",
"Expiration": expiration_str, "Expiration": expiration_str,
"SecretKey": "ABCSECRETKEY", "SecretKey": "ABCSECRETKEY",
"SessionToken": "ABC12345" "SessionToken": "ABC12345"
}, },
"IdentityId": identity_id "IdentityId": identity_id
}) })
return response return response

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import cognitoidentity_backends from .models import cognitoidentity_backends
from .utils import get_random_identity_id from .utils import get_random_identity_id
@ -16,6 +15,7 @@ class CognitoIdentityResponse(BaseResponse):
open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs') open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs')
cognito_identity_providers = self._get_param('CognitoIdentityProviders') cognito_identity_providers = self._get_param('CognitoIdentityProviders')
saml_provider_arns = self._get_param('SamlProviderARNs') saml_provider_arns = self._get_param('SamlProviderARNs')
return cognitoidentity_backends[self.region].create_identity_pool( return cognitoidentity_backends[self.region].create_identity_pool(
identity_pool_name=identity_pool_name, identity_pool_name=identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities, allow_unauthenticated_identities=allow_unauthenticated_identities,
@ -28,6 +28,9 @@ class CognitoIdentityResponse(BaseResponse):
def get_id(self): def get_id(self):
return cognitoidentity_backends[self.region].get_id() return cognitoidentity_backends[self.region].get_id()
def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool(self._get_param('IdentityPoolId'))
def get_credentials_for_identity(self): def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId')) return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId'))

View File

@ -14,7 +14,7 @@ SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
""" """
ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?> ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
<Response> <ErrorResponse>
<Errors> <Errors>
<Error> <Error>
<Code>{{error_type}}</Code> <Code>{{error_type}}</Code>
@ -23,7 +23,7 @@ ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error> </Error>
</Errors> </Errors>
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID> <RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID>
</Response> </ErrorResponse>
""" """
ERROR_JSON_RESPONSE = u"""{ ERROR_JSON_RESPONSE = u"""{

View File

@ -197,53 +197,9 @@ class CallbackResponse(responses.CallbackResponse):
botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send')
responses_mock = responses._default_mock responses_mock = responses._default_mock
# Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests.
class ResponsesMockAWS(BaseMockAWS): responses_mock.add_passthru("http")
def reset(self):
botocore_mock.reset()
responses_mock.reset()
def enable_patching(self):
if not hasattr(botocore_mock, '_patcher') or not hasattr(botocore_mock._patcher, 'target'):
# Check for unactivated patcher
botocore_mock.start()
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'):
responses_mock.start()
for method in RESPONSES_METHODS:
for backend in self.backends_for_urls.values():
for key, value in backend.urls.items():
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
def disable_patching(self):
try:
botocore_mock.stop()
except RuntimeError:
pass
try:
responses_mock.stop()
except RuntimeError:
pass
BOTOCORE_HTTP_METHODS = [ BOTOCORE_HTTP_METHODS = [
@ -310,6 +266,14 @@ botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(('before-send', botocore_stubber)) BUILTIN_HANDLERS.append(('before-send', botocore_stubber))
def not_implemented_callback(request):
status = 400
headers = {}
response = "The method is not implemented"
return status, headers, response
class BotocoreEventMockAWS(BaseMockAWS): class BotocoreEventMockAWS(BaseMockAWS):
def reset(self): def reset(self):
botocore_stubber.reset() botocore_stubber.reset()
@ -339,6 +303,24 @@ class BotocoreEventMockAWS(BaseMockAWS):
match_querystring=False, match_querystring=False,
) )
) )
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
def disable_patching(self): def disable_patching(self):
botocore_stubber.enabled = False botocore_stubber.enabled = False

View File

@ -941,8 +941,7 @@ class OpAnd(Op):
def expr(self, item): def expr(self, item):
lhs = self.lhs.expr(item) lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item) return lhs and self.rhs.expr(item)
return lhs and rhs
class OpLessThan(Op): class OpLessThan(Op):

View File

@ -363,7 +363,7 @@ class StreamRecord(BaseModel):
'dynamodb': { 'dynamodb': {
'StreamViewType': stream_type, 'StreamViewType': stream_type,
'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(), 'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(),
'SequenceNumber': seq, 'SequenceNumber': str(seq),
'SizeBytes': 1, 'SizeBytes': 1,
'Keys': keys 'Keys': keys
} }

View File

@ -356,9 +356,18 @@ class DynamoHandler(BaseResponse):
if projection_expression and expression_attribute_names: if projection_expression and expression_attribute_names:
expressions = [x.strip() for x in projection_expression.split(',')] expressions = [x.strip() for x in projection_expression.split(',')]
projection_expression = None
for expression in expressions: for expression in expressions:
if projection_expression is not None:
projection_expression = projection_expression + ", "
else:
projection_expression = ""
if expression in expression_attribute_names: if expression in expression_attribute_names:
projection_expression = projection_expression.replace(expression, expression_attribute_names[expression]) projection_expression = projection_expression + \
expression_attribute_names[expression]
else:
projection_expression = projection_expression + expression
filter_kwargs = {} filter_kwargs = {}

View File

@ -39,7 +39,7 @@ class ShardIterator(BaseModel):
def get(self, limit=1000): def get(self, limit=1000):
items = self.stream_shard.get(self.sequence_number, limit) items = self.stream_shard.get(self.sequence_number, limit)
try: try:
last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items) last_sequence_number = max(int(i['dynamodb']['SequenceNumber']) for i in items)
new_shard_iterator = ShardIterator(self.streams_backend, new_shard_iterator = ShardIterator(self.streams_backend,
self.stream_shard, self.stream_shard,
'AFTER_SEQUENCE_NUMBER', 'AFTER_SEQUENCE_NUMBER',

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import dynamodbstreams_backends from .models import dynamodbstreams_backends
from six import string_types
class DynamoDBStreamsHandler(BaseResponse): class DynamoDBStreamsHandler(BaseResponse):
@ -23,8 +24,13 @@ class DynamoDBStreamsHandler(BaseResponse):
arn = self._get_param('StreamArn') arn = self._get_param('StreamArn')
shard_id = self._get_param('ShardId') shard_id = self._get_param('ShardId')
shard_iterator_type = self._get_param('ShardIteratorType') shard_iterator_type = self._get_param('ShardIteratorType')
sequence_number = self._get_param('SequenceNumber')
# according to documentation sequence_number param should be string
if isinstance(sequence_number, string_types):
sequence_number = int(sequence_number)
return self.backend.get_shard_iterator(arn, shard_id, return self.backend.get_shard_iterator(arn, shard_id,
shard_iterator_type) shard_iterator_type, sequence_number)
def get_records(self): def get_records(self):
arn = self._get_param('ShardIterator') arn = self._get_param('ShardIterator')

View File

@ -6,6 +6,7 @@ from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, \ from moto.ec2.utils import filters_from_querystring, \
dict_from_querystring dict_from_querystring
from moto.elbv2 import elbv2_backends
class InstanceResponse(BaseResponse): class InstanceResponse(BaseResponse):
@ -68,6 +69,7 @@ class InstanceResponse(BaseResponse):
if self.is_not_dryrun('TerminateInstance'): if self.is_not_dryrun('TerminateInstance'):
instances = self.ec2_backend.terminate_instances(instance_ids) instances = self.ec2_backend.terminate_instances(instance_ids)
autoscaling_backends[self.region].notify_terminate_instances(instance_ids) autoscaling_backends[self.region].notify_terminate_instances(instance_ids)
elbv2_backends[self.region].notify_terminate_instances(instance_ids)
template = self.response_template(EC2_TERMINATE_INSTANCES) template = self.response_template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)

View File

@ -131,7 +131,7 @@ class InvalidActionTypeError(ELBClientError):
def __init__(self, invalid_name, index): def __init__(self, invalid_name, index):
super(InvalidActionTypeError, self).__init__( super(InvalidActionTypeError, self).__init__(
"ValidationError", "ValidationError",
"1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect]" % (invalid_name, index) "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect, fixed-response]" % (invalid_name, index)
) )
@ -190,3 +190,18 @@ class InvalidModifyRuleArgumentsError(ELBClientError):
"ValidationError", "ValidationError",
"Either conditions or actions must be specified" "Either conditions or actions must be specified"
) )
class InvalidStatusCodeActionTypeError(ELBClientError):
def __init__(self, msg):
super(InvalidStatusCodeActionTypeError, self).__init__(
"ValidationError", msg
)
class InvalidLoadBalancerActionException(ELBClientError):
def __init__(self, msg):
super(InvalidLoadBalancerActionException, self).__init__(
"InvalidLoadBalancerAction", msg
)

View File

@ -3,10 +3,11 @@ from __future__ import unicode_literals
import datetime import datetime
import re import re
from jinja2 import Template from jinja2 import Template
from botocore.exceptions import ParamValidationError
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase
from moto.ec2.models import ec2_backends from moto.ec2.models import ec2_backends
from moto.acm.models import acm_backends from moto.acm.models import acm_backends
from .utils import make_arn_for_target_group from .utils import make_arn_for_target_group
@ -31,8 +32,8 @@ from .exceptions import (
RuleNotFoundError, RuleNotFoundError,
DuplicatePriorityError, DuplicatePriorityError,
InvalidTargetGroupNameError, InvalidTargetGroupNameError,
InvalidModifyRuleArgumentsError InvalidModifyRuleArgumentsError,
) InvalidStatusCodeActionTypeError, InvalidLoadBalancerActionException)
class FakeHealthStatus(BaseModel): class FakeHealthStatus(BaseModel):
@ -110,6 +111,11 @@ class FakeTargetGroup(BaseModel):
if not t: if not t:
raise InvalidTargetError() raise InvalidTargetError()
def deregister_terminated_instances(self, instance_ids):
for target_id in list(self.targets.keys()):
if target_id in instance_ids:
del self.targets[target_id]
def add_tag(self, key, value): def add_tag(self, key, value):
if len(self.tags) >= 10 and key not in self.tags: if len(self.tags) >= 10 and key not in self.tags:
raise TooManyTagsError() raise TooManyTagsError()
@ -215,9 +221,9 @@ class FakeListener(BaseModel):
action_type = action['Type'] action_type = action['Type']
if action_type == 'forward': if action_type == 'forward':
default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']})
elif action_type in ['redirect', 'authenticate-cognito']: elif action_type in ['redirect', 'authenticate-cognito', 'fixed-response']:
redirect_action = {'type': action_type} redirect_action = {'type': action_type}
key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig' key = underscores_to_camelcase(action_type.capitalize().replace('-', '_')) + 'Config'
for redirect_config_key, redirect_config_value in action[key].items(): for redirect_config_key, redirect_config_value in action[key].items():
# need to match the output of _get_list_prefix # need to match the output of _get_list_prefix
redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value
@ -253,6 +259,12 @@ class FakeAction(BaseModel):
<UserPoolClientId>{{ action.data["authenticate_cognito_config._user_pool_client_id"] }}</UserPoolClientId> <UserPoolClientId>{{ action.data["authenticate_cognito_config._user_pool_client_id"] }}</UserPoolClientId>
<UserPoolDomain>{{ action.data["authenticate_cognito_config._user_pool_domain"] }}</UserPoolDomain> <UserPoolDomain>{{ action.data["authenticate_cognito_config._user_pool_domain"] }}</UserPoolDomain>
</AuthenticateCognitoConfig> </AuthenticateCognitoConfig>
{% elif action.type == "fixed-response" %}
<FixedResponseConfig>
<ContentType>{{ action.data["fixed_response_config._content_type"] }}</ContentType>
<MessageBody>{{ action.data["fixed_response_config._message_body"] }}</MessageBody>
<StatusCode>{{ action.data["fixed_response_config._status_code"] }}</StatusCode>
</FixedResponseConfig>
{% endif %} {% endif %}
""") """)
return template.render(action=self) return template.render(action=self)
@ -477,11 +489,30 @@ class ELBv2Backend(BaseBackend):
action_target_group_arn = action.data['target_group_arn'] action_target_group_arn = action.data['target_group_arn']
if action_target_group_arn not in target_group_arns: if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn) raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'fixed-response':
self._validate_fixed_response_action(action, i, index)
elif action_type in ['redirect', 'authenticate-cognito']: elif action_type in ['redirect', 'authenticate-cognito']:
pass pass
else: else:
raise InvalidActionTypeError(action_type, index) raise InvalidActionTypeError(action_type, index)
def _validate_fixed_response_action(self, action, i, index):
status_code = action.data.get('fixed_response_config._status_code')
if status_code is None:
raise ParamValidationError(
report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' % i)
if not re.match(r'^(2|4|5)\d\d$', status_code):
raise InvalidStatusCodeActionTypeError(
"1 validation error detected: Value '%s' at 'actions.%s.member.fixedResponseConfig.statusCode' failed to satisfy constraint: \
Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, index)
)
content_type = action.data['fixed_response_config._content_type']
if content_type and content_type not in ['text/plain', 'text/css', 'text/html', 'application/javascript',
'application/json']:
raise InvalidLoadBalancerActionException(
"The ContentType must be one of:'text/html', 'application/json', 'application/javascript', 'text/css', 'text/plain'"
)
def create_target_group(self, name, **kwargs): def create_target_group(self, name, **kwargs):
if len(name) > 32: if len(name) > 32:
raise InvalidTargetGroupNameError( raise InvalidTargetGroupNameError(
@ -936,6 +967,10 @@ class ELBv2Backend(BaseBackend):
return True return True
return False return False
def notify_terminate_instances(self, instance_ids):
for target_group in self.target_groups.values():
target_group.deregister_terminated_instances(instance_ids)
elbv2_backends = {} elbv2_backends = {}
for region in ec2_backends.keys(): for region in ec2_backends.keys():

View File

@ -152,8 +152,10 @@ class IAMPolicyDocumentValidator:
sids = [] sids = []
for statement in self._statements: for statement in self._statements:
if "Sid" in statement: if "Sid" in statement:
assert statement["Sid"] not in sids statementId = statement["Sid"]
sids.append(statement["Sid"]) if statementId:
assert statementId not in sids
sids.append(statementId)
def _validate_statements_syntax(self): def _validate_statements_syntax(self):
assert "Statement" in self._policy_json assert "Statement" in self._policy_json

View File

@ -21,3 +21,11 @@ class InvalidRequestException(IoTDataPlaneClientError):
super(InvalidRequestException, self).__init__( super(InvalidRequestException, self).__init__(
"InvalidRequestException", message "InvalidRequestException", message
) )
class ConflictException(IoTDataPlaneClientError):
def __init__(self, message):
self.code = 409
super(ConflictException, self).__init__(
"ConflictException", message
)

View File

@ -6,6 +6,7 @@ import jsondiff
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.iot import iot_backends from moto.iot import iot_backends
from .exceptions import ( from .exceptions import (
ConflictException,
ResourceNotFoundException, ResourceNotFoundException,
InvalidRequestException InvalidRequestException
) )
@ -161,6 +162,8 @@ class IoTDataPlaneBackend(BaseBackend):
if any(_ for _ in payload['state'].keys() if _ not in ['desired', 'reported']): if any(_ for _ in payload['state'].keys() if _ not in ['desired', 'reported']):
raise InvalidRequestException('State contains an invalid node') raise InvalidRequestException('State contains an invalid node')
if 'version' in payload and thing.thing_shadow.version != payload['version']:
raise ConflictException('Version conflict')
new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload) new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload)
thing.thing_shadow = new_shadow thing.thing_shadow = new_shadow
return thing.thing_shadow return thing.thing_shadow

View File

@ -34,3 +34,23 @@ class NotAuthorizedException(JsonRESTError):
"NotAuthorizedException", None) "NotAuthorizedException", None)
self.description = '{"__type":"NotAuthorizedException"}' self.description = '{"__type":"NotAuthorizedException"}'
class AccessDeniedException(JsonRESTError):
code = 400
def __init__(self, message):
super(AccessDeniedException, self).__init__(
"AccessDeniedException", message)
self.description = '{"__type":"AccessDeniedException"}'
class InvalidCiphertextException(JsonRESTError):
code = 400
def __init__(self):
super(InvalidCiphertextException, self).__init__(
"InvalidCiphertextException", None)
self.description = '{"__type":"InvalidCiphertextException"}'

View File

@ -1,16 +1,18 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os import os
import boto.kms
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import generate_key_id
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timedelta from datetime import datetime, timedelta
import boto.kms
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import decrypt, encrypt, generate_key_id, generate_master_key
class Key(BaseModel): class Key(BaseModel):
def __init__(self, policy, key_usage, description, tags, region): def __init__(self, policy, key_usage, description, tags, region):
self.id = generate_key_id() self.id = generate_key_id()
self.policy = policy self.policy = policy
@ -19,10 +21,11 @@ class Key(BaseModel):
self.description = description self.description = description
self.enabled = True self.enabled = True
self.region = region self.region = region
self.account_id = "0123456789012" self.account_id = "012345678912"
self.key_rotation_status = False self.key_rotation_status = False
self.deletion_date = None self.deletion_date = None
self.tags = tags or {} self.tags = tags or {}
self.key_material = generate_master_key()
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -45,8 +48,8 @@ class Key(BaseModel):
"KeyState": self.key_state, "KeyState": self.key_state,
} }
} }
if self.key_state == 'PendingDeletion': if self.key_state == "PendingDeletion":
key_dict['KeyMetadata']['DeletionDate'] = iso_8601_datetime_without_milliseconds(self.deletion_date) key_dict["KeyMetadata"]["DeletionDate"] = iso_8601_datetime_without_milliseconds(self.deletion_date)
return key_dict return key_dict
def delete(self, region_name): def delete(self, region_name):
@ -55,28 +58,28 @@ class Key(BaseModel):
@classmethod @classmethod
def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name):
kms_backend = kms_backends[region_name] kms_backend = kms_backends[region_name]
properties = cloudformation_json['Properties'] properties = cloudformation_json["Properties"]
key = kms_backend.create_key( key = kms_backend.create_key(
policy=properties['KeyPolicy'], policy=properties["KeyPolicy"],
key_usage='ENCRYPT_DECRYPT', key_usage="ENCRYPT_DECRYPT",
description=properties['Description'], description=properties["Description"],
tags=properties.get('Tags'), tags=properties.get("Tags"),
region=region_name, region=region_name,
) )
key.key_rotation_status = properties['EnableKeyRotation'] key.key_rotation_status = properties["EnableKeyRotation"]
key.enabled = properties['Enabled'] key.enabled = properties["Enabled"]
return key return key
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'Arn':
if attribute_name == "Arn":
return self.arn return self.arn
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
class KmsBackend(BaseBackend): class KmsBackend(BaseBackend):
def __init__(self): def __init__(self):
self.keys = {} self.keys = {}
self.key_to_aliases = defaultdict(set) self.key_to_aliases = defaultdict(set)
@ -109,16 +112,43 @@ class KmsBackend(BaseBackend):
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to # allow the different methods (alias, ARN :key/, keyId, ARN alias) to
# describe key not just KeyId # describe key not just KeyId
key_id = self.get_key_id(key_id) key_id = self.get_key_id(key_id)
if r'alias/' in str(key_id).lower(): if r"alias/" in str(key_id).lower():
key_id = self.get_key_id_from_alias(key_id.split('alias/')[1]) key_id = self.get_key_id_from_alias(key_id.split("alias/")[1])
return self.keys[self.get_key_id(key_id)] return self.keys[self.get_key_id(key_id)]
def list_keys(self): def list_keys(self):
return self.keys.values() return self.keys.values()
def get_key_id(self, key_id): @staticmethod
def get_key_id(key_id):
# Allow use of ARN as well as pure KeyId # Allow use of ARN as well as pure KeyId
return str(key_id).split(r':key/')[1] if r':key/' in str(key_id).lower() else key_id if key_id.startswith("arn:") and ":key/" in key_id:
return key_id.split(":key/")[1]
return key_id
@staticmethod
def get_alias_name(alias_name):
# Allow use of ARN as well as alias name
if alias_name.startswith("arn:") and ":alias/" in alias_name:
return alias_name.split(":alias/")[1]
return alias_name
def any_id_to_key_id(self, key_id):
"""Go from any valid key ID to the raw key ID.
Acceptable inputs:
- raw key ID
- key ARN
- alias name
- alias ARN
"""
key_id = self.get_alias_name(key_id)
key_id = self.get_key_id(key_id)
if key_id.startswith("alias/"):
key_id = self.get_key_id_from_alias(key_id)
return key_id
def alias_exists(self, alias_name): def alias_exists(self, alias_name):
for aliases in self.key_to_aliases.values(): for aliases in self.key_to_aliases.values():
@ -162,37 +192,69 @@ class KmsBackend(BaseBackend):
def disable_key(self, key_id): def disable_key(self, key_id):
self.keys[key_id].enabled = False self.keys[key_id].enabled = False
self.keys[key_id].key_state = 'Disabled' self.keys[key_id].key_state = "Disabled"
def enable_key(self, key_id): def enable_key(self, key_id):
self.keys[key_id].enabled = True self.keys[key_id].enabled = True
self.keys[key_id].key_state = 'Enabled' self.keys[key_id].key_state = "Enabled"
def cancel_key_deletion(self, key_id): def cancel_key_deletion(self, key_id):
self.keys[key_id].key_state = 'Disabled' self.keys[key_id].key_state = "Disabled"
self.keys[key_id].deletion_date = None self.keys[key_id].deletion_date = None
def schedule_key_deletion(self, key_id, pending_window_in_days): def schedule_key_deletion(self, key_id, pending_window_in_days):
if 7 <= pending_window_in_days <= 30: if 7 <= pending_window_in_days <= 30:
self.keys[key_id].enabled = False self.keys[key_id].enabled = False
self.keys[key_id].key_state = 'PendingDeletion' self.keys[key_id].key_state = "PendingDeletion"
self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days) self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days)
return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date) return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date)
def encrypt(self, key_id, plaintext, encryption_context):
key_id = self.any_id_to_key_id(key_id)
ciphertext_blob = encrypt(
master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context
)
arn = self.keys[key_id].arn
return ciphertext_blob, arn
def decrypt(self, ciphertext_blob, encryption_context):
plaintext, key_id = decrypt(
master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
)
arn = self.keys[key_id].arn
return plaintext, arn
def re_encrypt(
self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context
):
destination_key_id = self.any_id_to_key_id(destination_key_id)
plaintext, decrypting_arn = self.decrypt(
ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context
)
new_ciphertext_blob, encrypting_arn = self.encrypt(
key_id=destination_key_id, plaintext=plaintext, encryption_context=destination_encryption_context
)
return new_ciphertext_blob, decrypting_arn, encrypting_arn
def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens): def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens):
key = self.keys[self.get_key_id(key_id)] key_id = self.any_id_to_key_id(key_id)
if key_spec: if key_spec:
if key_spec == 'AES_128': # Note: Actual validation of key_spec is done in kms.responses
bytes = 16 if key_spec == "AES_128":
plaintext_len = 16
else: else:
bytes = 32 plaintext_len = 32
else: else:
bytes = number_of_bytes plaintext_len = number_of_bytes
plaintext = os.urandom(bytes) plaintext = os.urandom(plaintext_len)
return plaintext, key.arn ciphertext_blob, arn = self.encrypt(key_id=key_id, plaintext=plaintext, encryption_context=encryption_context)
return plaintext, ciphertext_blob, arn
kms_backends = {} kms_backends = {}

View File

@ -2,13 +2,16 @@ from __future__ import unicode_literals
import base64 import base64
import json import json
import os
import re import re
import six import six
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from .models import kms_backends from .models import kms_backends
from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException
ACCOUNT_ID = "012345678912"
reserved_aliases = [ reserved_aliases = [
'alias/aws/ebs', 'alias/aws/ebs',
'alias/aws/s3', 'alias/aws/s3',
@ -21,13 +24,86 @@ class KmsResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
return json.loads(self.body) params = json.loads(self.body)
for key in ("Plaintext", "CiphertextBlob"):
if key in params:
params[key] = base64.b64decode(params[key].encode("utf-8"))
return params
@property @property
def kms_backend(self): def kms_backend(self):
return kms_backends[self.region] return kms_backends[self.region]
def _display_arn(self, key_id):
if key_id.startswith("arn:"):
return key_id
if key_id.startswith("alias/"):
id_type = ""
else:
id_type = "key/"
return "arn:aws:kms:{region}:{account}:{id_type}{key_id}".format(
region=self.region, account=ACCOUNT_ID, id_type=id_type, key_id=key_id
)
def _validate_cmk_id(self, key_id):
"""Determine whether a CMK ID exists.
- raw key ID
- key ARN
"""
is_arn = key_id.startswith("arn:") and ":key/" in key_id
is_raw_key_id = re.match(r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", key_id, re.IGNORECASE)
if not is_arn and not is_raw_key_id:
raise NotFoundException("Invalid keyId {key_id}".format(key_id=key_id))
cmk_id = self.kms_backend.get_key_id(key_id)
if cmk_id not in self.kms_backend.keys:
raise NotFoundException("Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id)))
def _validate_alias(self, key_id):
"""Determine whether an alias exists.
- alias name
- alias ARN
"""
error = NotFoundException("Alias {key_id} is not found.".format(key_id=self._display_arn(key_id)))
is_arn = key_id.startswith("arn:") and ":alias/" in key_id
is_name = key_id.startswith("alias/")
if not is_arn and not is_name:
raise error
alias_name = self.kms_backend.get_alias_name(key_id)
cmk_id = self.kms_backend.get_key_id_from_alias(alias_name)
if cmk_id is None:
raise error
def _validate_key_id(self, key_id):
"""Determine whether or not a key ID exists.
- raw key ID
- key ARN
- alias name
- alias ARN
"""
is_alias_arn = key_id.startswith("arn:") and ":alias/" in key_id
is_alias_name = key_id.startswith("alias/")
if is_alias_arn or is_alias_name:
self._validate_alias(key_id)
return
self._validate_cmk_id(key_id)
def create_key(self): def create_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
policy = self.parameters.get('Policy') policy = self.parameters.get('Policy')
key_usage = self.parameters.get('KeyUsage') key_usage = self.parameters.get('KeyUsage')
description = self.parameters.get('Description') description = self.parameters.get('Description')
@ -38,20 +114,31 @@ class KmsResponse(BaseResponse):
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def update_key_description(self): def update_key_description(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
description = self.parameters.get('Description') description = self.parameters.get('Description')
self._validate_cmk_id(key_id)
self.kms_backend.update_key_description(key_id, description) self.kms_backend.update_key_description(key_id, description)
return json.dumps(None) return json.dumps(None)
def tag_resource(self): def tag_resource(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
tags = self.parameters.get('Tags') tags = self.parameters.get('Tags')
self._validate_cmk_id(key_id)
self.kms_backend.tag_resource(key_id, tags) self.kms_backend.tag_resource(key_id, tags)
return json.dumps({}) return json.dumps({})
def list_resource_tags(self): def list_resource_tags(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
self._validate_cmk_id(key_id)
tags = self.kms_backend.list_resource_tags(key_id) tags = self.kms_backend.list_resource_tags(key_id)
return json.dumps({ return json.dumps({
"Tags": tags, "Tags": tags,
@ -60,17 +147,19 @@ class KmsResponse(BaseResponse):
}) })
def describe_key(self): def describe_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
try:
key = self.kms_backend.describe_key( self._validate_key_id(key_id)
self.kms_backend.get_key_id(key_id))
except KeyError: key = self.kms_backend.describe_key(
headers = dict(self.headers) self.kms_backend.get_key_id(key_id)
headers['status'] = 404 )
return "{}", headers
return json.dumps(key.to_dict()) return json.dumps(key.to_dict())
def list_keys(self): def list_keys(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html"""
keys = self.kms_backend.list_keys() keys = self.kms_backend.list_keys()
return json.dumps({ return json.dumps({
@ -85,6 +174,7 @@ class KmsResponse(BaseResponse):
}) })
def create_alias(self): def create_alias(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html"""
alias_name = self.parameters['AliasName'] alias_name = self.parameters['AliasName']
target_key_id = self.parameters['TargetKeyId'] target_key_id = self.parameters['TargetKeyId']
@ -110,27 +200,31 @@ class KmsResponse(BaseResponse):
raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} ' raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} '
'already exists'.format(region=self.region, alias_name=alias_name)) 'already exists'.format(region=self.region, alias_name=alias_name))
self._validate_cmk_id(target_key_id)
self.kms_backend.add_alias(target_key_id, alias_name) self.kms_backend.add_alias(target_key_id, alias_name)
return json.dumps(None) return json.dumps(None)
def delete_alias(self): def delete_alias(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html"""
alias_name = self.parameters['AliasName'] alias_name = self.parameters['AliasName']
if not alias_name.startswith('alias/'): if not alias_name.startswith('alias/'):
raise ValidationException('Invalid identifier') raise ValidationException('Invalid identifier')
if not self.kms_backend.alias_exists(alias_name): self._validate_alias(alias_name)
raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:'
'{alias_name} is not found.'.format(region=self.region, alias_name=alias_name))
self.kms_backend.delete_alias(alias_name) self.kms_backend.delete_alias(alias_name)
return json.dumps(None) return json.dumps(None)
def list_aliases(self): def list_aliases(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html"""
region = self.region region = self.region
# TODO: The actual API can filter on KeyId.
response_aliases = [ response_aliases = [
{ {
'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region, 'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region,
@ -155,191 +249,239 @@ class KmsResponse(BaseResponse):
}) })
def enable_key_rotation(self): def enable_key_rotation(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.enable_key_rotation(key_id)
except KeyError: self.kms_backend.enable_key_rotation(key_id)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def disable_key_rotation(self): def disable_key_rotation(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.disable_key_rotation(key_id)
except KeyError: self.kms_backend.disable_key_rotation(key_id)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def get_key_rotation_status(self): def get_key_rotation_status(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
rotation_enabled = self.kms_backend.get_key_rotation_status(key_id)
except KeyError: rotation_enabled = self.kms_backend.get_key_rotation_status(key_id)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'KeyRotationEnabled': rotation_enabled}) return json.dumps({'KeyRotationEnabled': rotation_enabled})
def put_key_policy(self): def put_key_policy(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
policy_name = self.parameters.get('PolicyName') policy_name = self.parameters.get('PolicyName')
policy = self.parameters.get('Policy') policy = self.parameters.get('Policy')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
_assert_default_policy(policy_name) _assert_default_policy(policy_name)
try: self._validate_cmk_id(key_id)
self.kms_backend.put_key_policy(key_id, policy)
except KeyError: self.kms_backend.put_key_policy(key_id, policy)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def get_key_policy(self): def get_key_policy(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
policy_name = self.parameters.get('PolicyName') policy_name = self.parameters.get('PolicyName')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
_assert_default_policy(policy_name) _assert_default_policy(policy_name)
try: self._validate_cmk_id(key_id)
return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)})
except KeyError: return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)})
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
def list_key_policies(self): def list_key_policies(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.describe_key(key_id)
except KeyError: self.kms_backend.describe_key(key_id)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) return json.dumps({'Truncated': False, 'PolicyNames': ['default']})
def encrypt(self): def encrypt(self):
""" """https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html"""
We perform no encryption, we just encode the value as base64 and then key_id = self.parameters.get("KeyId")
decode it in decrypt(). encryption_context = self.parameters.get('EncryptionContext', {})
""" plaintext = self.parameters.get("Plaintext")
value = self.parameters.get("Plaintext")
if isinstance(value, six.text_type): self._validate_key_id(key_id)
value = value.encode('utf-8')
return json.dumps({"CiphertextBlob": base64.b64encode(value).decode("utf-8"), 'KeyId': 'key_id'}) if isinstance(plaintext, six.text_type):
plaintext = plaintext.encode('utf-8')
ciphertext_blob, arn = self.kms_backend.encrypt(
key_id=key_id,
plaintext=plaintext,
encryption_context=encryption_context,
)
ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8")
return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn})
def decrypt(self): def decrypt(self):
# TODO refuse decode if EncryptionContext is not the same as when it was encrypted / generated """https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob")
encryption_context = self.parameters.get('EncryptionContext', {})
value = self.parameters.get("CiphertextBlob") plaintext, arn = self.kms_backend.decrypt(
try: ciphertext_blob=ciphertext_blob,
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'}) encryption_context=encryption_context,
except UnicodeDecodeError: )
# Generate data key will produce random bytes which when decrypted is still returned as base64
return json.dumps({"Plaintext": value}) plaintext_response = base64.b64encode(plaintext).decode("utf-8")
return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn})
def re_encrypt(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob")
source_encryption_context = self.parameters.get("SourceEncryptionContext", {})
destination_key_id = self.parameters.get("DestinationKeyId")
destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {})
self._validate_cmk_id(destination_key_id)
new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt(
ciphertext_blob=ciphertext_blob,
source_encryption_context=source_encryption_context,
destination_key_id=destination_key_id,
destination_encryption_context=destination_encryption_context,
)
response_ciphertext_blob = base64.b64encode(new_ciphertext_blob).decode("utf-8")
return json.dumps(
{"CiphertextBlob": response_ciphertext_blob, "KeyId": encrypting_arn, "SourceKeyId": decrypting_arn}
)
def disable_key(self): def disable_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.disable_key(key_id)
except KeyError: self.kms_backend.disable_key(key_id)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def enable_key(self): def enable_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.enable_key(key_id)
except KeyError: self.kms_backend.enable_key(key_id)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None) return json.dumps(None)
def cancel_key_deletion(self): def cancel_key_deletion(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
self.kms_backend.cancel_key_deletion(key_id)
except KeyError: self.kms_backend.cancel_key_deletion(key_id)
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'KeyId': key_id}) return json.dumps({'KeyId': key_id})
def schedule_key_deletion(self): def schedule_key_deletion(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
if self.parameters.get('PendingWindowInDays') is None: if self.parameters.get('PendingWindowInDays') is None:
pending_window_in_days = 30 pending_window_in_days = 30
else: else:
pending_window_in_days = self.parameters.get('PendingWindowInDays') pending_window_in_days = self.parameters.get('PendingWindowInDays')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try: self._validate_cmk_id(key_id)
return json.dumps({
'KeyId': key_id, return json.dumps({
'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) 'KeyId': key_id,
}) 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days)
except KeyError: })
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
def generate_data_key(self): def generate_data_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html"""
key_id = self.parameters.get('KeyId') key_id = self.parameters.get('KeyId')
encryption_context = self.parameters.get('EncryptionContext') encryption_context = self.parameters.get('EncryptionContext', {})
number_of_bytes = self.parameters.get('NumberOfBytes') number_of_bytes = self.parameters.get('NumberOfBytes')
key_spec = self.parameters.get('KeySpec') key_spec = self.parameters.get('KeySpec')
grant_tokens = self.parameters.get('GrantTokens') grant_tokens = self.parameters.get('GrantTokens')
# Param validation # Param validation
if key_id.startswith('alias'): self._validate_key_id(key_id)
if self.kms_backend.get_key_id_from_alias(key_id) is None:
raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:{alias_name} is not found.'.format(
region=self.region, alias_name=key_id))
else:
if self.kms_backend.get_key_id(key_id) not in self.kms_backend.keys:
raise NotFoundException('Invalid keyId')
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0): if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
raise ValidationException("1 validation error detected: Value '2048' at 'numberOfBytes' failed " raise ValidationException((
"to satisfy constraint: Member must have value less than or " "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed "
"equal to 1024") "to satisfy constraint: Member must have value less than or "
"equal to 1024"
).format(number_of_bytes=number_of_bytes))
if key_spec and key_spec not in ('AES_256', 'AES_128'): if key_spec and key_spec not in ('AES_256', 'AES_128'):
raise ValidationException("1 validation error detected: Value 'AES_257' at 'keySpec' failed " raise ValidationException((
"to satisfy constraint: Member must satisfy enum value set: " "1 validation error detected: Value '{key_spec}' at 'keySpec' failed "
"[AES_256, AES_128]") "to satisfy constraint: Member must satisfy enum value set: "
"[AES_256, AES_128]"
).format(key_spec=key_spec))
if not key_spec and not number_of_bytes: if not key_spec and not number_of_bytes:
raise ValidationException("Please specify either number of bytes or key spec.") raise ValidationException("Please specify either number of bytes or key spec.")
if key_spec and number_of_bytes: if key_spec and number_of_bytes:
raise ValidationException("Please specify either number of bytes or key spec.") raise ValidationException("Please specify either number of bytes or key spec.")
plaintext, key_arn = self.kms_backend.generate_data_key(key_id, encryption_context, plaintext, ciphertext_blob, key_arn = self.kms_backend.generate_data_key(
number_of_bytes, key_spec, grant_tokens) key_id=key_id,
encryption_context=encryption_context,
number_of_bytes=number_of_bytes,
key_spec=key_spec,
grant_tokens=grant_tokens
)
plaintext = base64.b64encode(plaintext).decode() plaintext_response = base64.b64encode(plaintext).decode("utf-8")
ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8")
return json.dumps({ return json.dumps({
'CiphertextBlob': plaintext, 'CiphertextBlob': ciphertext_blob_response,
'Plaintext': plaintext, 'Plaintext': plaintext_response,
'KeyId': key_arn # not alias 'KeyId': key_arn # not alias
}) })
def generate_data_key_without_plaintext(self): def generate_data_key_without_plaintext(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html"""
result = json.loads(self.generate_data_key()) result = json.loads(self.generate_data_key())
del result['Plaintext'] del result['Plaintext']
return json.dumps(result) return json.dumps(result)
def generate_random(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html"""
number_of_bytes = self.parameters.get("NumberOfBytes")
def _assert_valid_key_id(key_id): if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE): raise ValidationException((
raise NotFoundException('Invalid keyId') "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed "
"to satisfy constraint: Member must have value less than or "
"equal to 1024"
).format(number_of_bytes=number_of_bytes))
entropy = os.urandom(number_of_bytes)
response_entropy = base64.b64encode(entropy).decode("utf-8")
return json.dumps({"Plaintext": response_entropy})
def _assert_default_policy(policy_name): def _assert_default_policy(policy_name):

View File

@ -1,7 +1,142 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from collections import namedtuple
import io
import os
import struct
import uuid import uuid
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
from .exceptions import InvalidCiphertextException, AccessDeniedException, NotFoundException
MASTER_KEY_LEN = 32
KEY_ID_LEN = 36
IV_LEN = 12
TAG_LEN = 16
HEADER_LEN = KEY_ID_LEN + IV_LEN + TAG_LEN
# NOTE: This is just a simple binary format. It is not what KMS actually does.
CIPHERTEXT_HEADER_FORMAT = ">{key_id_len}s{iv_len}s{tag_len}s".format(
key_id_len=KEY_ID_LEN, iv_len=IV_LEN, tag_len=TAG_LEN
)
Ciphertext = namedtuple("Ciphertext", ("key_id", "iv", "ciphertext", "tag"))
def generate_key_id(): def generate_key_id():
return str(uuid.uuid4()) return str(uuid.uuid4())
def generate_data_key(number_of_bytes):
"""Generate a data key."""
return os.urandom(number_of_bytes)
def generate_master_key():
"""Generate a master key."""
return generate_data_key(MASTER_KEY_LEN)
def _serialize_ciphertext_blob(ciphertext):
"""Serialize Ciphertext object into a ciphertext blob.
NOTE: This is just a simple binary format. It is not what KMS actually does.
"""
header = struct.pack(CIPHERTEXT_HEADER_FORMAT, ciphertext.key_id.encode("utf-8"), ciphertext.iv, ciphertext.tag)
return header + ciphertext.ciphertext
def _deserialize_ciphertext_blob(ciphertext_blob):
"""Deserialize ciphertext blob into a Ciphertext object.
NOTE: This is just a simple binary format. It is not what KMS actually does.
"""
header = ciphertext_blob[:HEADER_LEN]
ciphertext = ciphertext_blob[HEADER_LEN:]
key_id, iv, tag = struct.unpack(CIPHERTEXT_HEADER_FORMAT, header)
return Ciphertext(key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag)
def _serialize_encryption_context(encryption_context):
"""Serialize encryption context for use a AAD.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
"""
aad = io.BytesIO()
for key, value in sorted(encryption_context.items(), key=lambda x: x[0]):
aad.write(key.encode("utf-8"))
aad.write(value.encode("utf-8"))
return aad.getvalue()
def encrypt(master_keys, key_id, plaintext, encryption_context):
"""Encrypt data using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
NOTE: This function is NOT compatible with KMS APIs.
:param dict master_keys: Mapping of a KmsBackend's known master keys
:param str key_id: Key ID of moto master key
:param bytes plaintext: Plaintext data to encrypt
:param dict[str, str] encryption_context: KMS-style encryption context
:returns: Moto-structured ciphertext blob encrypted under a moto master key in master_keys
:rtype: bytes
"""
try:
key = master_keys[key_id]
except KeyError:
is_alias = key_id.startswith("alias/") or ":alias/" in key_id
raise NotFoundException(
"{id_type} {key_id} is not found.".format(id_type="Alias" if is_alias else "keyId", key_id=key_id)
)
iv = os.urandom(IV_LEN)
aad = _serialize_encryption_context(encryption_context=encryption_context)
encryptor = Cipher(algorithms.AES(key.key_material), modes.GCM(iv), backend=default_backend()).encryptor()
encryptor.authenticate_additional_data(aad)
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return _serialize_ciphertext_blob(
ciphertext=Ciphertext(key_id=key_id, iv=iv, ciphertext=ciphertext, tag=encryptor.tag)
)
def decrypt(master_keys, ciphertext_blob, encryption_context):
"""Decrypt a ciphertext blob using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
NOTE: This function is NOT compatible with KMS APIs.
:param dict master_keys: Mapping of a KmsBackend's known master keys
:param bytes ciphertext_blob: moto-structured ciphertext blob encrypted under a moto master key in master_keys
:param dict[str, str] encryption_context: KMS-style encryption context
:returns: plaintext bytes and moto key ID
:rtype: bytes and str
"""
try:
ciphertext = _deserialize_ciphertext_blob(ciphertext_blob=ciphertext_blob)
except Exception:
raise InvalidCiphertextException()
aad = _serialize_encryption_context(encryption_context=encryption_context)
try:
key = master_keys[ciphertext.key_id]
except KeyError:
raise AccessDeniedException(
"The ciphertext refers to a customer master key that does not exist, "
"does not exist in this region, or you are not allowed to access."
)
try:
decryptor = Cipher(
algorithms.AES(key.key_material), modes.GCM(ciphertext.iv, ciphertext.tag), backend=default_backend()
).decryptor()
decryptor.authenticate_additional_data(aad)
plaintext = decryptor.update(ciphertext.ciphertext) + decryptor.finalize()
except Exception:
raise InvalidCiphertextException()
return plaintext, ciphertext.key_id

View File

@ -41,7 +41,7 @@ class LogStream:
self.region = region self.region = region
self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format( self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format(
region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name) region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name)
self.creationTime = unix_time_millis() self.creationTime = int(unix_time_millis())
self.firstEventTimestamp = None self.firstEventTimestamp = None
self.lastEventTimestamp = None self.lastEventTimestamp = None
self.lastIngestionTime = None self.lastIngestionTime = None
@ -80,7 +80,7 @@ class LogStream:
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: ensure sequence_token # TODO: ensure sequence_token
# TODO: to be thread safe this would need a lock # TODO: to be thread safe this would need a lock
self.lastIngestionTime = unix_time_millis() self.lastIngestionTime = int(unix_time_millis())
# TODO: make this match AWS if possible # TODO: make this match AWS if possible
self.storedBytes += sum([len(log_event["message"]) for log_event in log_events]) self.storedBytes += sum([len(log_event["message"]) for log_event in log_events])
self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events] self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events]
@ -115,6 +115,8 @@ class LogStream:
events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]] events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]]
if next_index + limit < len(self.events): if next_index + limit < len(self.events):
next_index += limit next_index += limit
else:
next_index = len(self.events)
back_index -= limit back_index -= limit
if back_index <= 0: if back_index <= 0:
@ -146,7 +148,7 @@ class LogGroup:
self.region = region self.region = region
self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format( self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format(
region=region, log_group=name) region=region, log_group=name)
self.creationTime = unix_time_millis() self.creationTime = int(unix_time_millis())
self.tags = tags self.tags = tags
self.streams = dict() # {name: LogStream} self.streams = dict() # {name: LogStream}
self.retentionInDays = None # AWS defaults to Never Expire for log group retention self.retentionInDays = None # AWS defaults to Never Expire for log group retention

View File

@ -74,7 +74,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
automated_snapshot_retention_period, port, cluster_version, automated_snapshot_retention_period, port, cluster_version,
allow_version_upgrade, number_of_nodes, publicly_accessible, allow_version_upgrade, number_of_nodes, publicly_accessible,
encrypted, region_name, tags=None, iam_roles_arn=None, encrypted, region_name, tags=None, iam_roles_arn=None,
restored_from_snapshot=False): enhanced_vpc_routing=None, restored_from_snapshot=False):
super(Cluster, self).__init__(region_name, tags) super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier self.cluster_identifier = cluster_identifier
@ -85,6 +85,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
self.master_user_password = master_user_password self.master_user_password = master_user_password
self.db_name = db_name if db_name else "dev" self.db_name = db_name if db_name else "dev"
self.vpc_security_group_ids = vpc_security_group_ids self.vpc_security_group_ids = vpc_security_group_ids
self.enhanced_vpc_routing = enhanced_vpc_routing if enhanced_vpc_routing is not None else False
self.cluster_subnet_group_name = cluster_subnet_group_name self.cluster_subnet_group_name = cluster_subnet_group_name
self.publicly_accessible = publicly_accessible self.publicly_accessible = publicly_accessible
self.encrypted = encrypted self.encrypted = encrypted
@ -154,6 +155,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
port=properties.get('Port'), port=properties.get('Port'),
cluster_version=properties.get('ClusterVersion'), cluster_version=properties.get('ClusterVersion'),
allow_version_upgrade=properties.get('AllowVersionUpgrade'), allow_version_upgrade=properties.get('AllowVersionUpgrade'),
enhanced_vpc_routing=properties.get('EnhancedVpcRouting'),
number_of_nodes=properties.get('NumberOfNodes'), number_of_nodes=properties.get('NumberOfNodes'),
publicly_accessible=properties.get("PubliclyAccessible"), publicly_accessible=properties.get("PubliclyAccessible"),
encrypted=properties.get("Encrypted"), encrypted=properties.get("Encrypted"),
@ -241,6 +243,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
'ClusterCreateTime': self.create_time, 'ClusterCreateTime': self.create_time,
"PendingModifiedValues": [], "PendingModifiedValues": [],
"Tags": self.tags, "Tags": self.tags,
"EnhancedVpcRouting": self.enhanced_vpc_routing,
"IamRoles": [{ "IamRoles": [{
"ApplyStatus": "in-sync", "ApplyStatus": "in-sync",
"IamRoleArn": iam_role_arn "IamRoleArn": iam_role_arn
@ -427,6 +430,7 @@ class Snapshot(TaggableResourceMixin, BaseModel):
'NumberOfNodes': self.cluster.number_of_nodes, 'NumberOfNodes': self.cluster.number_of_nodes,
'DBName': self.cluster.db_name, 'DBName': self.cluster.db_name,
'Tags': self.tags, 'Tags': self.tags,
'EnhancedVpcRouting': self.cluster.enhanced_vpc_routing,
"IamRoles": [{ "IamRoles": [{
"ApplyStatus": "in-sync", "ApplyStatus": "in-sync",
"IamRoleArn": iam_role_arn "IamRoleArn": iam_role_arn
@ -678,7 +682,8 @@ class RedshiftBackend(BaseBackend):
"number_of_nodes": snapshot.cluster.number_of_nodes, "number_of_nodes": snapshot.cluster.number_of_nodes,
"encrypted": snapshot.cluster.encrypted, "encrypted": snapshot.cluster.encrypted,
"tags": snapshot.cluster.tags, "tags": snapshot.cluster.tags,
"restored_from_snapshot": True "restored_from_snapshot": True,
"enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing
} }
create_kwargs.update(kwargs) create_kwargs.update(kwargs)
return self.create_cluster(**create_kwargs) return self.create_cluster(**create_kwargs)

View File

@ -135,6 +135,7 @@ class RedshiftResponse(BaseResponse):
"region_name": self.region, "region_name": self.region,
"tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')), "tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')),
"iam_roles_arn": self._get_iam_roles(), "iam_roles_arn": self._get_iam_roles(),
"enhanced_vpc_routing": self._get_param('EnhancedVpcRouting'),
} }
cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json() cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json()
cluster['ClusterStatus'] = 'creating' cluster['ClusterStatus'] = 'creating'
@ -150,6 +151,7 @@ class RedshiftResponse(BaseResponse):
}) })
def restore_from_cluster_snapshot(self): def restore_from_cluster_snapshot(self):
enhanced_vpc_routing = self._get_bool_param('EnhancedVpcRouting')
restore_kwargs = { restore_kwargs = {
"snapshot_identifier": self._get_param('SnapshotIdentifier'), "snapshot_identifier": self._get_param('SnapshotIdentifier'),
"cluster_identifier": self._get_param('ClusterIdentifier'), "cluster_identifier": self._get_param('ClusterIdentifier'),
@ -171,6 +173,8 @@ class RedshiftResponse(BaseResponse):
"region_name": self.region, "region_name": self.region,
"iam_roles_arn": self._get_iam_roles(), "iam_roles_arn": self._get_iam_roles(),
} }
if enhanced_vpc_routing is not None:
restore_kwargs['enhanced_vpc_routing'] = enhanced_vpc_routing
cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json() cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json()
cluster['ClusterStatus'] = 'creating' cluster['ClusterStatus'] = 'creating'
return self.get_response({ return self.get_response({
@ -218,6 +222,7 @@ class RedshiftResponse(BaseResponse):
"publicly_accessible": self._get_param("PubliclyAccessible"), "publicly_accessible": self._get_param("PubliclyAccessible"),
"encrypted": self._get_param("Encrypted"), "encrypted": self._get_param("Encrypted"),
"iam_roles_arn": self._get_iam_roles(), "iam_roles_arn": self._get_iam_roles(),
"enhanced_vpc_routing": self._get_param("EnhancedVpcRouting")
} }
cluster_kwargs = {} cluster_kwargs = {}
# We only want parameters that were actually passed in, otherwise # We only want parameters that were actually passed in, otherwise

View File

@ -305,6 +305,7 @@ class Route53Backend(BaseBackend):
def list_tags_for_resource(self, resource_id): def list_tags_for_resource(self, resource_id):
if resource_id in self.resource_tags: if resource_id in self.resource_tags:
return self.resource_tags[resource_id] return self.resource_tags[resource_id]
return {}
def get_all_hosted_zones(self): def get_all_hosted_zones(self):
return self.zones.values() return self.zones.values()

View File

@ -20,7 +20,7 @@ from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, Missi
MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError
from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \ from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \
FakeTag FakeTag
from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url from .utils import bucket_name_from_url, clean_key_name, undo_clean_key_name, metadata_from_headers, parse_region_from_url
from xml.dom import minidom from xml.dom import minidom
@ -451,17 +451,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
continuation_token = querystring.get('continuation-token', [None])[0] continuation_token = querystring.get('continuation-token', [None])[0]
start_after = querystring.get('start-after', [None])[0] start_after = querystring.get('start-after', [None])[0]
# sort the combination of folders and keys into lexicographical order
all_keys = result_keys + result_folders
all_keys.sort(key=self._get_name)
if continuation_token or start_after: if continuation_token or start_after:
limit = continuation_token or start_after limit = continuation_token or start_after
if not delimiter: all_keys = self._get_results_from_token(all_keys, limit)
result_keys = self._get_results_from_token(result_keys, limit)
else:
result_folders = self._get_results_from_token(result_folders, limit)
if not delimiter: truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys)
result_keys, is_truncated, next_continuation_token = self._truncate_result(result_keys, max_keys) result_keys, result_folders = self._split_truncated_keys(truncated_keys)
else:
result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys)
key_count = len(result_keys) + len(result_folders) key_count = len(result_keys) + len(result_folders)
@ -479,6 +478,24 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
start_after=None if continuation_token else start_after start_after=None if continuation_token else start_after
) )
@staticmethod
def _get_name(key):
if isinstance(key, FakeKey):
return key.name
else:
return key
@staticmethod
def _split_truncated_keys(truncated_keys):
result_keys = []
result_folders = []
for key in truncated_keys:
if isinstance(key, FakeKey):
result_keys.append(key)
else:
result_folders.append(key)
return result_keys, result_folders
def _get_results_from_token(self, result_keys, token): def _get_results_from_token(self, result_keys, token):
continuation_index = 0 continuation_index = 0
for key in result_keys: for key in result_keys:
@ -694,7 +711,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
for k in keys: for k in keys:
key_name = k.firstChild.nodeValue key_name = k.firstChild.nodeValue
success = self.backend.delete_key(bucket_name, key_name) success = self.backend.delete_key(bucket_name, undo_clean_key_name(key_name))
if success: if success:
deleted_names.append(key_name) deleted_names.append(key_name)
else: else:

View File

@ -15,4 +15,6 @@ url_paths = {
'{0}/(?P<key_or_bucket_name>[^/]+)/?$': S3ResponseInstance.ambiguous_response, '{0}/(?P<key_or_bucket_name>[^/]+)/?$': S3ResponseInstance.ambiguous_response,
# path-based bucket + key # path-based bucket + key
'{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)': S3ResponseInstance.key_response, '{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)': S3ResponseInstance.key_response,
# subdomain bucket + key with empty first part of path
'{0}//(?P<key_name>.*)$': S3ResponseInstance.key_response,
} }

View File

@ -5,7 +5,7 @@ import os
from boto.s3.key import Key from boto.s3.key import Key
import re import re
import six import six
from six.moves.urllib.parse import urlparse, unquote from six.moves.urllib.parse import urlparse, unquote, quote
import sys import sys
@ -71,10 +71,15 @@ def metadata_from_headers(headers):
def clean_key_name(key_name): def clean_key_name(key_name):
if six.PY2: if six.PY2:
return unquote(key_name.encode('utf-8')).decode('utf-8') return unquote(key_name.encode('utf-8')).decode('utf-8')
return unquote(key_name) return unquote(key_name)
def undo_clean_key_name(key_name):
if six.PY2:
return quote(key_name.encode('utf-8')).decode('utf-8')
return quote(key_name)
class _VersionedKeyStore(dict): class _VersionedKeyStore(dict):
""" A simplified/modified version of Django's `MultiValueDict` taken from: """ A simplified/modified version of Django's `MultiValueDict` taken from:

View File

@ -154,9 +154,9 @@ class SecretsManagerBackend(BaseBackend):
return version_id return version_id
def put_secret_value(self, secret_id, secret_string, version_stages): def put_secret_value(self, secret_id, secret_string, secret_binary, version_stages):
version_id = self._add_secret(secret_id, secret_string, version_stages=version_stages) version_id = self._add_secret(secret_id, secret_string, secret_binary, version_stages=version_stages)
response = json.dumps({ response = json.dumps({
'ARN': secret_arn(self.region, secret_id), 'ARN': secret_arn(self.region, secret_id),

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.secretsmanager.exceptions import InvalidRequestException
from .models import secretsmanager_backends from .models import secretsmanager_backends
@ -71,10 +72,14 @@ class SecretsManagerResponse(BaseResponse):
def put_secret_value(self): def put_secret_value(self):
secret_id = self._get_param('SecretId', if_none='') secret_id = self._get_param('SecretId', if_none='')
secret_string = self._get_param('SecretString', if_none='') secret_string = self._get_param('SecretString')
secret_binary = self._get_param('SecretBinary')
if not secret_binary and not secret_string:
raise InvalidRequestException('You must provide either SecretString or SecretBinary.')
version_stages = self._get_param('VersionStages', if_none=['AWSCURRENT']) version_stages = self._get_param('VersionStages', if_none=['AWSCURRENT'])
return secretsmanager_backends[self.region].put_secret_value( return secretsmanager_backends[self.region].put_secret_value(
secret_id=secret_id, secret_id=secret_id,
secret_binary=secret_binary,
secret_string=secret_string, secret_string=secret_string,
version_stages=version_stages, version_stages=version_stages,
) )

View File

@ -174,10 +174,11 @@ def create_backend_app(service):
backend_app.url_map.converters['regex'] = RegexConverter backend_app.url_map.converters['regex'] = RegexConverter
backend = list(BACKENDS[service].values())[0] backend = list(BACKENDS[service].values())[0]
for url_path, handler in backend.flask_paths.items(): for url_path, handler in backend.flask_paths.items():
view_func = convert_flask_to_httpretty_response(handler)
if handler.__name__ == 'dispatch': if handler.__name__ == 'dispatch':
endpoint = '{0}.dispatch'.format(handler.__self__.__name__) endpoint = '{0}.dispatch'.format(handler.__self__.__name__)
else: else:
endpoint = None endpoint = view_func.__name__
original_endpoint = endpoint original_endpoint = endpoint
index = 2 index = 2
@ -191,7 +192,7 @@ def create_backend_app(service):
url_path, url_path,
endpoint=endpoint, endpoint=endpoint,
methods=HTTP_METHODS, methods=HTTP_METHODS,
view_func=convert_flask_to_httpretty_response(handler), view_func=view_func,
strict_slashes=False, strict_slashes=False,
) )

View File

@ -40,3 +40,11 @@ class InvalidParameterValue(RESTError):
def __init__(self, message): def __init__(self, message):
super(InvalidParameterValue, self).__init__( super(InvalidParameterValue, self).__init__(
"InvalidParameterValue", message) "InvalidParameterValue", message)
class InternalError(RESTError):
code = 500
def __init__(self, message):
super(InternalError, self).__init__(
"InternalFailure", message)

View File

@ -18,7 +18,7 @@ from moto.awslambda import lambda_backends
from .exceptions import ( from .exceptions import (
SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter, SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter,
InvalidParameterValue InvalidParameterValue, InternalError
) )
from .utils import make_arn_for_topic, make_arn_for_subscription from .utils import make_arn_for_topic, make_arn_for_subscription
@ -131,13 +131,47 @@ class Subscription(BaseModel):
message_attributes = {} message_attributes = {}
def _field_match(field, rules, message_attributes): def _field_match(field, rules, message_attributes):
if field not in message_attributes:
return False
for rule in rules: for rule in rules:
# TODO: boolean value matching is not supported, SNS behavior unknown
if isinstance(rule, six.string_types): if isinstance(rule, six.string_types):
# only string value matching is supported if field not in message_attributes:
return False
if message_attributes[field]['Value'] == rule: if message_attributes[field]['Value'] == rule:
return True return True
try:
json_data = json.loads(message_attributes[field]['Value'])
if rule in json_data:
return True
except (ValueError, TypeError):
pass
if isinstance(rule, (six.integer_types, float)):
if field not in message_attributes:
return False
if message_attributes[field]['Type'] == 'Number':
attribute_values = [message_attributes[field]['Value']]
elif message_attributes[field]['Type'] == 'String.Array':
try:
attribute_values = json.loads(message_attributes[field]['Value'])
if not isinstance(attribute_values, list):
attribute_values = [attribute_values]
except (ValueError, TypeError):
return False
else:
return False
for attribute_values in attribute_values:
# Even the offical documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
if int(attribute_values * 1000000) == int(rule * 1000000):
return True
if isinstance(rule, dict):
keyword = list(rule.keys())[0]
attributes = list(rule.values())[0]
if keyword == 'exists':
if attributes and field in message_attributes:
return True
elif not attributes and field not in message_attributes:
return True
return False return False
return all(_field_match(field, rules, message_attributes) return all(_field_match(field, rules, message_attributes)
@ -421,7 +455,49 @@ class SNSBackend(BaseBackend):
subscription.attributes[name] = value subscription.attributes[name] = value
if name == 'FilterPolicy': if name == 'FilterPolicy':
subscription._filter_policy = json.loads(value) filter_policy = json.loads(value)
self._validate_filter_policy(filter_policy)
subscription._filter_policy = filter_policy
def _validate_filter_policy(self, value):
# TODO: extend validation checks
combinations = 1
for rules in six.itervalues(value):
combinations *= len(rules)
# Even the offical documentation states the total combination of values must not exceed 100, in reality it is 150
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
if combinations > 150:
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Filter policy is too complex")
for field, rules in six.iteritems(value):
for rule in rules:
if rule is None:
continue
if isinstance(rule, six.string_types):
continue
if isinstance(rule, bool):
continue
if isinstance(rule, (six.integer_types, float)):
if rule <= -1000000000 or rule >= 1000000000:
raise InternalError("Unknown")
continue
if isinstance(rule, dict):
keyword = list(rule.keys())[0]
attributes = list(rule.values())[0]
if keyword == 'anything-but':
continue
elif keyword == 'exists':
if not isinstance(attributes, bool):
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: exists match pattern must be either true or false.")
continue
elif keyword == 'numeric':
continue
elif keyword == 'prefix':
continue
else:
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Unrecognized match type {type}".format(type=keyword))
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null")
sns_backends = {} sns_backends = {}

View File

@ -57,7 +57,16 @@ class SNSResponse(BaseResponse):
transform_value = None transform_value = None
if 'StringValue' in value: if 'StringValue' in value:
transform_value = value['StringValue'] if data_type == 'Number':
try:
transform_value = float(value['StringValue'])
except ValueError:
raise InvalidParameterValue(
"An error occurred (ParameterValueInvalid) "
"when calling the Publish operation: "
"Could not cast message attribute '{0}' value to number.".format(name))
else:
transform_value = value['StringValue']
elif 'BinaryValue' in value: elif 'BinaryValue' in value:
transform_value = value['BinaryValue'] transform_value = value['BinaryValue']
if not transform_value: if not transform_value:

View File

@ -265,6 +265,9 @@ class Queue(BaseModel):
if 'maxReceiveCount' not in self.redrive_policy: if 'maxReceiveCount' not in self.redrive_policy:
raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount') raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount')
# 'maxReceiveCount' is stored as int
self.redrive_policy['maxReceiveCount'] = int(self.redrive_policy['maxReceiveCount'])
for queue in sqs_backends[self.region].queues.values(): for queue in sqs_backends[self.region].queues.values():
if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']: if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']:
self.dead_letter_queue = queue self.dead_letter_queue = queue
@ -424,13 +427,26 @@ class SQSBackend(BaseBackend):
queue_attributes = queue.attributes queue_attributes = queue.attributes
new_queue_attributes = new_queue.attributes new_queue_attributes = new_queue.attributes
static_attributes = (
'DelaySeconds',
'MaximumMessageSize',
'MessageRetentionPeriod',
'Policy',
'QueueArn',
'ReceiveMessageWaitTimeSeconds',
'RedrivePolicy',
'VisibilityTimeout',
'KmsMasterKeyId',
'KmsDataKeyReusePeriodSeconds',
'FifoQueue',
'ContentBasedDeduplication',
)
for key in ['CreatedTimestamp', 'LastModifiedTimestamp']: for key in static_attributes:
queue_attributes.pop(key) if queue_attributes.get(key) != new_queue_attributes.get(key):
new_queue_attributes.pop(key) raise QueueAlreadyExists(
"The specified queue already exists.",
if queue_attributes != new_queue_attributes: )
raise QueueAlreadyExists("The specified queue already exists.")
else: else:
try: try:
kwargs.pop('region') kwargs.pop('region')

View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import stepfunction_backends
from ..core.models import base_decorator
stepfunction_backend = stepfunction_backends['us-east-1']
mock_stepfunctions = base_decorator(stepfunction_backends)

View File

@ -0,0 +1,35 @@
from __future__ import unicode_literals
import json
class AWSError(Exception):
TYPE = None
STATUS = 400
def __init__(self, message, type=None, status=None):
self.message = message
self.type = type if type is not None else self.TYPE
self.status = status if status is not None else self.STATUS
def response(self):
return json.dumps({'__type': self.type, 'message': self.message}), dict(status=self.status)
class ExecutionDoesNotExist(AWSError):
TYPE = 'ExecutionDoesNotExist'
STATUS = 400
class InvalidArn(AWSError):
TYPE = 'InvalidArn'
STATUS = 400
class InvalidName(AWSError):
TYPE = 'InvalidName'
STATUS = 400
class StateMachineDoesNotExist(AWSError):
TYPE = 'StateMachineDoesNotExist'
STATUS = 400

View File

@ -0,0 +1,162 @@
import boto
import re
from datetime import datetime
from moto.core import BaseBackend
from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.sts.models import ACCOUNT_ID
from uuid import uuid4
from .exceptions import ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist
class StateMachine():
def __init__(self, arn, name, definition, roleArn, tags=None):
self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.arn = arn
self.name = name
self.definition = definition
self.roleArn = roleArn
self.tags = tags
class Execution():
def __init__(self, region_name, account_id, state_machine_name, execution_name, state_machine_arn):
execution_arn = 'arn:aws:states:{}:{}:execution:{}:{}'
execution_arn = execution_arn.format(region_name, account_id, state_machine_name, execution_name)
self.execution_arn = execution_arn
self.name = execution_name
self.start_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.state_machine_arn = state_machine_arn
self.status = 'RUNNING'
self.stop_date = None
def stop(self):
self.status = 'SUCCEEDED'
self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now())
class StepFunctionBackend(BaseBackend):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.create_state_machine
# A name must not contain:
# whitespace
# brackets < > { } [ ]
# wildcard characters ? *
# special characters " # % \ ^ | ~ ` $ & , ; : /
invalid_chars_for_name = [' ', '{', '}', '[', ']', '<', '>',
'?', '*',
'"', '#', '%', '\\', '^', '|', '~', '`', '$', '&', ',', ';', ':', '/']
# control characters (U+0000-001F , U+007F-009F )
invalid_unicodes_for_name = [u'\u0000', u'\u0001', u'\u0002', u'\u0003', u'\u0004',
u'\u0005', u'\u0006', u'\u0007', u'\u0008', u'\u0009',
u'\u000A', u'\u000B', u'\u000C', u'\u000D', u'\u000E', u'\u000F',
u'\u0010', u'\u0011', u'\u0012', u'\u0013', u'\u0014',
u'\u0015', u'\u0016', u'\u0017', u'\u0018', u'\u0019',
u'\u001A', u'\u001B', u'\u001C', u'\u001D', u'\u001E', u'\u001F',
u'\u007F',
u'\u0080', u'\u0081', u'\u0082', u'\u0083', u'\u0084', u'\u0085',
u'\u0086', u'\u0087', u'\u0088', u'\u0089',
u'\u008A', u'\u008B', u'\u008C', u'\u008D', u'\u008E', u'\u008F',
u'\u0090', u'\u0091', u'\u0092', u'\u0093', u'\u0094', u'\u0095',
u'\u0096', u'\u0097', u'\u0098', u'\u0099',
u'\u009A', u'\u009B', u'\u009C', u'\u009D', u'\u009E', u'\u009F']
accepted_role_arn_format = re.compile('arn:aws:iam:(?P<account_id>[0-9]{12}):role/.+')
accepted_mchn_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):stateMachine:.+')
accepted_exec_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):execution:.+')
def __init__(self, region_name):
self.state_machines = []
self.executions = []
self.region_name = region_name
self._account_id = None
def create_state_machine(self, name, definition, roleArn, tags=None):
self._validate_name(name)
self._validate_role_arn(roleArn)
arn = 'arn:aws:states:' + self.region_name + ':' + str(self._get_account_id()) + ':stateMachine:' + name
try:
return self.describe_state_machine(arn)
except StateMachineDoesNotExist:
state_machine = StateMachine(arn, name, definition, roleArn, tags)
self.state_machines.append(state_machine)
return state_machine
def list_state_machines(self):
return self.state_machines
def describe_state_machine(self, arn):
self._validate_machine_arn(arn)
sm = next((x for x in self.state_machines if x.arn == arn), None)
if not sm:
raise StateMachineDoesNotExist("State Machine Does Not Exist: '" + arn + "'")
return sm
def delete_state_machine(self, arn):
self._validate_machine_arn(arn)
sm = next((x for x in self.state_machines if x.arn == arn), None)
if sm:
self.state_machines.remove(sm)
def start_execution(self, state_machine_arn):
state_machine_name = self.describe_state_machine(state_machine_arn).name
execution = Execution(region_name=self.region_name,
account_id=self._get_account_id(),
state_machine_name=state_machine_name,
execution_name=str(uuid4()),
state_machine_arn=state_machine_arn)
self.executions.append(execution)
return execution
def stop_execution(self, execution_arn):
execution = next((x for x in self.executions if x.execution_arn == execution_arn), None)
if not execution:
raise ExecutionDoesNotExist("Execution Does Not Exist: '" + execution_arn + "'")
execution.stop()
return execution
def list_executions(self, state_machine_arn):
return [execution for execution in self.executions if execution.state_machine_arn == state_machine_arn]
def describe_execution(self, arn):
self._validate_execution_arn(arn)
exctn = next((x for x in self.executions if x.execution_arn == arn), None)
if not exctn:
raise ExecutionDoesNotExist("Execution Does Not Exist: '" + arn + "'")
return exctn
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def _validate_name(self, name):
if any(invalid_char in name for invalid_char in self.invalid_chars_for_name):
raise InvalidName("Invalid Name: '" + name + "'")
if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name):
raise InvalidName("Invalid Name: '" + name + "'")
def _validate_role_arn(self, role_arn):
self._validate_arn(arn=role_arn,
regex=self.accepted_role_arn_format,
invalid_msg="Invalid Role Arn: '" + role_arn + "'")
def _validate_machine_arn(self, machine_arn):
self._validate_arn(arn=machine_arn,
regex=self.accepted_mchn_arn_format,
invalid_msg="Invalid Role Arn: '" + machine_arn + "'")
def _validate_execution_arn(self, execution_arn):
self._validate_arn(arn=execution_arn,
regex=self.accepted_exec_arn_format,
invalid_msg="Execution Does Not Exist: '" + execution_arn + "'")
def _validate_arn(self, arn, regex, invalid_msg):
match = regex.match(arn)
if not arn or not match:
raise InvalidArn(invalid_msg)
def _get_account_id(self):
return ACCOUNT_ID
stepfunction_backends = {_region.name: StepFunctionBackend(_region.name) for _region in boto.awslambda.regions()}

View File

@ -0,0 +1,138 @@
from __future__ import unicode_literals
import json
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .exceptions import AWSError
from .models import stepfunction_backends
class StepFunctionResponse(BaseResponse):
@property
def stepfunction_backend(self):
return stepfunction_backends[self.region]
@amzn_request_id
def create_state_machine(self):
name = self._get_param('name')
definition = self._get_param('definition')
roleArn = self._get_param('roleArn')
tags = self._get_param('tags')
try:
state_machine = self.stepfunction_backend.create_state_machine(name=name, definition=definition,
roleArn=roleArn,
tags=tags)
response = {
'creationDate': state_machine.creation_date,
'stateMachineArn': state_machine.arn
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def list_state_machines(self):
list_all = self.stepfunction_backend.list_state_machines()
list_all = sorted([{'creationDate': sm.creation_date,
'name': sm.name,
'stateMachineArn': sm.arn} for sm in list_all],
key=lambda x: x['name'])
response = {'stateMachines': list_all}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_state_machine(self):
arn = self._get_param('stateMachineArn')
return self._describe_state_machine(arn)
@amzn_request_id
def _describe_state_machine(self, state_machine_arn):
try:
state_machine = self.stepfunction_backend.describe_state_machine(state_machine_arn)
response = {
'creationDate': state_machine.creation_date,
'stateMachineArn': state_machine.arn,
'definition': state_machine.definition,
'name': state_machine.name,
'roleArn': state_machine.roleArn,
'status': 'ACTIVE'
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def delete_state_machine(self):
arn = self._get_param('stateMachineArn')
try:
self.stepfunction_backend.delete_state_machine(arn)
return 200, {}, json.dumps('{}')
except AWSError as err:
return err.response()
@amzn_request_id
def list_tags_for_resource(self):
arn = self._get_param('resourceArn')
try:
state_machine = self.stepfunction_backend.describe_state_machine(arn)
tags = state_machine.tags or []
except AWSError:
tags = []
response = {'tags': tags}
return 200, {}, json.dumps(response)
@amzn_request_id
def start_execution(self):
arn = self._get_param('stateMachineArn')
execution = self.stepfunction_backend.start_execution(arn)
response = {'executionArn': execution.execution_arn,
'startDate': execution.start_date}
return 200, {}, json.dumps(response)
@amzn_request_id
def list_executions(self):
arn = self._get_param('stateMachineArn')
state_machine = self.stepfunction_backend.describe_state_machine(arn)
executions = self.stepfunction_backend.list_executions(arn)
executions = [{'executionArn': execution.execution_arn,
'name': execution.name,
'startDate': execution.start_date,
'stateMachineArn': state_machine.arn,
'status': execution.status} for execution in executions]
return 200, {}, json.dumps({'executions': executions})
@amzn_request_id
def describe_execution(self):
arn = self._get_param('executionArn')
try:
execution = self.stepfunction_backend.describe_execution(arn)
response = {
'executionArn': arn,
'input': '{}',
'name': execution.name,
'startDate': execution.start_date,
'stateMachineArn': execution.state_machine_arn,
'status': execution.status,
'stopDate': execution.stop_date
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_state_machine_for_execution(self):
arn = self._get_param('executionArn')
try:
execution = self.stepfunction_backend.describe_execution(arn)
return self._describe_state_machine(execution.state_machine_arn)
except AWSError as err:
return err.response()
@amzn_request_id
def stop_execution(self):
arn = self._get_param('executionArn')
execution = self.stepfunction_backend.stop_execution(arn)
response = {'stopDate': execution.stop_date}
return 200, {}, json.dumps(response)

View File

@ -0,0 +1,10 @@
from __future__ import unicode_literals
from .responses import StepFunctionResponse
url_bases = [
"https?://states.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': StepFunctionResponse.dispatch,
}

View File

@ -10,6 +10,7 @@ boto>=2.45.0
boto3>=1.4.4 boto3>=1.4.4
botocore>=1.12.13 botocore>=1.12.13
six>=1.9 six>=1.9
parameterized>=0.7.0
prompt-toolkit==1.0.14 prompt-toolkit==1.0.14
click==6.7 click==6.7
inflection==0.3.1 inflection==0.3.1

View File

@ -981,11 +981,13 @@ def test_api_keys():
apikey['value'].should.equal(apikey_value) apikey['value'].should.equal(apikey_value)
apikey_name = 'TESTKEY2' apikey_name = 'TESTKEY2'
payload = {'name': apikey_name } payload = {'name': apikey_name, 'tags': {'tag1': 'test_tag1', 'tag2': '1'}}
response = client.create_api_key(**payload) response = client.create_api_key(**payload)
apikey_id = response['id'] apikey_id = response['id']
apikey = client.get_api_key(apiKey=apikey_id) apikey = client.get_api_key(apiKey=apikey_id)
apikey['name'].should.equal(apikey_name) apikey['name'].should.equal(apikey_name)
apikey['tags']['tag1'].should.equal('test_tag1')
apikey['tags']['tag2'].should.equal('1')
len(apikey['value']).should.equal(40) len(apikey['value']).should.equal(40)
apikey_name = 'TESTKEY3' apikey_name = 'TESTKEY3'

View File

@ -563,6 +563,38 @@ def test_reregister_task_definition():
resp2['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) resp2['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn'])
resp3 = batch_client.register_job_definition(
jobDefinitionName='sleep10',
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 42,
'command': ['sleep', '10']
}
)
resp3['revision'].should.equal(3)
resp3['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn'])
resp3['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn'])
resp4 = batch_client.register_job_definition(
jobDefinitionName='sleep10',
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 41,
'command': ['sleep', '10']
}
)
resp4['revision'].should.equal(4)
resp4['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn'])
resp4['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn'])
resp4['jobDefinitionArn'].should_not.equal(resp3['jobDefinitionArn'])
@mock_ec2 @mock_ec2
@mock_ecs @mock_ecs

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import boto3 import boto3
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_cognitoidentity from moto import mock_cognitoidentity
import sure # noqa
from moto.cognitoidentity.utils import get_random_identity_id from moto.cognitoidentity.utils import get_random_identity_id
@ -28,6 +28,47 @@ def test_create_identity_pool():
assert result['IdentityPoolId'] != '' assert result['IdentityPoolId'] != ''
@mock_cognitoidentity
def test_describe_identity_pool():
conn = boto3.client('cognito-identity', 'us-west-2')
res = conn.create_identity_pool(IdentityPoolName='TestPool',
AllowUnauthenticatedIdentities=False,
SupportedLoginProviders={'graph.facebook.com': '123456789012345'},
DeveloperProviderName='devname',
OpenIdConnectProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'],
CognitoIdentityProviders=[
{
'ProviderName': 'testprovider',
'ClientId': 'CLIENT12345',
'ServerSideTokenCheck': True
},
],
SamlProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'])
result = conn.describe_identity_pool(IdentityPoolId=res['IdentityPoolId'])
assert result['IdentityPoolId'] == res['IdentityPoolId']
assert result['AllowUnauthenticatedIdentities'] == res['AllowUnauthenticatedIdentities']
assert result['SupportedLoginProviders'] == res['SupportedLoginProviders']
assert result['DeveloperProviderName'] == res['DeveloperProviderName']
assert result['OpenIdConnectProviderARNs'] == res['OpenIdConnectProviderARNs']
assert result['CognitoIdentityProviders'] == res['CognitoIdentityProviders']
assert result['SamlProviderARNs'] == res['SamlProviderARNs']
@mock_cognitoidentity
def test_describe_identity_pool_with_invalid_id_raises_error():
conn = boto3.client('cognito-identity', 'us-west-2')
with assert_raises(ClientError) as cm:
conn.describe_identity_pool(IdentityPoolId='us-west-2_non-existent')
cm.exception.operation_name.should.equal('DescribeIdentityPool')
cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException')
cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400)
# testing a helper function # testing a helper function
def test_get_random_identity_id(): def test_get_random_identity_id():
assert len(get_random_identity_id('us-west-2')) > 0 assert len(get_random_identity_id('us-west-2')) > 0
@ -44,7 +85,8 @@ def test_get_id():
'someurl': '12345' 'someurl': '12345'
}) })
print(result) print(result)
assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get(
'HTTPStatusCode') == 200
@mock_cognitoidentity @mock_cognitoidentity
@ -71,6 +113,7 @@ def test_get_open_id_token_for_developer_identity():
assert len(result['Token']) > 0 assert len(result['Token']) > 0
assert result['IdentityId'] == '12345' assert result['IdentityId'] == '12345'
@mock_cognitoidentity @mock_cognitoidentity
def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id(): def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id():
conn = boto3.client('cognito-identity', 'us-west-2') conn = boto3.client('cognito-identity', 'us-west-2')
@ -84,6 +127,7 @@ def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id()
assert len(result['Token']) > 0 assert len(result['Token']) > 0
assert len(result['IdentityId']) > 0 assert len(result['IdentityId']) > 0
@mock_cognitoidentity @mock_cognitoidentity
def test_get_open_id_token(): def test_get_open_id_token():
conn = boto3.client('cognito-identity', 'us-west-2') conn = boto3.client('cognito-identity', 'us-west-2')

View File

@ -0,0 +1,22 @@
import requests
import sure # noqa
import boto3
from moto import mock_sqs, settings
@mock_sqs
def test_passthrough_requests():
conn = boto3.client("sqs", region_name='us-west-1')
conn.create_queue(QueueName="queue1")
res = requests.get("https://httpbin.org/ip")
assert res.status_code == 200
if not settings.TEST_SERVER_MODE:
@mock_sqs
def test_requests_to_amazon_subdomains_dont_work():
res = requests.get("https://fakeservice.amazonaws.com/foo/bar")
assert res.content == b"The method is not implemented"
assert res.status_code == 400

View File

@ -973,6 +973,53 @@ def test_query_filter():
assert response['Count'] == 2 assert response['Count'] == 2
@mock_dynamodb2
def test_query_filter_overlapping_expression_prefixes():
client = boto3.client('dynamodb', region_name='us-east-1')
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
# Create the DynamoDB table.
client.create_table(
TableName='test1',
AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}],
KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}],
ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app1'},
'nested': {'M': {
'version': {'S': 'version1'},
'contents': {'L': [
{'S': 'value1'}, {'S': 'value2'},
]},
}},
})
table = dynamodb.Table('test1')
response = table.query(
KeyConditionExpression=Key('client').eq('client1') & Key('app').eq('app1'),
ProjectionExpression='#1, #10, nested',
ExpressionAttributeNames={
'#1': 'client',
'#10': 'app',
}
)
assert response['Count'] == 1
assert response['Items'][0] == {
'client': 'client1',
'app': 'app1',
'nested': {
'version': 'version1',
'contents': ['value1', 'value2']
}
}
@mock_dynamodb2 @mock_dynamodb2
def test_scan_filter(): def test_scan_filter():
client = boto3.client('dynamodb', region_name='us-east-1') client = boto3.client('dynamodb', region_name='us-east-1')
@ -2034,6 +2081,36 @@ def test_condition_expression__or_order():
) )
@mock_dynamodb2
def test_condition_expression__and_order():
client = boto3.client('dynamodb', region_name='us-east-1')
client.create_table(
TableName='test',
KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
AttributeDefinitions=[
{'AttributeName': 'forum_name', 'AttributeType': 'S'},
],
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
)
# ensure that the RHS of the AND expression is not evaluated if the LHS
# returns true (as it would result an error)
with assert_raises(client.exceptions.ConditionalCheckFailedException):
client.update_item(
TableName='test',
Key={
'forum_name': {'S': 'the-key'},
},
UpdateExpression='set #ttl=:ttl',
ConditionExpression='attribute_exists(#ttl) AND #ttl <= :old_ttl',
ExpressionAttributeNames={'#ttl': 'ttl'},
ExpressionAttributeValues={
':ttl': {'N': '6'},
':old_ttl': {'N': '5'},
}
)
@mock_dynamodb2 @mock_dynamodb2
def test_query_gsi_with_range_key(): def test_query_gsi_with_range_key():
dynamodb = boto3.client('dynamodb', region_name='us-east-1') dynamodb = boto3.client('dynamodb', region_name='us-east-1')

View File

@ -76,6 +76,34 @@ class TestCore():
ShardIteratorType='TRIM_HORIZON' ShardIteratorType='TRIM_HORIZON'
) )
assert 'ShardIterator' in resp assert 'ShardIterator' in resp
def test_get_shard_iterator_at_sequence_number(self):
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
resp = conn.describe_stream(StreamArn=self.stream_arn)
shard_id = resp['StreamDescription']['Shards'][0]['ShardId']
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AT_SEQUENCE_NUMBER',
SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber']
)
assert 'ShardIterator' in resp
def test_get_shard_iterator_after_sequence_number(self):
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
resp = conn.describe_stream(StreamArn=self.stream_arn)
shard_id = resp['StreamDescription']['Shards'][0]['ShardId']
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AFTER_SEQUENCE_NUMBER',
SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber']
)
assert 'ShardIterator' in resp
def test_get_records_empty(self): def test_get_records_empty(self):
conn = boto3.client('dynamodbstreams', region_name='us-east-1') conn = boto3.client('dynamodbstreams', region_name='us-east-1')
@ -135,11 +163,39 @@ class TestCore():
assert resp['Records'][1]['eventName'] == 'MODIFY' assert resp['Records'][1]['eventName'] == 'MODIFY'
assert resp['Records'][2]['eventName'] == 'DELETE' assert resp['Records'][2]['eventName'] == 'DELETE'
sequence_number_modify = resp['Records'][1]['dynamodb']['SequenceNumber']
# now try fetching from the next shard iterator, it should be # now try fetching from the next shard iterator, it should be
# empty # empty
resp = conn.get_records(ShardIterator=resp['NextShardIterator']) resp = conn.get_records(ShardIterator=resp['NextShardIterator'])
assert len(resp['Records']) == 0 assert len(resp['Records']) == 0
# check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AT_SEQUENCE_NUMBER',
SequenceNumber=sequence_number_modify
)
iterator_id = resp['ShardIterator']
resp = conn.get_records(ShardIterator=iterator_id)
assert len(resp['Records']) == 2
assert resp['Records'][0]['eventName'] == 'MODIFY'
assert resp['Records'][1]['eventName'] == 'DELETE'
# check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AFTER_SEQUENCE_NUMBER',
SequenceNumber=sequence_number_modify
)
iterator_id = resp['ShardIterator']
resp = conn.get_records(ShardIterator=iterator_id)
assert len(resp['Records']) == 1
assert resp['Records'][0]['eventName'] == 'DELETE'
class TestEdges(): class TestEdges():
mocks = [] mocks = []

View File

@ -4,7 +4,7 @@ import json
import os import os
import boto3 import boto3
import botocore import botocore
from botocore.exceptions import ClientError from botocore.exceptions import ClientError, ParamValidationError
from nose.tools import assert_raises from nose.tools import assert_raises
import sure # noqa import sure # noqa
@ -752,6 +752,83 @@ def test_stopped_instance_target():
}) })
@mock_ec2
@mock_elbv2
def test_terminated_instance_target():
target_group_port = 8080
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.0/26',
AvailabilityZone='us-east-1b')
conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
response = conn.create_target_group(
Name='a-target',
Protocol='HTTP',
Port=target_group_port,
VpcId=vpc.id,
HealthCheckProtocol='HTTP',
HealthCheckPath='/',
HealthCheckIntervalSeconds=5,
HealthCheckTimeoutSeconds=5,
HealthyThresholdCount=5,
UnhealthyThresholdCount=2,
Matcher={'HttpCode': '200'})
target_group = response.get('TargetGroups')[0]
# No targets registered yet
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(0)
response = ec2.create_instances(
ImageId='ami-1234abcd', MinCount=1, MaxCount=1)
instance = response[0]
target_dict = {
'Id': instance.id,
'Port': 500
}
response = conn.register_targets(
TargetGroupArn=target_group.get('TargetGroupArn'),
Targets=[target_dict])
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(1)
target_health_description = response.get('TargetHealthDescriptions')[0]
target_health_description['Target'].should.equal(target_dict)
target_health_description['HealthCheckPort'].should.equal(str(target_group_port))
target_health_description['TargetHealth'].should.equal({
'State': 'healthy'
})
instance.terminate()
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(0)
@mock_ec2 @mock_ec2
@mock_elbv2 @mock_elbv2
def test_target_group_attributes(): def test_target_group_attributes():
@ -1940,3 +2017,279 @@ def test_cognito_action_listener_rule_cloudformation():
'UserPoolDomain': 'testpool', 'UserPoolDomain': 'testpool',
} }
},]) },])
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '404',
}
}
response = conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[action])
listener = response.get('Listeners')[0]
listener.get('DefaultActions')[0].should.equal(action)
listener_arn = listener.get('ListenerArn')
describe_rules_response = conn.describe_rules(ListenerArn=listener_arn)
describe_rules_response['Rules'][0]['Actions'][0].should.equal(action)
describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ])
describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0]
describe_listener_actions.should.equal(action)
@mock_elbv2
@mock_cloudformation
def test_fixed_response_action_listener_rule_cloudformation():
cnf_conn = boto3.client('cloudformation', region_name='us-east-1')
elbv2_client = boto3.client('elbv2', region_name='us-east-1')
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Description": "ECS Cluster Test CloudFormation",
"Resources": {
"testVPC": {
"Type": "AWS::EC2::VPC",
"Properties": {
"CidrBlock": "10.0.0.0/16",
},
},
"subnet1": {
"Type": "AWS::EC2::Subnet",
"Properties": {
"CidrBlock": "10.0.0.0/24",
"VpcId": {"Ref": "testVPC"},
"AvalabilityZone": "us-east-1b",
},
},
"subnet2": {
"Type": "AWS::EC2::Subnet",
"Properties": {
"CidrBlock": "10.0.1.0/24",
"VpcId": {"Ref": "testVPC"},
"AvalabilityZone": "us-east-1b",
},
},
"testLb": {
"Type": "AWS::ElasticLoadBalancingV2::LoadBalancer",
"Properties": {
"Name": "my-lb",
"Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}],
"Type": "application",
"SecurityGroups": [],
}
},
"testListener": {
"Type": "AWS::ElasticLoadBalancingV2::Listener",
"Properties": {
"LoadBalancerArn": {"Ref": "testLb"},
"Port": 80,
"Protocol": "HTTP",
"DefaultActions": [{
"Type": "fixed-response",
"FixedResponseConfig": {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '404',
}
}]
}
}
}
}
template_json = json.dumps(template)
cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json)
describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',])
load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn']
describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn)
describe_listeners_response['Listeners'].should.have.length_of(1)
describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{
'Type': 'fixed-response',
"FixedResponseConfig": {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '404',
}
},])
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule_validates_status_code():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
missing_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
}
}
with assert_raises(ParamValidationError):
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[missing_status_code_action])
invalid_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '100'
}
}
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule_validates_status_code():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
missing_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
}
}
with assert_raises(ParamValidationError):
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[missing_status_code_action])
invalid_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '100'
}
}
with assert_raises(ClientError) as invalid_status_code_exception:
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[invalid_status_code_action])
invalid_status_code_exception.exception.response['Error']['Code'].should.equal('ValidationError')
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule_validates_content_type():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
invalid_content_type_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'Fake content type',
'MessageBody': 'This page does not exist',
'StatusCode': '200'
}
}
with assert_raises(ClientError) as invalid_content_type_exception:
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[invalid_content_type_action])
invalid_content_type_exception.exception.response['Error']['Code'].should.equal('InvalidLoadBalancerAction')

View File

@ -1827,6 +1827,23 @@ valid_policy_documents = [
"Resource": ["*"] "Resource": ["*"]
} }
] ]
},
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "",
"Effect": "Allow",
"Action": "rds:*",
"Resource": ["arn:aws:rds:region:*:*"]
},
{
"Sid": "",
"Effect": "Allow",
"Action": ["rds:Describe*"],
"Resource": ["*"]
}
]
} }
] ]

View File

@ -86,6 +86,12 @@ def test_update():
payload.should.have.key('version').which.should.equal(2) payload.should.have.key('version').which.should.equal(2)
payload.should.have.key('timestamp') payload.should.have.key('timestamp')
raw_payload = b'{"state": {"desired": {"led": "on"}}, "version": 1}'
with assert_raises(ClientError) as ex:
client.update_thing_shadow(thingName=name, payload=raw_payload)
ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(409)
ex.exception.response['Error']['Message'].should.equal('Version conflict')
@mock_iotdata @mock_iotdata
def test_publish(): def test_publish():

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,157 @@
from __future__ import unicode_literals
import sure # noqa
from nose.tools import assert_raises
from parameterized import parameterized
from moto.kms.exceptions import AccessDeniedException, InvalidCiphertextException, NotFoundException
from moto.kms.models import Key
from moto.kms.utils import (
_deserialize_ciphertext_blob,
_serialize_ciphertext_blob,
_serialize_encryption_context,
generate_data_key,
generate_master_key,
MASTER_KEY_LEN,
encrypt,
decrypt,
Ciphertext,
)
ENCRYPTION_CONTEXT_VECTORS = (
({"this": "is", "an": "encryption", "context": "example"}, b"an" b"encryption" b"context" b"example" b"this" b"is"),
({"a_this": "one", "b_is": "actually", "c_in": "order"}, b"a_this" b"one" b"b_is" b"actually" b"c_in" b"order"),
)
CIPHERTEXT_BLOB_VECTORS = (
(
Ciphertext(
key_id="d25652e4-d2d2-49f7-929a-671ccda580c6",
iv=b"123456789012",
ciphertext=b"some ciphertext",
tag=b"1234567890123456",
),
b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext",
),
(
Ciphertext(
key_id="d25652e4-d2d2-49f7-929a-671ccda580c6",
iv=b"123456789012",
ciphertext=b"some ciphertext that is much longer now",
tag=b"1234567890123456",
),
b"d25652e4-d2d2-49f7-929a-671ccda580c6"
b"123456789012"
b"1234567890123456"
b"some ciphertext that is much longer now",
),
)
def test_generate_data_key():
test = generate_data_key(123)
test.should.be.a(bytes)
len(test).should.equal(123)
def test_generate_master_key():
test = generate_master_key()
test.should.be.a(bytes)
len(test).should.equal(MASTER_KEY_LEN)
@parameterized(ENCRYPTION_CONTEXT_VECTORS)
def test_serialize_encryption_context(raw, serialized):
test = _serialize_encryption_context(raw)
test.should.equal(serialized)
@parameterized(CIPHERTEXT_BLOB_VECTORS)
def test_cycle_ciphertext_blob(raw, _serialized):
test_serialized = _serialize_ciphertext_blob(raw)
test_deserialized = _deserialize_ciphertext_blob(test_serialized)
test_deserialized.should.equal(raw)
@parameterized(CIPHERTEXT_BLOB_VECTORS)
def test_serialize_ciphertext_blob(raw, serialized):
test = _serialize_ciphertext_blob(raw)
test.should.equal(serialized)
@parameterized(CIPHERTEXT_BLOB_VECTORS)
def test_deserialize_ciphertext_blob(raw, serialized):
test = _deserialize_ciphertext_blob(serialized)
test.should.equal(raw)
@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS))
def test_encrypt_decrypt_cycle(encryption_context):
plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt(
master_keys=master_key_map, key_id=master_key.id, plaintext=plaintext, encryption_context=encryption_context
)
ciphertext_blob.should_not.equal(plaintext)
decrypted, decrypting_key_id = decrypt(
master_keys=master_key_map, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
)
decrypted.should.equal(plaintext)
decrypting_key_id.should.equal(master_key.id)
def test_encrypt_unknown_key_id():
with assert_raises(NotFoundException):
encrypt(master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={})
def test_decrypt_invalid_ciphertext_format():
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
with assert_raises(InvalidCiphertextException):
decrypt(master_keys=master_key_map, ciphertext_blob=b"", encryption_context={})
def test_decrypt_unknwown_key_id():
ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext"
with assert_raises(AccessDeniedException):
decrypt(master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={})
def test_decrypt_invalid_ciphertext():
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext"
with assert_raises(InvalidCiphertextException):
decrypt(
master_keys=master_key_map,
ciphertext_blob=ciphertext_blob,
encryption_context={},
)
def test_decrypt_invalid_encryption_context():
plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt(
master_keys=master_key_map,
key_id=master_key.id,
plaintext=plaintext,
encryption_context={"some": "encryption", "context": "here"},
)
with assert_raises(InvalidCiphertextException):
decrypt(
master_keys=master_key_map,
ciphertext_blob=ciphertext_blob,
encryption_context={},
)

View File

@ -190,6 +190,8 @@ def test_get_log_events():
resp['events'].should.have.length_of(10) resp['events'].should.have.length_of(10)
resp.should.have.key('nextForwardToken') resp.should.have.key('nextForwardToken')
resp.should.have.key('nextBackwardToken') resp.should.have.key('nextBackwardToken')
resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000010')
resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000')
for i in range(10): for i in range(10):
resp['events'][i]['timestamp'].should.equal(i) resp['events'][i]['timestamp'].should.equal(i)
resp['events'][i]['message'].should.equal(str(i)) resp['events'][i]['message'].should.equal(str(i))
@ -205,7 +207,8 @@ def test_get_log_events():
resp['events'].should.have.length_of(10) resp['events'].should.have.length_of(10)
resp.should.have.key('nextForwardToken') resp.should.have.key('nextForwardToken')
resp.should.have.key('nextBackwardToken') resp.should.have.key('nextBackwardToken')
resp['nextForwardToken'].should.equal(next_token) resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000020')
resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000')
for i in range(10): for i in range(10):
resp['events'][i]['timestamp'].should.equal(i+10) resp['events'][i]['timestamp'].should.equal(i+10)
resp['events'][i]['message'].should.equal(str(i+10)) resp['events'][i]['message'].should.equal(str(i+10))

View File

@ -37,6 +37,25 @@ def test_create_cluster_boto3():
create_time = response['Cluster']['ClusterCreateTime'] create_time = response['Cluster']['ClusterCreateTime']
create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo))
create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1))
response['Cluster']['EnhancedVpcRouting'].should.equal(False)
@mock_redshift
def test_create_cluster_boto3():
client = boto3.client('redshift', region_name='us-east-1')
response = client.create_cluster(
DBName='test',
ClusterIdentifier='test',
ClusterType='single-node',
NodeType='ds2.xlarge',
MasterUsername='user',
MasterUserPassword='password',
EnhancedVpcRouting=True
)
response['Cluster']['NodeType'].should.equal('ds2.xlarge')
create_time = response['Cluster']['ClusterCreateTime']
create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo))
create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1))
response['Cluster']['EnhancedVpcRouting'].should.equal(True)
@mock_redshift @mock_redshift
@ -425,6 +444,58 @@ def test_delete_cluster():
"not-a-cluster").should.throw(ClusterNotFound) "not-a-cluster").should.throw(ClusterNotFound)
@mock_redshift
def test_modify_cluster_vpc_routing():
iam_roles_arn = ['arn:aws:iam:::role/my-iam-role', ]
client = boto3.client('redshift', region_name='us-east-1')
cluster_identifier = 'my_cluster'
client.create_cluster(
ClusterIdentifier=cluster_identifier,
NodeType="single-node",
MasterUsername="username",
MasterUserPassword="password",
IamRoles=iam_roles_arn
)
cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier)
cluster = cluster_response['Clusters'][0]
cluster['EnhancedVpcRouting'].should.equal(False)
client.create_cluster_security_group(ClusterSecurityGroupName='security_group',
Description='security_group')
client.create_cluster_parameter_group(ParameterGroupName='my_parameter_group',
ParameterGroupFamily='redshift-1.0',
Description='my_parameter_group')
client.modify_cluster(
ClusterIdentifier=cluster_identifier,
ClusterType='multi-node',
NodeType="ds2.8xlarge",
NumberOfNodes=3,
ClusterSecurityGroups=["security_group"],
MasterUserPassword="new_password",
ClusterParameterGroupName="my_parameter_group",
AutomatedSnapshotRetentionPeriod=7,
PreferredMaintenanceWindow="Tue:03:00-Tue:11:00",
AllowVersionUpgrade=False,
NewClusterIdentifier=cluster_identifier,
EnhancedVpcRouting=True
)
cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier)
cluster = cluster_response['Clusters'][0]
cluster['ClusterIdentifier'].should.equal(cluster_identifier)
cluster['NodeType'].should.equal("ds2.8xlarge")
cluster['PreferredMaintenanceWindow'].should.equal("Tue:03:00-Tue:11:00")
cluster['AutomatedSnapshotRetentionPeriod'].should.equal(7)
cluster['AllowVersionUpgrade'].should.equal(False)
# This one should remain unmodified.
cluster['NumberOfNodes'].should.equal(3)
cluster['EnhancedVpcRouting'].should.equal(True)
@mock_redshift_deprecated @mock_redshift_deprecated
def test_modify_cluster(): def test_modify_cluster():
conn = boto.connect_redshift() conn = boto.connect_redshift()
@ -446,6 +517,10 @@ def test_modify_cluster():
master_user_password="password", master_user_password="password",
) )
cluster_response = conn.describe_clusters(cluster_identifier)
cluster = cluster_response['DescribeClustersResponse']['DescribeClustersResult']['Clusters'][0]
cluster['EnhancedVpcRouting'].should.equal(False)
conn.modify_cluster( conn.modify_cluster(
cluster_identifier, cluster_identifier,
cluster_type="multi-node", cluster_type="multi-node",
@ -456,14 +531,13 @@ def test_modify_cluster():
automated_snapshot_retention_period=7, automated_snapshot_retention_period=7,
preferred_maintenance_window="Tue:03:00-Tue:11:00", preferred_maintenance_window="Tue:03:00-Tue:11:00",
allow_version_upgrade=False, allow_version_upgrade=False,
new_cluster_identifier="new_identifier", new_cluster_identifier=cluster_identifier,
) )
cluster_response = conn.describe_clusters("new_identifier") cluster_response = conn.describe_clusters(cluster_identifier)
cluster = cluster_response['DescribeClustersResponse'][ cluster = cluster_response['DescribeClustersResponse'][
'DescribeClustersResult']['Clusters'][0] 'DescribeClustersResult']['Clusters'][0]
cluster['ClusterIdentifier'].should.equal(cluster_identifier)
cluster['ClusterIdentifier'].should.equal("new_identifier")
cluster['NodeType'].should.equal("dw.hs1.xlarge") cluster['NodeType'].should.equal("dw.hs1.xlarge")
cluster['ClusterSecurityGroups'][0][ cluster['ClusterSecurityGroups'][0][
'ClusterSecurityGroupName'].should.equal("security_group") 'ClusterSecurityGroupName'].should.equal("security_group")
@ -674,6 +748,7 @@ def test_create_cluster_snapshot():
NodeType='ds2.xlarge', NodeType='ds2.xlarge',
MasterUsername='username', MasterUsername='username',
MasterUserPassword='password', MasterUserPassword='password',
EnhancedVpcRouting=True
) )
cluster_response['Cluster']['NodeType'].should.equal('ds2.xlarge') cluster_response['Cluster']['NodeType'].should.equal('ds2.xlarge')
@ -823,11 +898,14 @@ def test_create_cluster_from_snapshot():
NodeType='ds2.xlarge', NodeType='ds2.xlarge',
MasterUsername='username', MasterUsername='username',
MasterUserPassword='password', MasterUserPassword='password',
EnhancedVpcRouting=True,
) )
client.create_cluster_snapshot( client.create_cluster_snapshot(
SnapshotIdentifier=original_snapshot_identifier, SnapshotIdentifier=original_snapshot_identifier,
ClusterIdentifier=original_cluster_identifier ClusterIdentifier=original_cluster_identifier
) )
response = client.restore_from_cluster_snapshot( response = client.restore_from_cluster_snapshot(
ClusterIdentifier=new_cluster_identifier, ClusterIdentifier=new_cluster_identifier,
SnapshotIdentifier=original_snapshot_identifier, SnapshotIdentifier=original_snapshot_identifier,
@ -842,7 +920,7 @@ def test_create_cluster_from_snapshot():
new_cluster['NodeType'].should.equal('ds2.xlarge') new_cluster['NodeType'].should.equal('ds2.xlarge')
new_cluster['MasterUsername'].should.equal('username') new_cluster['MasterUsername'].should.equal('username')
new_cluster['Endpoint']['Port'].should.equal(1234) new_cluster['Endpoint']['Port'].should.equal(1234)
new_cluster['EnhancedVpcRouting'].should.equal(True)
@mock_redshift @mock_redshift
def test_create_cluster_from_snapshot_with_waiter(): def test_create_cluster_from_snapshot_with_waiter():
@ -857,6 +935,7 @@ def test_create_cluster_from_snapshot_with_waiter():
NodeType='ds2.xlarge', NodeType='ds2.xlarge',
MasterUsername='username', MasterUsername='username',
MasterUserPassword='password', MasterUserPassword='password',
EnhancedVpcRouting=True
) )
client.create_cluster_snapshot( client.create_cluster_snapshot(
SnapshotIdentifier=original_snapshot_identifier, SnapshotIdentifier=original_snapshot_identifier,
@ -883,6 +962,7 @@ def test_create_cluster_from_snapshot_with_waiter():
new_cluster = response['Clusters'][0] new_cluster = response['Clusters'][0]
new_cluster['NodeType'].should.equal('ds2.xlarge') new_cluster['NodeType'].should.equal('ds2.xlarge')
new_cluster['MasterUsername'].should.equal('username') new_cluster['MasterUsername'].should.equal('username')
new_cluster['EnhancedVpcRouting'].should.equal(True)
new_cluster['Endpoint']['Port'].should.equal(1234) new_cluster['Endpoint']['Port'].should.equal(1234)

View File

@ -404,6 +404,11 @@ def test_list_or_change_tags_for_resource_request():
) )
healthcheck_id = health_check['HealthCheck']['Id'] healthcheck_id = health_check['HealthCheck']['Id']
# confirm this works for resources with zero tags
response = conn.list_tags_for_resource(
ResourceType="healthcheck", ResourceId=healthcheck_id)
response["ResourceTagSet"]["Tags"].should.be.empty
tag1 = {"Key": "Deploy", "Value": "True"} tag1 = {"Key": "Deploy", "Value": "True"}
tag2 = {"Key": "Name", "Value": "UnitTest"} tag2 = {"Key": "Name", "Value": "UnitTest"}

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import datetime import datetime
import os
from six.moves.urllib.request import urlopen from six.moves.urllib.request import urlopen
from six.moves.urllib.error import HTTPError from six.moves.urllib.error import HTTPError
from functools import wraps from functools import wraps
@ -20,9 +21,11 @@ from botocore.handlers import disable_signing
from boto.s3.connection import S3Connection from boto.s3.connection import S3Connection
from boto.s3.key import Key from boto.s3.key import Key
from freezegun import freeze_time from freezegun import freeze_time
from parameterized import parameterized
import six import six
import requests import requests
import tests.backport_assert_raises # noqa import tests.backport_assert_raises # noqa
from nose import SkipTest
from nose.tools import assert_raises from nose.tools import assert_raises
import sure # noqa import sure # noqa
@ -1390,6 +1393,34 @@ def test_boto3_list_objects_v2_fetch_owner():
assert len(owner.keys()) == 2 assert len(owner.keys()) == 2
@mock_s3
def test_boto3_list_objects_v2_truncate_combined_keys_and_folders():
s3 = boto3.client('s3', region_name='us-east-1')
s3.create_bucket(Bucket='mybucket')
s3.put_object(Bucket='mybucket', Key='1/2', Body='')
s3.put_object(Bucket='mybucket', Key='2', Body='')
s3.put_object(Bucket='mybucket', Key='3/4', Body='')
s3.put_object(Bucket='mybucket', Key='4', Body='')
resp = s3.list_objects_v2(Bucket='mybucket', Prefix='', MaxKeys=2, Delimiter='/')
assert 'Delimiter' in resp
assert resp['IsTruncated'] is True
assert resp['KeyCount'] == 2
assert len(resp['Contents']) == 1
assert resp['Contents'][0]['Key'] == '2'
assert len(resp['CommonPrefixes']) == 1
assert resp['CommonPrefixes'][0]['Prefix'] == '1/'
last_tail = resp['NextContinuationToken']
resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Prefix='', Delimiter='/', StartAfter=last_tail)
assert resp['KeyCount'] == 2
assert resp['IsTruncated'] is False
assert len(resp['Contents']) == 1
assert resp['Contents'][0]['Key'] == '4'
assert len(resp['CommonPrefixes']) == 1
assert resp['CommonPrefixes'][0]['Prefix'] == '3/'
@mock_s3 @mock_s3
def test_boto3_bucket_create(): def test_boto3_bucket_create():
s3 = boto3.resource('s3', region_name='us-east-1') s3 = boto3.resource('s3', region_name='us-east-1')
@ -2991,3 +3022,64 @@ def test_accelerate_configuration_is_not_supported_when_bucket_name_has_dots():
AccelerateConfiguration={'Status': 'Enabled'}, AccelerateConfiguration={'Status': 'Enabled'},
) )
exc.exception.response['Error']['Code'].should.equal('InvalidRequest') exc.exception.response['Error']['Code'].should.equal('InvalidRequest')
def store_and_read_back_a_key(key):
s3 = boto3.client('s3', region_name='us-east-1')
bucket_name = 'mybucket'
body = b'Some body'
s3.create_bucket(Bucket=bucket_name)
s3.put_object(
Bucket=bucket_name,
Key=key,
Body=body
)
response = s3.get_object(Bucket=bucket_name, Key=key)
response['Body'].read().should.equal(body)
@mock_s3
def test_paths_with_leading_slashes_work():
store_and_read_back_a_key('/a-key')
@mock_s3
def test_root_dir_with_empty_name_works():
if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true':
raise SkipTest('Does not work in server mode due to error in Workzeug')
store_and_read_back_a_key('/')
@parameterized([
('foo/bar/baz',),
('foo',),
('foo/run_dt%3D2019-01-01%252012%253A30%253A00',),
])
@mock_s3
def test_delete_objects_with_url_encoded_key(key):
s3 = boto3.client('s3', region_name='us-east-1')
bucket_name = 'mybucket'
body = b'Some body'
s3.create_bucket(Bucket=bucket_name)
def put_object():
s3.put_object(
Bucket=bucket_name,
Key=key,
Body=body
)
def assert_deleted():
with assert_raises(ClientError) as e:
s3.get_object(Bucket=bucket_name, Key=key)
e.exception.response['Error']['Code'].should.equal('NoSuchKey')
put_object()
s3.delete_object(Bucket=bucket_name, Key=key)
assert_deleted()
put_object()
s3.delete_objects(Bucket=bucket_name, Delete={'Objects': [{'Key': key}]})
assert_deleted()

View File

@ -1,7 +1,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os import os
from sure import expect from sure import expect
from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url, clean_key_name, undo_clean_key_name
from parameterized import parameterized
def test_base_url(): def test_base_url():
@ -78,3 +79,29 @@ def test_parse_region_from_url():
'https://s3.amazonaws.com/bucket', 'https://s3.amazonaws.com/bucket',
'https://bucket.s3.amazonaws.com']: 'https://bucket.s3.amazonaws.com']:
parse_region_from_url(url).should.equal(expected) parse_region_from_url(url).should.equal(expected)
@parameterized([
('foo/bar/baz',
'foo/bar/baz'),
('foo',
'foo'),
('foo/run_dt%3D2019-01-01%252012%253A30%253A00',
'foo/run_dt=2019-01-01%2012%3A30%3A00'),
])
def test_clean_key_name(key, expected):
clean_key_name(key).should.equal(expected)
@parameterized([
('foo/bar/baz',
'foo/bar/baz'),
('foo',
'foo'),
('foo/run_dt%3D2019-01-01%252012%253A30%253A00',
'foo/run_dt%253D2019-01-01%25252012%25253A30%25253A00'),
])
def test_undo_clean_key_name(key, expected):
undo_clean_key_name(key).should.equal(expected)

View File

@ -5,9 +5,9 @@ import boto3
from moto import mock_secretsmanager from moto import mock_secretsmanager
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import string import string
import unittest
import pytz import pytz
from datetime import datetime from datetime import datetime
import sure # noqa
from nose.tools import assert_raises from nose.tools import assert_raises
from six import b from six import b
@ -23,6 +23,7 @@ def test_get_secret_value():
result = conn.get_secret_value(SecretId='java-util-test-password') result = conn.get_secret_value(SecretId='java-util-test-password')
assert result['SecretString'] == 'foosecret' assert result['SecretString'] == 'foosecret'
@mock_secretsmanager @mock_secretsmanager
def test_get_secret_value_binary(): def test_get_secret_value_binary():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -32,6 +33,7 @@ def test_get_secret_value_binary():
result = conn.get_secret_value(SecretId='java-util-test-password') result = conn.get_secret_value(SecretId='java-util-test-password')
assert result['SecretBinary'] == b('foosecret') assert result['SecretBinary'] == b('foosecret')
@mock_secretsmanager @mock_secretsmanager
def test_get_secret_that_does_not_exist(): def test_get_secret_that_does_not_exist():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -39,6 +41,7 @@ def test_get_secret_that_does_not_exist():
with assert_raises(ClientError): with assert_raises(ClientError):
result = conn.get_secret_value(SecretId='i-dont-exist') result = conn.get_secret_value(SecretId='i-dont-exist')
@mock_secretsmanager @mock_secretsmanager
def test_get_secret_that_does_not_match(): def test_get_secret_that_does_not_match():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -72,6 +75,7 @@ def test_create_secret():
secret = conn.get_secret_value(SecretId='test-secret') secret = conn.get_secret_value(SecretId='test-secret')
assert secret['SecretString'] == 'foosecret' assert secret['SecretString'] == 'foosecret'
@mock_secretsmanager @mock_secretsmanager
def test_create_secret_with_tags(): def test_create_secret_with_tags():
conn = boto3.client('secretsmanager', region_name='us-east-1') conn = boto3.client('secretsmanager', region_name='us-east-1')
@ -216,6 +220,7 @@ def test_get_random_exclude_lowercase():
ExcludeLowercase=True) ExcludeLowercase=True)
assert any(c.islower() for c in random_password['RandomPassword']) == False assert any(c.islower() for c in random_password['RandomPassword']) == False
@mock_secretsmanager @mock_secretsmanager
def test_get_random_exclude_uppercase(): def test_get_random_exclude_uppercase():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -224,6 +229,7 @@ def test_get_random_exclude_uppercase():
ExcludeUppercase=True) ExcludeUppercase=True)
assert any(c.isupper() for c in random_password['RandomPassword']) == False assert any(c.isupper() for c in random_password['RandomPassword']) == False
@mock_secretsmanager @mock_secretsmanager
def test_get_random_exclude_characters_and_symbols(): def test_get_random_exclude_characters_and_symbols():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -232,6 +238,7 @@ def test_get_random_exclude_characters_and_symbols():
ExcludeCharacters='xyzDje@?!.') ExcludeCharacters='xyzDje@?!.')
assert any(c in 'xyzDje@?!.' for c in random_password['RandomPassword']) == False assert any(c in 'xyzDje@?!.' for c in random_password['RandomPassword']) == False
@mock_secretsmanager @mock_secretsmanager
def test_get_random_exclude_numbers(): def test_get_random_exclude_numbers():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -240,6 +247,7 @@ def test_get_random_exclude_numbers():
ExcludeNumbers=True) ExcludeNumbers=True)
assert any(c.isdigit() for c in random_password['RandomPassword']) == False assert any(c.isdigit() for c in random_password['RandomPassword']) == False
@mock_secretsmanager @mock_secretsmanager
def test_get_random_exclude_punctuation(): def test_get_random_exclude_punctuation():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -249,6 +257,7 @@ def test_get_random_exclude_punctuation():
assert any(c in string.punctuation assert any(c in string.punctuation
for c in random_password['RandomPassword']) == False for c in random_password['RandomPassword']) == False
@mock_secretsmanager @mock_secretsmanager
def test_get_random_include_space_false(): def test_get_random_include_space_false():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -256,6 +265,7 @@ def test_get_random_include_space_false():
random_password = conn.get_random_password(PasswordLength=300) random_password = conn.get_random_password(PasswordLength=300)
assert any(c.isspace() for c in random_password['RandomPassword']) == False assert any(c.isspace() for c in random_password['RandomPassword']) == False
@mock_secretsmanager @mock_secretsmanager
def test_get_random_include_space_true(): def test_get_random_include_space_true():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -264,6 +274,7 @@ def test_get_random_include_space_true():
IncludeSpace=True) IncludeSpace=True)
assert any(c.isspace() for c in random_password['RandomPassword']) == True assert any(c.isspace() for c in random_password['RandomPassword']) == True
@mock_secretsmanager @mock_secretsmanager
def test_get_random_require_each_included_type(): def test_get_random_require_each_included_type():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -275,6 +286,7 @@ def test_get_random_require_each_included_type():
assert any(c in string.ascii_uppercase for c in random_password['RandomPassword']) == True assert any(c in string.ascii_uppercase for c in random_password['RandomPassword']) == True
assert any(c in string.digits for c in random_password['RandomPassword']) == True assert any(c in string.digits for c in random_password['RandomPassword']) == True
@mock_secretsmanager @mock_secretsmanager
def test_get_random_too_short_password(): def test_get_random_too_short_password():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -282,6 +294,7 @@ def test_get_random_too_short_password():
with assert_raises(ClientError): with assert_raises(ClientError):
random_password = conn.get_random_password(PasswordLength=3) random_password = conn.get_random_password(PasswordLength=3)
@mock_secretsmanager @mock_secretsmanager
def test_get_random_too_long_password(): def test_get_random_too_long_password():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -289,6 +302,7 @@ def test_get_random_too_long_password():
with assert_raises(Exception): with assert_raises(Exception):
random_password = conn.get_random_password(PasswordLength=5555) random_password = conn.get_random_password(PasswordLength=5555)
@mock_secretsmanager @mock_secretsmanager
def test_describe_secret(): def test_describe_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -307,6 +321,7 @@ def test_describe_secret():
assert secret_description_2['Name'] == ('test-secret-2') assert secret_description_2['Name'] == ('test-secret-2')
assert secret_description_2['ARN'] != '' # Test arn not empty assert secret_description_2['ARN'] != '' # Test arn not empty
@mock_secretsmanager @mock_secretsmanager
def test_describe_secret_that_does_not_exist(): def test_describe_secret_that_does_not_exist():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -314,6 +329,7 @@ def test_describe_secret_that_does_not_exist():
with assert_raises(ClientError): with assert_raises(ClientError):
result = conn.get_secret_value(SecretId='i-dont-exist') result = conn.get_secret_value(SecretId='i-dont-exist')
@mock_secretsmanager @mock_secretsmanager
def test_describe_secret_that_does_not_match(): def test_describe_secret_that_does_not_match():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -500,6 +516,7 @@ def test_rotate_secret_rotation_period_zero():
# test_server actually handles this error. # test_server actually handles this error.
assert True assert True
@mock_secretsmanager @mock_secretsmanager
def test_rotate_secret_rotation_period_too_long(): def test_rotate_secret_rotation_period_too_long():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -511,6 +528,7 @@ def test_rotate_secret_rotation_period_too_long():
result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME,
RotationRules=rotation_rules) RotationRules=rotation_rules)
@mock_secretsmanager @mock_secretsmanager
def test_put_secret_value_puts_new_secret(): def test_put_secret_value_puts_new_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -526,6 +544,45 @@ def test_put_secret_value_puts_new_secret():
assert get_secret_value_dict assert get_secret_value_dict
assert get_secret_value_dict['SecretString'] == 'foosecret' assert get_secret_value_dict['SecretString'] == 'foosecret'
@mock_secretsmanager
def test_put_secret_binary_value_puts_new_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2')
put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME,
SecretBinary=b('foosecret'),
VersionStages=['AWSCURRENT'])
version_id = put_secret_value_dict['VersionId']
get_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME,
VersionId=version_id,
VersionStage='AWSCURRENT')
assert get_secret_value_dict
assert get_secret_value_dict['SecretBinary'] == b('foosecret')
@mock_secretsmanager
def test_create_and_put_secret_binary_value_puts_new_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret"))
conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, SecretBinary=b('foosecret_update'))
latest_secret = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME)
assert latest_secret
assert latest_secret['SecretBinary'] == b('foosecret_update')
@mock_secretsmanager
def test_put_secret_binary_requires_either_string_or_binary():
conn = boto3.client('secretsmanager', region_name='us-west-2')
with assert_raises(ClientError) as ire:
conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME)
ire.exception.response['Error']['Code'].should.equal('InvalidRequestException')
ire.exception.response['Error']['Message'].should.equal('You must provide either SecretString or SecretBinary.')
@mock_secretsmanager @mock_secretsmanager
def test_put_secret_value_can_get_first_version_if_put_twice(): def test_put_secret_value_can_get_first_version_if_put_twice():
conn = boto3.client('secretsmanager', region_name='us-west-2') conn = boto3.client('secretsmanager', region_name='us-west-2')

View File

@ -109,6 +109,17 @@ def test_publish_to_sqs_bad():
}}) }})
except ClientError as err: except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameterValue') err.response['Error']['Code'].should.equal('InvalidParameterValue')
try:
# Test Number DataType, with a non numeric value
conn.publish(
TopicArn=topic_arn, Message=message,
MessageAttributes={'price': {
'DataType': 'Number',
'StringValue': 'error'
}})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameterValue')
err.response['Error']['Message'].should.equal("An error occurred (ParameterValueInvalid) when calling the Publish operation: Could not cast message attribute 'price' value to number.")
@mock_sqs @mock_sqs
@ -487,3 +498,380 @@ def test_filtering_exact_string_no_attributes_no_match():
message_attributes = [ message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages] json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([]) message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_exact_number_int():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'Number', 'Value': 100}}])
@mock_sqs
@mock_sns
def test_filtering_exact_number_float():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100.1]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '100.1'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'Number', 'Value': 100.1}}])
@mock_sqs
@mock_sns
def test_filtering_exact_number_float_accuracy():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100.123456789]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '100.1234561'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'Number', 'Value': 100.1234561}}])
@mock_sqs
@mock_sns
def test_filtering_exact_number_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='no match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '101'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_exact_number_with_string_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='no match',
MessageAttributes={'price': {'DataType': 'String',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_string_array_match():
topic, subscription, queue = _setup_filter_policy_test(
{'customer_interests': ['basketball', 'baseball']})
topic.publish(
Message='match',
MessageAttributes={'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}}])
@mock_sqs
@mock_sns
def test_filtering_string_array_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'customer_interests': ['baseball']})
topic.publish(
Message='no_match',
MessageAttributes={'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_string_array_with_number_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100, 500]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': json.dumps([100, 50])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'String.Array', 'Value': json.dumps([100, 50])}}])
@mock_sqs
@mock_sns
def test_filtering_string_array_with_number_float_accuracy_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100.123456789, 500]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': json.dumps([100.1234561, 50])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'String.Array', 'Value': json.dumps([100.1234561, 50])}}])
@mock_sqs
@mock_sns
# this is the correct behavior from SNS
def test_filtering_string_array_with_number_no_array_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100, 500]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'String.Array', 'Value': '100'}}])
@mock_sqs
@mock_sns
def test_filtering_string_array_with_number_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [500]})
topic.publish(
Message='no_match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': json.dumps([100, 50])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
# this is the correct behavior from SNS
def test_filtering_string_array_with_string_no_array_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='no_match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': 'one hundread'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_exists_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}]})
topic.publish(
Message='match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'store': {'Type': 'String', 'Value': 'example_corp'}}])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_exists_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}]})
topic.publish(
Message='no match',
MessageAttributes={'event': {'DataType': 'String',
'StringValue': 'order_cancelled'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_not_exists_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': False}]})
topic.publish(
Message='match',
MessageAttributes={'event': {'DataType': 'String',
'StringValue': 'order_cancelled'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'event': {'Type': 'String', 'Value': 'order_cancelled'}}])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_not_exists_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': False}]})
topic.publish(
Message='no match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_all_AND_matching_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}],
'event': ['order_cancelled'],
'customer_interests': ['basketball', 'baseball'],
'price': [100]})
topic.publish(
Message='match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'},
'event': {'DataType': 'String',
'StringValue': 'order_cancelled'},
'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])},
'price': {'DataType': 'Number',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(
['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([{
'store': {'Type': 'String', 'Value': 'example_corp'},
'event': {'Type': 'String', 'Value': 'order_cancelled'},
'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])},
'price': {'Type': 'Number', 'Value': 100}}])
@mock_sqs
@mock_sns
def test_filtering_all_AND_matching_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}],
'event': ['order_cancelled'],
'customer_interests': ['basketball', 'baseball'],
'price': [100],
"encrypted": [False]})
topic.publish(
Message='no match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'},
'event': {'DataType': 'String',
'StringValue': 'order_cancelled'},
'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])},
'price': {'DataType': 'Number',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])

View File

@ -201,7 +201,9 @@ def test_creating_subscription_with_attributes():
"store": ["example_corp"], "store": ["example_corp"],
"event": ["order_cancelled"], "event": ["order_cancelled"],
"encrypted": [False], "encrypted": [False],
"customer_interests": ["basketball", "baseball"] "customer_interests": ["basketball", "baseball"],
"price": [100, 100.12],
"error": [None]
}) })
conn.subscribe(TopicArn=topic_arn, conn.subscribe(TopicArn=topic_arn,
@ -294,7 +296,9 @@ def test_set_subscription_attributes():
"store": ["example_corp"], "store": ["example_corp"],
"event": ["order_cancelled"], "event": ["order_cancelled"],
"encrypted": [False], "encrypted": [False],
"customer_interests": ["basketball", "baseball"] "customer_interests": ["basketball", "baseball"],
"price": [100, 100.12],
"error": [None]
}) })
conn.set_subscription_attributes( conn.set_subscription_attributes(
SubscriptionArn=subscription_arn, SubscriptionArn=subscription_arn,
@ -332,6 +336,77 @@ def test_set_subscription_attributes():
) )
@mock_sns
def test_subscribe_invalid_filter_policy():
conn = boto3.client('sns', region_name = 'us-east-1')
conn.create_topic(Name = 'some-topic')
response = conn.list_topics()
topic_arn = response['Topics'][0]['TopicArn']
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [str(i) for i in range(151)]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Filter policy is too complex')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [['example_corp']]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [{'exists': None}]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: exists match pattern must be either true or false.')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [{'error': True}]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Unrecognized match type error')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [1000000001]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InternalFailure')
@mock_sns @mock_sns
def test_check_not_opted_out(): def test_check_not_opted_out():
conn = boto3.client('sns', region_name='us-east-1') conn = boto3.client('sns', region_name='us-east-1')

View File

@ -1117,6 +1117,28 @@ def test_redrive_policy_set_attributes():
assert copy_policy == redrive_policy assert copy_policy == redrive_policy
@mock_sqs
def test_redrive_policy_set_attributes_with_string_value():
sqs = boto3.resource('sqs', region_name='us-east-1')
queue = sqs.create_queue(QueueName='test-queue')
deadletter_queue = sqs.create_queue(QueueName='test-deadletter')
queue.set_attributes(Attributes={
'RedrivePolicy': json.dumps({
'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'],
'maxReceiveCount': '1',
})})
copy = sqs.get_queue_by_name(QueueName='test-queue')
assert 'RedrivePolicy' in copy.attributes
copy_policy = json.loads(copy.attributes['RedrivePolicy'])
assert copy_policy == {
'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'],
'maxReceiveCount': 1,
}
@mock_sqs @mock_sqs
def test_receive_messages_with_message_group_id(): def test_receive_messages_with_message_group_id():
sqs = boto3.resource('sqs', region_name='us-east-1') sqs = boto3.resource('sqs', region_name='us-east-1')

View File

@ -0,0 +1,378 @@
from __future__ import unicode_literals
import boto3
import sure # noqa
import datetime
from datetime import datetime
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_sts, mock_stepfunctions
region = 'us-east-1'
simple_definition = '{"Comment": "An example of the Amazon States Language using a choice state.",' \
'"StartAt": "DefaultState",' \
'"States": ' \
'{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}'
account_id = None
@mock_stepfunctions
@mock_sts
def test_state_machine_creation_succeeds():
client = boto3.client('stepfunctions', region_name=region)
name = 'example_step_function'
#
response = client.create_state_machine(name=name,
definition=str(simple_definition),
roleArn=_get_default_role())
#
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
response['creationDate'].should.be.a(datetime)
response['stateMachineArn'].should.equal('arn:aws:states:' + region + ':123456789012:stateMachine:' + name)
@mock_stepfunctions
def test_state_machine_creation_fails_with_invalid_names():
client = boto3.client('stepfunctions', region_name=region)
invalid_names = [
'with space',
'with<bracket', 'with>bracket', 'with{bracket', 'with}bracket', 'with[bracket', 'with]bracket',
'with?wildcard', 'with*wildcard',
'special"char', 'special#char', 'special%char', 'special\\char', 'special^char', 'special|char',
'special~char', 'special`char', 'special$char', 'special&char', 'special,char', 'special;char',
'special:char', 'special/char',
u'uni\u0000code', u'uni\u0001code', u'uni\u0002code', u'uni\u0003code', u'uni\u0004code',
u'uni\u0005code', u'uni\u0006code', u'uni\u0007code', u'uni\u0008code', u'uni\u0009code',
u'uni\u000Acode', u'uni\u000Bcode', u'uni\u000Ccode',
u'uni\u000Dcode', u'uni\u000Ecode', u'uni\u000Fcode',
u'uni\u0010code', u'uni\u0011code', u'uni\u0012code', u'uni\u0013code', u'uni\u0014code',
u'uni\u0015code', u'uni\u0016code', u'uni\u0017code', u'uni\u0018code', u'uni\u0019code',
u'uni\u001Acode', u'uni\u001Bcode', u'uni\u001Ccode',
u'uni\u001Dcode', u'uni\u001Ecode', u'uni\u001Fcode',
u'uni\u007Fcode',
u'uni\u0080code', u'uni\u0081code', u'uni\u0082code', u'uni\u0083code', u'uni\u0084code',
u'uni\u0085code', u'uni\u0086code', u'uni\u0087code', u'uni\u0088code', u'uni\u0089code',
u'uni\u008Acode', u'uni\u008Bcode', u'uni\u008Ccode',
u'uni\u008Dcode', u'uni\u008Ecode', u'uni\u008Fcode',
u'uni\u0090code', u'uni\u0091code', u'uni\u0092code', u'uni\u0093code', u'uni\u0094code',
u'uni\u0095code', u'uni\u0096code', u'uni\u0097code', u'uni\u0098code', u'uni\u0099code',
u'uni\u009Acode', u'uni\u009Bcode', u'uni\u009Ccode',
u'uni\u009Dcode', u'uni\u009Ecode', u'uni\u009Fcode']
#
for invalid_name in invalid_names:
with assert_raises(ClientError) as exc:
client.create_state_machine(name=invalid_name,
definition=str(simple_definition),
roleArn=_get_default_role())
@mock_stepfunctions
def test_state_machine_creation_requires_valid_role_arn():
client = boto3.client('stepfunctions', region_name=region)
name = 'example_step_function'
#
with assert_raises(ClientError) as exc:
client.create_state_machine(name=name,
definition=str(simple_definition),
roleArn='arn:aws:iam:1234:role/unknown_role')
@mock_stepfunctions
def test_state_machine_list_returns_empty_list_by_default():
client = boto3.client('stepfunctions', region_name=region)
#
list = client.list_state_machines()
list['stateMachines'].should.be.empty
@mock_stepfunctions
@mock_sts
def test_state_machine_list_returns_created_state_machines():
client = boto3.client('stepfunctions', region_name=region)
#
machine2 = client.create_state_machine(name='name2',
definition=str(simple_definition),
roleArn=_get_default_role())
machine1 = client.create_state_machine(name='name1',
definition=str(simple_definition),
roleArn=_get_default_role(),
tags=[{'key': 'tag_key', 'value': 'tag_value'}])
list = client.list_state_machines()
#
list['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
list['stateMachines'].should.have.length_of(2)
list['stateMachines'][0]['creationDate'].should.be.a(datetime)
list['stateMachines'][0]['creationDate'].should.equal(machine1['creationDate'])
list['stateMachines'][0]['name'].should.equal('name1')
list['stateMachines'][0]['stateMachineArn'].should.equal(machine1['stateMachineArn'])
list['stateMachines'][1]['creationDate'].should.be.a(datetime)
list['stateMachines'][1]['creationDate'].should.equal(machine2['creationDate'])
list['stateMachines'][1]['name'].should.equal('name2')
list['stateMachines'][1]['stateMachineArn'].should.equal(machine2['stateMachineArn'])
@mock_stepfunctions
@mock_sts
def test_state_machine_creation_is_idempotent_by_name():
client = boto3.client('stepfunctions', region_name=region)
#
client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(1)
#
client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(1)
#
client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=_get_default_role())
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(2)
@mock_stepfunctions
@mock_sts
def test_state_machine_creation_can_be_described():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
desc = client.describe_state_machine(stateMachineArn=sm['stateMachineArn'])
desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
desc['creationDate'].should.equal(sm['creationDate'])
desc['definition'].should.equal(str(simple_definition))
desc['name'].should.equal('name')
desc['roleArn'].should.equal(_get_default_role())
desc['stateMachineArn'].should.equal(sm['stateMachineArn'])
desc['status'].should.equal('ACTIVE')
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_unknown_machine():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown'
client.describe_state_machine(stateMachineArn=unknown_state_machine)
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_machine_in_different_account():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_state_machine = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown'
client.describe_state_machine(stateMachineArn=unknown_state_machine)
@mock_stepfunctions
@mock_sts
def test_state_machine_can_be_deleted():
client = boto3.client('stepfunctions', region_name=region)
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
#
response = client.delete_state_machine(stateMachineArn=sm['stateMachineArn'])
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
#
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_can_deleted_nonexisting_machine():
client = boto3.client('stepfunctions', region_name=region)
#
unknown_state_machine = 'arn:aws:states:' + region + ':123456789012:stateMachine:unknown'
response = client.delete_state_machine(stateMachineArn=unknown_state_machine)
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
#
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_list_tags_for_created_machine():
client = boto3.client('stepfunctions', region_name=region)
#
machine = client.create_state_machine(name='name1',
definition=str(simple_definition),
roleArn=_get_default_role(),
tags=[{'key': 'tag_key', 'value': 'tag_value'}])
response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn'])
tags = response['tags']
tags.should.have.length_of(1)
tags[0].should.equal({'key': 'tag_key', 'value': 'tag_value'})
@mock_stepfunctions
@mock_sts
def test_state_machine_list_tags_for_machine_without_tags():
client = boto3.client('stepfunctions', region_name=region)
#
machine = client.create_state_machine(name='name1',
definition=str(simple_definition),
roleArn=_get_default_role())
response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn'])
tags = response['tags']
tags.should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_list_tags_for_nonexisting_machine():
client = boto3.client('stepfunctions', region_name=region)
#
non_existing_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown'
response = client.list_tags_for_resource(resourceArn=non_existing_state_machine)
tags = response['tags']
tags.should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_start_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
#
execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
expected_exec_name = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:name:[a-zA-Z0-9-]+'
execution['executionArn'].should.match(expected_exec_name)
execution['startDate'].should.be.a(datetime)
@mock_stepfunctions
@mock_sts
def test_state_machine_list_executions():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
execution_arn = execution['executionArn']
execution_name = execution_arn[execution_arn.rindex(':')+1:]
executions = client.list_executions(stateMachineArn=sm['stateMachineArn'])
#
executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
executions['executions'].should.have.length_of(1)
executions['executions'][0]['executionArn'].should.equal(execution_arn)
executions['executions'][0]['name'].should.equal(execution_name)
executions['executions'][0]['startDate'].should.equal(execution['startDate'])
executions['executions'][0]['stateMachineArn'].should.equal(sm['stateMachineArn'])
executions['executions'][0]['status'].should.equal('RUNNING')
executions['executions'][0].shouldnt.have('stopDate')
@mock_stepfunctions
@mock_sts
def test_state_machine_list_executions_when_none_exist():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
executions = client.list_executions(stateMachineArn=sm['stateMachineArn'])
#
executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
executions['executions'].should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_describe_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
description = client.describe_execution(executionArn=execution['executionArn'])
#
description['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
description['executionArn'].should.equal(execution['executionArn'])
description['input'].should.equal("{}")
description['name'].shouldnt.be.empty
description['startDate'].should.equal(execution['startDate'])
description['stateMachineArn'].should.equal(sm['stateMachineArn'])
description['status'].should.equal('RUNNING')
description.shouldnt.have('stopDate')
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_unknown_machine():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown'
client.describe_execution(executionArn=unknown_execution)
@mock_stepfunctions
@mock_sts
def test_state_machine_can_be_described_by_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
desc = client.describe_state_machine_for_execution(executionArn=execution['executionArn'])
desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
desc['definition'].should.equal(str(simple_definition))
desc['name'].should.equal('name')
desc['roleArn'].should.equal(_get_default_role())
desc['stateMachineArn'].should.equal(sm['stateMachineArn'])
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_unknown_execution():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown'
client.describe_state_machine_for_execution(executionArn=unknown_execution)
@mock_stepfunctions
@mock_sts
def test_state_machine_stop_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
start = client.start_execution(stateMachineArn=sm['stateMachineArn'])
stop = client.stop_execution(executionArn=start['executionArn'])
#
stop['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
stop['stopDate'].should.be.a(datetime)
@mock_stepfunctions
@mock_sts
def test_state_machine_describe_execution_after_stoppage():
account_id
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
client.stop_execution(executionArn=execution['executionArn'])
description = client.describe_execution(executionArn=execution['executionArn'])
#
description['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
description['status'].should.equal('SUCCEEDED')
description['stopDate'].should.be.a(datetime)
def _get_account_id():
global account_id
if account_id:
return account_id
sts = boto3.client("sts")
identity = sts.get_caller_identity()
account_id = identity['Account']
return account_id
def _get_default_role():
return 'arn:aws:iam:' + _get_account_id() + ':role/unknown_sf_role'