Merge pull request #2 from spulec/master

Upstream changes
This commit is contained in:
Bert Blommers 2019-10-03 09:18:17 +01:00 committed by GitHub
commit 82ce5ad430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
66 changed files with 3740 additions and 790 deletions

View File

@ -1237,7 +1237,7 @@
- [ ] delete_identities
- [ ] delete_identity_pool
- [ ] describe_identity
- [ ] describe_identity_pool
- [X] describe_identity_pool
- [X] get_credentials_for_identity
- [X] get_id
- [ ] get_identity_pool_roles
@ -3801,14 +3801,14 @@
- [ ] update_stream
## kms
41% implemented
54% implemented
- [X] cancel_key_deletion
- [ ] connect_custom_key_store
- [ ] create_alias
- [ ] create_custom_key_store
- [ ] create_grant
- [X] create_key
- [ ] decrypt
- [X] decrypt
- [X] delete_alias
- [ ] delete_custom_key_store
- [ ] delete_imported_key_material
@ -3819,10 +3819,10 @@
- [ ] disconnect_custom_key_store
- [X] enable_key
- [X] enable_key_rotation
- [ ] encrypt
- [X] encrypt
- [X] generate_data_key
- [ ] generate_data_key_without_plaintext
- [ ] generate_random
- [X] generate_data_key_without_plaintext
- [X] generate_random
- [X] get_key_policy
- [X] get_key_rotation_status
- [ ] get_parameters_for_import
@ -3834,7 +3834,7 @@
- [X] list_resource_tags
- [ ] list_retirable_grants
- [X] put_key_policy
- [ ] re_encrypt
- [X] re_encrypt
- [ ] retire_grant
- [ ] revoke_grant
- [X] schedule_key_deletion
@ -6050,24 +6050,24 @@
## stepfunctions
0% implemented
- [ ] create_activity
- [ ] create_state_machine
- [X] create_state_machine
- [ ] delete_activity
- [ ] delete_state_machine
- [X] delete_state_machine
- [ ] describe_activity
- [ ] describe_execution
- [ ] describe_state_machine
- [ ] describe_state_machine_for_execution
- [X] describe_execution
- [X] describe_state_machine
- [x] describe_state_machine_for_execution
- [ ] get_activity_task
- [ ] get_execution_history
- [ ] list_activities
- [ ] list_executions
- [ ] list_state_machines
- [ ] list_tags_for_resource
- [X] list_executions
- [X] list_state_machines
- [X] list_tags_for_resource
- [ ] send_task_failure
- [ ] send_task_heartbeat
- [ ] send_task_success
- [ ] start_execution
- [ ] stop_execution
- [X] start_execution
- [X] stop_execution
- [ ] tag_resource
- [ ] untag_resource
- [ ] update_state_machine

View File

@ -94,6 +94,8 @@ Currently implemented Services:
+---------------------------+-----------------------+------------------------------------+
| SES | @mock_ses | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SFN | @mock_stepfunctions | basic endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SNS | @mock_sns | all endpoints done |
+---------------------------+-----------------------+------------------------------------+
| SQS | @mock_sqs | core endpoints done |

View File

@ -42,6 +42,7 @@ from .ses import mock_ses, mock_ses_deprecated # flake8: noqa
from .secretsmanager import mock_secretsmanager # flake8: noqa
from .sns import mock_sns, mock_sns_deprecated # flake8: noqa
from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa
from .stepfunctions import mock_stepfunctions # flake8: noqa
from .sts import mock_sts, mock_sts_deprecated # flake8: noqa
from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa

View File

@ -298,7 +298,7 @@ class Stage(BaseModel, dict):
class ApiKey(BaseModel, dict):
def __init__(self, name=None, description=None, enabled=True,
generateDistinctId=False, value=None, stageKeys=None, customerId=None):
generateDistinctId=False, value=None, stageKeys=None, tags=None, customerId=None):
super(ApiKey, self).__init__()
self['id'] = create_id()
self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40))
@ -308,6 +308,7 @@ class ApiKey(BaseModel, dict):
self['enabled'] = enabled
self['createdDate'] = self['lastUpdatedDate'] = int(time.time())
self['stageKeys'] = stageKeys
self['tags'] = tags
def update_operations(self, patch_operations):
for op in patch_operations:
@ -407,10 +408,16 @@ class RestAPI(BaseModel):
stage_url_upper = STAGE_URL.format(api_id=self.id.upper(),
region_name=self.region_name, stage_name=stage_name)
responses.add_callback(responses.GET, stage_url_lower,
callback=self.resource_callback)
responses.add_callback(responses.GET, stage_url_upper,
callback=self.resource_callback)
for url in [stage_url_lower, stage_url_upper]:
responses._default_mock._matches.insert(0,
responses.CallbackResponse(
url=url,
method=responses.GET,
callback=self.resource_callback,
content_type="text/plain",
match_querystring=False,
)
)
def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None):
if variables is None:

View File

@ -40,6 +40,7 @@ from moto.secretsmanager import secretsmanager_backends
from moto.sns import sns_backends
from moto.sqs import sqs_backends
from moto.ssm import ssm_backends
from moto.stepfunctions import stepfunction_backends
from moto.sts import sts_backends
from moto.swf import swf_backends
from moto.xray import xray_backends
@ -91,6 +92,7 @@ BACKENDS = {
'sns': sns_backends,
'sqs': sqs_backends,
'ssm': ssm_backends,
'stepfunctions': stepfunction_backends,
'sts': sts_backends,
'swf': swf_backends,
'route53': route53_backends,

View File

@ -0,0 +1,15 @@
from __future__ import unicode_literals
import json
from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest):
def __init__(self, message):
super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'ResourceNotFoundException',
})

View File

@ -8,7 +8,7 @@ import boto.cognito.identity
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds
from .exceptions import ResourceNotFoundError
from .utils import get_random_identity_id
@ -39,10 +39,29 @@ class CognitoIdentityBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region)
def describe_identity_pool(self, identity_pool_id):
identity_pool = self.identity_pools.get(identity_pool_id, None)
if not identity_pool:
raise ResourceNotFoundError(identity_pool)
response = json.dumps({
'AllowUnauthenticatedIdentities': identity_pool.allow_unauthenticated_identities,
'CognitoIdentityProviders': identity_pool.cognito_identity_providers,
'DeveloperProviderName': identity_pool.developer_provider_name,
'IdentityPoolId': identity_pool.identity_pool_id,
'IdentityPoolName': identity_pool.identity_pool_name,
'IdentityPoolTags': {},
'OpenIdConnectProviderARNs': identity_pool.open_id_connect_provider_arns,
'SamlProviderARNs': identity_pool.saml_provider_arns,
'SupportedLoginProviders': identity_pool.supported_login_providers
})
return response
def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities,
supported_login_providers, developer_provider_name, open_id_connect_provider_arns,
cognito_identity_providers, saml_provider_arns):
new_identity = CognitoIdentity(self.region, identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers,

View File

@ -1,7 +1,6 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from .models import cognitoidentity_backends
from .utils import get_random_identity_id
@ -16,6 +15,7 @@ class CognitoIdentityResponse(BaseResponse):
open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs')
cognito_identity_providers = self._get_param('CognitoIdentityProviders')
saml_provider_arns = self._get_param('SamlProviderARNs')
return cognitoidentity_backends[self.region].create_identity_pool(
identity_pool_name=identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities,
@ -28,6 +28,9 @@ class CognitoIdentityResponse(BaseResponse):
def get_id(self):
return cognitoidentity_backends[self.region].get_id()
def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool(self._get_param('IdentityPoolId'))
def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId'))

View File

@ -14,7 +14,7 @@ SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
"""
ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
<Response>
<ErrorResponse>
<Errors>
<Error>
<Code>{{error_type}}</Code>
@ -23,7 +23,7 @@ ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error>
</Errors>
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID>
</Response>
</ErrorResponse>
"""
ERROR_JSON_RESPONSE = u"""{

View File

@ -197,53 +197,9 @@ class CallbackResponse(responses.CallbackResponse):
botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send')
responses_mock = responses._default_mock
class ResponsesMockAWS(BaseMockAWS):
def reset(self):
botocore_mock.reset()
responses_mock.reset()
def enable_patching(self):
if not hasattr(botocore_mock, '_patcher') or not hasattr(botocore_mock._patcher, 'target'):
# Check for unactivated patcher
botocore_mock.start()
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'):
responses_mock.start()
for method in RESPONSES_METHODS:
for backend in self.backends_for_urls.values():
for key, value in backend.urls.items():
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile(key),
callback=convert_flask_to_responses_response(value),
stream=True,
match_querystring=False,
)
)
def disable_patching(self):
try:
botocore_mock.stop()
except RuntimeError:
pass
try:
responses_mock.stop()
except RuntimeError:
pass
# Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests.
responses_mock.add_passthru("http")
BOTOCORE_HTTP_METHODS = [
@ -310,6 +266,14 @@ botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(('before-send', botocore_stubber))
def not_implemented_callback(request):
status = 400
headers = {}
response = "The method is not implemented"
return status, headers, response
class BotocoreEventMockAWS(BaseMockAWS):
def reset(self):
botocore_stubber.reset()
@ -339,6 +303,24 @@ class BotocoreEventMockAWS(BaseMockAWS):
match_querystring=False,
)
)
responses_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
botocore_mock.add(
CallbackResponse(
method=method,
url=re.compile("https?://.+.amazonaws.com/.*"),
callback=not_implemented_callback,
stream=True,
match_querystring=False,
)
)
def disable_patching(self):
botocore_stubber.enabled = False

View File

@ -941,8 +941,7 @@ class OpAnd(Op):
def expr(self, item):
lhs = self.lhs.expr(item)
rhs = self.rhs.expr(item)
return lhs and rhs
return lhs and self.rhs.expr(item)
class OpLessThan(Op):

View File

@ -363,7 +363,7 @@ class StreamRecord(BaseModel):
'dynamodb': {
'StreamViewType': stream_type,
'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(),
'SequenceNumber': seq,
'SequenceNumber': str(seq),
'SizeBytes': 1,
'Keys': keys
}

View File

@ -356,9 +356,18 @@ class DynamoHandler(BaseResponse):
if projection_expression and expression_attribute_names:
expressions = [x.strip() for x in projection_expression.split(',')]
projection_expression = None
for expression in expressions:
if projection_expression is not None:
projection_expression = projection_expression + ", "
else:
projection_expression = ""
if expression in expression_attribute_names:
projection_expression = projection_expression.replace(expression, expression_attribute_names[expression])
projection_expression = projection_expression + \
expression_attribute_names[expression]
else:
projection_expression = projection_expression + expression
filter_kwargs = {}

View File

@ -39,7 +39,7 @@ class ShardIterator(BaseModel):
def get(self, limit=1000):
items = self.stream_shard.get(self.sequence_number, limit)
try:
last_sequence_number = max(i['dynamodb']['SequenceNumber'] for i in items)
last_sequence_number = max(int(i['dynamodb']['SequenceNumber']) for i in items)
new_shard_iterator = ShardIterator(self.streams_backend,
self.stream_shard,
'AFTER_SEQUENCE_NUMBER',

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from .models import dynamodbstreams_backends
from six import string_types
class DynamoDBStreamsHandler(BaseResponse):
@ -23,8 +24,13 @@ class DynamoDBStreamsHandler(BaseResponse):
arn = self._get_param('StreamArn')
shard_id = self._get_param('ShardId')
shard_iterator_type = self._get_param('ShardIteratorType')
sequence_number = self._get_param('SequenceNumber')
# according to documentation sequence_number param should be string
if isinstance(sequence_number, string_types):
sequence_number = int(sequence_number)
return self.backend.get_shard_iterator(arn, shard_id,
shard_iterator_type)
shard_iterator_type, sequence_number)
def get_records(self):
arn = self._get_param('ShardIterator')

View File

@ -6,6 +6,7 @@ from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, \
dict_from_querystring
from moto.elbv2 import elbv2_backends
class InstanceResponse(BaseResponse):
@ -68,6 +69,7 @@ class InstanceResponse(BaseResponse):
if self.is_not_dryrun('TerminateInstance'):
instances = self.ec2_backend.terminate_instances(instance_ids)
autoscaling_backends[self.region].notify_terminate_instances(instance_ids)
elbv2_backends[self.region].notify_terminate_instances(instance_ids)
template = self.response_template(EC2_TERMINATE_INSTANCES)
return template.render(instances=instances)

View File

@ -131,7 +131,7 @@ class InvalidActionTypeError(ELBClientError):
def __init__(self, invalid_name, index):
super(InvalidActionTypeError, self).__init__(
"ValidationError",
"1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect]" % (invalid_name, index)
"1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect, fixed-response]" % (invalid_name, index)
)
@ -190,3 +190,18 @@ class InvalidModifyRuleArgumentsError(ELBClientError):
"ValidationError",
"Either conditions or actions must be specified"
)
class InvalidStatusCodeActionTypeError(ELBClientError):
def __init__(self, msg):
super(InvalidStatusCodeActionTypeError, self).__init__(
"ValidationError", msg
)
class InvalidLoadBalancerActionException(ELBClientError):
def __init__(self, msg):
super(InvalidLoadBalancerActionException, self).__init__(
"InvalidLoadBalancerAction", msg
)

View File

@ -3,10 +3,11 @@ from __future__ import unicode_literals
import datetime
import re
from jinja2 import Template
from botocore.exceptions import ParamValidationError
from moto.compat import OrderedDict
from moto.core.exceptions import RESTError
from moto.core import BaseBackend, BaseModel
from moto.core.utils import camelcase_to_underscores
from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase
from moto.ec2.models import ec2_backends
from moto.acm.models import acm_backends
from .utils import make_arn_for_target_group
@ -31,8 +32,8 @@ from .exceptions import (
RuleNotFoundError,
DuplicatePriorityError,
InvalidTargetGroupNameError,
InvalidModifyRuleArgumentsError
)
InvalidModifyRuleArgumentsError,
InvalidStatusCodeActionTypeError, InvalidLoadBalancerActionException)
class FakeHealthStatus(BaseModel):
@ -110,6 +111,11 @@ class FakeTargetGroup(BaseModel):
if not t:
raise InvalidTargetError()
def deregister_terminated_instances(self, instance_ids):
for target_id in list(self.targets.keys()):
if target_id in instance_ids:
del self.targets[target_id]
def add_tag(self, key, value):
if len(self.tags) >= 10 and key not in self.tags:
raise TooManyTagsError()
@ -215,9 +221,9 @@ class FakeListener(BaseModel):
action_type = action['Type']
if action_type == 'forward':
default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']})
elif action_type in ['redirect', 'authenticate-cognito']:
elif action_type in ['redirect', 'authenticate-cognito', 'fixed-response']:
redirect_action = {'type': action_type}
key = 'RedirectConfig' if action_type == 'redirect' else 'AuthenticateCognitoConfig'
key = underscores_to_camelcase(action_type.capitalize().replace('-', '_')) + 'Config'
for redirect_config_key, redirect_config_value in action[key].items():
# need to match the output of _get_list_prefix
redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value
@ -253,6 +259,12 @@ class FakeAction(BaseModel):
<UserPoolClientId>{{ action.data["authenticate_cognito_config._user_pool_client_id"] }}</UserPoolClientId>
<UserPoolDomain>{{ action.data["authenticate_cognito_config._user_pool_domain"] }}</UserPoolDomain>
</AuthenticateCognitoConfig>
{% elif action.type == "fixed-response" %}
<FixedResponseConfig>
<ContentType>{{ action.data["fixed_response_config._content_type"] }}</ContentType>
<MessageBody>{{ action.data["fixed_response_config._message_body"] }}</MessageBody>
<StatusCode>{{ action.data["fixed_response_config._status_code"] }}</StatusCode>
</FixedResponseConfig>
{% endif %}
""")
return template.render(action=self)
@ -477,11 +489,30 @@ class ELBv2Backend(BaseBackend):
action_target_group_arn = action.data['target_group_arn']
if action_target_group_arn not in target_group_arns:
raise ActionTargetGroupNotFoundError(action_target_group_arn)
elif action_type == 'fixed-response':
self._validate_fixed_response_action(action, i, index)
elif action_type in ['redirect', 'authenticate-cognito']:
pass
else:
raise InvalidActionTypeError(action_type, index)
def _validate_fixed_response_action(self, action, i, index):
status_code = action.data.get('fixed_response_config._status_code')
if status_code is None:
raise ParamValidationError(
report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' % i)
if not re.match(r'^(2|4|5)\d\d$', status_code):
raise InvalidStatusCodeActionTypeError(
"1 validation error detected: Value '%s' at 'actions.%s.member.fixedResponseConfig.statusCode' failed to satisfy constraint: \
Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, index)
)
content_type = action.data['fixed_response_config._content_type']
if content_type and content_type not in ['text/plain', 'text/css', 'text/html', 'application/javascript',
'application/json']:
raise InvalidLoadBalancerActionException(
"The ContentType must be one of:'text/html', 'application/json', 'application/javascript', 'text/css', 'text/plain'"
)
def create_target_group(self, name, **kwargs):
if len(name) > 32:
raise InvalidTargetGroupNameError(
@ -936,6 +967,10 @@ class ELBv2Backend(BaseBackend):
return True
return False
def notify_terminate_instances(self, instance_ids):
for target_group in self.target_groups.values():
target_group.deregister_terminated_instances(instance_ids)
elbv2_backends = {}
for region in ec2_backends.keys():

View File

@ -152,8 +152,10 @@ class IAMPolicyDocumentValidator:
sids = []
for statement in self._statements:
if "Sid" in statement:
assert statement["Sid"] not in sids
sids.append(statement["Sid"])
statementId = statement["Sid"]
if statementId:
assert statementId not in sids
sids.append(statementId)
def _validate_statements_syntax(self):
assert "Statement" in self._policy_json

View File

@ -21,3 +21,11 @@ class InvalidRequestException(IoTDataPlaneClientError):
super(InvalidRequestException, self).__init__(
"InvalidRequestException", message
)
class ConflictException(IoTDataPlaneClientError):
def __init__(self, message):
self.code = 409
super(ConflictException, self).__init__(
"ConflictException", message
)

View File

@ -6,6 +6,7 @@ import jsondiff
from moto.core import BaseBackend, BaseModel
from moto.iot import iot_backends
from .exceptions import (
ConflictException,
ResourceNotFoundException,
InvalidRequestException
)
@ -161,6 +162,8 @@ class IoTDataPlaneBackend(BaseBackend):
if any(_ for _ in payload['state'].keys() if _ not in ['desired', 'reported']):
raise InvalidRequestException('State contains an invalid node')
if 'version' in payload and thing.thing_shadow.version != payload['version']:
raise ConflictException('Version conflict')
new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload)
thing.thing_shadow = new_shadow
return thing.thing_shadow

View File

@ -34,3 +34,23 @@ class NotAuthorizedException(JsonRESTError):
"NotAuthorizedException", None)
self.description = '{"__type":"NotAuthorizedException"}'
class AccessDeniedException(JsonRESTError):
code = 400
def __init__(self, message):
super(AccessDeniedException, self).__init__(
"AccessDeniedException", message)
self.description = '{"__type":"AccessDeniedException"}'
class InvalidCiphertextException(JsonRESTError):
code = 400
def __init__(self):
super(InvalidCiphertextException, self).__init__(
"InvalidCiphertextException", None)
self.description = '{"__type":"InvalidCiphertextException"}'

View File

@ -1,16 +1,18 @@
from __future__ import unicode_literals
import os
import boto.kms
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import generate_key_id
from collections import defaultdict
from datetime import datetime, timedelta
import boto.kms
from moto.core import BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_without_milliseconds
from .utils import decrypt, encrypt, generate_key_id, generate_master_key
class Key(BaseModel):
def __init__(self, policy, key_usage, description, tags, region):
self.id = generate_key_id()
self.policy = policy
@ -19,10 +21,11 @@ class Key(BaseModel):
self.description = description
self.enabled = True
self.region = region
self.account_id = "0123456789012"
self.account_id = "012345678912"
self.key_rotation_status = False
self.deletion_date = None
self.tags = tags or {}
self.key_material = generate_master_key()
@property
def physical_resource_id(self):
@ -45,8 +48,8 @@ class Key(BaseModel):
"KeyState": self.key_state,
}
}
if self.key_state == 'PendingDeletion':
key_dict['KeyMetadata']['DeletionDate'] = iso_8601_datetime_without_milliseconds(self.deletion_date)
if self.key_state == "PendingDeletion":
key_dict["KeyMetadata"]["DeletionDate"] = iso_8601_datetime_without_milliseconds(self.deletion_date)
return key_dict
def delete(self, region_name):
@ -55,28 +58,28 @@ class Key(BaseModel):
@classmethod
def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name):
kms_backend = kms_backends[region_name]
properties = cloudformation_json['Properties']
properties = cloudformation_json["Properties"]
key = kms_backend.create_key(
policy=properties['KeyPolicy'],
key_usage='ENCRYPT_DECRYPT',
description=properties['Description'],
tags=properties.get('Tags'),
policy=properties["KeyPolicy"],
key_usage="ENCRYPT_DECRYPT",
description=properties["Description"],
tags=properties.get("Tags"),
region=region_name,
)
key.key_rotation_status = properties['EnableKeyRotation']
key.enabled = properties['Enabled']
key.key_rotation_status = properties["EnableKeyRotation"]
key.enabled = properties["Enabled"]
return key
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'Arn':
if attribute_name == "Arn":
return self.arn
raise UnformattedGetAttTemplateException()
class KmsBackend(BaseBackend):
def __init__(self):
self.keys = {}
self.key_to_aliases = defaultdict(set)
@ -109,16 +112,43 @@ class KmsBackend(BaseBackend):
# allow the different methods (alias, ARN :key/, keyId, ARN alias) to
# describe key not just KeyId
key_id = self.get_key_id(key_id)
if r'alias/' in str(key_id).lower():
key_id = self.get_key_id_from_alias(key_id.split('alias/')[1])
if r"alias/" in str(key_id).lower():
key_id = self.get_key_id_from_alias(key_id.split("alias/")[1])
return self.keys[self.get_key_id(key_id)]
def list_keys(self):
return self.keys.values()
def get_key_id(self, key_id):
@staticmethod
def get_key_id(key_id):
# Allow use of ARN as well as pure KeyId
return str(key_id).split(r':key/')[1] if r':key/' in str(key_id).lower() else key_id
if key_id.startswith("arn:") and ":key/" in key_id:
return key_id.split(":key/")[1]
return key_id
@staticmethod
def get_alias_name(alias_name):
# Allow use of ARN as well as alias name
if alias_name.startswith("arn:") and ":alias/" in alias_name:
return alias_name.split(":alias/")[1]
return alias_name
def any_id_to_key_id(self, key_id):
"""Go from any valid key ID to the raw key ID.
Acceptable inputs:
- raw key ID
- key ARN
- alias name
- alias ARN
"""
key_id = self.get_alias_name(key_id)
key_id = self.get_key_id(key_id)
if key_id.startswith("alias/"):
key_id = self.get_key_id_from_alias(key_id)
return key_id
def alias_exists(self, alias_name):
for aliases in self.key_to_aliases.values():
@ -162,37 +192,69 @@ class KmsBackend(BaseBackend):
def disable_key(self, key_id):
self.keys[key_id].enabled = False
self.keys[key_id].key_state = 'Disabled'
self.keys[key_id].key_state = "Disabled"
def enable_key(self, key_id):
self.keys[key_id].enabled = True
self.keys[key_id].key_state = 'Enabled'
self.keys[key_id].key_state = "Enabled"
def cancel_key_deletion(self, key_id):
self.keys[key_id].key_state = 'Disabled'
self.keys[key_id].key_state = "Disabled"
self.keys[key_id].deletion_date = None
def schedule_key_deletion(self, key_id, pending_window_in_days):
if 7 <= pending_window_in_days <= 30:
self.keys[key_id].enabled = False
self.keys[key_id].key_state = 'PendingDeletion'
self.keys[key_id].key_state = "PendingDeletion"
self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days)
return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date)
def encrypt(self, key_id, plaintext, encryption_context):
key_id = self.any_id_to_key_id(key_id)
ciphertext_blob = encrypt(
master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context
)
arn = self.keys[key_id].arn
return ciphertext_blob, arn
def decrypt(self, ciphertext_blob, encryption_context):
plaintext, key_id = decrypt(
master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
)
arn = self.keys[key_id].arn
return plaintext, arn
def re_encrypt(
self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context
):
destination_key_id = self.any_id_to_key_id(destination_key_id)
plaintext, decrypting_arn = self.decrypt(
ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context
)
new_ciphertext_blob, encrypting_arn = self.encrypt(
key_id=destination_key_id, plaintext=plaintext, encryption_context=destination_encryption_context
)
return new_ciphertext_blob, decrypting_arn, encrypting_arn
def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens):
key = self.keys[self.get_key_id(key_id)]
key_id = self.any_id_to_key_id(key_id)
if key_spec:
if key_spec == 'AES_128':
bytes = 16
# Note: Actual validation of key_spec is done in kms.responses
if key_spec == "AES_128":
plaintext_len = 16
else:
bytes = 32
plaintext_len = 32
else:
bytes = number_of_bytes
plaintext_len = number_of_bytes
plaintext = os.urandom(bytes)
plaintext = os.urandom(plaintext_len)
return plaintext, key.arn
ciphertext_blob, arn = self.encrypt(key_id=key_id, plaintext=plaintext, encryption_context=encryption_context)
return plaintext, ciphertext_blob, arn
kms_backends = {}

View File

@ -2,13 +2,16 @@ from __future__ import unicode_literals
import base64
import json
import os
import re
import six
from moto.core.responses import BaseResponse
from .models import kms_backends
from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException
ACCOUNT_ID = "012345678912"
reserved_aliases = [
'alias/aws/ebs',
'alias/aws/s3',
@ -21,13 +24,86 @@ class KmsResponse(BaseResponse):
@property
def parameters(self):
return json.loads(self.body)
params = json.loads(self.body)
for key in ("Plaintext", "CiphertextBlob"):
if key in params:
params[key] = base64.b64decode(params[key].encode("utf-8"))
return params
@property
def kms_backend(self):
return kms_backends[self.region]
def _display_arn(self, key_id):
if key_id.startswith("arn:"):
return key_id
if key_id.startswith("alias/"):
id_type = ""
else:
id_type = "key/"
return "arn:aws:kms:{region}:{account}:{id_type}{key_id}".format(
region=self.region, account=ACCOUNT_ID, id_type=id_type, key_id=key_id
)
def _validate_cmk_id(self, key_id):
"""Determine whether a CMK ID exists.
- raw key ID
- key ARN
"""
is_arn = key_id.startswith("arn:") and ":key/" in key_id
is_raw_key_id = re.match(r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", key_id, re.IGNORECASE)
if not is_arn and not is_raw_key_id:
raise NotFoundException("Invalid keyId {key_id}".format(key_id=key_id))
cmk_id = self.kms_backend.get_key_id(key_id)
if cmk_id not in self.kms_backend.keys:
raise NotFoundException("Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id)))
def _validate_alias(self, key_id):
"""Determine whether an alias exists.
- alias name
- alias ARN
"""
error = NotFoundException("Alias {key_id} is not found.".format(key_id=self._display_arn(key_id)))
is_arn = key_id.startswith("arn:") and ":alias/" in key_id
is_name = key_id.startswith("alias/")
if not is_arn and not is_name:
raise error
alias_name = self.kms_backend.get_alias_name(key_id)
cmk_id = self.kms_backend.get_key_id_from_alias(alias_name)
if cmk_id is None:
raise error
def _validate_key_id(self, key_id):
"""Determine whether or not a key ID exists.
- raw key ID
- key ARN
- alias name
- alias ARN
"""
is_alias_arn = key_id.startswith("arn:") and ":alias/" in key_id
is_alias_name = key_id.startswith("alias/")
if is_alias_arn or is_alias_name:
self._validate_alias(key_id)
return
self._validate_cmk_id(key_id)
def create_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html"""
policy = self.parameters.get('Policy')
key_usage = self.parameters.get('KeyUsage')
description = self.parameters.get('Description')
@ -38,20 +114,31 @@ class KmsResponse(BaseResponse):
return json.dumps(key.to_dict())
def update_key_description(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html"""
key_id = self.parameters.get('KeyId')
description = self.parameters.get('Description')
self._validate_cmk_id(key_id)
self.kms_backend.update_key_description(key_id, description)
return json.dumps(None)
def tag_resource(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html"""
key_id = self.parameters.get('KeyId')
tags = self.parameters.get('Tags')
self._validate_cmk_id(key_id)
self.kms_backend.tag_resource(key_id, tags)
return json.dumps({})
def list_resource_tags(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html"""
key_id = self.parameters.get('KeyId')
self._validate_cmk_id(key_id)
tags = self.kms_backend.list_resource_tags(key_id)
return json.dumps({
"Tags": tags,
@ -60,17 +147,19 @@ class KmsResponse(BaseResponse):
})
def describe_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html"""
key_id = self.parameters.get('KeyId')
try:
self._validate_key_id(key_id)
key = self.kms_backend.describe_key(
self.kms_backend.get_key_id(key_id))
except KeyError:
headers = dict(self.headers)
headers['status'] = 404
return "{}", headers
self.kms_backend.get_key_id(key_id)
)
return json.dumps(key.to_dict())
def list_keys(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html"""
keys = self.kms_backend.list_keys()
return json.dumps({
@ -85,6 +174,7 @@ class KmsResponse(BaseResponse):
})
def create_alias(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html"""
alias_name = self.parameters['AliasName']
target_key_id = self.parameters['TargetKeyId']
@ -110,27 +200,31 @@ class KmsResponse(BaseResponse):
raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} '
'already exists'.format(region=self.region, alias_name=alias_name))
self._validate_cmk_id(target_key_id)
self.kms_backend.add_alias(target_key_id, alias_name)
return json.dumps(None)
def delete_alias(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html"""
alias_name = self.parameters['AliasName']
if not alias_name.startswith('alias/'):
raise ValidationException('Invalid identifier')
if not self.kms_backend.alias_exists(alias_name):
raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:'
'{alias_name} is not found.'.format(region=self.region, alias_name=alias_name))
self._validate_alias(alias_name)
self.kms_backend.delete_alias(alias_name)
return json.dumps(None)
def list_aliases(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListAliases.html"""
region = self.region
# TODO: The actual API can filter on KeyId.
response_aliases = [
{
'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region,
@ -155,191 +249,239 @@ class KmsResponse(BaseResponse):
})
def enable_key_rotation(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
self.kms_backend.enable_key_rotation(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None)
def disable_key_rotation(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html"""
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
self.kms_backend.disable_key_rotation(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None)
def get_key_rotation_status(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html"""
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
rotation_enabled = self.kms_backend.get_key_rotation_status(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'KeyRotationEnabled': rotation_enabled})
def put_key_policy(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html"""
key_id = self.parameters.get('KeyId')
policy_name = self.parameters.get('PolicyName')
policy = self.parameters.get('Policy')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
_assert_default_policy(policy_name)
try:
self._validate_cmk_id(key_id)
self.kms_backend.put_key_policy(key_id, policy)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None)
def get_key_policy(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html"""
key_id = self.parameters.get('KeyId')
policy_name = self.parameters.get('PolicyName')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
_assert_default_policy(policy_name)
try:
self._validate_cmk_id(key_id)
return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)})
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
def list_key_policies(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html"""
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
self.kms_backend.describe_key(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'Truncated': False, 'PolicyNames': ['default']})
def encrypt(self):
"""
We perform no encryption, we just encode the value as base64 and then
decode it in decrypt().
"""
value = self.parameters.get("Plaintext")
if isinstance(value, six.text_type):
value = value.encode('utf-8')
return json.dumps({"CiphertextBlob": base64.b64encode(value).decode("utf-8"), 'KeyId': 'key_id'})
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html"""
key_id = self.parameters.get("KeyId")
encryption_context = self.parameters.get('EncryptionContext', {})
plaintext = self.parameters.get("Plaintext")
self._validate_key_id(key_id)
if isinstance(plaintext, six.text_type):
plaintext = plaintext.encode('utf-8')
ciphertext_blob, arn = self.kms_backend.encrypt(
key_id=key_id,
plaintext=plaintext,
encryption_context=encryption_context,
)
ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8")
return json.dumps({"CiphertextBlob": ciphertext_blob_response, "KeyId": arn})
def decrypt(self):
# TODO refuse decode if EncryptionContext is not the same as when it was encrypted / generated
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob")
encryption_context = self.parameters.get('EncryptionContext', {})
value = self.parameters.get("CiphertextBlob")
try:
return json.dumps({"Plaintext": base64.b64decode(value).decode("utf-8"), 'KeyId': 'key_id'})
except UnicodeDecodeError:
# Generate data key will produce random bytes which when decrypted is still returned as base64
return json.dumps({"Plaintext": value})
plaintext, arn = self.kms_backend.decrypt(
ciphertext_blob=ciphertext_blob,
encryption_context=encryption_context,
)
plaintext_response = base64.b64encode(plaintext).decode("utf-8")
return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn})
def re_encrypt(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html"""
ciphertext_blob = self.parameters.get("CiphertextBlob")
source_encryption_context = self.parameters.get("SourceEncryptionContext", {})
destination_key_id = self.parameters.get("DestinationKeyId")
destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {})
self._validate_cmk_id(destination_key_id)
new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt(
ciphertext_blob=ciphertext_blob,
source_encryption_context=source_encryption_context,
destination_key_id=destination_key_id,
destination_encryption_context=destination_encryption_context,
)
response_ciphertext_blob = base64.b64encode(new_ciphertext_blob).decode("utf-8")
return json.dumps(
{"CiphertextBlob": response_ciphertext_blob, "KeyId": encrypting_arn, "SourceKeyId": decrypting_arn}
)
def disable_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html"""
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
self.kms_backend.disable_key(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None)
def enable_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html"""
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
self.kms_backend.enable_key(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps(None)
def cancel_key_deletion(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html"""
key_id = self.parameters.get('KeyId')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
self.kms_backend.cancel_key_deletion(key_id)
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
return json.dumps({'KeyId': key_id})
def schedule_key_deletion(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html"""
key_id = self.parameters.get('KeyId')
if self.parameters.get('PendingWindowInDays') is None:
pending_window_in_days = 30
else:
pending_window_in_days = self.parameters.get('PendingWindowInDays')
_assert_valid_key_id(self.kms_backend.get_key_id(key_id))
try:
self._validate_cmk_id(key_id)
return json.dumps({
'KeyId': key_id,
'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days)
})
except KeyError:
raise NotFoundException("Key 'arn:aws:kms:{region}:012345678912:key/"
"{key_id}' does not exist".format(region=self.region, key_id=key_id))
def generate_data_key(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html"""
key_id = self.parameters.get('KeyId')
encryption_context = self.parameters.get('EncryptionContext')
encryption_context = self.parameters.get('EncryptionContext', {})
number_of_bytes = self.parameters.get('NumberOfBytes')
key_spec = self.parameters.get('KeySpec')
grant_tokens = self.parameters.get('GrantTokens')
# Param validation
if key_id.startswith('alias'):
if self.kms_backend.get_key_id_from_alias(key_id) is None:
raise NotFoundException('Alias arn:aws:kms:{region}:012345678912:{alias_name} is not found.'.format(
region=self.region, alias_name=key_id))
else:
if self.kms_backend.get_key_id(key_id) not in self.kms_backend.keys:
raise NotFoundException('Invalid keyId')
self._validate_key_id(key_id)
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 0):
raise ValidationException("1 validation error detected: Value '2048' at 'numberOfBytes' failed "
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
raise ValidationException((
"1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed "
"to satisfy constraint: Member must have value less than or "
"equal to 1024")
"equal to 1024"
).format(number_of_bytes=number_of_bytes))
if key_spec and key_spec not in ('AES_256', 'AES_128'):
raise ValidationException("1 validation error detected: Value 'AES_257' at 'keySpec' failed "
raise ValidationException((
"1 validation error detected: Value '{key_spec}' at 'keySpec' failed "
"to satisfy constraint: Member must satisfy enum value set: "
"[AES_256, AES_128]")
"[AES_256, AES_128]"
).format(key_spec=key_spec))
if not key_spec and not number_of_bytes:
raise ValidationException("Please specify either number of bytes or key spec.")
if key_spec and number_of_bytes:
raise ValidationException("Please specify either number of bytes or key spec.")
plaintext, key_arn = self.kms_backend.generate_data_key(key_id, encryption_context,
number_of_bytes, key_spec, grant_tokens)
plaintext, ciphertext_blob, key_arn = self.kms_backend.generate_data_key(
key_id=key_id,
encryption_context=encryption_context,
number_of_bytes=number_of_bytes,
key_spec=key_spec,
grant_tokens=grant_tokens
)
plaintext = base64.b64encode(plaintext).decode()
plaintext_response = base64.b64encode(plaintext).decode("utf-8")
ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8")
return json.dumps({
'CiphertextBlob': plaintext,
'Plaintext': plaintext,
'CiphertextBlob': ciphertext_blob_response,
'Plaintext': plaintext_response,
'KeyId': key_arn # not alias
})
def generate_data_key_without_plaintext(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html"""
result = json.loads(self.generate_data_key())
del result['Plaintext']
return json.dumps(result)
def generate_random(self):
"""https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateRandom.html"""
number_of_bytes = self.parameters.get("NumberOfBytes")
def _assert_valid_key_id(key_id):
if not re.match(r'^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$', key_id, re.IGNORECASE):
raise NotFoundException('Invalid keyId')
if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1):
raise ValidationException((
"1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed "
"to satisfy constraint: Member must have value less than or "
"equal to 1024"
).format(number_of_bytes=number_of_bytes))
entropy = os.urandom(number_of_bytes)
response_entropy = base64.b64encode(entropy).decode("utf-8")
return json.dumps({"Plaintext": response_entropy})
def _assert_default_policy(policy_name):

View File

@ -1,7 +1,142 @@
from __future__ import unicode_literals
from collections import namedtuple
import io
import os
import struct
import uuid
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes
from .exceptions import InvalidCiphertextException, AccessDeniedException, NotFoundException
MASTER_KEY_LEN = 32
KEY_ID_LEN = 36
IV_LEN = 12
TAG_LEN = 16
HEADER_LEN = KEY_ID_LEN + IV_LEN + TAG_LEN
# NOTE: This is just a simple binary format. It is not what KMS actually does.
CIPHERTEXT_HEADER_FORMAT = ">{key_id_len}s{iv_len}s{tag_len}s".format(
key_id_len=KEY_ID_LEN, iv_len=IV_LEN, tag_len=TAG_LEN
)
Ciphertext = namedtuple("Ciphertext", ("key_id", "iv", "ciphertext", "tag"))
def generate_key_id():
return str(uuid.uuid4())
def generate_data_key(number_of_bytes):
"""Generate a data key."""
return os.urandom(number_of_bytes)
def generate_master_key():
"""Generate a master key."""
return generate_data_key(MASTER_KEY_LEN)
def _serialize_ciphertext_blob(ciphertext):
"""Serialize Ciphertext object into a ciphertext blob.
NOTE: This is just a simple binary format. It is not what KMS actually does.
"""
header = struct.pack(CIPHERTEXT_HEADER_FORMAT, ciphertext.key_id.encode("utf-8"), ciphertext.iv, ciphertext.tag)
return header + ciphertext.ciphertext
def _deserialize_ciphertext_blob(ciphertext_blob):
"""Deserialize ciphertext blob into a Ciphertext object.
NOTE: This is just a simple binary format. It is not what KMS actually does.
"""
header = ciphertext_blob[:HEADER_LEN]
ciphertext = ciphertext_blob[HEADER_LEN:]
key_id, iv, tag = struct.unpack(CIPHERTEXT_HEADER_FORMAT, header)
return Ciphertext(key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag)
def _serialize_encryption_context(encryption_context):
"""Serialize encryption context for use a AAD.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
"""
aad = io.BytesIO()
for key, value in sorted(encryption_context.items(), key=lambda x: x[0]):
aad.write(key.encode("utf-8"))
aad.write(value.encode("utf-8"))
return aad.getvalue()
def encrypt(master_keys, key_id, plaintext, encryption_context):
"""Encrypt data using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
NOTE: This function is NOT compatible with KMS APIs.
:param dict master_keys: Mapping of a KmsBackend's known master keys
:param str key_id: Key ID of moto master key
:param bytes plaintext: Plaintext data to encrypt
:param dict[str, str] encryption_context: KMS-style encryption context
:returns: Moto-structured ciphertext blob encrypted under a moto master key in master_keys
:rtype: bytes
"""
try:
key = master_keys[key_id]
except KeyError:
is_alias = key_id.startswith("alias/") or ":alias/" in key_id
raise NotFoundException(
"{id_type} {key_id} is not found.".format(id_type="Alias" if is_alias else "keyId", key_id=key_id)
)
iv = os.urandom(IV_LEN)
aad = _serialize_encryption_context(encryption_context=encryption_context)
encryptor = Cipher(algorithms.AES(key.key_material), modes.GCM(iv), backend=default_backend()).encryptor()
encryptor.authenticate_additional_data(aad)
ciphertext = encryptor.update(plaintext) + encryptor.finalize()
return _serialize_ciphertext_blob(
ciphertext=Ciphertext(key_id=key_id, iv=iv, ciphertext=ciphertext, tag=encryptor.tag)
)
def decrypt(master_keys, ciphertext_blob, encryption_context):
"""Decrypt a ciphertext blob using a master key material.
NOTE: This is not necessarily what KMS does, but it retains the same properties.
NOTE: This function is NOT compatible with KMS APIs.
:param dict master_keys: Mapping of a KmsBackend's known master keys
:param bytes ciphertext_blob: moto-structured ciphertext blob encrypted under a moto master key in master_keys
:param dict[str, str] encryption_context: KMS-style encryption context
:returns: plaintext bytes and moto key ID
:rtype: bytes and str
"""
try:
ciphertext = _deserialize_ciphertext_blob(ciphertext_blob=ciphertext_blob)
except Exception:
raise InvalidCiphertextException()
aad = _serialize_encryption_context(encryption_context=encryption_context)
try:
key = master_keys[ciphertext.key_id]
except KeyError:
raise AccessDeniedException(
"The ciphertext refers to a customer master key that does not exist, "
"does not exist in this region, or you are not allowed to access."
)
try:
decryptor = Cipher(
algorithms.AES(key.key_material), modes.GCM(ciphertext.iv, ciphertext.tag), backend=default_backend()
).decryptor()
decryptor.authenticate_additional_data(aad)
plaintext = decryptor.update(ciphertext.ciphertext) + decryptor.finalize()
except Exception:
raise InvalidCiphertextException()
return plaintext, ciphertext.key_id

View File

@ -41,7 +41,7 @@ class LogStream:
self.region = region
self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format(
region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name)
self.creationTime = unix_time_millis()
self.creationTime = int(unix_time_millis())
self.firstEventTimestamp = None
self.lastEventTimestamp = None
self.lastIngestionTime = None
@ -80,7 +80,7 @@ class LogStream:
def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token):
# TODO: ensure sequence_token
# TODO: to be thread safe this would need a lock
self.lastIngestionTime = unix_time_millis()
self.lastIngestionTime = int(unix_time_millis())
# TODO: make this match AWS if possible
self.storedBytes += sum([len(log_event["message"]) for log_event in log_events])
self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events]
@ -115,6 +115,8 @@ class LogStream:
events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]]
if next_index + limit < len(self.events):
next_index += limit
else:
next_index = len(self.events)
back_index -= limit
if back_index <= 0:
@ -146,7 +148,7 @@ class LogGroup:
self.region = region
self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format(
region=region, log_group=name)
self.creationTime = unix_time_millis()
self.creationTime = int(unix_time_millis())
self.tags = tags
self.streams = dict() # {name: LogStream}
self.retentionInDays = None # AWS defaults to Never Expire for log group retention

View File

@ -74,7 +74,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
automated_snapshot_retention_period, port, cluster_version,
allow_version_upgrade, number_of_nodes, publicly_accessible,
encrypted, region_name, tags=None, iam_roles_arn=None,
restored_from_snapshot=False):
enhanced_vpc_routing=None, restored_from_snapshot=False):
super(Cluster, self).__init__(region_name, tags)
self.redshift_backend = redshift_backend
self.cluster_identifier = cluster_identifier
@ -85,6 +85,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
self.master_user_password = master_user_password
self.db_name = db_name if db_name else "dev"
self.vpc_security_group_ids = vpc_security_group_ids
self.enhanced_vpc_routing = enhanced_vpc_routing if enhanced_vpc_routing is not None else False
self.cluster_subnet_group_name = cluster_subnet_group_name
self.publicly_accessible = publicly_accessible
self.encrypted = encrypted
@ -154,6 +155,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
port=properties.get('Port'),
cluster_version=properties.get('ClusterVersion'),
allow_version_upgrade=properties.get('AllowVersionUpgrade'),
enhanced_vpc_routing=properties.get('EnhancedVpcRouting'),
number_of_nodes=properties.get('NumberOfNodes'),
publicly_accessible=properties.get("PubliclyAccessible"),
encrypted=properties.get("Encrypted"),
@ -241,6 +243,7 @@ class Cluster(TaggableResourceMixin, BaseModel):
'ClusterCreateTime': self.create_time,
"PendingModifiedValues": [],
"Tags": self.tags,
"EnhancedVpcRouting": self.enhanced_vpc_routing,
"IamRoles": [{
"ApplyStatus": "in-sync",
"IamRoleArn": iam_role_arn
@ -427,6 +430,7 @@ class Snapshot(TaggableResourceMixin, BaseModel):
'NumberOfNodes': self.cluster.number_of_nodes,
'DBName': self.cluster.db_name,
'Tags': self.tags,
'EnhancedVpcRouting': self.cluster.enhanced_vpc_routing,
"IamRoles": [{
"ApplyStatus": "in-sync",
"IamRoleArn": iam_role_arn
@ -678,7 +682,8 @@ class RedshiftBackend(BaseBackend):
"number_of_nodes": snapshot.cluster.number_of_nodes,
"encrypted": snapshot.cluster.encrypted,
"tags": snapshot.cluster.tags,
"restored_from_snapshot": True
"restored_from_snapshot": True,
"enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing
}
create_kwargs.update(kwargs)
return self.create_cluster(**create_kwargs)

View File

@ -135,6 +135,7 @@ class RedshiftResponse(BaseResponse):
"region_name": self.region,
"tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')),
"iam_roles_arn": self._get_iam_roles(),
"enhanced_vpc_routing": self._get_param('EnhancedVpcRouting'),
}
cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json()
cluster['ClusterStatus'] = 'creating'
@ -150,6 +151,7 @@ class RedshiftResponse(BaseResponse):
})
def restore_from_cluster_snapshot(self):
enhanced_vpc_routing = self._get_bool_param('EnhancedVpcRouting')
restore_kwargs = {
"snapshot_identifier": self._get_param('SnapshotIdentifier'),
"cluster_identifier": self._get_param('ClusterIdentifier'),
@ -171,6 +173,8 @@ class RedshiftResponse(BaseResponse):
"region_name": self.region,
"iam_roles_arn": self._get_iam_roles(),
}
if enhanced_vpc_routing is not None:
restore_kwargs['enhanced_vpc_routing'] = enhanced_vpc_routing
cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json()
cluster['ClusterStatus'] = 'creating'
return self.get_response({
@ -218,6 +222,7 @@ class RedshiftResponse(BaseResponse):
"publicly_accessible": self._get_param("PubliclyAccessible"),
"encrypted": self._get_param("Encrypted"),
"iam_roles_arn": self._get_iam_roles(),
"enhanced_vpc_routing": self._get_param("EnhancedVpcRouting")
}
cluster_kwargs = {}
# We only want parameters that were actually passed in, otherwise

View File

@ -305,6 +305,7 @@ class Route53Backend(BaseBackend):
def list_tags_for_resource(self, resource_id):
if resource_id in self.resource_tags:
return self.resource_tags[resource_id]
return {}
def get_all_hosted_zones(self):
return self.zones.values()

View File

@ -20,7 +20,7 @@ from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, Missi
MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError
from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \
FakeTag
from .utils import bucket_name_from_url, clean_key_name, metadata_from_headers, parse_region_from_url
from .utils import bucket_name_from_url, clean_key_name, undo_clean_key_name, metadata_from_headers, parse_region_from_url
from xml.dom import minidom
@ -451,17 +451,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
continuation_token = querystring.get('continuation-token', [None])[0]
start_after = querystring.get('start-after', [None])[0]
# sort the combination of folders and keys into lexicographical order
all_keys = result_keys + result_folders
all_keys.sort(key=self._get_name)
if continuation_token or start_after:
limit = continuation_token or start_after
if not delimiter:
result_keys = self._get_results_from_token(result_keys, limit)
else:
result_folders = self._get_results_from_token(result_folders, limit)
all_keys = self._get_results_from_token(all_keys, limit)
if not delimiter:
result_keys, is_truncated, next_continuation_token = self._truncate_result(result_keys, max_keys)
else:
result_folders, is_truncated, next_continuation_token = self._truncate_result(result_folders, max_keys)
truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys)
result_keys, result_folders = self._split_truncated_keys(truncated_keys)
key_count = len(result_keys) + len(result_folders)
@ -479,6 +478,24 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
start_after=None if continuation_token else start_after
)
@staticmethod
def _get_name(key):
if isinstance(key, FakeKey):
return key.name
else:
return key
@staticmethod
def _split_truncated_keys(truncated_keys):
result_keys = []
result_folders = []
for key in truncated_keys:
if isinstance(key, FakeKey):
result_keys.append(key)
else:
result_folders.append(key)
return result_keys, result_folders
def _get_results_from_token(self, result_keys, token):
continuation_index = 0
for key in result_keys:
@ -694,7 +711,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
for k in keys:
key_name = k.firstChild.nodeValue
success = self.backend.delete_key(bucket_name, key_name)
success = self.backend.delete_key(bucket_name, undo_clean_key_name(key_name))
if success:
deleted_names.append(key_name)
else:

View File

@ -15,4 +15,6 @@ url_paths = {
'{0}/(?P<key_or_bucket_name>[^/]+)/?$': S3ResponseInstance.ambiguous_response,
# path-based bucket + key
'{0}/(?P<bucket_name_path>[^/]+)/(?P<key_name>.+)': S3ResponseInstance.key_response,
# subdomain bucket + key with empty first part of path
'{0}//(?P<key_name>.*)$': S3ResponseInstance.key_response,
}

View File

@ -5,7 +5,7 @@ import os
from boto.s3.key import Key
import re
import six
from six.moves.urllib.parse import urlparse, unquote
from six.moves.urllib.parse import urlparse, unquote, quote
import sys
@ -71,10 +71,15 @@ def metadata_from_headers(headers):
def clean_key_name(key_name):
if six.PY2:
return unquote(key_name.encode('utf-8')).decode('utf-8')
return unquote(key_name)
def undo_clean_key_name(key_name):
if six.PY2:
return quote(key_name.encode('utf-8')).decode('utf-8')
return quote(key_name)
class _VersionedKeyStore(dict):
""" A simplified/modified version of Django's `MultiValueDict` taken from:

View File

@ -154,9 +154,9 @@ class SecretsManagerBackend(BaseBackend):
return version_id
def put_secret_value(self, secret_id, secret_string, version_stages):
def put_secret_value(self, secret_id, secret_string, secret_binary, version_stages):
version_id = self._add_secret(secret_id, secret_string, version_stages=version_stages)
version_id = self._add_secret(secret_id, secret_string, secret_binary, version_stages=version_stages)
response = json.dumps({
'ARN': secret_arn(self.region, secret_id),

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.secretsmanager.exceptions import InvalidRequestException
from .models import secretsmanager_backends
@ -71,10 +72,14 @@ class SecretsManagerResponse(BaseResponse):
def put_secret_value(self):
secret_id = self._get_param('SecretId', if_none='')
secret_string = self._get_param('SecretString', if_none='')
secret_string = self._get_param('SecretString')
secret_binary = self._get_param('SecretBinary')
if not secret_binary and not secret_string:
raise InvalidRequestException('You must provide either SecretString or SecretBinary.')
version_stages = self._get_param('VersionStages', if_none=['AWSCURRENT'])
return secretsmanager_backends[self.region].put_secret_value(
secret_id=secret_id,
secret_binary=secret_binary,
secret_string=secret_string,
version_stages=version_stages,
)

View File

@ -174,10 +174,11 @@ def create_backend_app(service):
backend_app.url_map.converters['regex'] = RegexConverter
backend = list(BACKENDS[service].values())[0]
for url_path, handler in backend.flask_paths.items():
view_func = convert_flask_to_httpretty_response(handler)
if handler.__name__ == 'dispatch':
endpoint = '{0}.dispatch'.format(handler.__self__.__name__)
else:
endpoint = None
endpoint = view_func.__name__
original_endpoint = endpoint
index = 2
@ -191,7 +192,7 @@ def create_backend_app(service):
url_path,
endpoint=endpoint,
methods=HTTP_METHODS,
view_func=convert_flask_to_httpretty_response(handler),
view_func=view_func,
strict_slashes=False,
)

View File

@ -40,3 +40,11 @@ class InvalidParameterValue(RESTError):
def __init__(self, message):
super(InvalidParameterValue, self).__init__(
"InvalidParameterValue", message)
class InternalError(RESTError):
code = 500
def __init__(self, message):
super(InternalError, self).__init__(
"InternalFailure", message)

View File

@ -18,7 +18,7 @@ from moto.awslambda import lambda_backends
from .exceptions import (
SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter,
InvalidParameterValue
InvalidParameterValue, InternalError
)
from .utils import make_arn_for_topic, make_arn_for_subscription
@ -131,13 +131,47 @@ class Subscription(BaseModel):
message_attributes = {}
def _field_match(field, rules, message_attributes):
for rule in rules:
# TODO: boolean value matching is not supported, SNS behavior unknown
if isinstance(rule, six.string_types):
if field not in message_attributes:
return False
for rule in rules:
if isinstance(rule, six.string_types):
# only string value matching is supported
if message_attributes[field]['Value'] == rule:
return True
try:
json_data = json.loads(message_attributes[field]['Value'])
if rule in json_data:
return True
except (ValueError, TypeError):
pass
if isinstance(rule, (six.integer_types, float)):
if field not in message_attributes:
return False
if message_attributes[field]['Type'] == 'Number':
attribute_values = [message_attributes[field]['Value']]
elif message_attributes[field]['Type'] == 'String.Array':
try:
attribute_values = json.loads(message_attributes[field]['Value'])
if not isinstance(attribute_values, list):
attribute_values = [attribute_values]
except (ValueError, TypeError):
return False
else:
return False
for attribute_values in attribute_values:
# Even the offical documentation states a 5 digits of accuracy after the decimal point for numerics, in reality it is 6
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
if int(attribute_values * 1000000) == int(rule * 1000000):
return True
if isinstance(rule, dict):
keyword = list(rule.keys())[0]
attributes = list(rule.values())[0]
if keyword == 'exists':
if attributes and field in message_attributes:
return True
elif not attributes and field not in message_attributes:
return True
return False
return all(_field_match(field, rules, message_attributes)
@ -421,7 +455,49 @@ class SNSBackend(BaseBackend):
subscription.attributes[name] = value
if name == 'FilterPolicy':
subscription._filter_policy = json.loads(value)
filter_policy = json.loads(value)
self._validate_filter_policy(filter_policy)
subscription._filter_policy = filter_policy
def _validate_filter_policy(self, value):
# TODO: extend validation checks
combinations = 1
for rules in six.itervalues(value):
combinations *= len(rules)
# Even the offical documentation states the total combination of values must not exceed 100, in reality it is 150
# https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints
if combinations > 150:
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Filter policy is too complex")
for field, rules in six.iteritems(value):
for rule in rules:
if rule is None:
continue
if isinstance(rule, six.string_types):
continue
if isinstance(rule, bool):
continue
if isinstance(rule, (six.integer_types, float)):
if rule <= -1000000000 or rule >= 1000000000:
raise InternalError("Unknown")
continue
if isinstance(rule, dict):
keyword = list(rule.keys())[0]
attributes = list(rule.values())[0]
if keyword == 'anything-but':
continue
elif keyword == 'exists':
if not isinstance(attributes, bool):
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: exists match pattern must be either true or false.")
continue
elif keyword == 'numeric':
continue
elif keyword == 'prefix':
continue
else:
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Unrecognized match type {type}".format(type=keyword))
raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null")
sns_backends = {}

View File

@ -57,6 +57,15 @@ class SNSResponse(BaseResponse):
transform_value = None
if 'StringValue' in value:
if data_type == 'Number':
try:
transform_value = float(value['StringValue'])
except ValueError:
raise InvalidParameterValue(
"An error occurred (ParameterValueInvalid) "
"when calling the Publish operation: "
"Could not cast message attribute '{0}' value to number.".format(name))
else:
transform_value = value['StringValue']
elif 'BinaryValue' in value:
transform_value = value['BinaryValue']

View File

@ -265,6 +265,9 @@ class Queue(BaseModel):
if 'maxReceiveCount' not in self.redrive_policy:
raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount')
# 'maxReceiveCount' is stored as int
self.redrive_policy['maxReceiveCount'] = int(self.redrive_policy['maxReceiveCount'])
for queue in sqs_backends[self.region].queues.values():
if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']:
self.dead_letter_queue = queue
@ -424,13 +427,26 @@ class SQSBackend(BaseBackend):
queue_attributes = queue.attributes
new_queue_attributes = new_queue.attributes
static_attributes = (
'DelaySeconds',
'MaximumMessageSize',
'MessageRetentionPeriod',
'Policy',
'QueueArn',
'ReceiveMessageWaitTimeSeconds',
'RedrivePolicy',
'VisibilityTimeout',
'KmsMasterKeyId',
'KmsDataKeyReusePeriodSeconds',
'FifoQueue',
'ContentBasedDeduplication',
)
for key in ['CreatedTimestamp', 'LastModifiedTimestamp']:
queue_attributes.pop(key)
new_queue_attributes.pop(key)
if queue_attributes != new_queue_attributes:
raise QueueAlreadyExists("The specified queue already exists.")
for key in static_attributes:
if queue_attributes.get(key) != new_queue_attributes.get(key):
raise QueueAlreadyExists(
"The specified queue already exists.",
)
else:
try:
kwargs.pop('region')

View File

@ -0,0 +1,6 @@
from __future__ import unicode_literals
from .models import stepfunction_backends
from ..core.models import base_decorator
stepfunction_backend = stepfunction_backends['us-east-1']
mock_stepfunctions = base_decorator(stepfunction_backends)

View File

@ -0,0 +1,35 @@
from __future__ import unicode_literals
import json
class AWSError(Exception):
TYPE = None
STATUS = 400
def __init__(self, message, type=None, status=None):
self.message = message
self.type = type if type is not None else self.TYPE
self.status = status if status is not None else self.STATUS
def response(self):
return json.dumps({'__type': self.type, 'message': self.message}), dict(status=self.status)
class ExecutionDoesNotExist(AWSError):
TYPE = 'ExecutionDoesNotExist'
STATUS = 400
class InvalidArn(AWSError):
TYPE = 'InvalidArn'
STATUS = 400
class InvalidName(AWSError):
TYPE = 'InvalidName'
STATUS = 400
class StateMachineDoesNotExist(AWSError):
TYPE = 'StateMachineDoesNotExist'
STATUS = 400

View File

@ -0,0 +1,162 @@
import boto
import re
from datetime import datetime
from moto.core import BaseBackend
from moto.core.utils import iso_8601_datetime_without_milliseconds
from moto.sts.models import ACCOUNT_ID
from uuid import uuid4
from .exceptions import ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist
class StateMachine():
def __init__(self, arn, name, definition, roleArn, tags=None):
self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.arn = arn
self.name = name
self.definition = definition
self.roleArn = roleArn
self.tags = tags
class Execution():
def __init__(self, region_name, account_id, state_machine_name, execution_name, state_machine_arn):
execution_arn = 'arn:aws:states:{}:{}:execution:{}:{}'
execution_arn = execution_arn.format(region_name, account_id, state_machine_name, execution_name)
self.execution_arn = execution_arn
self.name = execution_name
self.start_date = iso_8601_datetime_without_milliseconds(datetime.now())
self.state_machine_arn = state_machine_arn
self.status = 'RUNNING'
self.stop_date = None
def stop(self):
self.status = 'SUCCEEDED'
self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now())
class StepFunctionBackend(BaseBackend):
# https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/stepfunctions.html#SFN.Client.create_state_machine
# A name must not contain:
# whitespace
# brackets < > { } [ ]
# wildcard characters ? *
# special characters " # % \ ^ | ~ ` $ & , ; : /
invalid_chars_for_name = [' ', '{', '}', '[', ']', '<', '>',
'?', '*',
'"', '#', '%', '\\', '^', '|', '~', '`', '$', '&', ',', ';', ':', '/']
# control characters (U+0000-001F , U+007F-009F )
invalid_unicodes_for_name = [u'\u0000', u'\u0001', u'\u0002', u'\u0003', u'\u0004',
u'\u0005', u'\u0006', u'\u0007', u'\u0008', u'\u0009',
u'\u000A', u'\u000B', u'\u000C', u'\u000D', u'\u000E', u'\u000F',
u'\u0010', u'\u0011', u'\u0012', u'\u0013', u'\u0014',
u'\u0015', u'\u0016', u'\u0017', u'\u0018', u'\u0019',
u'\u001A', u'\u001B', u'\u001C', u'\u001D', u'\u001E', u'\u001F',
u'\u007F',
u'\u0080', u'\u0081', u'\u0082', u'\u0083', u'\u0084', u'\u0085',
u'\u0086', u'\u0087', u'\u0088', u'\u0089',
u'\u008A', u'\u008B', u'\u008C', u'\u008D', u'\u008E', u'\u008F',
u'\u0090', u'\u0091', u'\u0092', u'\u0093', u'\u0094', u'\u0095',
u'\u0096', u'\u0097', u'\u0098', u'\u0099',
u'\u009A', u'\u009B', u'\u009C', u'\u009D', u'\u009E', u'\u009F']
accepted_role_arn_format = re.compile('arn:aws:iam:(?P<account_id>[0-9]{12}):role/.+')
accepted_mchn_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):stateMachine:.+')
accepted_exec_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P<account_id>[0-9]{12}):execution:.+')
def __init__(self, region_name):
self.state_machines = []
self.executions = []
self.region_name = region_name
self._account_id = None
def create_state_machine(self, name, definition, roleArn, tags=None):
self._validate_name(name)
self._validate_role_arn(roleArn)
arn = 'arn:aws:states:' + self.region_name + ':' + str(self._get_account_id()) + ':stateMachine:' + name
try:
return self.describe_state_machine(arn)
except StateMachineDoesNotExist:
state_machine = StateMachine(arn, name, definition, roleArn, tags)
self.state_machines.append(state_machine)
return state_machine
def list_state_machines(self):
return self.state_machines
def describe_state_machine(self, arn):
self._validate_machine_arn(arn)
sm = next((x for x in self.state_machines if x.arn == arn), None)
if not sm:
raise StateMachineDoesNotExist("State Machine Does Not Exist: '" + arn + "'")
return sm
def delete_state_machine(self, arn):
self._validate_machine_arn(arn)
sm = next((x for x in self.state_machines if x.arn == arn), None)
if sm:
self.state_machines.remove(sm)
def start_execution(self, state_machine_arn):
state_machine_name = self.describe_state_machine(state_machine_arn).name
execution = Execution(region_name=self.region_name,
account_id=self._get_account_id(),
state_machine_name=state_machine_name,
execution_name=str(uuid4()),
state_machine_arn=state_machine_arn)
self.executions.append(execution)
return execution
def stop_execution(self, execution_arn):
execution = next((x for x in self.executions if x.execution_arn == execution_arn), None)
if not execution:
raise ExecutionDoesNotExist("Execution Does Not Exist: '" + execution_arn + "'")
execution.stop()
return execution
def list_executions(self, state_machine_arn):
return [execution for execution in self.executions if execution.state_machine_arn == state_machine_arn]
def describe_execution(self, arn):
self._validate_execution_arn(arn)
exctn = next((x for x in self.executions if x.execution_arn == arn), None)
if not exctn:
raise ExecutionDoesNotExist("Execution Does Not Exist: '" + arn + "'")
return exctn
def reset(self):
region_name = self.region_name
self.__dict__ = {}
self.__init__(region_name)
def _validate_name(self, name):
if any(invalid_char in name for invalid_char in self.invalid_chars_for_name):
raise InvalidName("Invalid Name: '" + name + "'")
if any(name.find(char) >= 0 for char in self.invalid_unicodes_for_name):
raise InvalidName("Invalid Name: '" + name + "'")
def _validate_role_arn(self, role_arn):
self._validate_arn(arn=role_arn,
regex=self.accepted_role_arn_format,
invalid_msg="Invalid Role Arn: '" + role_arn + "'")
def _validate_machine_arn(self, machine_arn):
self._validate_arn(arn=machine_arn,
regex=self.accepted_mchn_arn_format,
invalid_msg="Invalid Role Arn: '" + machine_arn + "'")
def _validate_execution_arn(self, execution_arn):
self._validate_arn(arn=execution_arn,
regex=self.accepted_exec_arn_format,
invalid_msg="Execution Does Not Exist: '" + execution_arn + "'")
def _validate_arn(self, arn, regex, invalid_msg):
match = regex.match(arn)
if not arn or not match:
raise InvalidArn(invalid_msg)
def _get_account_id(self):
return ACCOUNT_ID
stepfunction_backends = {_region.name: StepFunctionBackend(_region.name) for _region in boto.awslambda.regions()}

View File

@ -0,0 +1,138 @@
from __future__ import unicode_literals
import json
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .exceptions import AWSError
from .models import stepfunction_backends
class StepFunctionResponse(BaseResponse):
@property
def stepfunction_backend(self):
return stepfunction_backends[self.region]
@amzn_request_id
def create_state_machine(self):
name = self._get_param('name')
definition = self._get_param('definition')
roleArn = self._get_param('roleArn')
tags = self._get_param('tags')
try:
state_machine = self.stepfunction_backend.create_state_machine(name=name, definition=definition,
roleArn=roleArn,
tags=tags)
response = {
'creationDate': state_machine.creation_date,
'stateMachineArn': state_machine.arn
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def list_state_machines(self):
list_all = self.stepfunction_backend.list_state_machines()
list_all = sorted([{'creationDate': sm.creation_date,
'name': sm.name,
'stateMachineArn': sm.arn} for sm in list_all],
key=lambda x: x['name'])
response = {'stateMachines': list_all}
return 200, {}, json.dumps(response)
@amzn_request_id
def describe_state_machine(self):
arn = self._get_param('stateMachineArn')
return self._describe_state_machine(arn)
@amzn_request_id
def _describe_state_machine(self, state_machine_arn):
try:
state_machine = self.stepfunction_backend.describe_state_machine(state_machine_arn)
response = {
'creationDate': state_machine.creation_date,
'stateMachineArn': state_machine.arn,
'definition': state_machine.definition,
'name': state_machine.name,
'roleArn': state_machine.roleArn,
'status': 'ACTIVE'
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def delete_state_machine(self):
arn = self._get_param('stateMachineArn')
try:
self.stepfunction_backend.delete_state_machine(arn)
return 200, {}, json.dumps('{}')
except AWSError as err:
return err.response()
@amzn_request_id
def list_tags_for_resource(self):
arn = self._get_param('resourceArn')
try:
state_machine = self.stepfunction_backend.describe_state_machine(arn)
tags = state_machine.tags or []
except AWSError:
tags = []
response = {'tags': tags}
return 200, {}, json.dumps(response)
@amzn_request_id
def start_execution(self):
arn = self._get_param('stateMachineArn')
execution = self.stepfunction_backend.start_execution(arn)
response = {'executionArn': execution.execution_arn,
'startDate': execution.start_date}
return 200, {}, json.dumps(response)
@amzn_request_id
def list_executions(self):
arn = self._get_param('stateMachineArn')
state_machine = self.stepfunction_backend.describe_state_machine(arn)
executions = self.stepfunction_backend.list_executions(arn)
executions = [{'executionArn': execution.execution_arn,
'name': execution.name,
'startDate': execution.start_date,
'stateMachineArn': state_machine.arn,
'status': execution.status} for execution in executions]
return 200, {}, json.dumps({'executions': executions})
@amzn_request_id
def describe_execution(self):
arn = self._get_param('executionArn')
try:
execution = self.stepfunction_backend.describe_execution(arn)
response = {
'executionArn': arn,
'input': '{}',
'name': execution.name,
'startDate': execution.start_date,
'stateMachineArn': execution.state_machine_arn,
'status': execution.status,
'stopDate': execution.stop_date
}
return 200, {}, json.dumps(response)
except AWSError as err:
return err.response()
@amzn_request_id
def describe_state_machine_for_execution(self):
arn = self._get_param('executionArn')
try:
execution = self.stepfunction_backend.describe_execution(arn)
return self._describe_state_machine(execution.state_machine_arn)
except AWSError as err:
return err.response()
@amzn_request_id
def stop_execution(self):
arn = self._get_param('executionArn')
execution = self.stepfunction_backend.stop_execution(arn)
response = {'stopDate': execution.stop_date}
return 200, {}, json.dumps(response)

View File

@ -0,0 +1,10 @@
from __future__ import unicode_literals
from .responses import StepFunctionResponse
url_bases = [
"https?://states.(.+).amazonaws.com",
]
url_paths = {
'{0}/$': StepFunctionResponse.dispatch,
}

View File

@ -10,6 +10,7 @@ boto>=2.45.0
boto3>=1.4.4
botocore>=1.12.13
six>=1.9
parameterized>=0.7.0
prompt-toolkit==1.0.14
click==6.7
inflection==0.3.1

View File

@ -981,11 +981,13 @@ def test_api_keys():
apikey['value'].should.equal(apikey_value)
apikey_name = 'TESTKEY2'
payload = {'name': apikey_name }
payload = {'name': apikey_name, 'tags': {'tag1': 'test_tag1', 'tag2': '1'}}
response = client.create_api_key(**payload)
apikey_id = response['id']
apikey = client.get_api_key(apiKey=apikey_id)
apikey['name'].should.equal(apikey_name)
apikey['tags']['tag1'].should.equal('test_tag1')
apikey['tags']['tag2'].should.equal('1')
len(apikey['value']).should.equal(40)
apikey_name = 'TESTKEY3'

View File

@ -563,6 +563,38 @@ def test_reregister_task_definition():
resp2['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn'])
resp3 = batch_client.register_job_definition(
jobDefinitionName='sleep10',
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 42,
'command': ['sleep', '10']
}
)
resp3['revision'].should.equal(3)
resp3['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn'])
resp3['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn'])
resp4 = batch_client.register_job_definition(
jobDefinitionName='sleep10',
type='container',
containerProperties={
'image': 'busybox',
'vcpus': 1,
'memory': 41,
'command': ['sleep', '10']
}
)
resp4['revision'].should.equal(4)
resp4['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn'])
resp4['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn'])
resp4['jobDefinitionArn'].should_not.equal(resp3['jobDefinitionArn'])
@mock_ec2
@mock_ecs

View File

@ -1,10 +1,10 @@
from __future__ import unicode_literals
import boto3
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_cognitoidentity
import sure # noqa
from moto.cognitoidentity.utils import get_random_identity_id
@ -28,6 +28,47 @@ def test_create_identity_pool():
assert result['IdentityPoolId'] != ''
@mock_cognitoidentity
def test_describe_identity_pool():
conn = boto3.client('cognito-identity', 'us-west-2')
res = conn.create_identity_pool(IdentityPoolName='TestPool',
AllowUnauthenticatedIdentities=False,
SupportedLoginProviders={'graph.facebook.com': '123456789012345'},
DeveloperProviderName='devname',
OpenIdConnectProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'],
CognitoIdentityProviders=[
{
'ProviderName': 'testprovider',
'ClientId': 'CLIENT12345',
'ServerSideTokenCheck': True
},
],
SamlProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'])
result = conn.describe_identity_pool(IdentityPoolId=res['IdentityPoolId'])
assert result['IdentityPoolId'] == res['IdentityPoolId']
assert result['AllowUnauthenticatedIdentities'] == res['AllowUnauthenticatedIdentities']
assert result['SupportedLoginProviders'] == res['SupportedLoginProviders']
assert result['DeveloperProviderName'] == res['DeveloperProviderName']
assert result['OpenIdConnectProviderARNs'] == res['OpenIdConnectProviderARNs']
assert result['CognitoIdentityProviders'] == res['CognitoIdentityProviders']
assert result['SamlProviderARNs'] == res['SamlProviderARNs']
@mock_cognitoidentity
def test_describe_identity_pool_with_invalid_id_raises_error():
conn = boto3.client('cognito-identity', 'us-west-2')
with assert_raises(ClientError) as cm:
conn.describe_identity_pool(IdentityPoolId='us-west-2_non-existent')
cm.exception.operation_name.should.equal('DescribeIdentityPool')
cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException')
cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400)
# testing a helper function
def test_get_random_identity_id():
assert len(get_random_identity_id('us-west-2')) > 0
@ -44,7 +85,8 @@ def test_get_id():
'someurl': '12345'
})
print(result)
assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get('HTTPStatusCode') == 200
assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get(
'HTTPStatusCode') == 200
@mock_cognitoidentity
@ -71,6 +113,7 @@ def test_get_open_id_token_for_developer_identity():
assert len(result['Token']) > 0
assert result['IdentityId'] == '12345'
@mock_cognitoidentity
def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id():
conn = boto3.client('cognito-identity', 'us-west-2')
@ -84,6 +127,7 @@ def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id()
assert len(result['Token']) > 0
assert len(result['IdentityId']) > 0
@mock_cognitoidentity
def test_get_open_id_token():
conn = boto3.client('cognito-identity', 'us-west-2')

View File

@ -0,0 +1,22 @@
import requests
import sure # noqa
import boto3
from moto import mock_sqs, settings
@mock_sqs
def test_passthrough_requests():
conn = boto3.client("sqs", region_name='us-west-1')
conn.create_queue(QueueName="queue1")
res = requests.get("https://httpbin.org/ip")
assert res.status_code == 200
if not settings.TEST_SERVER_MODE:
@mock_sqs
def test_requests_to_amazon_subdomains_dont_work():
res = requests.get("https://fakeservice.amazonaws.com/foo/bar")
assert res.content == b"The method is not implemented"
assert res.status_code == 400

View File

@ -973,6 +973,53 @@ def test_query_filter():
assert response['Count'] == 2
@mock_dynamodb2
def test_query_filter_overlapping_expression_prefixes():
client = boto3.client('dynamodb', region_name='us-east-1')
dynamodb = boto3.resource('dynamodb', region_name='us-east-1')
# Create the DynamoDB table.
client.create_table(
TableName='test1',
AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}],
KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}],
ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123}
)
client.put_item(
TableName='test1',
Item={
'client': {'S': 'client1'},
'app': {'S': 'app1'},
'nested': {'M': {
'version': {'S': 'version1'},
'contents': {'L': [
{'S': 'value1'}, {'S': 'value2'},
]},
}},
})
table = dynamodb.Table('test1')
response = table.query(
KeyConditionExpression=Key('client').eq('client1') & Key('app').eq('app1'),
ProjectionExpression='#1, #10, nested',
ExpressionAttributeNames={
'#1': 'client',
'#10': 'app',
}
)
assert response['Count'] == 1
assert response['Items'][0] == {
'client': 'client1',
'app': 'app1',
'nested': {
'version': 'version1',
'contents': ['value1', 'value2']
}
}
@mock_dynamodb2
def test_scan_filter():
client = boto3.client('dynamodb', region_name='us-east-1')
@ -2034,6 +2081,36 @@ def test_condition_expression__or_order():
)
@mock_dynamodb2
def test_condition_expression__and_order():
client = boto3.client('dynamodb', region_name='us-east-1')
client.create_table(
TableName='test',
KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}],
AttributeDefinitions=[
{'AttributeName': 'forum_name', 'AttributeType': 'S'},
],
ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1},
)
# ensure that the RHS of the AND expression is not evaluated if the LHS
# returns true (as it would result an error)
with assert_raises(client.exceptions.ConditionalCheckFailedException):
client.update_item(
TableName='test',
Key={
'forum_name': {'S': 'the-key'},
},
UpdateExpression='set #ttl=:ttl',
ConditionExpression='attribute_exists(#ttl) AND #ttl <= :old_ttl',
ExpressionAttributeNames={'#ttl': 'ttl'},
ExpressionAttributeValues={
':ttl': {'N': '6'},
':old_ttl': {'N': '5'},
}
)
@mock_dynamodb2
def test_query_gsi_with_range_key():
dynamodb = boto3.client('dynamodb', region_name='us-east-1')

View File

@ -77,6 +77,34 @@ class TestCore():
)
assert 'ShardIterator' in resp
def test_get_shard_iterator_at_sequence_number(self):
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
resp = conn.describe_stream(StreamArn=self.stream_arn)
shard_id = resp['StreamDescription']['Shards'][0]['ShardId']
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AT_SEQUENCE_NUMBER',
SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber']
)
assert 'ShardIterator' in resp
def test_get_shard_iterator_after_sequence_number(self):
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
resp = conn.describe_stream(StreamArn=self.stream_arn)
shard_id = resp['StreamDescription']['Shards'][0]['ShardId']
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AFTER_SEQUENCE_NUMBER',
SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber']
)
assert 'ShardIterator' in resp
def test_get_records_empty(self):
conn = boto3.client('dynamodbstreams', region_name='us-east-1')
@ -135,11 +163,39 @@ class TestCore():
assert resp['Records'][1]['eventName'] == 'MODIFY'
assert resp['Records'][2]['eventName'] == 'DELETE'
sequence_number_modify = resp['Records'][1]['dynamodb']['SequenceNumber']
# now try fetching from the next shard iterator, it should be
# empty
resp = conn.get_records(ShardIterator=resp['NextShardIterator'])
assert len(resp['Records']) == 0
# check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AT_SEQUENCE_NUMBER',
SequenceNumber=sequence_number_modify
)
iterator_id = resp['ShardIterator']
resp = conn.get_records(ShardIterator=iterator_id)
assert len(resp['Records']) == 2
assert resp['Records'][0]['eventName'] == 'MODIFY'
assert resp['Records'][1]['eventName'] == 'DELETE'
# check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event
resp = conn.get_shard_iterator(
StreamArn=self.stream_arn,
ShardId=shard_id,
ShardIteratorType='AFTER_SEQUENCE_NUMBER',
SequenceNumber=sequence_number_modify
)
iterator_id = resp['ShardIterator']
resp = conn.get_records(ShardIterator=iterator_id)
assert len(resp['Records']) == 1
assert resp['Records'][0]['eventName'] == 'DELETE'
class TestEdges():
mocks = []

View File

@ -4,7 +4,7 @@ import json
import os
import boto3
import botocore
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, ParamValidationError
from nose.tools import assert_raises
import sure # noqa
@ -752,6 +752,83 @@ def test_stopped_instance_target():
})
@mock_ec2
@mock_elbv2
def test_terminated_instance_target():
target_group_port = 8080
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.0/26',
AvailabilityZone='us-east-1b')
conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
response = conn.create_target_group(
Name='a-target',
Protocol='HTTP',
Port=target_group_port,
VpcId=vpc.id,
HealthCheckProtocol='HTTP',
HealthCheckPath='/',
HealthCheckIntervalSeconds=5,
HealthCheckTimeoutSeconds=5,
HealthyThresholdCount=5,
UnhealthyThresholdCount=2,
Matcher={'HttpCode': '200'})
target_group = response.get('TargetGroups')[0]
# No targets registered yet
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(0)
response = ec2.create_instances(
ImageId='ami-1234abcd', MinCount=1, MaxCount=1)
instance = response[0]
target_dict = {
'Id': instance.id,
'Port': 500
}
response = conn.register_targets(
TargetGroupArn=target_group.get('TargetGroupArn'),
Targets=[target_dict])
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(1)
target_health_description = response.get('TargetHealthDescriptions')[0]
target_health_description['Target'].should.equal(target_dict)
target_health_description['HealthCheckPort'].should.equal(str(target_group_port))
target_health_description['TargetHealth'].should.equal({
'State': 'healthy'
})
instance.terminate()
response = conn.describe_target_health(
TargetGroupArn=target_group.get('TargetGroupArn'))
response.get('TargetHealthDescriptions').should.have.length_of(0)
@mock_ec2
@mock_elbv2
def test_target_group_attributes():
@ -1940,3 +2017,279 @@ def test_cognito_action_listener_rule_cloudformation():
'UserPoolDomain': 'testpool',
}
},])
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '404',
}
}
response = conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[action])
listener = response.get('Listeners')[0]
listener.get('DefaultActions')[0].should.equal(action)
listener_arn = listener.get('ListenerArn')
describe_rules_response = conn.describe_rules(ListenerArn=listener_arn)
describe_rules_response['Rules'][0]['Actions'][0].should.equal(action)
describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ])
describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0]
describe_listener_actions.should.equal(action)
@mock_elbv2
@mock_cloudformation
def test_fixed_response_action_listener_rule_cloudformation():
cnf_conn = boto3.client('cloudformation', region_name='us-east-1')
elbv2_client = boto3.client('elbv2', region_name='us-east-1')
template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Description": "ECS Cluster Test CloudFormation",
"Resources": {
"testVPC": {
"Type": "AWS::EC2::VPC",
"Properties": {
"CidrBlock": "10.0.0.0/16",
},
},
"subnet1": {
"Type": "AWS::EC2::Subnet",
"Properties": {
"CidrBlock": "10.0.0.0/24",
"VpcId": {"Ref": "testVPC"},
"AvalabilityZone": "us-east-1b",
},
},
"subnet2": {
"Type": "AWS::EC2::Subnet",
"Properties": {
"CidrBlock": "10.0.1.0/24",
"VpcId": {"Ref": "testVPC"},
"AvalabilityZone": "us-east-1b",
},
},
"testLb": {
"Type": "AWS::ElasticLoadBalancingV2::LoadBalancer",
"Properties": {
"Name": "my-lb",
"Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}],
"Type": "application",
"SecurityGroups": [],
}
},
"testListener": {
"Type": "AWS::ElasticLoadBalancingV2::Listener",
"Properties": {
"LoadBalancerArn": {"Ref": "testLb"},
"Port": 80,
"Protocol": "HTTP",
"DefaultActions": [{
"Type": "fixed-response",
"FixedResponseConfig": {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '404',
}
}]
}
}
}
}
template_json = json.dumps(template)
cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json)
describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',])
load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn']
describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn)
describe_listeners_response['Listeners'].should.have.length_of(1)
describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{
'Type': 'fixed-response',
"FixedResponseConfig": {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '404',
}
},])
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule_validates_status_code():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
missing_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
}
}
with assert_raises(ParamValidationError):
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[missing_status_code_action])
invalid_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '100'
}
}
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule_validates_status_code():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
missing_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
}
}
with assert_raises(ParamValidationError):
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[missing_status_code_action])
invalid_status_code_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'text/plain',
'MessageBody': 'This page does not exist',
'StatusCode': '100'
}
}
with assert_raises(ClientError) as invalid_status_code_exception:
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[invalid_status_code_action])
invalid_status_code_exception.exception.response['Error']['Code'].should.equal('ValidationError')
@mock_elbv2
@mock_ec2
def test_fixed_response_action_listener_rule_validates_content_type():
conn = boto3.client('elbv2', region_name='us-east-1')
ec2 = boto3.resource('ec2', region_name='us-east-1')
security_group = ec2.create_security_group(
GroupName='a-security-group', Description='First One')
vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default')
subnet1 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.192/26',
AvailabilityZone='us-east-1a')
subnet2 = ec2.create_subnet(
VpcId=vpc.id,
CidrBlock='172.28.7.128/26',
AvailabilityZone='us-east-1b')
response = conn.create_load_balancer(
Name='my-lb',
Subnets=[subnet1.id, subnet2.id],
SecurityGroups=[security_group.id],
Scheme='internal',
Tags=[{'Key': 'key_name', 'Value': 'a_value'}])
load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn')
invalid_content_type_action = {
'Type': 'fixed-response',
'FixedResponseConfig': {
'ContentType': 'Fake content type',
'MessageBody': 'This page does not exist',
'StatusCode': '200'
}
}
with assert_raises(ClientError) as invalid_content_type_exception:
conn.create_listener(LoadBalancerArn=load_balancer_arn,
Protocol='HTTP',
Port=80,
DefaultActions=[invalid_content_type_action])
invalid_content_type_exception.exception.response['Error']['Code'].should.equal('InvalidLoadBalancerAction')

View File

@ -1827,6 +1827,23 @@ valid_policy_documents = [
"Resource": ["*"]
}
]
},
{
"Version": "2012-10-17",
"Statement": [
{
"Sid": "",
"Effect": "Allow",
"Action": "rds:*",
"Resource": ["arn:aws:rds:region:*:*"]
},
{
"Sid": "",
"Effect": "Allow",
"Action": ["rds:Describe*"],
"Resource": ["*"]
}
]
}
]

View File

@ -86,6 +86,12 @@ def test_update():
payload.should.have.key('version').which.should.equal(2)
payload.should.have.key('timestamp')
raw_payload = b'{"state": {"desired": {"led": "on"}}, "version": 1}'
with assert_raises(ClientError) as ex:
client.update_thing_shadow(thingName=name, payload=raw_payload)
ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(409)
ex.exception.response['Error']['Message'].should.equal('Version conflict')
@mock_iotdata
def test_publish():

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,157 @@
from __future__ import unicode_literals
import sure # noqa
from nose.tools import assert_raises
from parameterized import parameterized
from moto.kms.exceptions import AccessDeniedException, InvalidCiphertextException, NotFoundException
from moto.kms.models import Key
from moto.kms.utils import (
_deserialize_ciphertext_blob,
_serialize_ciphertext_blob,
_serialize_encryption_context,
generate_data_key,
generate_master_key,
MASTER_KEY_LEN,
encrypt,
decrypt,
Ciphertext,
)
ENCRYPTION_CONTEXT_VECTORS = (
({"this": "is", "an": "encryption", "context": "example"}, b"an" b"encryption" b"context" b"example" b"this" b"is"),
({"a_this": "one", "b_is": "actually", "c_in": "order"}, b"a_this" b"one" b"b_is" b"actually" b"c_in" b"order"),
)
CIPHERTEXT_BLOB_VECTORS = (
(
Ciphertext(
key_id="d25652e4-d2d2-49f7-929a-671ccda580c6",
iv=b"123456789012",
ciphertext=b"some ciphertext",
tag=b"1234567890123456",
),
b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext",
),
(
Ciphertext(
key_id="d25652e4-d2d2-49f7-929a-671ccda580c6",
iv=b"123456789012",
ciphertext=b"some ciphertext that is much longer now",
tag=b"1234567890123456",
),
b"d25652e4-d2d2-49f7-929a-671ccda580c6"
b"123456789012"
b"1234567890123456"
b"some ciphertext that is much longer now",
),
)
def test_generate_data_key():
test = generate_data_key(123)
test.should.be.a(bytes)
len(test).should.equal(123)
def test_generate_master_key():
test = generate_master_key()
test.should.be.a(bytes)
len(test).should.equal(MASTER_KEY_LEN)
@parameterized(ENCRYPTION_CONTEXT_VECTORS)
def test_serialize_encryption_context(raw, serialized):
test = _serialize_encryption_context(raw)
test.should.equal(serialized)
@parameterized(CIPHERTEXT_BLOB_VECTORS)
def test_cycle_ciphertext_blob(raw, _serialized):
test_serialized = _serialize_ciphertext_blob(raw)
test_deserialized = _deserialize_ciphertext_blob(test_serialized)
test_deserialized.should.equal(raw)
@parameterized(CIPHERTEXT_BLOB_VECTORS)
def test_serialize_ciphertext_blob(raw, serialized):
test = _serialize_ciphertext_blob(raw)
test.should.equal(serialized)
@parameterized(CIPHERTEXT_BLOB_VECTORS)
def test_deserialize_ciphertext_blob(raw, serialized):
test = _deserialize_ciphertext_blob(serialized)
test.should.equal(raw)
@parameterized(((ec[0],) for ec in ENCRYPTION_CONTEXT_VECTORS))
def test_encrypt_decrypt_cycle(encryption_context):
plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt(
master_keys=master_key_map, key_id=master_key.id, plaintext=plaintext, encryption_context=encryption_context
)
ciphertext_blob.should_not.equal(plaintext)
decrypted, decrypting_key_id = decrypt(
master_keys=master_key_map, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context
)
decrypted.should.equal(plaintext)
decrypting_key_id.should.equal(master_key.id)
def test_encrypt_unknown_key_id():
with assert_raises(NotFoundException):
encrypt(master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={})
def test_decrypt_invalid_ciphertext_format():
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
with assert_raises(InvalidCiphertextException):
decrypt(master_keys=master_key_map, ciphertext_blob=b"", encryption_context={})
def test_decrypt_unknwown_key_id():
ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext"
with assert_raises(AccessDeniedException):
decrypt(master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={})
def test_decrypt_invalid_ciphertext():
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext"
with assert_raises(InvalidCiphertextException):
decrypt(
master_keys=master_key_map,
ciphertext_blob=ciphertext_blob,
encryption_context={},
)
def test_decrypt_invalid_encryption_context():
plaintext = b"some secret plaintext"
master_key = Key("nop", "nop", "nop", [], "nop")
master_key_map = {master_key.id: master_key}
ciphertext_blob = encrypt(
master_keys=master_key_map,
key_id=master_key.id,
plaintext=plaintext,
encryption_context={"some": "encryption", "context": "here"},
)
with assert_raises(InvalidCiphertextException):
decrypt(
master_keys=master_key_map,
ciphertext_blob=ciphertext_blob,
encryption_context={},
)

View File

@ -190,6 +190,8 @@ def test_get_log_events():
resp['events'].should.have.length_of(10)
resp.should.have.key('nextForwardToken')
resp.should.have.key('nextBackwardToken')
resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000010')
resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000')
for i in range(10):
resp['events'][i]['timestamp'].should.equal(i)
resp['events'][i]['message'].should.equal(str(i))
@ -205,7 +207,8 @@ def test_get_log_events():
resp['events'].should.have.length_of(10)
resp.should.have.key('nextForwardToken')
resp.should.have.key('nextBackwardToken')
resp['nextForwardToken'].should.equal(next_token)
resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000020')
resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000')
for i in range(10):
resp['events'][i]['timestamp'].should.equal(i+10)
resp['events'][i]['message'].should.equal(str(i+10))

View File

@ -37,6 +37,25 @@ def test_create_cluster_boto3():
create_time = response['Cluster']['ClusterCreateTime']
create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo))
create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1))
response['Cluster']['EnhancedVpcRouting'].should.equal(False)
@mock_redshift
def test_create_cluster_boto3():
client = boto3.client('redshift', region_name='us-east-1')
response = client.create_cluster(
DBName='test',
ClusterIdentifier='test',
ClusterType='single-node',
NodeType='ds2.xlarge',
MasterUsername='user',
MasterUserPassword='password',
EnhancedVpcRouting=True
)
response['Cluster']['NodeType'].should.equal('ds2.xlarge')
create_time = response['Cluster']['ClusterCreateTime']
create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo))
create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1))
response['Cluster']['EnhancedVpcRouting'].should.equal(True)
@mock_redshift
@ -425,6 +444,58 @@ def test_delete_cluster():
"not-a-cluster").should.throw(ClusterNotFound)
@mock_redshift
def test_modify_cluster_vpc_routing():
iam_roles_arn = ['arn:aws:iam:::role/my-iam-role', ]
client = boto3.client('redshift', region_name='us-east-1')
cluster_identifier = 'my_cluster'
client.create_cluster(
ClusterIdentifier=cluster_identifier,
NodeType="single-node",
MasterUsername="username",
MasterUserPassword="password",
IamRoles=iam_roles_arn
)
cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier)
cluster = cluster_response['Clusters'][0]
cluster['EnhancedVpcRouting'].should.equal(False)
client.create_cluster_security_group(ClusterSecurityGroupName='security_group',
Description='security_group')
client.create_cluster_parameter_group(ParameterGroupName='my_parameter_group',
ParameterGroupFamily='redshift-1.0',
Description='my_parameter_group')
client.modify_cluster(
ClusterIdentifier=cluster_identifier,
ClusterType='multi-node',
NodeType="ds2.8xlarge",
NumberOfNodes=3,
ClusterSecurityGroups=["security_group"],
MasterUserPassword="new_password",
ClusterParameterGroupName="my_parameter_group",
AutomatedSnapshotRetentionPeriod=7,
PreferredMaintenanceWindow="Tue:03:00-Tue:11:00",
AllowVersionUpgrade=False,
NewClusterIdentifier=cluster_identifier,
EnhancedVpcRouting=True
)
cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier)
cluster = cluster_response['Clusters'][0]
cluster['ClusterIdentifier'].should.equal(cluster_identifier)
cluster['NodeType'].should.equal("ds2.8xlarge")
cluster['PreferredMaintenanceWindow'].should.equal("Tue:03:00-Tue:11:00")
cluster['AutomatedSnapshotRetentionPeriod'].should.equal(7)
cluster['AllowVersionUpgrade'].should.equal(False)
# This one should remain unmodified.
cluster['NumberOfNodes'].should.equal(3)
cluster['EnhancedVpcRouting'].should.equal(True)
@mock_redshift_deprecated
def test_modify_cluster():
conn = boto.connect_redshift()
@ -446,6 +517,10 @@ def test_modify_cluster():
master_user_password="password",
)
cluster_response = conn.describe_clusters(cluster_identifier)
cluster = cluster_response['DescribeClustersResponse']['DescribeClustersResult']['Clusters'][0]
cluster['EnhancedVpcRouting'].should.equal(False)
conn.modify_cluster(
cluster_identifier,
cluster_type="multi-node",
@ -456,14 +531,13 @@ def test_modify_cluster():
automated_snapshot_retention_period=7,
preferred_maintenance_window="Tue:03:00-Tue:11:00",
allow_version_upgrade=False,
new_cluster_identifier="new_identifier",
new_cluster_identifier=cluster_identifier,
)
cluster_response = conn.describe_clusters("new_identifier")
cluster_response = conn.describe_clusters(cluster_identifier)
cluster = cluster_response['DescribeClustersResponse'][
'DescribeClustersResult']['Clusters'][0]
cluster['ClusterIdentifier'].should.equal("new_identifier")
cluster['ClusterIdentifier'].should.equal(cluster_identifier)
cluster['NodeType'].should.equal("dw.hs1.xlarge")
cluster['ClusterSecurityGroups'][0][
'ClusterSecurityGroupName'].should.equal("security_group")
@ -674,6 +748,7 @@ def test_create_cluster_snapshot():
NodeType='ds2.xlarge',
MasterUsername='username',
MasterUserPassword='password',
EnhancedVpcRouting=True
)
cluster_response['Cluster']['NodeType'].should.equal('ds2.xlarge')
@ -823,11 +898,14 @@ def test_create_cluster_from_snapshot():
NodeType='ds2.xlarge',
MasterUsername='username',
MasterUserPassword='password',
EnhancedVpcRouting=True,
)
client.create_cluster_snapshot(
SnapshotIdentifier=original_snapshot_identifier,
ClusterIdentifier=original_cluster_identifier
)
response = client.restore_from_cluster_snapshot(
ClusterIdentifier=new_cluster_identifier,
SnapshotIdentifier=original_snapshot_identifier,
@ -842,7 +920,7 @@ def test_create_cluster_from_snapshot():
new_cluster['NodeType'].should.equal('ds2.xlarge')
new_cluster['MasterUsername'].should.equal('username')
new_cluster['Endpoint']['Port'].should.equal(1234)
new_cluster['EnhancedVpcRouting'].should.equal(True)
@mock_redshift
def test_create_cluster_from_snapshot_with_waiter():
@ -857,6 +935,7 @@ def test_create_cluster_from_snapshot_with_waiter():
NodeType='ds2.xlarge',
MasterUsername='username',
MasterUserPassword='password',
EnhancedVpcRouting=True
)
client.create_cluster_snapshot(
SnapshotIdentifier=original_snapshot_identifier,
@ -883,6 +962,7 @@ def test_create_cluster_from_snapshot_with_waiter():
new_cluster = response['Clusters'][0]
new_cluster['NodeType'].should.equal('ds2.xlarge')
new_cluster['MasterUsername'].should.equal('username')
new_cluster['EnhancedVpcRouting'].should.equal(True)
new_cluster['Endpoint']['Port'].should.equal(1234)

View File

@ -404,6 +404,11 @@ def test_list_or_change_tags_for_resource_request():
)
healthcheck_id = health_check['HealthCheck']['Id']
# confirm this works for resources with zero tags
response = conn.list_tags_for_resource(
ResourceType="healthcheck", ResourceId=healthcheck_id)
response["ResourceTagSet"]["Tags"].should.be.empty
tag1 = {"Key": "Deploy", "Value": "True"}
tag2 = {"Key": "Name", "Value": "UnitTest"}

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals
import datetime
import os
from six.moves.urllib.request import urlopen
from six.moves.urllib.error import HTTPError
from functools import wraps
@ -20,9 +21,11 @@ from botocore.handlers import disable_signing
from boto.s3.connection import S3Connection
from boto.s3.key import Key
from freezegun import freeze_time
from parameterized import parameterized
import six
import requests
import tests.backport_assert_raises # noqa
from nose import SkipTest
from nose.tools import assert_raises
import sure # noqa
@ -1390,6 +1393,34 @@ def test_boto3_list_objects_v2_fetch_owner():
assert len(owner.keys()) == 2
@mock_s3
def test_boto3_list_objects_v2_truncate_combined_keys_and_folders():
s3 = boto3.client('s3', region_name='us-east-1')
s3.create_bucket(Bucket='mybucket')
s3.put_object(Bucket='mybucket', Key='1/2', Body='')
s3.put_object(Bucket='mybucket', Key='2', Body='')
s3.put_object(Bucket='mybucket', Key='3/4', Body='')
s3.put_object(Bucket='mybucket', Key='4', Body='')
resp = s3.list_objects_v2(Bucket='mybucket', Prefix='', MaxKeys=2, Delimiter='/')
assert 'Delimiter' in resp
assert resp['IsTruncated'] is True
assert resp['KeyCount'] == 2
assert len(resp['Contents']) == 1
assert resp['Contents'][0]['Key'] == '2'
assert len(resp['CommonPrefixes']) == 1
assert resp['CommonPrefixes'][0]['Prefix'] == '1/'
last_tail = resp['NextContinuationToken']
resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Prefix='', Delimiter='/', StartAfter=last_tail)
assert resp['KeyCount'] == 2
assert resp['IsTruncated'] is False
assert len(resp['Contents']) == 1
assert resp['Contents'][0]['Key'] == '4'
assert len(resp['CommonPrefixes']) == 1
assert resp['CommonPrefixes'][0]['Prefix'] == '3/'
@mock_s3
def test_boto3_bucket_create():
s3 = boto3.resource('s3', region_name='us-east-1')
@ -2991,3 +3022,64 @@ def test_accelerate_configuration_is_not_supported_when_bucket_name_has_dots():
AccelerateConfiguration={'Status': 'Enabled'},
)
exc.exception.response['Error']['Code'].should.equal('InvalidRequest')
def store_and_read_back_a_key(key):
s3 = boto3.client('s3', region_name='us-east-1')
bucket_name = 'mybucket'
body = b'Some body'
s3.create_bucket(Bucket=bucket_name)
s3.put_object(
Bucket=bucket_name,
Key=key,
Body=body
)
response = s3.get_object(Bucket=bucket_name, Key=key)
response['Body'].read().should.equal(body)
@mock_s3
def test_paths_with_leading_slashes_work():
store_and_read_back_a_key('/a-key')
@mock_s3
def test_root_dir_with_empty_name_works():
if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true':
raise SkipTest('Does not work in server mode due to error in Workzeug')
store_and_read_back_a_key('/')
@parameterized([
('foo/bar/baz',),
('foo',),
('foo/run_dt%3D2019-01-01%252012%253A30%253A00',),
])
@mock_s3
def test_delete_objects_with_url_encoded_key(key):
s3 = boto3.client('s3', region_name='us-east-1')
bucket_name = 'mybucket'
body = b'Some body'
s3.create_bucket(Bucket=bucket_name)
def put_object():
s3.put_object(
Bucket=bucket_name,
Key=key,
Body=body
)
def assert_deleted():
with assert_raises(ClientError) as e:
s3.get_object(Bucket=bucket_name, Key=key)
e.exception.response['Error']['Code'].should.equal('NoSuchKey')
put_object()
s3.delete_object(Bucket=bucket_name, Key=key)
assert_deleted()
put_object()
s3.delete_objects(Bucket=bucket_name, Delete={'Objects': [{'Key': key}]})
assert_deleted()

View File

@ -1,7 +1,8 @@
from __future__ import unicode_literals
import os
from sure import expect
from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url
from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url, clean_key_name, undo_clean_key_name
from parameterized import parameterized
def test_base_url():
@ -78,3 +79,29 @@ def test_parse_region_from_url():
'https://s3.amazonaws.com/bucket',
'https://bucket.s3.amazonaws.com']:
parse_region_from_url(url).should.equal(expected)
@parameterized([
('foo/bar/baz',
'foo/bar/baz'),
('foo',
'foo'),
('foo/run_dt%3D2019-01-01%252012%253A30%253A00',
'foo/run_dt=2019-01-01%2012%3A30%3A00'),
])
def test_clean_key_name(key, expected):
clean_key_name(key).should.equal(expected)
@parameterized([
('foo/bar/baz',
'foo/bar/baz'),
('foo',
'foo'),
('foo/run_dt%3D2019-01-01%252012%253A30%253A00',
'foo/run_dt%253D2019-01-01%25252012%25253A30%25253A00'),
])
def test_undo_clean_key_name(key, expected):
undo_clean_key_name(key).should.equal(expected)

View File

@ -5,9 +5,9 @@ import boto3
from moto import mock_secretsmanager
from botocore.exceptions import ClientError
import string
import unittest
import pytz
from datetime import datetime
import sure # noqa
from nose.tools import assert_raises
from six import b
@ -23,6 +23,7 @@ def test_get_secret_value():
result = conn.get_secret_value(SecretId='java-util-test-password')
assert result['SecretString'] == 'foosecret'
@mock_secretsmanager
def test_get_secret_value_binary():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -32,6 +33,7 @@ def test_get_secret_value_binary():
result = conn.get_secret_value(SecretId='java-util-test-password')
assert result['SecretBinary'] == b('foosecret')
@mock_secretsmanager
def test_get_secret_that_does_not_exist():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -39,6 +41,7 @@ def test_get_secret_that_does_not_exist():
with assert_raises(ClientError):
result = conn.get_secret_value(SecretId='i-dont-exist')
@mock_secretsmanager
def test_get_secret_that_does_not_match():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -72,6 +75,7 @@ def test_create_secret():
secret = conn.get_secret_value(SecretId='test-secret')
assert secret['SecretString'] == 'foosecret'
@mock_secretsmanager
def test_create_secret_with_tags():
conn = boto3.client('secretsmanager', region_name='us-east-1')
@ -216,6 +220,7 @@ def test_get_random_exclude_lowercase():
ExcludeLowercase=True)
assert any(c.islower() for c in random_password['RandomPassword']) == False
@mock_secretsmanager
def test_get_random_exclude_uppercase():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -224,6 +229,7 @@ def test_get_random_exclude_uppercase():
ExcludeUppercase=True)
assert any(c.isupper() for c in random_password['RandomPassword']) == False
@mock_secretsmanager
def test_get_random_exclude_characters_and_symbols():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -232,6 +238,7 @@ def test_get_random_exclude_characters_and_symbols():
ExcludeCharacters='xyzDje@?!.')
assert any(c in 'xyzDje@?!.' for c in random_password['RandomPassword']) == False
@mock_secretsmanager
def test_get_random_exclude_numbers():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -240,6 +247,7 @@ def test_get_random_exclude_numbers():
ExcludeNumbers=True)
assert any(c.isdigit() for c in random_password['RandomPassword']) == False
@mock_secretsmanager
def test_get_random_exclude_punctuation():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -249,6 +257,7 @@ def test_get_random_exclude_punctuation():
assert any(c in string.punctuation
for c in random_password['RandomPassword']) == False
@mock_secretsmanager
def test_get_random_include_space_false():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -256,6 +265,7 @@ def test_get_random_include_space_false():
random_password = conn.get_random_password(PasswordLength=300)
assert any(c.isspace() for c in random_password['RandomPassword']) == False
@mock_secretsmanager
def test_get_random_include_space_true():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -264,6 +274,7 @@ def test_get_random_include_space_true():
IncludeSpace=True)
assert any(c.isspace() for c in random_password['RandomPassword']) == True
@mock_secretsmanager
def test_get_random_require_each_included_type():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -275,6 +286,7 @@ def test_get_random_require_each_included_type():
assert any(c in string.ascii_uppercase for c in random_password['RandomPassword']) == True
assert any(c in string.digits for c in random_password['RandomPassword']) == True
@mock_secretsmanager
def test_get_random_too_short_password():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -282,6 +294,7 @@ def test_get_random_too_short_password():
with assert_raises(ClientError):
random_password = conn.get_random_password(PasswordLength=3)
@mock_secretsmanager
def test_get_random_too_long_password():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -289,6 +302,7 @@ def test_get_random_too_long_password():
with assert_raises(Exception):
random_password = conn.get_random_password(PasswordLength=5555)
@mock_secretsmanager
def test_describe_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -307,6 +321,7 @@ def test_describe_secret():
assert secret_description_2['Name'] == ('test-secret-2')
assert secret_description_2['ARN'] != '' # Test arn not empty
@mock_secretsmanager
def test_describe_secret_that_does_not_exist():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -314,6 +329,7 @@ def test_describe_secret_that_does_not_exist():
with assert_raises(ClientError):
result = conn.get_secret_value(SecretId='i-dont-exist')
@mock_secretsmanager
def test_describe_secret_that_does_not_match():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -500,6 +516,7 @@ def test_rotate_secret_rotation_period_zero():
# test_server actually handles this error.
assert True
@mock_secretsmanager
def test_rotate_secret_rotation_period_too_long():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -511,6 +528,7 @@ def test_rotate_secret_rotation_period_too_long():
result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME,
RotationRules=rotation_rules)
@mock_secretsmanager
def test_put_secret_value_puts_new_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2')
@ -526,6 +544,45 @@ def test_put_secret_value_puts_new_secret():
assert get_secret_value_dict
assert get_secret_value_dict['SecretString'] == 'foosecret'
@mock_secretsmanager
def test_put_secret_binary_value_puts_new_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2')
put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME,
SecretBinary=b('foosecret'),
VersionStages=['AWSCURRENT'])
version_id = put_secret_value_dict['VersionId']
get_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME,
VersionId=version_id,
VersionStage='AWSCURRENT')
assert get_secret_value_dict
assert get_secret_value_dict['SecretBinary'] == b('foosecret')
@mock_secretsmanager
def test_create_and_put_secret_binary_value_puts_new_secret():
conn = boto3.client('secretsmanager', region_name='us-west-2')
conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret"))
conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, SecretBinary=b('foosecret_update'))
latest_secret = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME)
assert latest_secret
assert latest_secret['SecretBinary'] == b('foosecret_update')
@mock_secretsmanager
def test_put_secret_binary_requires_either_string_or_binary():
conn = boto3.client('secretsmanager', region_name='us-west-2')
with assert_raises(ClientError) as ire:
conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME)
ire.exception.response['Error']['Code'].should.equal('InvalidRequestException')
ire.exception.response['Error']['Message'].should.equal('You must provide either SecretString or SecretBinary.')
@mock_secretsmanager
def test_put_secret_value_can_get_first_version_if_put_twice():
conn = boto3.client('secretsmanager', region_name='us-west-2')

View File

@ -109,6 +109,17 @@ def test_publish_to_sqs_bad():
}})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameterValue')
try:
# Test Number DataType, with a non numeric value
conn.publish(
TopicArn=topic_arn, Message=message,
MessageAttributes={'price': {
'DataType': 'Number',
'StringValue': 'error'
}})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameterValue')
err.response['Error']['Message'].should.equal("An error occurred (ParameterValueInvalid) when calling the Publish operation: Could not cast message attribute 'price' value to number.")
@mock_sqs
@ -487,3 +498,380 @@ def test_filtering_exact_string_no_attributes_no_match():
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_exact_number_int():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'Number', 'Value': 100}}])
@mock_sqs
@mock_sns
def test_filtering_exact_number_float():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100.1]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '100.1'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'Number', 'Value': 100.1}}])
@mock_sqs
@mock_sns
def test_filtering_exact_number_float_accuracy():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100.123456789]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '100.1234561'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'Number', 'Value': 100.1234561}}])
@mock_sqs
@mock_sns
def test_filtering_exact_number_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='no match',
MessageAttributes={'price': {'DataType': 'Number',
'StringValue': '101'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_exact_number_with_string_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='no match',
MessageAttributes={'price': {'DataType': 'String',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_string_array_match():
topic, subscription, queue = _setup_filter_policy_test(
{'customer_interests': ['basketball', 'baseball']})
topic.publish(
Message='match',
MessageAttributes={'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}}])
@mock_sqs
@mock_sns
def test_filtering_string_array_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'customer_interests': ['baseball']})
topic.publish(
Message='no_match',
MessageAttributes={'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_string_array_with_number_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100, 500]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': json.dumps([100, 50])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'String.Array', 'Value': json.dumps([100, 50])}}])
@mock_sqs
@mock_sns
def test_filtering_string_array_with_number_float_accuracy_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100.123456789, 500]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': json.dumps([100.1234561, 50])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'String.Array', 'Value': json.dumps([100.1234561, 50])}}])
@mock_sqs
@mock_sns
# this is the correct behavior from SNS
def test_filtering_string_array_with_number_no_array_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100, 500]})
topic.publish(
Message='match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'price': {'Type': 'String.Array', 'Value': '100'}}])
@mock_sqs
@mock_sns
def test_filtering_string_array_with_number_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [500]})
topic.publish(
Message='no_match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': json.dumps([100, 50])}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
# this is the correct behavior from SNS
def test_filtering_string_array_with_string_no_array_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'price': [100]})
topic.publish(
Message='no_match',
MessageAttributes={'price': {'DataType': 'String.Array',
'StringValue': 'one hundread'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_exists_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}]})
topic.publish(
Message='match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'store': {'Type': 'String', 'Value': 'example_corp'}}])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_exists_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}]})
topic.publish(
Message='no match',
MessageAttributes={'event': {'DataType': 'String',
'StringValue': 'order_cancelled'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_not_exists_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': False}]})
topic.publish(
Message='match',
MessageAttributes={'event': {'DataType': 'String',
'StringValue': 'order_cancelled'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal(
[{'event': {'Type': 'String', 'Value': 'order_cancelled'}}])
@mock_sqs
@mock_sns
def test_filtering_attribute_key_not_exists_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': False}]})
topic.publish(
Message='no match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])
@mock_sqs
@mock_sns
def test_filtering_all_AND_matching_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}],
'event': ['order_cancelled'],
'customer_interests': ['basketball', 'baseball'],
'price': [100]})
topic.publish(
Message='match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'},
'event': {'DataType': 'String',
'StringValue': 'order_cancelled'},
'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])},
'price': {'DataType': 'Number',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal(
['match'])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([{
'store': {'Type': 'String', 'Value': 'example_corp'},
'event': {'Type': 'String', 'Value': 'order_cancelled'},
'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])},
'price': {'Type': 'Number', 'Value': 100}}])
@mock_sqs
@mock_sns
def test_filtering_all_AND_matching_no_match():
topic, subscription, queue = _setup_filter_policy_test(
{'store': [{'exists': True}],
'event': ['order_cancelled'],
'customer_interests': ['basketball', 'baseball'],
'price': [100],
"encrypted": [False]})
topic.publish(
Message='no match',
MessageAttributes={'store': {'DataType': 'String',
'StringValue': 'example_corp'},
'event': {'DataType': 'String',
'StringValue': 'order_cancelled'},
'customer_interests': {'DataType': 'String.Array',
'StringValue': json.dumps(['basketball', 'rugby'])},
'price': {'DataType': 'Number',
'StringValue': '100'}})
messages = queue.receive_messages(MaxNumberOfMessages=5)
message_bodies = [json.loads(m.body)['Message'] for m in messages]
message_bodies.should.equal([])
message_attributes = [
json.loads(m.body)['MessageAttributes'] for m in messages]
message_attributes.should.equal([])

View File

@ -201,7 +201,9 @@ def test_creating_subscription_with_attributes():
"store": ["example_corp"],
"event": ["order_cancelled"],
"encrypted": [False],
"customer_interests": ["basketball", "baseball"]
"customer_interests": ["basketball", "baseball"],
"price": [100, 100.12],
"error": [None]
})
conn.subscribe(TopicArn=topic_arn,
@ -294,7 +296,9 @@ def test_set_subscription_attributes():
"store": ["example_corp"],
"event": ["order_cancelled"],
"encrypted": [False],
"customer_interests": ["basketball", "baseball"]
"customer_interests": ["basketball", "baseball"],
"price": [100, 100.12],
"error": [None]
})
conn.set_subscription_attributes(
SubscriptionArn=subscription_arn,
@ -332,6 +336,77 @@ def test_set_subscription_attributes():
)
@mock_sns
def test_subscribe_invalid_filter_policy():
conn = boto3.client('sns', region_name = 'us-east-1')
conn.create_topic(Name = 'some-topic')
response = conn.list_topics()
topic_arn = response['Topics'][0]['TopicArn']
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [str(i) for i in range(151)]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Filter policy is too complex')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [['example_corp']]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [{'exists': None}]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: exists match pattern must be either true or false.')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [{'error': True}]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InvalidParameter')
err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Unrecognized match type error')
try:
conn.subscribe(TopicArn = topic_arn,
Protocol = 'http',
Endpoint = 'http://example.com/',
Attributes = {
'FilterPolicy': json.dumps({
'store': [1000000001]
})
})
except ClientError as err:
err.response['Error']['Code'].should.equal('InternalFailure')
@mock_sns
def test_check_not_opted_out():
conn = boto3.client('sns', region_name='us-east-1')

View File

@ -1117,6 +1117,28 @@ def test_redrive_policy_set_attributes():
assert copy_policy == redrive_policy
@mock_sqs
def test_redrive_policy_set_attributes_with_string_value():
sqs = boto3.resource('sqs', region_name='us-east-1')
queue = sqs.create_queue(QueueName='test-queue')
deadletter_queue = sqs.create_queue(QueueName='test-deadletter')
queue.set_attributes(Attributes={
'RedrivePolicy': json.dumps({
'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'],
'maxReceiveCount': '1',
})})
copy = sqs.get_queue_by_name(QueueName='test-queue')
assert 'RedrivePolicy' in copy.attributes
copy_policy = json.loads(copy.attributes['RedrivePolicy'])
assert copy_policy == {
'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'],
'maxReceiveCount': 1,
}
@mock_sqs
def test_receive_messages_with_message_group_id():
sqs = boto3.resource('sqs', region_name='us-east-1')

View File

@ -0,0 +1,378 @@
from __future__ import unicode_literals
import boto3
import sure # noqa
import datetime
from datetime import datetime
from botocore.exceptions import ClientError
from nose.tools import assert_raises
from moto import mock_sts, mock_stepfunctions
region = 'us-east-1'
simple_definition = '{"Comment": "An example of the Amazon States Language using a choice state.",' \
'"StartAt": "DefaultState",' \
'"States": ' \
'{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}'
account_id = None
@mock_stepfunctions
@mock_sts
def test_state_machine_creation_succeeds():
client = boto3.client('stepfunctions', region_name=region)
name = 'example_step_function'
#
response = client.create_state_machine(name=name,
definition=str(simple_definition),
roleArn=_get_default_role())
#
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
response['creationDate'].should.be.a(datetime)
response['stateMachineArn'].should.equal('arn:aws:states:' + region + ':123456789012:stateMachine:' + name)
@mock_stepfunctions
def test_state_machine_creation_fails_with_invalid_names():
client = boto3.client('stepfunctions', region_name=region)
invalid_names = [
'with space',
'with<bracket', 'with>bracket', 'with{bracket', 'with}bracket', 'with[bracket', 'with]bracket',
'with?wildcard', 'with*wildcard',
'special"char', 'special#char', 'special%char', 'special\\char', 'special^char', 'special|char',
'special~char', 'special`char', 'special$char', 'special&char', 'special,char', 'special;char',
'special:char', 'special/char',
u'uni\u0000code', u'uni\u0001code', u'uni\u0002code', u'uni\u0003code', u'uni\u0004code',
u'uni\u0005code', u'uni\u0006code', u'uni\u0007code', u'uni\u0008code', u'uni\u0009code',
u'uni\u000Acode', u'uni\u000Bcode', u'uni\u000Ccode',
u'uni\u000Dcode', u'uni\u000Ecode', u'uni\u000Fcode',
u'uni\u0010code', u'uni\u0011code', u'uni\u0012code', u'uni\u0013code', u'uni\u0014code',
u'uni\u0015code', u'uni\u0016code', u'uni\u0017code', u'uni\u0018code', u'uni\u0019code',
u'uni\u001Acode', u'uni\u001Bcode', u'uni\u001Ccode',
u'uni\u001Dcode', u'uni\u001Ecode', u'uni\u001Fcode',
u'uni\u007Fcode',
u'uni\u0080code', u'uni\u0081code', u'uni\u0082code', u'uni\u0083code', u'uni\u0084code',
u'uni\u0085code', u'uni\u0086code', u'uni\u0087code', u'uni\u0088code', u'uni\u0089code',
u'uni\u008Acode', u'uni\u008Bcode', u'uni\u008Ccode',
u'uni\u008Dcode', u'uni\u008Ecode', u'uni\u008Fcode',
u'uni\u0090code', u'uni\u0091code', u'uni\u0092code', u'uni\u0093code', u'uni\u0094code',
u'uni\u0095code', u'uni\u0096code', u'uni\u0097code', u'uni\u0098code', u'uni\u0099code',
u'uni\u009Acode', u'uni\u009Bcode', u'uni\u009Ccode',
u'uni\u009Dcode', u'uni\u009Ecode', u'uni\u009Fcode']
#
for invalid_name in invalid_names:
with assert_raises(ClientError) as exc:
client.create_state_machine(name=invalid_name,
definition=str(simple_definition),
roleArn=_get_default_role())
@mock_stepfunctions
def test_state_machine_creation_requires_valid_role_arn():
client = boto3.client('stepfunctions', region_name=region)
name = 'example_step_function'
#
with assert_raises(ClientError) as exc:
client.create_state_machine(name=name,
definition=str(simple_definition),
roleArn='arn:aws:iam:1234:role/unknown_role')
@mock_stepfunctions
def test_state_machine_list_returns_empty_list_by_default():
client = boto3.client('stepfunctions', region_name=region)
#
list = client.list_state_machines()
list['stateMachines'].should.be.empty
@mock_stepfunctions
@mock_sts
def test_state_machine_list_returns_created_state_machines():
client = boto3.client('stepfunctions', region_name=region)
#
machine2 = client.create_state_machine(name='name2',
definition=str(simple_definition),
roleArn=_get_default_role())
machine1 = client.create_state_machine(name='name1',
definition=str(simple_definition),
roleArn=_get_default_role(),
tags=[{'key': 'tag_key', 'value': 'tag_value'}])
list = client.list_state_machines()
#
list['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
list['stateMachines'].should.have.length_of(2)
list['stateMachines'][0]['creationDate'].should.be.a(datetime)
list['stateMachines'][0]['creationDate'].should.equal(machine1['creationDate'])
list['stateMachines'][0]['name'].should.equal('name1')
list['stateMachines'][0]['stateMachineArn'].should.equal(machine1['stateMachineArn'])
list['stateMachines'][1]['creationDate'].should.be.a(datetime)
list['stateMachines'][1]['creationDate'].should.equal(machine2['creationDate'])
list['stateMachines'][1]['name'].should.equal('name2')
list['stateMachines'][1]['stateMachineArn'].should.equal(machine2['stateMachineArn'])
@mock_stepfunctions
@mock_sts
def test_state_machine_creation_is_idempotent_by_name():
client = boto3.client('stepfunctions', region_name=region)
#
client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(1)
#
client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(1)
#
client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=_get_default_role())
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(2)
@mock_stepfunctions
@mock_sts
def test_state_machine_creation_can_be_described():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
desc = client.describe_state_machine(stateMachineArn=sm['stateMachineArn'])
desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
desc['creationDate'].should.equal(sm['creationDate'])
desc['definition'].should.equal(str(simple_definition))
desc['name'].should.equal('name')
desc['roleArn'].should.equal(_get_default_role())
desc['stateMachineArn'].should.equal(sm['stateMachineArn'])
desc['status'].should.equal('ACTIVE')
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_unknown_machine():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown'
client.describe_state_machine(stateMachineArn=unknown_state_machine)
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_machine_in_different_account():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_state_machine = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown'
client.describe_state_machine(stateMachineArn=unknown_state_machine)
@mock_stepfunctions
@mock_sts
def test_state_machine_can_be_deleted():
client = boto3.client('stepfunctions', region_name=region)
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
#
response = client.delete_state_machine(stateMachineArn=sm['stateMachineArn'])
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
#
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_can_deleted_nonexisting_machine():
client = boto3.client('stepfunctions', region_name=region)
#
unknown_state_machine = 'arn:aws:states:' + region + ':123456789012:stateMachine:unknown'
response = client.delete_state_machine(stateMachineArn=unknown_state_machine)
response['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
#
sm_list = client.list_state_machines()
sm_list['stateMachines'].should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_list_tags_for_created_machine():
client = boto3.client('stepfunctions', region_name=region)
#
machine = client.create_state_machine(name='name1',
definition=str(simple_definition),
roleArn=_get_default_role(),
tags=[{'key': 'tag_key', 'value': 'tag_value'}])
response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn'])
tags = response['tags']
tags.should.have.length_of(1)
tags[0].should.equal({'key': 'tag_key', 'value': 'tag_value'})
@mock_stepfunctions
@mock_sts
def test_state_machine_list_tags_for_machine_without_tags():
client = boto3.client('stepfunctions', region_name=region)
#
machine = client.create_state_machine(name='name1',
definition=str(simple_definition),
roleArn=_get_default_role())
response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn'])
tags = response['tags']
tags.should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_list_tags_for_nonexisting_machine():
client = boto3.client('stepfunctions', region_name=region)
#
non_existing_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown'
response = client.list_tags_for_resource(resourceArn=non_existing_state_machine)
tags = response['tags']
tags.should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_start_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
#
execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
expected_exec_name = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:name:[a-zA-Z0-9-]+'
execution['executionArn'].should.match(expected_exec_name)
execution['startDate'].should.be.a(datetime)
@mock_stepfunctions
@mock_sts
def test_state_machine_list_executions():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
execution_arn = execution['executionArn']
execution_name = execution_arn[execution_arn.rindex(':')+1:]
executions = client.list_executions(stateMachineArn=sm['stateMachineArn'])
#
executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
executions['executions'].should.have.length_of(1)
executions['executions'][0]['executionArn'].should.equal(execution_arn)
executions['executions'][0]['name'].should.equal(execution_name)
executions['executions'][0]['startDate'].should.equal(execution['startDate'])
executions['executions'][0]['stateMachineArn'].should.equal(sm['stateMachineArn'])
executions['executions'][0]['status'].should.equal('RUNNING')
executions['executions'][0].shouldnt.have('stopDate')
@mock_stepfunctions
@mock_sts
def test_state_machine_list_executions_when_none_exist():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
executions = client.list_executions(stateMachineArn=sm['stateMachineArn'])
#
executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
executions['executions'].should.have.length_of(0)
@mock_stepfunctions
@mock_sts
def test_state_machine_describe_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
description = client.describe_execution(executionArn=execution['executionArn'])
#
description['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
description['executionArn'].should.equal(execution['executionArn'])
description['input'].should.equal("{}")
description['name'].shouldnt.be.empty
description['startDate'].should.equal(execution['startDate'])
description['stateMachineArn'].should.equal(sm['stateMachineArn'])
description['status'].should.equal('RUNNING')
description.shouldnt.have('stopDate')
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_unknown_machine():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown'
client.describe_execution(executionArn=unknown_execution)
@mock_stepfunctions
@mock_sts
def test_state_machine_can_be_described_by_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
desc = client.describe_state_machine_for_execution(executionArn=execution['executionArn'])
desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
desc['definition'].should.equal(str(simple_definition))
desc['name'].should.equal('name')
desc['roleArn'].should.equal(_get_default_role())
desc['stateMachineArn'].should.equal(sm['stateMachineArn'])
@mock_stepfunctions
@mock_sts
def test_state_machine_throws_error_when_describing_unknown_execution():
client = boto3.client('stepfunctions', region_name=region)
#
with assert_raises(ClientError) as exc:
unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown'
client.describe_state_machine_for_execution(executionArn=unknown_execution)
@mock_stepfunctions
@mock_sts
def test_state_machine_stop_execution():
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
start = client.start_execution(stateMachineArn=sm['stateMachineArn'])
stop = client.stop_execution(executionArn=start['executionArn'])
#
stop['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
stop['stopDate'].should.be.a(datetime)
@mock_stepfunctions
@mock_sts
def test_state_machine_describe_execution_after_stoppage():
account_id
client = boto3.client('stepfunctions', region_name=region)
#
sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role())
execution = client.start_execution(stateMachineArn=sm['stateMachineArn'])
client.stop_execution(executionArn=execution['executionArn'])
description = client.describe_execution(executionArn=execution['executionArn'])
#
description['ResponseMetadata']['HTTPStatusCode'].should.equal(200)
description['status'].should.equal('SUCCEEDED')
description['stopDate'].should.be.a(datetime)
def _get_account_id():
global account_id
if account_id:
return account_id
sts = boto3.client("sts")
identity = sts.get_caller_identity()
account_id = identity['Account']
return account_id
def _get_default_role():
return 'arn:aws:iam:' + _get_account_id() + ':role/unknown_sf_role'