Merge branch 'master' into feature/extend_generic_tagger_to_s3

This commit is contained in:
Steve Pulec 2020-04-25 18:40:50 -05:00 committed by GitHub
commit b24b7cb858
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
80 changed files with 7002 additions and 474 deletions

View File

@ -6,6 +6,9 @@ Moto has a [Code of Conduct](https://github.com/spulec/moto/blob/master/CODE_OF_
Moto has a Makefile which has some helpful commands for getting setup. You should be able to run `make init` to install the dependencies and then `make test` to run the tests.
## Linting
Run `make lint` or `black --check moto tests` to verify whether your code confirms to the guidelines.
## Is there a missing feature?
Moto is easier to contribute to than you probably think. There's [a list of which endpoints have been implemented](https://github.com/spulec/moto/blob/master/IMPLEMENTATION_COVERAGE.md) and we invite you to add new endpoints to existing services or to add new services.

View File

@ -3351,11 +3351,11 @@
- [ ] update_listener
## glue
4% implemented
- [ ] batch_create_partition
11% implemented
- [X] batch_create_partition
- [ ] batch_delete_connection
- [ ] batch_delete_partition
- [ ] batch_delete_table
- [X] batch_delete_partition
- [X] batch_delete_table
- [ ] batch_delete_table_version
- [ ] batch_get_crawlers
- [ ] batch_get_dev_endpoints
@ -3372,7 +3372,7 @@
- [ ] create_dev_endpoint
- [ ] create_job
- [ ] create_ml_transform
- [ ] create_partition
- [X] create_partition
- [ ] create_script
- [ ] create_security_configuration
- [X] create_table
@ -3404,7 +3404,7 @@
- [ ] get_crawlers
- [ ] get_data_catalog_encryption_settings
- [X] get_database
- [ ] get_databases
- [X] get_databases
- [ ] get_dataflow_graph
- [ ] get_dev_endpoint
- [ ] get_dev_endpoints
@ -3418,7 +3418,7 @@
- [ ] get_ml_task_runs
- [ ] get_ml_transform
- [ ] get_ml_transforms
- [ ] get_partition
- [X] get_partition
- [ ] get_partitions
- [ ] get_plan
- [ ] get_resource_policy
@ -3470,8 +3470,8 @@
- [ ] update_dev_endpoint
- [ ] update_job
- [ ] update_ml_transform
- [ ] update_partition
- [ ] update_table
- [X] update_partition
- [X] update_table
- [ ] update_trigger
- [ ] update_user_defined_function
- [ ] update_workflow
@ -7210,13 +7210,13 @@
- [ ] update_vtl_device_type
## sts
50% implemented
62% implemented
- [X] assume_role
- [ ] assume_role_with_saml
- [X] assume_role_with_web_identity
- [ ] decode_authorization_message
- [ ] get_access_key_info
- [ ] get_caller_identity
- [X] get_caller_identity
- [X] get_federation_token
- [X] get_session_token

View File

@ -119,3 +119,57 @@ class ApiKeyAlreadyExists(RESTError):
super(ApiKeyAlreadyExists, self).__init__(
"ConflictException", "API Key already exists"
)
class InvalidDomainName(BadRequestException):
code = 404
def __init__(self):
super(InvalidDomainName, self).__init__(
"BadRequestException", "No Domain Name specified"
)
class DomainNameNotFound(RESTError):
code = 404
def __init__(self):
super(DomainNameNotFound, self).__init__(
"NotFoundException", "Invalid Domain Name specified"
)
class InvalidRestApiId(BadRequestException):
code = 404
def __init__(self):
super(InvalidRestApiId, self).__init__(
"BadRequestException", "No Rest API Id specified"
)
class InvalidModelName(BadRequestException):
code = 404
def __init__(self):
super(InvalidModelName, self).__init__(
"BadRequestException", "No Model Name specified"
)
class RestAPINotFound(RESTError):
code = 404
def __init__(self):
super(RestAPINotFound, self).__init__(
"NotFoundException", "Invalid Rest API Id specified"
)
class ModelNotFound(RESTError):
code = 404
def __init__(self):
super(ModelNotFound, self).__init__(
"NotFoundException", "Invalid Model Name specified"
)

View File

@ -34,6 +34,12 @@ from .exceptions import (
NoIntegrationDefined,
NoMethodDefined,
ApiKeyAlreadyExists,
DomainNameNotFound,
InvalidDomainName,
InvalidRestApiId,
InvalidModelName,
RestAPINotFound,
ModelNotFound,
)
STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
@ -455,6 +461,7 @@ class RestAPI(BaseModel):
self.description = description
self.create_date = int(time.time())
self.api_key_source = kwargs.get("api_key_source") or "HEADER"
self.policy = kwargs.get("policy") or None
self.endpoint_configuration = kwargs.get("endpoint_configuration") or {
"types": ["EDGE"]
}
@ -463,8 +470,8 @@ class RestAPI(BaseModel):
self.deployments = {}
self.authorizers = {}
self.stages = {}
self.resources = {}
self.models = {}
self.add_child("/") # Add default child
def __repr__(self):
@ -479,6 +486,7 @@ class RestAPI(BaseModel):
"apiKeySource": self.api_key_source,
"endpointConfiguration": self.endpoint_configuration,
"tags": self.tags,
"policy": self.policy,
}
def add_child(self, path, parent_id=None):
@ -493,6 +501,29 @@ class RestAPI(BaseModel):
self.resources[child_id] = child
return child
def add_model(
self,
name,
description=None,
schema=None,
content_type=None,
cli_input_json=None,
generate_cli_skeleton=None,
):
model_id = create_id()
new_model = Model(
id=model_id,
name=name,
description=description,
schema=schema,
content_type=content_type,
cli_input_json=cli_input_json,
generate_cli_skeleton=generate_cli_skeleton,
)
self.models[name] = new_model
return new_model
def get_resource_for_path(self, path_after_stage_name):
for resource in self.resources.values():
if resource.get_path() == path_after_stage_name:
@ -609,6 +640,58 @@ class RestAPI(BaseModel):
return self.deployments.pop(deployment_id)
class DomainName(BaseModel, dict):
def __init__(self, domain_name, **kwargs):
super(DomainName, self).__init__()
self["domainName"] = domain_name
self["regionalDomainName"] = domain_name
self["distributionDomainName"] = domain_name
self["domainNameStatus"] = "AVAILABLE"
self["domainNameStatusMessage"] = "Domain Name Available"
self["regionalHostedZoneId"] = "Z2FDTNDATAQYW2"
self["distributionHostedZoneId"] = "Z2FDTNDATAQYW2"
self["certificateUploadDate"] = int(time.time())
if kwargs.get("certificate_name"):
self["certificateName"] = kwargs.get("certificate_name")
if kwargs.get("certificate_arn"):
self["certificateArn"] = kwargs.get("certificate_arn")
if kwargs.get("certificate_body"):
self["certificateBody"] = kwargs.get("certificate_body")
if kwargs.get("tags"):
self["tags"] = kwargs.get("tags")
if kwargs.get("security_policy"):
self["securityPolicy"] = kwargs.get("security_policy")
if kwargs.get("certificate_chain"):
self["certificateChain"] = kwargs.get("certificate_chain")
if kwargs.get("regional_certificate_name"):
self["regionalCertificateName"] = kwargs.get("regional_certificate_name")
if kwargs.get("certificate_private_key"):
self["certificatePrivateKey"] = kwargs.get("certificate_private_key")
if kwargs.get("regional_certificate_arn"):
self["regionalCertificateArn"] = kwargs.get("regional_certificate_arn")
if kwargs.get("endpoint_configuration"):
self["endpointConfiguration"] = kwargs.get("endpoint_configuration")
if kwargs.get("generate_cli_skeleton"):
self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton")
class Model(BaseModel, dict):
def __init__(self, id, name, **kwargs):
super(Model, self).__init__()
self["id"] = id
self["name"] = name
if kwargs.get("description"):
self["description"] = kwargs.get("description")
if kwargs.get("schema"):
self["schema"] = kwargs.get("schema")
if kwargs.get("content_type"):
self["contentType"] = kwargs.get("content_type")
if kwargs.get("cli_input_json"):
self["cliInputJson"] = kwargs.get("cli_input_json")
if kwargs.get("generate_cli_skeleton"):
self["generateCliSkeleton"] = kwargs.get("generate_cli_skeleton")
class APIGatewayBackend(BaseBackend):
def __init__(self, region_name):
super(APIGatewayBackend, self).__init__()
@ -616,6 +699,8 @@ class APIGatewayBackend(BaseBackend):
self.keys = {}
self.usage_plans = {}
self.usage_plan_keys = {}
self.domain_names = {}
self.models = {}
self.region_name = region_name
def reset(self):
@ -630,6 +715,7 @@ class APIGatewayBackend(BaseBackend):
api_key_source=None,
endpoint_configuration=None,
tags=None,
policy=None,
):
api_id = create_id()
rest_api = RestAPI(
@ -640,12 +726,15 @@ class APIGatewayBackend(BaseBackend):
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
)
self.apis[api_id] = rest_api
return rest_api
def get_rest_api(self, function_id):
rest_api = self.apis[function_id]
rest_api = self.apis.get(function_id)
if rest_api is None:
raise RestAPINotFound()
return rest_api
def list_apis(self):
@ -1001,6 +1090,98 @@ class APIGatewayBackend(BaseBackend):
except Exception:
return False
def create_domain_name(
self,
domain_name,
certificate_name=None,
tags=None,
certificate_arn=None,
certificate_body=None,
certificate_private_key=None,
certificate_chain=None,
regional_certificate_name=None,
regional_certificate_arn=None,
endpoint_configuration=None,
security_policy=None,
generate_cli_skeleton=None,
):
if not domain_name:
raise InvalidDomainName()
new_domain_name = DomainName(
domain_name=domain_name,
certificate_name=certificate_name,
certificate_private_key=certificate_private_key,
certificate_arn=certificate_arn,
certificate_body=certificate_body,
certificate_chain=certificate_chain,
regional_certificate_name=regional_certificate_name,
regional_certificate_arn=regional_certificate_arn,
endpoint_configuration=endpoint_configuration,
tags=tags,
security_policy=security_policy,
generate_cli_skeleton=generate_cli_skeleton,
)
self.domain_names[domain_name] = new_domain_name
return new_domain_name
def get_domain_names(self):
return list(self.domain_names.values())
def get_domain_name(self, domain_name):
domain_info = self.domain_names.get(domain_name)
if domain_info is None:
raise DomainNameNotFound
else:
return self.domain_names[domain_name]
def create_model(
self,
rest_api_id,
name,
content_type,
description=None,
schema=None,
cli_input_json=None,
generate_cli_skeleton=None,
):
if not rest_api_id:
raise InvalidRestApiId
if not name:
raise InvalidModelName
api = self.get_rest_api(rest_api_id)
new_model = api.add_model(
name=name,
description=description,
schema=schema,
content_type=content_type,
cli_input_json=cli_input_json,
generate_cli_skeleton=generate_cli_skeleton,
)
return new_model
def get_models(self, rest_api_id):
if not rest_api_id:
raise InvalidRestApiId
api = self.get_rest_api(rest_api_id)
models = api.models.values()
return list(models)
def get_model(self, rest_api_id, model_name):
if not rest_api_id:
raise InvalidRestApiId
api = self.get_rest_api(rest_api_id)
model = api.models.get(model_name)
if model is None:
raise ModelNotFound
else:
return model
apigateway_backends = {}
for region_name in Session().get_available_regions("apigateway"):

View File

@ -11,6 +11,12 @@ from .exceptions import (
AuthorizerNotFoundException,
StageNotFoundException,
ApiKeyAlreadyExists,
DomainNameNotFound,
InvalidDomainName,
InvalidRestApiId,
InvalidModelName,
RestAPINotFound,
ModelNotFound,
)
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
@ -53,6 +59,7 @@ class APIGatewayResponse(BaseResponse):
api_key_source = self._get_param("apiKeySource")
endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags")
policy = self._get_param("policy")
# Param validation
if api_key_source and api_key_source not in API_KEY_SOURCES:
@ -88,6 +95,7 @@ class APIGatewayResponse(BaseResponse):
api_key_source=api_key_source,
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
)
return 200, {}, json.dumps(rest_api.to_dict())
@ -527,3 +535,130 @@ class APIGatewayResponse(BaseResponse):
usage_plan_id, key_id
)
return 200, {}, json.dumps(usage_plan_response)
def domain_names(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
try:
if self.method == "GET":
domain_names = self.backend.get_domain_names()
return 200, {}, json.dumps({"item": domain_names})
elif self.method == "POST":
domain_name = self._get_param("domainName")
certificate_name = self._get_param("certificateName")
tags = self._get_param("tags")
certificate_arn = self._get_param("certificateArn")
certificate_body = self._get_param("certificateBody")
certificate_private_key = self._get_param("certificatePrivateKey")
certificate_chain = self._get_param("certificateChain")
regional_certificate_name = self._get_param("regionalCertificateName")
regional_certificate_arn = self._get_param("regionalCertificateArn")
endpoint_configuration = self._get_param("endpointConfiguration")
security_policy = self._get_param("securityPolicy")
generate_cli_skeleton = self._get_param("generateCliSkeleton")
domain_name_resp = self.backend.create_domain_name(
domain_name,
certificate_name,
tags,
certificate_arn,
certificate_body,
certificate_private_key,
certificate_chain,
regional_certificate_name,
regional_certificate_arn,
endpoint_configuration,
security_policy,
generate_cli_skeleton,
)
return 200, {}, json.dumps(domain_name_resp)
except InvalidDomainName as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
def domain_name_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
domain_name = url_path_parts[2]
domain_names = {}
try:
if self.method == "GET":
if domain_name is not None:
domain_names = self.backend.get_domain_name(domain_name)
return 200, {}, json.dumps(domain_names)
except DomainNameNotFound as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
def models(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
rest_api_id = self.path.replace("/restapis/", "", 1).split("/")[0]
try:
if self.method == "GET":
models = self.backend.get_models(rest_api_id)
return 200, {}, json.dumps({"item": models})
elif self.method == "POST":
name = self._get_param("name")
description = self._get_param("description")
schema = self._get_param("schema")
content_type = self._get_param("contentType")
cli_input_json = self._get_param("cliInputJson")
generate_cli_skeleton = self._get_param("generateCliSkeleton")
model = self.backend.create_model(
rest_api_id,
name,
content_type,
description,
schema,
cli_input_json,
generate_cli_skeleton,
)
return 200, {}, json.dumps(model)
except (InvalidRestApiId, InvalidModelName, RestAPINotFound) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
def model_induvidual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
rest_api_id = url_path_parts[2]
model_name = url_path_parts[4]
model_info = {}
try:
if self.method == "GET":
model_info = self.backend.get_model(rest_api_id, model_name)
return 200, {}, json.dumps(model_info)
except (
ModelNotFound,
RestAPINotFound,
InvalidRestApiId,
InvalidModelName,
) as error:
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)

View File

@ -21,6 +21,10 @@ url_paths = {
"{0}/apikeys$": APIGatewayResponse().apikeys,
"{0}/apikeys/(?P<apikey>[^/]+)": APIGatewayResponse().apikey_individual,
"{0}/usageplans$": APIGatewayResponse().usage_plans,
"{0}/domainnames$": APIGatewayResponse().domain_names,
"{0}/restapis/(?P<function_id>[^/]+)/models$": APIGatewayResponse().models,
"{0}/restapis/(?P<function_id>[^/]+)/models/(?P<model_name>[^/]+)/?$": APIGatewayResponse().model_induvidual,
"{0}/domainnames/(?P<domain_name>[^/]+)/?$": APIGatewayResponse().domain_name_induvidual,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$": APIGatewayResponse().usage_plan_individual,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$": APIGatewayResponse().usage_plan_keys,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual,

View File

@ -267,6 +267,9 @@ class FakeAutoScalingGroup(BaseModel):
self.tags = tags if tags else []
self.set_desired_capacity(desired_capacity)
def active_instances(self):
return [x for x in self.instance_states if x.lifecycle_state == "InService"]
def _set_azs_and_vpcs(self, availability_zones, vpc_zone_identifier, update=False):
# for updates, if only AZs are provided, they must not clash with
# the AZs of existing VPCs
@ -413,9 +416,11 @@ class FakeAutoScalingGroup(BaseModel):
else:
self.desired_capacity = new_capacity
curr_instance_count = len(self.instance_states)
curr_instance_count = len(self.active_instances())
if self.desired_capacity == curr_instance_count:
self.autoscaling_backend.update_attached_elbs(self.name)
self.autoscaling_backend.update_attached_target_groups(self.name)
return
if self.desired_capacity > curr_instance_count:
@ -442,6 +447,8 @@ class FakeAutoScalingGroup(BaseModel):
self.instance_states = list(
set(self.instance_states) - set(instances_to_remove)
)
self.autoscaling_backend.update_attached_elbs(self.name)
self.autoscaling_backend.update_attached_target_groups(self.name)
def get_propagated_tags(self):
propagated_tags = {}
@ -655,10 +662,16 @@ class AutoScalingBackend(BaseBackend):
self.set_desired_capacity(group_name, 0)
self.autoscaling_groups.pop(group_name, None)
def describe_auto_scaling_instances(self):
def describe_auto_scaling_instances(self, instance_ids):
instance_states = []
for group in self.autoscaling_groups.values():
instance_states.extend(group.instance_states)
instance_states.extend(
[
x
for x in group.instance_states
if not instance_ids or x.instance.id in instance_ids
]
)
return instance_states
def attach_instances(self, group_name, instance_ids):
@ -697,7 +710,7 @@ class AutoScalingBackend(BaseBackend):
def detach_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states)
original_size = group.desired_capacity
detached_instances = [
x for x in group.instance_states if x.instance.id in instance_ids
@ -714,13 +727,8 @@ class AutoScalingBackend(BaseBackend):
if should_decrement:
group.desired_capacity = original_size - len(instance_ids)
else:
count_needed = len(instance_ids)
group.replace_autoscaling_group_instances(
count_needed, group.get_propagated_tags()
)
self.update_attached_elbs(group_name)
group.set_desired_capacity(group.desired_capacity)
return detached_instances
def set_desired_capacity(self, group_name, desired_capacity):
@ -785,7 +793,9 @@ class AutoScalingBackend(BaseBackend):
def update_attached_elbs(self, group_name):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(state.instance.id for state in group.instance_states)
group_instance_ids = set(
state.instance.id for state in group.active_instances()
)
# skip this if group.load_balancers is empty
# otherwise elb_backend.describe_load_balancers returns all available load balancers
@ -902,15 +912,15 @@ class AutoScalingBackend(BaseBackend):
autoscaling_group_name,
autoscaling_group,
) in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
original_active_instance_count = len(autoscaling_group.active_instances())
autoscaling_group.instance_states = list(
filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states,
)
)
difference = original_instance_count - len(
autoscaling_group.instance_states
difference = original_active_instance_count - len(
autoscaling_group.active_instances()
)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(
@ -918,6 +928,45 @@ class AutoScalingBackend(BaseBackend):
)
self.update_attached_elbs(autoscaling_group_name)
def enter_standby_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name]
original_size = group.desired_capacity
standby_instances = []
for instance_state in group.instance_states:
if instance_state.instance.id in instance_ids:
instance_state.lifecycle_state = "Standby"
standby_instances.append(instance_state)
if should_decrement:
group.desired_capacity = group.desired_capacity - len(instance_ids)
else:
group.set_desired_capacity(group.desired_capacity)
return standby_instances, original_size, group.desired_capacity
def exit_standby_instances(self, group_name, instance_ids):
group = self.autoscaling_groups[group_name]
original_size = group.desired_capacity
standby_instances = []
for instance_state in group.instance_states:
if instance_state.instance.id in instance_ids:
instance_state.lifecycle_state = "InService"
standby_instances.append(instance_state)
group.desired_capacity = group.desired_capacity + len(instance_ids)
return standby_instances, original_size, group.desired_capacity
def terminate_instance(self, instance_id, should_decrement):
instance = self.ec2_backend.get_instance(instance_id)
instance_state = next(
instance_state
for group in self.autoscaling_groups.values()
for instance_state in group.instance_states
if instance_state.instance.id == instance.id
)
group = instance.autoscaling_group
original_size = group.desired_capacity
self.detach_instances(group.name, [instance.id], should_decrement)
self.ec2_backend.terminate_instances([instance.id])
return instance_state, original_size, group.desired_capacity
autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items():

View File

@ -1,7 +1,12 @@
from __future__ import unicode_literals
import datetime
from moto.core.responses import BaseResponse
from moto.core.utils import amz_crc32, amzn_request_id
from moto.core.utils import (
amz_crc32,
amzn_request_id,
iso_8601_datetime_with_milliseconds,
)
from .models import autoscaling_backends
@ -226,7 +231,9 @@ class AutoScalingResponse(BaseResponse):
return template.render()
def describe_auto_scaling_instances(self):
instance_states = self.autoscaling_backend.describe_auto_scaling_instances()
instance_states = self.autoscaling_backend.describe_auto_scaling_instances(
instance_ids=self._get_multi_param("InstanceIds.member")
)
template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
return template.render(instance_states=instance_states)
@ -289,6 +296,50 @@ class AutoScalingResponse(BaseResponse):
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def enter_standby(self):
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == "true":
should_decrement = True
else:
should_decrement = False
(
standby_instances,
original_size,
desired_capacity,
) = self.autoscaling_backend.enter_standby_instances(
group_name, instance_ids, should_decrement
)
template = self.response_template(ENTER_STANDBY_TEMPLATE)
return template.render(
standby_instances=standby_instances,
should_decrement=should_decrement,
original_size=original_size,
desired_capacity=desired_capacity,
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
)
@amz_crc32
@amzn_request_id
def exit_standby(self):
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
(
standby_instances,
original_size,
desired_capacity,
) = self.autoscaling_backend.exit_standby_instances(group_name, instance_ids)
template = self.response_template(EXIT_STANDBY_TEMPLATE)
return template.render(
standby_instances=standby_instances,
original_size=original_size,
desired_capacity=desired_capacity,
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
)
def suspend_processes(self):
autoscaling_group_name = self._get_param("AutoScalingGroupName")
scaling_processes = self._get_multi_param("ScalingProcesses.member")
@ -308,6 +359,29 @@ class AutoScalingResponse(BaseResponse):
template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def terminate_instance_in_auto_scaling_group(self):
instance_id = self._get_param("InstanceId")
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == "true":
should_decrement = True
else:
should_decrement = False
(
instance,
original_size,
desired_capacity,
) = self.autoscaling_backend.terminate_instance(instance_id, should_decrement)
template = self.response_template(TERMINATE_INSTANCES_TEMPLATE)
return template.render(
instance=instance,
should_decrement=should_decrement,
original_size=original_size,
desired_capacity=desired_capacity,
timestamp=iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()),
)
CREATE_LAUNCH_CONFIGURATION_TEMPLATE = """<CreateLaunchConfigurationResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<ResponseMetadata>
@ -705,3 +779,73 @@ SET_INSTANCE_PROTECTION_TEMPLATE = """<SetInstanceProtectionResponse xmlns="http
<RequestId></RequestId>
</ResponseMetadata>
</SetInstanceProtectionResponse>"""
ENTER_STANDBY_TEMPLATE = """<EnterStandbyResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<EnterStandbyResult>
<Activities>
{% for instance in standby_instances %}
<member>
<ActivityId>12345678-1234-1234-1234-123456789012</ActivityId>
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
{% if should_decrement %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved to standby in response to a user request, shrinking the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
{% else %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved to standby in response to a user request.</Cause>
{% endif %}
<Description>Moving EC2 instance to Standby: {{ instance.instance.id }}</Description>
<Progress>50</Progress>
<StartTime>{{ timestamp }}</StartTime>
<Details>{&quot;Subnet ID&quot;:&quot;??&quot;,&quot;Availability Zone&quot;:&quot;{{ instance.instance.placement }}&quot;}</Details>
<StatusCode>InProgress</StatusCode>
</member>
{% endfor %}
</Activities>
</EnterStandbyResult>
<ResponseMetadata>
<RequestId>7c6e177f-f082-11e1-ac58-3714bEXAMPLE</RequestId>
</ResponseMetadata>
</EnterStandbyResponse>"""
EXIT_STANDBY_TEMPLATE = """<ExitStandbyResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<ExitStandbyResult>
<Activities>
{% for instance in standby_instances %}
<member>
<ActivityId>12345678-1234-1234-1234-123456789012</ActivityId>
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
<Description>Moving EC2 instance out of Standby: {{ instance.instance.id }}</Description>
<Progress>30</Progress>
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was moved out of standby in response to a user request, increasing the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
<StartTime>{{ timestamp }}</StartTime>
<Details>{&quot;Subnet ID&quot;:&quot;??&quot;,&quot;Availability Zone&quot;:&quot;{{ instance.instance.placement }}&quot;}</Details>
<StatusCode>PreInService</StatusCode>
</member>
{% endfor %}
</Activities>
</ExitStandbyResult>
<ResponseMetadata>
<RequestId>7c6e177f-f082-11e1-ac58-3714bEXAMPLE</RequestId>
</ResponseMetadata>
</ExitStandbyResponse>"""
TERMINATE_INSTANCES_TEMPLATE = """<TerminateInstanceInAutoScalingGroupResponse xmlns="http://autoscaling.amazonaws.com/doc/2011-01-01/">
<TerminateInstanceInAutoScalingGroupResult>
<Activity>
<ActivityId>35b5c464-0b63-2fc7-1611-467d4a7f2497EXAMPLE</ActivityId>
<AutoScalingGroupName>{{ group_name }}</AutoScalingGroupName>
{% if should_decrement %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was taken out of service in response to a user request, shrinking the capacity from {{ original_size }} to {{ desired_capacity }}.</Cause>
{% else %}
<Cause>At {{ timestamp }} instance {{ instance.instance.id }} was taken out of service in response to a user request.</Cause>
{% endif %}
<Description>Terminating EC2 instance: {{ instance.instance.id }}</Description>
<Progress>0</Progress>
<StartTime>{{ timestamp }}</StartTime>
<Details>{&quot;Subnet ID&quot;:&quot;??&quot;,&quot;Availability Zone&quot;:&quot;{{ instance.instance.placement }}&quot;}</Details>
<StatusCode>InProgress</StatusCode>
</Activity>
</TerminateInstanceInAutoScalingGroupResult>
<ResponseMetadata>
<RequestId>a1ba8fb9-31d6-4d9a-ace1-a7f76749df11EXAMPLE</RequestId>
</ResponseMetadata>
</TerminateInstanceInAutoScalingGroupResponse>"""

View File

@ -1006,11 +1006,11 @@ class LambdaBackend(BaseBackend):
return True
return False
def add_policy_statement(self, function_name, raw):
def add_permission(self, function_name, raw):
fn = self.get_function(function_name)
fn.policy.add_statement(raw)
def del_policy_statement(self, function_name, sid, revision=""):
def remove_permission(self, function_name, sid, revision=""):
fn = self.get_function(function_name)
fn.policy.del_statement(sid, revision)

View File

@ -146,7 +146,7 @@ class LambdaResponse(BaseResponse):
function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name):
statement = self.body
self.lambda_backend.add_policy_statement(function_name, statement)
self.lambda_backend.add_permission(function_name, statement)
return 200, {}, json.dumps({"Statement": statement})
else:
return 404, {}, "{}"
@ -166,9 +166,7 @@ class LambdaResponse(BaseResponse):
statement_id = path.split("/")[-1].split("?")[0]
revision = querystring.get("RevisionId", "")
if self.lambda_backend.get_function(function_name):
self.lambda_backend.del_policy_statement(
function_name, statement_id, revision
)
self.lambda_backend.remove_permission(function_name, statement_id, revision)
return 204, {}, "{}"
else:
return 404, {}, "{}"
@ -184,9 +182,9 @@ class LambdaResponse(BaseResponse):
function_name, qualifier, self.body, self.headers, response_headers
)
if payload:
if request.headers["X-Amz-Invocation-Type"] == "Event":
if request.headers.get("X-Amz-Invocation-Type") == "Event":
status_code = 202
elif request.headers["X-Amz-Invocation-Type"] == "DryRun":
elif request.headers.get("X-Amz-Invocation-Type") == "DryRun":
status_code = 204
else:
status_code = 200

View File

@ -22,6 +22,14 @@ class Dimension(object):
self.name = name
self.value = value
def __eq__(self, item):
if isinstance(item, Dimension):
return self.name == item.name and self.value == item.value
return False
def __ne__(self, item): # Only needed on Py2; Py3 defines it implicitly
return self != item
def daterange(start, stop, step=timedelta(days=1), inclusive=False):
"""
@ -124,6 +132,17 @@ class MetricDatum(BaseModel):
Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
]
def filter(self, namespace, name, dimensions):
if namespace and namespace != self.namespace:
return False
if name and name != self.name:
return False
if dimensions and any(
Dimension(d["Name"], d["Value"]) not in self.dimensions for d in dimensions
):
return False
return True
class Dashboard(BaseModel):
def __init__(self, name, body):
@ -202,6 +221,15 @@ class CloudWatchBackend(BaseBackend):
self.metric_data = []
self.paged_metric_data = {}
@property
# Retrieve a list of all OOTB metrics that are provided by metrics providers
# Computed on the fly
def aws_metric_data(self):
md = []
for name, service in metric_providers.items():
md.extend(service.get_cloudwatch_metrics())
return md
def put_metric_alarm(
self,
name,
@ -295,6 +323,43 @@ class CloudWatchBackend(BaseBackend):
)
)
def get_metric_data(self, queries, start_time, end_time):
period_data = [
md for md in self.metric_data if start_time <= md.timestamp <= end_time
]
results = []
for query in queries:
query_ns = query["metric_stat._metric._namespace"]
query_name = query["metric_stat._metric._metric_name"]
query_data = [
md
for md in period_data
if md.namespace == query_ns and md.name == query_name
]
metric_values = [m.value for m in query_data]
result_vals = []
stat = query["metric_stat._stat"]
if len(metric_values) > 0:
if stat == "Average":
result_vals.append(sum(metric_values) / len(metric_values))
elif stat == "Minimum":
result_vals.append(min(metric_values))
elif stat == "Maximum":
result_vals.append(max(metric_values))
elif stat == "Sum":
result_vals.append(sum(metric_values))
label = query["metric_stat._metric._metric_name"] + " " + stat
results.append(
{
"id": query["id"],
"label": label,
"vals": result_vals,
"timestamps": [datetime.now() for _ in result_vals],
}
)
return results
def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats
):
@ -334,7 +399,7 @@ class CloudWatchBackend(BaseBackend):
return data
def get_all_metrics(self):
return self.metric_data
return self.metric_data + self.aws_metric_data
def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body)
@ -386,7 +451,7 @@ class CloudWatchBackend(BaseBackend):
self.alarms[alarm_name].update_state(reason, reason_data, state_value)
def list_metrics(self, next_token, namespace, metric_name):
def list_metrics(self, next_token, namespace, metric_name, dimensions):
if next_token:
if next_token not in self.paged_metric_data:
raise RESTError(
@ -397,15 +462,16 @@ class CloudWatchBackend(BaseBackend):
del self.paged_metric_data[next_token] # Cant reuse same token twice
return self._get_paginated(metrics)
else:
metrics = self.get_filtered_metrics(metric_name, namespace)
metrics = self.get_filtered_metrics(metric_name, namespace, dimensions)
return self._get_paginated(metrics)
def get_filtered_metrics(self, metric_name, namespace):
def get_filtered_metrics(self, metric_name, namespace, dimensions):
metrics = self.get_all_metrics()
if namespace:
metrics = [md for md in metrics if md.namespace == namespace]
if metric_name:
metrics = [md for md in metrics if md.name == metric_name]
metrics = [
md
for md in metrics
if md.filter(namespace=namespace, name=metric_name, dimensions=dimensions)
]
return metrics
def _get_paginated(self, metrics):
@ -431,7 +497,9 @@ class LogGroup(BaseModel):
properties = cloudformation_json["Properties"]
log_group_name = properties["LogGroupName"]
tags = properties.get("Tags", {})
return logs_backends[region_name].create_log_group(log_group_name, tags)
return logs_backends[region_name].create_log_group(
log_group_name, tags, **properties
)
cloudwatch_backends = {}
@ -443,3 +511,8 @@ for region in Session().get_available_regions(
cloudwatch_backends[region] = CloudWatchBackend()
for region in Session().get_available_regions("cloudwatch", partition_name="aws-cn"):
cloudwatch_backends[region] = CloudWatchBackend()
# List of services that provide OOTB CW metrics
# See the S3Backend constructor for an example
# TODO: We might have to separate this out per region for non-global services
metric_providers = {}

View File

@ -92,6 +92,18 @@ class CloudWatchResponse(BaseResponse):
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
return template.render()
@amzn_request_id
def get_metric_data(self):
start = dtparse(self._get_param("StartTime"))
end = dtparse(self._get_param("EndTime"))
queries = self._get_list_prefix("MetricDataQueries.member")
results = self.cloudwatch_backend.get_metric_data(
start_time=start, end_time=end, queries=queries
)
template = self.response_template(GET_METRIC_DATA_TEMPLATE)
return template.render(results=results)
@amzn_request_id
def get_metric_statistics(self):
namespace = self._get_param("Namespace")
@ -124,9 +136,10 @@ class CloudWatchResponse(BaseResponse):
def list_metrics(self):
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
dimensions = self._get_multi_param("Dimensions.member")
next_token = self._get_param("NextToken")
next_token, metrics = self.cloudwatch_backend.list_metrics(
next_token, namespace, metric_name
next_token, namespace, metric_name, dimensions
)
template = self.response_template(LIST_METRICS_TEMPLATE)
return template.render(metrics=metrics, next_token=next_token)
@ -285,6 +298,35 @@ PUT_METRIC_DATA_TEMPLATE = """<PutMetricDataResponse xmlns="http://monitoring.am
</ResponseMetadata>
</PutMetricDataResponse>"""
GET_METRIC_DATA_TEMPLATE = """<GetMetricDataResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
<RequestId>
{{ request_id }}
</RequestId>
</ResponseMetadata>
<GetMetricDataResult>
<MetricDataResults>
{% for result in results %}
<member>
<Id>{{ result.id }}</Id>
<Label>{{ result.label }}</Label>
<StatusCode>Complete</StatusCode>
<Timestamps>
{% for val in result.timestamps %}
<member>{{ val }}</member>
{% endfor %}
</Timestamps>
<Values>
{% for val in result.vals %}
<member>{{ val }}</member>
{% endfor %}
</Values>
</member>
{% endfor %}
</MetricDataResults>
</GetMetricDataResult>
</GetMetricDataResponse>"""
GET_METRIC_STATISTICS_TEMPLATE = """<GetMetricStatisticsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<ResponseMetadata>
<RequestId>
@ -342,7 +384,7 @@ LIST_METRICS_TEMPLATE = """<ListMetricsResponse xmlns="http://monitoring.amazona
</member>
{% endfor %}
</Dimensions>
<MetricName>{{ metric.name }}</MetricName>
<MetricName>Metric:{{ metric.name }}</MetricName>
<Namespace>{{ metric.namespace }}</Namespace>
</member>
{% endfor %}

View File

@ -1,5 +1,5 @@
from moto.core.utils import get_random_hex
from uuid import uuid4
def get_random_identity_id(region):
return "{0}:{1}".format(region, get_random_hex(length=19))
return "{0}:{1}".format(region, uuid4())

View File

@ -12,6 +12,8 @@ from io import BytesIO
from collections import defaultdict
from botocore.handlers import BUILTIN_HANDLERS
from botocore.awsrequest import AWSResponse
from six.moves.urllib.parse import urlparse
from werkzeug.wrappers import Request
import mock
from moto import settings
@ -175,6 +177,26 @@ class CallbackResponse(responses.CallbackResponse):
"""
Need to override this so we can pass decode_content=False
"""
if not isinstance(request, Request):
url = urlparse(request.url)
if request.body is None:
body = None
elif isinstance(request.body, six.text_type):
body = six.BytesIO(six.b(request.body))
else:
body = six.BytesIO(request.body)
req = Request.from_values(
path="?".join([url.path, url.query]),
input_stream=body,
content_length=request.headers.get("Content-Length"),
content_type=request.headers.get("Content-Type"),
method=request.method,
base_url="{scheme}://{netloc}".format(
scheme=url.scheme, netloc=url.netloc
),
headers=[(k, v) for k, v in six.iteritems(request.headers)],
)
request = req
headers = self.get_headers()
result = self.callback(request)

View File

@ -1,5 +1,5 @@
from __future__ import unicode_literals
from .models import dynamodb_backends as dynamodb_backends2
from moto.dynamodb2.models import dynamodb_backends as dynamodb_backends2
from ..core.models import base_decorator, deprecated_base_decorator
dynamodb_backend2 = dynamodb_backends2["us-east-1"]

View File

@ -2,9 +2,132 @@ class InvalidIndexNameError(ValueError):
pass
class InvalidUpdateExpression(ValueError):
pass
class MockValidationException(ValueError):
def __init__(self, message):
self.exception_msg = message
class ItemSizeTooLarge(Exception):
message = "Item size has exceeded the maximum allowed size"
class InvalidUpdateExpressionInvalidDocumentPath(MockValidationException):
invalid_update_expression_msg = (
"The document path provided in the update expression is invalid for update"
)
def __init__(self):
super(InvalidUpdateExpressionInvalidDocumentPath, self).__init__(
self.invalid_update_expression_msg
)
class InvalidUpdateExpression(MockValidationException):
invalid_update_expr_msg = "Invalid UpdateExpression: {update_expression_error}"
def __init__(self, update_expression_error):
self.update_expression_error = update_expression_error
super(InvalidUpdateExpression, self).__init__(
self.invalid_update_expr_msg.format(
update_expression_error=update_expression_error
)
)
class AttributeDoesNotExist(MockValidationException):
attr_does_not_exist_msg = (
"The provided expression refers to an attribute that does not exist in the item"
)
def __init__(self):
super(AttributeDoesNotExist, self).__init__(self.attr_does_not_exist_msg)
class ExpressionAttributeNameNotDefined(InvalidUpdateExpression):
name_not_defined_msg = "An expression attribute name used in the document path is not defined; attribute name: {n}"
def __init__(self, attribute_name):
self.not_defined_attribute_name = attribute_name
super(ExpressionAttributeNameNotDefined, self).__init__(
self.name_not_defined_msg.format(n=attribute_name)
)
class AttributeIsReservedKeyword(InvalidUpdateExpression):
attribute_is_keyword_msg = (
"Attribute name is a reserved keyword; reserved keyword: {keyword}"
)
def __init__(self, keyword):
self.keyword = keyword
super(AttributeIsReservedKeyword, self).__init__(
self.attribute_is_keyword_msg.format(keyword=keyword)
)
class ExpressionAttributeValueNotDefined(InvalidUpdateExpression):
attr_value_not_defined_msg = "An expression attribute value used in expression is not defined; attribute value: {attribute_value}"
def __init__(self, attribute_value):
self.attribute_value = attribute_value
super(ExpressionAttributeValueNotDefined, self).__init__(
self.attr_value_not_defined_msg.format(attribute_value=attribute_value)
)
class UpdateExprSyntaxError(InvalidUpdateExpression):
update_expr_syntax_error_msg = "Syntax error; {error_detail}"
def __init__(self, error_detail):
self.error_detail = error_detail
super(UpdateExprSyntaxError, self).__init__(
self.update_expr_syntax_error_msg.format(error_detail=error_detail)
)
class InvalidTokenException(UpdateExprSyntaxError):
token_detail_msg = 'token: "{token}", near: "{near}"'
def __init__(self, token, near):
self.token = token
self.near = near
super(InvalidTokenException, self).__init__(
self.token_detail_msg.format(token=token, near=near)
)
class InvalidExpressionAttributeNameKey(MockValidationException):
invalid_expr_attr_name_msg = (
'ExpressionAttributeNames contains invalid key: Syntax error; key: "{key}"'
)
def __init__(self, key):
self.key = key
super(InvalidExpressionAttributeNameKey, self).__init__(
self.invalid_expr_attr_name_msg.format(key=key)
)
class ItemSizeTooLarge(MockValidationException):
item_size_too_large_msg = "Item size has exceeded the maximum allowed size"
def __init__(self):
super(ItemSizeTooLarge, self).__init__(self.item_size_too_large_msg)
class ItemSizeToUpdateTooLarge(MockValidationException):
item_size_to_update_too_large_msg = (
"Item size to update has exceeded the maximum allowed size"
)
def __init__(self):
super(ItemSizeToUpdateTooLarge, self).__init__(
self.item_size_to_update_too_large_msg
)
class IncorrectOperandType(InvalidUpdateExpression):
inv_operand_msg = "Incorrect operand type for operator or function; operator or function: {f}, operand type: {t}"
def __init__(self, operator_or_function, operand_type):
self.operator_or_function = operator_or_function
self.operand_type = operand_type
super(IncorrectOperandType, self).__init__(
self.inv_operand_msg.format(f=operator_or_function, t=operand_type)
)

View File

@ -6,7 +6,6 @@ import decimal
import json
import re
import uuid
import six
from boto3 import Session
from botocore.exceptions import ParamValidationError
@ -14,10 +13,17 @@ from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time
from moto.core.exceptions import JsonRESTError
from .comparisons import get_comparison_func
from .comparisons import get_filter_expression
from .comparisons import get_expected
from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSizeTooLarge
from moto.dynamodb2.comparisons import get_filter_expression
from moto.dynamodb2.comparisons import get_expected
from moto.dynamodb2.exceptions import (
InvalidIndexNameError,
ItemSizeTooLarge,
ItemSizeToUpdateTooLarge,
)
from moto.dynamodb2.models.utilities import bytesize, attribute_is_list
from moto.dynamodb2.models.dynamo_type import DynamoType
from moto.dynamodb2.parsing.expressions import UpdateExpressionParser
from moto.dynamodb2.parsing.validators import UpdateExpressionValidator
class DynamoJsonEncoder(json.JSONEncoder):
@ -30,223 +36,6 @@ def dynamo_json_dump(dynamo_object):
return json.dumps(dynamo_object, cls=DynamoJsonEncoder)
def bytesize(val):
return len(str(val).encode("utf-8"))
def attribute_is_list(attr):
"""
Checks if attribute denotes a list, and returns the name of the list and the given list index if so
:param attr: attr or attr[index]
:return: attr, index or None
"""
list_index_update = re.match("(.+)\\[([0-9]+)\\]", attr)
if list_index_update:
attr = list_index_update.group(1)
return attr, list_index_update.group(2) if list_index_update else None
class DynamoType(object):
"""
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
"""
def __init__(self, type_as_dict):
if type(type_as_dict) == DynamoType:
self.type = type_as_dict.type
self.value = type_as_dict.value
else:
self.type = list(type_as_dict)[0]
self.value = list(type_as_dict.values())[0]
if self.is_list():
self.value = [DynamoType(val) for val in self.value]
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def get(self, key):
if not key:
return self
else:
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
if key_head not in self.value:
self.value[key_head] = DynamoType({"NONE": None})
return self.value[key_head].get(key_tail)
def set(self, key, new_value, index=None):
if index:
index = int(index)
if type(self.value) is not list:
raise InvalidUpdateExpression
if index >= len(self.value):
self.value.append(new_value)
# {'L': [DynamoType, ..]} ==> DynamoType.set()
self.value[min(index, len(self.value) - 1)].set(key, new_value)
else:
attr = (key or "").split(".").pop(0)
attr, list_index = attribute_is_list(attr)
if not key:
# {'S': value} ==> {'S': new_value}
self.type = new_value.type
self.value = new_value.value
else:
if attr not in self.value: # nonexistingattribute
type_of_new_attr = "M" if "." in key else new_value.type
self.value[attr] = DynamoType({type_of_new_attr: {}})
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
self.value[attr].set(
".".join(key.split(".")[1:]), new_value, list_index
)
def delete(self, key, index=None):
if index:
if not key:
if int(index) < len(self.value):
del self.value[int(index)]
elif "." in key:
self.value[int(index)].delete(".".join(key.split(".")[1:]))
else:
self.value[int(index)].delete(key)
else:
attr = key.split(".")[0]
attr, list_index = attribute_is_list(attr)
if list_index:
self.value[attr].delete(".".join(key.split(".")[1:]), list_index)
elif "." in key:
self.value[attr].delete(".".join(key.split(".")[1:]))
else:
self.value.pop(key)
def filter(self, projection_expressions):
nested_projections = [
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
]
if self.is_map():
expressions_to_delete = []
for attr in self.value:
if (
attr not in projection_expressions
and attr not in nested_projections
):
expressions_to_delete.append(attr)
elif attr in nested_projections:
relevant_expressions = [
expr[len(attr + ".") :]
for expr in projection_expressions
if expr.startswith(attr + ".")
]
self.value[attr].filter(relevant_expressions)
for expr in expressions_to_delete:
self.value.pop(expr)
def __hash__(self):
return hash((self.type, self.value))
def __eq__(self, other):
return self.type == other.type and self.value == other.value
def __ne__(self, other):
return self.type != other.type or self.value != other.value
def __lt__(self, other):
return self.cast_value < other.cast_value
def __le__(self, other):
return self.cast_value <= other.cast_value
def __gt__(self, other):
return self.cast_value > other.cast_value
def __ge__(self, other):
return self.cast_value >= other.cast_value
def __repr__(self):
return "DynamoType: {0}".format(self.to_json())
@property
def cast_value(self):
if self.is_number():
try:
return int(self.value)
except ValueError:
return float(self.value)
elif self.is_set():
sub_type = self.type[0]
return set([DynamoType({sub_type: v}).cast_value for v in self.value])
elif self.is_list():
return [DynamoType(v).cast_value for v in self.value]
elif self.is_map():
return dict([(k, DynamoType(v).cast_value) for k, v in self.value.items()])
else:
return self.value
def child_attr(self, key):
"""
Get Map or List children by key. str for Map, int for List.
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map():
if "." in key and key.split(".")[0] in self.value:
return self.value[key.split(".")[0]].child_attr(
".".join(key.split(".")[1:])
)
elif "." not in key and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
if 0 <= idx < len(self.value):
return DynamoType(self.value[idx])
return None
def size(self):
if self.is_number():
value_size = len(str(self.value))
elif self.is_set():
sub_type = self.type[0]
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
elif self.is_list():
value_size = sum([v.size() for v in self.value])
elif self.is_map():
value_size = sum(
[bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]
)
elif type(self.value) == bool:
value_size = 1
else:
value_size = bytesize(self.value)
return value_size
def to_json(self):
return {self.type: self.value}
def compare(self, range_comparison, range_objs):
"""
Compares this type against comparison filters
"""
range_values = [obj.cast_value for obj in range_objs]
comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.cast_value, *range_values)
def is_number(self):
return self.type == "N"
def is_set(self):
return self.type == "SS" or self.type == "NS" or self.type == "BS"
def is_list(self):
return self.type == "L"
def is_map(self):
return self.type == "M"
def same_type(self, other):
return self.type == other.type
# https://github.com/spulec/moto/issues/1874
# Ensure that the total size of an item does not exceed 400kb
class LimitedSizeDict(dict):
@ -285,6 +74,9 @@ class Item(BaseModel):
def __repr__(self):
return "Item: {0}".format(self.to_json())
def size(self):
return sum(bytesize(key) + value.size() for key, value in self.attrs.items())
def to_json(self):
attributes = {}
for attribute_key, attribute in self.attrs.items():
@ -367,7 +159,10 @@ class Item(BaseModel):
if "." in key and attr not in self.attrs:
raise ValueError # Setting nested attr not allowed if first attr does not exist yet
elif attr not in self.attrs:
try:
self.attrs[attr] = dyn_value # set new top-level attribute
except ItemSizeTooLarge:
raise ItemSizeToUpdateTooLarge()
else:
self.attrs[attr].set(
".".join(key.split(".")[1:]), dyn_value, list_index
@ -1129,6 +924,14 @@ class Table(BaseModel):
break
last_evaluated_key = None
size_limit = 1000000 # DynamoDB has a 1MB size limit
item_size = sum(res.size() for res in results)
if item_size > size_limit:
item_size = idx = 0
while item_size + results[idx].size() < size_limit:
item_size += results[idx].size()
idx += 1
limit = min(limit, idx) if limit else idx
if limit and len(results) > limit:
results = results[:limit]
last_evaluated_key = {self.hash_key_attr: results[-1].hash_key}
@ -1414,6 +1217,13 @@ class DynamoDBBackend(BaseBackend):
):
table = self.get_table(table_name)
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
if update_expression:
# Parse expression to get validation errors
update_expression_ast = UpdateExpressionParser.make(update_expression)
update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression)
if all([table.hash_key_attr in key, table.range_key_attr in key]):
# Covers cases where table has hash and range keys, ``key`` param
# will be a dict
@ -1456,6 +1266,12 @@ class DynamoDBBackend(BaseBackend):
item = table.get_item(hash_value, range_value)
if update_expression:
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=expression_attribute_names,
expression_attribute_values=expression_attribute_values,
item=item,
).validate()
item.update(
update_expression,
expression_attribute_names,

View File

@ -0,0 +1,237 @@
import six
from moto.dynamodb2.comparisons import get_comparison_func
from moto.dynamodb2.exceptions import InvalidUpdateExpression
from moto.dynamodb2.models.utilities import attribute_is_list, bytesize
class DynamoType(object):
"""
http://docs.aws.amazon.com/amazondynamodb/latest/developerguide/DataModel.html#DataModelDataTypes
"""
def __init__(self, type_as_dict):
if type(type_as_dict) == DynamoType:
self.type = type_as_dict.type
self.value = type_as_dict.value
else:
self.type = list(type_as_dict)[0]
self.value = list(type_as_dict.values())[0]
if self.is_list():
self.value = [DynamoType(val) for val in self.value]
elif self.is_map():
self.value = dict((k, DynamoType(v)) for k, v in self.value.items())
def get(self, key):
if not key:
return self
else:
key_head = key.split(".")[0]
key_tail = ".".join(key.split(".")[1:])
if key_head not in self.value:
self.value[key_head] = DynamoType({"NONE": None})
return self.value[key_head].get(key_tail)
def set(self, key, new_value, index=None):
if index:
index = int(index)
if type(self.value) is not list:
raise InvalidUpdateExpression
if index >= len(self.value):
self.value.append(new_value)
# {'L': [DynamoType, ..]} ==> DynamoType.set()
self.value[min(index, len(self.value) - 1)].set(key, new_value)
else:
attr = (key or "").split(".").pop(0)
attr, list_index = attribute_is_list(attr)
if not key:
# {'S': value} ==> {'S': new_value}
self.type = new_value.type
self.value = new_value.value
else:
if attr not in self.value: # nonexistingattribute
type_of_new_attr = "M" if "." in key else new_value.type
self.value[attr] = DynamoType({type_of_new_attr: {}})
# {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value)
self.value[attr].set(
".".join(key.split(".")[1:]), new_value, list_index
)
def delete(self, key, index=None):
if index:
if not key:
if int(index) < len(self.value):
del self.value[int(index)]
elif "." in key:
self.value[int(index)].delete(".".join(key.split(".")[1:]))
else:
self.value[int(index)].delete(key)
else:
attr = key.split(".")[0]
attr, list_index = attribute_is_list(attr)
if list_index:
self.value[attr].delete(".".join(key.split(".")[1:]), list_index)
elif "." in key:
self.value[attr].delete(".".join(key.split(".")[1:]))
else:
self.value.pop(key)
def filter(self, projection_expressions):
nested_projections = [
expr[0 : expr.index(".")] for expr in projection_expressions if "." in expr
]
if self.is_map():
expressions_to_delete = []
for attr in self.value:
if (
attr not in projection_expressions
and attr not in nested_projections
):
expressions_to_delete.append(attr)
elif attr in nested_projections:
relevant_expressions = [
expr[len(attr + ".") :]
for expr in projection_expressions
if expr.startswith(attr + ".")
]
self.value[attr].filter(relevant_expressions)
for expr in expressions_to_delete:
self.value.pop(expr)
def __hash__(self):
return hash((self.type, self.value))
def __eq__(self, other):
return self.type == other.type and self.value == other.value
def __ne__(self, other):
return self.type != other.type or self.value != other.value
def __lt__(self, other):
return self.cast_value < other.cast_value
def __le__(self, other):
return self.cast_value <= other.cast_value
def __gt__(self, other):
return self.cast_value > other.cast_value
def __ge__(self, other):
return self.cast_value >= other.cast_value
def __repr__(self):
return "DynamoType: {0}".format(self.to_json())
def __add__(self, other):
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.type == "N":
return DynamoType({"N": "{v}".format(v=int(self.value) + int(other.value))})
else:
raise TypeError("Sum only supported for Numbers.")
def __sub__(self, other):
if self.type != other.type:
raise TypeError("Different types of operandi is not allowed.")
if self.type == "N":
return DynamoType({"N": "{v}".format(v=int(self.value) - int(other.value))})
else:
raise TypeError("Sum only supported for Numbers.")
def __getitem__(self, item):
if isinstance(item, six.string_types):
# If our DynamoType is a map it should be subscriptable with a key
if self.type == "M":
return self.value[item]
elif isinstance(item, int):
# If our DynamoType is a list is should be subscriptable with an index
if self.type == "L":
return self.value[item]
raise TypeError(
"This DynamoType {dt} is not subscriptable by a {it}".format(
dt=self.type, it=type(item)
)
)
@property
def cast_value(self):
if self.is_number():
try:
return int(self.value)
except ValueError:
return float(self.value)
elif self.is_set():
sub_type = self.type[0]
return set([DynamoType({sub_type: v}).cast_value for v in self.value])
elif self.is_list():
return [DynamoType(v).cast_value for v in self.value]
elif self.is_map():
return dict([(k, DynamoType(v).cast_value) for k, v in self.value.items()])
else:
return self.value
def child_attr(self, key):
"""
Get Map or List children by key. str for Map, int for List.
Returns DynamoType or None.
"""
if isinstance(key, six.string_types) and self.is_map():
if "." in key and key.split(".")[0] in self.value:
return self.value[key.split(".")[0]].child_attr(
".".join(key.split(".")[1:])
)
elif "." not in key and key in self.value:
return DynamoType(self.value[key])
if isinstance(key, int) and self.is_list():
idx = key
if 0 <= idx < len(self.value):
return DynamoType(self.value[idx])
return None
def size(self):
if self.is_number():
value_size = len(str(self.value))
elif self.is_set():
sub_type = self.type[0]
value_size = sum([DynamoType({sub_type: v}).size() for v in self.value])
elif self.is_list():
value_size = sum([v.size() for v in self.value])
elif self.is_map():
value_size = sum(
[bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]
)
elif type(self.value) == bool:
value_size = 1
else:
value_size = bytesize(self.value)
return value_size
def to_json(self):
return {self.type: self.value}
def compare(self, range_comparison, range_objs):
"""
Compares this type against comparison filters
"""
range_values = [obj.cast_value for obj in range_objs]
comparison_func = get_comparison_func(range_comparison)
return comparison_func(self.cast_value, *range_values)
def is_number(self):
return self.type == "N"
def is_set(self):
return self.type == "SS" or self.type == "NS" or self.type == "BS"
def is_list(self):
return self.type == "L"
def is_map(self):
return self.type == "M"
def same_type(self, other):
return self.type == other.type

View File

@ -0,0 +1,17 @@
import re
def bytesize(val):
return len(str(val).encode("utf-8"))
def attribute_is_list(attr):
"""
Checks if attribute denotes a list, and returns the name of the list and the given list index if so
:param attr: attr or attr[index]
:return: attr, index or None
"""
list_index_update = re.match("(.+)\\[([0-9]+)\\]", attr)
if list_index_update:
attr = list_index_update.group(1)
return attr, list_index_update.group(2) if list_index_update else None

View File

@ -0,0 +1,23 @@
# Parsing dev documentation
Parsing happens in a structured manner and happens in different phases.
This document explains these phases.
## 1) Expression gets parsed into a tokenlist (tokenized)
A string gets parsed from left to right and gets converted into a list of tokens.
The tokens are available in `tokens.py`.
## 2) Tokenlist get transformed to expression tree (AST)
This is the parsing of the token list. This parsing will result in an Abstract Syntax Tree (AST).
The different node types are available in `ast_nodes.py`. The AST is a representation that has all
the information that is in the expression but its tree form allows processing it in a structured manner.
## 3) The AST gets validated (full semantic correctness)
The AST is used for validation. The paths and attributes are validated to be correct. At the end of the
validation all the values will be resolved.
## 4) Update Expression gets executed using the validated AST
Finally the AST is used to execute the update expression. There should be no reason for this step to fail
since validation has completed. Due to this we have the update expressions behaving atomically (i.e. all the
actions of the update expresion are performed or none of them are performed).

View File

View File

@ -0,0 +1,360 @@
import abc
from abc import abstractmethod
from collections import deque
import six
from moto.dynamodb2.models import DynamoType
@six.add_metaclass(abc.ABCMeta)
class Node:
def __init__(self, children=None):
self.type = self.__class__.__name__
assert children is None or isinstance(children, list)
self.children = children
self.parent = None
if isinstance(children, list):
for child in children:
if isinstance(child, Node):
child.set_parent(self)
def set_parent(self, parent_node):
self.parent = parent_node
class LeafNode(Node):
"""A LeafNode is a Node where none of the children are Nodes themselves."""
def __init__(self, children=None):
super(LeafNode, self).__init__(children)
@six.add_metaclass(abc.ABCMeta)
class Expression(Node):
"""
Abstract Syntax Tree representing the expression
For the Grammar start here and jump down into the classes at the righ-hand side to look further. Nodes marked with
a star are abstract and won't appear in the final AST.
Expression* => UpdateExpression
Expression* => ConditionExpression
"""
class UpdateExpression(Expression):
"""
UpdateExpression => UpdateExpressionClause*
UpdateExpression => UpdateExpressionClause* UpdateExpression
"""
@six.add_metaclass(abc.ABCMeta)
class UpdateExpressionClause(UpdateExpression):
"""
UpdateExpressionClause* => UpdateExpressionSetClause
UpdateExpressionClause* => UpdateExpressionRemoveClause
UpdateExpressionClause* => UpdateExpressionAddClause
UpdateExpressionClause* => UpdateExpressionDeleteClause
"""
class UpdateExpressionSetClause(UpdateExpressionClause):
"""
UpdateExpressionSetClause => SET SetActions
"""
class UpdateExpressionSetActions(UpdateExpressionClause):
"""
UpdateExpressionSetClause => SET SetActions
SetActions => SetAction
SetActions => SetAction , SetActions
"""
class UpdateExpressionSetAction(UpdateExpressionClause):
"""
SetAction => Path = Value
"""
class UpdateExpressionRemoveActions(UpdateExpressionClause):
"""
UpdateExpressionSetClause => REMOVE RemoveActions
RemoveActions => RemoveAction
RemoveActions => RemoveAction , RemoveActions
"""
class UpdateExpressionRemoveAction(UpdateExpressionClause):
"""
RemoveAction => Path
"""
class UpdateExpressionAddActions(UpdateExpressionClause):
"""
UpdateExpressionAddClause => ADD RemoveActions
AddActions => AddAction
AddActions => AddAction , AddActions
"""
class UpdateExpressionAddAction(UpdateExpressionClause):
"""
AddAction => Path Value
"""
class UpdateExpressionDeleteActions(UpdateExpressionClause):
"""
UpdateExpressionDeleteClause => DELETE RemoveActions
DeleteActions => DeleteAction
DeleteActions => DeleteAction , DeleteActions
"""
class UpdateExpressionDeleteAction(UpdateExpressionClause):
"""
DeleteAction => Path Value
"""
class UpdateExpressionPath(UpdateExpressionClause):
pass
class UpdateExpressionValue(UpdateExpressionClause):
"""
Value => Operand
Value => Operand + Value
Value => Operand - Value
"""
class UpdateExpressionGroupedValue(UpdateExpressionClause):
"""
GroupedValue => ( Value )
"""
class UpdateExpressionRemoveClause(UpdateExpressionClause):
"""
UpdateExpressionRemoveClause => REMOVE RemoveActions
"""
class UpdateExpressionAddClause(UpdateExpressionClause):
"""
UpdateExpressionAddClause => ADD AddActions
"""
class UpdateExpressionDeleteClause(UpdateExpressionClause):
"""
UpdateExpressionDeleteClause => DELETE DeleteActions
"""
class ExpressionPathDescender(Node):
"""Node identifying descender into nested structure (.) in expression"""
class ExpressionSelector(LeafNode):
"""Node identifying selector [selection_index] in expresion"""
def __init__(self, selection_index):
try:
super(ExpressionSelector, self).__init__(children=[int(selection_index)])
except ValueError:
assert (
False
), "Expression selector must be an int, this is a bug in the moto library."
def get_index(self):
return self.children[0]
class ExpressionAttribute(LeafNode):
"""An attribute identifier as used in the DDB item"""
def __init__(self, attribute):
super(ExpressionAttribute, self).__init__(children=[attribute])
def get_attribute_name(self):
return self.children[0]
class ExpressionAttributeName(LeafNode):
"""An ExpressionAttributeName is an alias for an attribute identifier"""
def __init__(self, attribute_name):
super(ExpressionAttributeName, self).__init__(children=[attribute_name])
def get_attribute_name_placeholder(self):
return self.children[0]
class ExpressionAttributeValue(LeafNode):
"""An ExpressionAttributeValue is an alias for an value"""
def __init__(self, value):
super(ExpressionAttributeValue, self).__init__(children=[value])
def get_value_name(self):
return self.children[0]
class ExpressionValueOperator(LeafNode):
"""An ExpressionValueOperator is an operation that works on 2 values"""
def __init__(self, value):
super(ExpressionValueOperator, self).__init__(children=[value])
def get_operator(self):
return self.children[0]
class UpdateExpressionFunction(Node):
"""
A Node representing a function of an Update Expression. The first child is the function name the others are the
arguments.
"""
def get_function_name(self):
return self.children[0]
def get_nth_argument(self, n=1):
"""Return nth element where n is a 1-based index."""
assert n >= 1
return self.children[n]
class DDBTypedValue(Node):
"""
A node representing a DDBTyped value. This can be any structure as supported by DyanmoDB. The node only has 1 child
which is the value of type `DynamoType`.
"""
def __init__(self, value):
assert isinstance(value, DynamoType), "DDBTypedValue must be of DynamoType"
super(DDBTypedValue, self).__init__(children=[value])
def get_value(self):
return self.children[0]
class NoneExistingPath(LeafNode):
"""A placeholder for Paths that did not exist in the Item."""
def __init__(self, creatable=False):
super(NoneExistingPath, self).__init__(children=[creatable])
def is_creatable(self):
"""Can this path be created if need be. For example path creating element in a dictionary or creating a new
attribute under root level of an item."""
return self.children[0]
class DepthFirstTraverser(object):
"""
Helper class that allows depth first traversal and to implement custom processing for certain AST nodes. The
processor of a node must return the new resulting node. This node will be placed in the tree. Processing of a
node using this traverser should therefore only transform child nodes. The returned node will get the same parent
as the node before processing had.
"""
@abstractmethod
def _processing_map(self):
"""
A map providing a processing function per node class type to a function that takes in a Node object and
processes it. A Node can only be processed by a single function and they are considered in order. Therefore if
multiple classes from a single class hierarchy strain are used the more specific classes have to be put before
the less specific ones. That requires overriding `nodes_to_be_processed`. If no multiple classes form a single
class hierarchy strain are used the default implementation of `nodes_to_be_processed` should be OK.
Returns:
dict: Mapping a Node Class to a processing function.
"""
pass
def nodes_to_be_processed(self):
"""Cached accessor for getting Node types that need to be processed."""
return tuple(k for k in self._processing_map().keys())
def process(self, node):
"""Process a Node"""
for class_key, processor in self._processing_map().items():
if isinstance(node, class_key):
return processor(node)
def pre_processing_of_child(self, parent_node, child_id):
"""Hook that is called pre-processing of the child at position `child_id`"""
pass
def traverse_node_recursively(self, node, child_id=-1):
"""
Traverse nodes depth first processing nodes bottom up (if root node is considered the top).
Args:
node(Node): The node which is the last node to be processed but which allows to identify all the
work (which is in the children)
child_id(int): The index in the list of children from the parent that this node corresponds to
Returns:
Node: The node of the new processed AST
"""
if isinstance(node, Node):
parent_node = node.parent
if node.children is not None:
for i, child_node in enumerate(node.children):
self.pre_processing_of_child(node, i)
self.traverse_node_recursively(child_node, i)
# noinspection PyTypeChecker
if isinstance(node, self.nodes_to_be_processed()):
node = self.process(node)
node.parent = parent_node
parent_node.children[child_id] = node
return node
def traverse(self, node):
return self.traverse_node_recursively(node)
class NodeDepthLeftTypeFetcher(object):
"""Helper class to fetch a node of a specific type. Depth left-first traversal"""
def __init__(self, node_type, root_node):
assert issubclass(node_type, Node)
self.node_type = node_type
self.root_node = root_node
self.queue = deque()
self.add_nodes_left_to_right_depth_first(self.root_node)
def add_nodes_left_to_right_depth_first(self, node):
if isinstance(node, Node) and node.children is not None:
for child_node in node.children:
self.add_nodes_left_to_right_depth_first(child_node)
self.queue.append(child_node)
self.queue.append(node)
def __iter__(self):
return self
def next(self):
return self.__next__()
def __next__(self):
while len(self.queue) > 0:
candidate = self.queue.popleft()
if isinstance(candidate, self.node_type):
return candidate
else:
raise StopIteration

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,29 @@
class ReservedKeywords(list):
"""
DynamoDB has an extensive list of keywords. Keywords are considered when validating the expression Tree.
Not earlier since an update expression like "SET path = VALUE 1" fails with:
'Invalid UpdateExpression: Syntax error; token: "1", near: "VALUE 1"'
"""
KEYWORDS = None
@classmethod
def get_reserved_keywords(cls):
if cls.KEYWORDS is None:
cls.KEYWORDS = cls._get_reserved_keywords()
return cls.KEYWORDS
@classmethod
def _get_reserved_keywords(cls):
"""
Get a list of reserved keywords of DynamoDB
"""
try:
import importlib.resources as pkg_resources
except ImportError:
import importlib_resources as pkg_resources
reserved_keywords = pkg_resources.read_text(
"moto.dynamodb2.parsing", "reserved_keywords.txt"
)
return reserved_keywords.split()

View File

@ -0,0 +1,573 @@
ABORT
ABSOLUTE
ACTION
ADD
AFTER
AGENT
AGGREGATE
ALL
ALLOCATE
ALTER
ANALYZE
AND
ANY
ARCHIVE
ARE
ARRAY
AS
ASC
ASCII
ASENSITIVE
ASSERTION
ASYMMETRIC
AT
ATOMIC
ATTACH
ATTRIBUTE
AUTH
AUTHORIZATION
AUTHORIZE
AUTO
AVG
BACK
BACKUP
BASE
BATCH
BEFORE
BEGIN
BETWEEN
BIGINT
BINARY
BIT
BLOB
BLOCK
BOOLEAN
BOTH
BREADTH
BUCKET
BULK
BY
BYTE
CALL
CALLED
CALLING
CAPACITY
CASCADE
CASCADED
CASE
CAST
CATALOG
CHAR
CHARACTER
CHECK
CLASS
CLOB
CLOSE
CLUSTER
CLUSTERED
CLUSTERING
CLUSTERS
COALESCE
COLLATE
COLLATION
COLLECTION
COLUMN
COLUMNS
COMBINE
COMMENT
COMMIT
COMPACT
COMPILE
COMPRESS
CONDITION
CONFLICT
CONNECT
CONNECTION
CONSISTENCY
CONSISTENT
CONSTRAINT
CONSTRAINTS
CONSTRUCTOR
CONSUMED
CONTINUE
CONVERT
COPY
CORRESPONDING
COUNT
COUNTER
CREATE
CROSS
CUBE
CURRENT
CURSOR
CYCLE
DATA
DATABASE
DATE
DATETIME
DAY
DEALLOCATE
DEC
DECIMAL
DECLARE
DEFAULT
DEFERRABLE
DEFERRED
DEFINE
DEFINED
DEFINITION
DELETE
DELIMITED
DEPTH
DEREF
DESC
DESCRIBE
DESCRIPTOR
DETACH
DETERMINISTIC
DIAGNOSTICS
DIRECTORIES
DISABLE
DISCONNECT
DISTINCT
DISTRIBUTE
DO
DOMAIN
DOUBLE
DROP
DUMP
DURATION
DYNAMIC
EACH
ELEMENT
ELSE
ELSEIF
EMPTY
ENABLE
END
EQUAL
EQUALS
ERROR
ESCAPE
ESCAPED
EVAL
EVALUATE
EXCEEDED
EXCEPT
EXCEPTION
EXCEPTIONS
EXCLUSIVE
EXEC
EXECUTE
EXISTS
EXIT
EXPLAIN
EXPLODE
EXPORT
EXPRESSION
EXTENDED
EXTERNAL
EXTRACT
FAIL
FALSE
FAMILY
FETCH
FIELDS
FILE
FILTER
FILTERING
FINAL
FINISH
FIRST
FIXED
FLATTERN
FLOAT
FOR
FORCE
FOREIGN
FORMAT
FORWARD
FOUND
FREE
FROM
FULL
FUNCTION
FUNCTIONS
GENERAL
GENERATE
GET
GLOB
GLOBAL
GO
GOTO
GRANT
GREATER
GROUP
GROUPING
HANDLER
HASH
HAVE
HAVING
HEAP
HIDDEN
HOLD
HOUR
IDENTIFIED
IDENTITY
IF
IGNORE
IMMEDIATE
IMPORT
IN
INCLUDING
INCLUSIVE
INCREMENT
INCREMENTAL
INDEX
INDEXED
INDEXES
INDICATOR
INFINITE
INITIALLY
INLINE
INNER
INNTER
INOUT
INPUT
INSENSITIVE
INSERT
INSTEAD
INT
INTEGER
INTERSECT
INTERVAL
INTO
INVALIDATE
IS
ISOLATION
ITEM
ITEMS
ITERATE
JOIN
KEY
KEYS
LAG
LANGUAGE
LARGE
LAST
LATERAL
LEAD
LEADING
LEAVE
LEFT
LENGTH
LESS
LEVEL
LIKE
LIMIT
LIMITED
LINES
LIST
LOAD
LOCAL
LOCALTIME
LOCALTIMESTAMP
LOCATION
LOCATOR
LOCK
LOCKS
LOG
LOGED
LONG
LOOP
LOWER
MAP
MATCH
MATERIALIZED
MAX
MAXLEN
MEMBER
MERGE
METHOD
METRICS
MIN
MINUS
MINUTE
MISSING
MOD
MODE
MODIFIES
MODIFY
MODULE
MONTH
MULTI
MULTISET
NAME
NAMES
NATIONAL
NATURAL
NCHAR
NCLOB
NEW
NEXT
NO
NONE
NOT
NULL
NULLIF
NUMBER
NUMERIC
OBJECT
OF
OFFLINE
OFFSET
OLD
ON
ONLINE
ONLY
OPAQUE
OPEN
OPERATOR
OPTION
OR
ORDER
ORDINALITY
OTHER
OTHERS
OUT
OUTER
OUTPUT
OVER
OVERLAPS
OVERRIDE
OWNER
PAD
PARALLEL
PARAMETER
PARAMETERS
PARTIAL
PARTITION
PARTITIONED
PARTITIONS
PATH
PERCENT
PERCENTILE
PERMISSION
PERMISSIONS
PIPE
PIPELINED
PLAN
POOL
POSITION
PRECISION
PREPARE
PRESERVE
PRIMARY
PRIOR
PRIVATE
PRIVILEGES
PROCEDURE
PROCESSED
PROJECT
PROJECTION
PROPERTY
PROVISIONING
PUBLIC
PUT
QUERY
QUIT
QUORUM
RAISE
RANDOM
RANGE
RANK
RAW
READ
READS
REAL
REBUILD
RECORD
RECURSIVE
REDUCE
REF
REFERENCE
REFERENCES
REFERENCING
REGEXP
REGION
REINDEX
RELATIVE
RELEASE
REMAINDER
RENAME
REPEAT
REPLACE
REQUEST
RESET
RESIGNAL
RESOURCE
RESPONSE
RESTORE
RESTRICT
RESULT
RETURN
RETURNING
RETURNS
REVERSE
REVOKE
RIGHT
ROLE
ROLES
ROLLBACK
ROLLUP
ROUTINE
ROW
ROWS
RULE
RULES
SAMPLE
SATISFIES
SAVE
SAVEPOINT
SCAN
SCHEMA
SCOPE
SCROLL
SEARCH
SECOND
SECTION
SEGMENT
SEGMENTS
SELECT
SELF
SEMI
SENSITIVE
SEPARATE
SEQUENCE
SERIALIZABLE
SESSION
SET
SETS
SHARD
SHARE
SHARED
SHORT
SHOW
SIGNAL
SIMILAR
SIZE
SKEWED
SMALLINT
SNAPSHOT
SOME
SOURCE
SPACE
SPACES
SPARSE
SPECIFIC
SPECIFICTYPE
SPLIT
SQL
SQLCODE
SQLERROR
SQLEXCEPTION
SQLSTATE
SQLWARNING
START
STATE
STATIC
STATUS
STORAGE
STORE
STORED
STREAM
STRING
STRUCT
STYLE
SUB
SUBMULTISET
SUBPARTITION
SUBSTRING
SUBTYPE
SUM
SUPER
SYMMETRIC
SYNONYM
SYSTEM
TABLE
TABLESAMPLE
TEMP
TEMPORARY
TERMINATED
TEXT
THAN
THEN
THROUGHPUT
TIME
TIMESTAMP
TIMEZONE
TINYINT
TO
TOKEN
TOTAL
TOUCH
TRAILING
TRANSACTION
TRANSFORM
TRANSLATE
TRANSLATION
TREAT
TRIGGER
TRIM
TRUE
TRUNCATE
TTL
TUPLE
TYPE
UNDER
UNDO
UNION
UNIQUE
UNIT
UNKNOWN
UNLOGGED
UNNEST
UNPROCESSED
UNSIGNED
UNTIL
UPDATE
UPPER
URL
USAGE
USE
USER
USERS
USING
UUID
VACUUM
VALUE
VALUED
VALUES
VARCHAR
VARIABLE
VARIANCE
VARINT
VARYING
VIEW
VIEWS
VIRTUAL
VOID
WAIT
WHEN
WHENEVER
WHERE
WHILE
WINDOW
WITH
WITHIN
WITHOUT
WORK
WRAPPED
WRITE
YEAR
ZONE

View File

@ -0,0 +1,223 @@
import re
import sys
from moto.dynamodb2.exceptions import (
InvalidTokenException,
InvalidExpressionAttributeNameKey,
)
class Token(object):
_TOKEN_INSTANCE = None
MINUS_SIGN = "-"
PLUS_SIGN = "+"
SPACE_SIGN = " "
EQUAL_SIGN = "="
OPEN_ROUND_BRACKET = "("
CLOSE_ROUND_BRACKET = ")"
COMMA = ","
SPACE = " "
DOT = "."
OPEN_SQUARE_BRACKET = "["
CLOSE_SQUARE_BRACKET = "]"
SPECIAL_CHARACTERS = [
MINUS_SIGN,
PLUS_SIGN,
SPACE_SIGN,
EQUAL_SIGN,
OPEN_ROUND_BRACKET,
CLOSE_ROUND_BRACKET,
COMMA,
SPACE,
DOT,
OPEN_SQUARE_BRACKET,
CLOSE_SQUARE_BRACKET,
]
# Attribute: an identifier that is an attribute
ATTRIBUTE = 0
# Place holder for attribute name
ATTRIBUTE_NAME = 1
# Placeholder for attribute value starts with :
ATTRIBUTE_VALUE = 2
# WhiteSpace shall be grouped together
WHITESPACE = 3
# Placeholder for a number
NUMBER = 4
PLACEHOLDER_NAMES = {
ATTRIBUTE: "Attribute",
ATTRIBUTE_NAME: "AttributeName",
ATTRIBUTE_VALUE: "AttributeValue",
WHITESPACE: "Whitespace",
NUMBER: "Number",
}
def __init__(self, token_type, value):
assert (
token_type in self.SPECIAL_CHARACTERS
or token_type in self.PLACEHOLDER_NAMES
)
self.type = token_type
self.value = value
def __repr__(self):
if isinstance(self.type, int):
return 'Token("{tt}", "{tv}")'.format(
tt=self.PLACEHOLDER_NAMES[self.type], tv=self.value
)
else:
return 'Token("{tt}", "{tv}")'.format(tt=self.type, tv=self.value)
def __eq__(self, other):
return self.type == other.type and self.value == other.value
class ExpressionTokenizer(object):
"""
Takes a string and returns a list of tokens. While attribute names in DynamoDB must be between 1 and 255 characters
long there are no other restrictions for attribute names. For expressions however there are additional rules. If an
attribute name does not adhere then it must be passed via an ExpressionAttributeName. This tokenizer is aware of the
rules of Expression attributes.
We consider a Token as a tuple which has the tokenType
From https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
1) If an attribute name begins with a number or contains a space, a special character, or a reserved word, you
must use an expression attribute name to replace that attribute's name in the expression.
=> So spaces,+,- or other special characters do identify tokens in update expressions
2) When using a dot (.) in an attribute name you must use expression-attribute-names. A dot in an expression
will be interpreted as a separator in a document path
3) For a nested structure if you want to use expression_attribute_names you must specify one per part of the
path. Since for members of expression_attribute_names the . is part of the name
"""
@classmethod
def is_simple_token_character(cls, character):
return character.isalnum() or character in ("_", ":", "#")
@classmethod
def is_possible_token_boundary(cls, character):
return (
character in Token.SPECIAL_CHARACTERS
or not cls.is_simple_token_character(character)
)
@classmethod
def is_expression_attribute(cls, input_string):
return re.compile("^[a-zA-Z][a-zA-Z0-9_]*$").match(input_string) is not None
@classmethod
def is_expression_attribute_name(cls, input_string):
"""
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/Expressions.ExpressionAttributeNames.html
An expression attribute name must begin with a pound sign (#), and be followed by one or more alphanumeric
characters.
"""
return input_string.startswith("#") and cls.is_expression_attribute(
input_string[1:]
)
@classmethod
def is_expression_attribute_value(cls, input_string):
return re.compile("^:[a-zA-Z0-9_]*$").match(input_string) is not None
def raise_unexpected_token(self):
"""If during parsing an unexpected token is encountered"""
if len(self.token_list) == 0:
near = ""
else:
if len(self.token_list) == 1:
near = self.token_list[-1].value
else:
if self.token_list[-1].type == Token.WHITESPACE:
# Last token was whitespace take 2nd last token value as well to help User orientate
near = self.token_list[-2].value + self.token_list[-1].value
else:
near = self.token_list[-1].value
problematic_token = self.staged_characters[0]
raise InvalidTokenException(problematic_token, near + self.staged_characters)
def __init__(self, input_expression_str):
self.input_expression_str = input_expression_str
self.token_list = []
self.staged_characters = ""
@classmethod
def is_py2(cls):
return sys.version_info[0] == 2
@classmethod
def make_list(cls, input_expression_str):
if cls.is_py2():
pass
else:
assert isinstance(input_expression_str, str)
return ExpressionTokenizer(input_expression_str)._make_list()
def add_token(self, token_type, token_value):
self.token_list.append(Token(token_type, token_value))
def add_token_from_stage(self, token_type):
self.add_token(token_type, self.staged_characters)
self.staged_characters = ""
@classmethod
def is_numeric(cls, input_str):
return re.compile("[0-9]+").match(input_str) is not None
def process_staged_characters(self):
if len(self.staged_characters) == 0:
return
if self.staged_characters.startswith("#"):
if self.is_expression_attribute_name(self.staged_characters):
self.add_token_from_stage(Token.ATTRIBUTE_NAME)
else:
raise InvalidExpressionAttributeNameKey(self.staged_characters)
elif self.is_numeric(self.staged_characters):
self.add_token_from_stage(Token.NUMBER)
elif self.is_expression_attribute(self.staged_characters):
self.add_token_from_stage(Token.ATTRIBUTE)
elif self.is_expression_attribute_value(self.staged_characters):
self.add_token_from_stage(Token.ATTRIBUTE_VALUE)
else:
self.raise_unexpected_token()
def _make_list(self):
"""
Just go through characters if a character is not a token boundary stage it for adding it as a grouped token
later if it is a tokenboundary process staged characters and then process the token boundary as well.
"""
for character in self.input_expression_str:
if not self.is_possible_token_boundary(character):
self.staged_characters += character
else:
self.process_staged_characters()
if character == Token.SPACE:
if (
len(self.token_list) > 0
and self.token_list[-1].type == Token.WHITESPACE
):
self.token_list[-1].value = (
self.token_list[-1].value + character
)
else:
self.add_token(Token.WHITESPACE, character)
elif character in Token.SPECIAL_CHARACTERS:
self.add_token(character, character)
elif not self.is_simple_token_character(character):
self.staged_characters += character
self.raise_unexpected_token()
else:
raise NotImplementedError(
"Encountered character which was not implemented : " + character
)
self.process_staged_characters()
return self.token_list

View File

@ -0,0 +1,341 @@
"""
See docstring class Validator below for more details on validation
"""
from abc import abstractmethod
from copy import deepcopy
from moto.dynamodb2.exceptions import (
AttributeIsReservedKeyword,
ExpressionAttributeValueNotDefined,
AttributeDoesNotExist,
ExpressionAttributeNameNotDefined,
IncorrectOperandType,
InvalidUpdateExpressionInvalidDocumentPath,
)
from moto.dynamodb2.models import DynamoType
from moto.dynamodb2.parsing.ast_nodes import (
ExpressionAttribute,
UpdateExpressionPath,
UpdateExpressionSetAction,
UpdateExpressionAddAction,
UpdateExpressionDeleteAction,
UpdateExpressionRemoveAction,
DDBTypedValue,
ExpressionAttributeValue,
ExpressionAttributeName,
DepthFirstTraverser,
NoneExistingPath,
UpdateExpressionFunction,
ExpressionPathDescender,
UpdateExpressionValue,
ExpressionValueOperator,
ExpressionSelector,
)
from moto.dynamodb2.parsing.reserved_keywords import ReservedKeywords
class ExpressionAttributeValueProcessor(DepthFirstTraverser):
def __init__(self, expression_attribute_values):
self.expression_attribute_values = expression_attribute_values
def _processing_map(self):
return {
ExpressionAttributeValue: self.replace_expression_attribute_value_with_value
}
def replace_expression_attribute_value_with_value(self, node):
"""A node representing an Expression Attribute Value. Resolve and replace value"""
assert isinstance(node, ExpressionAttributeValue)
attribute_value_name = node.get_value_name()
try:
target = self.expression_attribute_values[attribute_value_name]
except KeyError:
raise ExpressionAttributeValueNotDefined(
attribute_value=attribute_value_name
)
return DDBTypedValue(DynamoType(target))
class ExpressionAttributeResolvingProcessor(DepthFirstTraverser):
def _processing_map(self):
return {
UpdateExpressionSetAction: self.disable_resolving,
UpdateExpressionPath: self.process_expression_path_node,
}
def __init__(self, expression_attribute_names, item):
self.expression_attribute_names = expression_attribute_names
self.item = item
self.resolving = False
def pre_processing_of_child(self, parent_node, child_id):
"""
We have to enable resolving if we are processing a child of UpdateExpressionSetAction that is not first.
Because first argument is path to be set, 2nd argument would be the value.
"""
if isinstance(
parent_node,
(
UpdateExpressionSetAction,
UpdateExpressionRemoveAction,
UpdateExpressionDeleteAction,
UpdateExpressionAddAction,
),
):
if child_id == 0:
self.resolving = False
else:
self.resolving = True
def disable_resolving(self, node=None):
self.resolving = False
return node
def process_expression_path_node(self, node):
"""Resolve ExpressionAttribute if not part of a path and resolving is enabled."""
if self.resolving:
return self.resolve_expression_path(node)
else:
# Still resolve but return original note to make sure path is correct Just make sure nodes are creatable.
result_node = self.resolve_expression_path(node)
if (
isinstance(result_node, NoneExistingPath)
and not result_node.is_creatable()
):
raise InvalidUpdateExpressionInvalidDocumentPath()
return node
def resolve_expression_path(self, node):
assert isinstance(node, UpdateExpressionPath)
target = deepcopy(self.item.attrs)
for child in node.children:
# First replace placeholder with attribute_name
attr_name = None
if isinstance(child, ExpressionAttributeName):
attr_placeholder = child.get_attribute_name_placeholder()
try:
attr_name = self.expression_attribute_names[attr_placeholder]
except KeyError:
raise ExpressionAttributeNameNotDefined(attr_placeholder)
elif isinstance(child, ExpressionAttribute):
attr_name = child.get_attribute_name()
self.raise_exception_if_keyword(attr_name)
if attr_name is not None:
# Resolv attribute_name
try:
target = target[attr_name]
except (KeyError, TypeError):
if child == node.children[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
if isinstance(child, ExpressionPathDescender):
continue
elif isinstance(child, ExpressionSelector):
index = child.get_index()
if target.is_list():
try:
target = target[index]
except IndexError:
# When a list goes out of bounds when assigning that is no problem when at the assignment
# side. It will just append to the list.
if child == node.children[-1]:
return NoneExistingPath(creatable=True)
return NoneExistingPath()
else:
raise InvalidUpdateExpressionInvalidDocumentPath
else:
raise NotImplementedError(
"Path resolution for {t}".format(t=type(child))
)
return DDBTypedValue(DynamoType(target))
@classmethod
def raise_exception_if_keyword(cls, attribute):
if attribute.upper() in ReservedKeywords.get_reserved_keywords():
raise AttributeIsReservedKeyword(attribute)
class UpdateExpressionFunctionEvaluator(DepthFirstTraverser):
"""
At time of writing there are only 2 functions for DDB UpdateExpressions. They both are specific to the SET
expression as per the official AWS docs:
https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/
Expressions.UpdateExpressions.html#Expressions.UpdateExpressions.SET
"""
def _processing_map(self):
return {UpdateExpressionFunction: self.process_function}
def process_function(self, node):
assert isinstance(node, UpdateExpressionFunction)
function_name = node.get_function_name()
first_arg = node.get_nth_argument(1)
second_arg = node.get_nth_argument(2)
if function_name == "if_not_exists":
if isinstance(first_arg, NoneExistingPath):
result = second_arg
else:
result = first_arg
assert isinstance(result, (DDBTypedValue, NoneExistingPath))
return result
elif function_name == "list_append":
first_arg = self.get_list_from_ddb_typed_value(first_arg, function_name)
second_arg = self.get_list_from_ddb_typed_value(second_arg, function_name)
for list_element in second_arg.value:
first_arg.value.append(list_element)
return DDBTypedValue(first_arg)
else:
raise NotImplementedError(
"Unsupported function for moto {name}".format(name=function_name)
)
@classmethod
def get_list_from_ddb_typed_value(cls, node, function_name):
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
if not dynamo_value.is_list():
raise IncorrectOperandType(function_name, dynamo_value.type)
return dynamo_value
class NoneExistingPathChecker(DepthFirstTraverser):
"""
Pass through the AST and make sure there are no none-existing paths.
"""
def _processing_map(self):
return {NoneExistingPath: self.raise_none_existing_path}
def raise_none_existing_path(self, node):
raise AttributeDoesNotExist
class ExecuteOperations(DepthFirstTraverser):
def _processing_map(self):
return {UpdateExpressionValue: self.process_update_expression_value}
def process_update_expression_value(self, node):
"""
If an UpdateExpressionValue only has a single child the node will be replaced with the childe.
Otherwise it has 3 children and the middle one is an ExpressionValueOperator which details how to combine them
Args:
node(Node):
Returns:
Node: The resulting node of the operation if present or the child.
"""
assert isinstance(node, UpdateExpressionValue)
if len(node.children) == 1:
return node.children[0]
elif len(node.children) == 3:
operator_node = node.children[1]
assert isinstance(operator_node, ExpressionValueOperator)
operator = operator_node.get_operator()
left_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[0])
right_operand = self.get_dynamo_value_from_ddb_typed_value(node.children[2])
if operator == "+":
return self.get_sum(left_operand, right_operand)
elif operator == "-":
return self.get_subtraction(left_operand, right_operand)
else:
raise NotImplementedError(
"Moto does not support operator {operator}".format(
operator=operator
)
)
else:
raise NotImplementedError(
"UpdateExpressionValue only has implementations for 1 or 3 children."
)
@classmethod
def get_dynamo_value_from_ddb_typed_value(cls, node):
assert isinstance(node, DDBTypedValue)
dynamo_value = node.get_value()
assert isinstance(dynamo_value, DynamoType)
return dynamo_value
@classmethod
def get_sum(cls, left_operand, right_operand):
"""
Args:
left_operand(DynamoType):
right_operand(DynamoType):
Returns:
DDBTypedValue:
"""
try:
return DDBTypedValue(left_operand + right_operand)
except TypeError:
raise IncorrectOperandType("+", left_operand.type)
@classmethod
def get_subtraction(cls, left_operand, right_operand):
"""
Args:
left_operand(DynamoType):
right_operand(DynamoType):
Returns:
DDBTypedValue:
"""
try:
return DDBTypedValue(left_operand - right_operand)
except TypeError:
raise IncorrectOperandType("-", left_operand.type)
class Validator(object):
"""
A validator is used to validate expressions which are passed in as an AST.
"""
def __init__(
self, expression, expression_attribute_names, expression_attribute_values, item
):
"""
Besides validation the Validator should also replace referenced parts of an item which is cheapest upon
validation.
Args:
expression(Node): The root node of the AST representing the expression to be validated
expression_attribute_names(ExpressionAttributeNames):
expression_attribute_values(ExpressionAttributeValues):
item(Item): The item which will be updated (pointed to by Key of update_item)
"""
self.expression_attribute_names = expression_attribute_names
self.expression_attribute_values = expression_attribute_values
self.item = item
self.processors = self.get_ast_processors()
self.node_to_validate = deepcopy(expression)
@abstractmethod
def get_ast_processors(self):
"""Get the different processors that go through the AST tree and processes the nodes."""
def validate(self):
n = self.node_to_validate
for processor in self.processors:
n = processor.traverse(n)
return n
class UpdateExpressionValidator(Validator):
def get_ast_processors(self):
"""Get the different processors that go through the AST tree and processes the nodes."""
processors = [
ExpressionAttributeValueProcessor(self.expression_attribute_values),
ExpressionAttributeResolvingProcessor(
self.expression_attribute_names, self.item
),
UpdateExpressionFunctionEvaluator(),
NoneExistingPathChecker(),
ExecuteOperations(),
]
return processors

View File

@ -9,8 +9,8 @@ import six
from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores, amzn_request_id
from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSizeTooLarge
from .models import dynamodb_backends, dynamo_json_dump
from .exceptions import InvalidIndexNameError, ItemSizeTooLarge, MockValidationException
from moto.dynamodb2.models import dynamodb_backends, dynamo_json_dump
TRANSACTION_MAX_ITEMS = 25
@ -92,16 +92,24 @@ class DynamoHandler(BaseResponse):
def list_tables(self):
body = self.body
limit = body.get("Limit", 100)
if body.get("ExclusiveStartTableName"):
last = body.get("ExclusiveStartTableName")
start = list(self.dynamodb_backend.tables.keys()).index(last) + 1
all_tables = list(self.dynamodb_backend.tables.keys())
exclusive_start_table_name = body.get("ExclusiveStartTableName")
if exclusive_start_table_name:
try:
last_table_index = all_tables.index(exclusive_start_table_name)
except ValueError:
start = len(all_tables)
else:
start = last_table_index + 1
else:
start = 0
all_tables = list(self.dynamodb_backend.tables.keys())
if limit:
tables = all_tables[start : start + limit]
else:
tables = all_tables[start:]
response = {"TableNames": tables}
if limit and len(all_tables) > start + limit:
response["LastEvaluatedTableName"] = tables[-1]
@ -298,7 +306,7 @@ class DynamoHandler(BaseResponse):
)
except ItemSizeTooLarge:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, ItemSizeTooLarge.message)
return self.error(er, ItemSizeTooLarge.item_size_too_large_msg)
except KeyError as ke:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, ke.args[0])
@ -462,8 +470,10 @@ class DynamoHandler(BaseResponse):
for k, v in six.iteritems(self.body.get("ExpressionAttributeNames", {}))
)
if " AND " in key_condition_expression:
expressions = key_condition_expression.split(" AND ", 1)
if " and " in key_condition_expression.lower():
expressions = re.split(
" AND ", key_condition_expression, maxsplit=1, flags=re.IGNORECASE
)
index_hash_key = [key for key in index if key["KeyType"] == "HASH"][0]
hash_key_var = reverse_attribute_lookup.get(
@ -748,11 +758,6 @@ class DynamoHandler(BaseResponse):
expression_attribute_names = self.body.get("ExpressionAttributeNames", {})
expression_attribute_values = self.body.get("ExpressionAttributeValues", {})
# Support spaces between operators in an update expression
# E.g. `a = b + c` -> `a=b+c`
if update_expression:
update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression)
try:
item = self.dynamodb_backend.update_item(
name,
@ -764,15 +769,9 @@ class DynamoHandler(BaseResponse):
expected,
condition_expression,
)
except InvalidUpdateExpression:
except MockValidationException as mve:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(
er,
"The document path provided in the update expression is invalid for update",
)
except ItemSizeTooLarge:
er = "com.amazonaws.dynamodb.v20111205#ValidationException"
return self.error(er, ItemSizeTooLarge.message)
return self.error(er, mve.exception_msg)
except ValueError:
er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException"
return self.error(

View File

@ -7,7 +7,7 @@ import base64
from boto3 import Session
from moto.core import BaseBackend, BaseModel
from moto.dynamodb2.models import dynamodb_backends
from moto.dynamodb2.models import dynamodb_backends, DynamoJsonEncoder
class ShardIterator(BaseModel):
@ -137,7 +137,7 @@ class DynamoDBStreamsBackend(BaseBackend):
def get_records(self, iterator_arn, limit):
shard_iterator = self.shard_iterators[iterator_arn]
return json.dumps(shard_iterator.get(limit))
return json.dumps(shard_iterator.get(limit), cls=DynamoJsonEncoder)
dynamodbstreams_backends = {}

View File

@ -231,6 +231,14 @@ class InvalidVolumeAttachmentError(EC2ClientError):
)
class VolumeInUseError(EC2ClientError):
def __init__(self, volume_id, instance_id):
super(VolumeInUseError, self).__init__(
"VolumeInUse",
"Volume {0} is currently attached to {1}".format(volume_id, instance_id),
)
class InvalidDomainError(EC2ClientError):
def __init__(self, domain):
super(InvalidDomainError, self).__init__(

View File

@ -70,6 +70,7 @@ from .exceptions import (
InvalidSubnetIdError,
InvalidSubnetRangeError,
InvalidVolumeIdError,
VolumeInUseError,
InvalidVolumeAttachmentError,
InvalidVpcCidrBlockAssociationIdError,
InvalidVPCPeeringConnectionIdError,
@ -556,6 +557,10 @@ class Instance(TaggedEC2Resource, BotoInstance):
# worst case we'll get IP address exaustion... rarely
pass
def add_block_device(self, size, device_path):
volume = self.ec2_backend.create_volume(size, self.region_name)
self.ec2_backend.attach_volume(volume.id, self.id, device_path)
def setup_defaults(self):
# Default have an instance with root volume should you not wish to
# override with attach volume cmd.
@ -563,6 +568,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
self.ec2_backend.attach_volume(volume.id, self.id, "/dev/sda1")
def teardown_defaults(self):
if "/dev/sda1" in self.block_device_mapping:
volume_id = self.block_device_mapping["/dev/sda1"].volume_id
self.ec2_backend.detach_volume(volume_id, self.id, "/dev/sda1")
self.ec2_backend.delete_volume(volume_id)
@ -620,6 +626,7 @@ class Instance(TaggedEC2Resource, BotoInstance):
subnet_id=properties.get("SubnetId"),
key_name=properties.get("KeyName"),
private_ip=properties.get("PrivateIpAddress"),
block_device_mappings=properties.get("BlockDeviceMappings", {}),
)
instance = reservation.instances[0]
for tag in properties.get("Tags", []):
@ -775,7 +782,14 @@ class Instance(TaggedEC2Resource, BotoInstance):
if "SubnetId" in nic:
subnet = self.ec2_backend.get_subnet(nic["SubnetId"])
else:
subnet = None
# Get default Subnet
subnet = [
subnet
for subnet in self.ec2_backend.get_all_subnets(
filters={"availabilityZone": self._placement.zone}
)
if subnet.default_for_az
][0]
group_id = nic.get("SecurityGroupId")
group_ids = [group_id] if group_id else []
@ -872,7 +886,14 @@ class InstanceBackend(object):
)
new_reservation.instances.append(new_instance)
new_instance.add_tags(instance_tags)
if "block_device_mappings" in kwargs:
for block_device in kwargs["block_device_mappings"]:
new_instance.add_block_device(
block_device["Ebs"]["VolumeSize"], block_device["DeviceName"]
)
else:
new_instance.setup_defaults()
return new_reservation
def start_instances(self, instance_ids):
@ -936,6 +957,12 @@ class InstanceBackend(object):
value = getattr(instance, key)
return instance, value
def describe_instance_credit_specifications(self, instance_ids):
queried_instances = []
for instance in self.get_multi_instances_by_id(instance_ids):
queried_instances.append(instance)
return queried_instances
def all_instances(self, filters=None):
instances = []
for reservation in self.all_reservations():
@ -1498,6 +1525,11 @@ class RegionsAndZonesBackend(object):
regions.append(Region(region, "ec2.{}.amazonaws.com.cn".format(region)))
zones = {
"af-south-1": [
Zone(region_name="af-south-1", name="af-south-1a", zone_id="afs1-az1"),
Zone(region_name="af-south-1", name="af-south-1b", zone_id="afs1-az2"),
Zone(region_name="af-south-1", name="af-south-1c", zone_id="afs1-az3"),
],
"ap-south-1": [
Zone(region_name="ap-south-1", name="ap-south-1a", zone_id="aps1-az1"),
Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3"),
@ -2385,6 +2417,9 @@ class EBSBackend(object):
def delete_volume(self, volume_id):
if volume_id in self.volumes:
volume = self.volumes[volume_id]
if volume.attachment:
raise VolumeInUseError(volume_id, volume.attachment.instance.id)
return self.volumes.pop(volume_id)
raise InvalidVolumeIdError(volume_id)

File diff suppressed because one or more lines are too long

View File

@ -35,6 +35,7 @@ DESCRIBE_ZONES_RESPONSE = """<DescribeAvailabilityZonesResponse xmlns="http://ec
<zoneName>{{ zone.name }}</zoneName>
<zoneState>available</zoneState>
<regionName>{{ zone.region_name }}</regionName>
<zoneId>{{ zone.zone_id }}</zoneId>
<messageSet/>
</item>
{% endfor %}

View File

@ -168,6 +168,14 @@ class InstanceResponse(BaseResponse):
return template.render(instance=instance, attribute=attribute, value=value)
def describe_instance_credit_specifications(self):
instance_ids = self._get_multi_param("InstanceId")
instance = self.ec2_backend.describe_instance_credit_specifications(
instance_ids
)
template = self.response_template(EC2_DESCRIBE_INSTANCE_CREDIT_SPECIFICATIONS)
return template.render(instances=instance)
def modify_instance_attribute(self):
handlers = [
self._dot_value_instance_attribute_handler,
@ -671,6 +679,18 @@ EC2_DESCRIBE_INSTANCE_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="h
</{{ attribute }}>
</DescribeInstanceAttributeResponse>"""
EC2_DESCRIBE_INSTANCE_CREDIT_SPECIFICATIONS = """<DescribeInstanceCreditSpecificationsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>1b234b5c-d6ef-7gh8-90i1-j2345678901</requestId>
<instanceCreditSpecificationSet>
{% for instance in instances %}
<item>
<instanceId>{{ instance.id }}</instanceId>
<cpuCredits>standard</cpuCredits>
</item>
{% endfor %}
</instanceCreditSpecificationSet>
</DescribeInstanceCreditSpecificationsResponse>"""
EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE = """<DescribeInstanceAttributeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<instanceId>{{ instance.id }}</instanceId>

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .responses import EC2Response
url_bases = ["https?://ec2\.(.+)\.amazonaws\.com(|\.cn)"]
url_bases = [r"https?://ec2\.(.+)\.amazonaws\.com(|\.cn)"]
url_paths = {"{0}/": EC2Response.dispatch}

View File

@ -604,7 +604,10 @@ class EC2ContainerServiceBackend(BaseBackend):
raise Exception("{0} is not a task_definition".format(task_definition_name))
def run_task(self, cluster_str, task_definition_str, count, overrides, started_by):
if cluster_str:
cluster_name = cluster_str.split("/")[-1]
else:
cluster_name = "default"
if cluster_name in self.clusters:
cluster = self.clusters[cluster_name]
else:

View File

@ -1,6 +1,9 @@
from __future__ import unicode_literals
import datetime
import pytz
from boto.ec2.elb.attributes import (
LbAttributes,
ConnectionSettingAttribute,
@ -83,7 +86,7 @@ class FakeLoadBalancer(BaseModel):
self.zones = zones
self.listeners = []
self.backends = []
self.created_time = datetime.datetime.now()
self.created_time = datetime.datetime.now(pytz.utc)
self.scheme = scheme
self.attributes = FakeLoadBalancer.get_default_attributes()
self.policies = Policies()

View File

@ -442,7 +442,7 @@ DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http
{% endfor %}
</SecurityGroups>
<LoadBalancerName>{{ load_balancer.name }}</LoadBalancerName>
<CreatedTime>{{ load_balancer.created_time }}</CreatedTime>
<CreatedTime>{{ load_balancer.created_time.isoformat() }}</CreatedTime>
<HealthCheck>
{% if load_balancer.health_check %}
<Interval>{{ load_balancer.health_check.interval }}</Interval>

View File

@ -26,6 +26,10 @@ class Rule(BaseModel):
self.role_arn = kwargs.get("RoleArn")
self.targets = []
@property
def physical_resource_id(self):
return self.name
# This song and dance for targets is because we need order for Limits and NextTokens, but can't use OrderedDicts
# with Python 2.6, so tracking it with an array it is.
def _check_target_exists(self, target_id):
@ -59,6 +63,14 @@ class Rule(BaseModel):
if index is not None:
self.targets.pop(index)
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
return self.arn
raise UnformattedGetAttTemplateException()
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name

View File

@ -34,6 +34,9 @@ class GlueBackend(BaseBackend):
except KeyError:
raise DatabaseNotFoundException(database_name)
def get_databases(self):
return [self.databases[key] for key in self.databases] if self.databases else []
def create_table(self, database_name, table_name, table_input):
database = self.get_database(database_name)

View File

@ -30,6 +30,12 @@ class GlueResponse(BaseResponse):
database = self.glue_backend.get_database(database_name)
return json.dumps({"Database": {"Name": database.name}})
def get_databases(self):
database_list = self.glue_backend.get_databases()
return json.dumps(
{"DatabaseList": [{"Name": database.name} for database in database_list]}
)
def create_table(self):
database_name = self.parameters.get("DatabaseName")
table_input = self.parameters.get("TableInput")

View File

@ -7,10 +7,10 @@ class IoTClientError(JsonRESTError):
class ResourceNotFoundException(IoTClientError):
def __init__(self):
def __init__(self, msg=None):
self.code = 404
super(ResourceNotFoundException, self).__init__(
"ResourceNotFoundException", "The specified resource does not exist"
"ResourceNotFoundException", msg or "The specified resource does not exist"
)

View File

@ -805,6 +805,14 @@ class IoTBackend(BaseBackend):
return thing_names
def list_thing_principals(self, thing_name):
things = [_ for _ in self.things.values() if _.thing_name == thing_name]
if len(things) == 0:
raise ResourceNotFoundException(
"Failed to list principals for thing %s because the thing does not exist in your account"
% thing_name
)
principals = [
k[0] for k, v in self.principal_things.items() if k[1] == thing_name
]

View File

@ -134,7 +134,7 @@ class LogStream:
return None, 0
events = sorted(
filter(filter_func, self.events), key=lambda event: event.timestamp,
filter(filter_func, self.events), key=lambda event: event.timestamp
)
direction, index = get_index_and_direction_from_token(next_token)
@ -169,11 +169,7 @@ class LogStream:
if end_index > final_index:
end_index = final_index
elif end_index < 0:
return (
[],
"b/{:056d}".format(0),
"f/{:056d}".format(0),
)
return ([], "b/{:056d}".format(0), "f/{:056d}".format(0))
events_page = [
event.to_response_dict() for event in events[start_index : end_index + 1]
@ -219,7 +215,7 @@ class LogStream:
class LogGroup:
def __init__(self, region, name, tags):
def __init__(self, region, name, tags, **kwargs):
self.name = name
self.region = region
self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format(
@ -228,9 +224,9 @@ class LogGroup:
self.creationTime = int(unix_time_millis())
self.tags = tags
self.streams = dict() # {name: LogStream}
self.retentionInDays = (
None # AWS defaults to Never Expire for log group retention
)
self.retention_in_days = kwargs.get(
"RetentionInDays"
) # AWS defaults to Never Expire for log group retention
def create_log_stream(self, log_stream_name):
if log_stream_name in self.streams:
@ -368,12 +364,12 @@ class LogGroup:
"storedBytes": sum(s.storedBytes for s in self.streams.values()),
}
# AWS only returns retentionInDays if a value is set for the log group (ie. not Never Expire)
if self.retentionInDays:
log_group["retentionInDays"] = self.retentionInDays
if self.retention_in_days:
log_group["retentionInDays"] = self.retention_in_days
return log_group
def set_retention_policy(self, retention_in_days):
self.retentionInDays = retention_in_days
self.retention_in_days = retention_in_days
def list_tags(self):
return self.tags if self.tags else {}
@ -401,10 +397,12 @@ class LogsBackend(BaseBackend):
self.__dict__ = {}
self.__init__(region_name)
def create_log_group(self, log_group_name, tags):
def create_log_group(self, log_group_name, tags, **kwargs):
if log_group_name in self.groups:
raise ResourceAlreadyExistsException()
self.groups[log_group_name] = LogGroup(self.region_name, log_group_name, tags)
self.groups[log_group_name] = LogGroup(
self.region_name, log_group_name, tags, **kwargs
)
return self.groups[log_group_name]
def ensure_log_group(self, log_group_name, tags):

View File

@ -865,7 +865,10 @@ class RDS2Backend(BaseBackend):
def stop_database(self, db_instance_identifier, db_snapshot_identifier=None):
database = self.describe_databases(db_instance_identifier)[0]
# todo: certain rds types not allowed to be stopped at this time.
if database.is_replica or database.multi_az:
# https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/USER_StopInstance.html#USER_StopInstance.Limitations
if database.is_replica or (
database.multi_az and database.engine.lower().startswith("sqlserver")
):
# todo: more db types not supported by stop/start instance api
raise InvalidDBClusterStateFaultError(db_instance_identifier)
if database.status != "available":

View File

@ -22,6 +22,7 @@ import six
from bisect import insort
from moto.core import ACCOUNT_ID, BaseBackend, BaseModel
from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime
from moto.cloudwatch.models import metric_providers, MetricDatum
from moto.utilities.tagging_service import TaggingService
from .exceptions import (
BucketAlreadyExists,
@ -1158,6 +1159,39 @@ class S3Backend(BaseBackend):
self.account_public_access_block = None
self.tagger = TaggingService()
# Register this class as a CloudWatch Metric Provider
# Must provide a method 'get_cloudwatch_metrics' that will return a list of metrics, based on the data available
metric_providers["S3"] = self
def get_cloudwatch_metrics(self):
metrics = []
for name, bucket in self.buckets.items():
metrics.append(
MetricDatum(
namespace="AWS/S3",
name="BucketSizeBytes",
value=bucket.keys.item_size(),
dimensions=[
{"Name": "StorageType", "Value": "StandardStorage"},
{"Name": "BucketName", "Value": name},
],
timestamp=datetime.datetime.now(),
)
)
metrics.append(
MetricDatum(
namespace="AWS/S3",
name="NumberOfObjects",
value=len(bucket.keys),
dimensions=[
{"Name": "StorageType", "Value": "AllStorageTypes"},
{"Name": "BucketName", "Value": name},
],
timestamp=datetime.datetime.now(),
)
)
return metrics
def create_bucket(self, bucket_name, region_name):
if bucket_name in self.buckets:
raise BucketAlreadyExists(bucket=bucket_name)

View File

@ -7,7 +7,7 @@ import six
from botocore.awsrequest import AWSPreparedRequest
from moto.core.utils import str_to_rfc_1123_datetime, py2_strip_unicode_keys
from six.moves.urllib.parse import parse_qs, urlparse, unquote
from six.moves.urllib.parse import parse_qs, urlparse, unquote, parse_qsl
import xmltodict
@ -775,6 +775,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return 409, {}, template.render(bucket=removed_bucket)
def _bucket_response_post(self, request, body, bucket_name):
response_headers = {}
if not request.headers.get("Content-Length"):
return 411, {}, "Content-Length required"
@ -796,11 +797,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else:
# HTTPretty, build new form object
body = body.decode()
form = {}
for kv in body.split("&"):
k, v = kv.split("=")
form[k] = v
form = dict(parse_qsl(body))
key = form["key"]
if "file" in form:
@ -808,13 +805,23 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
else:
f = request.files["file"].stream.read()
if "success_action_redirect" in form:
response_headers["Location"] = form["success_action_redirect"]
if "success_action_status" in form:
status_code = form["success_action_status"]
elif "success_action_redirect" in form:
status_code = 303
else:
status_code = 204
new_key = self.backend.set_key(bucket_name, key, f)
# Metadata
metadata = metadata_from_headers(form)
new_key.set_metadata(metadata)
return 200, {}, ""
return status_code, response_headers, ""
@staticmethod
def _get_path(request):
@ -1232,9 +1239,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
)
self.backend.set_key_tags(new_key, tagging)
template = self.response_template(S3_OBJECT_RESPONSE)
response_headers.update(new_key.response_dict)
return 200, response_headers, template.render(key=new_key)
return 200, response_headers, ""
def _key_response_head(self, bucket_name, query, key_name, headers):
response_headers = {}
@ -1542,8 +1548,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return 204, {}, ""
version_id = query.get("versionId", [None])[0]
self.backend.delete_key(bucket_name, key_name, version_id=version_id)
template = self.response_template(S3_DELETE_OBJECT_SUCCESS)
return 204, {}, template.render()
return 204, {}, ""
def _complete_multipart_body(self, body):
ps = minidom.parseString(body).getElementsByTagName("Part")
@ -1858,20 +1863,6 @@ S3_DELETE_KEYS_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
{% endfor %}
</DeleteResult>"""
S3_DELETE_OBJECT_SUCCESS = """<DeleteObjectResponse xmlns="http://s3.amazonaws.com/doc/2006-03-01">
<DeleteObjectResponse>
<Code>200</Code>
<Description>OK</Description>
</DeleteObjectResponse>
</DeleteObjectResponse>"""
S3_OBJECT_RESPONSE = """<PutObjectResponse xmlns="http://s3.amazonaws.com/doc/2006-03-01">
<PutObjectResponse>
<ETag>{{ key.etag }}</ETag>
<LastModified>{{ key.last_modified_ISO8601 }}</LastModified>
</PutObjectResponse>
</PutObjectResponse>"""
S3_OBJECT_ACL_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<AccessControlPolicy xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
<Owner>

View File

@ -146,6 +146,12 @@ class _VersionedKeyStore(dict):
for key in self:
yield key, self.getlist(key)
def item_size(self):
size = 0
for val in self.values():
size += sys.getsizeof(val)
return size
items = iteritems = _iteritems
lists = iterlists = _iterlists
values = itervalues = _itervalues

View File

@ -107,6 +107,34 @@ class SecretsManagerBackend(BaseBackend):
return response
def update_secret(
self, secret_id, secret_string=None, secret_binary=None, **kwargs
):
# error if secret does not exist
if secret_id not in self.secrets.keys():
raise SecretNotFoundException()
if "deleted_date" in self.secrets[secret_id]:
raise InvalidRequestException(
"An error occurred (InvalidRequestException) when calling the UpdateSecret operation: "
"You can't perform this operation on the secret because it was marked for deletion."
)
version_id = self._add_secret(
secret_id, secret_string=secret_string, secret_binary=secret_binary
)
response = json.dumps(
{
"ARN": secret_arn(self.region, secret_id),
"Name": secret_id,
"VersionId": version_id,
}
)
return response
def create_secret(
self, name, secret_string=None, secret_binary=None, tags=[], **kwargs
):

View File

@ -29,6 +29,16 @@ class SecretsManagerResponse(BaseResponse):
tags=tags,
)
def update_secret(self):
secret_id = self._get_param("SecretId")
secret_string = self._get_param("SecretString")
secret_binary = self._get_param("SecretBinary")
return secretsmanager_backends[self.region].update_secret(
secret_id=secret_id,
secret_string=secret_string,
secret_binary=secret_binary,
)
def get_random_password(self):
password_length = self._get_param("PasswordLength", if_none=32)
exclude_characters = self._get_param("ExcludeCharacters", if_none="")

View File

@ -651,7 +651,7 @@ class SimpleSystemManagerBackend(BaseBackend):
label.startswith("aws")
or label.startswith("ssm")
or label[:1].isdigit()
or not re.match("^[a-zA-z0-9_\.\-]*$", label)
or not re.match(r"^[a-zA-z0-9_\.\-]*$", label)
):
invalid_labels.append(label)
continue

View File

@ -1,4 +1,5 @@
#!/usr/bin/env python
import json
import os
import subprocess
@ -11,128 +12,142 @@ class Instance(object):
self.instance = instance
def _get_td(self, td):
return self.instance.find('td', attrs={'class': td})
return self.instance.find("td", attrs={"class": td})
def _get_sort(self, td):
return float(self.instance.find('td', attrs={'class': td}).find('span')['sort'])
return float(self.instance.find("td", attrs={"class": td}).find("span")["sort"])
@property
def name(self):
return self._get_td('name').text.strip()
return self._get_td("name").text.strip()
@property
def apiname(self):
return self._get_td('apiname').text.strip()
return self._get_td("apiname").text.strip()
@property
def memory(self):
return self._get_sort('memory')
return self._get_sort("memory")
@property
def computeunits(self):
return self._get_sort('computeunits')
return self._get_sort("computeunits")
@property
def vcpus(self):
return self._get_sort('vcpus')
return self._get_sort("vcpus")
@property
def gpus(self):
return int(self._get_td('gpus').text.strip())
return int(self._get_td("gpus").text.strip())
@property
def fpga(self):
return int(self._get_td('fpga').text.strip())
return int(self._get_td("fpga").text.strip())
@property
def ecu_per_vcpu(self):
return self._get_sort('ecu-per-vcpu')
return self._get_sort("ecu-per-vcpu")
@property
def physical_processor(self):
return self._get_td('physical_processor').text.strip()
return self._get_td("physical_processor").text.strip()
@property
def clock_speed_ghz(self):
return self._get_td('clock_speed_ghz').text.strip()
return self._get_td("clock_speed_ghz").text.strip()
@property
def intel_avx(self):
return self._get_td('intel_avx').text.strip()
return self._get_td("intel_avx").text.strip()
@property
def intel_avx2(self):
return self._get_td('intel_avx2').text.strip()
return self._get_td("intel_avx2").text.strip()
@property
def intel_turbo(self):
return self._get_td('intel_turbo').text.strip()
return self._get_td("intel_turbo").text.strip()
@property
def storage(self):
return self._get_sort('storage')
return self._get_sort("storage")
@property
def architecture(self):
return self._get_td('architecture').text.strip()
return self._get_td("architecture").text.strip()
@property
def network_perf(self): # 2 == low
return self._get_sort('networkperf')
return self._get_sort("networkperf")
@property
def ebs_max_bandwidth(self):
return self._get_sort('ebs-max-bandwidth')
return self._get_sort("ebs-max-bandwidth")
@property
def ebs_throughput(self):
return self._get_sort('ebs-throughput')
return self._get_sort("ebs-throughput")
@property
def ebs_iops(self):
return self._get_sort('ebs-iops')
return self._get_sort("ebs-iops")
@property
def max_ips(self):
return int(self._get_td('maxips').text.strip())
return int(self._get_td("maxips").text.strip())
@property
def enhanced_networking(self):
return self._get_td('enhanced-networking').text.strip() != 'No'
return self._get_td("enhanced-networking").text.strip() != "No"
@property
def vpc_only(self):
return self._get_td('vpc-only').text.strip() != 'No'
return self._get_td("vpc-only").text.strip() != "No"
@property
def ipv6_support(self):
return self._get_td('ipv6-support').text.strip() != 'No'
return self._get_td("ipv6-support").text.strip() != "No"
@property
def placement_group_support(self):
return self._get_td('placement-group-support').text.strip() != 'No'
return self._get_td("placement-group-support").text.strip() != "No"
@property
def linux_virtualization(self):
return self._get_td('linux-virtualization').text.strip()
return self._get_td("linux-virtualization").text.strip()
def to_dict(self):
result = {}
for attr in [x for x in self.__class__.__dict__.keys() if not x.startswith('_') and x != 'to_dict']:
for attr in [
x
for x in self.__class__.__dict__.keys()
if not x.startswith("_") and x != "to_dict"
]:
try:
result[attr] = getattr(self, attr)
except ValueError as ex:
if "'N/A'" in str(ex):
print(
"Skipping attribute '{0}' for instance type '{1}' (not found)".format(
attr, self.name
)
)
else:
raise
return self.apiname, result
def main():
print("Getting HTML from http://www.ec2instances.info")
page_request = requests.get('http://www.ec2instances.info')
soup = BeautifulSoup(page_request.text, 'html.parser')
data_table = soup.find(id='data')
page_request = requests.get("http://www.ec2instances.info")
soup = BeautifulSoup(page_request.text, "html.parser")
data_table = soup.find(id="data")
print("Finding data in table")
instances = data_table.find('tbody').find_all('tr')
instances = data_table.find("tbody").find_all("tr")
print("Parsing data")
result = {}
@ -140,11 +155,16 @@ def main():
instance_id, instance_data = Instance(instance).to_dict()
result[instance_id] = instance_data
root_dir = subprocess.check_output(['git', 'rev-parse', '--show-toplevel']).decode().strip()
dest = os.path.join(root_dir, 'moto/ec2/resources/instance_types.json')
root_dir = (
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
.decode()
.strip()
)
dest = os.path.join(root_dir, "moto/ec2/resources/instance_types.json")
print("Writing data to {0}".format(dest))
with open(dest, 'w') as open_file:
json.dump(result, open_file)
with open(dest, "w") as open_file:
json.dump(result, open_file, sort_keys=True)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -94,10 +94,12 @@ setup(
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"License :: OSI Approved :: Apache Software License",
"Topic :: Software Development :: Testing",
],
project_urls={
"Documentation": "http://docs.getmoto.org/en/latest/",
},
data_files=[('', ['moto/dynamodb2/parsing/reserved_keywords.txt'])],
)

View File

@ -69,6 +69,22 @@ def test_create_rest_api_with_tags():
response["tags"].should.equal({"MY_TAG1": "MY_VALUE1"})
@mock_apigateway
def test_create_rest_api_with_policy():
client = boto3.client("apigateway", region_name="us-west-2")
policy = '{"Version": "2012-10-17","Statement": []}'
response = client.create_rest_api(
name="my_api", description="this is my api", policy=policy
)
api_id = response["id"]
response = client.get_rest_api(restApiId=api_id)
assert "policy" in response
response["policy"].should.equal(policy)
@mock_apigateway
def test_create_rest_api_invalid_apikeysource():
client = boto3.client("apigateway", region_name="us-west-2")
@ -1483,6 +1499,181 @@ def test_deployment():
stage["description"].should.equal("_new_description_")
@mock_apigateway
def test_create_domain_names():
client = boto3.client("apigateway", region_name="us-west-2")
domain_name = "testDomain"
test_certificate_name = "test.certificate"
test_certificate_private_key = "testPrivateKey"
# success case with valid params
response = client.create_domain_name(
domainName=domain_name,
certificateName=test_certificate_name,
certificatePrivateKey=test_certificate_private_key,
)
response["domainName"].should.equal(domain_name)
response["certificateName"].should.equal(test_certificate_name)
# without domain name it should throw BadRequestException
with assert_raises(ClientError) as ex:
client.create_domain_name(domainName="")
ex.exception.response["Error"]["Message"].should.equal("No Domain Name specified")
ex.exception.response["Error"]["Code"].should.equal("BadRequestException")
@mock_apigateway
def test_get_domain_names():
client = boto3.client("apigateway", region_name="us-west-2")
# without any domain names already present
result = client.get_domain_names()
result["items"].should.equal([])
domain_name = "testDomain"
test_certificate_name = "test.certificate"
response = client.create_domain_name(
domainName=domain_name, certificateName=test_certificate_name
)
response["domainName"].should.equal(domain_name)
response["certificateName"].should.equal(test_certificate_name)
response["domainNameStatus"].should.equal("AVAILABLE")
# after adding a new domain name
result = client.get_domain_names()
result["items"][0]["domainName"].should.equal(domain_name)
result["items"][0]["certificateName"].should.equal(test_certificate_name)
result["items"][0]["domainNameStatus"].should.equal("AVAILABLE")
@mock_apigateway
def test_get_domain_name():
client = boto3.client("apigateway", region_name="us-west-2")
domain_name = "testDomain"
# quering an invalid domain name which is not present
with assert_raises(ClientError) as ex:
client.get_domain_name(domainName=domain_name)
ex.exception.response["Error"]["Message"].should.equal(
"Invalid Domain Name specified"
)
ex.exception.response["Error"]["Code"].should.equal("NotFoundException")
# adding a domain name
client.create_domain_name(domainName=domain_name)
# retrieving the data of added domain name.
result = client.get_domain_name(domainName=domain_name)
result["domainName"].should.equal(domain_name)
result["domainNameStatus"].should.equal("AVAILABLE")
@mock_apigateway
def test_create_model():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
rest_api_id = response["id"]
dummy_rest_api_id = "a12b3c4d"
model_name = "testModel"
description = "test model"
content_type = "application/json"
# success case with valid params
response = client.create_model(
restApiId=rest_api_id,
name=model_name,
description=description,
contentType=content_type,
)
response["name"].should.equal(model_name)
response["description"].should.equal(description)
# with an invalid rest_api_id it should throw NotFoundException
with assert_raises(ClientError) as ex:
client.create_model(
restApiId=dummy_rest_api_id,
name=model_name,
description=description,
contentType=content_type,
)
ex.exception.response["Error"]["Message"].should.equal(
"Invalid Rest API Id specified"
)
ex.exception.response["Error"]["Code"].should.equal("NotFoundException")
with assert_raises(ClientError) as ex:
client.create_model(
restApiId=rest_api_id,
name="",
description=description,
contentType=content_type,
)
ex.exception.response["Error"]["Message"].should.equal("No Model Name specified")
ex.exception.response["Error"]["Code"].should.equal("BadRequestException")
@mock_apigateway
def test_get_api_models():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
rest_api_id = response["id"]
model_name = "testModel"
description = "test model"
content_type = "application/json"
# when no models are present
result = client.get_models(restApiId=rest_api_id)
result["items"].should.equal([])
# add a model
client.create_model(
restApiId=rest_api_id,
name=model_name,
description=description,
contentType=content_type,
)
# get models after adding
result = client.get_models(restApiId=rest_api_id)
result["items"][0]["name"] = model_name
result["items"][0]["description"] = description
@mock_apigateway
def test_get_model_by_name():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
rest_api_id = response["id"]
dummy_rest_api_id = "a12b3c4d"
model_name = "testModel"
description = "test model"
content_type = "application/json"
# add a model
client.create_model(
restApiId=rest_api_id,
name=model_name,
description=description,
contentType=content_type,
)
# get models after adding
result = client.get_model(restApiId=rest_api_id, modelName=model_name)
result["name"] = model_name
result["description"] = description
with assert_raises(ClientError) as ex:
client.get_model(restApiId=dummy_rest_api_id, modelName=model_name)
ex.exception.response["Error"]["Message"].should.equal(
"Invalid Rest API Id specified"
)
ex.exception.response["Error"]["Code"].should.equal("NotFoundException")
@mock_apigateway
def test_get_model_with_invalid_name():
client = boto3.client("apigateway", region_name="us-west-2")
response = client.create_rest_api(name="my_api", description="this is my api")
rest_api_id = response["id"]
# test with an invalid model name
with assert_raises(ClientError) as ex:
client.get_model(restApiId=rest_api_id, modelName="fake")
ex.exception.response["Error"]["Message"].should.equal(
"Invalid Model Name specified"
)
ex.exception.response["Error"]["Code"].should.equal("NotFoundException")
@mock_apigateway
def test_http_proxying_integration():
responses.add(

View File

@ -843,13 +843,41 @@ def test_describe_autoscaling_instances_boto3():
NewInstancesProtectedFromScaleIn=True,
)
response = client.describe_auto_scaling_instances()
len(response["AutoScalingInstances"]).should.equal(5)
for instance in response["AutoScalingInstances"]:
instance["AutoScalingGroupName"].should.equal("test_asg")
instance["AvailabilityZone"].should.equal("us-east-1a")
instance["ProtectedFromScaleIn"].should.equal(True)
@mock_autoscaling
def test_describe_autoscaling_instances_instanceid_filter():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
_ = client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=20,
DesiredCapacity=5,
VPCZoneIdentifier=mocked_networking["subnet1"],
NewInstancesProtectedFromScaleIn=True,
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_ids = [
instance["InstanceId"]
for instance in response["AutoScalingGroups"][0]["Instances"]
]
response = client.describe_auto_scaling_instances(InstanceIds=instance_ids)
response = client.describe_auto_scaling_instances(
InstanceIds=instance_ids[0:2]
) # Filter by first 2 of 5
len(response["AutoScalingInstances"]).should.equal(2)
for instance in response["AutoScalingInstances"]:
instance["AutoScalingGroupName"].should.equal("test_asg")
instance["AvailabilityZone"].should.equal("us-east-1a")
@ -1074,8 +1102,6 @@ def test_detach_one_instance_decrement():
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = ec2_client.describe_instances(InstanceIds=[instance_to_detach])
response = client.detach_instances(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_detach],
@ -1128,8 +1154,6 @@ def test_detach_one_instance():
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = ec2_client.describe_instances(InstanceIds=[instance_to_detach])
response = client.detach_instances(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_detach],
@ -1150,6 +1174,516 @@ def test_detach_one_instance():
tags.should.have.length_of(2)
@mock_autoscaling
@mock_ec2
def test_standby_one_instance_decrement():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=2,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby = response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"]
instance_to_keep = response["AutoScalingGroups"][0]["Instances"][1]["InstanceId"]
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby],
ShouldDecrementDesiredCapacity=True,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(2)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(1)
response = client.describe_auto_scaling_instances(InstanceIds=[instance_to_standby])
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
# test to ensure tag has been retained (standby instance is still part of the ASG)
response = ec2_client.describe_instances()
for reservation in response["Reservations"]:
for instance in reservation["Instances"]:
tags = instance["Tags"]
tags.should.have.length_of(2)
@mock_autoscaling
@mock_ec2
def test_standby_one_instance():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=2,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby = response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"]
instance_to_keep = response["AutoScalingGroups"][0]["Instances"][1]["InstanceId"]
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(3)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(2)
response = client.describe_auto_scaling_instances(InstanceIds=[instance_to_standby])
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
# test to ensure tag has been retained (standby instance is still part of the ASG)
response = ec2_client.describe_instances()
for reservation in response["Reservations"]:
for instance in reservation["Instances"]:
tags = instance["Tags"]
tags.should.have.length_of(2)
@mock_elb
@mock_autoscaling
@mock_ec2
def test_standby_elb_update():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=2,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
elb_client = boto3.client("elb", region_name="us-east-1")
elb_client.create_load_balancer(
LoadBalancerName="my-lb",
Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}],
AvailabilityZones=["us-east-1a", "us-east-1b"],
)
response = client.attach_load_balancers(
AutoScalingGroupName="test_asg", LoadBalancerNames=["my-lb"]
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby = response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"]
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(3)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(2)
response = client.describe_auto_scaling_instances(InstanceIds=[instance_to_standby])
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
response = elb_client.describe_load_balancers(LoadBalancerNames=["my-lb"])
list(response["LoadBalancerDescriptions"][0]["Instances"]).should.have.length_of(2)
@mock_autoscaling
@mock_ec2
def test_standby_terminate_instance_decrement():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby_terminate = response["AutoScalingGroups"][0]["Instances"][0][
"InstanceId"
]
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby_terminate],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(3)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(2)
response = client.describe_auto_scaling_instances(
InstanceIds=[instance_to_standby_terminate]
)
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
response = client.terminate_instance_in_auto_scaling_group(
InstanceId=instance_to_standby_terminate, ShouldDecrementDesiredCapacity=True
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
# AWS still decrements desired capacity ASG if requested, even if the terminated instance is in standby
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(1)
response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"].should_not.equal(
instance_to_standby_terminate
)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(1)
response = ec2_client.describe_instances(
InstanceIds=[instance_to_standby_terminate]
)
response["Reservations"][0]["Instances"][0]["State"]["Name"].should.equal(
"terminated"
)
@mock_autoscaling
@mock_ec2
def test_standby_terminate_instance_no_decrement():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby_terminate = response["AutoScalingGroups"][0]["Instances"][0][
"InstanceId"
]
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby_terminate],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(3)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(2)
response = client.describe_auto_scaling_instances(
InstanceIds=[instance_to_standby_terminate]
)
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
response = client.terminate_instance_in_auto_scaling_group(
InstanceId=instance_to_standby_terminate, ShouldDecrementDesiredCapacity=False
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
group = response["AutoScalingGroups"][0]
group["Instances"].should.have.length_of(2)
instance_to_standby_terminate.shouldnt.be.within(
[x["InstanceId"] for x in group["Instances"]]
)
group["DesiredCapacity"].should.equal(2)
response = ec2_client.describe_instances(
InstanceIds=[instance_to_standby_terminate]
)
response["Reservations"][0]["Instances"][0]["State"]["Name"].should.equal(
"terminated"
)
@mock_autoscaling
@mock_ec2
def test_standby_detach_instance_decrement():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby_detach = response["AutoScalingGroups"][0]["Instances"][0][
"InstanceId"
]
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby_detach],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(3)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(2)
response = client.describe_auto_scaling_instances(
InstanceIds=[instance_to_standby_detach]
)
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
response = client.detach_instances(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby_detach],
ShouldDecrementDesiredCapacity=True,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
# AWS still decrements desired capacity ASG if requested, even if the detached instance was in standby
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(1)
response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"].should_not.equal(
instance_to_standby_detach
)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(1)
response = ec2_client.describe_instances(InstanceIds=[instance_to_standby_detach])
response["Reservations"][0]["Instances"][0]["State"]["Name"].should.equal("running")
@mock_autoscaling
@mock_ec2
def test_standby_detach_instance_no_decrement():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby_detach = response["AutoScalingGroups"][0]["Instances"][0][
"InstanceId"
]
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby_detach],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(3)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(2)
response = client.describe_auto_scaling_instances(
InstanceIds=[instance_to_standby_detach]
)
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
response = client.detach_instances(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby_detach],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
group = response["AutoScalingGroups"][0]
group["Instances"].should.have.length_of(2)
instance_to_standby_detach.shouldnt.be.within(
[x["InstanceId"] for x in group["Instances"]]
)
group["DesiredCapacity"].should.equal(2)
response = ec2_client.describe_instances(InstanceIds=[instance_to_standby_detach])
response["Reservations"][0]["Instances"][0]["State"]["Name"].should.equal("running")
@mock_autoscaling
@mock_ec2
def test_standby_exit_standby():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
MaxSize=3,
DesiredCapacity=2,
Tags=[
{
"ResourceId": "test_asg",
"ResourceType": "auto-scaling-group",
"Key": "propogated-tag-key",
"Value": "propagate-tag-value",
"PropagateAtLaunch": True,
}
],
VPCZoneIdentifier=mocked_networking["subnet1"],
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
instance_to_standby_exit_standby = response["AutoScalingGroups"][0]["Instances"][0][
"InstanceId"
]
ec2_client = boto3.client("ec2", region_name="us-east-1")
response = client.enter_standby(
AutoScalingGroupName="test_asg",
InstanceIds=[instance_to_standby_exit_standby],
ShouldDecrementDesiredCapacity=False,
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.have.length_of(3)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(2)
response = client.describe_auto_scaling_instances(
InstanceIds=[instance_to_standby_exit_standby]
)
response["AutoScalingInstances"][0]["LifecycleState"].should.equal("Standby")
response = client.exit_standby(
AutoScalingGroupName="test_asg", InstanceIds=[instance_to_standby_exit_standby],
)
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
group = response["AutoScalingGroups"][0]
group["Instances"].should.have.length_of(3)
instance_to_standby_exit_standby.should.be.within(
[x["InstanceId"] for x in group["Instances"]]
)
group["DesiredCapacity"].should.equal(3)
response = ec2_client.describe_instances(
InstanceIds=[instance_to_standby_exit_standby]
)
response["Reservations"][0]["Instances"][0]["State"]["Name"].should.equal("running")
@mock_autoscaling
@mock_ec2
def test_attach_one_instance():
@ -1383,7 +1917,7 @@ def test_set_desired_capacity_down_boto3():
@mock_autoscaling
@mock_ec2
def test_terminate_instance_in_autoscaling_group():
def test_terminate_instance_via_ec2_in_autoscaling_group():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
@ -1412,3 +1946,71 @@ def test_terminate_instance_in_autoscaling_group():
for instance in response["AutoScalingGroups"][0]["Instances"]
)
replaced_instance_id.should_not.equal(original_instance_id)
@mock_autoscaling
@mock_ec2
def test_terminate_instance_in_auto_scaling_group_decrement():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
_ = client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
DesiredCapacity=1,
MaxSize=2,
VPCZoneIdentifier=mocked_networking["subnet1"],
NewInstancesProtectedFromScaleIn=False,
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
original_instance_id = next(
instance["InstanceId"]
for instance in response["AutoScalingGroups"][0]["Instances"]
)
client.terminate_instance_in_auto_scaling_group(
InstanceId=original_instance_id, ShouldDecrementDesiredCapacity=True
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
response["AutoScalingGroups"][0]["Instances"].should.equal([])
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(0)
@mock_autoscaling
@mock_ec2
def test_terminate_instance_in_auto_scaling_group_no_decrement():
mocked_networking = setup_networking()
client = boto3.client("autoscaling", region_name="us-east-1")
_ = client.create_launch_configuration(
LaunchConfigurationName="test_launch_configuration"
)
_ = client.create_auto_scaling_group(
AutoScalingGroupName="test_asg",
LaunchConfigurationName="test_launch_configuration",
MinSize=0,
DesiredCapacity=1,
MaxSize=2,
VPCZoneIdentifier=mocked_networking["subnet1"],
NewInstancesProtectedFromScaleIn=False,
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
original_instance_id = next(
instance["InstanceId"]
for instance in response["AutoScalingGroups"][0]["Instances"]
)
client.terminate_instance_in_auto_scaling_group(
InstanceId=original_instance_id, ShouldDecrementDesiredCapacity=False
)
response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"])
replaced_instance_id = next(
instance["InstanceId"]
for instance in response["AutoScalingGroups"][0]["Instances"]
)
replaced_instance_id.should_not.equal(original_instance_id)
response["AutoScalingGroups"][0]["DesiredCapacity"].should.equal(1)

View File

@ -1677,6 +1677,42 @@ def test_create_function_with_unknown_arn():
)
@mock_lambda
def test_remove_function_permission():
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()
conn.create_function(
FunctionName="testFunction",
Runtime="python2.7",
Role=(get_role_name()),
Handler="lambda_function.handler",
Code={"ZipFile": zip_content},
Description="test lambda function",
Timeout=3,
MemorySize=128,
Publish=True,
)
conn.add_permission(
FunctionName="testFunction",
StatementId="1",
Action="lambda:InvokeFunction",
Principal="432143214321",
SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld",
SourceAccount="123412341234",
EventSourceToken="blah",
Qualifier="2",
)
remove = conn.remove_permission(
FunctionName="testFunction", StatementId="1", Qualifier="2",
)
remove["ResponseMetadata"]["HTTPStatusCode"].should.equal(204)
policy = conn.get_policy(FunctionName="testFunction", Qualifier="2")["Policy"]
policy = json.loads(policy)
policy["Statement"].should.equal([])
def create_invalid_lambda(role):
conn = boto3.client("lambda", _lambda_region)
zip_content = get_test_zip_file1()

View File

@ -495,7 +495,7 @@ def test_autoscaling_group_with_elb():
"my-as-group": {
"Type": "AWS::AutoScaling::AutoScalingGroup",
"Properties": {
"AvailabilityZones": ["us-east1"],
"AvailabilityZones": ["us-east-1a"],
"LaunchConfigurationName": {"Ref": "my-launch-config"},
"MinSize": "2",
"MaxSize": "2",
@ -522,7 +522,7 @@ def test_autoscaling_group_with_elb():
"my-elb": {
"Type": "AWS::ElasticLoadBalancing::LoadBalancer",
"Properties": {
"AvailabilityZones": ["us-east1"],
"AvailabilityZones": ["us-east-1a"],
"Listeners": [
{
"LoadBalancerPort": "80",
@ -545,10 +545,10 @@ def test_autoscaling_group_with_elb():
web_setup_template_json = json.dumps(web_setup_template)
conn = boto.cloudformation.connect_to_region("us-west-1")
conn = boto.cloudformation.connect_to_region("us-east-1")
conn.create_stack("web_stack", template_body=web_setup_template_json)
autoscale_conn = boto.ec2.autoscale.connect_to_region("us-west-1")
autoscale_conn = boto.ec2.autoscale.connect_to_region("us-east-1")
autoscale_group = autoscale_conn.get_all_groups()[0]
autoscale_group.launch_config_name.should.contain("my-launch-config")
autoscale_group.load_balancers[0].should.equal("my-elb")
@ -557,7 +557,7 @@ def test_autoscaling_group_with_elb():
autoscale_conn.get_all_launch_configurations().should.have.length_of(1)
# Confirm the ELB was actually created
elb_conn = boto.ec2.elb.connect_to_region("us-west-1")
elb_conn = boto.ec2.elb.connect_to_region("us-east-1")
elb_conn.get_all_load_balancers().should.have.length_of(1)
stack = conn.describe_stacks()[0]
@ -584,7 +584,7 @@ def test_autoscaling_group_with_elb():
elb_resource.physical_resource_id.should.contain("my-elb")
# confirm the instances were created with the right tags
ec2_conn = boto.ec2.connect_to_region("us-west-1")
ec2_conn = boto.ec2.connect_to_region("us-east-1")
reservations = ec2_conn.get_all_reservations()
len(reservations).should.equal(1)
reservation = reservations[0]
@ -604,7 +604,7 @@ def test_autoscaling_group_update():
"my-as-group": {
"Type": "AWS::AutoScaling::AutoScalingGroup",
"Properties": {
"AvailabilityZones": ["us-west-1"],
"AvailabilityZones": ["us-west-1a"],
"LaunchConfigurationName": {"Ref": "my-launch-config"},
"MinSize": "2",
"MaxSize": "2",
@ -2373,13 +2373,12 @@ def test_create_log_group_using_fntransform():
}
cf_conn = boto3.client("cloudformation", "us-west-2")
cf_conn.create_stack(
StackName="test_stack", TemplateBody=json.dumps(template),
)
cf_conn.create_stack(StackName="test_stack", TemplateBody=json.dumps(template))
logs_conn = boto3.client("logs", region_name="us-west-2")
log_group = logs_conn.describe_log_groups()["logGroups"][0]
log_group["logGroupName"].should.equal("some-log-group")
log_group["retentionInDays"].should.be.equal(90)
@mock_cloudformation
@ -2400,7 +2399,7 @@ def test_stack_events_create_rule_integration():
}
cf_conn = boto3.client("cloudformation", "us-west-2")
cf_conn.create_stack(
StackName="test_stack", TemplateBody=json.dumps(events_template),
StackName="test_stack", TemplateBody=json.dumps(events_template)
)
rules = boto3.client("events", "us-west-2").list_rules()
@ -2428,7 +2427,7 @@ def test_stack_events_delete_rule_integration():
}
cf_conn = boto3.client("cloudformation", "us-west-2")
cf_conn.create_stack(
StackName="test_stack", TemplateBody=json.dumps(events_template),
StackName="test_stack", TemplateBody=json.dumps(events_template)
)
rules = boto3.client("events", "us-west-2").list_rules()
@ -2457,8 +2456,45 @@ def test_stack_events_create_rule_without_name_integration():
}
cf_conn = boto3.client("cloudformation", "us-west-2")
cf_conn.create_stack(
StackName="test_stack", TemplateBody=json.dumps(events_template),
StackName="test_stack", TemplateBody=json.dumps(events_template)
)
rules = boto3.client("events", "us-west-2").list_rules()
rules["Rules"][0]["Name"].should.contain("test_stack-Event-")
@mock_cloudformation
@mock_events
@mock_logs
def test_stack_events_create_rule_as_target():
events_template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
"SecurityGroup": {
"Type": "AWS::Logs::LogGroup",
"Properties": {
"LogGroupName": {"Fn::GetAtt": ["Event", "Arn"]},
"RetentionInDays": 3,
},
},
"Event": {
"Type": "AWS::Events::Rule",
"Properties": {
"State": "ENABLED",
"ScheduleExpression": "rate(5 minutes)",
},
},
},
}
cf_conn = boto3.client("cloudformation", "us-west-2")
cf_conn.create_stack(
StackName="test_stack", TemplateBody=json.dumps(events_template)
)
rules = boto3.client("events", "us-west-2").list_rules()
log_groups = boto3.client("logs", "us-west-2").describe_log_groups()
rules["Rules"][0]["Name"].should.contain("test_stack-Event-")
log_groups["logGroups"][0]["logGroupName"].should.equal(rules["Rules"][0]["Arn"])
log_groups["logGroups"][0]["retentionInDays"].should.equal(3)

View File

@ -1,9 +1,10 @@
import boto
from boto.ec2.cloudwatch.alarm import MetricAlarm
from boto.s3.key import Key
from datetime import datetime
import sure # noqa
from moto import mock_cloudwatch_deprecated
from moto import mock_cloudwatch_deprecated, mock_s3_deprecated
def alarm_fixture(name="tester", action=None):
@ -83,10 +84,11 @@ def test_put_metric_data():
)
metrics = conn.list_metrics()
metrics.should.have.length_of(1)
metric_names = [m for m in metrics if m.name == "metric"]
metric_names.should.have(1)
metric = metrics[0]
metric.namespace.should.equal("tester")
metric.name.should.equal("metric")
metric.name.should.equal("Metric:metric")
dict(metric.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]})
@ -153,3 +155,35 @@ def test_get_metric_statistics():
datapoint = datapoints[0]
datapoint.should.have.key("Minimum").which.should.equal(1.5)
datapoint.should.have.key("Timestamp").which.should.equal(metric_timestamp)
@mock_s3_deprecated
@mock_cloudwatch_deprecated
def test_cloudwatch_return_s3_metrics():
region = "us-east-1"
cw = boto.ec2.cloudwatch.connect_to_region(region)
s3 = boto.s3.connect_to_region(region)
bucket_name_1 = "test-bucket-1"
bucket_name_2 = "test-bucket-2"
bucket1 = s3.create_bucket(bucket_name=bucket_name_1)
key = Key(bucket1)
key.key = "the-key"
key.set_contents_from_string("foobar" * 4)
s3.create_bucket(bucket_name=bucket_name_2)
metrics_s3_bucket_1 = cw.list_metrics(dimensions={"BucketName": bucket_name_1})
# Verify that the OOTB S3 metrics are available for the created buckets
len(metrics_s3_bucket_1).should.be(2)
metric_names = [m.name for m in metrics_s3_bucket_1]
sorted(metric_names).should.equal(
["Metric:BucketSizeBytes", "Metric:NumberOfObjects"]
)
# Explicit clean up - the metrics for these buckets are messing with subsequent tests
key.delete()
s3.delete_bucket(bucket_name_1)
s3.delete_bucket(bucket_name_2)

View File

@ -3,6 +3,7 @@
import boto3
from botocore.exceptions import ClientError
from datetime import datetime, timedelta
from freezegun import freeze_time
from nose.tools import assert_raises
from uuid import uuid4
import pytz
@ -154,7 +155,7 @@ def test_put_metric_data_no_dimensions():
metrics.should.have.length_of(1)
metric = metrics[0]
metric["Namespace"].should.equal("tester")
metric["MetricName"].should.equal("metric")
metric["MetricName"].should.equal("Metric:metric")
@mock_cloudwatch
@ -182,7 +183,7 @@ def test_put_metric_data_with_statistics():
metrics.should.have.length_of(1)
metric = metrics[0]
metric["Namespace"].should.equal("tester")
metric["MetricName"].should.equal("statmetric")
metric["MetricName"].should.equal("Metric:statmetric")
# TODO: test statistics - https://github.com/spulec/moto/issues/1615
@ -211,6 +212,35 @@ def test_get_metric_statistics():
datapoint["Sum"].should.equal(1.5)
@mock_cloudwatch
@freeze_time("2020-02-10 18:44:05")
def test_custom_timestamp():
utc_now = datetime.now(tz=pytz.utc)
time = "2020-02-10T18:44:09Z"
cw = boto3.client("cloudwatch", "eu-west-1")
cw.put_metric_data(
Namespace="tester",
MetricData=[dict(MetricName="metric1", Value=1.5, Timestamp=time)],
)
cw.put_metric_data(
Namespace="tester",
MetricData=[
dict(MetricName="metric2", Value=1.5, Timestamp=datetime(2020, 2, 10))
],
)
stats = cw.get_metric_statistics(
Namespace="tester",
MetricName="metric",
StartTime=utc_now - timedelta(seconds=60),
EndTime=utc_now + timedelta(seconds=60),
Period=60,
Statistics=["SampleCount", "Sum"],
)
@mock_cloudwatch
def test_list_metrics():
cloudwatch = boto3.client("cloudwatch", "eu-west-1")
@ -233,8 +263,16 @@ def test_list_metrics():
# Verify format
res.should.equal(
[
{u"Namespace": "list_test_1/", u"Dimensions": [], u"MetricName": "metric1"},
{u"Namespace": "list_test_1/", u"Dimensions": [], u"MetricName": "metric1"},
{
u"Namespace": "list_test_1/",
u"Dimensions": [],
u"MetricName": "Metric:metric1",
},
{
u"Namespace": "list_test_1/",
u"Dimensions": [],
u"MetricName": "Metric:metric1",
},
]
)
# Verify unknown namespace still has no results
@ -292,3 +330,232 @@ def create_metrics(cloudwatch, namespace, metrics=5, data_points=5):
Namespace=namespace,
MetricData=[{"MetricName": metric_name, "Value": j, "Unit": "Seconds"}],
)
@mock_cloudwatch
def test_get_metric_data_within_timeframe():
utc_now = datetime.now(tz=pytz.utc)
cloudwatch = boto3.client("cloudwatch", "eu-west-1")
namespace1 = "my_namespace/"
# put metric data
values = [0, 2, 4, 3.5, 7, 100]
cloudwatch.put_metric_data(
Namespace=namespace1,
MetricData=[
{"MetricName": "metric1", "Value": val, "Unit": "Seconds"} for val in values
],
)
# get_metric_data
stats = ["Average", "Sum", "Minimum", "Maximum"]
response = cloudwatch.get_metric_data(
MetricDataQueries=[
{
"Id": "result_" + stat,
"MetricStat": {
"Metric": {"Namespace": namespace1, "MetricName": "metric1"},
"Period": 60,
"Stat": stat,
},
}
for stat in stats
],
StartTime=utc_now - timedelta(seconds=60),
EndTime=utc_now + timedelta(seconds=60),
)
#
# Assert Average/Min/Max/Sum is returned as expected
avg = [
res for res in response["MetricDataResults"] if res["Id"] == "result_Average"
][0]
avg["Label"].should.equal("metric1 Average")
avg["StatusCode"].should.equal("Complete")
[int(val) for val in avg["Values"]].should.equal([19])
sum_ = [res for res in response["MetricDataResults"] if res["Id"] == "result_Sum"][
0
]
sum_["Label"].should.equal("metric1 Sum")
sum_["StatusCode"].should.equal("Complete")
[val for val in sum_["Values"]].should.equal([sum(values)])
min_ = [
res for res in response["MetricDataResults"] if res["Id"] == "result_Minimum"
][0]
min_["Label"].should.equal("metric1 Minimum")
min_["StatusCode"].should.equal("Complete")
[int(val) for val in min_["Values"]].should.equal([0])
max_ = [
res for res in response["MetricDataResults"] if res["Id"] == "result_Maximum"
][0]
max_["Label"].should.equal("metric1 Maximum")
max_["StatusCode"].should.equal("Complete")
[int(val) for val in max_["Values"]].should.equal([100])
@mock_cloudwatch
def test_get_metric_data_partially_within_timeframe():
utc_now = datetime.now(tz=pytz.utc)
yesterday = utc_now - timedelta(days=1)
last_week = utc_now - timedelta(days=7)
cloudwatch = boto3.client("cloudwatch", "eu-west-1")
namespace1 = "my_namespace/"
# put metric data
values = [0, 2, 4, 3.5, 7, 100]
cloudwatch.put_metric_data(
Namespace=namespace1,
MetricData=[
{
"MetricName": "metric1",
"Value": 10,
"Unit": "Seconds",
"Timestamp": utc_now,
}
],
)
cloudwatch.put_metric_data(
Namespace=namespace1,
MetricData=[
{
"MetricName": "metric1",
"Value": 20,
"Unit": "Seconds",
"Timestamp": yesterday,
}
],
)
cloudwatch.put_metric_data(
Namespace=namespace1,
MetricData=[
{
"MetricName": "metric1",
"Value": 50,
"Unit": "Seconds",
"Timestamp": last_week,
}
],
)
# get_metric_data
response = cloudwatch.get_metric_data(
MetricDataQueries=[
{
"Id": "result",
"MetricStat": {
"Metric": {"Namespace": namespace1, "MetricName": "metric1"},
"Period": 60,
"Stat": "Sum",
},
}
],
StartTime=yesterday - timedelta(seconds=60),
EndTime=utc_now + timedelta(seconds=60),
)
#
# Assert Last week's data is not returned
len(response["MetricDataResults"]).should.equal(1)
sum_ = response["MetricDataResults"][0]
sum_["Label"].should.equal("metric1 Sum")
sum_["StatusCode"].should.equal("Complete")
sum_["Values"].should.equal([30.0])
@mock_cloudwatch
def test_get_metric_data_outside_timeframe():
utc_now = datetime.now(tz=pytz.utc)
last_week = utc_now - timedelta(days=7)
cloudwatch = boto3.client("cloudwatch", "eu-west-1")
namespace1 = "my_namespace/"
# put metric data
cloudwatch.put_metric_data(
Namespace=namespace1,
MetricData=[
{
"MetricName": "metric1",
"Value": 50,
"Unit": "Seconds",
"Timestamp": last_week,
}
],
)
# get_metric_data
response = cloudwatch.get_metric_data(
MetricDataQueries=[
{
"Id": "result",
"MetricStat": {
"Metric": {"Namespace": namespace1, "MetricName": "metric1"},
"Period": 60,
"Stat": "Sum",
},
}
],
StartTime=utc_now - timedelta(seconds=60),
EndTime=utc_now + timedelta(seconds=60),
)
#
# Assert Last week's data is not returned
len(response["MetricDataResults"]).should.equal(1)
response["MetricDataResults"][0]["Id"].should.equal("result")
response["MetricDataResults"][0]["StatusCode"].should.equal("Complete")
response["MetricDataResults"][0]["Values"].should.equal([])
@mock_cloudwatch
def test_get_metric_data_for_multiple_metrics():
utc_now = datetime.now(tz=pytz.utc)
cloudwatch = boto3.client("cloudwatch", "eu-west-1")
namespace = "my_namespace/"
# put metric data
cloudwatch.put_metric_data(
Namespace=namespace,
MetricData=[
{
"MetricName": "metric1",
"Value": 50,
"Unit": "Seconds",
"Timestamp": utc_now,
}
],
)
cloudwatch.put_metric_data(
Namespace=namespace,
MetricData=[
{
"MetricName": "metric2",
"Value": 25,
"Unit": "Seconds",
"Timestamp": utc_now,
}
],
)
# get_metric_data
response = cloudwatch.get_metric_data(
MetricDataQueries=[
{
"Id": "result1",
"MetricStat": {
"Metric": {"Namespace": namespace, "MetricName": "metric1"},
"Period": 60,
"Stat": "Sum",
},
},
{
"Id": "result2",
"MetricStat": {
"Metric": {"Namespace": namespace, "MetricName": "metric2"},
"Period": 60,
"Stat": "Sum",
},
},
],
StartTime=utc_now - timedelta(seconds=60),
EndTime=utc_now + timedelta(seconds=60),
)
#
len(response["MetricDataResults"]).should.equal(2)
res1 = [res for res in response["MetricDataResults"] if res["Id"] == "result1"][0]
res1["Values"].should.equal([50.0])
res2 = [res for res in response["MetricDataResults"] if res["Id"] == "result2"][0]
res2["Values"].should.equal([25.0])

View File

@ -7,6 +7,7 @@ from nose.tools import assert_raises
from moto import mock_cognitoidentity
from moto.cognitoidentity.utils import get_random_identity_id
from moto.core import ACCOUNT_ID
from uuid import UUID
@mock_cognitoidentity
@ -83,8 +84,10 @@ def test_describe_identity_pool_with_invalid_id_raises_error():
# testing a helper function
def test_get_random_identity_id():
assert len(get_random_identity_id("us-west-2")) > 0
assert len(get_random_identity_id("us-west-2").split(":")[1]) == 19
identity_id = get_random_identity_id("us-west-2")
region, id = identity_id.split(":")
region.should.equal("us-west-2")
UUID(id, version=4) # Will throw an error if it's not a valid UUID
@mock_cognitoidentity
@ -96,7 +99,6 @@ def test_get_id():
IdentityPoolId="us-west-2:12345",
Logins={"someurl": "12345"},
)
print(result)
assert (
result.get("IdentityId", "").startswith("us-west-2")
or result.get("ResponseMetadata").get("HTTPStatusCode") == 200

View File

@ -48,6 +48,5 @@ def test_get_id():
},
)
print(res.data)
json_data = json.loads(res.data.decode("utf-8"))
assert ":" in json_data["IdentityId"]

View File

@ -1,5 +1,6 @@
from __future__ import unicode_literals, print_function
import re
from decimal import Decimal
import six
@ -1453,6 +1454,13 @@ def test_filter_expression():
filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False)
# lowercase AND test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression(
"Id > :v0 and Subs < :v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "7"}}
)
filter_expr.expr(row1).should.be(True)
filter_expr.expr(row2).should.be(False)
# OR test
filter_expr = moto.dynamodb2.comparisons.get_filter_expression(
"Id = :v0 OR Id=:v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "8"}}
@ -2146,13 +2154,33 @@ def test_update_item_on_map():
# Nonexistent nested attributes are supported for existing top-level attributes.
table.update_item(
Key={"forum_name": "the-key", "subject": "123"},
UpdateExpression="SET body.#nested.#data = :tb, body.nested.#nonexistentnested.#data = :tb2",
UpdateExpression="SET body.#nested.#data = :tb",
ExpressionAttributeNames={"#nested": "nested", "#data": "data",},
ExpressionAttributeValues={":tb": "new_value"},
)
# Running this against AWS DDB gives an exception so make sure it also fails.:
with assert_raises(client.exceptions.ClientError):
# botocore.exceptions.ClientError: An error occurred (ValidationException) when calling the UpdateItem
# operation: The document path provided in the update expression is invalid for update
table.update_item(
Key={"forum_name": "the-key", "subject": "123"},
UpdateExpression="SET body.#nested.#nonexistentnested.#data = :tb2",
ExpressionAttributeNames={
"#nested": "nested",
"#nonexistentnested": "nonexistentnested",
"#data": "data",
},
ExpressionAttributeValues={":tb": "new_value", ":tb2": "other_value"},
ExpressionAttributeValues={":tb2": "other_value"},
)
table.update_item(
Key={"forum_name": "the-key", "subject": "123"},
UpdateExpression="SET body.#nested.#nonexistentnested = :tb2",
ExpressionAttributeNames={
"#nested": "nested",
"#nonexistentnested": "nonexistentnested",
},
ExpressionAttributeValues={":tb2": {"data": "other_value"}},
)
resp = table.scan()
@ -2160,8 +2188,8 @@ def test_update_item_on_map():
{"nested": {"data": "new_value", "nonexistentnested": {"data": "other_value"}}}
)
# Test nested value for a nonexistent attribute.
with assert_raises(client.exceptions.ConditionalCheckFailedException):
# Test nested value for a nonexistent attribute throws a ClientError.
with assert_raises(client.exceptions.ClientError):
table.update_item(
Key={"forum_name": "the-key", "subject": "123"},
UpdateExpression="SET nonexistent.#nested = :tb",
@ -2764,7 +2792,7 @@ def test_query_gsi_with_range_key():
res = dynamodb.query(
TableName="test",
IndexName="test_gsi",
KeyConditionExpression="gsi_hash_key = :gsi_hash_key AND gsi_range_key = :gsi_range_key",
KeyConditionExpression="gsi_hash_key = :gsi_hash_key and gsi_range_key = :gsi_range_key",
ExpressionAttributeValues={
":gsi_hash_key": {"S": "key1"},
":gsi_range_key": {"S": "range1"},
@ -3183,7 +3211,10 @@ def test_remove_top_level_attribute():
TableName=table_name, Item={"id": {"S": "foo"}, "item": {"S": "bar"}}
)
client.update_item(
TableName=table_name, Key={"id": {"S": "foo"}}, UpdateExpression="REMOVE item"
TableName=table_name,
Key={"id": {"S": "foo"}},
UpdateExpression="REMOVE #i",
ExpressionAttributeNames={"#i": "item"},
)
#
result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"]
@ -3358,21 +3389,21 @@ def test_item_size_is_under_400KB():
assert_failure_due_to_item_size(
func=client.put_item,
TableName="moto-test",
Item={"id": {"S": "foo"}, "item": {"S": large_item}},
Item={"id": {"S": "foo"}, "cont": {"S": large_item}},
)
assert_failure_due_to_item_size(
func=table.put_item, Item={"id": "bar", "item": large_item}
func=table.put_item, Item={"id": "bar", "cont": large_item}
)
assert_failure_due_to_item_size(
assert_failure_due_to_item_size_to_update(
func=client.update_item,
TableName="moto-test",
Key={"id": {"S": "foo2"}},
UpdateExpression="set item=:Item",
UpdateExpression="set cont=:Item",
ExpressionAttributeValues={":Item": {"S": large_item}},
)
# Assert op fails when updating a nested item
assert_failure_due_to_item_size(
func=table.put_item, Item={"id": "bar", "itemlist": [{"item": large_item}]}
func=table.put_item, Item={"id": "bar", "itemlist": [{"cont": large_item}]}
)
assert_failure_due_to_item_size(
func=client.put_item,
@ -3393,6 +3424,15 @@ def assert_failure_due_to_item_size(func, **kwargs):
)
def assert_failure_due_to_item_size_to_update(func, **kwargs):
with assert_raises(ClientError) as ex:
func(**kwargs)
ex.exception.response["Error"]["Code"].should.equal("ValidationException")
ex.exception.response["Error"]["Message"].should.equal(
"Item size to update has exceeded the maximum allowed size"
)
@mock_dynamodb2
# https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_Query.html#DDB-Query-request-KeyConditionExpression
def test_hash_key_cannot_use_begins_with_operations():
@ -4177,3 +4217,117 @@ def test_gsi_verify_negative_number_order():
[float(item["gsiK1SortKey"]) for item in resp["Items"]].should.equal(
[-0.7, -0.6, 0.7]
)
@mock_dynamodb2
def test_dynamodb_max_1mb_limit():
ddb = boto3.resource("dynamodb", region_name="eu-west-1")
table_name = "populated-mock-table"
table = ddb.create_table(
TableName=table_name,
KeySchema=[
{"AttributeName": "partition_key", "KeyType": "HASH"},
{"AttributeName": "sort_key", "KeyType": "RANGE"},
],
AttributeDefinitions=[
{"AttributeName": "partition_key", "AttributeType": "S"},
{"AttributeName": "sort_key", "AttributeType": "S"},
],
BillingMode="PAY_PER_REQUEST",
)
# Populate the table
items = [
{
"partition_key": "partition_key_val", # size=30
"sort_key": "sort_key_value____" + str(i), # size=30
}
for i in range(10000, 29999)
]
with table.batch_writer() as batch:
for item in items:
batch.put_item(Item=item)
response = table.query(
KeyConditionExpression=Key("partition_key").eq("partition_key_val")
)
# We shouldn't get everything back - the total result set is well over 1MB
len(items).should.be.greater_than(response["Count"])
response["LastEvaluatedKey"].shouldnt.be(None)
def assert_raise_syntax_error(client_error, token, near):
"""
Assert whether a client_error is as expected Syntax error. Syntax error looks like: `syntax_error_template`
Args:
client_error(ClientError): The ClientError exception that was raised
token(str): The token that ws unexpected
near(str): The part in the expression that shows where the error occurs it generally has the preceding token the
optional separation and the problematic token.
"""
syntax_error_template = (
'Invalid UpdateExpression: Syntax error; token: "{token}", near: "{near}"'
)
expected_syntax_error = syntax_error_template.format(token=token, near=near)
assert client_error.response["Error"]["Code"] == "ValidationException"
assert expected_syntax_error == client_error.response["Error"]["Message"]
@mock_dynamodb2
def test_update_expression_with_numeric_literal_instead_of_value():
"""
DynamoDB requires literals to be passed in as values. If they are put literally in the expression a token error will
be raised
"""
dynamodb = boto3.client("dynamodb", region_name="eu-west-1")
dynamodb.create_table(
TableName="moto-test",
KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}],
)
try:
dynamodb.update_item(
TableName="moto-test",
Key={"id": {"S": "1"}},
UpdateExpression="SET MyStr = myNum + 1",
)
assert False, "Validation exception not thrown"
except dynamodb.exceptions.ClientError as e:
assert_raise_syntax_error(e, "1", "+ 1")
@mock_dynamodb2
def test_update_expression_with_multiple_set_clauses_must_be_comma_separated():
"""
An UpdateExpression can have multiple set clauses but if they are passed in without the separating comma.
"""
dynamodb = boto3.client("dynamodb", region_name="eu-west-1")
dynamodb.create_table(
TableName="moto-test",
KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}],
)
try:
dynamodb.update_item(
TableName="moto-test",
Key={"id": {"S": "1"}},
UpdateExpression="SET MyStr = myNum Mystr2 myNum2",
)
assert False, "Validation exception not thrown"
except dynamodb.exceptions.ClientError as e:
assert_raise_syntax_error(e, "Mystr2", "myNum Mystr2 myNum2")
@mock_dynamodb2
def test_list_tables_exclusive_start_table_name_empty():
client = boto3.client("dynamodb", region_name="us-east-1")
resp = client.list_tables(Limit=1, ExclusiveStartTableName="whatever")
len(resp["TableNames"]).should.equal(0)

View File

@ -0,0 +1,259 @@
from moto.dynamodb2.exceptions import (
InvalidTokenException,
InvalidExpressionAttributeNameKey,
)
from moto.dynamodb2.parsing.tokens import ExpressionTokenizer, Token
def test_expression_tokenizer_single_set_action():
set_action = "SET attrName = :attrValue"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attrName"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attrValue"),
]
def test_expression_tokenizer_single_set_action_leading_space():
set_action = "Set attrName = :attrValue"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "Set"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attrName"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attrValue"),
]
def test_expression_tokenizer_single_set_action_attribute_name_leading_space():
set_action = "SET #a = :attrValue"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_NAME, "#a"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attrValue"),
]
def test_expression_tokenizer_single_set_action_trailing_space():
set_action = "SET attrName = :attrValue "
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attrName"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attrValue"),
Token(Token.WHITESPACE, " "),
]
def test_expression_tokenizer_single_set_action_multi_spaces():
set_action = "SET attrName = :attrValue "
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attrName"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attrValue"),
Token(Token.WHITESPACE, " "),
]
def test_expression_tokenizer_single_set_action_with_numbers_in_identifiers():
set_action = "SET attrName3 = :attr3Value"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attrName3"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attr3Value"),
]
def test_expression_tokenizer_single_set_action_with_underscore_in_identifier():
set_action = "SET attr_Name = :attr_Value"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attr_Name"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attr_Value"),
]
def test_expression_tokenizer_leading_underscore_in_attribute_name_expression():
"""Leading underscore is not allowed for an attribute name"""
set_action = "SET attrName = _idid"
try:
ExpressionTokenizer.make_list(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "_"
assert te.near == "= _idid"
def test_expression_tokenizer_leading_underscore_in_attribute_value_expression():
"""Leading underscore is allowed in an attribute value"""
set_action = "SET attrName = :_attrValue"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attrName"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":_attrValue"),
]
def test_expression_tokenizer_single_set_action_nested_attribute():
set_action = "SET attrName.elem = :attrValue"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attrName"),
Token(Token.DOT, "."),
Token(Token.ATTRIBUTE, "elem"),
Token(Token.WHITESPACE, " "),
Token(Token.EQUAL_SIGN, "="),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE_VALUE, ":attrValue"),
]
def test_expression_tokenizer_list_index_with_sub_attribute():
set_action = "SET itemmap.itemlist[1].foos=:Item"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "itemmap"),
Token(Token.DOT, "."),
Token(Token.ATTRIBUTE, "itemlist"),
Token(Token.OPEN_SQUARE_BRACKET, "["),
Token(Token.NUMBER, "1"),
Token(Token.CLOSE_SQUARE_BRACKET, "]"),
Token(Token.DOT, "."),
Token(Token.ATTRIBUTE, "foos"),
Token(Token.EQUAL_SIGN, "="),
Token(Token.ATTRIBUTE_VALUE, ":Item"),
]
def test_expression_tokenizer_list_index_surrounded_with_whitespace():
set_action = "SET itemlist[ 1 ]=:Item"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "itemlist"),
Token(Token.OPEN_SQUARE_BRACKET, "["),
Token(Token.WHITESPACE, " "),
Token(Token.NUMBER, "1"),
Token(Token.WHITESPACE, " "),
Token(Token.CLOSE_SQUARE_BRACKET, "]"),
Token(Token.EQUAL_SIGN, "="),
Token(Token.ATTRIBUTE_VALUE, ":Item"),
]
def test_expression_tokenizer_single_set_action_attribute_name_invalid_key():
"""
ExpressionAttributeNames contains invalid key: Syntax error; key: "#va#l2"
"""
set_action = "SET #va#l2 = 3"
try:
ExpressionTokenizer.make_list(set_action)
assert False, "Exception not raised correctly"
except InvalidExpressionAttributeNameKey as e:
assert e.key == "#va#l2"
def test_expression_tokenizer_single_set_action_attribute_name_invalid_key_double_hash():
"""
ExpressionAttributeNames contains invalid key: Syntax error; key: "#va#l"
"""
set_action = "SET #va#l = 3"
try:
ExpressionTokenizer.make_list(set_action)
assert False, "Exception not raised correctly"
except InvalidExpressionAttributeNameKey as e:
assert e.key == "#va#l"
def test_expression_tokenizer_single_set_action_attribute_name_valid_key():
set_action = "SET attr=#val2"
token_list = ExpressionTokenizer.make_list(set_action)
assert token_list == [
Token(Token.ATTRIBUTE, "SET"),
Token(Token.WHITESPACE, " "),
Token(Token.ATTRIBUTE, "attr"),
Token(Token.EQUAL_SIGN, "="),
Token(Token.ATTRIBUTE_NAME, "#val2"),
]
def test_expression_tokenizer_just_a_pipe():
set_action = "|"
try:
ExpressionTokenizer.make_list(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "|"
assert te.near == "|"
def test_expression_tokenizer_just_a_pipe_with_leading_white_spaces():
set_action = " |"
try:
ExpressionTokenizer.make_list(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "|"
assert te.near == " |"
def test_expression_tokenizer_just_a_pipe_for_set_expression():
set_action = "SET|"
try:
ExpressionTokenizer.make_list(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "|"
assert te.near == "SET|"
def test_expression_tokenizer_just_an_attribute_and_a_pipe_for_set_expression():
set_action = "SET a|"
try:
ExpressionTokenizer.make_list(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "|"
assert te.near == "a|"

View File

@ -0,0 +1,405 @@
from moto.dynamodb2.exceptions import InvalidTokenException
from moto.dynamodb2.parsing.expressions import UpdateExpressionParser
from moto.dynamodb2.parsing.reserved_keywords import ReservedKeywords
def test_get_reserved_keywords():
reserved_keywords = ReservedKeywords.get_reserved_keywords()
assert "SET" in reserved_keywords
assert "DELETE" in reserved_keywords
assert "ADD" in reserved_keywords
# REMOVE is not part of the list of reserved keywords.
assert "REMOVE" not in reserved_keywords
def test_update_expression_numeric_literal_in_expression():
set_action = "SET attrName = 3"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "3"
assert te.near == "= 3"
def test_expression_tokenizer_multi_number_numeric_literal_in_expression():
set_action = "SET attrName = 34"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "34"
assert te.near == "= 34"
def test_expression_tokenizer_numeric_literal_unclosed_square_bracket():
set_action = "SET MyStr[ 3"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "<EOF>"
assert te.near == "3"
def test_expression_tokenizer_wrong_closing_bracket_with_space():
set_action = "SET MyStr[3 )"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == ")"
assert te.near == "3 )"
def test_expression_tokenizer_wrong_closing_bracket():
set_action = "SET MyStr[3)"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == ")"
assert te.near == "3)"
def test_expression_tokenizer_only_numeric_literal_for_set():
set_action = "SET 2"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "2"
assert te.near == "SET 2"
def test_expression_tokenizer_only_numeric_literal():
set_action = "2"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "2"
assert te.near == "2"
def test_expression_tokenizer_set_closing_round_bracket():
set_action = "SET )"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == ")"
assert te.near == "SET )"
def test_expression_tokenizer_set_closing_followed_by_numeric_literal():
set_action = "SET ) 3"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == ")"
assert te.near == "SET ) 3"
def test_expression_tokenizer_numeric_literal_unclosed_square_bracket_trailing_space():
set_action = "SET MyStr[ 3 "
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "<EOF>"
assert te.near == "3 "
def test_expression_tokenizer_unbalanced_round_brackets_only_opening():
set_action = "SET MyStr = (:_val"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "<EOF>"
assert te.near == ":_val"
def test_expression_tokenizer_unbalanced_round_brackets_only_opening_trailing_space():
set_action = "SET MyStr = (:_val "
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "<EOF>"
assert te.near == ":_val "
def test_expression_tokenizer_unbalanced_square_brackets_only_opening():
set_action = "SET MyStr = [:_val"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "["
assert te.near == "= [:_val"
def test_expression_tokenizer_unbalanced_square_brackets_only_opening_trailing_spaces():
set_action = "SET MyStr = [:_val "
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "["
assert te.near == "= [:_val"
def test_expression_tokenizer_unbalanced_round_brackets_multiple_opening():
set_action = "SET MyStr = (:_val + (:val2"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "<EOF>"
assert te.near == ":val2"
def test_expression_tokenizer_unbalanced_round_brackets_only_closing():
set_action = "SET MyStr = ):_val"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == ")"
assert te.near == "= ):_val"
def test_expression_tokenizer_unbalanced_square_brackets_only_closing():
set_action = "SET MyStr = ]:_val"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "]"
assert te.near == "= ]:_val"
def test_expression_tokenizer_unbalanced_round_brackets_only_closing_followed_by_other_parts():
set_action = "SET MyStr = ):_val + :val2"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == ")"
assert te.near == "= ):_val"
def test_update_expression_starts_with_keyword_reset_followed_by_identifier():
update_expression = "RESET NonExistent"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "RESET"
assert te.near == "RESET NonExistent"
def test_update_expression_starts_with_keyword_reset_followed_by_identifier_and_value():
update_expression = "RESET NonExistent value"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "RESET"
assert te.near == "RESET NonExistent"
def test_update_expression_starts_with_leading_spaces_and_keyword_reset_followed_by_identifier_and_value():
update_expression = " RESET NonExistent value"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "RESET"
assert te.near == " RESET NonExistent"
def test_update_expression_with_only_keyword_reset():
update_expression = "RESET"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "RESET"
assert te.near == "RESET"
def test_update_nested_expression_with_selector_just_should_fail_parsing_at_numeric_literal_value():
update_expression = "SET a[0].b = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "5"
assert te.near == "= 5"
def test_update_nested_expression_with_selector_and_spaces_should_only_fail_parsing_at_numeric_literal_value():
update_expression = "SET a [ 2 ]. b = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "5"
assert te.near == "= 5"
def test_update_nested_expression_with_double_selector_and_spaces_should_only_fail_parsing_at_numeric_literal_value():
update_expression = "SET a [2][ 3 ]. b = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "5"
assert te.near == "= 5"
def test_update_nested_expression_should_only_fail_parsing_at_numeric_literal_value():
update_expression = "SET a . b = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "5"
assert te.near == "= 5"
def test_nested_selectors_in_update_expression_should_fail_at_nesting():
update_expression = "SET a [ [2] ]. b = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "["
assert te.near == "[ [2"
def test_update_expression_number_in_selector_cannot_be_splite():
update_expression = "SET a [2 1]. b = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "1"
assert te.near == "2 1]"
def test_update_expression_cannot_have_successive_attributes():
update_expression = "SET #a a = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "a"
assert te.near == "#a a ="
def test_update_expression_path_with_both_attribute_and_attribute_name_should_only_fail_at_numeric_value():
update_expression = "SET #a.a = 5"
try:
UpdateExpressionParser.make(update_expression)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "5"
assert te.near == "= 5"
def test_expression_tokenizer_2_same_operators_back_to_back():
set_action = "SET MyStr = NoExist + + :_val "
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "+"
assert te.near == "+ + :_val"
def test_expression_tokenizer_2_different_operators_back_to_back():
set_action = "SET MyStr = NoExist + - :_val "
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "-"
assert te.near == "+ - :_val"
def test_update_expression_remove_does_not_allow_operations():
remove_action = "REMOVE NoExist + "
try:
UpdateExpressionParser.make(remove_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "+"
assert te.near == "NoExist + "
def test_update_expression_add_does_not_allow_attribute_after_path():
"""value here is not really a value since a value starts with a colon (:)"""
add_expr = "ADD attr val foobar"
try:
UpdateExpressionParser.make(add_expr)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "val"
assert te.near == "attr val foobar"
def test_update_expression_add_does_not_allow_attribute_foobar_after_value():
add_expr = "ADD attr :val foobar"
try:
UpdateExpressionParser.make(add_expr)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "foobar"
assert te.near == ":val foobar"
def test_update_expression_delete_does_not_allow_attribute_after_path():
"""value here is not really a value since a value starts with a colon (:)"""
delete_expr = "DELETE attr val"
try:
UpdateExpressionParser.make(delete_expr)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "val"
assert te.near == "attr val"
def test_update_expression_delete_does_not_allow_attribute_foobar_after_value():
delete_expr = "DELETE attr :val foobar"
try:
UpdateExpressionParser.make(delete_expr)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "foobar"
assert te.near == ":val foobar"
def test_update_expression_parsing_is_not_keyword_aware():
"""path and VALUE are keywords. Yet a token error will be thrown for the numeric literal 1."""
delete_expr = "SET path = VALUE 1"
try:
UpdateExpressionParser.make(delete_expr)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "1"
assert te.near == "VALUE 1"
def test_expression_if_not_exists_is_not_valid_in_remove_statement():
set_action = "REMOVE if_not_exists(a,b)"
try:
UpdateExpressionParser.make(set_action)
assert False, "Exception not raised correctly"
except InvalidTokenException as te:
assert te.token == "("
assert te.near == "if_not_exists(a"

View File

@ -1254,14 +1254,22 @@ def test_update_item_with_expression():
item_key = {"forum_name": "the-key", "subject": "123"}
table.update_item(Key=item_key, UpdateExpression="SET field=2")
table.update_item(
Key=item_key,
UpdateExpression="SET field = :field_value",
ExpressionAttributeValues={":field_value": 2},
)
dict(table.get_item(Key=item_key)["Item"]).should.equal(
{"field": "2", "forum_name": "the-key", "subject": "123"}
{"field": Decimal("2"), "forum_name": "the-key", "subject": "123"}
)
table.update_item(Key=item_key, UpdateExpression="SET field = 3")
table.update_item(
Key=item_key,
UpdateExpression="SET field = :field_value",
ExpressionAttributeValues={":field_value": 3},
)
dict(table.get_item(Key=item_key)["Item"]).should.equal(
{"field": "3", "forum_name": "the-key", "subject": "123"}
{"field": Decimal("3"), "forum_name": "the-key", "subject": "123"}
)

View File

@ -443,23 +443,40 @@ def test_update_item_nested_remove():
dict(returned_item).should.equal({"username": "steve", "Meta": {}})
@mock_dynamodb2_deprecated
@mock_dynamodb2
def test_update_item_double_nested_remove():
conn = boto.dynamodb2.connect_to_region("us-east-1")
table = Table.create("messages", schema=[HashKey("username")])
conn = boto3.client("dynamodb", region_name="us-east-1")
conn.create_table(
TableName="messages",
KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}],
AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}],
BillingMode="PAY_PER_REQUEST",
)
data = {"username": "steve", "Meta": {"Name": {"First": "Steve", "Last": "Urkel"}}}
table.put_item(data=data)
item = {
"username": {"S": "steve"},
"Meta": {
"M": {"Name": {"M": {"First": {"S": "Steve"}, "Last": {"S": "Urkel"}}}}
},
}
conn.put_item(TableName="messages", Item=item)
key_map = {"username": {"S": "steve"}}
# Then remove the Meta.FullName field
conn.update_item("messages", key_map, update_expression="REMOVE Meta.Name.First")
returned_item = table.get_item(username="steve")
dict(returned_item).should.equal(
{"username": "steve", "Meta": {"Name": {"Last": "Urkel"}}}
conn.update_item(
TableName="messages",
Key=key_map,
UpdateExpression="REMOVE Meta.#N.#F",
ExpressionAttributeNames={"#N": "Name", "#F": "First"},
)
returned_item = conn.get_item(TableName="messages", Key=key_map)
expected_item = {
"username": {"S": "steve"},
"Meta": {"M": {"Name": {"M": {"Last": {"S": "Urkel"}}}}},
}
dict(returned_item["Item"]).should.equal(expected_item)
@mock_dynamodb2_deprecated
def test_update_item_set():
@ -471,7 +488,10 @@ def test_update_item_set():
key_map = {"username": {"S": "steve"}}
conn.update_item(
"messages", key_map, update_expression="SET foo=bar, blah=baz REMOVE SentBy"
"messages",
key_map,
update_expression="SET foo=:bar, blah=:baz REMOVE SentBy",
expression_attribute_values={":bar": {"S": "bar"}, ":baz": {"S": "baz"}},
)
returned_item = table.get_item(username="steve")
@ -616,8 +636,9 @@ def test_boto3_update_item_conditions_fail():
table.put_item(Item={"username": "johndoe", "foo": "baz"})
table.update_item.when.called_with(
Key={"username": "johndoe"},
UpdateExpression="SET foo=bar",
UpdateExpression="SET foo=:bar",
Expected={"foo": {"Value": "bar"}},
ExpressionAttributeValues={":bar": "bar"},
).should.throw(botocore.client.ClientError)
@ -627,8 +648,9 @@ def test_boto3_update_item_conditions_fail_because_expect_not_exists():
table.put_item(Item={"username": "johndoe", "foo": "baz"})
table.update_item.when.called_with(
Key={"username": "johndoe"},
UpdateExpression="SET foo=bar",
UpdateExpression="SET foo=:bar",
Expected={"foo": {"Exists": False}},
ExpressionAttributeValues={":bar": "bar"},
).should.throw(botocore.client.ClientError)
@ -638,8 +660,9 @@ def test_boto3_update_item_conditions_fail_because_expect_not_exists_by_compare_
table.put_item(Item={"username": "johndoe", "foo": "baz"})
table.update_item.when.called_with(
Key={"username": "johndoe"},
UpdateExpression="SET foo=bar",
UpdateExpression="SET foo=:bar",
Expected={"foo": {"ComparisonOperator": "NULL"}},
ExpressionAttributeValues={":bar": "bar"},
).should.throw(botocore.client.ClientError)
@ -649,8 +672,9 @@ def test_boto3_update_item_conditions_pass():
table.put_item(Item={"username": "johndoe", "foo": "bar"})
table.update_item(
Key={"username": "johndoe"},
UpdateExpression="SET foo=baz",
UpdateExpression="SET foo=:baz",
Expected={"foo": {"Value": "bar"}},
ExpressionAttributeValues={":baz": "baz"},
)
returned_item = table.get_item(Key={"username": "johndoe"})
assert dict(returned_item)["Item"]["foo"].should.equal("baz")
@ -662,8 +686,9 @@ def test_boto3_update_item_conditions_pass_because_expect_not_exists():
table.put_item(Item={"username": "johndoe", "foo": "bar"})
table.update_item(
Key={"username": "johndoe"},
UpdateExpression="SET foo=baz",
UpdateExpression="SET foo=:baz",
Expected={"whatever": {"Exists": False}},
ExpressionAttributeValues={":baz": "baz"},
)
returned_item = table.get_item(Key={"username": "johndoe"})
assert dict(returned_item)["Item"]["foo"].should.equal("baz")
@ -675,8 +700,9 @@ def test_boto3_update_item_conditions_pass_because_expect_not_exists_by_compare_
table.put_item(Item={"username": "johndoe", "foo": "bar"})
table.update_item(
Key={"username": "johndoe"},
UpdateExpression="SET foo=baz",
UpdateExpression="SET foo=:baz",
Expected={"whatever": {"ComparisonOperator": "NULL"}},
ExpressionAttributeValues={":baz": "baz"},
)
returned_item = table.get_item(Key={"username": "johndoe"})
assert dict(returned_item)["Item"]["foo"].should.equal("baz")
@ -688,8 +714,9 @@ def test_boto3_update_item_conditions_pass_because_expect_exists_by_compare_to_n
table.put_item(Item={"username": "johndoe", "foo": "bar"})
table.update_item(
Key={"username": "johndoe"},
UpdateExpression="SET foo=baz",
UpdateExpression="SET foo=:baz",
Expected={"foo": {"ComparisonOperator": "NOT_NULL"}},
ExpressionAttributeValues={":baz": "baz"},
)
returned_item = table.get_item(Key={"username": "johndoe"})
assert dict(returned_item)["Item"]["foo"].should.equal("baz")

View File

@ -0,0 +1,464 @@
from moto.dynamodb2.exceptions import (
AttributeIsReservedKeyword,
ExpressionAttributeValueNotDefined,
AttributeDoesNotExist,
ExpressionAttributeNameNotDefined,
IncorrectOperandType,
InvalidUpdateExpressionInvalidDocumentPath,
)
from moto.dynamodb2.models import Item, DynamoType
from moto.dynamodb2.parsing.ast_nodes import (
NodeDepthLeftTypeFetcher,
UpdateExpressionSetAction,
UpdateExpressionValue,
DDBTypedValue,
)
from moto.dynamodb2.parsing.expressions import UpdateExpressionParser
from moto.dynamodb2.parsing.validators import UpdateExpressionValidator
from parameterized import parameterized
def test_validation_of_update_expression_with_keyword():
try:
update_expression = "SET myNum = path + :val"
update_expression_values = {":val": {"N": "3"}}
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "path": {"N": "3"}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=update_expression_values,
item=item,
).validate()
assert False, "No exception raised"
except AttributeIsReservedKeyword as e:
assert e.keyword == "path"
@parameterized(
["SET a = #b + :val2", "SET a = :val2 + #b",]
)
def test_validation_of_a_set_statement_with_incorrect_passed_value(update_expression):
"""
By running permutations it shows that values are replaced prior to resolving attributes.
An error occurred (ValidationException) when calling the UpdateItem operation: Invalid UpdateExpression:
An expression attribute value used in expression is not defined; attribute value: :val2
"""
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "b": {"N": "3"}},
)
try:
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names={"#b": "ok"},
expression_attribute_values={":val": {"N": "3"}},
item=item,
).validate()
except ExpressionAttributeValueNotDefined as e:
assert e.attribute_value == ":val2"
def test_validation_of_update_expression_with_attribute_that_does_not_exist_in_item():
"""
When an update expression tries to get an attribute that does not exist it must throw the appropriate exception.
An error occurred (ValidationException) when calling the UpdateItem operation:
The provided expression refers to an attribute that does not exist in the item
"""
try:
update_expression = "SET a = nonexistent"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "path": {"N": "3"}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=None,
item=item,
).validate()
assert False, "No exception raised"
except AttributeDoesNotExist:
assert True
@parameterized(
["SET a = #c", "SET a = #c + #d",]
)
def test_validation_of_update_expression_with_attribute_name_that_is_not_defined(
update_expression,
):
"""
When an update expression tries to get an attribute name that is not provided it must throw an exception.
An error occurred (ValidationException) when calling the UpdateItem operation: Invalid UpdateExpression:
An expression attribute name used in the document path is not defined; attribute name: #c
"""
try:
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "path": {"N": "3"}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names={"#b": "ok"},
expression_attribute_values=None,
item=item,
).validate()
assert False, "No exception raised"
except ExpressionAttributeNameNotDefined as e:
assert e.not_defined_attribute_name == "#c"
def test_validation_of_if_not_exists_not_existing_invalid_replace_value():
try:
update_expression = "SET a = if_not_exists(b, a.c)"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "a": {"S": "A"}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=None,
item=item,
).validate()
assert False, "No exception raised"
except AttributeDoesNotExist:
assert True
def get_first_node_of_type(ast, node_type):
return next(NodeDepthLeftTypeFetcher(node_type, ast))
def get_set_action_value(ast):
"""
Helper that takes an AST and gets the first UpdateExpressionSetAction and retrieves the value of that action.
This should only be called on validated expressions.
Args:
ast(Node):
Returns:
DynamoType: The DynamoType object representing the Dynamo value.
"""
set_action = get_first_node_of_type(ast, UpdateExpressionSetAction)
typed_value = set_action.children[1]
assert isinstance(typed_value, DDBTypedValue)
dynamo_value = typed_value.children[0]
assert isinstance(dynamo_value, DynamoType)
return dynamo_value
def test_validation_of_if_not_exists_not_existing_value():
update_expression = "SET a = if_not_exists(b, a)"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "a": {"S": "A"}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=None,
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"S": "A"})
def test_validation_of_if_not_exists_with_existing_attribute_should_return_attribute():
update_expression = "SET a = if_not_exists(b, a)"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "a": {"S": "A"}, "b": {"S": "B"}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=None,
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"S": "B"})
def test_validation_of_if_not_exists_with_existing_attribute_should_return_value():
update_expression = "SET a = if_not_exists(b, :val)"
update_expression_values = {":val": {"N": "4"}}
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "b": {"N": "3"}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=update_expression_values,
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"N": "3"})
def test_validation_of_if_not_exists_with_non_existing_attribute_should_return_value():
update_expression = "SET a = if_not_exists(b, :val)"
update_expression_values = {":val": {"N": "4"}}
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=update_expression_values,
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"N": "4"})
def test_validation_of_sum_operation():
update_expression = "SET a = a + b"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "a": {"N": "3"}, "b": {"N": "4"}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=None,
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"N": "7"})
def test_validation_homogeneous_list_append_function():
update_expression = "SET ri = list_append(ri, :vals)"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "ri": {"L": [{"S": "i1"}, {"S": "i2"}]}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values={":vals": {"L": [{"S": "i3"}, {"S": "i4"}]}},
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType(
{"L": [{"S": "i1"}, {"S": "i2"}, {"S": "i3"}, {"S": "i4"}]}
)
def test_validation_hetereogenous_list_append_function():
update_expression = "SET ri = list_append(ri, :vals)"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "ri": {"L": [{"S": "i1"}, {"S": "i2"}]}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values={":vals": {"L": [{"N": "3"}]}},
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"L": [{"S": "i1"}, {"S": "i2"}, {"N": "3"}]})
def test_validation_list_append_function_with_non_list_arg():
"""
Must error out:
Invalid UpdateExpression: Incorrect operand type for operator or function;
operator or function: list_append, operand type: S'
Returns:
"""
try:
update_expression = "SET ri = list_append(ri, :vals)"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "ri": {"L": [{"S": "i1"}, {"S": "i2"}]}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values={":vals": {"S": "N"}},
item=item,
).validate()
except IncorrectOperandType as e:
assert e.operand_type == "S"
assert e.operator_or_function == "list_append"
def test_sum_with_incompatible_types():
"""
Must error out:
Invalid UpdateExpression: Incorrect operand type for operator or function; operator or function: +, operand type: S'
Returns:
"""
try:
update_expression = "SET ri = :val + :val2"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "ri": {"L": [{"S": "i1"}, {"S": "i2"}]}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values={":val": {"S": "N"}, ":val2": {"N": "3"}},
item=item,
).validate()
except IncorrectOperandType as e:
assert e.operand_type == "S"
assert e.operator_or_function == "+"
def test_validation_of_subraction_operation():
update_expression = "SET ri = :val - :val2"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "1"}, "a": {"N": "3"}, "b": {"N": "4"}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values={":val": {"N": "1"}, ":val2": {"N": "3"}},
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"N": "-2"})
def test_cannot_index_into_a_string():
"""
Must error out:
The document path provided in the update expression is invalid for update'
"""
try:
update_expression = "set itemstr[1]=:Item"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "foo2"}, "itemstr": {"S": "somestring"}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values={":Item": {"S": "string_update"}},
item=item,
).validate()
assert False, "Must raise exception"
except InvalidUpdateExpressionInvalidDocumentPath:
assert True
def test_validation_set_path_does_not_need_to_be_resolvable_when_setting_a_new_attribute():
"""If this step just passes we are happy enough"""
update_expression = "set d=a"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "foo2"}, "a": {"N": "3"}},
)
validated_ast = UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=None,
item=item,
).validate()
dynamo_value = get_set_action_value(validated_ast)
assert dynamo_value == DynamoType({"N": "3"})
def test_validation_set_path_does_not_need_to_be_resolvable_but_must_be_creatable_when_setting_a_new_attribute():
try:
update_expression = "set d.e=a"
update_expression_ast = UpdateExpressionParser.make(update_expression)
item = Item(
hash_key=DynamoType({"S": "id"}),
hash_key_type="TYPE",
range_key=None,
range_key_type=None,
attrs={"id": {"S": "foo2"}, "a": {"N": "3"}},
)
UpdateExpressionValidator(
update_expression_ast,
expression_attribute_names=None,
expression_attribute_values=None,
item=item,
).validate()
assert False, "Must raise exception"
except InvalidUpdateExpressionInvalidDocumentPath:
assert True

View File

@ -134,6 +134,7 @@ class TestCore:
"id": {"S": "entry1"},
"first_col": {"S": "bar"},
"second_col": {"S": "baz"},
"a": {"L": [{"M": {"b": {"S": "bar1"}}}]},
},
)
conn.delete_item(TableName="test-streams", Key={"id": {"S": "entry1"}})

View File

@ -52,3 +52,15 @@ def test_boto3_availability_zones():
resp = conn.describe_availability_zones()
for rec in resp["AvailabilityZones"]:
rec["ZoneName"].should.contain(region)
@mock_ec2
def test_boto3_zoneId_in_availability_zones():
conn = boto3.client("ec2", "us-east-1")
resp = conn.describe_availability_zones()
for rec in resp["AvailabilityZones"]:
rec.get("ZoneId").should.contain("use1")
conn = boto3.client("ec2", "us-west-1")
resp = conn.describe_availability_zones()
for rec in resp["AvailabilityZones"]:
rec.get("ZoneId").should.contain("usw1")

View File

@ -53,6 +53,45 @@ def test_create_and_delete_volume():
cm.exception.request_id.should_not.be.none
@mock_ec2_deprecated
def test_delete_attached_volume():
conn = boto.ec2.connect_to_region("us-east-1")
reservation = conn.run_instances("ami-1234abcd")
# create an instance
instance = reservation.instances[0]
# create a volume
volume = conn.create_volume(80, "us-east-1a")
# attach volume to instance
volume.attach(instance.id, "/dev/sdh")
volume.update()
volume.volume_state().should.equal("in-use")
volume.attachment_state().should.equal("attached")
volume.attach_data.instance_id.should.equal(instance.id)
# attempt to delete volume
# assert raises VolumeInUseError
with assert_raises(EC2ResponseError) as ex:
volume.delete()
ex.exception.error_code.should.equal("VolumeInUse")
ex.exception.status.should.equal(400)
ex.exception.message.should.equal(
"Volume {0} is currently attached to {1}".format(volume.id, instance.id)
)
volume.detach()
volume.update()
volume.volume_state().should.equal("available")
volume.delete()
all_volumes = conn.get_all_volumes()
my_volume = [item for item in all_volumes if item.id == volume.id]
my_volume.should.have.length_of(0)
@mock_ec2_deprecated
def test_create_encrypted_volume_dryrun():
conn = boto.ec2.connect_to_region("us-east-1")

View File

@ -9,6 +9,7 @@ from nose.tools import assert_raises
import base64
import datetime
import ipaddress
import json
import six
import boto
@ -18,7 +19,7 @@ from boto.exception import EC2ResponseError, EC2ResponseError
from freezegun import freeze_time
import sure # noqa
from moto import mock_ec2_deprecated, mock_ec2
from moto import mock_ec2_deprecated, mock_ec2, mock_cloudformation
from tests.helpers import requires_boto_gte
@ -71,7 +72,7 @@ def test_instance_launch_and_terminate():
instance.id.should.equal(instance.id)
instance.state.should.equal("running")
instance.launch_time.should.equal("2014-01-01T05:00:00.000Z")
instance.vpc_id.should.equal(None)
instance.vpc_id.shouldnt.equal(None)
instance.placement.should.equal("us-east-1a")
root_device_name = instance.root_device_name
@ -1166,6 +1167,21 @@ def test_describe_instance_status_with_instance_filter_deprecated():
cm.exception.request_id.should_not.be.none
@mock_ec2
def test_describe_instance_credit_specifications():
conn = boto3.client("ec2", region_name="us-west-1")
# We want to filter based on this one
reservation = conn.run_instances(ImageId="ami-1234abcd", MinCount=1, MaxCount=1)
result = conn.describe_instance_credit_specifications(
InstanceIds=[reservation["Instances"][0]["InstanceId"]]
)
assert (
result["InstanceCreditSpecifications"][0]["InstanceId"]
== reservation["Instances"][0]["InstanceId"]
)
@mock_ec2
def test_describe_instance_status_with_instance_filter():
conn = boto3.client("ec2", region_name="us-west-1")
@ -1399,3 +1415,40 @@ def test_describe_instance_attribute():
invalid_instance_attribute=invalid_instance_attribute
)
ex.exception.response["Error"]["Message"].should.equal(message)
@mock_ec2
@mock_cloudformation
def test_volume_size_through_cloudformation():
ec2 = boto3.client("ec2", region_name="us-east-1")
cf = boto3.client("cloudformation", region_name="us-east-1")
volume_template = {
"AWSTemplateFormatVersion": "2010-09-09",
"Resources": {
"testInstance": {
"Type": "AWS::EC2::Instance",
"Properties": {
"ImageId": "ami-d3adb33f",
"KeyName": "dummy",
"InstanceType": "t2.micro",
"BlockDeviceMappings": [
{"DeviceName": "/dev/sda2", "Ebs": {"VolumeSize": "50"}}
],
"Tags": [
{"Key": "foo", "Value": "bar"},
{"Key": "blah", "Value": "baz"},
],
},
}
},
}
template_json = json.dumps(volume_template)
cf.create_stack(StackName="test_stack", TemplateBody=template_json)
instances = ec2.describe_instances()
volume = instances["Reservations"][0]["Instances"][0]["BlockDeviceMappings"][0][
"Ebs"
]
volumes = ec2.describe_volumes(VolumeIds=[volume["VolumeId"]])
volumes["Volumes"][0]["Size"].should.equal(50)

View File

@ -599,3 +599,20 @@ def validate_subnet_details_after_creating_eni(
for eni in enis_created:
client.delete_network_interface(NetworkInterfaceId=eni["NetworkInterfaceId"])
client.delete_subnet(SubnetId=subnet["SubnetId"])
@mock_ec2
def test_run_instances_should_attach_to_default_subnet():
# https://github.com/spulec/moto/issues/2877
ec2 = boto3.resource("ec2", region_name="us-west-1")
client = boto3.client("ec2", region_name="us-west-1")
ec2.create_security_group(GroupName="sg01", Description="Test security group sg01")
# run_instances
instances = client.run_instances(MinCount=1, MaxCount=1, SecurityGroups=["sg01"],)
# Assert subnet is created appropriately
subnets = client.describe_subnets()["Subnets"]
default_subnet_id = subnets[0]["SubnetId"]
instances["Instances"][0]["NetworkInterfaces"][0]["SubnetId"].should.equal(
default_subnet_id
)
subnets[0]["AvailableIpAddressCount"].should.equal(4090)

View File

@ -1122,6 +1122,71 @@ def test_run_task():
response["tasks"][0]["stoppedReason"].should.equal("")
@mock_ec2
@mock_ecs
def test_run_task_default_cluster():
client = boto3.client("ecs", region_name="us-east-1")
ec2 = boto3.resource("ec2", region_name="us-east-1")
test_cluster_name = "default"
_ = client.create_cluster(clusterName=test_cluster_name)
test_instance = ec2.create_instances(
ImageId="ami-1234abcd", MinCount=1, MaxCount=1
)[0]
instance_id_document = json.dumps(
ec2_utils.generate_instance_identity_document(test_instance)
)
response = client.register_container_instance(
cluster=test_cluster_name, instanceIdentityDocument=instance_id_document
)
_ = client.register_task_definition(
family="test_ecs_task",
containerDefinitions=[
{
"name": "hello_world",
"image": "docker/hello-world:latest",
"cpu": 1024,
"memory": 400,
"essential": True,
"environment": [
{"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"}
],
"logConfiguration": {"logDriver": "json-file"},
}
],
)
response = client.run_task(
launchType="FARGATE",
overrides={},
taskDefinition="test_ecs_task",
count=2,
startedBy="moto",
)
len(response["tasks"]).should.equal(2)
response["tasks"][0]["taskArn"].should.contain(
"arn:aws:ecs:us-east-1:012345678910:task/"
)
response["tasks"][0]["clusterArn"].should.equal(
"arn:aws:ecs:us-east-1:012345678910:cluster/default"
)
response["tasks"][0]["taskDefinitionArn"].should.equal(
"arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1"
)
response["tasks"][0]["containerInstanceArn"].should.contain(
"arn:aws:ecs:us-east-1:012345678910:container-instance/"
)
response["tasks"][0]["overrides"].should.equal({})
response["tasks"][0]["lastStatus"].should.equal("RUNNING")
response["tasks"][0]["desiredStatus"].should.equal("RUNNING")
response["tasks"][0]["startedBy"].should.equal("moto")
response["tasks"][0]["stoppedReason"].should.equal("")
@mock_ec2
@mock_ecs
def test_start_task():

View File

@ -79,13 +79,23 @@ def generate_environment():
@mock_events
def test_put_rule():
client = boto3.client("events", "us-west-2")
client.list_rules()["Rules"].should.have.length_of(0)
rule_data = get_random_rule()
rule_data = {
"Name": "my-event",
"ScheduleExpression": "rate(5 minutes)",
"EventPattern": '{"source": ["test-source"]}',
}
client.put_rule(**rule_data)
client.list_rules()["Rules"].should.have.length_of(1)
rules = client.list_rules()["Rules"]
rules.should.have.length_of(1)
rules[0]["Name"].should.equal(rule_data["Name"])
rules[0]["ScheduleExpression"].should.equal(rule_data["ScheduleExpression"])
rules[0]["EventPattern"].should.equal(rule_data["EventPattern"])
rules[0]["State"].should.equal("ENABLED")
@mock_events

View File

@ -52,6 +52,29 @@ def test_get_database_not_exits():
)
@mock_glue
def test_get_databases_empty():
client = boto3.client("glue", region_name="us-east-1")
response = client.get_databases()
response["DatabaseList"].should.have.length_of(0)
@mock_glue
def test_get_databases_several_items():
client = boto3.client("glue", region_name="us-east-1")
database_name_1, database_name_2 = "firstdatabase", "seconddatabase"
helpers.create_database(client, database_name_1)
helpers.create_database(client, database_name_2)
database_list = sorted(
client.get_databases()["DatabaseList"], key=lambda x: x["Name"]
)
database_list.should.have.length_of(2)
database_list[0].should.equal({"Name": database_name_1})
database_list[1].should.equal({"Name": database_name_2})
@mock_glue
def test_create_table():
client = boto3.client("glue", region_name="us-east-1")

View File

@ -728,6 +728,14 @@ def test_principal_thing():
res = client.list_thing_principals(thingName=thing_name)
res.should.have.key("principals").which.should.have.length_of(0)
with assert_raises(ClientError) as e:
client.list_thing_principals(thingName="xxx")
e.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException")
e.exception.response["Error"]["Message"].should.equal(
"Failed to list principals for thing xxx because the thing does not exist in your account"
)
@mock_iot
def test_delete_principal_thing():

View File

@ -12,17 +12,14 @@ _logs_region = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2"
@mock_logs
def test_log_group_create():
def test_create_log_group():
conn = boto3.client("logs", "us-west-2")
log_group_name = "dummy"
response = conn.create_log_group(logGroupName=log_group_name)
response = conn.describe_log_groups(logGroupNamePrefix=log_group_name)
assert len(response["logGroups"]) == 1
# AWS defaults to Never Expire for log group retention
assert response["logGroups"][0].get("retentionInDays") == None
response = conn.create_log_group(logGroupName="dummy")
response = conn.describe_log_groups()
response = conn.delete_log_group(logGroupName=log_group_name)
response["logGroups"].should.have.length_of(1)
response["logGroups"][0].should_not.have.key("retentionInDays")
@mock_logs

View File

@ -183,12 +183,12 @@ def test_start_database():
@mock_rds2
def test_fail_to_stop_multi_az():
def test_fail_to_stop_multi_az_and_sqlserver():
conn = boto3.client("rds", region_name="us-west-2")
database = conn.create_db_instance(
DBInstanceIdentifier="db-master-1",
AllocatedStorage=10,
Engine="postgres",
Engine="sqlserver-ee",
DBName="staging-postgres",
DBInstanceClass="db.m1.small",
LicenseModel="license-included",
@ -213,6 +213,33 @@ def test_fail_to_stop_multi_az():
).should.throw(ClientError)
@mock_rds2
def test_stop_multi_az_postgres():
conn = boto3.client("rds", region_name="us-west-2")
database = conn.create_db_instance(
DBInstanceIdentifier="db-master-1",
AllocatedStorage=10,
Engine="postgres",
DBName="staging-postgres",
DBInstanceClass="db.m1.small",
LicenseModel="license-included",
MasterUsername="root",
MasterUserPassword="hunter2",
Port=1234,
DBSecurityGroups=["my_sg"],
MultiAZ=True,
)
mydb = conn.describe_db_instances(
DBInstanceIdentifier=database["DBInstance"]["DBInstanceIdentifier"]
)["DBInstances"][0]
mydb["DBInstanceStatus"].should.equal("available")
response = conn.stop_db_instance(DBInstanceIdentifier=mydb["DBInstanceIdentifier"])
response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
response["DBInstance"]["DBInstanceStatus"].should.equal("stopped")
@mock_rds2
def test_fail_to_stop_readreplica():
conn = boto3.client("rds", region_name="us-west-2")

View File

@ -14,6 +14,7 @@ from io import BytesIO
import mimetypes
import zlib
import pickle
import uuid
import json
import boto
@ -4424,3 +4425,41 @@ def test_s3_config_dict():
assert not logging_bucket["supplementaryConfiguration"].get(
"BucketTaggingConfiguration"
)
@mock_s3
def test_creating_presigned_post():
bucket = "presigned-test"
s3 = boto3.client("s3", region_name="us-east-1")
s3.create_bucket(Bucket=bucket)
success_url = "http://localhost/completed"
fdata = b"test data\n"
file_uid = uuid.uuid4()
conditions = [
{"Content-Type": "text/plain"},
{"x-amz-server-side-encryption": "AES256"},
{"success_action_redirect": success_url},
]
conditions.append(["content-length-range", 1, 30])
data = s3.generate_presigned_post(
Bucket=bucket,
Key="{file_uid}.txt".format(file_uid=file_uid),
Fields={
"content-type": "text/plain",
"success_action_redirect": success_url,
"x-amz-server-side-encryption": "AES256",
},
Conditions=conditions,
ExpiresIn=1000,
)
resp = requests.post(
data["url"], data=data["fields"], files={"file": fdata}, allow_redirects=False
)
assert resp.headers["Location"] == success_url
assert resp.status_code == 303
assert (
s3.get_object(Bucket=bucket, Key="{file_uid}.txt".format(file_uid=file_uid))[
"Body"
].read()
== fdata
)

View File

@ -711,3 +711,79 @@ def test_can_list_secret_version_ids():
returned_version_ids = [v["VersionId"] for v in versions_list["Versions"]]
assert [first_version_id, second_version_id].sort() == returned_version_ids.sort()
@mock_secretsmanager
def test_update_secret():
conn = boto3.client("secretsmanager", region_name="us-west-2")
created_secret = conn.create_secret(Name="test-secret", SecretString="foosecret")
assert created_secret["ARN"]
assert created_secret["Name"] == "test-secret"
assert created_secret["VersionId"] != ""
secret = conn.get_secret_value(SecretId="test-secret")
assert secret["SecretString"] == "foosecret"
updated_secret = conn.update_secret(
SecretId="test-secret", SecretString="barsecret"
)
assert updated_secret["ARN"]
assert updated_secret["Name"] == "test-secret"
assert updated_secret["VersionId"] != ""
secret = conn.get_secret_value(SecretId="test-secret")
assert secret["SecretString"] == "barsecret"
assert created_secret["VersionId"] != updated_secret["VersionId"]
@mock_secretsmanager
def test_update_secret_which_does_not_exit():
conn = boto3.client("secretsmanager", region_name="us-west-2")
with assert_raises(ClientError) as cm:
updated_secret = conn.update_secret(
SecretId="test-secret", SecretString="barsecret"
)
assert_equal(
"Secrets Manager can't find the specified secret.",
cm.exception.response["Error"]["Message"],
)
@mock_secretsmanager
def test_update_secret_marked_as_deleted():
conn = boto3.client("secretsmanager", region_name="us-west-2")
created_secret = conn.create_secret(Name="test-secret", SecretString="foosecret")
deleted_secret = conn.delete_secret(SecretId="test-secret")
with assert_raises(ClientError) as cm:
updated_secret = conn.update_secret(
SecretId="test-secret", SecretString="barsecret"
)
assert (
"because it was marked for deletion."
in cm.exception.response["Error"]["Message"]
)
@mock_secretsmanager
def test_update_secret_marked_as_deleted_after_restoring():
conn = boto3.client("secretsmanager", region_name="us-west-2")
created_secret = conn.create_secret(Name="test-secret", SecretString="foosecret")
deleted_secret = conn.delete_secret(SecretId="test-secret")
restored_secret = conn.restore_secret(SecretId="test-secret")
updated_secret = conn.update_secret(
SecretId="test-secret", SecretString="barsecret"
)
assert updated_secret["ARN"]
assert updated_secret["Name"] == "test-secret"
assert updated_secret["VersionId"] != ""