Merge pull request #48 from spulec/master

Merge upstream
This commit is contained in:
Bert Blommers 2020-06-20 09:38:37 +01:00 committed by GitHub
commit c5f8fa4e1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
40 changed files with 1548 additions and 186 deletions

View File

@ -1,32 +1,41 @@
### Contributing code # Contributing code
Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project. Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_CONDUCT.md), you can expect to be treated with respect at all times when interacting with this project.
## Running the tests locally ## Running the tests locally
Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests. Moto has a [Makefile](./Makefile) which has some helpful commands for getting set up.
You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
*NB. On first run, some tests might take a while to execute, especially the Lambda ones, because they may need to download a Docker image before they can execute.*
## Linting ## Linting
Run `make lint` or `black --check moto tests` to verify whether your code confirms to the guidelines. Run `make lint` or `black --check moto tests` to verify whether your code confirms to the guidelines.
## Is there a missing feature? ## Getting to grips with the codebase
Moto maintains a list of [good first issues](https://github.com/spulec/moto/contribute) which you may want to look at before
implementing a whole new endpoint.
## Missing features
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services. Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.
How to teach Moto to support a new AWS endpoint: How to teach Moto to support a new AWS endpoint:
* Create an issue describing what's missing. This is where we'll all talk about the new addition and help you get it done. * Search for an existing [issue](https://github.com/spulec/moto/issues) that matches what you want to achieve.
* If one doesn't already exist, create a new issue describing what's missing. This is where we'll all talk about the new addition and help you get it done.
* Create a [pull request](https://help.github.com/articles/using-pull-requests/) and mention the issue # in the PR description. * Create a [pull request](https://help.github.com/articles/using-pull-requests/) and mention the issue # in the PR description.
* Try to add a failing test case. For example, if you're trying to implement `boto3.client('acm').import_certificate()` you'll want to add a new method called `def test_import_certificate` to `tests/test_acm/test_acm.py`. * Try to add a failing test case. For example, if you're trying to implement `boto3.client('acm').import_certificate()` you'll want to add a new method called `def test_import_certificate` to `tests/test_acm/test_acm.py`.
* If you can also implement the code that gets that test passing that's great. If not, just ask the community for a hand and somebody will assist you. * If you can also implement the code that gets that test passing that's great. If not, just ask the community for a hand and somebody will assist you.
# Maintainers ## Maintainers
## Releasing a new version of Moto ### Releasing a new version of Moto
You'll need a PyPi account and a Dockerhub account to release Moto. After we release a new PyPi package we build and push the [motoserver/moto](https://hub.docker.com/r/motoserver/moto/) Docker image. You'll need a PyPi account and a DockerHub account to release Moto. After we release a new PyPi package we build and push the [motoserver/moto](https://hub.docker.com/r/motoserver/moto/) Docker image.
* First, `scripts/bump_version` modifies the version and opens a PR * First, `scripts/bump_version` modifies the version and opens a PR
* Then, merge the new pull request * Then, merge the new pull request
* Finally, generate and ship the new artifacts with `make publish` * Finally, generate and ship the new artifacts with `make publish`

File diff suppressed because it is too large Load Diff

View File

@ -56,13 +56,21 @@ class Deployment(BaseModel, dict):
class IntegrationResponse(BaseModel, dict): class IntegrationResponse(BaseModel, dict):
def __init__(self, status_code, selection_pattern=None, response_templates=None): def __init__(
self,
status_code,
selection_pattern=None,
response_templates=None,
content_handling=None,
):
if response_templates is None: if response_templates is None:
response_templates = {"application/json": None} response_templates = {"application/json": None}
self["responseTemplates"] = response_templates self["responseTemplates"] = response_templates
self["statusCode"] = status_code self["statusCode"] = status_code
if selection_pattern: if selection_pattern:
self["selectionPattern"] = selection_pattern self["selectionPattern"] = selection_pattern
if content_handling:
self["contentHandling"] = content_handling
class Integration(BaseModel, dict): class Integration(BaseModel, dict):
@ -75,12 +83,12 @@ class Integration(BaseModel, dict):
self["integrationResponses"] = {"200": IntegrationResponse(200)} self["integrationResponses"] = {"200": IntegrationResponse(200)}
def create_integration_response( def create_integration_response(
self, status_code, selection_pattern, response_templates self, status_code, selection_pattern, response_templates, content_handling
): ):
if response_templates == {}: if response_templates == {}:
response_templates = None response_templates = None
integration_response = IntegrationResponse( integration_response = IntegrationResponse(
status_code, selection_pattern, response_templates status_code, selection_pattern, response_templates, content_handling
) )
self["integrationResponses"][status_code] = integration_response self["integrationResponses"][status_code] = integration_response
return integration_response return integration_response
@ -959,12 +967,13 @@ class APIGatewayBackend(BaseBackend):
status_code, status_code,
selection_pattern, selection_pattern,
response_templates, response_templates,
content_handling,
): ):
if response_templates is None: if response_templates is None:
raise InvalidRequestInput() raise InvalidRequestInput()
integration = self.get_integration(function_id, resource_id, method_type) integration = self.get_integration(function_id, resource_id, method_type)
integration_response = integration.create_integration_response( integration_response = integration.create_integration_response(
status_code, selection_pattern, response_templates status_code, selection_pattern, response_templates, content_handling
) )
return integration_response return integration_response

View File

@ -387,6 +387,7 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "PUT": elif self.method == "PUT":
selection_pattern = self._get_param("selectionPattern") selection_pattern = self._get_param("selectionPattern")
response_templates = self._get_param("responseTemplates") response_templates = self._get_param("responseTemplates")
content_handling = self._get_param("contentHandling")
integration_response = self.backend.create_integration_response( integration_response = self.backend.create_integration_response(
function_id, function_id,
resource_id, resource_id,
@ -394,6 +395,7 @@ class APIGatewayResponse(BaseResponse):
status_code, status_code,
selection_pattern, selection_pattern,
response_templates, response_templates,
content_handling,
) )
elif self.method == "DELETE": elif self.method == "DELETE":
integration_response = self.backend.delete_integration_response( integration_response = self.backend.delete_integration_response(

View File

@ -60,6 +60,16 @@ class Execution(BaseModel):
self.status = "QUEUED" self.status = "QUEUED"
class NamedQuery(BaseModel):
def __init__(self, name, description, database, query_string, workgroup):
self.id = str(uuid4())
self.name = name
self.description = description
self.database = database
self.query_string = query_string
self.workgroup = workgroup
class AthenaBackend(BaseBackend): class AthenaBackend(BaseBackend):
region_name = None region_name = None
@ -68,6 +78,7 @@ class AthenaBackend(BaseBackend):
self.region_name = region_name self.region_name = region_name
self.work_groups = {} self.work_groups = {}
self.executions = {} self.executions = {}
self.named_queries = {}
def create_work_group(self, name, configuration, description, tags): def create_work_group(self, name, configuration, description, tags):
if name in self.work_groups: if name in self.work_groups:
@ -113,6 +124,20 @@ class AthenaBackend(BaseBackend):
execution = self.executions[exec_id] execution = self.executions[exec_id]
execution.status = "CANCELLED" execution.status = "CANCELLED"
def create_named_query(self, name, description, database, query_string, workgroup):
nq = NamedQuery(
name=name,
description=description,
database=database,
query_string=query_string,
workgroup=workgroup,
)
self.named_queries[nq.id] = nq
return nq.id
def get_named_query(self, query_id):
return self.named_queries[query_id] if query_id in self.named_queries else None
athena_backends = {} athena_backends = {}
for region in Session().get_available_regions("athena"): for region in Session().get_available_regions("athena"):

View File

@ -85,3 +85,32 @@ class AthenaResponse(BaseResponse):
json.dumps({"__type": "InvalidRequestException", "Message": msg,}), json.dumps({"__type": "InvalidRequestException", "Message": msg,}),
dict(status=status), dict(status=status),
) )
def create_named_query(self):
name = self._get_param("Name")
description = self._get_param("Description")
database = self._get_param("Database")
query_string = self._get_param("QueryString")
workgroup = self._get_param("WorkGroup")
if workgroup and not self.athena_backend.get_work_group(workgroup):
return self.error("WorkGroup does not exist", 400)
query_id = self.athena_backend.create_named_query(
name, description, database, query_string, workgroup
)
return json.dumps({"NamedQueryId": query_id})
def get_named_query(self):
query_id = self._get_param("NamedQueryId")
nq = self.athena_backend.get_named_query(query_id)
return json.dumps(
{
"NamedQuery": {
"Name": nq.name,
"Description": nq.description,
"Database": nq.database,
"QueryString": nq.query_string,
"NamedQueryId": nq.id,
"WorkGroup": nq.workgroup,
}
}
)

View File

@ -218,7 +218,7 @@ class LambdaFunction(BaseModel):
key = None key = None
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_key(self.code["S3Bucket"], self.code["S3Key"]) key = s3_backend.get_object(self.code["S3Bucket"], self.code["S3Key"])
except MissingBucket: except MissingBucket:
if do_validate_s3(): if do_validate_s3():
raise InvalidParameterValueException( raise InvalidParameterValueException(
@ -344,7 +344,7 @@ class LambdaFunction(BaseModel):
key = None key = None
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_key( key = s3_backend.get_object(
updated_spec["S3Bucket"], updated_spec["S3Key"] updated_spec["S3Bucket"], updated_spec["S3Key"]
) )
except MissingBucket: except MissingBucket:
@ -555,40 +555,63 @@ class LambdaFunction(BaseModel):
class EventSourceMapping(BaseModel): class EventSourceMapping(BaseModel):
def __init__(self, spec): def __init__(self, spec):
# required # required
self.function_arn = spec["FunctionArn"] self.function_name = spec["FunctionName"]
self.event_source_arn = spec["EventSourceArn"] self.event_source_arn = spec["EventSourceArn"]
# optional
self.batch_size = spec.get("BatchSize")
self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON")
self.enabled = spec.get("Enabled", True)
self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None)
self.function_arn = spec["FunctionArn"]
self.uuid = str(uuid.uuid4()) self.uuid = str(uuid.uuid4())
self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple()) self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
# BatchSize service default/max mapping def _get_service_source_from_arn(self, event_source_arn):
batch_size_map = { return event_source_arn.split(":")[2].lower()
def _validate_event_source(self, event_source_arn):
valid_services = ("dynamodb", "kinesis", "sqs")
service = self._get_service_source_from_arn(event_source_arn)
return True if service in valid_services else False
@property
def event_source_arn(self):
return self._event_source_arn
@event_source_arn.setter
def event_source_arn(self, event_source_arn):
if not self._validate_event_source(event_source_arn):
raise ValueError(
"InvalidParameterValueException", "Unsupported event source type"
)
self._event_source_arn = event_source_arn
@property
def batch_size(self):
return self._batch_size
@batch_size.setter
def batch_size(self, batch_size):
batch_size_service_map = {
"kinesis": (100, 10000), "kinesis": (100, 10000),
"dynamodb": (100, 1000), "dynamodb": (100, 1000),
"sqs": (10, 10), "sqs": (10, 10),
} }
source_type = self.event_source_arn.split(":")[2].lower()
batch_size_entry = batch_size_map.get(source_type)
if batch_size_entry:
# Use service default if not provided
batch_size = int(spec.get("BatchSize", batch_size_entry[0]))
if batch_size > batch_size_entry[1]:
raise ValueError(
"InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(
batch_size, batch_size_entry[1]
),
)
else:
self.batch_size = batch_size
else:
raise ValueError(
"InvalidParameterValueException", "Unsupported event source type"
)
# optional source_type = self._get_service_source_from_arn(self.event_source_arn)
self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON") batch_size_for_source = batch_size_service_map[source_type]
self.enabled = spec.get("Enabled", True)
self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None) if batch_size is None:
self._batch_size = batch_size_for_source[0]
elif batch_size > batch_size_for_source[1]:
error_message = "BatchSize {} exceeds the max of {}".format(
batch_size, batch_size_for_source[1]
)
raise ValueError("InvalidParameterValueException", error_message)
else:
self._batch_size = int(batch_size)
def get_configuration(self): def get_configuration(self):
return { return {
@ -602,23 +625,42 @@ class EventSourceMapping(BaseModel):
"StateTransitionReason": "User initiated", "StateTransitionReason": "User initiated",
} }
def delete(self, region_name):
lambda_backend = lambda_backends[region_name]
lambda_backend.delete_event_source_mapping(self.uuid)
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
): ):
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
func = lambda_backends[region_name].get_function(properties["FunctionName"]) lambda_backend = lambda_backends[region_name]
spec = { return lambda_backend.create_event_source_mapping(properties)
"FunctionArn": func.function_arn,
"EventSourceArn": properties["EventSourceArn"], @classmethod
"StartingPosition": properties["StartingPosition"], def update_from_cloudformation_json(
"BatchSize": properties.get("BatchSize", 100), cls, new_resource_name, cloudformation_json, original_resource, region_name
} ):
optional_properties = "BatchSize Enabled StartingPositionTimestamp".split() properties = cloudformation_json["Properties"]
for prop in optional_properties: event_source_uuid = original_resource.uuid
if prop in properties: lambda_backend = lambda_backends[region_name]
spec[prop] = properties[prop] return lambda_backend.update_event_source_mapping(event_source_uuid, properties)
return EventSourceMapping(spec)
@classmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
lambda_backend = lambda_backends[region_name]
esms = lambda_backend.list_event_source_mappings(
event_source_arn=properties["EventSourceArn"],
function_name=properties["FunctionName"],
)
for esm in esms:
if esm.logical_resource_id in resource_name:
lambda_backend.delete_event_source_mapping
esm.delete(region_name)
class LambdaVersion(BaseModel): class LambdaVersion(BaseModel):
@ -819,7 +861,7 @@ class LambdaBackend(BaseBackend):
) )
# Validate function name # Validate function name
func = self._lambdas.get_function_by_name_or_arn(spec.pop("FunctionName", "")) func = self._lambdas.get_function_by_name_or_arn(spec.get("FunctionName", ""))
if not func: if not func:
raise RESTError("ResourceNotFoundException", "Invalid FunctionName") raise RESTError("ResourceNotFoundException", "Invalid FunctionName")
@ -877,18 +919,20 @@ class LambdaBackend(BaseBackend):
def update_event_source_mapping(self, uuid, spec): def update_event_source_mapping(self, uuid, spec):
esm = self.get_event_source_mapping(uuid) esm = self.get_event_source_mapping(uuid)
if esm: if not esm:
if spec.get("FunctionName"): return False
func = self._lambdas.get_function_by_name_or_arn(
spec.get("FunctionName") for key, value in spec.items():
) if key == "FunctionName":
func = self._lambdas.get_function_by_name_or_arn(spec[key])
esm.function_arn = func.function_arn esm.function_arn = func.function_arn
if "BatchSize" in spec: elif key == "BatchSize":
esm.batch_size = spec["BatchSize"] esm.batch_size = spec[key]
if "Enabled" in spec: elif key == "Enabled":
esm.enabled = spec["Enabled"] esm.enabled = spec[key]
return esm
return False esm.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
return esm
def list_event_source_mappings(self, event_source_arn, function_name): def list_event_source_mappings(self, event_source_arn, function_name):
esms = list(self._event_source_mappings.values()) esms = list(self._event_source_mappings.values())

View File

@ -315,8 +315,8 @@ class FakeStack(BaseModel):
yaml.add_multi_constructor("", yaml_tag_constructor) yaml.add_multi_constructor("", yaml_tag_constructor)
try: try:
self.template_dict = yaml.load(self.template, Loader=yaml.Loader) self.template_dict = yaml.load(self.template, Loader=yaml.Loader)
except yaml.parser.ParserError: except (yaml.parser.ParserError, yaml.scanner.ScannerError):
self.template_dict = json.loads(self.template, Loader=yaml.Loader) self.template_dict = json.loads(self.template)
@property @property
def stack_parameters(self): def stack_parameters(self):

View File

@ -541,7 +541,7 @@ class ResourceMap(collections_abc.Mapping):
if name == "AWS::Include": if name == "AWS::Include":
location = params["Location"] location = params["Location"]
bucket_name, name = bucket_and_name_from_url(location) bucket_name, name = bucket_and_name_from_url(location)
key = s3_backend.get_key(bucket_name, name) key = s3_backend.get_object(bucket_name, name)
self._parsed_resources.update(json.loads(key.value)) self._parsed_resources.update(json.loads(key.value))
def load_parameters(self): def load_parameters(self):

View File

@ -36,7 +36,7 @@ class CloudFormationResponse(BaseResponse):
bucket_name = template_url_parts.netloc.split(".")[0] bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/") key_name = template_url_parts.path.lstrip("/")
key = s3_backend.get_key(bucket_name, key_name) key = s3_backend.get_object(bucket_name, key_name)
return key.value.decode("utf-8") return key.value.decode("utf-8")
def create_stack(self): def create_stack(self):
@ -50,6 +50,12 @@ class CloudFormationResponse(BaseResponse):
for item in self._get_list_prefix("Tags.member") for item in self._get_list_prefix("Tags.member")
) )
if self.stack_name_exists(new_stack_name=stack_name):
template = self.response_template(
CREATE_STACK_NAME_EXISTS_RESPONSE_TEMPLATE
)
return 400, {"status": 400}, template.render(name=stack_name)
# Hack dict-comprehension # Hack dict-comprehension
parameters = dict( parameters = dict(
[ [
@ -82,6 +88,12 @@ class CloudFormationResponse(BaseResponse):
template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE) template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE)
return template.render(stack=stack) return template.render(stack=stack)
def stack_name_exists(self, new_stack_name):
for stack in self.cloudformation_backend.stacks.values():
if stack.name == new_stack_name:
return True
return False
@amzn_request_id @amzn_request_id
def create_change_set(self): def create_change_set(self):
stack_name = self._get_param("StackName") stack_name = self._get_param("StackName")
@ -564,6 +576,15 @@ CREATE_STACK_RESPONSE_TEMPLATE = """<CreateStackResponse>
</CreateStackResponse> </CreateStackResponse>
""" """
CREATE_STACK_NAME_EXISTS_RESPONSE_TEMPLATE = """<ErrorResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<Error>
<Type>Sender</Type>
<Code>AlreadyExistsException</Code>
<Message>Stack [{{ name }}] already exists</Message>
</Error>
<RequestId>950ff8d7-812a-44b3-bb0c-9b271b954104</RequestId>
</ErrorResponse>"""
UPDATE_STACK_RESPONSE_TEMPLATE = """<UpdateStackResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/"> UPDATE_STACK_RESPONSE_TEMPLATE = """<UpdateStackResponse xmlns="http://cloudformation.amazonaws.com/doc/2010-05-15/">
<UpdateStackResult> <UpdateStackResult>
<StackId>{{ stack.stack_id }}</StackId> <StackId>{{ stack.stack_id }}</StackId>

View File

@ -184,6 +184,8 @@ class CallbackResponse(responses.CallbackResponse):
body = None body = None
elif isinstance(request.body, six.text_type): elif isinstance(request.body, six.text_type):
body = six.BytesIO(six.b(request.body)) body = six.BytesIO(six.b(request.body))
elif hasattr(request.body, "read"):
body = six.BytesIO(request.body.read())
else: else:
body = six.BytesIO(request.body) body = six.BytesIO(request.body)
req = Request.from_values( req = Request.from_values(

View File

@ -272,6 +272,66 @@ class StreamShard(BaseModel):
return [i.to_json() for i in self.items[start:end]] return [i.to_json() for i in self.items[start:end]]
class LocalSecondaryIndex(BaseModel):
def __init__(self, index_name, schema, projection):
self.name = index_name
self.schema = schema
self.projection = projection
def describe(self):
return {
"IndexName": self.name,
"KeySchema": self.schema,
"Projection": self.projection,
}
@staticmethod
def create(dct):
return LocalSecondaryIndex(
index_name=dct["IndexName"],
schema=dct["KeySchema"],
projection=dct["Projection"],
)
class GlobalSecondaryIndex(BaseModel):
def __init__(
self, index_name, schema, projection, status="ACTIVE", throughput=None
):
self.name = index_name
self.schema = schema
self.projection = projection
self.status = status
self.throughput = throughput or {
"ReadCapacityUnits": 0,
"WriteCapacityUnits": 0,
}
def describe(self):
return {
"IndexName": self.name,
"KeySchema": self.schema,
"Projection": self.projection,
"IndexStatus": self.status,
"ProvisionedThroughput": self.throughput,
}
@staticmethod
def create(dct):
return GlobalSecondaryIndex(
index_name=dct["IndexName"],
schema=dct["KeySchema"],
projection=dct["Projection"],
throughput=dct.get("ProvisionedThroughput", None),
)
def update(self, u):
self.name = u.get("IndexName", self.name)
self.schema = u.get("KeySchema", self.schema)
self.projection = u.get("Projection", self.projection)
self.throughput = u.get("ProvisionedThroughput", self.throughput)
class Table(BaseModel): class Table(BaseModel):
def __init__( def __init__(
self, self,
@ -302,12 +362,13 @@ class Table(BaseModel):
else: else:
self.throughput = throughput self.throughput = throughput
self.throughput["NumberOfDecreasesToday"] = 0 self.throughput["NumberOfDecreasesToday"] = 0
self.indexes = indexes self.indexes = [
self.global_indexes = global_indexes if global_indexes else [] LocalSecondaryIndex.create(i) for i in (indexes if indexes else [])
for index in self.global_indexes: ]
index[ self.global_indexes = [
"IndexStatus" GlobalSecondaryIndex.create(i)
] = "ACTIVE" # One of 'CREATING'|'UPDATING'|'DELETING'|'ACTIVE' for i in (global_indexes if global_indexes else [])
]
self.created_at = datetime.datetime.utcnow() self.created_at = datetime.datetime.utcnow()
self.items = defaultdict(dict) self.items = defaultdict(dict)
self.table_arn = self._generate_arn(table_name) self.table_arn = self._generate_arn(table_name)
@ -325,6 +386,16 @@ class Table(BaseModel):
}, },
} }
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
return self.table_arn
elif attribute_name == "StreamArn" and self.stream_specification:
return self.describe()["TableDescription"]["LatestStreamArn"]
raise UnformattedGetAttTemplateException()
@classmethod @classmethod
def create_from_cloudformation_json( def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
@ -342,6 +413,8 @@ class Table(BaseModel):
params["throughput"] = properties["ProvisionedThroughput"] params["throughput"] = properties["ProvisionedThroughput"]
if "LocalSecondaryIndexes" in properties: if "LocalSecondaryIndexes" in properties:
params["indexes"] = properties["LocalSecondaryIndexes"] params["indexes"] = properties["LocalSecondaryIndexes"]
if "StreamSpecification" in properties:
params["streams"] = properties["StreamSpecification"]
table = dynamodb_backends[region_name].create_table( table = dynamodb_backends[region_name].create_table(
name=properties["TableName"], **params name=properties["TableName"], **params
@ -374,8 +447,10 @@ class Table(BaseModel):
"KeySchema": self.schema, "KeySchema": self.schema,
"ItemCount": len(self), "ItemCount": len(self),
"CreationDateTime": unix_time(self.created_at), "CreationDateTime": unix_time(self.created_at),
"GlobalSecondaryIndexes": [index for index in self.global_indexes], "GlobalSecondaryIndexes": [
"LocalSecondaryIndexes": [index for index in self.indexes], index.describe() for index in self.global_indexes
],
"LocalSecondaryIndexes": [index.describe() for index in self.indexes],
} }
} }
if self.stream_specification and self.stream_specification["StreamEnabled"]: if self.stream_specification and self.stream_specification["StreamEnabled"]:
@ -401,7 +476,7 @@ class Table(BaseModel):
keys = [self.hash_key_attr] keys = [self.hash_key_attr]
for index in self.global_indexes: for index in self.global_indexes:
hash_key = None hash_key = None
for key in index["KeySchema"]: for key in index.schema:
if key["KeyType"] == "HASH": if key["KeyType"] == "HASH":
hash_key = key["AttributeName"] hash_key = key["AttributeName"]
keys.append(hash_key) keys.append(hash_key)
@ -412,7 +487,7 @@ class Table(BaseModel):
keys = [self.range_key_attr] keys = [self.range_key_attr]
for index in self.global_indexes: for index in self.global_indexes:
range_key = None range_key = None
for key in index["KeySchema"]: for key in index.schema:
if key["KeyType"] == "RANGE": if key["KeyType"] == "RANGE":
range_key = keys.append(key["AttributeName"]) range_key = keys.append(key["AttributeName"])
keys.append(range_key) keys.append(range_key)
@ -545,7 +620,7 @@ class Table(BaseModel):
if index_name: if index_name:
all_indexes = self.all_indexes() all_indexes = self.all_indexes()
indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) indexes_by_name = dict((i.name, i) for i in all_indexes)
if index_name not in indexes_by_name: if index_name not in indexes_by_name:
raise ValueError( raise ValueError(
"Invalid index: %s for table: %s. Available indexes are: %s" "Invalid index: %s for table: %s. Available indexes are: %s"
@ -555,14 +630,14 @@ class Table(BaseModel):
index = indexes_by_name[index_name] index = indexes_by_name[index_name]
try: try:
index_hash_key = [ index_hash_key = [
key for key in index["KeySchema"] if key["KeyType"] == "HASH" key for key in index.schema if key["KeyType"] == "HASH"
][0] ][0]
except IndexError: except IndexError:
raise ValueError("Missing Hash Key. KeySchema: %s" % index["KeySchema"]) raise ValueError("Missing Hash Key. KeySchema: %s" % index.name)
try: try:
index_range_key = [ index_range_key = [
key for key in index["KeySchema"] if key["KeyType"] == "RANGE" key for key in index.schema if key["KeyType"] == "RANGE"
][0] ][0]
except IndexError: except IndexError:
index_range_key = None index_range_key = None
@ -667,9 +742,9 @@ class Table(BaseModel):
def has_idx_items(self, index_name): def has_idx_items(self, index_name):
all_indexes = self.all_indexes() all_indexes = self.all_indexes()
indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) indexes_by_name = dict((i.name, i) for i in all_indexes)
idx = indexes_by_name[index_name] idx = indexes_by_name[index_name]
idx_col_set = set([i["AttributeName"] for i in idx["KeySchema"]]) idx_col_set = set([i["AttributeName"] for i in idx.schema])
for hash_set in self.items.values(): for hash_set in self.items.values():
if self.range_key_attr: if self.range_key_attr:
@ -692,7 +767,7 @@ class Table(BaseModel):
results = [] results = []
scanned_count = 0 scanned_count = 0
all_indexes = self.all_indexes() all_indexes = self.all_indexes()
indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) indexes_by_name = dict((i.name, i) for i in all_indexes)
if index_name: if index_name:
if index_name not in indexes_by_name: if index_name not in indexes_by_name:
@ -773,9 +848,9 @@ class Table(BaseModel):
if scanned_index: if scanned_index:
all_indexes = self.all_indexes() all_indexes = self.all_indexes()
indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) indexes_by_name = dict((i.name, i) for i in all_indexes)
idx = indexes_by_name[scanned_index] idx = indexes_by_name[scanned_index]
idx_col_list = [i["AttributeName"] for i in idx["KeySchema"]] idx_col_list = [i["AttributeName"] for i in idx.schema]
for col in idx_col_list: for col in idx_col_list:
last_evaluated_key[col] = results[-1].attrs[col] last_evaluated_key[col] = results[-1].attrs[col]
@ -885,7 +960,7 @@ class DynamoDBBackend(BaseBackend):
def update_table_global_indexes(self, name, global_index_updates): def update_table_global_indexes(self, name, global_index_updates):
table = self.tables[name] table = self.tables[name]
gsis_by_name = dict((i["IndexName"], i) for i in table.global_indexes) gsis_by_name = dict((i.name, i) for i in table.global_indexes)
for gsi_update in global_index_updates: for gsi_update in global_index_updates:
gsi_to_create = gsi_update.get("Create") gsi_to_create = gsi_update.get("Create")
gsi_to_update = gsi_update.get("Update") gsi_to_update = gsi_update.get("Update")
@ -906,7 +981,7 @@ class DynamoDBBackend(BaseBackend):
if index_name not in gsis_by_name: if index_name not in gsis_by_name:
raise ValueError( raise ValueError(
"Global Secondary Index does not exist, but tried to update: %s" "Global Secondary Index does not exist, but tried to update: %s"
% gsi_to_update["IndexName"] % index_name
) )
gsis_by_name[index_name].update(gsi_to_update) gsis_by_name[index_name].update(gsi_to_update)
@ -917,7 +992,9 @@ class DynamoDBBackend(BaseBackend):
% gsi_to_create["IndexName"] % gsi_to_create["IndexName"]
) )
gsis_by_name[gsi_to_create["IndexName"]] = gsi_to_create gsis_by_name[gsi_to_create["IndexName"]] = GlobalSecondaryIndex.create(
gsi_to_create
)
# in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other # in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other
# parts of the codebase # parts of the codebase

View File

@ -371,6 +371,26 @@ class DynamoHandler(BaseResponse):
results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}} results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}}
# Validation: Can only request up to 100 items at the same time
# Scenario 1: We're requesting more than a 100 keys from a single table
for table_name, table_request in table_batches.items():
if len(table_request["Keys"]) > 100:
return self.error(
"com.amazonaws.dynamodb.v20111205#ValidationException",
"1 validation error detected: Value at 'requestItems."
+ table_name
+ ".member.keys' failed to satisfy constraint: Member must have length less than or equal to 100",
)
# Scenario 2: We're requesting more than a 100 keys across all tables
nr_of_keys_across_all_tables = sum(
[len(req["Keys"]) for _, req in table_batches.items()]
)
if nr_of_keys_across_all_tables > 100:
return self.error(
"com.amazonaws.dynamodb.v20111205#ValidationException",
"Too many items requested for the BatchGetItem call",
)
for table_name, table_request in table_batches.items(): for table_name, table_request in table_batches.items():
keys = table_request["Keys"] keys = table_request["Keys"]
if self._contains_duplicates(keys): if self._contains_duplicates(keys):
@ -411,7 +431,6 @@ class DynamoHandler(BaseResponse):
def query(self): def query(self):
name = self.body["TableName"] name = self.body["TableName"]
# {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}}
key_condition_expression = self.body.get("KeyConditionExpression") key_condition_expression = self.body.get("KeyConditionExpression")
projection_expression = self.body.get("ProjectionExpression") projection_expression = self.body.get("ProjectionExpression")
expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
@ -439,7 +458,7 @@ class DynamoHandler(BaseResponse):
index_name = self.body.get("IndexName") index_name = self.body.get("IndexName")
if index_name: if index_name:
all_indexes = (table.global_indexes or []) + (table.indexes or []) all_indexes = (table.global_indexes or []) + (table.indexes or [])
indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) indexes_by_name = dict((i.name, i) for i in all_indexes)
if index_name not in indexes_by_name: if index_name not in indexes_by_name:
er = "com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" er = "com.amazonaws.dynamodb.v20120810#ResourceNotFoundException"
return self.error( return self.error(
@ -449,7 +468,7 @@ class DynamoHandler(BaseResponse):
), ),
) )
index = indexes_by_name[index_name]["KeySchema"] index = indexes_by_name[index_name].schema
else: else:
index = table.schema index = table.schema

View File

@ -3639,26 +3639,31 @@ class RouteBackend(object):
interface_id=None, interface_id=None,
vpc_peering_connection_id=None, vpc_peering_connection_id=None,
): ):
gateway = None
nat_gateway = None
route_table = self.get_route_table(route_table_id) route_table = self.get_route_table(route_table_id)
if interface_id: if interface_id:
self.raise_not_implemented_error("CreateRoute to NetworkInterfaceId") # for validating interface Id whether it is valid or not.
self.get_network_interface(interface_id)
gateway = None else:
if gateway_id: if gateway_id:
if EC2_RESOURCE_TO_PREFIX["vpn-gateway"] in gateway_id: if EC2_RESOURCE_TO_PREFIX["vpn-gateway"] in gateway_id:
gateway = self.get_vpn_gateway(gateway_id) gateway = self.get_vpn_gateway(gateway_id)
elif EC2_RESOURCE_TO_PREFIX["internet-gateway"] in gateway_id: elif EC2_RESOURCE_TO_PREFIX["internet-gateway"] in gateway_id:
gateway = self.get_internet_gateway(gateway_id) gateway = self.get_internet_gateway(gateway_id)
try: try:
ipaddress.IPv4Network(six.text_type(destination_cidr_block), strict=False) ipaddress.IPv4Network(
except ValueError: six.text_type(destination_cidr_block), strict=False
raise InvalidDestinationCIDRBlockParameterError(destination_cidr_block) )
except ValueError:
raise InvalidDestinationCIDRBlockParameterError(destination_cidr_block)
nat_gateway = None if nat_gateway_id is not None:
if nat_gateway_id is not None: nat_gateway = self.nat_gateways.get(nat_gateway_id)
nat_gateway = self.nat_gateways.get(nat_gateway_id)
route = Route( route = Route(
route_table, route_table,

View File

@ -125,7 +125,7 @@ DESCRIBE_IMAGES_RESPONSE = """<DescribeImagesResponse xmlns="http://ec2.amazonaw
<snapshotId>{{ image.ebs_snapshot.id }}</snapshotId> <snapshotId>{{ image.ebs_snapshot.id }}</snapshotId>
<volumeSize>15</volumeSize> <volumeSize>15</volumeSize>
<deleteOnTermination>false</deleteOnTermination> <deleteOnTermination>false</deleteOnTermination>
<volumeType>{{ image.root_device_type }}</volumeType> <volumeType>standard</volumeType>
</ebs> </ebs>
</item> </item>
</blockDeviceMapping> </blockDeviceMapping>

View File

@ -13,6 +13,7 @@ from moto.elbv2 import elbv2_backends
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from copy import deepcopy from copy import deepcopy
import six
class InstanceResponse(BaseResponse): class InstanceResponse(BaseResponse):
@ -283,15 +284,15 @@ class InstanceResponse(BaseResponse):
device_template["Ebs"]["VolumeSize"] = device_mapping.get( device_template["Ebs"]["VolumeSize"] = device_mapping.get(
"ebs._volume_size" "ebs._volume_size"
) )
device_template["Ebs"]["DeleteOnTermination"] = device_mapping.get( device_template["Ebs"]["DeleteOnTermination"] = self._convert_to_bool(
"ebs._delete_on_termination", False device_mapping.get("ebs._delete_on_termination", False)
) )
device_template["Ebs"]["VolumeType"] = device_mapping.get( device_template["Ebs"]["VolumeType"] = device_mapping.get(
"ebs._volume_type" "ebs._volume_type"
) )
device_template["Ebs"]["Iops"] = device_mapping.get("ebs._iops") device_template["Ebs"]["Iops"] = device_mapping.get("ebs._iops")
device_template["Ebs"]["Encrypted"] = device_mapping.get( device_template["Ebs"]["Encrypted"] = self._convert_to_bool(
"ebs._encrypted", False device_mapping.get("ebs._encrypted", False)
) )
mappings.append(device_template) mappings.append(device_template)
@ -308,6 +309,16 @@ class InstanceResponse(BaseResponse):
): ):
raise MissingParameterError("size or snapshotId") raise MissingParameterError("size or snapshotId")
@staticmethod
def _convert_to_bool(bool_str):
if isinstance(bool_str, bool):
return bool_str
if isinstance(bool_str, six.text_type):
return str(bool_str).lower() == "true"
return False
BLOCK_DEVICE_MAPPING_TEMPLATE = { BLOCK_DEVICE_MAPPING_TEMPLATE = {
"VirtualName": None, "VirtualName": None,

View File

@ -2083,6 +2083,16 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<UserName>{{ user.name }}</UserName> <UserName>{{ user.name }}</UserName>
<Arn>{{ user.arn }}</Arn> <Arn>{{ user.arn }}</Arn>
<CreateDate>{{ user.created_iso_8601 }}</CreateDate> <CreateDate>{{ user.created_iso_8601 }}</CreateDate>
{% if user.policies %}
<UserPolicyList>
{% for policy in user.policies %}
<member>
<PolicyName>{{ policy }}</PolicyName>
<PolicyDocument>{{ user.policies[policy] }}</PolicyDocument>
</member>
{% endfor %}
</UserPolicyList>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</UserDetailList> </UserDetailList>
@ -2106,7 +2116,7 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
{% for policy in group.policies %} {% for policy in group.policies %}
<member> <member>
<PolicyName>{{ policy }}</PolicyName> <PolicyName>{{ policy }}</PolicyName>
<PolicyDocument>{{ group.get_policy(policy) }}</PolicyDocument> <PolicyDocument>{{ group.policies[policy] }}</PolicyDocument>
</member> </member>
{% endfor %} {% endfor %}
</GroupPolicyList> </GroupPolicyList>

View File

@ -5,6 +5,7 @@ import json
import os import os
import base64 import base64
import datetime import datetime
import pytz
import hashlib import hashlib
import copy import copy
import itertools import itertools
@ -776,7 +777,7 @@ class FakeBucket(BaseModel):
self.notification_configuration = None self.notification_configuration = None
self.accelerate_configuration = None self.accelerate_configuration = None
self.payer = "BucketOwner" self.payer = "BucketOwner"
self.creation_date = datetime.datetime.utcnow() self.creation_date = datetime.datetime.now(tz=pytz.utc)
self.public_access_block = None self.public_access_block = None
self.encryption = None self.encryption = None
@ -1315,7 +1316,7 @@ class S3Backend(BaseBackend):
return self.account_public_access_block return self.account_public_access_block
def set_key( def set_object(
self, bucket_name, key_name, value, storage=None, etag=None, multipart=None self, bucket_name, key_name, value, storage=None, etag=None, multipart=None
): ):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
@ -1346,11 +1347,11 @@ class S3Backend(BaseBackend):
def append_to_key(self, bucket_name, key_name, value): def append_to_key(self, bucket_name, key_name, value):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
key = self.get_key(bucket_name, key_name) key = self.get_object(bucket_name, key_name)
key.append_to_value(value) key.append_to_value(value)
return key return key
def get_key(self, bucket_name, key_name, version_id=None, part_number=None): def get_object(self, bucket_name, key_name, version_id=None, part_number=None):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
key = None key = None
@ -1385,11 +1386,11 @@ class S3Backend(BaseBackend):
) )
return key return key
def get_bucket_tags(self, bucket_name): def get_bucket_tagging(self, bucket_name):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
return self.tagger.list_tags_for_resource(bucket.arn) return self.tagger.list_tags_for_resource(bucket.arn)
def put_bucket_tags(self, bucket_name, tags): def put_bucket_tagging(self, bucket_name, tags):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
self.tagger.delete_all_tags_for_resource(bucket.arn) self.tagger.delete_all_tags_for_resource(bucket.arn)
self.tagger.tag_resource( self.tagger.tag_resource(
@ -1481,7 +1482,7 @@ class S3Backend(BaseBackend):
return return
del bucket.multiparts[multipart_id] del bucket.multiparts[multipart_id]
key = self.set_key( key = self.set_object(
bucket_name, multipart.key_name, value, etag=etag, multipart=multipart bucket_name, multipart.key_name, value, etag=etag, multipart=multipart
) )
key.set_metadata(multipart.metadata) key.set_metadata(multipart.metadata)
@ -1521,7 +1522,7 @@ class S3Backend(BaseBackend):
dest_bucket = self.get_bucket(dest_bucket_name) dest_bucket = self.get_bucket(dest_bucket_name)
multipart = dest_bucket.multiparts[multipart_id] multipart = dest_bucket.multiparts[multipart_id]
src_value = self.get_key( src_value = self.get_object(
src_bucket_name, src_key_name, version_id=src_version_id src_bucket_name, src_key_name, version_id=src_version_id
).value ).value
if start_byte is not None: if start_byte is not None:
@ -1565,7 +1566,7 @@ class S3Backend(BaseBackend):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
bucket.keys[key_name] = FakeDeleteMarker(key=bucket.keys[key_name]) bucket.keys[key_name] = FakeDeleteMarker(key=bucket.keys[key_name])
def delete_key(self, bucket_name, key_name, version_id=None): def delete_object(self, bucket_name, key_name, version_id=None):
key_name = clean_key_name(key_name) key_name = clean_key_name(key_name)
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
@ -1606,7 +1607,7 @@ class S3Backend(BaseBackend):
src_key_name = clean_key_name(src_key_name) src_key_name = clean_key_name(src_key_name)
dest_key_name = clean_key_name(dest_key_name) dest_key_name = clean_key_name(dest_key_name)
dest_bucket = self.get_bucket(dest_bucket_name) dest_bucket = self.get_bucket(dest_bucket_name)
key = self.get_key(src_bucket_name, src_key_name, version_id=src_version_id) key = self.get_object(src_bucket_name, src_key_name, version_id=src_version_id)
new_key = key.copy(dest_key_name, dest_bucket.is_versioned) new_key = key.copy(dest_key_name, dest_bucket.is_versioned)
self.tagger.copy_tags(key.arn, new_key.arn) self.tagger.copy_tags(key.arn, new_key.arn)
@ -1626,5 +1627,17 @@ class S3Backend(BaseBackend):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
return bucket.acl return bucket.acl
def get_bucket_cors(self, bucket_name):
bucket = self.get_bucket(bucket_name)
return bucket.cors
def get_bucket_logging(self, bucket_name):
bucket = self.get_bucket(bucket_name)
return bucket.logging
def get_bucket_notification_configuration(self, bucket_name):
bucket = self.get_bucket(bucket_name)
return bucket.notification_configuration
s3_backend = S3Backend() s3_backend = S3Backend()

View File

@ -382,7 +382,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_OBJECT_ACL_RESPONSE) template = self.response_template(S3_OBJECT_ACL_RESPONSE)
return template.render(obj=bucket) return template.render(obj=bucket)
elif "tagging" in querystring: elif "tagging" in querystring:
tags = self.backend.get_bucket_tags(bucket_name)["Tags"] tags = self.backend.get_bucket_tagging(bucket_name)["Tags"]
# "Special Error" if no tags: # "Special Error" if no tags:
if len(tags) == 0: if len(tags) == 0:
template = self.response_template(S3_NO_BUCKET_TAGGING) template = self.response_template(S3_NO_BUCKET_TAGGING)
@ -390,25 +390,27 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
template = self.response_template(S3_OBJECT_TAGGING_RESPONSE) template = self.response_template(S3_OBJECT_TAGGING_RESPONSE)
return template.render(tags=tags) return template.render(tags=tags)
elif "logging" in querystring: elif "logging" in querystring:
bucket = self.backend.get_bucket(bucket_name) logging = self.backend.get_bucket_logging(bucket_name)
if not bucket.logging: if not logging:
template = self.response_template(S3_NO_LOGGING_CONFIG) template = self.response_template(S3_NO_LOGGING_CONFIG)
return 200, {}, template.render() return 200, {}, template.render()
template = self.response_template(S3_LOGGING_CONFIG) template = self.response_template(S3_LOGGING_CONFIG)
return 200, {}, template.render(logging=bucket.logging) return 200, {}, template.render(logging=logging)
elif "cors" in querystring: elif "cors" in querystring:
bucket = self.backend.get_bucket(bucket_name) cors = self.backend.get_bucket_cors(bucket_name)
if len(bucket.cors) == 0: if len(cors) == 0:
template = self.response_template(S3_NO_CORS_CONFIG) template = self.response_template(S3_NO_CORS_CONFIG)
return 404, {}, template.render(bucket_name=bucket_name) return 404, {}, template.render(bucket_name=bucket_name)
template = self.response_template(S3_BUCKET_CORS_RESPONSE) template = self.response_template(S3_BUCKET_CORS_RESPONSE)
return template.render(bucket=bucket) return template.render(cors=cors)
elif "notification" in querystring: elif "notification" in querystring:
bucket = self.backend.get_bucket(bucket_name) notification_configuration = self.backend.get_bucket_notification_configuration(
if not bucket.notification_configuration: bucket_name
)
if not notification_configuration:
return 200, {}, "" return 200, {}, ""
template = self.response_template(S3_GET_BUCKET_NOTIFICATION_CONFIG) template = self.response_template(S3_GET_BUCKET_NOTIFICATION_CONFIG)
return template.render(bucket=bucket) return template.render(config=notification_configuration)
elif "accelerate" in querystring: elif "accelerate" in querystring:
bucket = self.backend.get_bucket(bucket_name) bucket = self.backend.get_bucket(bucket_name)
if bucket.accelerate_configuration is None: if bucket.accelerate_configuration is None:
@ -613,6 +615,19 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
pass pass
return False return False
def _create_bucket_configuration_is_empty(self, body):
if body:
try:
create_bucket_configuration = xmltodict.parse(body)[
"CreateBucketConfiguration"
]
del create_bucket_configuration["@xmlns"]
if len(create_bucket_configuration) == 0:
return True
except KeyError:
pass
return False
def _parse_pab_config(self, body): def _parse_pab_config(self, body):
parsed_xml = xmltodict.parse(body) parsed_xml = xmltodict.parse(body)
parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None) parsed_xml["PublicAccessBlockConfiguration"].pop("@xmlns", None)
@ -663,7 +678,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return "" return ""
elif "tagging" in querystring: elif "tagging" in querystring:
tagging = self._bucket_tagging_from_xml(body) tagging = self._bucket_tagging_from_xml(body)
self.backend.put_bucket_tags(bucket_name, tagging) self.backend.put_bucket_tagging(bucket_name, tagging)
return "" return ""
elif "website" in querystring: elif "website" in querystring:
self.backend.set_bucket_website_configuration(bucket_name, body) self.backend.set_bucket_website_configuration(bucket_name, body)
@ -731,6 +746,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
): ):
raise IllegalLocationConstraintException() raise IllegalLocationConstraintException()
if body: if body:
if self._create_bucket_configuration_is_empty(body):
raise MalformedXML()
try: try:
forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][ forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][
"LocationConstraint" "LocationConstraint"
@ -840,7 +858,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else: else:
status_code = 204 status_code = 204
new_key = self.backend.set_key(bucket_name, key, f) new_key = self.backend.set_object(bucket_name, key, f)
# Metadata # Metadata
metadata = metadata_from_headers(form) metadata = metadata_from_headers(form)
@ -879,7 +897,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
key_name = object_["Key"] key_name = object_["Key"]
version_id = object_.get("VersionId", None) version_id = object_.get("VersionId", None)
success = self.backend.delete_key( success = self.backend.delete_object(
bucket_name, undo_clean_key_name(key_name), version_id=version_id bucket_name, undo_clean_key_name(key_name), version_id=version_id
) )
if success: if success:
@ -1056,7 +1074,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
signed_url = "Signature=" in request.url signed_url = "Signature=" in request.url
elif hasattr(request, "requestline"): elif hasattr(request, "requestline"):
signed_url = "Signature=" in request.path signed_url = "Signature=" in request.path
key = self.backend.get_key(bucket_name, key_name) key = self.backend.get_object(bucket_name, key_name)
if key: if key:
if not key.acl.public_read and not signed_url: if not key.acl.public_read and not signed_url:
@ -1118,7 +1136,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
version_id = query.get("versionId", [None])[0] version_id = query.get("versionId", [None])[0]
if_modified_since = headers.get("If-Modified-Since", None) if_modified_since = headers.get("If-Modified-Since", None)
key = self.backend.get_key(bucket_name, key_name, version_id=version_id) key = self.backend.get_object(bucket_name, key_name, version_id=version_id)
if key is None: if key is None:
raise MissingKey(key_name) raise MissingKey(key_name)
if if_modified_since: if if_modified_since:
@ -1164,7 +1182,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
except ValueError: except ValueError:
start_byte, end_byte = None, None start_byte, end_byte = None, None
if self.backend.get_key(src_bucket, src_key, version_id=src_version_id): if self.backend.get_object(
src_bucket, src_key, version_id=src_version_id
):
key = self.backend.copy_part( key = self.backend.copy_part(
bucket_name, bucket_name,
upload_id, upload_id,
@ -1193,7 +1213,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
tagging = self._tagging_from_headers(request.headers) tagging = self._tagging_from_headers(request.headers)
if "acl" in query: if "acl" in query:
key = self.backend.get_key(bucket_name, key_name) key = self.backend.get_object(bucket_name, key_name)
# TODO: Support the XML-based ACL format # TODO: Support the XML-based ACL format
key.set_acl(acl) key.set_acl(acl)
return 200, response_headers, "" return 200, response_headers, ""
@ -1203,7 +1223,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
version_id = query["versionId"][0] version_id = query["versionId"][0]
else: else:
version_id = None version_id = None
key = self.backend.get_key(bucket_name, key_name, version_id=version_id) key = self.backend.get_object(bucket_name, key_name, version_id=version_id)
tagging = self._tagging_from_xml(body) tagging = self._tagging_from_xml(body)
self.backend.set_key_tags(key, tagging, key_name) self.backend.set_key_tags(key, tagging, key_name)
return 200, response_headers, "" return 200, response_headers, ""
@ -1221,7 +1241,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
src_version_id = parse_qs(src_key_parsed.query).get("versionId", [None])[0] src_version_id = parse_qs(src_key_parsed.query).get("versionId", [None])[0]
key = self.backend.get_key(src_bucket, src_key, version_id=src_version_id) key = self.backend.get_object(
src_bucket, src_key, version_id=src_version_id
)
if key is not None: if key is not None:
if key.storage_class in ["GLACIER", "DEEP_ARCHIVE"]: if key.storage_class in ["GLACIER", "DEEP_ARCHIVE"]:
@ -1238,7 +1260,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else: else:
return 404, response_headers, "" return 404, response_headers, ""
new_key = self.backend.get_key(bucket_name, key_name) new_key = self.backend.get_object(bucket_name, key_name)
mdirective = request.headers.get("x-amz-metadata-directive") mdirective = request.headers.get("x-amz-metadata-directive")
if mdirective is not None and mdirective == "REPLACE": if mdirective is not None and mdirective == "REPLACE":
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
@ -1254,13 +1276,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
closing_connection = headers.get("connection") == "close" closing_connection = headers.get("connection") == "close"
if closing_connection and streaming_request: if closing_connection and streaming_request:
# Closing the connection of a streaming request. No more data # Closing the connection of a streaming request. No more data
new_key = self.backend.get_key(bucket_name, key_name) new_key = self.backend.get_object(bucket_name, key_name)
elif streaming_request: elif streaming_request:
# Streaming request, more data # Streaming request, more data
new_key = self.backend.append_to_key(bucket_name, key_name, body) new_key = self.backend.append_to_key(bucket_name, key_name, body)
else: else:
# Initial data # Initial data
new_key = self.backend.set_key( new_key = self.backend.set_object(
bucket_name, key_name, body, storage=storage_class bucket_name, key_name, body, storage=storage_class
) )
request.streaming = True request.streaming = True
@ -1286,7 +1308,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if if_modified_since: if if_modified_since:
if_modified_since = str_to_rfc_1123_datetime(if_modified_since) if_modified_since = str_to_rfc_1123_datetime(if_modified_since)
key = self.backend.get_key( key = self.backend.get_object(
bucket_name, key_name, version_id=version_id, part_number=part_number bucket_name, key_name, version_id=version_id, part_number=part_number
) )
if key: if key:
@ -1596,7 +1618,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.backend.cancel_multipart(bucket_name, upload_id) self.backend.cancel_multipart(bucket_name, upload_id)
return 204, {}, "" return 204, {}, ""
version_id = query.get("versionId", [None])[0] version_id = query.get("versionId", [None])[0]
self.backend.delete_key(bucket_name, key_name, version_id=version_id) self.backend.delete_object(bucket_name, key_name, version_id=version_id)
return 204, {}, "" return 204, {}, ""
def _complete_multipart_body(self, body): def _complete_multipart_body(self, body):
@ -1633,7 +1655,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
elif "restore" in query: elif "restore" in query:
es = minidom.parseString(body).getElementsByTagName("Days") es = minidom.parseString(body).getElementsByTagName("Days")
days = es[0].childNodes[0].wholeText days = es[0].childNodes[0].wholeText
key = self.backend.get_key(bucket_name, key_name) key = self.backend.get_object(bucket_name, key_name)
r = 202 r = 202
if key.expiry_date is not None: if key.expiry_date is not None:
r = 200 r = 200
@ -1959,7 +1981,7 @@ S3_OBJECT_TAGGING_RESPONSE = """\
S3_BUCKET_CORS_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?> S3_BUCKET_CORS_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<CORSConfiguration> <CORSConfiguration>
{% for cors in bucket.cors %} {% for cors in cors %}
<CORSRule> <CORSRule>
{% for origin in cors.allowed_origins %} {% for origin in cors.allowed_origins %}
<AllowedOrigin>{{ origin }}</AllowedOrigin> <AllowedOrigin>{{ origin }}</AllowedOrigin>
@ -2192,7 +2214,7 @@ S3_NO_ENCRYPTION = """<?xml version="1.0" encoding="UTF-8"?>
S3_GET_BUCKET_NOTIFICATION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?> S3_GET_BUCKET_NOTIFICATION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
<NotificationConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/"> <NotificationConfiguration xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
{% for topic in bucket.notification_configuration.topic %} {% for topic in config.topic %}
<TopicConfiguration> <TopicConfiguration>
<Id>{{ topic.id }}</Id> <Id>{{ topic.id }}</Id>
<Topic>{{ topic.arn }}</Topic> <Topic>{{ topic.arn }}</Topic>
@ -2213,7 +2235,7 @@ S3_GET_BUCKET_NOTIFICATION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
{% endif %} {% endif %}
</TopicConfiguration> </TopicConfiguration>
{% endfor %} {% endfor %}
{% for queue in bucket.notification_configuration.queue %} {% for queue in config.queue %}
<QueueConfiguration> <QueueConfiguration>
<Id>{{ queue.id }}</Id> <Id>{{ queue.id }}</Id>
<Queue>{{ queue.arn }}</Queue> <Queue>{{ queue.arn }}</Queue>
@ -2234,7 +2256,7 @@ S3_GET_BUCKET_NOTIFICATION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
{% endif %} {% endif %}
</QueueConfiguration> </QueueConfiguration>
{% endfor %} {% endfor %}
{% for cf in bucket.notification_configuration.cloud_function %} {% for cf in config.cloud_function %}
<CloudFunctionConfiguration> <CloudFunctionConfiguration>
<Id>{{ cf.id }}</Id> <Id>{{ cf.id }}</Id>
<CloudFunction>{{ cf.arn }}</CloudFunction> <CloudFunction>{{ cf.arn }}</CloudFunction>

View File

@ -38,6 +38,10 @@ class SecretsStore(dict):
new_key = get_secret_name_from_arn(key) new_key = get_secret_name_from_arn(key)
return dict.__contains__(self, new_key) return dict.__contains__(self, new_key)
def pop(self, key, *args, **kwargs):
new_key = get_secret_name_from_arn(key)
return super(SecretsStore, self).pop(new_key, *args, **kwargs)
class SecretsManagerBackend(BaseBackend): class SecretsManagerBackend(BaseBackend):
def __init__(self, region_name=None, **kwargs): def __init__(self, region_name=None, **kwargs):

View File

@ -41,3 +41,26 @@ class TemplateDoesNotExist(RESTError):
def __init__(self, message): def __init__(self, message):
super(TemplateDoesNotExist, self).__init__("TemplateDoesNotExist", message) super(TemplateDoesNotExist, self).__init__("TemplateDoesNotExist", message)
class RuleSetNameAlreadyExists(RESTError):
code = 400
def __init__(self, message):
super(RuleSetNameAlreadyExists, self).__init__(
"RuleSetNameAlreadyExists", message
)
class RuleAlreadyExists(RESTError):
code = 400
def __init__(self, message):
super(RuleAlreadyExists, self).__init__("RuleAlreadyExists", message)
class RuleSetDoesNotExist(RESTError):
code = 400
def __init__(self, message):
super(RuleSetDoesNotExist, self).__init__("RuleSetDoesNotExist", message)

View File

@ -12,6 +12,9 @@ from .exceptions import (
EventDestinationAlreadyExists, EventDestinationAlreadyExists,
TemplateNameAlreadyExists, TemplateNameAlreadyExists,
TemplateDoesNotExist, TemplateDoesNotExist,
RuleSetNameAlreadyExists,
RuleSetDoesNotExist,
RuleAlreadyExists,
) )
from .utils import get_random_message_id from .utils import get_random_message_id
from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY from .feedback import COMMON_MAIL, BOUNCE, COMPLAINT, DELIVERY
@ -94,6 +97,7 @@ class SESBackend(BaseBackend):
self.config_set_event_destination = {} self.config_set_event_destination = {}
self.event_destinations = {} self.event_destinations = {}
self.templates = {} self.templates = {}
self.receipt_rule_set = {}
def _is_verified_address(self, source): def _is_verified_address(self, source):
_, address = parseaddr(source) _, address = parseaddr(source)
@ -294,5 +298,19 @@ class SESBackend(BaseBackend):
def list_templates(self): def list_templates(self):
return list(self.templates.values()) return list(self.templates.values())
def create_receipt_rule_set(self, rule_set_name):
if self.receipt_rule_set.get(rule_set_name) is not None:
raise RuleSetNameAlreadyExists("Duplicate receipt rule set Name.")
self.receipt_rule_set[rule_set_name] = []
def create_receipt_rule(self, rule_set_name, rule):
rule_set = self.receipt_rule_set.get(rule_set_name)
if rule_set is None:
raise RuleSetDoesNotExist("Invalid Rule Set Name.")
if rule in rule_set:
raise RuleAlreadyExists("Duplicate Rule Name.")
rule_set.append(rule)
self.receipt_rule_set[rule_set_name] = rule_set
ses_backend = SESBackend() ses_backend = SESBackend()

View File

@ -199,6 +199,19 @@ class EmailResponse(BaseResponse):
template = self.response_template(LIST_TEMPLATES) template = self.response_template(LIST_TEMPLATES)
return template.render(templates=email_templates) return template.render(templates=email_templates)
def create_receipt_rule_set(self):
rule_set_name = self._get_param("RuleSetName")
ses_backend.create_receipt_rule_set(rule_set_name)
template = self.response_template(CREATE_RECEIPT_RULE_SET)
return template.render()
def create_receipt_rule(self):
rule_set_name = self._get_param("RuleSetName")
rule = self._get_dict_param("Rule")
ses_backend.create_receipt_rule(rule_set_name, rule)
template = self.response_template(CREATE_RECEIPT_RULE)
return template.render()
VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/"> VERIFY_EMAIL_IDENTITY = """<VerifyEmailIdentityResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<VerifyEmailIdentityResult/> <VerifyEmailIdentityResult/>
@ -385,3 +398,17 @@ LIST_TEMPLATES = """<ListTemplatesResponse xmlns="http://ses.amazonaws.com/doc/2
<RequestId>47e0ef1a-9bf2-11e1-9279-0100e8cf12ba</RequestId> <RequestId>47e0ef1a-9bf2-11e1-9279-0100e8cf12ba</RequestId>
</ResponseMetadata> </ResponseMetadata>
</ListTemplatesResponse>""" </ListTemplatesResponse>"""
CREATE_RECEIPT_RULE_SET = """<CreateReceiptRuleSetResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<CreateReceiptRuleSetResult/>
<ResponseMetadata>
<RequestId>47e0ef1a-9bf2-11e1-9279-01ab88cf109a</RequestId>
</ResponseMetadata>
</CreateReceiptRuleSetResponse>"""
CREATE_RECEIPT_RULE = """<CreateReceiptRuleResponse xmlns="http://ses.amazonaws.com/doc/2010-12-01/">
<CreateReceiptRuleResult/>
<ResponseMetadata>
<RequestId>15e0ef1a-9bf2-11e1-9279-01ab88cf109a</RequestId>
</ResponseMetadata>
</CreateReceiptRuleResponse>"""

View File

@ -544,6 +544,7 @@ def test_integration_response():
selectionPattern="foobar", selectionPattern="foobar",
responseTemplates={}, responseTemplates={},
) )
# this is hard to match against, so remove it # this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None) response["ResponseMetadata"].pop("RetryAttempts", None)
@ -592,6 +593,63 @@ def test_integration_response():
response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET")
response["methodIntegration"]["integrationResponses"].should.equal({}) response["methodIntegration"]["integrationResponses"].should.equal({})
# adding a new method and perfomring put intergration with contentHandling as CONVERT_TO_BINARY
client.put_method(
restApiId=api_id, resourceId=root_id, httpMethod="PUT", authorizationType="none"
)
client.put_method_response(
restApiId=api_id, resourceId=root_id, httpMethod="PUT", statusCode="200"
)
client.put_integration(
restApiId=api_id,
resourceId=root_id,
httpMethod="PUT",
type="HTTP",
uri="http://httpbin.org/robots.txt",
integrationHttpMethod="POST",
)
response = client.put_integration_response(
restApiId=api_id,
resourceId=root_id,
httpMethod="PUT",
statusCode="200",
selectionPattern="foobar",
responseTemplates={},
contentHandling="CONVERT_TO_BINARY",
)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"statusCode": "200",
"selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None},
"contentHandling": "CONVERT_TO_BINARY",
}
)
response = client.get_integration_response(
restApiId=api_id, resourceId=root_id, httpMethod="PUT", statusCode="200"
)
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response.should.equal(
{
"statusCode": "200",
"selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None},
"contentHandling": "CONVERT_TO_BINARY",
}
)
@mock_apigateway @mock_apigateway
@mock_cognitoidp @mock_cognitoidp

View File

@ -172,6 +172,44 @@ def test_stop_query_execution():
details["Status"]["State"].should.equal("CANCELLED") details["Status"]["State"].should.equal("CANCELLED")
@mock_athena
def test_create_named_query():
client = boto3.client("athena", region_name="us-east-1")
# craete named query
res = client.create_named_query(
Name="query-name", Database="target_db", QueryString="SELECT * FROM table1",
)
assert "NamedQueryId" in res
@mock_athena
def test_get_named_query():
client = boto3.client("athena", region_name="us-east-1")
query_name = "query-name"
database = "target_db"
query_string = "SELECT * FROM tbl1"
description = "description of this query"
# craete named query
res_create = client.create_named_query(
Name=query_name,
Database=database,
QueryString=query_string,
Description=description,
)
query_id = res_create["NamedQueryId"]
# get named query
res_get = client.get_named_query(NamedQueryId=query_id)["NamedQuery"]
res_get["Name"].should.equal(query_name)
res_get["Description"].should.equal(description)
res_get["Database"].should.equal(database)
res_get["QueryString"].should.equal(query_string)
res_get["NamedQueryId"].should.equal(query_id)
def create_basic_workgroup(client, name): def create_basic_workgroup(client, name):
client.create_work_group( client.create_work_group(
Name=name, Name=name,

View File

@ -1446,11 +1446,12 @@ def test_update_event_source_mapping():
assert response["State"] == "Enabled" assert response["State"] == "Enabled"
mapping = conn.update_event_source_mapping( mapping = conn.update_event_source_mapping(
UUID=response["UUID"], Enabled=False, BatchSize=15, FunctionName="testFunction2" UUID=response["UUID"], Enabled=False, BatchSize=2, FunctionName="testFunction2"
) )
assert mapping["UUID"] == response["UUID"] assert mapping["UUID"] == response["UUID"]
assert mapping["FunctionArn"] == func2["FunctionArn"] assert mapping["FunctionArn"] == func2["FunctionArn"]
assert mapping["State"] == "Disabled" assert mapping["State"] == "Disabled"
assert mapping["BatchSize"] == 2
@mock_lambda @mock_lambda

View File

@ -3,7 +3,7 @@ import io
import sure # noqa import sure # noqa
import zipfile import zipfile
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_cloudformation, mock_iam, mock_lambda, mock_s3 from moto import mock_cloudformation, mock_iam, mock_lambda, mock_s3, mock_sqs
from nose.tools import assert_raises from nose.tools import assert_raises
from string import Template from string import Template
from uuid import uuid4 from uuid import uuid4
@ -48,6 +48,23 @@ template = Template(
}""" }"""
) )
event_source_mapping_template = Template(
"""{
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
"$resource_name": {
"Type": "AWS::Lambda::EventSourceMapping",
"Properties": {
"BatchSize": $batch_size,
"EventSourceArn": $event_source_arn,
"FunctionName": $function_name,
"Enabled": $enabled
}
}
}
}"""
)
@mock_cloudformation @mock_cloudformation
@mock_lambda @mock_lambda
@ -97,6 +114,194 @@ def test_lambda_can_be_deleted_by_cloudformation():
e.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") e.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException")
@mock_cloudformation
@mock_lambda
@mock_s3
@mock_sqs
def test_event_source_mapping_create_from_cloudformation_json():
sqs = boto3.resource("sqs", region_name="us-east-1")
s3 = boto3.client("s3", "us-east-1")
cf = boto3.client("cloudformation", region_name="us-east-1")
lmbda = boto3.client("lambda", region_name="us-east-1")
queue = sqs.create_queue(QueueName="test-sqs-queue1")
# Creates lambda
_, lambda_stack = create_stack(cf, s3)
created_fn_name = get_created_function_name(cf, lambda_stack)
created_fn_arn = lmbda.get_function(FunctionName=created_fn_name)["Configuration"][
"FunctionArn"
]
template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
"event_source_arn": queue.attributes["QueueArn"],
"function_name": created_fn_name,
"enabled": True,
}
)
cf.create_stack(StackName="test-event-source", TemplateBody=template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1)
event_source = event_sources["EventSourceMappings"][0]
event_source["EventSourceArn"].should.be.equal(queue.attributes["QueueArn"])
event_source["FunctionArn"].should.be.equal(created_fn_arn)
@mock_cloudformation
@mock_lambda
@mock_s3
@mock_sqs
def test_event_source_mapping_delete_stack():
sqs = boto3.resource("sqs", region_name="us-east-1")
s3 = boto3.client("s3", "us-east-1")
cf = boto3.client("cloudformation", region_name="us-east-1")
lmbda = boto3.client("lambda", region_name="us-east-1")
queue = sqs.create_queue(QueueName="test-sqs-queue1")
# Creates lambda
_, lambda_stack = create_stack(cf, s3)
created_fn_name = get_created_function_name(cf, lambda_stack)
template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
"event_source_arn": queue.attributes["QueueArn"],
"function_name": created_fn_name,
"enabled": True,
}
)
esm_stack = cf.create_stack(StackName="test-event-source", TemplateBody=template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1)
cf.delete_stack(StackName=esm_stack["StackId"])
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(0)
@mock_cloudformation
@mock_lambda
@mock_s3
@mock_sqs
def test_event_source_mapping_update_from_cloudformation_json():
sqs = boto3.resource("sqs", region_name="us-east-1")
s3 = boto3.client("s3", "us-east-1")
cf = boto3.client("cloudformation", region_name="us-east-1")
lmbda = boto3.client("lambda", region_name="us-east-1")
queue = sqs.create_queue(QueueName="test-sqs-queue1")
# Creates lambda
_, lambda_stack = create_stack(cf, s3)
created_fn_name = get_created_function_name(cf, lambda_stack)
created_fn_arn = lmbda.get_function(FunctionName=created_fn_name)["Configuration"][
"FunctionArn"
]
original_template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
"event_source_arn": queue.attributes["QueueArn"],
"function_name": created_fn_name,
"enabled": True,
}
)
cf.create_stack(StackName="test-event-source", TemplateBody=original_template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
original_esm = event_sources["EventSourceMappings"][0]
original_esm["State"].should.equal("Enabled")
original_esm["BatchSize"].should.equal(1)
# Update
new_template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 10,
"event_source_arn": queue.attributes["QueueArn"],
"function_name": created_fn_name,
"enabled": False,
}
)
cf.update_stack(StackName="test-event-source", TemplateBody=new_template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
updated_esm = event_sources["EventSourceMappings"][0]
updated_esm["State"].should.equal("Disabled")
updated_esm["BatchSize"].should.equal(10)
@mock_cloudformation
@mock_lambda
@mock_s3
@mock_sqs
def test_event_source_mapping_delete_from_cloudformation_json():
sqs = boto3.resource("sqs", region_name="us-east-1")
s3 = boto3.client("s3", "us-east-1")
cf = boto3.client("cloudformation", region_name="us-east-1")
lmbda = boto3.client("lambda", region_name="us-east-1")
queue = sqs.create_queue(QueueName="test-sqs-queue1")
# Creates lambda
_, lambda_stack = create_stack(cf, s3)
created_fn_name = get_created_function_name(cf, lambda_stack)
created_fn_arn = lmbda.get_function(FunctionName=created_fn_name)["Configuration"][
"FunctionArn"
]
original_template = event_source_mapping_template.substitute(
{
"resource_name": "Foo",
"batch_size": 1,
"event_source_arn": queue.attributes["QueueArn"],
"function_name": created_fn_name,
"enabled": True,
}
)
cf.create_stack(StackName="test-event-source", TemplateBody=original_template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
original_esm = event_sources["EventSourceMappings"][0]
original_esm["State"].should.equal("Enabled")
original_esm["BatchSize"].should.equal(1)
# Update with deletion of old resources
new_template = event_source_mapping_template.substitute(
{
"resource_name": "Bar", # changed name
"batch_size": 10,
"event_source_arn": queue.attributes["QueueArn"],
"function_name": created_fn_name,
"enabled": False,
}
)
cf.update_stack(StackName="test-event-source", TemplateBody=new_template)
event_sources = lmbda.list_event_source_mappings(FunctionName=created_fn_name)
event_sources["EventSourceMappings"].should.have.length_of(1)
updated_esm = event_sources["EventSourceMappings"][0]
updated_esm["State"].should.equal("Disabled")
updated_esm["BatchSize"].should.equal(10)
updated_esm["UUID"].shouldnt.equal(original_esm["UUID"])
def create_stack(cf, s3): def create_stack(cf, s3):
bucket_name = str(uuid4()) bucket_name = str(uuid4())
s3.create_bucket(Bucket=bucket_name) s3.create_bucket(Bucket=bucket_name)

View File

@ -98,12 +98,12 @@ def test_create_stack_hosted_zone_by_id():
}, },
} }
conn.create_stack( conn.create_stack(
"test_stack", template_body=json.dumps(dummy_template), parameters={}.items() "test_stack1", template_body=json.dumps(dummy_template), parameters={}.items()
) )
r53_conn = boto.connect_route53() r53_conn = boto.connect_route53()
zone_id = r53_conn.get_zones()[0].id zone_id = r53_conn.get_zones()[0].id
conn.create_stack( conn.create_stack(
"test_stack", "test_stack2",
template_body=json.dumps(dummy_template2), template_body=json.dumps(dummy_template2),
parameters={"ZoneId": zone_id}.items(), parameters={"ZoneId": zone_id}.items(),
) )
@ -541,13 +541,14 @@ def test_create_stack_lambda_and_dynamodb():
"ReadCapacityUnits": 10, "ReadCapacityUnits": 10,
"WriteCapacityUnits": 10, "WriteCapacityUnits": 10,
}, },
"StreamSpecification": {"StreamViewType": "KEYS_ONLY"},
}, },
}, },
"func1mapping": { "func1mapping": {
"Type": "AWS::Lambda::EventSourceMapping", "Type": "AWS::Lambda::EventSourceMapping",
"Properties": { "Properties": {
"FunctionName": {"Ref": "func1"}, "FunctionName": {"Ref": "func1"},
"EventSourceArn": "arn:aws:dynamodb:region:XXXXXX:table/tab1/stream/2000T00:00:00.000", "EventSourceArn": {"Fn::GetAtt": ["tab1", "StreamArn"]},
"StartingPosition": "0", "StartingPosition": "0",
"BatchSize": 100, "BatchSize": 100,
"Enabled": True, "Enabled": True,

View File

@ -919,7 +919,9 @@ def test_execute_change_set_w_name():
def test_describe_stack_pagination(): def test_describe_stack_pagination():
conn = boto3.client("cloudformation", region_name="us-east-1") conn = boto3.client("cloudformation", region_name="us-east-1")
for i in range(100): for i in range(100):
conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) conn.create_stack(
StackName="test_stack_{}".format(i), TemplateBody=dummy_template_json
)
resp = conn.describe_stacks() resp = conn.describe_stacks()
stacks = resp["Stacks"] stacks = resp["Stacks"]
@ -1211,7 +1213,8 @@ def test_list_exports_with_token():
# Add index to ensure name is unique # Add index to ensure name is unique
dummy_output_template["Outputs"]["StackVPC"]["Export"]["Name"] += str(i) dummy_output_template["Outputs"]["StackVPC"]["Export"]["Name"] += str(i)
cf.create_stack( cf.create_stack(
StackName="test_stack", TemplateBody=json.dumps(dummy_output_template) StackName="test_stack_{}".format(i),
TemplateBody=json.dumps(dummy_output_template),
) )
exports = cf.list_exports() exports = cf.list_exports()
exports["Exports"].should.have.length_of(100) exports["Exports"].should.have.length_of(100)
@ -1273,3 +1276,16 @@ def test_non_json_redrive_policy():
stack.Resource("MainQueue").resource_status.should.equal("CREATE_COMPLETE") stack.Resource("MainQueue").resource_status.should.equal("CREATE_COMPLETE")
stack.Resource("DeadLetterQueue").resource_status.should.equal("CREATE_COMPLETE") stack.Resource("DeadLetterQueue").resource_status.should.equal("CREATE_COMPLETE")
@mock_cloudformation
def test_boto3_create_duplicate_stack():
cf_conn = boto3.client("cloudformation", region_name="us-east-1")
cf_conn.create_stack(
StackName="test_stack", TemplateBody=dummy_template_json,
)
with assert_raises(ClientError):
cf_conn.create_stack(
StackName="test_stack", TemplateBody=dummy_template_json,
)

View File

@ -2303,6 +2303,7 @@ def test_stack_dynamodb_resources_integration():
}, },
} }
], ],
"StreamSpecification": {"StreamViewType": "KEYS_ONLY"},
}, },
} }
}, },
@ -2315,6 +2316,12 @@ def test_stack_dynamodb_resources_integration():
StackName="dynamodb_stack", TemplateBody=dynamodb_template_json StackName="dynamodb_stack", TemplateBody=dynamodb_template_json
) )
dynamodb_client = boto3.client("dynamodb", region_name="us-east-1")
table_desc = dynamodb_client.describe_table(TableName="myTableName")["Table"]
table_desc["StreamSpecification"].should.equal(
{"StreamEnabled": True, "StreamViewType": "KEYS_ONLY",}
)
dynamodb_conn = boto3.resource("dynamodb", region_name="us-east-1") dynamodb_conn = boto3.resource("dynamodb", region_name="us-east-1")
table = dynamodb_conn.Table("myTableName") table = dynamodb_conn.Table("myTableName")
table.name.should.equal("myTableName") table.name.should.equal("myTableName")

View File

@ -38,6 +38,16 @@ name_type_template = {
}, },
} }
name_type_template_with_tabs_json = """
\t{
\t\t"AWSTemplateFormatVersion": "2010-09-09",
\t\t"Description": "Create a multi-az, load balanced, Auto Scaled sample web site. The Auto Scaling trigger is based on the CPU utilization of the web servers. The AMI is chosen based on the region in which the stack is run. This example creates a web service running across all availability zones in a region. The instances are load balanced with a simple health check. The web site is available on port 80, however, the instances can be configured to listen on any port (8888 by default). **WARNING** This template creates one or more Amazon EC2 instances. You will be billed for the AWS resources used if you create a stack from this template.",
\t\t"Resources": {
\t\t\t"Queue": {"Type": "AWS::SQS::Queue", "Properties": {"VisibilityTimeout": 60}}
\t\t}
\t}
"""
output_dict = { output_dict = {
"Outputs": { "Outputs": {
"Output1": {"Value": {"Ref": "Queue"}, "Description": "This is a description."} "Output1": {"Value": {"Ref": "Queue"}, "Description": "This is a description."}
@ -186,6 +196,21 @@ def test_parse_stack_with_name_type_resource():
queue.should.be.a(Queue) queue.should.be.a(Queue)
def test_parse_stack_with_tabbed_json_template():
stack = FakeStack(
stack_id="test_id",
name="test_stack",
template=name_type_template_with_tabs_json,
parameters={},
region_name="us-west-1",
)
stack.resource_map.should.have.length_of(1)
list(stack.resource_map.keys())[0].should.equal("Queue")
queue = list(stack.resource_map.values())[0]
queue.should.be.a(Queue)
def test_parse_stack_with_yaml_template(): def test_parse_stack_with_yaml_template():
stack = FakeStack( stack = FakeStack(
stack_id="test_id", stack_id="test_id",

View File

@ -3038,6 +3038,54 @@ def test_batch_items_returns_all():
] ]
@mock_dynamodb2
def test_batch_items_throws_exception_when_requesting_100_items_for_single_table():
dynamodb = _create_user_table()
with assert_raises(ClientError) as ex:
dynamodb.batch_get_item(
RequestItems={
"users": {
"Keys": [
{"username": {"S": "user" + str(i)}} for i in range(0, 104)
],
"ConsistentRead": True,
}
}
)
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
msg = ex.exception.response["Error"]["Message"]
msg.should.contain("1 validation error detected: Value")
msg.should.contain(
"at 'requestItems.users.member.keys' failed to satisfy constraint: Member must have length less than or equal to 100"
)
@mock_dynamodb2
def test_batch_items_throws_exception_when_requesting_100_items_across_all_tables():
dynamodb = _create_user_table()
with assert_raises(ClientError) as ex:
dynamodb.batch_get_item(
RequestItems={
"users": {
"Keys": [
{"username": {"S": "user" + str(i)}} for i in range(0, 75)
],
"ConsistentRead": True,
},
"users2": {
"Keys": [
{"username": {"S": "user" + str(i)}} for i in range(0, 75)
],
"ConsistentRead": True,
},
}
)
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
ex.exception.response["Error"]["Message"].should.equal(
"Too many items requested for the BatchGetItem call"
)
@mock_dynamodb2 @mock_dynamodb2
def test_batch_items_with_basic_projection_expression(): def test_batch_items_with_basic_projection_expression():
dynamodb = _create_user_table() dynamodb = _create_user_table()

View File

@ -931,6 +931,83 @@ boto3
""" """
@mock_dynamodb2
def test_boto3_create_table_with_gsi():
dynamodb = boto3.client("dynamodb", region_name="us-east-1")
table = dynamodb.create_table(
TableName="users",
KeySchema=[
{"AttributeName": "forum_name", "KeyType": "HASH"},
{"AttributeName": "subject", "KeyType": "RANGE"},
],
AttributeDefinitions=[
{"AttributeName": "forum_name", "AttributeType": "S"},
{"AttributeName": "subject", "AttributeType": "S"},
],
BillingMode="PAY_PER_REQUEST",
GlobalSecondaryIndexes=[
{
"IndexName": "test_gsi",
"KeySchema": [{"AttributeName": "subject", "KeyType": "HASH"}],
"Projection": {"ProjectionType": "ALL"},
}
],
)
table["TableDescription"]["GlobalSecondaryIndexes"].should.equal(
[
{
"KeySchema": [{"KeyType": "HASH", "AttributeName": "subject"}],
"IndexName": "test_gsi",
"Projection": {"ProjectionType": "ALL"},
"IndexStatus": "ACTIVE",
"ProvisionedThroughput": {
"ReadCapacityUnits": 0,
"WriteCapacityUnits": 0,
},
}
]
)
table = dynamodb.create_table(
TableName="users2",
KeySchema=[
{"AttributeName": "forum_name", "KeyType": "HASH"},
{"AttributeName": "subject", "KeyType": "RANGE"},
],
AttributeDefinitions=[
{"AttributeName": "forum_name", "AttributeType": "S"},
{"AttributeName": "subject", "AttributeType": "S"},
],
BillingMode="PAY_PER_REQUEST",
GlobalSecondaryIndexes=[
{
"IndexName": "test_gsi",
"KeySchema": [{"AttributeName": "subject", "KeyType": "HASH"}],
"Projection": {"ProjectionType": "ALL"},
"ProvisionedThroughput": {
"ReadCapacityUnits": 3,
"WriteCapacityUnits": 5,
},
}
],
)
table["TableDescription"]["GlobalSecondaryIndexes"].should.equal(
[
{
"KeySchema": [{"KeyType": "HASH", "AttributeName": "subject"}],
"IndexName": "test_gsi",
"Projection": {"ProjectionType": "ALL"},
"IndexStatus": "ACTIVE",
"ProvisionedThroughput": {
"ReadCapacityUnits": 3,
"WriteCapacityUnits": 5,
},
}
]
)
@mock_dynamodb2 @mock_dynamodb2
def test_boto3_conditions(): def test_boto3_conditions():
dynamodb = boto3.resource("dynamodb", region_name="us-east-1") dynamodb = boto3.resource("dynamodb", region_name="us-east-1")

View File

@ -843,7 +843,11 @@ def test_ami_snapshots_have_correct_owner():
] ]
existing_snapshot_ids = owner_id_to_snapshot_ids.get(owner_id, []) existing_snapshot_ids = owner_id_to_snapshot_ids.get(owner_id, [])
owner_id_to_snapshot_ids[owner_id] = existing_snapshot_ids + snapshot_ids owner_id_to_snapshot_ids[owner_id] = existing_snapshot_ids + snapshot_ids
# adding an assertion to volumeType
assert (
image.get("BlockDeviceMappings", {})[0].get("Ebs", {}).get("VolumeType")
== "standard"
)
for owner_id in owner_id_to_snapshot_ids: for owner_id in owner_id_to_snapshot_ids:
snapshots_rseponse = ec2_client.describe_snapshots( snapshots_rseponse = ec2_client.describe_snapshots(
SnapshotIds=owner_id_to_snapshot_ids[owner_id] SnapshotIds=owner_id_to_snapshot_ids[owner_id]

View File

@ -128,7 +128,35 @@ def test_instance_terminate_discard_volumes():
@mock_ec2 @mock_ec2
def test_instance_terminate_keep_volumes(): def test_instance_terminate_keep_volumes_explicit():
ec2_resource = boto3.resource("ec2", "us-west-1")
result = ec2_resource.create_instances(
ImageId="ami-d3adb33f",
MinCount=1,
MaxCount=1,
BlockDeviceMappings=[
{
"DeviceName": "/dev/sda1",
"Ebs": {"VolumeSize": 50, "DeleteOnTermination": False},
}
],
)
instance = result[0]
instance_volume_ids = []
for volume in instance.volumes.all():
instance_volume_ids.append(volume.volume_id)
instance.terminate()
instance.wait_until_terminated()
assert len(list(ec2_resource.volumes.all())) == 1
@mock_ec2
def test_instance_terminate_keep_volumes_implicit():
ec2_resource = boto3.resource("ec2", "us-west-1") ec2_resource = boto3.resource("ec2", "us-west-1")
result = ec2_resource.create_instances( result = ec2_resource.create_instances(

View File

@ -462,7 +462,7 @@ def test_routes_not_supported():
# Create # Create
conn.create_route.when.called_with( conn.create_route.when.called_with(
main_route_table.id, ROUTE_CIDR, interface_id="eni-1234abcd" main_route_table.id, ROUTE_CIDR, interface_id="eni-1234abcd"
).should.throw(NotImplementedError) ).should.throw("InvalidNetworkInterfaceID.NotFound")
# Replace # Replace
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
@ -583,6 +583,32 @@ def test_create_route_with_invalid_destination_cidr_block_parameter():
) )
@mock_ec2
def test_create_route_with_network_interface_id():
ec2 = boto3.resource("ec2", region_name="us-west-2")
ec2_client = boto3.client("ec2", region_name="us-west-2")
vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16")
subnet = ec2.create_subnet(
VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a"
)
route_table = ec2_client.create_route_table(VpcId=vpc.id)
route_table_id = route_table["RouteTable"]["RouteTableId"]
eni1 = ec2_client.create_network_interface(
SubnetId=subnet.id, PrivateIpAddress="10.0.10.5"
)
route = ec2_client.create_route(
NetworkInterfaceId=eni1["NetworkInterface"]["NetworkInterfaceId"],
RouteTableId=route_table_id,
)
route["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
@mock_ec2 @mock_ec2
def test_describe_route_tables_with_nat_gateway(): def test_describe_route_tables_with_nat_gateway():
ec2 = boto3.client("ec2", region_name="us-west-1") ec2 = boto3.client("ec2", region_name="us-west-1")

View File

@ -1690,11 +1690,15 @@ def test_get_account_authorization_details():
assert result["RoleDetailList"][0]["AttachedManagedPolicies"][0][ assert result["RoleDetailList"][0]["AttachedManagedPolicies"][0][
"PolicyArn" "PolicyArn"
] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID) ] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID)
assert result["RoleDetailList"][0]["RolePolicyList"][0][
"PolicyDocument"
] == json.loads(test_policy)
result = conn.get_account_authorization_details(Filter=["User"]) result = conn.get_account_authorization_details(Filter=["User"])
assert len(result["RoleDetailList"]) == 0 assert len(result["RoleDetailList"]) == 0
assert len(result["UserDetailList"]) == 1 assert len(result["UserDetailList"]) == 1
assert len(result["UserDetailList"][0]["GroupList"]) == 1 assert len(result["UserDetailList"][0]["GroupList"]) == 1
assert len(result["UserDetailList"][0]["UserPolicyList"]) == 1
assert len(result["UserDetailList"][0]["AttachedManagedPolicies"]) == 1 assert len(result["UserDetailList"][0]["AttachedManagedPolicies"]) == 1
assert len(result["GroupDetailList"]) == 0 assert len(result["GroupDetailList"]) == 0
assert len(result["Policies"]) == 0 assert len(result["Policies"]) == 0
@ -1705,6 +1709,9 @@ def test_get_account_authorization_details():
assert result["UserDetailList"][0]["AttachedManagedPolicies"][0][ assert result["UserDetailList"][0]["AttachedManagedPolicies"][0][
"PolicyArn" "PolicyArn"
] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID) ] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID)
assert result["UserDetailList"][0]["UserPolicyList"][0][
"PolicyDocument"
] == json.loads(test_policy)
result = conn.get_account_authorization_details(Filter=["Group"]) result = conn.get_account_authorization_details(Filter=["Group"])
assert len(result["RoleDetailList"]) == 0 assert len(result["RoleDetailList"]) == 0
@ -1720,6 +1727,9 @@ def test_get_account_authorization_details():
assert result["GroupDetailList"][0]["AttachedManagedPolicies"][0][ assert result["GroupDetailList"][0]["AttachedManagedPolicies"][0][
"PolicyArn" "PolicyArn"
] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID) ] == "arn:aws:iam::{}:policy/testPolicy".format(ACCOUNT_ID)
assert result["GroupDetailList"][0]["GroupPolicyList"][0][
"PolicyDocument"
] == json.loads(test_policy)
result = conn.get_account_authorization_details(Filter=["LocalManagedPolicy"]) result = conn.get_account_authorization_details(Filter=["LocalManagedPolicy"])
assert len(result["RoleDetailList"]) == 0 assert len(result["RoleDetailList"]) == 0

View File

@ -1040,6 +1040,22 @@ def test_s3_object_in_public_bucket_using_multiple_presigned_urls():
assert response.status_code == 200, "Failed on req number {}".format(i) assert response.status_code == 200, "Failed on req number {}".format(i)
@mock_s3
def test_streaming_upload_from_file_to_presigned_url():
s3 = boto3.resource("s3", region_name="us-east-1")
bucket = s3.Bucket("test-bucket")
bucket.create()
bucket.put_object(Body=b"ABCD", Key="file.txt")
params = {"Bucket": "test-bucket", "Key": "file.txt"}
presigned_url = boto3.client("s3").generate_presigned_url(
"put_object", params, ExpiresIn=900
)
with open(__file__, "rb") as f:
response = requests.get(presigned_url, data=f)
assert response.status_code == 200
@mock_s3 @mock_s3
def test_s3_object_in_private_bucket(): def test_s3_object_in_private_bucket():
s3 = boto3.resource("s3") s3 = boto3.resource("s3")
@ -1960,6 +1976,15 @@ def test_boto3_bucket_create_eu_central():
) )
@mock_s3
def test_bucket_create_empty_bucket_configuration_should_return_malformed_xml_error():
s3 = boto3.resource("s3", region_name="us-east-1")
with assert_raises(ClientError) as e:
s3.create_bucket(Bucket="whatever", CreateBucketConfiguration={})
e.exception.response["Error"]["Code"].should.equal("MalformedXML")
e.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400)
@mock_s3 @mock_s3
def test_boto3_head_object(): def test_boto3_head_object():
s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME) s3 = boto3.resource("s3", region_name=DEFAULT_REGION_NAME)
@ -4364,7 +4389,7 @@ def test_s3_config_dict():
# With 1 bucket in us-west-2: # With 1 bucket in us-west-2:
s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2")
s3_config_query.backends["global"].put_bucket_tags("bucket1", tags) s3_config_query.backends["global"].put_bucket_tagging("bucket1", tags)
# With a log bucket: # With a log bucket:
s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2") s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2")

View File

@ -211,6 +211,24 @@ def test_delete_secret_force():
result = conn.get_secret_value(SecretId="test-secret") result = conn.get_secret_value(SecretId="test-secret")
@mock_secretsmanager
def test_delete_secret_force_with_arn():
conn = boto3.client("secretsmanager", region_name="us-west-2")
create_secret = conn.create_secret(Name="test-secret", SecretString="foosecret")
result = conn.delete_secret(
SecretId=create_secret["ARN"], ForceDeleteWithoutRecovery=True
)
assert result["ARN"]
assert result["DeletionDate"] > datetime.fromtimestamp(1, pytz.utc)
assert result["Name"] == "test-secret"
with assert_raises(ClientError):
result = conn.get_secret_value(SecretId="test-secret")
@mock_secretsmanager @mock_secretsmanager
def test_delete_secret_that_does_not_exist(): def test_delete_secret_that_does_not_exist():
conn = boto3.client("secretsmanager", region_name="us-west-2") conn = boto3.client("secretsmanager", region_name="us-west-2")

View File

@ -300,6 +300,118 @@ def test_create_configuration_set():
ex.exception.response["Error"]["Code"].should.equal("EventDestinationAlreadyExists") ex.exception.response["Error"]["Code"].should.equal("EventDestinationAlreadyExists")
@mock_ses
def test_create_receipt_rule_set():
conn = boto3.client("ses", region_name="us-east-1")
result = conn.create_receipt_rule_set(RuleSetName="testRuleSet")
result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
with assert_raises(ClientError) as ex:
conn.create_receipt_rule_set(RuleSetName="testRuleSet")
ex.exception.response["Error"]["Code"].should.equal("RuleSetNameAlreadyExists")
@mock_ses
def test_create_receipt_rule():
conn = boto3.client("ses", region_name="us-east-1")
rule_set_name = "testRuleSet"
conn.create_receipt_rule_set(RuleSetName=rule_set_name)
result = conn.create_receipt_rule(
RuleSetName=rule_set_name,
Rule={
"Name": "testRule",
"Enabled": False,
"TlsPolicy": "Optional",
"Recipients": ["string"],
"Actions": [
{
"S3Action": {
"TopicArn": "string",
"BucketName": "string",
"ObjectKeyPrefix": "string",
"KmsKeyArn": "string",
},
"BounceAction": {
"TopicArn": "string",
"SmtpReplyCode": "string",
"StatusCode": "string",
"Message": "string",
"Sender": "string",
},
}
],
"ScanEnabled": False,
},
)
result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
with assert_raises(ClientError) as ex:
conn.create_receipt_rule(
RuleSetName=rule_set_name,
Rule={
"Name": "testRule",
"Enabled": False,
"TlsPolicy": "Optional",
"Recipients": ["string"],
"Actions": [
{
"S3Action": {
"TopicArn": "string",
"BucketName": "string",
"ObjectKeyPrefix": "string",
"KmsKeyArn": "string",
},
"BounceAction": {
"TopicArn": "string",
"SmtpReplyCode": "string",
"StatusCode": "string",
"Message": "string",
"Sender": "string",
},
}
],
"ScanEnabled": False,
},
)
ex.exception.response["Error"]["Code"].should.equal("RuleAlreadyExists")
with assert_raises(ClientError) as ex:
conn.create_receipt_rule(
RuleSetName="InvalidRuleSetaName",
Rule={
"Name": "testRule",
"Enabled": False,
"TlsPolicy": "Optional",
"Recipients": ["string"],
"Actions": [
{
"S3Action": {
"TopicArn": "string",
"BucketName": "string",
"ObjectKeyPrefix": "string",
"KmsKeyArn": "string",
},
"BounceAction": {
"TopicArn": "string",
"SmtpReplyCode": "string",
"StatusCode": "string",
"Message": "string",
"Sender": "string",
},
}
],
"ScanEnabled": False,
},
)
ex.exception.response["Error"]["Code"].should.equal("RuleSetDoesNotExist")
@mock_ses @mock_ses
def test_create_ses_template(): def test_create_ses_template():
conn = boto3.client("ses", region_name="us-east-1") conn = boto3.client("ses", region_name="us-east-1")