Merge LocalStack changes into upstream moto (#4082)

* fix OPTIONS requests on non-existing API GW integrations

* add cloudformation models for API Gateway deployments

* bump version

* add backdoor to return CloudWatch metrics

* Updating implementation coverage

* Updating implementation coverage

* add cloudformation models for API Gateway deployments

* Updating implementation coverage

* Updating implementation coverage

* Implemented get-caller-identity returning real data depending on the access key used.

* bump version

* minor fixes

* fix Number data_type for SQS message attribute

* fix handling of encoding errors

* bump version

* make CF stack queryable before starting to initialize its resources

* bump version

* fix integration_method for API GW method integrations

* fix undefined status in CF FakeStack

* Fix apigateway issues with terraform v0.12.21
* resource_methods -> add handle for "DELETE" method
* integrations -> fix issue that "httpMethod" wasn't included in body request (this value was set as the value from refer method resource)

* bump version

* Fix setting http method for API gateway integrations (#6)

* bump version

* remove duplicate methods

* add storage class to S3 Key when completing multipart upload (#7)

* fix SQS performance issues; bump version

* add pagination to SecretsManager list-secrets (#9)

* fix default parameter groups in RDS

* fix adding S3 metadata headers with names containing dots (#13)

* Updating implementation coverage

* Updating implementation coverage

* add cloudformation models for API Gateway deployments

* Updating implementation coverage

* Updating implementation coverage

* Implemented get-caller-identity returning real data depending on the access key used.

* make CF stack queryable before starting to initialize its resources

* bump version

* remove duplicate methods

* fix adding S3 metadata headers with names containing dots (#13)

* Update amis.json to support EKS AMI mocks (#15)

* fix PascalCase for boolean value in ListMultipartUploads response (#17); fix _get_multi_param to parse nested list/dict query params

* determine non-zero container exit code in Batch API

* support filtering by dimensions in CW get_metric_statistics

* fix storing attributes for ELBv2 Route entities; API GW refactorings for TF tests

* add missing fields for API GW resources

* fix error messages for Route53 (TF-compat)

* various fixes for IAM resources (tf-compat)

* minor fixes for API GW models (tf-compat)

* minor fixes for API GW responses (tf-compat)

* add s3 exception for bucket notification filter rule validation

* change the way RESTErrors generate the response body and content-type header

* fix lint errors and disable "black" syntax enforcement

* remove return type hint in RESTError.get_body

* add RESTError XML template for IAM exceptions

* add support for API GW minimumCompressionSize

* fix casing getting PrivateDnsEnabled API GW attribute

* minor fixes for error responses

* fix escaping special chars for IAM role descriptions (tf-compat)

* minor fixes and tagging support for API GW and ELB v2 (tf-compat)

* Merge branch 'master' into localstack

* add "AlarmRule" attribute to enable support for composite CloudWatch metrics

* fix recursive parsing of complex/nested query params

* bump version

* add API to delete S3 website configurations (#18)

* use dict copy to allow parallelism and avoid concurrent modification exceptions in S3

* fix precondition check for etags in S3 (#19)

* minor fix for user filtering in Cognito

* fix API Gateway error response; avoid returning empty response templates (tf-compat)

* support tags and tracingEnabled attribute for API GW stages

* fix boolean value in S3 encryption response (#20)

* fix connection arn structure

* fix api destination arn structure

* black format

* release 2.0.3.37

* fix s3 exception tests

see botocore/parsers.py:1002 where RequestId is removed from parsed

* remove python 2 from build action

* add test failure annotations in build action

* fix events test arn comparisons

* fix s3 encryption response test

* return default value "0" if EC2 availableIpAddressCount is empty

* fix extracting SecurityGroupIds for EC2 VPC endpoints

* support deleting/updating API Gateway DomainNames

* fix(events): Return empty string instead of null when no pattern is specified in EventPattern (tf-compat) (#22)

* fix logic and revert CF changes to get tests running again (#21)

* add support for EC2 customer gateway API (#25)

* add support for EC2 Transit Gateway APIs (#24)

* feat(logs): add `kmsKeyId` into `LogGroup` entity (#23)

* minor change in ELBv2 logic to fix tests

* feat(events): add APIs to describe and delete CloudWatch Events connections (#26)

* add support for EC2 transit gateway route tables (#27)

* pass transit gateway route table ID in Describe API, minor refactoring (#29)

* add support for EC2 Transit Gateway Routes (#28)

* fix region on ACM certificate import (#31)

* add support for EC2 transit gateway attachments (#30)

* add support for EC2 Transit Gateway VPN attachments (#32)

* fix account ID for logs API

* add support for DeleteOrganization API

* feat(events): store raw filter representation for CloudWatch events patterns (tf-compat) (#36)

* feat(events): add support to describe/update/delete CloudWatch API destinations (#35)

* add Cognito UpdateIdentityPool, CW Logs PutResourcePolicy

* feat(events): add support for tags in EventBus API (#38)

* fix parameter validation for Batch compute environments (tf-compat)

* revert merge conflicts in IMPLEMENTATION_COVERAGE.md

* format code using black

* restore original README; re-enable and fix CloudFormation tests

* restore tests and old logic for CF stack parameters from SSM

* parameterize RequestId/RequestID in response messages and revert related test changes

* undo LocalStack-specific adaptations

* minor fix

* Update CodeCov config to reflect removal of Py2

* undo change related to CW metric filtering; add additional test for CW metric statistics with dimensions

* Terraform - Extend whitelist of running tests

Co-authored-by: acsbendi <acsbendi28@gmail.com>
Co-authored-by: Phan Duong <duongpv@outlook.com>
Co-authored-by: Thomas Rausch <thomas@thrau.at>
Co-authored-by: Macwan Nevil <macnev2013@gmail.com>
Co-authored-by: Dominik Schubert <dominik.schubert91@gmail.com>
Co-authored-by: Gonzalo Saad <saad.gonzalo.ale@gmail.com>
Co-authored-by: Mohit Alonja <monty16597@users.noreply.github.com>
Co-authored-by: Miguel Gagliardo <migag9@gmail.com>
Co-authored-by: Bert Blommers <info@bertblommers.nl>
This commit is contained in:
Waldemar Hummer 2021-07-26 16:21:17 +02:00 committed by GitHub
parent 7693d77333
commit f4f8527955
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
93 changed files with 2947 additions and 631 deletions

View File

@ -90,7 +90,7 @@ jobs:
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}-4
# Update PIP - recent version does not support PY2 though
# Update PIP
- name: Update pip
run: |
# https://github.com/pypa/pip/issues/10201
@ -136,6 +136,7 @@ jobs:
run: |
pip install -r requirements-dev.txt
pip install pytest-cov
pip install pytest-github-actions-annotate-failures
- name: Test with pytest
run: |
make test-coverage
@ -229,7 +230,7 @@ jobs:
- name: Run Terraform Tests
run: |
cd moto-terraform-tests
bin/run-tests -i ../tests/terraform-tests.success.txt
bin/run-tests -i ../tests/terraform-tests.success.txt -e ../tests/terraform-tests.failures.txt
cd ..
- name: "Create report"
run: |

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
moto.egg-info/*
moto*.egg-info/*
dist/*
.cache
.tox

View File

@ -3,8 +3,7 @@ SHELL := /bin/bash
ifeq ($(TEST_SERVER_MODE), true)
# exclude test_kinesisvideoarchivedmedia
# because testing with moto_server is difficult with data-endpoint
TEST_EXCLUDE := -k 'not test_kinesisvideoarchivedmedia'
TEST_EXCLUDE := -k 'not test_kinesisvideoarchivedmedia'
else
TEST_EXCLUDE :=
endif

View File

@ -448,11 +448,13 @@ class AWSCertificateManagerBackend(BaseBackend):
else:
# Will reuse provided ARN
bundle = CertBundle(
certificate, private_key, chain=chain, region=region, arn=arn
certificate, private_key, chain=chain, region=self.region, arn=arn
)
else:
# Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region)
bundle = CertBundle(
certificate, private_key, chain=chain, region=self.region
)
self._certificates[bundle.arn] = bundle

View File

@ -1,8 +1,16 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
from moto.core.exceptions import JsonRESTError
class BadRequestException(RESTError):
class BadRequestException(JsonRESTError):
pass
class NotFoundException(JsonRESTError):
pass
class AccessDeniedException(JsonRESTError):
pass
@ -14,7 +22,7 @@ class AwsProxyNotAllowed(BadRequestException):
)
class CrossAccountNotAllowed(RESTError):
class CrossAccountNotAllowed(AccessDeniedException):
def __init__(self):
super(CrossAccountNotAllowed, self).__init__(
"AccessDeniedException", "Cross-account pass role is not allowed."
@ -71,10 +79,19 @@ class InvalidRequestInput(BadRequestException):
)
class NoIntegrationDefined(BadRequestException):
class NoIntegrationDefined(NotFoundException):
def __init__(self):
super(NoIntegrationDefined, self).__init__(
"BadRequestException", "No integration defined for method"
"NotFoundException", "No integration defined for method"
)
class NoIntegrationResponseDefined(NotFoundException):
code = 404
def __init__(self, code=None):
super(NoIntegrationResponseDefined, self).__init__(
"NotFoundException", "No integration defined for method, code '%s'" % code
)
@ -85,7 +102,7 @@ class NoMethodDefined(BadRequestException):
)
class AuthorizerNotFoundException(RESTError):
class AuthorizerNotFoundException(NotFoundException):
code = 404
def __init__(self):
@ -94,7 +111,7 @@ class AuthorizerNotFoundException(RESTError):
)
class StageNotFoundException(RESTError):
class StageNotFoundException(NotFoundException):
code = 404
def __init__(self):
@ -103,7 +120,7 @@ class StageNotFoundException(RESTError):
)
class ApiKeyNotFoundException(RESTError):
class ApiKeyNotFoundException(NotFoundException):
code = 404
def __init__(self):
@ -112,7 +129,7 @@ class ApiKeyNotFoundException(RESTError):
)
class UsagePlanNotFoundException(RESTError):
class UsagePlanNotFoundException(NotFoundException):
code = 404
def __init__(self):
@ -121,7 +138,7 @@ class UsagePlanNotFoundException(RESTError):
)
class ApiKeyAlreadyExists(RESTError):
class ApiKeyAlreadyExists(JsonRESTError):
code = 409
def __init__(self):
@ -139,7 +156,7 @@ class InvalidDomainName(BadRequestException):
)
class DomainNameNotFound(RESTError):
class DomainNameNotFound(NotFoundException):
code = 404
def __init__(self):
@ -166,7 +183,7 @@ class InvalidModelName(BadRequestException):
)
class RestAPINotFound(RESTError):
class RestAPINotFound(NotFoundException):
code = 404
def __init__(self):
@ -175,7 +192,7 @@ class RestAPINotFound(RESTError):
)
class ModelNotFound(RESTError):
class ModelNotFound(NotFoundException):
code = 404
def __init__(self):
@ -184,10 +201,19 @@ class ModelNotFound(RESTError):
)
class ApiKeyValueMinLength(RESTError):
class ApiKeyValueMinLength(BadRequestException):
code = 400
def __init__(self):
super(ApiKeyValueMinLength, self).__init__(
"BadRequestException", "API Key value should be at least 20 characters"
)
class MethodNotFoundException(NotFoundException):
code = 404
def __init__(self):
super(MethodNotFoundException, self).__init__(
"NotFoundException", "Invalid method properties specified"
)

View File

@ -33,6 +33,7 @@ from .exceptions import (
StageNotFoundException,
RoleNotSpecified,
NoIntegrationDefined,
NoIntegrationResponseDefined,
NoMethodDefined,
ApiKeyAlreadyExists,
DomainNameNotFound,
@ -44,6 +45,7 @@ from .exceptions import (
ApiKeyValueMinLength,
)
from ..core.models import responses_mock
from moto.apigateway.exceptions import MethodNotFoundException
STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
@ -87,7 +89,12 @@ class IntegrationResponse(BaseModel, dict):
content_handling=None,
):
if response_templates is None:
response_templates = {"application/json": None}
# response_templates = {"application/json": None} # Note: removed for compatibility with TF
response_templates = {}
for key in response_templates.keys():
response_templates[key] = (
response_templates[key] or None
) # required for compatibility with TF
self["responseTemplates"] = response_templates
self["statusCode"] = status_code
if selection_pattern:
@ -97,13 +104,26 @@ class IntegrationResponse(BaseModel, dict):
class Integration(BaseModel, dict):
def __init__(self, integration_type, uri, http_method, request_templates=None):
def __init__(
self,
integration_type,
uri,
http_method,
request_templates=None,
tls_config=None,
cache_namespace=None,
):
super(Integration, self).__init__()
self["type"] = integration_type
self["uri"] = uri
self["httpMethod"] = http_method
self["requestTemplates"] = request_templates
self["integrationResponses"] = {"200": IntegrationResponse(200)}
# self["integrationResponses"] = {"200": IntegrationResponse(200)} # commented out (tf-compat)
self[
"integrationResponses"
] = None # prevent json serialization from including them if none provided
self["tlsConfig"] = tls_config
self["cacheNamespace"] = cache_namespace
def create_integration_response(
self, status_code, selection_pattern, response_templates, content_handling
@ -113,20 +133,27 @@ class Integration(BaseModel, dict):
integration_response = IntegrationResponse(
status_code, selection_pattern, response_templates, content_handling
)
if self.get("integrationResponses") is None:
self["integrationResponses"] = {}
self["integrationResponses"][status_code] = integration_response
return integration_response
def get_integration_response(self, status_code):
return self["integrationResponses"][status_code]
result = self.get("integrationResponses", {}).get(status_code)
if not result:
raise NoIntegrationResponseDefined(status_code)
return result
def delete_integration_response(self, status_code):
return self["integrationResponses"].pop(status_code)
return self.get("integrationResponses", {}).pop(status_code, None)
class MethodResponse(BaseModel, dict):
def __init__(self, status_code):
def __init__(self, status_code, response_models=None, response_parameters=None):
super(MethodResponse, self).__init__()
self["statusCode"] = status_code
self["responseModels"] = response_models
self["responseParameters"] = response_parameters
class Method(CloudFormationModel, dict):
@ -136,11 +163,14 @@ class Method(CloudFormationModel, dict):
dict(
httpMethod=method_type,
authorizationType=authorization_type,
authorizerId=None,
authorizerId=kwargs.get("authorizer_id"),
authorizationScopes=kwargs.get("authorization_scopes"),
apiKeyRequired=kwargs.get("api_key_required") or False,
requestParameters=None,
requestModels=None,
requestModels=kwargs.get("request_models"),
methodIntegration=None,
operationName=kwargs.get("operation_name"),
requestValidatorId=kwargs.get("request_validator_id"),
)
)
self.method_responses = {}
@ -184,16 +214,18 @@ class Method(CloudFormationModel, dict):
)
return m
def create_response(self, response_code):
method_response = MethodResponse(response_code)
def create_response(self, response_code, response_models, response_parameters):
method_response = MethodResponse(
response_code, response_models, response_parameters
)
self.method_responses[response_code] = method_response
return method_response
def get_response(self, response_code):
return self.method_responses[response_code]
return self.method_responses.get(response_code)
def delete_response(self, response_code):
return self.method_responses.pop(response_code)
return self.method_responses.pop(response_code, None)
class Resource(CloudFormationModel):
@ -279,29 +311,62 @@ class Resource(CloudFormationModel):
)
return response.status_code, response.text
def add_method(self, method_type, authorization_type, api_key_required):
def add_method(
self,
method_type,
authorization_type,
api_key_required,
request_models=None,
operation_name=None,
authorizer_id=None,
authorization_scopes=None,
request_validator_id=None,
):
if authorization_scopes and not isinstance(authorization_scopes, list):
authorization_scopes = [authorization_scopes]
method = Method(
method_type=method_type,
authorization_type=authorization_type,
api_key_required=api_key_required,
request_models=request_models,
operation_name=operation_name,
authorizer_id=authorizer_id,
authorization_scopes=authorization_scopes,
request_validator_id=request_validator_id,
)
self.resource_methods[method_type] = method
return method
def get_method(self, method_type):
return self.resource_methods[method_type]
method = self.resource_methods.get(method_type)
if not method:
raise MethodNotFoundException()
return method
def add_integration(
self, method_type, integration_type, uri, request_templates=None
self,
method_type,
integration_type,
uri,
request_templates=None,
integration_method=None,
tls_config=None,
cache_namespace=None,
):
integration_method = integration_method or method_type
integration = Integration(
integration_type, uri, method_type, request_templates=request_templates
integration_type,
uri,
integration_method,
request_templates=request_templates,
tls_config=tls_config,
cache_namespace=cache_namespace,
)
self.resource_methods[method_type]["methodIntegration"] = integration
return integration
def get_integration(self, method_type):
return self.resource_methods[method_type]["methodIntegration"]
return self.resource_methods.get(method_type, {}).get("methodIntegration", {})
def delete_integration(self, method_type):
return self.resource_methods[method_type].pop("methodIntegration")
@ -364,6 +429,8 @@ class Stage(BaseModel, dict):
description="",
cacheClusterEnabled=False,
cacheClusterSize=None,
tags=None,
tracing_enabled=None,
):
super(Stage, self).__init__()
if variables is None:
@ -376,9 +443,12 @@ class Stage(BaseModel, dict):
self["cacheClusterEnabled"] = cacheClusterEnabled
if self["cacheClusterEnabled"]:
self["cacheClusterSize"] = str(0.5)
if cacheClusterSize is not None:
self["cacheClusterSize"] = str(cacheClusterSize)
if tags is not None:
self["tags"] = tags
if tracing_enabled is not None:
self["tracingEnabled"] = tracing_enabled
def apply_operations(self, patch_operations):
for op in patch_operations:
@ -607,6 +677,7 @@ class RestAPI(CloudFormationModel):
self.disableExecuteApiEndpoint = (
kwargs.get("disableExecuteApiEndpoint") or False
)
self.minimum_compression_size = kwargs.get("minimum_compression_size")
self.deployments = {}
self.authorizers = {}
self.stages = {}
@ -624,12 +695,13 @@ class RestAPI(CloudFormationModel):
"description": self.description,
"version": self.version,
"binaryMediaTypes": self.binaryMediaTypes,
"createdDate": int(time.time()),
"createdDate": self.create_date,
"apiKeySource": self.api_key_source,
"endpointConfiguration": self.endpoint_configuration,
"tags": self.tags,
"policy": self.policy,
"disableExecuteApiEndpoint": self.disableExecuteApiEndpoint,
"minimumCompressionSize": self.minimum_compression_size,
}
def apply_patch_operations(self, patch_operations):
@ -652,7 +724,10 @@ class RestAPI(CloudFormationModel):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "RootResourceId":
return self.id
for res_id, res_obj in self.resources.items():
if res_obj.path_part == "/" and not res_obj.parent_id:
return res_id
raise Exception("Unable to find root resource for API %s" % self)
raise UnformattedGetAttTemplateException()
@property
@ -787,6 +862,8 @@ class RestAPI(CloudFormationModel):
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
tags=None,
tracing_enabled=None,
):
if variables is None:
variables = {}
@ -797,6 +874,8 @@ class RestAPI(CloudFormationModel):
description=description,
cacheClusterSize=cacheClusterSize,
cacheClusterEnabled=cacheClusterEnabled,
tags=tags,
tracing_enabled=tracing_enabled,
)
self.stages[name] = stage
self.update_integration_mocks(name)
@ -835,8 +914,11 @@ class DomainName(BaseModel, dict):
def __init__(self, domain_name, **kwargs):
super(DomainName, self).__init__()
self["domainName"] = domain_name
self["regionalDomainName"] = domain_name
self["distributionDomainName"] = domain_name
self["regionalDomainName"] = "d-%s.execute-api.%s.amazonaws.com" % (
create_id(),
kwargs.get("region_name") or "us-east-1",
)
self["distributionDomainName"] = "d%s.cloudfront.net" % create_id()
self["domainNameStatus"] = "AVAILABLE"
self["domainNameStatusMessage"] = "Domain Name Available"
self["regionalHostedZoneId"] = "Z2FDTNDATAQYW2"
@ -907,6 +989,7 @@ class APIGatewayBackend(BaseBackend):
endpoint_configuration=None,
tags=None,
policy=None,
minimum_compression_size=None,
):
api_id = create_id()
rest_api = RestAPI(
@ -918,6 +1001,7 @@ class APIGatewayBackend(BaseBackend):
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
minimum_compression_size=minimum_compression_size,
)
self.apis[api_id] = rest_api
return rest_api
@ -974,13 +1058,30 @@ class APIGatewayBackend(BaseBackend):
method_type,
authorization_type,
api_key_required=None,
request_models=None,
operation_name=None,
authorizer_id=None,
authorization_scopes=None,
request_validator_id=None,
):
resource = self.get_resource(function_id, resource_id)
method = resource.add_method(
method_type, authorization_type, api_key_required=api_key_required
method_type,
authorization_type,
api_key_required=api_key_required,
request_models=request_models,
operation_name=operation_name,
authorizer_id=authorizer_id,
authorization_scopes=authorization_scopes,
request_validator_id=request_validator_id,
)
return method
def update_method(self, function_id, resource_id, method_type, patch_operations):
resource = self.get_resource(function_id, resource_id)
method = resource.get_method(method_type)
return method.apply_operations(patch_operations)
def get_authorizer(self, restapi_id, authorizer_id):
api = self.get_rest_api(restapi_id)
authorizer = api.authorizers.get(authorizer_id)
@ -1026,8 +1127,7 @@ class APIGatewayBackend(BaseBackend):
stage = api.stages.get(stage_name)
if stage is None:
raise StageNotFoundException()
else:
return stage
return stage
def get_stages(self, function_id):
api = self.get_rest_api(function_id)
@ -1042,6 +1142,8 @@ class APIGatewayBackend(BaseBackend):
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
tags=None,
tracing_enabled=None,
):
if variables is None:
variables = {}
@ -1053,6 +1155,8 @@ class APIGatewayBackend(BaseBackend):
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
tags=tags,
tracing_enabled=tracing_enabled,
)
return api.stages.get(stage_name)
@ -1065,7 +1169,9 @@ class APIGatewayBackend(BaseBackend):
def delete_stage(self, function_id, stage_name):
api = self.get_rest_api(function_id)
del api.stages[stage_name]
deleted = api.stages.pop(stage_name, None)
if not deleted:
raise StageNotFoundException()
def get_method_response(self, function_id, resource_id, method_type, response_code):
method = self.get_method(function_id, resource_id, method_type)
@ -1073,10 +1179,26 @@ class APIGatewayBackend(BaseBackend):
return method_response
def create_method_response(
self, function_id, resource_id, method_type, response_code
self,
function_id,
resource_id,
method_type,
response_code,
response_models,
response_parameters,
):
method = self.get_method(function_id, resource_id, method_type)
method_response = method.create_response(response_code)
method_response = method.create_response(
response_code, response_models, response_parameters
)
return method_response
def update_method_response(
self, function_id, resource_id, method_type, response_code, patch_operations
):
method = self.get_method(function_id, resource_id, method_type)
method_response = method.get_response(response_code)
method_response.apply_operations(patch_operations)
return method_response
def delete_method_response(
@ -1096,6 +1218,8 @@ class APIGatewayBackend(BaseBackend):
integration_method=None,
credentials=None,
request_templates=None,
tls_config=None,
cache_namespace=None,
):
resource = self.get_resource(function_id, resource_id)
if credentials and not re.match(
@ -1128,7 +1252,13 @@ class APIGatewayBackend(BaseBackend):
):
raise InvalidIntegrationArn()
integration = resource.add_integration(
method_type, integration_type, uri, request_templates=request_templates
method_type,
integration_type,
uri,
integration_method=integration_method,
request_templates=request_templates,
tls_config=tls_config,
cache_namespace=cache_namespace,
)
return integration
@ -1205,7 +1335,7 @@ class APIGatewayBackend(BaseBackend):
return api.delete_deployment(deployment_id)
def create_api_key(self, payload):
if payload.get("value") is not None:
if payload.get("value"):
if len(payload.get("value", [])) < 20:
raise ApiKeyValueMinLength()
for api_key in self.get_api_keys(include_values=True):
@ -1229,7 +1359,9 @@ class APIGatewayBackend(BaseBackend):
return api_keys
def get_api_key(self, api_key_id, include_value=False):
api_key = self.keys[api_key_id]
api_key = self.keys.get(api_key_id)
if not api_key:
raise ApiKeyNotFoundException()
if not include_value:
new_key = copy(api_key)
@ -1322,7 +1454,7 @@ class APIGatewayBackend(BaseBackend):
def _uri_validator(self, uri):
try:
result = urlparse(uri)
return all([result.scheme, result.netloc, result.path])
return all([result.scheme, result.netloc, result.path or "/"])
except Exception:
return False
@ -1358,6 +1490,7 @@ class APIGatewayBackend(BaseBackend):
tags=tags,
security_policy=security_policy,
generate_cli_skeleton=generate_cli_skeleton,
region_name=self.region_name,
)
self.domain_names[domain_name] = new_domain_name
@ -1369,10 +1502,22 @@ class APIGatewayBackend(BaseBackend):
def get_domain_name(self, domain_name):
domain_info = self.domain_names.get(domain_name)
if domain_info is None:
raise DomainNameNotFound
raise DomainNameNotFound()
else:
return self.domain_names[domain_name]
def delete_domain_name(self, domain_name):
domain_info = self.domain_names.pop(domain_name, None)
if domain_info is None:
raise DomainNameNotFound()
def update_domain_name(self, domain_name, patch_operations):
domain_info = self.domain_names.get(domain_name)
if not domain_info:
raise DomainNameNotFound()
domain_info.apply_patch_operations(patch_operations)
return domain_info
def create_model(
self,
rest_api_id,

View File

@ -20,6 +20,8 @@ from .exceptions import (
ModelNotFound,
ApiKeyValueMinLength,
InvalidRequestInput,
NoIntegrationDefined,
NotFoundException,
)
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
@ -29,9 +31,11 @@ ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400):
headers = self.response_headers or {}
headers["X-Amzn-Errortype"] = type_
return (
status,
self.response_headers,
headers,
json.dumps({"__type": type_, "message": message}),
)
@ -80,6 +84,7 @@ class APIGatewayResponse(BaseResponse):
endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags")
policy = self._get_param("policy")
minimum_compression_size = self._get_param("minimumCompressionSize")
# Param validation
response = self.__validate_api_key_source(api_key_source)
@ -97,6 +102,7 @@ class APIGatewayResponse(BaseResponse):
endpoint_configuration=endpoint_configuration,
tags=tags,
policy=policy,
minimum_compression_size=minimum_compression_size,
)
return 200, {}, json.dumps(rest_api.to_dict())
@ -162,9 +168,7 @@ class APIGatewayResponse(BaseResponse):
resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict())
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
return self.error("BadRequestException", e.message)
def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -179,15 +183,37 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "PUT":
authorization_type = self._get_param("authorizationType")
api_key_required = self._get_param("apiKeyRequired")
request_models = self._get_param("requestModels")
operation_name = self._get_param("operationName")
authorizer_id = self._get_param("authorizerId")
authorization_scopes = self._get_param("authorizationScopes")
request_validator_id = self._get_param("requestValidatorId")
method = self.backend.create_method(
function_id,
resource_id,
method_type,
authorization_type,
api_key_required,
request_models=request_models,
operation_name=operation_name,
authorizer_id=authorizer_id,
authorization_scopes=authorization_scopes,
request_validator_id=request_validator_id,
)
return 200, {}, json.dumps(method)
elif self.method == "DELETE":
self.backend.delete_method(function_id, resource_id, method_type)
return 200, {}, ""
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
self.backend.update_method(
function_id, resource_id, method_type, patch_operations
)
return 200, {}, ""
def resource_method_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/")
@ -201,13 +227,27 @@ class APIGatewayResponse(BaseResponse):
function_id, resource_id, method_type, response_code
)
elif self.method == "PUT":
response_models = self._get_param("responseModels")
response_parameters = self._get_param("responseParameters")
method_response = self.backend.create_method_response(
function_id, resource_id, method_type, response_code
function_id,
resource_id,
method_type,
response_code,
response_models,
response_parameters,
)
elif self.method == "DELETE":
method_response = self.backend.delete_method_response(
function_id, resource_id, method_type, response_code
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
method_response = self.backend.update_method_response(
function_id, resource_id, method_type, response_code, patch_operations
)
else:
raise Exception('Unexpected HTTP method "%s"' % self.method)
return 200, {}, json.dumps(method_response)
def restapis_authorizers(self, request, full_url, headers):
@ -302,6 +342,8 @@ class APIGatewayResponse(BaseResponse):
description = self._get_param("description", if_none="")
cacheClusterEnabled = self._get_param("cacheClusterEnabled", if_none=False)
cacheClusterSize = self._get_param("cacheClusterSize")
tags = self._get_param("tags")
tracing_enabled = self._get_param("tracingEnabled")
stage_response = self.backend.create_stage(
function_id,
@ -311,6 +353,8 @@ class APIGatewayResponse(BaseResponse):
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
tags=tags,
tracing_enabled=tracing_enabled,
)
elif self.method == "GET":
stages = self.backend.get_stages(function_id)
@ -353,6 +397,8 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6]
try:
integration_response = {}
if self.method == "GET":
integration_response = self.backend.get_integration(
function_id, resource_id, method_type
@ -360,32 +406,39 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "PUT":
integration_type = self._get_param("type")
uri = self._get_param("uri")
integration_http_method = self._get_param("httpMethod")
creds = self._get_param("credentials")
credentials = self._get_param("credentials")
request_templates = self._get_param("requestTemplates")
tls_config = self._get_param("tlsConfig")
cache_namespace = self._get_param("cacheNamespace")
self.backend.get_method(function_id, resource_id, method_type)
integration_http_method = self._get_param(
"httpMethod"
) # default removed because it's a required parameter
integration_response = self.backend.create_integration(
function_id,
resource_id,
method_type,
integration_type,
uri,
credentials=creds,
credentials=credentials,
integration_method=integration_http_method,
request_templates=request_templates,
tls_config=tls_config,
cache_namespace=cache_namespace,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration(
function_id, resource_id, method_type
)
return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
return self.error("BadRequestException", e.message)
except CrossAccountNotAllowed as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#AccessDeniedException", e.message
)
return self.error("AccessDeniedException", e.message)
def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -422,9 +475,9 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(integration_response)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
return self.error("BadRequestException", e.message)
except NoIntegrationDefined as e:
return self.error("NotFoundException", e.message)
def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -443,9 +496,9 @@ class APIGatewayResponse(BaseResponse):
)
return 200, {}, json.dumps(deployment)
except BadRequestException as e:
return self.error(
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
return self.error("BadRequestException", e.message)
except NotFoundException as e:
return self.error("NotFoundException", e.message)
def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -453,6 +506,7 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2]
deployment_id = url_path_parts[4]
deployment = None
if self.method == "GET":
deployment = self.backend.get_deployment(function_id, deployment_id)
elif self.method == "DELETE":
@ -652,6 +706,18 @@ class APIGatewayResponse(BaseResponse):
if self.method == "GET":
if domain_name is not None:
domain_names = self.backend.get_domain_name(domain_name)
elif self.method == "DELETE":
if domain_name is not None:
self.backend.delete_domain_name(domain_name)
elif self.method == "PATCH":
if domain_name is not None:
patch_operations = self._get_param("patchOperations")
self.backend.update_domain_name(domain_name, patch_operations)
else:
msg = (
'Method "%s" for API GW domain names not implemented' % self.method
)
return 404, {}, json.dumps({"error": msg})
return 200, {}, json.dumps(domain_names)
except DomainNameNotFound as error:
return (

View File

@ -155,7 +155,7 @@ class ApplicationAutoscalingBackend(BaseBackend):
service_namespace, resource_id, scalable_dimension, policy_name
)
if policy_key in self.policies:
old_policy = self.policies[policy_name]
old_policy = self.policies[policy_key]
policy = FakeApplicationAutoscalingPolicy(
region_name=self.region,
policy_name=policy_name,

View File

@ -82,7 +82,7 @@ class AthenaResponse(BaseResponse):
def error(self, msg, status):
return (
json.dumps({"__type": "InvalidRequestException", "Message": msg,}),
json.dumps({"__type": "InvalidRequestException", "Message": msg}),
dict(status=status),
)

View File

@ -26,7 +26,7 @@ import requests.exceptions
from boto3 import Session
from moto.awslambda.policy import Policy
from moto.core import BaseBackend, CloudFormationModel
from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend
from moto.iam.exceptions import IAMNotFoundException
@ -1072,6 +1072,32 @@ class LayerStorage(object):
return None
class LambdaPermission(BaseModel):
def __init__(self, spec):
self.action = spec["Action"]
self.function_name = spec["FunctionName"]
self.principal = spec["Principal"]
# optional
self.source_account = spec.get("SourceAccount")
self.source_arn = spec.get("SourceArn")
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
spec = {
"Action": properties["Action"],
"FunctionName": properties["FunctionName"],
"Principal": properties["Principal"],
}
optional_properties = "SourceAccount SourceArn".split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
return LambdaPermission(spec)
class LambdaBackend(BaseBackend):
def __init__(self, region_name):
self._lambdas = LambdaStorage()

View File

@ -16,7 +16,7 @@ from moto.ec2 import ec2_backends
from moto.ecs import ecs_backends
from moto.logs import logs_backends
from .exceptions import InvalidParameterValueException, InternalFailure, ClientException
from .exceptions import InvalidParameterValueException, ClientException
from .utils import (
make_arn_for_compute_env,
make_arn_for_job_queue,
@ -463,7 +463,8 @@ class Job(threading.Thread, BaseModel, DockerModel):
self.job_state = "STARTING"
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
image_repository, image_tag = parse_image_ref(image)
self.docker_client.images.pull(image_repository, image_tag)
# avoid explicit pulling here, to allow using cached images
# self.docker_client.images.pull(image_repository, image_tag)
container = self.docker_client.containers.run(
image,
cmd,
@ -479,6 +480,7 @@ class Job(threading.Thread, BaseModel, DockerModel):
container.reload()
while container.status == "running" and not self.stop:
container.reload()
time.sleep(0.5)
# Container should be stopped by this point... unless asked to stop
if container.status == "running":
@ -531,11 +533,9 @@ class Job(threading.Thread, BaseModel, DockerModel):
self._log_backend.create_log_stream(log_group, stream_name)
self._log_backend.put_log_events(log_group, stream_name, logs, None)
result = container.wait()
if self.stop or result["StatusCode"] != 0:
self._mark_stopped(success=False)
else:
self._mark_stopped(success=True)
result = container.wait() or {}
job_failed = self.stop or result.get("StatusCode", 0) > 0
self._mark_stopped(success=not job_failed)
except Exception as err:
logger.error(
@ -782,6 +782,7 @@ class BatchBackend(BaseBackend):
"state": environment.state,
"type": environment.env_type,
"status": "VALID",
"statusReason": "Compute environment is available",
}
if environment.env_type == "MANAGED":
json_part["computeResources"] = environment.compute_resources
@ -898,9 +899,10 @@ class BatchBackend(BaseBackend):
"type",
):
if param not in cr:
raise InvalidParameterValueException(
"computeResources must contain {0}".format(param)
)
pass # commenting out invalid check below - values may be missing (tf-compat)
# raise InvalidParameterValueException(
# "computeResources must contain {0}".format(param)
# )
for profile in self.iam_backend.get_instance_profiles():
if profile.arn == cr["instanceRole"]:
break
@ -955,9 +957,6 @@ class BatchBackend(BaseBackend):
"computeResources.type must be either EC2 | SPOT"
)
if cr["type"] == "SPOT":
raise InternalFailure("SPOT NOT SUPPORTED YET")
@staticmethod
def find_min_instances_to_meet_vcpus(instance_types, target):
"""
@ -1027,7 +1026,8 @@ class BatchBackend(BaseBackend):
if compute_env.env_type == "MANAGED":
# Delete compute environment
instance_ids = [instance.id for instance in compute_env.instances]
self.ec2_backend.terminate_instances(instance_ids)
if instance_ids:
self.ec2_backend.terminate_instances(instance_ids)
def update_compute_environment(
self, compute_environment_name, state, compute_resources, service_role

View File

@ -505,9 +505,8 @@ class ResourceMap(collections_abc.Mapping):
self._parsed_resources.update(json.loads(key.value))
def parse_ssm_parameter(self, value, value_type):
# The Value in SSM parameters is the SSM parameter path
# we need to use ssm_backend to retreive the
# we need to use ssm_backend to retrieve the
# actual value from parameter store
parameter = ssm_backends[self._region_name].get_parameter(value, False)
actual_value = parameter.value

View File

@ -44,6 +44,8 @@ def yaml_tag_constructor(loader, tag, node):
def _f(loader, tag, node):
if tag == "!GetAtt":
if isinstance(node.value, list):
return node.value
return node.value.split(".")
elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node)

View File

@ -107,6 +107,7 @@ class FakeAlarm(BaseModel):
unit,
actions_enabled,
region="us-east-1",
rule=None,
):
self.name = name
self.alarm_arn = make_arn_for_alarm(region, DEFAULT_ACCOUNT_ID, name)
@ -123,7 +124,7 @@ class FakeAlarm(BaseModel):
self.dimensions = [
Dimension(dimension["name"], dimension["value"]) for dimension in dimensions
]
self.actions_enabled = actions_enabled
self.actions_enabled = True if actions_enabled is None else actions_enabled
self.alarm_actions = alarm_actions
self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions
@ -137,6 +138,9 @@ class FakeAlarm(BaseModel):
self.state_value = "OK"
self.state_updated_timestamp = datetime.utcnow()
# only used for composite alarms
self.rule = rule
def update_state(self, reason, reason_data, state_value):
# History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action
self.history.append(
@ -156,6 +160,8 @@ class FakeAlarm(BaseModel):
def are_dimensions_same(metric_dimensions, dimensions):
if len(metric_dimensions) != len(dimensions):
return False
for dimension in metric_dimensions:
for new_dimension in dimensions:
if (
@ -163,7 +169,6 @@ def are_dimensions_same(metric_dimensions, dimensions):
or dimension.value != new_dimension.value
):
return False
return True
@ -178,11 +183,12 @@ class MetricDatum(BaseModel):
]
self.unit = unit
def filter(self, namespace, name, dimensions, already_present_metrics):
def filter(self, namespace, name, dimensions, already_present_metrics=[]):
if namespace and namespace != self.namespace:
return False
if name and name != self.name:
return False
for metric in already_present_metrics:
if self.dimensions and are_dimensions_same(
metric.dimensions, self.dimensions
@ -302,6 +308,7 @@ class CloudWatchBackend(BaseBackend):
unit,
actions_enabled,
region="us-east-1",
rule=None,
):
alarm = FakeAlarm(
name,
@ -322,6 +329,7 @@ class CloudWatchBackend(BaseBackend):
unit,
actions_enabled,
region,
rule=rule,
)
self.alarms[name] = alarm
@ -451,7 +459,15 @@ class CloudWatchBackend(BaseBackend):
return results
def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats, unit=None
self,
namespace,
metric_name,
start_time,
end_time,
period,
stats,
unit=None,
dimensions=None,
):
period_delta = timedelta(seconds=period)
filtered_data = [
@ -464,6 +480,10 @@ class CloudWatchBackend(BaseBackend):
if unit:
filtered_data = [md for md in filtered_data if md.unit == unit]
if dimensions:
filtered_data = [
md for md in filtered_data if md.filter(None, None, dimensions)
]
# earliest to oldest
filtered_data = sorted(filtered_data, key=lambda x: x.timestamp)

View File

@ -1,9 +1,11 @@
import json
from moto.core.utils import amzn_request_id
from moto.core.responses import BaseResponse
from .models import cloudwatch_backends, MetricDataQuery, MetricStat, Metric, Dimension
from dateutil.parser import parse as dtparse
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .models import cloudwatch_backends, MetricDataQuery, MetricStat, Metric, Dimension
class CloudWatchResponse(BaseResponse):
@property
@ -19,34 +21,49 @@ class CloudWatchResponse(BaseResponse):
name = self._get_param("AlarmName")
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
metrics = self._get_multi_param("Metrics.member")
metrics = self._get_multi_param("Metrics.member", skip_result_conversion=True)
metric_data_queries = None
if metrics:
metric_data_queries = [
MetricDataQuery(
id=metric.get("Id"),
label=metric.get("Label"),
period=metric.get("Period"),
return_data=metric.get("ReturnData"),
expression=metric.get("Expression"),
metric_stat=MetricStat(
metric=Metric(
metric_name=metric.get("MetricStat.Metric.MetricName"),
namespace=metric.get("MetricStat.Metric.Namespace"),
dimensions=[
Dimension(name=dim["Name"], value=dim["Value"])
for dim in metric["MetricStat.Metric.Dimensions.member"]
],
),
period=metric.get("MetricStat.Period"),
stat=metric.get("MetricStat.Stat"),
unit=metric.get("MetricStat.Unit"),
)
if "MetricStat.Metric.MetricName" in metric
else None,
metric_data_queries = []
for metric in metrics:
dimensions = []
dims = (
metric.get("MetricStat", {})
.get("Metric", {})
.get("Dimensions.member", [])
)
for metric in metrics
]
for dim in dims:
dimensions.append(
Dimension(name=dim.get("Name"), value=dim.get("Value"))
)
metric_stat = None
stat_metric_name = (
metric.get("MetricStat", {}).get("Metric", {}).get("MetricName")
)
if stat_metric_name:
stat_details = metric.get("MetricStat", {})
stat_metric_ns = stat_details.get("Metric", {}).get("Namespace")
metric_stat = MetricStat(
metric=Metric(
metric_name=stat_metric_name,
namespace=stat_metric_ns,
dimensions=dimensions,
),
period=stat_details.get("Period"),
stat=stat_details.get("Stat"),
unit=stat_details.get("Unit"),
)
metric_data_queries.append(
MetricDataQuery(
id=metric.get("Id"),
label=metric.get("Label"),
period=metric.get("Period"),
return_data=metric.get("ReturnData"),
expression=metric.get("Expression"),
metric_stat=metric_stat,
)
)
comparison_operator = self._get_param("ComparisonOperator")
evaluation_periods = self._get_param("EvaluationPeriods")
datapoints_to_alarm = self._get_param("DatapointsToAlarm")
@ -62,6 +79,8 @@ class CloudWatchResponse(BaseResponse):
"InsufficientDataActions.member"
)
unit = self._get_param("Unit")
# fetch AlarmRule to re-use this method for composite alarms as well
rule = self._get_param("AlarmRule")
alarm = self.cloudwatch_backend.put_metric_alarm(
name,
namespace,
@ -81,6 +100,7 @@ class CloudWatchResponse(BaseResponse):
unit,
actions_enabled,
self.region,
rule=rule,
)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm)
@ -105,8 +125,13 @@ class CloudWatchResponse(BaseResponse):
else:
alarms = self.cloudwatch_backend.get_all_alarms()
metric_alarms = [a for a in alarms if a.rule is None]
composite_alarms = [a for a in alarms if a.rule is not None]
template = self.response_template(DESCRIBE_ALARMS_TEMPLATE)
return template.render(alarms=alarms)
return template.render(
metric_alarms=metric_alarms, composite_alarms=composite_alarms
)
@amzn_request_id
def delete_alarms(self):
@ -145,12 +170,12 @@ class CloudWatchResponse(BaseResponse):
end_time = dtparse(self._get_param("EndTime"))
period = int(self._get_param("Period"))
statistics = self._get_multi_param("Statistics.member")
dimensions = self._get_multi_param("Dimensions.member")
# Unsupported Parameters (To Be Implemented)
unit = self._get_param("Unit")
extended_statistics = self._get_param("ExtendedStatistics")
dimensions = self._get_param("Dimensions")
if extended_statistics or dimensions:
if extended_statistics:
raise NotImplementedError()
# TODO: this should instead throw InvalidParameterCombination
@ -160,7 +185,14 @@ class CloudWatchResponse(BaseResponse):
)
datapoints = self.cloudwatch_backend.get_metric_statistics(
namespace, metric_name, start_time, end_time, period, statistics, unit
namespace,
metric_name,
start_time,
end_time,
period,
statistics,
unit,
dimensions=dimensions,
)
template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE)
return template.render(label=metric_name, datapoints=datapoints)
@ -280,7 +312,8 @@ PUT_METRIC_ALARM_TEMPLATE = """<PutMetricAlarmResponse xmlns="http://monitoring.
DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<DescribeAlarmsResult>
<MetricAlarms>
{% for tag_name, alarms in (('MetricAlarms', metric_alarms), ('CompositeAlarms', composite_alarms)) %}
<{{tag_name}}>
{% for alarm in alarms %}
<member>
<ActionsEnabled>{{ alarm.actions_enabled }}</ActionsEnabled>
@ -291,7 +324,7 @@ DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.a
</AlarmActions>
<AlarmArn>{{ alarm.alarm_arn }}</AlarmArn>
<AlarmConfigurationUpdatedTimestamp>{{ alarm.configuration_updated_timestamp }}</AlarmConfigurationUpdatedTimestamp>
<AlarmDescription>{{ alarm.description }}</AlarmDescription>
<AlarmDescription>{{ alarm.description or '' }}</AlarmDescription>
<AlarmName>{{ alarm.name }}</AlarmName>
<ComparisonOperator>{{ alarm.comparison_operator }}</ComparisonOperator>
{% if alarm.dimensions is not none %}
@ -376,13 +409,19 @@ DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.a
{% if alarm.statistic is not none %}
<Statistic>{{ alarm.statistic }}</Statistic>
{% endif %}
{% if alarm.threshold is not none %}
<Threshold>{{ alarm.threshold }}</Threshold>
{% endif %}
{% if alarm.unit is not none %}
<Unit>{{ alarm.unit }}</Unit>
{% endif %}
{% if alarm.rule is not none %}
<AlarmRule>{{ alarm.rule }}</AlarmRule>
{% endif %}
</member>
{% endfor %}
</MetricAlarms>
</{{tag_name}}>
{% endfor %}
</DescribeAlarmsResult>
</DescribeAlarmsResponse>"""
@ -429,7 +468,9 @@ DESCRIBE_METRIC_ALARMS_TEMPLATE = """<DescribeAlarmsForMetricResponse xmlns="htt
<StateUpdatedTimestamp>{{ alarm.state_updated_timestamp }}</StateUpdatedTimestamp>
<StateValue>{{ alarm.state_value }}</StateValue>
<Statistic>{{ alarm.statistic }}</Statistic>
{% if alarm.threshold is not none %}
<Threshold>{{ alarm.threshold }}</Threshold>
{% endif %}
<Unit>{{ alarm.unit }}</Unit>
</member>
{% endfor %}

View File

@ -29,6 +29,20 @@ class CognitoIdentity(BaseModel):
self.identity_pool_id = get_random_identity_id(region)
self.creation_time = datetime.datetime.utcnow()
def to_json(self):
return json.dumps(
{
"IdentityPoolId": self.identity_pool_id,
"IdentityPoolName": self.identity_pool_name,
"AllowUnauthenticatedIdentities": self.allow_unauthenticated_identities,
"SupportedLoginProviders": self.supported_login_providers,
"DeveloperProviderName": self.developer_provider_name,
"OpenIdConnectProviderARNs": self.open_id_connect_provider_arns,
"CognitoIdentityProviders": self.cognito_identity_providers,
"SamlProviderARNs": self.saml_provider_arns,
}
)
class CognitoIdentityBackend(BaseBackend):
def __init__(self, region):
@ -54,7 +68,7 @@ class CognitoIdentityBackend(BaseBackend):
"DeveloperProviderName": identity_pool.developer_provider_name,
"IdentityPoolId": identity_pool.identity_pool_id,
"IdentityPoolName": identity_pool.identity_pool_name,
"IdentityPoolTags": {},
"IdentityPoolTags": {}, # TODO: add tags
"OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns,
"SamlProviderARNs": identity_pool.saml_provider_arns,
"SupportedLoginProviders": identity_pool.supported_login_providers,
@ -85,19 +99,38 @@ class CognitoIdentityBackend(BaseBackend):
)
self.identity_pools[new_identity.identity_pool_id] = new_identity
response = json.dumps(
{
"IdentityPoolId": new_identity.identity_pool_id,
"IdentityPoolName": new_identity.identity_pool_name,
"AllowUnauthenticatedIdentities": new_identity.allow_unauthenticated_identities,
"SupportedLoginProviders": new_identity.supported_login_providers,
"DeveloperProviderName": new_identity.developer_provider_name,
"OpenIdConnectProviderARNs": new_identity.open_id_connect_provider_arns,
"CognitoIdentityProviders": new_identity.cognito_identity_providers,
"SamlProviderARNs": new_identity.saml_provider_arns,
}
)
response = new_identity.to_json()
return response
def update_identity_pool(
self,
identity_pool_id,
identity_pool_name,
allow_unauthenticated,
allow_classic,
login_providers,
provider_name,
provider_arns,
identity_providers,
saml_providers,
pool_tags,
):
pool = self.identity_pools[identity_pool_id]
pool.identity_pool_name = pool.identity_pool_name or identity_pool_name
if allow_unauthenticated is not None:
pool.allow_unauthenticated_identities = allow_unauthenticated
if login_providers is not None:
pool.supported_login_providers = login_providers
if provider_name:
pool.developer_provider_name = provider_name
if provider_arns is not None:
pool.open_id_connect_provider_arns = provider_arns
if identity_providers is not None:
pool.cognito_identity_providers = identity_providers
if saml_providers is not None:
pool.saml_provider_arns = saml_providers
response = pool.to_json()
return response
def get_id(self):

View File

@ -27,6 +27,31 @@ class CognitoIdentityResponse(BaseResponse):
saml_provider_arns=saml_provider_arns,
)
def update_identity_pool(self):
pool_id = self._get_param("IdentityPoolId")
pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated = self._get_bool_param("AllowUnauthenticatedIdentities")
allow_classic = self._get_bool_param("AllowClassicFlow")
login_providers = self._get_multi_param_dict("SupportedLoginProviders")
provider_name = self._get_param("DeveloperProviderName")
provider_arns = self._get_multi_param("OpenIdConnectProviderARNs")
identity_providers = self._get_multi_param_dict("CognitoIdentityProviders")
saml_providers = self._get_multi_param("SamlProviderARNs")
pool_tags = self._get_multi_param_dict("IdentityPoolTags")
return cognitoidentity_backends[self.region].update_identity_pool(
identity_pool_id=pool_id,
identity_pool_name=pool_name,
allow_unauthenticated=allow_unauthenticated,
allow_classic=allow_classic,
login_providers=login_providers,
provider_name=provider_name,
provider_arns=provider_arns,
identity_providers=identity_providers,
saml_providers=saml_providers,
pool_tags=pool_tags,
)
def get_id(self):
return cognitoidentity_backends[self.region].get_id()

View File

@ -8,10 +8,8 @@ import json
import os
import time
import uuid
from boto3 import Session
from jose import jws
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID

View File

@ -331,8 +331,11 @@ class CognitoIdpResponse(BaseResponse):
users = [
user
for user in users
for attribute in user.attributes
if attribute["Name"] == name and attribute["Value"] == value
if [
attr
for attr in user.attributes
if attr["Name"] == name and attr["Value"] == value
]
]
response = {"Users": [user.to_json(extended=True) for user in users]}
if token:

View File

@ -4,16 +4,28 @@ from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment
import json
# TODO: add "<Type>Sender</Type>" to error responses below?
SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>{{error_type}}</Code>
<Message>{{message}}</Message>
{% block extra %}{% endblock %}
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID>
<{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</Error>
"""
WRAPPED_SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse{% if xmlns is defined %} xmlns="{{xmlns}}"{% endif %}>
<Error>
<Code>{{error_type}}</Code>
<Message>{{message}}</Message>
{% block extra %}{% endblock %}
<{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</Error>
</ErrorResponse>"""
ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse>
<Errors>
@ -23,7 +35,7 @@ ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
{% block extra %}{% endblock %}
</Error>
</Errors>
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID>
<{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</ErrorResponse>
"""
@ -36,9 +48,12 @@ ERROR_JSON_RESPONSE = """{
class RESTError(HTTPException):
code = 400
# most APIs use <RequestId>, but some APIs (including EC2, S3) use <RequestID>
request_id_tag_name = "RequestId"
templates = {
"single_error": SINGLE_ERROR_RESPONSE,
"wrapped_single_error": WRAPPED_SINGLE_ERROR_RESPONSE,
"error": ERROR_RESPONSE,
"error_json": ERROR_JSON_RESPONSE,
}
@ -49,9 +64,23 @@ class RESTError(HTTPException):
self.error_type = error_type
self.message = message
self.description = env.get_template(template).render(
error_type=error_type, message=message, **kwargs
error_type=error_type,
message=message,
request_id_tag=self.request_id_tag_name,
**kwargs
)
self.content_type = "application/xml"
def get_headers(self, *args, **kwargs):
return [
("X-Amzn-ErrorType", self.error_type or "UnknownError"),
("Content-Type", self.content_type),
]
def get_body(self, *args, **kwargs):
return self.description
class DryRunClientError(RESTError):
code = 400
@ -60,9 +89,7 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError):
def __init__(self, error_type, message, template="error_json", **kwargs):
super(JsonRESTError, self).__init__(error_type, message, template, **kwargs)
def get_headers(self, *args, **kwargs):
return [("Content-Type", "application/json")]
self.content_type = "application/json"
def get_body(self, *args, **kwargs):
return self.description

View File

@ -562,6 +562,7 @@ class CloudFormationModel(BaseModel):
# See for example https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
return "AWS::SERVICE::RESOURCE"
@classmethod
@abstractmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
@ -572,6 +573,7 @@ class CloudFormationModel(BaseModel):
# and return an instance of the resource class
pass
@classmethod
@abstractmethod
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
@ -583,6 +585,7 @@ class CloudFormationModel(BaseModel):
# the change in parameters and no-op when nothing has changed.
pass
@classmethod
@abstractmethod
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name

View File

@ -192,7 +192,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
region_from_useragent_regex = re.compile(
r"region/(?P<region>[a-z]{2}-[a-z]+-\d{1})"
)
param_list_regex = re.compile(r"(.*)\.(\d+)\.")
param_list_regex = re.compile(r"^(\.?[^.]*(\.member)?)\.(\d+)\.")
param_regex = re.compile(r"([^\.]*)\.(\w+)(\..+)?")
access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
)
@ -252,7 +253,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
)
)
)
except UnicodeEncodeError:
except (UnicodeEncodeError, UnicodeDecodeError):
pass # ignore encoding errors, as the body may not contain a legitimate querystring
if not querystring:
querystring.update(headers)
@ -402,7 +403,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
try:
response = method()
except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code)
response_headers = dict(http_error.get_headers() or [])
response_headers["status"] = http_error.code
response = http_error.description, response_headers
if isinstance(response, str):
return 200, headers, response
@ -460,15 +463,23 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _get_bool_param(self, param_name, if_none=None):
val = self._get_param(param_name)
if val is not None:
val = str(val)
if val.lower() == "true":
return True
elif val.lower() == "false":
return False
return if_none
def _get_multi_param_helper(self, param_prefix):
def _get_multi_param_dict(self, param_prefix):
return self._get_multi_param_helper(param_prefix, skip_result_conversion=True)
def _get_multi_param_helper(
self, param_prefix, skip_result_conversion=False, tracked_prefixes=None
):
value_dict = dict()
tracked_prefixes = set() # prefixes which have already been processed
tracked_prefixes = (
tracked_prefixes or set()
) # prefixes which have already been processed
def is_tracked(name_param):
for prefix_loop in tracked_prefixes:
@ -497,23 +508,46 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
name = prefix
value_dict[name] = value
else:
value_dict[name] = value[0]
match = self.param_regex.search(name[len(param_prefix) :])
if match:
# enable access to params that are lists of dicts, e.g., "TagSpecification.1.ResourceType=.."
sub_attr = "%s%s.%s" % (
name[: len(param_prefix)],
match.group(1),
match.group(2),
)
if match.group(3):
value = self._get_multi_param_helper(
sub_attr,
tracked_prefixes=tracked_prefixes,
skip_result_conversion=skip_result_conversion,
)
else:
value = self._get_param(sub_attr)
tracked_prefixes.add(sub_attr)
value_dict[name] = value
else:
value_dict[name] = value[0]
if not value_dict:
return None
if len(value_dict) > 1:
if skip_result_conversion or len(value_dict) > 1:
# strip off period prefix
value_dict = {
name[len(param_prefix) + 1 :]: value
for name, value in value_dict.items()
}
for k in list(value_dict.keys()):
parts = k.split(".")
if len(parts) != 2 or parts[1] != "member":
value_dict[parts[0]] = value_dict.pop(k)
else:
value_dict = list(value_dict.values())[0]
return value_dict
def _get_multi_param(self, param_prefix):
def _get_multi_param(self, param_prefix, skip_result_conversion=False):
"""
Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2
this will return ['my-test-1', 'my-test-2']
@ -525,7 +559,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
values = []
index = 1
while True:
value_dict = self._get_multi_param_helper(prefix + str(index))
value_dict = self._get_multi_param_helper(
prefix + str(index), skip_result_conversion=skip_result_conversion
)
if not value_dict and value_dict != "":
break

View File

@ -4,6 +4,8 @@ from moto.core.exceptions import RESTError
class EC2ClientError(RESTError):
code = 400
# EC2 uses <RequestID> as tag name in the XML response
request_id_tag_name = "RequestID"
class DependencyViolationError(EC2ClientError):
@ -612,7 +614,7 @@ class InvalidAssociationIDIamProfileAssociationError(EC2ClientError):
class InvalidVpcEndPointIdError(EC2ClientError):
def __init__(self, vpc_end_point_id):
super(InvalidVpcEndPointIdError, self).__init__(
"InvalidVpcEndPointId.NotFound",
"InvalidVpcEndpointId.NotFound",
"The VpcEndPoint ID '{0}' does not exist".format(vpc_end_point_id),
)

View File

@ -33,7 +33,7 @@ from moto.core.utils import (
)
from moto.core import ACCOUNT_ID
from moto.kms import kms_backends
from moto.utilities.utils import load_resource
from moto.utilities.utils import load_resource, merge_multiple_dicts
from os import listdir
from .exceptions import (
@ -124,9 +124,12 @@ from .utils import (
random_internet_gateway_id,
random_ip,
random_ipv6_cidr,
random_transit_gateway_attachment_id,
random_transit_gateway_route_table_id,
randor_ipv4_cidr,
random_launch_template_id,
random_nat_gateway_id,
random_transit_gateway_id,
random_key_pair,
random_private_ip,
random_public_ip,
@ -249,8 +252,12 @@ class TaggedEC2Resource(BaseModel):
return [tag["key"] for tag in tags]
elif filter_name == "tag-value":
return [tag["value"] for tag in tags]
else:
raise FilterNotImplementedError(filter_name, method_name)
value = getattr(self, filter_name.lower().replace("-", "_"), None)
if value is not None:
return [value]
raise FilterNotImplementedError(filter_name, method_name)
class NetworkInterface(TaggedEC2Resource, CloudFormationModel):
@ -2015,7 +2022,9 @@ class SecurityRule(object):
class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def __init__(self, ec2_backend, group_id, name, description, vpc_id=None):
def __init__(
self, ec2_backend, group_id, name, description, vpc_id=None, tags=None
):
self.ec2_backend = ec2_backend
self.id = group_id
self.name = name
@ -2027,6 +2036,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
self.enis = {}
self.vpc_id = vpc_id
self.owner_id = OWNER_ID
self.add_tags(tags or {})
# Append default IPv6 egress rule for VPCs with IPv6 support
if vpc_id:
@ -2188,7 +2198,9 @@ class SecurityGroupBackend(object):
super(SecurityGroupBackend, self).__init__()
def create_security_group(self, name, description, vpc_id=None, force=False):
def create_security_group(
self, name, description, vpc_id=None, tags=None, force=False
):
if not description:
raise MissingParameterError("GroupDescription")
@ -2197,7 +2209,9 @@ class SecurityGroupBackend(object):
existing_group = self.get_security_group_from_name(name, vpc_id)
if existing_group:
raise InvalidSecurityGroupDuplicateError(name)
group = SecurityGroup(self, group_id, name, description, vpc_id=vpc_id)
group = SecurityGroup(
self, group_id, name, description, vpc_id=vpc_id, tags=tags
)
self.groups[vpc_id][group_id] = group
return group
@ -3051,10 +3065,12 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
).get("cidr_block"):
raise OperationNotPermitted(association_id)
response = self.cidr_block_association_set.pop(association_id, {})
if response:
entry = response = self.cidr_block_association_set.get(association_id, {})
if entry:
response = json.loads(json.dumps(entry))
response["vpc_id"] = self.id
response["cidr_block_state"]["state"] = "disassociating"
entry["cidr_block_state"]["state"] = "disassociated"
return response
def get_cidr_block_association_set(self, ipv6=False):
@ -3229,7 +3245,7 @@ class VPCBackend(object):
network_interface_ids=[],
dns_entries=None,
client_token=None,
security_group=None,
security_group_ids=None,
tag_specifications=None,
private_dns_enabled=None,
):
@ -3259,6 +3275,7 @@ class VPCBackend(object):
dns_entries = [dns_entries]
vpc_end_point = VPCEndPoint(
self,
vpc_endpoint_id,
vpc_id,
service_name,
@ -3269,7 +3286,7 @@ class VPCBackend(object):
network_interface_ids,
dns_entries,
client_token,
security_group,
security_group_ids,
tag_specifications,
private_dns_enabled,
)
@ -4308,6 +4325,7 @@ class Route(CloudFormationModel):
class VPCEndPoint(TaggedEC2Resource):
def __init__(
self,
ec2_backend,
id,
vpc_id,
service_name,
@ -4318,10 +4336,11 @@ class VPCEndPoint(TaggedEC2Resource):
network_interface_ids=None,
dns_entries=None,
client_token=None,
security_group=None,
security_group_ids=None,
tag_specifications=None,
private_dns_enabled=None,
):
self.ec2_backend = ec2_backend
self.id = id
self.vpc_id = vpc_id
self.service_name = service_name
@ -4331,7 +4350,7 @@ class VPCEndPoint(TaggedEC2Resource):
self.network_interface_ids = network_interface_ids
self.subnet_ids = subnet_ids
self.client_token = client_token
self.security_group = security_group
self.security_group_ids = security_group_ids
self.tag_specifications = tag_specifications
self.private_dns_enabled = private_dns_enabled
self.created_at = datetime.utcnow()
@ -5395,7 +5414,16 @@ class DHCPOptionsSetBackend(object):
class VPNConnection(TaggedEC2Resource):
def __init__(self, ec2_backend, id, type, customer_gateway_id, vpn_gateway_id):
def __init__(
self,
ec2_backend,
id,
type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
tags={},
):
self.ec2_backend = ec2_backend
self.id = id
self.state = "available"
@ -5403,9 +5431,11 @@ class VPNConnection(TaggedEC2Resource):
self.type = type
self.customer_gateway_id = customer_gateway_id
self.vpn_gateway_id = vpn_gateway_id
self.transit_gateway_id = transit_gateway_id
self.tunnels = None
self.options = None
self.static_routes = None
self.add_tags(tags or {})
def get_filter_value(self, filter_name):
return super(VPNConnection, self).get_filter_value(
@ -5419,7 +5449,13 @@ class VPNConnectionBackend(object):
super(VPNConnectionBackend, self).__init__()
def create_vpn_connection(
self, type, customer_gateway_id, vpn_gateway_id, static_routes_only=None
self,
type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
static_routes_only=None,
tags={},
):
vpn_connection_id = random_vpn_connection_id()
if static_routes_only:
@ -5430,6 +5466,8 @@ class VPNConnectionBackend(object):
type=type,
customer_gateway_id=customer_gateway_id,
vpn_gateway_id=vpn_gateway_id,
transit_gateway_id=transit_gateway_id,
tags=tags,
)
self.vpn_connections[vpn_connection.id] = vpn_connection
return vpn_connection
@ -5437,10 +5475,10 @@ class VPNConnectionBackend(object):
def delete_vpn_connection(self, vpn_connection_id):
if vpn_connection_id in self.vpn_connections:
self.vpn_connections.pop(vpn_connection_id)
self.vpn_connections[vpn_connection_id].state = "deleted"
else:
raise InvalidVpnConnectionIdError(vpn_connection_id)
return True
return self.vpn_connections[vpn_connection_id]
def describe_vpn_connections(self, vpn_connection_ids=None):
vpn_connections = []
@ -5723,10 +5761,23 @@ class NetworkAclEntry(TaggedEC2Resource):
class VpnGateway(TaggedEC2Resource):
def __init__(self, ec2_backend, id, type):
def __init__(
self,
ec2_backend,
id,
type,
amazon_side_asn,
availability_zone,
tags=None,
state="available",
):
self.ec2_backend = ec2_backend
self.id = id
self.type = type
self.amazon_side_asn = amazon_side_asn
self.availability_zone = availability_zone
self.state = state
self.add_tags(tags or {})
self.attachments = {}
super(VpnGateway, self).__init__()
@ -5756,9 +5807,13 @@ class VpnGatewayBackend(object):
self.vpn_gateways = {}
super(VpnGatewayBackend, self).__init__()
def create_vpn_gateway(self, type="ipsec.1"):
def create_vpn_gateway(
self, type="ipsec.1", amazon_side_asn=None, availability_zone=None, tags=None
):
vpn_gateway_id = random_vpn_gateway_id()
vpn_gateway = VpnGateway(self, vpn_gateway_id, type)
vpn_gateway = VpnGateway(
self, vpn_gateway_id, type, amazon_side_asn, availability_zone, tags
)
self.vpn_gateways[vpn_gateway_id] = vpn_gateway
return vpn_gateway
@ -5795,13 +5850,17 @@ class VpnGatewayBackend(object):
class CustomerGateway(TaggedEC2Resource):
def __init__(self, ec2_backend, id, type, ip_address, bgp_asn):
def __init__(
self, ec2_backend, id, type, ip_address, bgp_asn, state="available", tags=None
):
self.ec2_backend = ec2_backend
self.id = id
self.type = type
self.ip_address = ip_address
self.bgp_asn = bgp_asn
self.attachments = {}
self.state = state
self.add_tags(tags or {})
super(CustomerGateway, self).__init__()
def get_filter_value(self, filter_name):
@ -5815,17 +5874,44 @@ class CustomerGatewayBackend(object):
self.customer_gateways = {}
super(CustomerGatewayBackend, self).__init__()
def create_customer_gateway(self, type="ipsec.1", ip_address=None, bgp_asn=None):
def create_customer_gateway(
self, type="ipsec.1", ip_address=None, bgp_asn=None, tags=None
):
customer_gateway_id = random_customer_gateway_id()
customer_gateway = CustomerGateway(
self, customer_gateway_id, type, ip_address, bgp_asn
self, customer_gateway_id, type, ip_address, bgp_asn, tags=tags
)
self.customer_gateways[customer_gateway_id] = customer_gateway
return customer_gateway
def get_all_customer_gateways(self, filters=None):
customer_gateways = self.customer_gateways.values()
return generic_filter(filters, customer_gateways)
if filters is not None:
if filters.get("customer-gateway-id") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.id in filters["customer-gateway-id"]
]
if filters.get("type") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.type in filters["type"]
]
if filters.get("bgp-asn") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.bgp_asn in filters["bgp-asn"]
]
if filters.get("ip-address") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.ip_address in filters["ip-address"]
]
return customer_gateways
def get_customer_gateway(self, customer_gateway_id):
customer_gateway = self.customer_gateways.get(customer_gateway_id, None)
@ -5834,12 +5920,425 @@ class CustomerGatewayBackend(object):
return customer_gateway
def delete_customer_gateway(self, customer_gateway_id):
deleted = self.customer_gateways.pop(customer_gateway_id, None)
customer_gateway = self.get_customer_gateway(customer_gateway_id)
customer_gateway.state = "deleted"
# deleted = self.customer_gateways.pop(customer_gateway_id, None)
deleted = True
if not deleted:
raise InvalidCustomerGatewayIdError(customer_gateway_id)
return deleted
class TransitGateway(TaggedEC2Resource, CloudFormationModel):
DEFAULT_OPTIONS = {
"AmazonSideAsn": "64512",
"AssociationDefaultRouteTableId": "tgw-rtb-0d571391e50cf8514",
"AutoAcceptSharedAttachments": "disable",
"DefaultRouteTableAssociation": "enable",
"DefaultRouteTablePropagation": "enable",
"DnsSupport": "enable",
"MulticastSupport": "disable",
"PropagationDefaultRouteTableId": "tgw-rtb-0d571391e50cf8514",
"TransitGatewayCidrBlocks": None,
"VpnEcmpSupport": "enable",
}
def __init__(self, backend, description=None, options=None, tags=None):
self.ec2_backend = backend
self.id = random_transit_gateway_id()
self.description = description
self.state = "available"
self.add_tags(tags or {})
self.options = merge_multiple_dicts(self.DEFAULT_OPTIONS, options or {})
self._created_at = datetime.utcnow()
@property
def physical_resource_id(self):
return self.id
@property
def create_time(self):
return iso_8601_datetime_with_milliseconds(self._created_at)
@property
def owner_id(self):
return ACCOUNT_ID
@staticmethod
def cloudformation_name_type():
return None
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-natgateway.html
return "AWS::EC2::TransitGateway"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
ec2_backend = ec2_backends[region_name]
transit_gateway = ec2_backend.create_transit_gateway(
cloudformation_json["Properties"]["Description"],
cloudformation_json["Properties"]["Options"],
)
return transit_gateway
class TransitGatewayBackend(object):
def __init__(self):
self.transit_gateways = {}
super(TransitGatewayBackend, self).__init__()
def create_transit_gateway(self, description=None, options=None, tags=None):
transit_gateway = TransitGateway(self, description, options, tags)
self.transit_gateways[transit_gateway.id] = transit_gateway
return transit_gateway
def get_all_transit_gateways(self, filters):
transit_gateways = self.transit_gateways.values()
if filters is not None:
if filters.get("transit-gateway-id") is not None:
transit_gateways = [
transit_gateway
for transit_gateway in transit_gateways
if transit_gateway.id in filters["transit-gateway-id"]
]
if filters.get("state") is not None:
transit_gateways = [
transit_gateway
for transit_gateway in transit_gateways
if transit_gateway.state in filters["state"]
]
if filters.get("owner-id") is not None:
transit_gateways = [
transit_gateway
for transit_gateway in transit_gateways
if transit_gateway.owner_id in filters["owner-id"]
]
return transit_gateways
def delete_transit_gateway(self, transit_gateway_id):
return self.transit_gateways.pop(transit_gateway_id)
def modify_transit_gateway(
self, transit_gateway_id, description=None, options=None
):
transit_gateway = self.transit_gateways.get(transit_gateway_id)
if description:
transit_gateway.description = description
if options:
transit_gateway.options.update(options)
return transit_gateway
class TransitGatewayRouteTable(TaggedEC2Resource):
def __init__(
self,
backend,
transit_gateway_id,
tags=None,
default_association_route_table=False,
default_propagation_route_table=False,
):
self.ec2_backend = backend
self.id = random_transit_gateway_route_table_id()
self.transit_gateway_id = transit_gateway_id
self._created_at = datetime.utcnow()
self.default_association_route_table = default_association_route_table
self.default_propagation_route_table = default_propagation_route_table
self.state = "available"
self.routes = {}
self.add_tags(tags or {})
@property
def physical_resource_id(self):
return self.id
@property
def create_time(self):
return iso_8601_datetime_with_milliseconds(self._created_at)
class TransitGatewayRouteTableBackend(object):
def __init__(self):
self.transit_gateways_route_tables = {}
super(TransitGatewayRouteTableBackend, self).__init__()
def create_transit_gateway_route_table(
self,
transit_gateway_id,
tags=None,
default_association_route_table=False,
default_propagation_route_table=False,
):
transit_gateways_route_table = TransitGatewayRouteTable(
self,
transit_gateway_id=transit_gateway_id,
tags=tags,
default_association_route_table=default_association_route_table,
default_propagation_route_table=default_propagation_route_table,
)
self.transit_gateways_route_tables[
transit_gateways_route_table.id
] = transit_gateways_route_table
return transit_gateways_route_table
def get_all_transit_gateway_route_tables(
self, transit_gateway_ids=None, filters=None
):
transit_gateway_route_tables = self.transit_gateways_route_tables.values()
attr_pairs = (
("default-association-route-table", "default_association_route_table"),
("default-propagation-route-table", "default_propagation_route_table"),
("state", "state"),
("transit-gateway-id", "transit_gateway_id"),
("transit-gateway-route-table-id", "id"),
)
if transit_gateway_ids:
transit_gateway_route_tables = [
transit_gateway_route_table
for transit_gateway_route_table in transit_gateway_route_tables
if transit_gateway_route_table.id in transit_gateway_ids
]
if filters:
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values is not None:
transit_gateway_route_tables = [
transit_gateway_route_table
for transit_gateway_route_table in transit_gateway_route_tables
if not values
or getattr(transit_gateway_route_table, attrs[1]) in values
]
return transit_gateway_route_tables
def delete_transit_gateway_route_table(self, transit_gateway_route_table_id):
return self.transit_gateways_route_tables.pop(transit_gateway_route_table_id)
def create_transit_gateway_route(
self,
transit_gateway_route_table_id,
destination_cidr_block,
transit_gateway_attachment_id=None,
blackhole=False,
):
transit_gateways_route_table = self.transit_gateways_route_tables[
transit_gateway_route_table_id
]
transit_gateways_route_table.routes[destination_cidr_block] = {
"destinationCidrBlock": destination_cidr_block,
"prefixListId": "",
"state": "blackhole" if blackhole else "active",
# TODO: needs to be fixed once we have support for transit gateway attachments
"transitGatewayAttachments": {
"resourceId": "TODO",
"resourceType": "TODO",
"transitGatewayAttachmentId": transit_gateway_attachment_id,
},
"type": "TODO",
}
return transit_gateways_route_table
def delete_transit_gateway_route(
self, transit_gateway_route_table_id, destination_cidr_block,
):
transit_gateways_route_table = self.transit_gateways_route_tables[
transit_gateway_route_table_id
]
transit_gateways_route_table.routes[destination_cidr_block]["state"] = "deleted"
return transit_gateways_route_table
def search_transit_gateway_routes(
self, transit_gateway_route_table_id, filters, max_results=None
):
transit_gateway_route_table = self.transit_gateways_route_tables[
transit_gateway_route_table_id
]
attr_pairs = (
("type", "type"),
("state", "state"),
)
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values:
routes = [
transit_gateway_route_table.routes[key]
for key in transit_gateway_route_table.routes
if transit_gateway_route_table.routes[key][attrs[1]] in values
]
if max_results:
routes = routes[: int(max_results)]
return routes
class TransitGatewayAttachment(TaggedEC2Resource):
def __init__(
self, backend, resource_id, resource_type, transit_gateway_id, tags=None
):
self.ec2_backend = backend
self.association = {}
self.resource_id = resource_id
self.resource_type = resource_type
self.id = random_transit_gateway_attachment_id()
self.transit_gateway_id = transit_gateway_id
self.state = "available"
self.add_tags(tags or {})
self._created_at = datetime.utcnow()
@property
def create_time(self):
return iso_8601_datetime_with_milliseconds(self._created_at)
@property
def resource_owner_id(self):
return ACCOUNT_ID
@property
def transit_gateway_owner_id(self):
return ACCOUNT_ID
class TransitGatewayVpcAttachment(TransitGatewayAttachment):
DEFAULT_OPTIONS = {
"ApplianceModeSupport": "disable",
"DnsSupport": "enable",
"Ipv6Support": "disable",
}
def __init__(
self, backend, transit_gateway_id, vpc_id, subnet_ids, tags=None, options=None
):
super().__init__(
backend=backend,
transit_gateway_id=transit_gateway_id,
resource_id=vpc_id,
resource_type="vpc",
tags=tags,
)
self.vpc_id = vpc_id
self.subnet_ids = subnet_ids
self.options = merge_multiple_dicts(self.DEFAULT_OPTIONS, options or {})
class TransitGatewayAttachmentBackend(object):
def __init__(self):
self.transit_gateways_attachments = {}
super(TransitGatewayAttachmentBackend, self).__init__()
def create_transit_gateway_vpn_attachment(
self, vpn_id, transit_gateway_id, tags=[]
):
transit_gateway_vpn_attachment = TransitGatewayAttachment(
self,
resource_id=vpn_id,
resource_type="vpn",
transit_gateway_id=transit_gateway_id,
tags=tags,
)
self.transit_gateways_attachments[
transit_gateway_vpn_attachment.id
] = transit_gateway_vpn_attachment
return transit_gateway_vpn_attachment
def create_transit_gateway_vpc_attachment(
self, transit_gateway_id, vpc_id, subnet_ids, tags=None, options=None
):
transit_gateway_vpc_attachment = TransitGatewayVpcAttachment(
self,
transit_gateway_id=transit_gateway_id,
tags=tags,
vpc_id=vpc_id,
subnet_ids=subnet_ids,
options=options,
)
self.transit_gateways_attachments[
transit_gateway_vpc_attachment.id
] = transit_gateway_vpc_attachment
return transit_gateway_vpc_attachment
def describe_transit_gateway_attachments(
self, transit_gateways_attachment_ids=None, filters=None, max_results=0
):
transit_gateways_attachments = self.transit_gateways_attachments.values()
attr_pairs = (
("resource-id", "resource_id"),
("resource-type", "resource_type"),
("transit-gateway-id", "transit_gateway_id"),
)
if transit_gateways_attachment_ids:
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if transit_gateways_attachment.id in transit_gateways_attachment_ids
]
if filters:
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values is not None:
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if getattr(transit_gateways_attachment, attrs[1]) in values
]
return transit_gateways_attachments
def describe_transit_gateway_vpc_attachments(
self, transit_gateways_attachment_ids=None, filters=None, max_results=0
):
transit_gateways_attachments = self.transit_gateways_attachments.values()
attr_pairs = (
("state", "state"),
("transit-gateway-attachment-id", "id"),
("transit-gateway-id", "transit_gateway_id"),
("vpc-id", "resource_id"),
)
if (
not transit_gateways_attachment_ids == []
and transit_gateways_attachment_ids is not None
):
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if transit_gateways_attachment.id in transit_gateways_attachment_ids
]
if filters:
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values is not None:
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if getattr(transit_gateways_attachment, attrs[1]) in values
]
return transit_gateways_attachments
class NatGateway(CloudFormationModel):
def __init__(self, backend, subnet_id, allocation_id, tags=[]):
# public properties
@ -6199,6 +6698,9 @@ class EC2Backend(
VpnGatewayBackend,
CustomerGatewayBackend,
NatGatewayBackend,
TransitGatewayBackend,
TransitGatewayRouteTableBackend,
TransitGatewayAttachmentBackend,
LaunchTemplateBackend,
IamInstanceProfileAssociationBackend,
):

View File

@ -576,5 +576,39 @@
"name": "suse-sles-11-sp4-v20151207-pv-ssd-x86_64",
"virtualization_type": "paravirtual",
"hypervisor": "xen"
},
{
"ami_id": "ami-ekswin",
"state": "available",
"public": true,
"owner_id": "801119661308",
"image_location": "amazon/amazon-eks",
"sriov": "simple",
"root_device_type": "ebs",
"root_device_name": "/dev/sda1",
"description": "Microsoft Windows Server 2019 Core optimized for EKS and provided by Amazon",
"image_type": "machine",
"platform": "windows",
"architecture": "x86_64",
"name": "Windows_Server-2019-English-Core-EKS_Optimized",
"virtualization_type": "hvm",
"hypervisor": "xen"
},
{
"ami_id": "ami-ekslinux",
"state": "available",
"public": true,
"owner_id": "801119661308",
"image_location": "amazon/amazon-eks",
"sriov": "simple",
"root_device_type": "ebs",
"root_device_name": "/dev/sda1",
"description": "EKS Kubernetes Worker AMI with AmazonLinux2 image",
"image_type": "machine",
"platform": "Linux/UNIX",
"architecture": "x86_64",
"name": "amazon-eks-node-linux",
"virtualization_type": "hvm",
"hypervisor": "xen"
}
]
]

View File

@ -34,6 +34,9 @@ from .vpc_peering_connections import VPCPeeringConnections
from .vpn_connections import VPNConnections
from .windows import Windows
from .nat_gateways import NatGateways
from .transit_gateways import TransitGateways
from .transit_gateway_route_tables import TransitGatewayRouteTable
from .transit_gateway_attachments import TransitGatewayAttachment
from .iam_instance_profiles import IamInstanceProfiles
@ -72,6 +75,9 @@ class EC2Response(
VPNConnections,
Windows,
NatGateways,
TransitGateways,
TransitGatewayRouteTable,
TransitGatewayAttachment,
IamInstanceProfiles,
):
@property

0
moto/ec2/responses/amis.py Executable file → Normal file
View File

View File

@ -9,8 +9,12 @@ class CustomerGateways(BaseResponse):
type = self._get_param("Type")
ip_address = self._get_param("IpAddress")
bgp_asn = self._get_param("BgpAsn")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
customer_gateway = self.ec2_backend.create_customer_gateway(
type, ip_address=ip_address, bgp_asn=bgp_asn
type, ip_address=ip_address, bgp_asn=bgp_asn, tags=tags
)
template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=customer_gateway)
@ -19,7 +23,7 @@ class CustomerGateways(BaseResponse):
customer_gateway_id = self._get_param("CustomerGatewayId")
delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id)
template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=delete_status)
return template.render(delete_status=delete_status)
def describe_customer_gateways(self):
filters = filters_from_querystring(self.querystring)
@ -33,20 +37,18 @@ CREATE_CUSTOMER_GATEWAY_RESPONSE = """
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<customerGateway>
<customerGatewayId>{{ customer_gateway.id }}</customerGatewayId>
<state>pending</state>
<state>{{ customer_gateway.state }}</state>
<type>{{ customer_gateway.type }}</type>
<ipAddress>{{ customer_gateway.ip_address }}</ipAddress>
<bgpAsn>{{ customer_gateway.bgp_asn }}</bgpAsn>
<tagSet>
{% for tag in customer_gateway.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<tagSet>
{% for tag in customer_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</customerGateway>
</CreateCustomerGatewayResponse>"""
@ -64,19 +66,17 @@ DESCRIBE_CUSTOMER_GATEWAYS_RESPONSE = """
<item>
<customerGatewayId>{{ customer_gateway.id }}</customerGatewayId>
<state>{{ customer_gateway.state }}</state>
<type>available</type>
<type>{{ customer_gateway.type }}</type>
<ipAddress>{{ customer_gateway.ip_address }}</ipAddress>
<bgpAsn>{{ customer_gateway.bgp_asn }}</bgpAsn>
<tagSet>
{% for tag in customer_gateway.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<tagSet>
{% for tag in customer_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</item>
{% endfor %}
</customerGatewaySet>

View File

@ -14,9 +14,11 @@ class InternetGateways(BaseResponse):
def create_internet_gateway(self):
if self.is_not_dryrun("CreateInternetGateway"):
tags = self._get_multi_param("TagSpecification")
tags = self._get_multi_param(
"TagSpecification", skip_result_conversion=True
)
if tags:
tags = tags[0].get("Tag")
tags = tags[0].get("Tag") or []
igw = self.ec2_backend.create_internet_gateway(tags=tags)
template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE)
return template.render(internet_gateway=igw)

View File

@ -40,9 +40,9 @@ class RouteTables(BaseResponse):
def create_route_table(self):
vpc_id = self._get_param("VpcId")
tags = self._get_multi_param("TagSpecification")
tags = self._get_multi_param("TagSpecification", skip_result_conversion=True)
if tags:
tags = tags[0].get("Tag")
tags = tags[0].get("Tag") or []
route_table = self.ec2_backend.create_route_table(vpc_id, tags)
template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE)
return template.render(route_table=route_table)

View File

@ -115,10 +115,14 @@ class SecurityGroups(BaseResponse):
name = self._get_param("GroupName")
description = self._get_param("GroupDescription")
vpc_id = self._get_param("VpcId")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
if self.is_not_dryrun("CreateSecurityGroup"):
group = self.ec2_backend.create_security_group(
name, description, vpc_id=vpc_id
name, description, vpc_id=vpc_id, tags=tags
)
template = self.response_template(CREATE_SECURITY_GROUP_RESPONSE)
return template.render(group=group)

View File

@ -64,13 +64,13 @@ CREATE_SUBNET_RESPONSE = """
<state>pending</state>
<vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>{{ subnet.available_ip_addresses }}</availableIpAddressCount>
<availableIpAddressCount>{{ subnet.available_ip_addresses or '0' }}</availableIpAddressCount>
<availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
<defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<ownerId>{{ subnet.owner_id }}</ownerId>
<assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation>
<assignIpv6AddressOnCreation>{{ 'false' if not subnet.assign_ipv6_address_on_creation or subnet.assign_ipv6_address_on_creation == 'false' else 'true'}}</assignIpv6AddressOnCreation>
<ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
<tagSet>
@ -102,13 +102,13 @@ DESCRIBE_SUBNETS_RESPONSE = """
<state>{{ subnet.state }}</state>
<vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>{{ subnet.available_ip_addresses }}</availableIpAddressCount>
<availableIpAddressCount>{{ subnet.available_ip_addresses or '0' }}</availableIpAddressCount>
<availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
<defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<ownerId>{{ subnet.owner_id }}</ownerId>
<assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation>
<assignIpv6AddressOnCreation>{{ 'false' if not subnet.assign_ipv6_address_on_creation or subnet.assign_ipv6_address_on_creation == 'false' else 'true'}}</assignIpv6AddressOnCreation>
<ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
{% if subnet.get_tags() %}

View File

@ -0,0 +1,155 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
class TransitGatewayAttachment(BaseResponse):
def create_transit_gateway_vpc_attachment(self):
options = self._get_multi_param_dict("Options")
subnet_ids = self._get_multi_param("SubnetIds")
transit_gateway_id = self._get_param("TransitGatewayId")
vpc_id = self._get_param("VpcId")
tags = self._get_multi_param("TagSpecifications")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
transit_gateway_attachment = self.ec2_backend.create_transit_gateway_vpc_attachment(
transit_gateway_id=transit_gateway_id,
tags=tags,
vpc_id=vpc_id,
subnet_ids=subnet_ids,
options=options,
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_VPC_ATTACHMENT)
return template.render(transit_gateway_attachment=transit_gateway_attachment)
def describe_transit_gateway_vpc_attachments(self):
transit_gateways_attachment_ids = self._get_multi_param(
"TransitGatewayAttachmentIds"
)
filters = filters_from_querystring(self.querystring)
max_results = self._get_param("MaxResults")
transit_gateway_vpc_attachments = self.ec2_backend.describe_transit_gateway_vpc_attachments(
transit_gateways_attachment_ids=transit_gateways_attachment_ids,
filters=filters,
max_results=max_results,
)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_VPC_ATTACHMENTS)
return template.render(
transit_gateway_vpc_attachments=transit_gateway_vpc_attachments
)
def describe_transit_gateway_attachments(self):
transit_gateways_attachment_ids = self._get_multi_param(
"TransitGatewayAttachmentIds"
)
filters = filters_from_querystring(self.querystring)
max_results = self._get_param("MaxResults")
transit_gateway_attachments = self.ec2_backend.describe_transit_gateway_attachments(
transit_gateways_attachment_ids=transit_gateways_attachment_ids,
filters=filters,
max_results=max_results,
)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_ATTACHMENTS)
return template.render(transit_gateway_attachments=transit_gateway_attachments)
CREATE_TRANSIT_GATEWAY_VPC_ATTACHMENT = """<CreateTransitGatewayVpcAttachmentResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>9b5766ac-2af6-4b92-9a8a-4d74ae46ae79</requestId>
<transitGatewayVpcAttachment>
<createTime>{{ transit_gateway_attachment.create_time }}</createTime>
<options>
<applianceModeSupport>{{ transit_gateway_attachment.options.ApplianceModeSupport }}</applianceModeSupport>
<dnsSupport>{{ transit_gateway_attachment.options.DnsSupport }}</dnsSupport>
<ipv6Support>{{ transit_gateway_attachment.options.Ipv6Support }}</ipv6Support>
</options>
<state>{{ transit_gateway_attachment.state }}</state>
<subnetIds>
{% for subnet_id in transit_gateway_attachment.subnet_ids %}
<item>{{ subnet_id }}</item>
{% endfor %}
</subnetIds>
<tagSet>
{% for tag in transit_gateway_attachment.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayAttachmentId>{{ transit_gateway_attachment.id }}</transitGatewayAttachmentId>
<transitGatewayId>{{ transit_gateway_attachment.transit_gateway_id }}</transitGatewayId>
<vpcId>{{ transit_gateway_attachment.vpc_id }}</vpcId>
<vpcOwnerId>{{ transit_gateway_attachment.resource_owner_id }}</vpcOwnerId>
</transitGatewayVpcAttachment>
</CreateTransitGatewayVpcAttachmentResponse>"""
DESCRIBE_TRANSIT_GATEWAY_ATTACHMENTS = """<DescribeTransitGatewayAttachmentsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>92aa7885-74c0-42d1-a846-e59bd07488a7</requestId>
<transitGatewayAttachments>
{% for transit_gateway_attachment in transit_gateway_attachments %}
<item>
<association>
<state>associated</state>
<transitGatewayRouteTableId>tgw-rtb-0b36edb9b88f0d5e3</transitGatewayRouteTableId>
</association>
<creationTime>2021-07-18T08:57:21.000Z</creationTime>
<resourceId>{{ transit_gateway_attachment.resource_id }}</resourceId>
<resourceOwnerId>{{ transit_gateway_attachment.resource_owner_id }}</resourceOwnerId>
<resourceType>{{ transit_gateway_attachment.resource_type }}</resourceType>
<state>{{ transit_gateway_attachment.state }}</state>
<tagSet>
{% for tag in transit_gateway_attachment.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayAttachmentId>{{ transit_gateway_attachment.id }}</transitGatewayAttachmentId>
<transitGatewayId>{{ transit_gateway_attachment.transit_gateway_id }}</transitGatewayId>
<transitGatewayOwnerId>074255357339</transitGatewayOwnerId>
</item>
{% endfor %}
</transitGatewayAttachments>
</DescribeTransitGatewayAttachmentsResponse>
"""
DESCRIBE_TRANSIT_GATEWAY_VPC_ATTACHMENTS = """<DescribeTransitGatewayVpcAttachmentsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>bebc9670-0205-4f28-ad89-049c97e46633</requestId>
<transitGatewayVpcAttachments>
{% for transit_gateway_vpc_attachment in transit_gateway_vpc_attachments %}
<item>
<creationTime>2021-07-18T08:57:21.000Z</creationTime>
<options>
<applianceModeSupport>{{ transit_gateway_vpc_attachment.options.ApplianceModeSupport }}</applianceModeSupport>
<dnsSupport>{{ transit_gateway_vpc_attachment.options.DnsSupport }}</dnsSupport>
<ipv6Support>{{ transit_gateway_vpc_attachment.options.Ipv6Support }}</ipv6Support>
</options>
<state>{{ transit_gateway_vpc_attachment.state }}</state>
<subnetIds>
{% for id in transit_gateway_vpc_attachment.subnet_ids %}
<item>id</item>
{% endfor %}
</subnetIds>
<tagSet>
{% for tag in transit_gateway_vpc_attachment.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayAttachmentId>{{ transit_gateway_vpc_attachment.id }}</transitGatewayAttachmentId>
<transitGatewayId>{{ transit_gateway_vpc_attachment.transit_gateway_id }}</transitGatewayId>
<vpcId>{{ transit_gateway_vpc_attachment.vpc_id }}</vpcId>
<vpcOwnerId>074255357339</vpcOwnerId>
</item>
{% endfor %}
</transitGatewayVpcAttachments>
</DescribeTransitGatewayVpcAttachmentsResponse>
"""

View File

@ -0,0 +1,187 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
from moto.utilities.utils import str2bool
class TransitGatewayRouteTable(BaseResponse):
def create_transit_gateway_route_table(self):
transit_gateway_id = self._get_param("TransitGatewayId")
tags = self._get_multi_param("TagSpecifications")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
transit_gateway_route_table = self.ec2_backend.create_transit_gateway_route_table(
transit_gateway_id=transit_gateway_id, tags=tags
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE)
return template.render(transit_gateway_route_table=transit_gateway_route_table)
def describe_transit_gateway_route_tables(self):
filters = filters_from_querystring(self.querystring)
transit_gateway_ids = (
self._get_multi_param("TransitGatewayRouteTableIds") or None
)
transit_gateway_route_tables = self.ec2_backend.get_all_transit_gateway_route_tables(
transit_gateway_ids, filters
)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE)
return template.render(
transit_gateway_route_tables=transit_gateway_route_tables
)
def delete_transit_gateway_route_table(self):
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
transit_gateway_route_table = self.ec2_backend.delete_transit_gateway_route_table(
transit_gateway_route_table_id
)
template = self.response_template(DELETE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE)
return template.render(transit_gateway_route_table=transit_gateway_route_table)
def create_transit_gateway_route(self):
transit_gateway_attachment_id = self._get_param("TransitGatewayAttachmentId")
destination_cidr_block = self._get_param("DestinationCidrBlock")
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
blackhole = str2bool(self._get_param("Blackhole"))
transit_gateways_route_table = self.ec2_backend.create_transit_gateway_route(
destination_cidr_block=destination_cidr_block,
transit_gateway_route_table_id=transit_gateway_route_table_id,
transit_gateway_attachment_id=transit_gateway_attachment_id,
blackhole=blackhole,
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_ROUTE_RESPONSE)
return template.render(
transit_gateway_route_table=transit_gateways_route_table,
destination_cidr_block=destination_cidr_block,
)
def delete_transit_gateway_route(self):
destination_cidr_block = self._get_param("DestinationCidrBlock")
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
transit_gateway_route_table = self.ec2_backend.delete_transit_gateway_route(
destination_cidr_block=destination_cidr_block,
transit_gateway_route_table_id=transit_gateway_route_table_id,
)
template = self.response_template(DELETE_TRANSIT_GATEWAY_ROUTE_RESPONSE)
rendered_template = template.render(
transit_gateway_route_table=transit_gateway_route_table,
destination_cidr_block=destination_cidr_block,
)
del transit_gateway_route_table.routes[destination_cidr_block]
return rendered_template
def search_transit_gateway_routes(self):
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
filters = filters_from_querystring(self.querystring)
max_results = self._get_param("MaxResults")
transit_gateway_routes = self.ec2_backend.search_transit_gateway_routes(
transit_gateway_route_table_id=transit_gateway_route_table_id,
filters=filters,
max_results=max_results,
)
template = self.response_template(SEARCH_TRANSIT_GATEWAY_ROUTES_RESPONSE)
return template.render(transit_gateway_routes=transit_gateway_routes)
CREATE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE = """<CreateTransitGatewayRouteTableResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>3a495d25-08d4-466d-822e-477c9b1fc606</requestId>
<transitGatewayRouteTable>
<creationTime>{{ transit_gateway_route_table.create_time }}</creationTime>
<defaultAssociationRouteTable>{{ transit_gateway_route_table.default_association_route_table }}</defaultAssociationRouteTable>
<defaultPropagationRouteTable>{{ transit_gateway_route_table.default_propagation_route_table }}</defaultPropagationRouteTable>
<state>{{ transit_gateway_route_table.state }}</state>
<transitGatewayId>{{ transit_gateway_route_table.transit_gateway_id }}</transitGatewayId>
<transitGatewayRouteTableId>{{ transit_gateway_route_table.id }}</transitGatewayRouteTableId>
<tagSet>
{% for tag in transit_gateway_route_table.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</transitGatewayRouteTable>
</CreateTransitGatewayRouteTableResponse>
"""
DESCRIBE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE = """<DescribeTransitGatewayRouteTablesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>f9dea58a-7bb3-458b-a40d-0b7ae32eefdb</requestId>
<transitGatewayRouteTables>
{% for transit_gateway_route_table in transit_gateway_route_tables %}
<item>
<creationTime>{{ transit_gateway_route_table.create_time }}</creationTime>
<defaultAssociationRouteTable>{{ transit_gateway_route_table.default_association_route_table }}</defaultAssociationRouteTable>
<defaultPropagationRouteTable>{{ transit_gateway_route_table.default_propagation_route_table }}</defaultPropagationRouteTable>
<state>{{ transit_gateway_route_table.state }}</state>
<tagSet>
{% for tag in transit_gateway_route_table.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayId>{{ transit_gateway_route_table.transit_gateway_id }}</transitGatewayId>
<transitGatewayRouteTableId>{{ transit_gateway_route_table.id }}</transitGatewayRouteTableId>
</item>
{% endfor %}
</transitGatewayRouteTables>
</DescribeTransitGatewayRouteTablesResponse>
"""
DELETE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE = """<DeleteTransitGatewayRouteTableResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>a9a07226-c7b1-4305-9934-0bcfc3ef1c5e</requestId>
<transitGatewayRouteTable>
{% for transit_gateway_route_table in transit_gateway_route_tables %}
<item>
<creationTime>{{ transit_gateway_route_table.create_time }}</creationTime>
<defaultAssociationRouteTable>{{ transit_gateway_route_table.default_association_route_table }}</defaultAssociationRouteTable>
<defaultPropagationRouteTable>{{ transit_gateway_route_table.default_propagation_route_table }}</defaultPropagationRouteTable>
<state>{{ transit_gateway_route_table.state }}</state>
<transitGatewayId>{{ transit_gateway_route_table.transit_gateway_id }}</transitGatewayId>
<transitGatewayRouteTableId>{{ transit_gateway_route_table.id }}</transitGatewayRouteTableId>
</item>
{% endfor %}
</transitGatewayRouteTable>
</DeleteTransitGatewayRouteTableResponse>
"""
CREATE_TRANSIT_GATEWAY_ROUTE_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<CreateTransitGatewayRouteResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>072b02ce-df3a-4de6-a20b-6653ae4b91a4</requestId>
<route>
<destinationCidrBlock>{{ transit_gateway_route_table.routes[destination_cidr_block]['destinationCidrBlock'] }}</destinationCidrBlock>
<state>{{ transit_gateway_route_table.routes[destination_cidr_block]['state'] }}</state>
<type>{{ transit_gateway_route_table.routes[destination_cidr_block]['type'] }}</type>
</route>
</CreateTransitGatewayRouteResponse>
"""
DELETE_TRANSIT_GATEWAY_ROUTE_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<DeleteTransitGatewayRouteResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>2109d5bb-f874-4f35-b419-4723792a638f</requestId>
<route>
<destinationCidrBlock>{{ transit_gateway_route_table.routes[destination_cidr_block]['destinationCidrBlock'] }}</destinationCidrBlock>
<state>{{ transit_gateway_route_table.routes[destination_cidr_block]['state'] }}</state>
<type>{{ transit_gateway_route_table.routes[destination_cidr_block]['type'] }}</type>
</route>
</DeleteTransitGatewayRouteResponse>
"""
SEARCH_TRANSIT_GATEWAY_ROUTES_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<SearchTransitGatewayRoutesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>04b46ad2-5a0e-46db-afe4-68679a193b48</requestId>
<routeSet>
{% for route in transit_gateway_routes %}
<item>
<destinationCidrBlock>{{ route['destinationCidrBlock'] }}</destinationCidrBlock>
<state>{{ route['state'] }}</state>
<type>{{ route['type'] }}</type>
</item>
{% endfor %}
</routeSet>
<additionalRoutesAvailable>false</additionalRoutesAvailable>
</SearchTransitGatewayRoutesResponse>
"""

View File

@ -0,0 +1,157 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
class TransitGateways(BaseResponse):
def create_transit_gateway(self):
description = self._get_param("Description") or None
options = self._get_multi_param_dict("Options")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
transit_gateway = self.ec2_backend.create_transit_gateway(
description=description, options=options, tags=tags
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateway=transit_gateway)
def delete_transit_gateway(self):
transit_gateway_id = self._get_param("TransitGatewayId")
transit_gateway = self.ec2_backend.delete_transit_gateway(transit_gateway_id)
template = self.response_template(DELETE_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateway=transit_gateway)
def describe_transit_gateways(self):
filters = filters_from_querystring(self.querystring)
transit_gateways = self.ec2_backend.get_all_transit_gateways(filters)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateways=transit_gateways)
def modify_transit_gateway(self):
transit_gateway_id = self._get_param("TransitGatewayId")
description = self._get_param("Description") or None
options = self._get_multi_param_dict("Options")
transit_gateway = self.ec2_backend.modify_transit_gateway(
transit_gateway_id=transit_gateway_id,
description=description,
options=options,
)
template = self.response_template(MODIFY_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateway=transit_gateway)
CREATE_TRANSIT_GATEWAY_RESPONSE = """<CreateTransitGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGateway>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
<ownerId>{{ transit_gateway.owner_id }}</ownerId>
<description>{{ transit_gateway.description or '' }}</description>
<createTime>{{ transit_gateway.create_time }}</createTime>
<state>{{ transit_gateway.state }}</state>
{% if transit_gateway.options %}
<options>
<amazonSideAsn>{{ transit_gateway.options.AmazonSideAsn }}</amazonSideAsn>
<autoAcceptSharedAttachments>{{ transit_gateway.options.AutoAcceptSharedAttachments }}</autoAcceptSharedAttachments>
<defaultRouteTableAssociation>{{ transit_gateway.options.DefaultRouteTableAssociation }}</defaultRouteTableAssociation>
<defaultRouteTablePropagation>{{ transit_gateway.options.DefaultRouteTablePropagation }}</defaultRouteTablePropagation>
<dnsSupport>{{ transit_gateway.options.DnsSupport }}</dnsSupport>
<propagationDefaultRouteTableId>{{ transit_gateway.options.PropagationDefaultRouteTableId }}</propagationDefaultRouteTableId>
<vpnEcmpSupport>{{ transit_gateway.options.VpnEcmpSupport }}</vpnEcmpSupport>
<transitGatewayCidrBlocks>{{ transit_gateway.options.TransitGatewayCidrBlocks }}</transitGatewayCidrBlocks>
</options>
{% endif %}
<tagSet>
{% for tag in transit_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</transitGateway>
</CreateTransitGatewayResponse>
"""
DESCRIBE_TRANSIT_GATEWAY_RESPONSE = """<DescribeTransitGatewaysResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGatewaySet>
{% for transit_gateway in transit_gateways %}
<item>
<creationTime>{{ transit_gateway.create_time }}</creationTime>
<description>{{ transit_gateway.description or '' }}</description>
{% if transit_gateway.options %}
<options>
<amazonSideAsn>{{ transit_gateway.options.AmazonSideAsn }}</amazonSideAsn>
<associationDefaultRouteTableId>{{ transit_gateway.options.AssociationDefaultRouteTableId }}</associationDefaultRouteTableId>
<autoAcceptSharedAttachments>{{ transit_gateway.options.AutoAcceptSharedAttachments }}</autoAcceptSharedAttachments>
<defaultRouteTableAssociation>{{ transit_gateway.options.DefaultRouteTableAssociation }}</defaultRouteTableAssociation>
<defaultRouteTablePropagation>{{ transit_gateway.options.DefaultRouteTablePropagation }}</defaultRouteTablePropagation>
<dnsSupport>{{ transit_gateway.options.DnsSupport }}</dnsSupport>
<propagationDefaultRouteTableId>{{ transit_gateway.options.PropagationDefaultRouteTableId }}</propagationDefaultRouteTableId>
<vpnEcmpSupport>{{ transit_gateway.options.VpnEcmpSupport }}</vpnEcmpSupport>
<transitGatewayCidrBlocks>{{ transit_gateway.options.TransitGatewayCidrBlocks }}</transitGatewayCidrBlocks>
</options>
{% endif %}
<ownerId>{{ transit_gateway.owner_id }}</ownerId>
<state>{{ transit_gateway.state }}</state>
<tagSet>
{% for tag in transit_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayArn>arn:aws:ec2:us-east-1:{{ transit_gateway.owner_id }}:transit-gateway/{{ transit_gateway.id }}</transitGatewayArn>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
</item>
{% endfor %}
</transitGatewaySet>
</DescribeTransitGatewaysResponse>
"""
DELETE_TRANSIT_GATEWAY_RESPONSE = """<DeleteTransitGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
</DeleteTransitGatewayResponse>
"""
MODIFY_TRANSIT_GATEWAY_RESPONSE = """<ModifyTransitGatewaysResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGatewaySet>
<item>
<creationTime>{{ transit_gateway.create_time }}</creationTime>
<description>{{ transit_gateway.description or '' }}</description>
{% if transit_gateway.options %}
<options>
<amazonSideAsn>{{ transit_gateway.options.AmazonSideAsn }}</amazonSideAsn>
<associationDefaultRouteTableId>{{ transit_gateway.options.AssociationDefaultRouteTableId }}</associationDefaultRouteTableId>
<autoAcceptSharedAttachments>{{ transit_gateway.options.AutoAcceptSharedAttachments }}</autoAcceptSharedAttachments>
<defaultRouteTableAssociation>{{ transit_gateway.options.DefaultRouteTableAssociation }}</defaultRouteTableAssociation>
<defaultRouteTablePropagation>{{ transit_gateway.options.DefaultRouteTablePropagation }}</defaultRouteTablePropagation>
<dnsSupport>{{ transit_gateway.options.DnsSupport }}</dnsSupport>
<propagationDefaultRouteTableId>{{ transit_gateway.options.PropagationDefaultRouteTableId }}</propagationDefaultRouteTableId>
<vpnEcmpSupport>{{ transit_gateway.options.VpnEcmpSupport }}</vpnEcmpSupport>
<transitGatewayCidrBlocks>{{ transit_gateway.options.TransitGatewayCidrBlocks }}</transitGatewayCidrBlocks>
</options>
{% endif %}
<ownerId>{{ transit_gateway.owner_id }}</ownerId>
<state>{{ transit_gateway.state }}</state>
<tagSet>
{% for tag in transit_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayArn>arn:aws:ec2:us-east-1:{{ transit_gateway.owner_id }}:transit-gateway/{{ transit_gateway.id }}</transitGatewayArn>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
</item>
</transitGatewaySet>
</ModifyTransitGatewaysResponse>
"""

View File

@ -13,7 +13,18 @@ class VirtualPrivateGateways(BaseResponse):
def create_vpn_gateway(self):
type = self._get_param("Type")
vpn_gateway = self.ec2_backend.create_vpn_gateway(type)
amazon_side_asn = self._get_param("AmazonSideAsn")
availability_zone = self._get_param("AvailabilityZone")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
vpn_gateway = self.ec2_backend.create_vpn_gateway(
type=type,
amazon_side_asn=amazon_side_asn,
availability_zone=availability_zone,
tags=tags,
)
template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE)
return template.render(vpn_gateway=vpn_gateway)
@ -44,7 +55,7 @@ CREATE_VPN_GATEWAY_RESPONSE = """
<vpnGatewayId>{{ vpn_gateway.id }}</vpnGatewayId>
<state>available</state>
<type>{{ vpn_gateway.type }}</type>
<availabilityZone>us-east-1a</availabilityZone>
<availabilityZone>{{ vpn_gateway.availability_zone }}</availabilityZone>
<attachments/>
<tagSet>
{% for tag in vpn_gateway.get_tags() %}

View File

@ -17,16 +17,16 @@ class VPCs(BaseResponse):
cidr_block = self._get_param("CidrBlock")
tags = self._get_multi_param("TagSpecification")
instance_tenancy = self._get_param("InstanceTenancy", if_none="default")
amazon_provided_ipv6_cidr_blocks = self._get_param(
amazon_provided_ipv6_cidr_block = self._get_param(
"AmazonProvidedIpv6CidrBlock"
)
) in ["true", "True"]
if tags:
tags = tags[0].get("Tag")
vpc = self.ec2_backend.create_vpc(
cidr_block,
instance_tenancy,
amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_blocks,
amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_block,
tags=tags,
)
doc_date = self._get_doc_date()
@ -178,8 +178,8 @@ class VPCs(BaseResponse):
policy_document = self._get_param("PolicyDocument")
client_token = self._get_param("ClientToken")
tag_specifications = self._get_param("TagSpecifications")
private_dns_enabled = self._get_bool_param("PrivateDNSEnabled", if_none=True)
security_group = self._get_param("SecurityGroup")
private_dns_enabled = self._get_bool_param("PrivateDnsEnabled", if_none=True)
security_group_ids = self._get_multi_param("SecurityGroupId")
vpc_end_point = self.ec2_backend.create_vpc_endpoint(
vpc_id=vpc_id,
@ -189,7 +189,7 @@ class VPCs(BaseResponse):
route_table_ids=route_table_ids,
subnet_ids=subnet_ids,
client_token=client_token,
security_group=security_group,
security_group_ids=security_group_ids,
tag_specifications=tag_specifications,
private_dns_enabled=private_dns_enabled,
)
@ -479,8 +479,8 @@ DESCRIBE_VPC_ENDPOINT_SERVICES_RESPONSE = """<DescribeVpcEndpointServicesRespons
{% endfor %}
</serviceNameSet>
<serviceDetailSet>
{% for service in vpc_end_points.servicesDetails %}
<item>
{% for service in vpc_end_points.servicesDetails %}
<owner>amazon</owner>
<serviceType>
<item>
@ -498,8 +498,8 @@ DESCRIBE_VPC_ENDPOINT_SERVICES_RESPONSE = """<DescribeVpcEndpointServicesRespons
</availabilityZoneSet>
<serviceName>{{ service.service_name }}</serviceName>
<vpcEndpointPolicySupported>true</vpcEndpointPolicySupported>
{% endfor %}
</item>
{% endfor %}
</serviceDetailSet>
</DescribeVpcEndpointServicesResponse>"""
@ -545,12 +545,15 @@ DESCRIBE_VPC_ENDPOINT_RESPONSE = """<DescribeVpcEndpointsResponse xmlns="http://
{% endfor %}
</dnsEntries>
{% endif %}
{% if vpc_end_point.groups %}
<groups>
{% for group in vpc_end_point.groups %}
<item>{{ group }}</item>
{% if vpc_end_point.security_group_ids %}
<groupSet>
{% for group_id in vpc_end_point.security_group_ids %}
<item>
<groupId>{{ group_id }}</groupId>
<groupName>TODO</groupName>
</item>
{% endfor %}
</groups>
</groupSet>
{% endif %}
{% if vpc_end_point.tag_specifications %}
<tagSet>

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
from moto.ec2.utils import filters_from_querystring, add_tag_specification
from xml.sax.saxutils import escape
class VPNConnections(BaseResponse):
@ -8,16 +9,34 @@ class VPNConnections(BaseResponse):
type = self._get_param("Type")
cgw_id = self._get_param("CustomerGatewayId")
vgw_id = self._get_param("VpnGatewayId")
tgw_id = self._get_param("TransitGatewayId")
static_routes = self._get_param("StaticRoutesOnly")
tags = add_tag_specification(self._get_multi_param("TagSpecification"))
vpn_connection = self.ec2_backend.create_vpn_connection(
type, cgw_id, vgw_id, static_routes_only=static_routes
type,
cgw_id,
vpn_gateway_id=vgw_id,
transit_gateway_id=tgw_id,
static_routes_only=static_routes,
tags=tags,
)
if vpn_connection.transit_gateway_id:
self.ec2_backend.create_transit_gateway_vpn_attachment(
vpn_id=vpn_connection.id, transit_gateway_id=tgw_id
)
template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection)
def delete_vpn_connection(self):
vpn_connection_id = self._get_param("VpnConnectionId")
vpn_connection = self.ec2_backend.delete_vpn_connection(vpn_connection_id)
if vpn_connection.transit_gateway_id:
transit_gateway_attachments = (
self.ec2_backend.describe_transit_gateway_attachments()
)
for attachment in transit_gateway_attachments:
if attachment.resource_id == vpn_connection.id:
attachment.state = "deleted"
template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection)
@ -31,17 +50,11 @@ class VPNConnections(BaseResponse):
return template.render(vpn_connections=vpn_connections)
CREATE_VPN_CONNECTION_RESPONSE = """
<CreateVpnConnectionResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<vpnConnection>
<vpnConnectionId>{{ vpn_connection.id }}</vpnConnectionId>
<state>pending</state>
<customerGatewayConfiguration>
<vpn_connection id="{{ vpn_connection.id }}">
CUSTOMER_GATEWAY_CONFIGURATION_TEMPLATE = """
<vpn_connection id="{{ vpn_connection.id }}">
<customer_gateway_id>{{ vpn_connection.customer_gateway_id }}</customer_gateway_id>
<vpn_gateway_id>{{ vpn_connection.vpn_gateway_id }}</vpn_gateway_id>
<vpn_connection_type>ipsec.1</vpn_connection_type>
<vpn_gateway_id> {{ vpn_connection.vpn_gateway_id if vpn_connection.vpn_gateway_id is not none }} </vpn_gateway_id>
<vpn_connection_type>{{ vpn_connection.type }}</vpn_connection_type>
<ipsec_tunnel>
<customer_gateway>
<tunnel_outside_address>
@ -149,15 +162,29 @@ CREATE_VPN_CONNECTION_RESPONSE = """
</ipsec>
</ipsec_tunnel>
</vpn_connection>
"""
CREATE_VPN_CONNECTION_RESPONSE = (
"""
<CreateVpnConnectionResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<vpnConnection>
<vpnConnectionId>{{ vpn_connection.id }}</vpnConnectionId>
<state>{{ vpn_connection.state }}</state>
<customerGatewayConfiguration>
"""
+ escape(CUSTOMER_GATEWAY_CONFIGURATION_TEMPLATE)
+ """
</customerGatewayConfiguration>
<type>ipsec.1</type>
<customerGatewayId>{{ vpn_connection.customer_gateway_id }}</customerGatewayId>
<vpnGatewayId>{{ vpn_connection.vpn_gateway_id }}</vpnGatewayId>
<vpnGatewayId>{{ vpn_connection.vpn_gateway_id or '' }}</vpnGatewayId>
{% if vpn_connection.transit_gateway_id %}
<transitGatewayId>{{ vpn_connection.transit_gateway_id }}</transitGatewayId>
{% endif %}
<tagSet>
{% for tag in vpn_connection.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
@ -165,6 +192,8 @@ CREATE_VPN_CONNECTION_RESPONSE = """
</tagSet>
</vpnConnection>
</CreateVpnConnectionResponse>"""
)
CREATE_VPN_CONNECTION_ROUTE_RESPONSE = """
<CreateVpnConnectionRouteResponse xmlns="http://ec2.amazonaws.com/doc/2013-10- 15/">
@ -184,135 +213,29 @@ DELETE_VPN_CONNECTION_ROUTE_RESPONSE = """
<return>true</return>
</DeleteVpnConnectionRouteResponse>"""
DESCRIBE_VPN_CONNECTION_RESPONSE = """
DESCRIBE_VPN_CONNECTION_RESPONSE = (
"""
<DescribeVpnConnectionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<vpnConnectionSet>
{% for vpn_connection in vpn_connections %}
<item>
<vpnConnectionId>{{ vpn_connection.id }}</vpnConnectionId>
<state>available</state>
<state>{{ vpn_connection.state }}</state>
<customerGatewayConfiguration>
<vpn_connection id="{{ vpn_connection.id }}">
<customer_gateway_id>{{ vpn_connection.customer_gateway_id }}</customer_gateway_id>
<vpn_gateway_id>{{ vpn_connection.vpn_gateway_id }}</vpn_gateway_id>
<vpn_connection_type>ipsec.1</vpn_connection_type>
<ipsec_tunnel>
<customer_gateway>
<tunnel_outside_address>
<ip_address>12.1.2.3</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.42</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>65000</asn>
<hold_time>30</hold_time>
</bgp>
</customer_gateway>
<vpn_gateway>
<tunnel_outside_address>
<ip_address>52.2.144.13</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.41</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>7224</asn>
<hold_time>30</hold_time>
</bgp>
</vpn_gateway>
<ike>
<authentication_protocol>sha1</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>28800</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>main</mode>
<pre_shared_key>Iw2IAN9XUsQeYUrkMGP3kP59ugFDkfHg</pre_shared_key>
</ike>
<ipsec>
<protocol>esp</protocol>
<authentication_protocol>hmac-sha1-96</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>3600</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>tunnel</mode>
<clear_df_bit>true</clear_df_bit>
<fragmentation_before_encryption>true</fragmentation_before_encryption>
<tcp_mss_adjustment>1387</tcp_mss_adjustment>
<dead_peer_detection>
<interval>10</interval>
<retries>3</retries>
</dead_peer_detection>
</ipsec>
</ipsec_tunnel>
<ipsec_tunnel>
<customer_gateway>
<tunnel_outside_address>
<ip_address>12.1.2.3</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.42</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>65000</asn>
<hold_time>30</hold_time>
</bgp>
</customer_gateway>
<vpn_gateway>
<tunnel_outside_address>
<ip_address>52.2.144.13</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.41</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>7224</asn>
<hold_time>30</hold_time>
</bgp>
</vpn_gateway>
<ike>
<authentication_protocol>sha1</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>28800</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>main</mode>
<pre_shared_key>Iw2IAN9XUsQeYUrkMGP3kP59ugFDkfHg</pre_shared_key>
</ike>
<ipsec>
<protocol>esp</protocol>
<authentication_protocol>hmac-sha1-96</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>3600</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>tunnel</mode>
<clear_df_bit>true</clear_df_bit>
<fragmentation_before_encryption>true</fragmentation_before_encryption>
<tcp_mss_adjustment>1387</tcp_mss_adjustment>
<dead_peer_detection>
<interval>10</interval>
<retries>3</retries>
</dead_peer_detection>
</ipsec>
</ipsec_tunnel>
</vpn_connection>
"""
+ escape(CUSTOMER_GATEWAY_CONFIGURATION_TEMPLATE)
+ """
</customerGatewayConfiguration>
<type>ipsec.1</type>
<customerGatewayId>{{ vpn_connection.customer_gateway_id }}</customerGatewayId>
<vpnGatewayId>{{ vpn_connection.vpn_gateway_id }}</vpnGatewayId>
<vpnGatewayId>{{ vpn_connection.vpn_gateway_id or '' }}</vpnGatewayId>
{% if vpn_connection.transit_gateway_id %}
<transitGatewayId>{{ vpn_connection.transit_gateway_id }}</transitGatewayId>
{% endif %}
<tagSet>
{% for tag in vpn_connection.get_tags() %}
<item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
@ -322,3 +245,4 @@ DESCRIBE_VPN_CONNECTION_RESPONSE = """
{% endfor %}
</vpnConnectionSet>
</DescribeVpnConnectionsResponse>"""
)

View File

@ -15,6 +15,9 @@ from moto.iam import iam_backends
EC2_RESOURCE_TO_PREFIX = {
"customer-gateway": "cgw",
"transit-gateway": "tgw",
"transit-gateway-route-table": "tgw-rtb",
"transit-gateway-attachment": "tgw-attach",
"dhcp-options": "dopt",
"flow-logs": "fl",
"image": "ami",
@ -168,6 +171,22 @@ def random_nat_gateway_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["nat-gateway"], size=17)
def random_transit_gateway_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["transit-gateway"], size=17)
def random_transit_gateway_route_table_id():
return random_id(
prefix=EC2_RESOURCE_TO_PREFIX["transit-gateway-route-table"], size=17
)
def random_transit_gateway_attachment_id():
return random_id(
prefix=EC2_RESOURCE_TO_PREFIX["transit-gateway-attachment"], size=17
)
def random_launch_template_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["launch-template"], size=17)
@ -207,7 +226,7 @@ def generate_route_id(route_table_id, cidr_block, ipv6_cidr_block=None):
def generate_vpc_end_point_id(vpc_id):
return "%s-%s" % ("vpce", vpc_id[4:])
return "%s-%s%s" % ("vpce", vpc_id[4:], random_resource_id(4))
def create_dns_entries(service_name, vpc_endpoint_id):
@ -342,6 +361,13 @@ def get_obj_tag_values(obj):
return tags
def add_tag_specification(tags):
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
return tags
def tag_filter_matches(obj, filter_name, filter_values):
regex_filters = [re.compile(simple_aws_filter_to_re(f)) for f in filter_values]
if filter_name == "tag-key":
@ -515,6 +541,11 @@ def random_key_pair():
def get_prefix(resource_id):
resource_id_prefix, separator, after = resource_id.partition("-")
if resource_id_prefix == EC2_RESOURCE_TO_PREFIX["transit-gateway"]:
if after.startswith("rtb"):
resource_id_prefix = EC2_RESOURCE_TO_PREFIX["transit-gateway-route-table"]
if after.startswith("attach"):
resource_id_prefix = EC2_RESOURCE_TO_PREFIX["transit-gateway-attachment"]
if resource_id_prefix == EC2_RESOURCE_TO_PREFIX["network-interface"]:
if after.startswith("attach"):
resource_id_prefix = EC2_RESOURCE_TO_PREFIX["network-interface-attachment"]

View File

@ -307,6 +307,7 @@ class Service(BaseObject, CloudFormationModel):
tags=None,
deployment_controller=None,
launch_type=None,
service_registries=None,
):
self.cluster_arn = cluster.arn
self.arn = "arn:aws:ecs:{0}:{1}:service/{2}".format(
@ -324,6 +325,7 @@ class Service(BaseObject, CloudFormationModel):
self.deployment_controller = deployment_controller or {"type": "ECS"}
self.events = []
self.launch_type = launch_type
self.service_registries = service_registries or []
if self.deployment_controller["type"] == "ECS":
self.deployments = [
{
@ -1076,6 +1078,7 @@ class EC2ContainerServiceBackend(BaseBackend):
tags=None,
deployment_controller=None,
launch_type=None,
service_registries=None,
):
cluster = self._get_cluster(cluster_str)
@ -1101,6 +1104,7 @@ class EC2ContainerServiceBackend(BaseBackend):
tags,
deployment_controller,
launch_type,
service_registries=service_registries,
)
cluster_service_pair = "{0}:{1}".format(cluster.name, service_name)
self.services[cluster_service_pair] = service

View File

@ -176,6 +176,7 @@ class EC2ContainerServiceResponse(BaseResponse):
desired_count = self._get_int_param("desiredCount")
load_balancers = self._get_param("loadBalancers")
scheduling_strategy = self._get_param("schedulingStrategy")
service_registries = self._get_param("serviceRegistries")
tags = self._get_param("tags")
deployment_controller = self._get_param("deploymentController")
launch_type = self._get_param("launchType")
@ -189,6 +190,7 @@ class EC2ContainerServiceResponse(BaseResponse):
tags,
deployment_controller,
launch_type,
service_registries=service_registries,
)
return json.dumps({"service": service.response_object})

View File

@ -103,11 +103,20 @@ class PriorityInUseError(ELBClientError):
class InvalidConditionFieldError(ELBClientError):
VALID_FIELDS = [
"path-pattern",
"host-header",
"http-header",
"http-request-method",
"query-string",
"source-ip",
]
def __init__(self, invalid_name):
super(InvalidConditionFieldError, self).__init__(
"ValidationError",
"Condition field '%s' must be one of '[path-pattern, host-header]"
% (invalid_name),
"Condition field '%s' must be one of '[%s]'"
% (invalid_name, ",".join(self.VALID_FIELDS)),
)
@ -132,6 +141,14 @@ class ActionTargetGroupNotFoundError(ELBClientError):
)
class ListenerOrBalancerMissingError(ELBClientError):
def __init__(self, arn):
super(ListenerOrBalancerMissingError, self).__init__(
"ValidationError",
"You must specify either listener ARNs or a load balancer ARN",
)
class InvalidDescribeRulesRequest(ELBClientError):
def __init__(self, msg):
super(InvalidDescribeRulesRequest, self).__init__("ValidationError", msg)

View File

@ -70,6 +70,7 @@ class FakeTargetGroup(CloudFormationModel):
healthcheck_path=None,
healthcheck_interval_seconds=None,
healthcheck_timeout_seconds=None,
healthcheck_enabled=None,
healthy_threshold_count=None,
unhealthy_threshold_count=None,
matcher=None,
@ -82,19 +83,21 @@ class FakeTargetGroup(CloudFormationModel):
self.vpc_id = vpc_id
self.protocol = protocol
self.port = port
self.healthcheck_protocol = healthcheck_protocol or "HTTP"
self.healthcheck_port = healthcheck_port or str(self.port)
self.healthcheck_path = healthcheck_path or "/"
self.healthcheck_protocol = healthcheck_protocol or self.protocol
self.healthcheck_port = healthcheck_port
self.healthcheck_path = healthcheck_path
self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5
self.healthcheck_enabled = healthcheck_enabled
self.healthy_threshold_count = healthy_threshold_count or 5
self.unhealthy_threshold_count = unhealthy_threshold_count or 2
self.load_balancer_arns = []
self.tags = {}
if matcher is None:
self.matcher = {"HttpCode": "200"}
else:
self.matcher = matcher
self.matcher = matcher
if self.protocol != "TCP":
self.matcher = self.matcher or {"HttpCode": "200"}
self.healthcheck_path = self.healthcheck_path or "/"
self.healthcheck_port = self.healthcheck_port or str(self.port)
self.target_type = target_type
self.attributes = {
@ -209,7 +212,7 @@ class FakeListener(CloudFormationModel):
):
self.load_balancer_arn = load_balancer_arn
self.arn = arn
self.protocol = protocol.upper()
self.protocol = (protocol or "").upper()
self.port = port
self.ssl_policy = ssl_policy
self.certificate = certificate
@ -224,6 +227,7 @@ class FakeListener(CloudFormationModel):
actions=default_actions,
is_default=True,
)
self.tags = {}
@property
def physical_resource_id(self):
@ -437,6 +441,7 @@ class FakeLoadBalancer(CloudFormationModel):
dns_name,
state,
scheme="internet-facing",
loadbalancer_type=None,
):
self.name = name
self.created_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
@ -449,14 +454,15 @@ class FakeLoadBalancer(CloudFormationModel):
self.arn = arn
self.dns_name = dns_name
self.state = state
self.loadbalancer_type = loadbalancer_type or "application"
self.stack = "ipv4"
self.attrs = {
"access_logs.s3.enabled": "false",
# "access_logs.s3.enabled": "false", # commented out for TF compatibility
"access_logs.s3.bucket": None,
"access_logs.s3.prefix": None,
"deletion_protection.enabled": "false",
"idle_timeout.timeout_seconds": "60",
# "idle_timeout.timeout_seconds": "60", # commented out for TF compatibility
}
@property
@ -573,7 +579,12 @@ class ELBv2Backend(BaseBackend):
self.__init__(region_name)
def create_load_balancer(
self, name, security_groups, subnet_ids, scheme="internet-facing"
self,
name,
security_groups,
subnet_ids,
scheme="internet-facing",
loadbalancer_type=None,
):
vpc_id = None
subnets = []
@ -605,6 +616,7 @@ class ELBv2Backend(BaseBackend):
vpc_id=vpc_id,
dns_name=dns_name,
state=state,
loadbalancer_type=loadbalancer_type,
)
self.load_balancers[arn] = new_load_balancer
return new_load_balancer
@ -634,7 +646,7 @@ class ELBv2Backend(BaseBackend):
)
elif action_type == "forward" and "ForwardConfig" not in action:
default_actions.append(
{"type": action_type, "target_group_arn": action["TargetGroupArn"],}
{"type": action_type, "target_group_arn": action["TargetGroupArn"]}
)
elif action_type in [
"redirect",
@ -666,8 +678,12 @@ class ELBv2Backend(BaseBackend):
listener = listeners[0]
# validate conditions
# see: https://docs.aws.amazon.com/cli/latest/reference/elbv2/create-rule.html
self._validate_conditions(conditions)
# TODO: check QueryStringConfig condition
# TODO: check HttpRequestMethodConfig condition
# TODO: check SourceIpConfig condition
# TODO: check pattern of value for 'host-header'
# TODO: check pattern of value for 'path-pattern'
@ -778,6 +794,10 @@ class ELBv2Backend(BaseBackend):
)
if values is None or len(values) == 0:
raise InvalidConditionValueError("A condition value must be specified")
if condition.get("Values") and condition.get("PathPatternConfig"):
raise InvalidConditionValueError(
"You cannot provide both Values and 'PathPatternConfig' for a condition of type 'path-pattern'"
)
for value in values:
if len(value) > 128:
raise InvalidConditionValueError(
@ -955,7 +975,7 @@ Member must satisfy regular expression pattern: {}".format(
action_type = action["Type"]
if action_type == "forward":
default_actions.append(
{"type": action_type, "target_group_arn": action["TargetGroupArn"],}
{"type": action_type, "target_group_arn": action["TargetGroupArn"]}
)
elif action_type in [
"redirect",
@ -1340,6 +1360,7 @@ Member must satisfy regular expression pattern: {}".format(
healthy_threshold_count=None,
unhealthy_threshold_count=None,
http_codes=None,
health_check_enabled=None,
):
target_group = self.target_groups.get(arn)
if target_group is None:
@ -1366,6 +1387,8 @@ Member must satisfy regular expression pattern: {}".format(
target_group.healthcheck_protocol = health_check_proto
if health_check_timeout is not None:
target_group.healthcheck_timeout_seconds = health_check_timeout
if health_check_enabled is not None:
target_group.healthcheck_enabled = health_check_enabled
if healthy_threshold_count is not None:
target_group.healthy_threshold_count = healthy_threshold_count
if unhealthy_threshold_count is not None:

View File

@ -6,7 +6,8 @@ from .models import elbv2_backends
from .exceptions import DuplicateTagKeysError
from .exceptions import LoadBalancerNotFoundError
from .exceptions import TargetGroupNotFoundError
from .exceptions import ListenerNotFoundError
from .exceptions import ListenerOrBalancerMissingError
SSL_POLICIES = [
{
@ -138,12 +139,14 @@ class ELBV2Response(BaseResponse):
subnet_ids = self._get_multi_param("Subnets.member")
security_groups = self._get_multi_param("SecurityGroups.member")
scheme = self._get_param("Scheme")
loadbalancer_type = self._get_param("Type")
load_balancer = self.elbv2_backend.create_load_balancer(
name=load_balancer_name,
security_groups=security_groups,
subnet_ids=subnet_ids,
scheme=scheme,
loadbalancer_type=loadbalancer_type,
)
self._add_tags(load_balancer)
template = self.response_template(CREATE_LOAD_BALANCER_TEMPLATE)
@ -173,9 +176,11 @@ class ELBV2Response(BaseResponse):
healthcheck_path = self._get_param("HealthCheckPath")
healthcheck_interval_seconds = self._get_param("HealthCheckIntervalSeconds")
healthcheck_timeout_seconds = self._get_param("HealthCheckTimeoutSeconds")
healthcheck_enabled = self._get_param("HealthCheckEnabled")
healthy_threshold_count = self._get_param("HealthyThresholdCount")
unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount")
matcher = self._get_param("Matcher")
target_type = self._get_param("TargetType")
target_group = self.elbv2_backend.create_target_group(
name,
@ -187,9 +192,11 @@ class ELBV2Response(BaseResponse):
healthcheck_path=healthcheck_path,
healthcheck_interval_seconds=healthcheck_interval_seconds,
healthcheck_timeout_seconds=healthcheck_timeout_seconds,
healthcheck_enabled=healthcheck_enabled,
healthy_threshold_count=healthy_threshold_count,
unhealthy_threshold_count=unhealthy_threshold_count,
matcher=matcher,
target_type=target_type,
)
template = self.response_template(CREATE_TARGET_GROUP_TEMPLATE)
@ -299,7 +306,7 @@ class ELBV2Response(BaseResponse):
load_balancer_arn = self._get_param("LoadBalancerArn")
listener_arns = self._get_multi_param("ListenerArns.member")
if not load_balancer_arn and not listener_arns:
raise LoadBalancerNotFoundError()
raise ListenerOrBalancerMissingError()
listeners = self.elbv2_backend.describe_listeners(
load_balancer_arn, listener_arns
@ -453,6 +460,14 @@ class ELBV2Response(BaseResponse):
resource = self.elbv2_backend.load_balancers.get(arn)
if not resource:
raise LoadBalancerNotFoundError()
elif ":listener" in arn:
lb_arn, _, _ = arn.replace(":listener", ":loadbalancer").rpartition("/")
balancer = self.elbv2_backend.load_balancers.get(lb_arn)
if not balancer:
raise LoadBalancerNotFoundError()
resource = balancer.listeners.get(arn)
if not resource:
raise ListenerNotFoundError()
else:
raise LoadBalancerNotFoundError()
resources.append(resource)
@ -555,6 +570,7 @@ class ELBV2Response(BaseResponse):
health_check_path = self._get_param("HealthCheckPath")
health_check_interval = self._get_param("HealthCheckIntervalSeconds")
health_check_timeout = self._get_param("HealthCheckTimeoutSeconds")
health_check_enabled = self._get_param("HealthCheckEnabled")
healthy_threshold_count = self._get_param("HealthyThresholdCount")
unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount")
http_codes = self._get_param("Matcher.HttpCode")
@ -569,6 +585,7 @@ class ELBV2Response(BaseResponse):
healthy_threshold_count,
unhealthy_threshold_count,
http_codes,
health_check_enabled=health_check_enabled,
)
template = self.response_template(MODIFY_TARGET_GROUP_TEMPLATE)
@ -687,7 +704,7 @@ CREATE_LOAD_BALANCER_TEMPLATE = """<CreateLoadBalancerResponse xmlns="http://ela
<State>
<Code>{{ load_balancer.state }}</Code>
</State>
<Type>application</Type>
<Type>{{ load_balancer.loadbalancer_type }}</Type>
</member>
</LoadBalancers>
</CreateLoadBalancerResult>
@ -817,10 +834,11 @@ CREATE_TARGET_GROUP_TEMPLATE = """<CreateTargetGroupResponse xmlns="http://elast
<Port>{{ target_group.port }}</Port>
<VpcId>{{ target_group.vpc_id }}</VpcId>
<HealthCheckProtocol>{{ target_group.health_check_protocol }}</HealthCheckProtocol>
<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path }}</HealthCheckPath>
<HealthCheckPort>{{ target_group.healthcheck_port or '' }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path or '' }}</HealthCheckPath>
<HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds>
<HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds>
<HealthCheckEnabled>{{ target_group.healthcheck_enabled and 'true' or 'false' }}</HealthCheckEnabled>
<HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount>
<UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount>
{% if target_group.matcher %}
@ -928,7 +946,7 @@ DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http
<State>
<Code>{{ load_balancer.state }}</Code>
</State>
<Type>application</Type>
<Type>{{ load_balancer.loadbalancer_type }}</Type>
<IpAddressType>ipv4</IpAddressType>
</member>
{% endfor %}
@ -1052,10 +1070,11 @@ DESCRIBE_TARGET_GROUPS_TEMPLATE = """<DescribeTargetGroupsResponse xmlns="http:/
<Port>{{ target_group.port }}</Port>
<VpcId>{{ target_group.vpc_id }}</VpcId>
<HealthCheckProtocol>{{ target_group.healthcheck_protocol }}</HealthCheckProtocol>
<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path }}</HealthCheckPath>
<HealthCheckPort>{{ target_group.healthcheck_port or '' }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path or '' }}</HealthCheckPath>
<HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds>
<HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds>
<HealthCheckEnabled>{{ target_group.healthcheck_enabled and 'true' or 'false' }}</HealthCheckEnabled>
<HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount>
<UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount>
{% if target_group.matcher %}

View File

@ -246,9 +246,10 @@ class Rule(CloudFormationModel):
class EventBus(CloudFormationModel):
def __init__(self, region_name, name):
def __init__(self, region_name, name, tags=None):
self.region = region_name
self.name = name
self.tags = tags or []
self._permissions = {}
@ -545,6 +546,7 @@ class Connection(BaseModel):
def __init__(
self, name, region_name, description, authorization_type, auth_parameters,
):
self.uuid = uuid4()
self.name = name
self.region = region_name
self.description = description
@ -555,10 +557,62 @@ class Connection(BaseModel):
@property
def arn(self):
return "arn:aws:events:{0}:{1}:connection/{2}".format(
self.region, ACCOUNT_ID, self.name
return "arn:aws:events:{0}:{1}:connection/{2}/{3}".format(
self.region, ACCOUNT_ID, self.name, self.uuid
)
def describe_short(self):
"""
Create the short description for the Connection object.
Taken our from the Response Syntax of this API doc:
- https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DeleteConnection.html
Something to consider:
- The original response also has
- LastAuthorizedTime (number)
- LastModifiedTime (number)
- At the time of implemeting this, there was no place where to set/get
those attributes. That is why they are not in the response.
Returns:
dict
"""
return {
"ConnectionArn": self.arn,
"ConnectionState": self.state,
"CreationTime": self.creation_time,
}
def describe(self):
"""
Create a complete description for the Connection object.
Taken our from the Response Syntax of this API doc:
- https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeConnection.html
Something to consider:
- The original response also has:
- LastAuthorizedTime (number)
- LastModifiedTime (number)
- SecretArn (string)
- StateReason (string)
- At the time of implemeting this, there was no place where to set/get
those attributes. That is why they are not in the response.
Returns:
dict
"""
return {
"AuthorizationType": self.authorization_type,
"AuthParameters": self.auth_parameters,
"ConnectionArn": self.arn,
"ConnectionState": self.state,
"CreationTime": self.creation_time,
"Description": self.description,
"Name": self.name,
}
class Destination(BaseModel):
def __init__(
@ -568,32 +622,71 @@ class Destination(BaseModel):
description,
connection_arn,
invocation_endpoint,
invocation_rate_limit_per_second,
http_method,
):
self.uuid = uuid4()
self.name = name
self.region = region_name
self.description = description
self.connection_arn = connection_arn
self.invocation_endpoint = invocation_endpoint
self.invocation_rate_limit_per_second = invocation_rate_limit_per_second
self.creation_time = unix_time(datetime.utcnow())
self.http_method = http_method
self.state = "ACTIVE"
@property
def arn(self):
return "arn:aws:events:{0}:{1}:destination/{2}".format(
self.region, ACCOUNT_ID, self.name
return "arn:aws:events:{0}:{1}:api-destination/{2}/{3}".format(
self.region, ACCOUNT_ID, self.name, self.uuid
)
def describe(self):
"""
Describes the Destination object as a dict
Docs:
Response Syntax in
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeApiDestination.html
Something to consider:
- The response also has [InvocationRateLimitPerSecond] which was not
available when implementing this method
Returns:
dict
"""
return {
"ApiDestinationArn": self.arn,
"ApiDestinationState": self.state,
"ConnectionArn": self.connection_arn,
"CreationTime": self.creation_time,
"Description": self.description,
"HttpMethod": self.http_method,
"InvocationEndpoint": self.invocation_endpoint,
"InvocationRateLimitPerSecond": self.invocation_rate_limit_per_second,
"LastModifiedTime": self.creation_time,
"Name": self.name,
}
def describe_short(self):
return {
"ApiDestinationArn": self.arn,
"ApiDestinationState": self.state,
"CreationTime": self.creation_time,
"LastModifiedTime": self.creation_time,
}
class EventPattern:
def __init__(self, filter):
self._filter = self._load_event_pattern(filter)
self._filter_raw = filter
if not self._validate_event_pattern(self._filter):
raise InvalidEventPatternException
def __str__(self):
return json.dumps(self._filter)
return self._filter_raw or str()
def _load_event_pattern(self, pattern):
try:
@ -1032,7 +1125,7 @@ class EventsBackend(BaseBackend):
return event_bus
def create_event_bus(self, name, event_source_name=None):
def create_event_bus(self, name, event_source_name=None, tags=None):
if name in self.event_buses:
raise JsonRESTError(
"ResourceAlreadyExistsException",
@ -1050,7 +1143,10 @@ class EventsBackend(BaseBackend):
"Event source {} does not exist.".format(event_source_name),
)
self.event_buses[name] = EventBus(self.region_name, name)
event_bus = EventBus(self.region_name, name, tags=tags)
self.event_buses[name] = event_bus
if tags:
self.tagger.tag_resource(event_bus.arn, tags)
return self.event_buses[name]
@ -1069,30 +1165,38 @@ class EventsBackend(BaseBackend):
raise JsonRESTError(
"ValidationException", "Cannot delete event bus default."
)
self.event_buses.pop(name, None)
event_bus = self.event_buses.pop(name, None)
if event_bus:
self.tagger.delete_all_tags_for_resource(event_bus.arn)
def list_tags_for_resource(self, arn):
name = arn.split("/")[-1]
if name in self.rules:
return self.tagger.list_tags_for_resource(self.rules[name].arn)
registries = [self.rules, self.event_buses]
for registry in registries:
if name in registry:
return self.tagger.list_tags_for_resource(registry[name].arn)
raise ResourceNotFoundException(
"Rule {0} does not exist on EventBus default.".format(name)
)
def tag_resource(self, arn, tags):
name = arn.split("/")[-1]
if name in self.rules:
self.tagger.tag_resource(self.rules[name].arn, tags)
return {}
registries = [self.rules, self.event_buses]
for registry in registries:
if name in registry:
self.tagger.tag_resource(registry[name].arn, tags)
return {}
raise ResourceNotFoundException(
"Rule {0} does not exist on EventBus default.".format(name)
)
def untag_resource(self, arn, tag_names):
name = arn.split("/")[-1]
if name in self.rules:
self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names)
return {}
registries = [self.rules, self.event_buses]
for registry in registries:
if name in registry:
self.tagger.untag_resource_using_names(registry[name].arn, tag_names)
return {}
raise ResourceNotFoundException(
"Rule {0} does not exist on EventBus default.".format(name)
)
@ -1337,27 +1441,145 @@ class EventsBackend(BaseBackend):
def list_connections(self):
return self.connections.values()
def create_api_destination(
self, name, description, connection_arn, invocation_endpoint, http_method
):
def describe_connection(self, name):
"""
Retrieves details about a connection.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeConnection.html
Args:
name: The name of the connection to retrieve.
Raises:
ResourceNotFoundException: When the connection is not present.
Returns:
dict
"""
connection = self.connections.get(name)
if not connection:
raise ResourceNotFoundException(
"Connection '{}' does not exist.".format(name)
)
return connection.describe()
def delete_connection(self, name):
"""
Deletes a connection.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DeleteConnection.html
Args:
name: The name of the connection to delete.
Raises:
ResourceNotFoundException: When the connection is not present.
Returns:
dict
"""
connection = self.connections.pop(name, None)
if not connection:
raise ResourceNotFoundException(
"Connection '{}' does not exist.".format(name)
)
return connection.describe_short()
def create_api_destination(
self,
name,
description,
connection_arn,
invocation_endpoint,
invocation_rate_limit_per_second,
http_method,
):
"""
Creates an API destination, which is an HTTP invocation endpoint configured as a target for events.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_CreateApiDestination.html
Returns:
dict
"""
destination = Destination(
name=name,
region_name=self.region_name,
description=description,
connection_arn=connection_arn,
invocation_endpoint=invocation_endpoint,
invocation_rate_limit_per_second=invocation_rate_limit_per_second,
http_method=http_method,
)
self.destinations[name] = destination
return destination
return destination.describe_short()
def list_api_destinations(self):
return self.destinations.values()
def describe_api_destination(self, name):
return self.destinations.get(name)
"""
Retrieves details about an API destination.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeApiDestination.html
Args:
name: The name of the API destination to retrieve.
Returns:
dict
"""
destination = self.destinations.get(name)
if not destination:
raise ResourceNotFoundException(
"An api-destination '{}' does not exist.".format(name)
)
return destination.describe()
def update_api_destination(self, *, name, **kwargs):
"""
Creates an API destination, which is an HTTP invocation endpoint configured as a target for events.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_UpdateApiDestination.html
Returns:
dict
"""
destination = self.destinations.get(name)
if not destination:
raise ResourceNotFoundException(
"An api-destination '{}' does not exist.".format(name)
)
for attr, value in kwargs.items():
if value is not None and hasattr(destination, attr):
setattr(destination, attr, value)
return destination.describe_short()
def delete_api_destination(self, name):
"""
Deletes the specified API destination.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DeleteApiDestination.html
Args:
name: The name of the destination to delete.
Raises:
ResourceNotFoundException: When the destination is not present.
Returns:
dict
"""
destination = self.destinations.pop(name, None)
if not destination:
raise ResourceNotFoundException(
"An api-destination '{}' does not exist.".format(name)
)
return {}
events_backends = {}

View File

@ -42,6 +42,20 @@ class EventsHandler(BaseResponse):
def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none)
def _create_response(self, result):
"""
Creates a proper response for the API.
It basically transforms a dict-like result from the backend
into a tuple (str, dict) properly formatted.
Args:
result (dict): result from backend
Returns:
(str, dict): dumped result and headers
"""
return json.dumps(result), self.headers
def error(self, type_, message="", status=400):
headers = self.response_headers
headers["status"] = status
@ -266,9 +280,9 @@ class EventsHandler(BaseResponse):
def create_event_bus(self):
name = self._get_param("Name")
event_source_name = self._get_param("EventSourceName")
tags = self._get_param("Tags")
event_bus = self.events_backend.create_event_bus(name, event_source_name)
event_bus = self.events_backend.create_event_bus(name, event_source_name, tags)
return json.dumps({"EventBusArn": event_bus.arn}), self.response_headers
def list_event_buses(self):
@ -448,27 +462,35 @@ class EventsHandler(BaseResponse):
return json.dumps({"Connections": result}), self.response_headers
def describe_connection(self):
name = self._get_param("Name")
result = self.events_backend.describe_connection(name)
return json.dumps(result), self.response_headers
def delete_connection(self):
name = self._get_param("Name")
result = self.events_backend.delete_connection(name)
return json.dumps(result), self.response_headers
def create_api_destination(self):
name = self._get_param("Name")
description = self._get_param("Description")
connection_arn = self._get_param("ConnectionArn")
invocation_endpoint = self._get_param("InvocationEndpoint")
invocation_rate_limit_per_second = self._get_param(
"InvocationRateLimitPerSecond"
)
http_method = self._get_param("HttpMethod")
destination = self.events_backend.create_api_destination(
name, description, connection_arn, invocation_endpoint, http_method
)
return (
json.dumps(
{
"ApiDestinationArn": destination.arn,
"ApiDestinationState": "ACTIVE",
"CreationTime": destination.creation_time,
"LastModifiedTime": destination.creation_time,
}
),
self.response_headers,
result = self.events_backend.create_api_destination(
name,
description,
connection_arn,
invocation_endpoint,
invocation_rate_limit_per_second,
http_method,
)
return self._create_response(result)
def list_api_destinations(self):
destinations = self.events_backend.list_api_destinations()
@ -491,20 +513,25 @@ class EventsHandler(BaseResponse):
def describe_api_destination(self):
name = self._get_param("Name")
destination = self.events_backend.describe_api_destination(name)
result = self.events_backend.describe_api_destination(name)
return self._create_response(result)
return (
json.dumps(
{
"ApiDestinationArn": destination.arn,
"Name": destination.name,
"ApiDestinationState": destination.state,
"ConnectionArn": destination.connection_arn,
"InvocationEndpoint": destination.invocation_endpoint,
"HttpMethod": destination.http_method,
"CreationTime": destination.creation_time,
"LastModifiedTime": destination.creation_time,
}
def update_api_destination(self):
updates = dict(
connection_arn=self._get_param("ConnectionArn"),
description=self._get_param("Description"),
http_method=self._get_param("HttpMethod"),
invocation_endpoint=self._get_param("InvocationEndpoint"),
invocation_rate_limit_per_second=self._get_param(
"InvocationRateLimitPerSecond"
),
self.response_headers,
name=self._get_param("Name"),
)
result = self.events_backend.update_api_destination(**updates)
return self._create_response(result)
def delete_api_destination(self):
name = self._get_param("Name")
result = self.events_backend.delete_api_destination(name)
return self._create_response(result)

View File

@ -1,12 +1,16 @@
from __future__ import unicode_literals
from moto.core.exceptions import RESTError
XMLNS_IAM = "https://iam.amazonaws.com/doc/2010-05-08/"
class IAMNotFoundException(RESTError):
code = 404
def __init__(self, message):
super(IAMNotFoundException, self).__init__("NoSuchEntity", message)
super(IAMNotFoundException, self).__init__(
"NoSuchEntity", message, xmlns=XMLNS_IAM, template="wrapped_single_error"
)
class IAMConflictException(RESTError):
@ -134,4 +138,6 @@ class NoSuchEntity(RESTError):
code = 404
def __init__(self, message):
super(NoSuchEntity, self).__init__("NoSuchEntity", message)
super(NoSuchEntity, self).__init__(
"NoSuchEntity", message, xmlns=XMLNS_IAM, template="wrapped_single_error"
)

32
moto/iam/models.py Executable file → Normal file
View File

@ -101,6 +101,7 @@ class Policy(CloudFormationModel):
path=None,
create_date=None,
update_date=None,
tags=None,
):
self.name = name
@ -108,6 +109,7 @@ class Policy(CloudFormationModel):
self.description = description or ""
self.id = random_policy_id()
self.path = path or "/"
self.tags = {tag["Key"]: tag["Value"] for tag in tags or []}
if default_version_id:
self.default_version_id = default_version_id
@ -337,9 +339,10 @@ class AWSManagedPolicy(ManagedPolicy):
# AWS defines some of its own managed policies and we periodically
# import them via `make aws_managed_policies`
# FIXME: Takes about 40ms at import time
aws_managed_policies_data_parsed = json.loads(aws_managed_policies_data)
aws_managed_policies = [
AWSManagedPolicy.from_data(name, d)
for name, d in json.loads(aws_managed_policies_data).items()
for name, d in aws_managed_policies_data_parsed.items()
]
@ -646,14 +649,21 @@ class Role(CloudFormationModel):
def get_tags(self):
return [self.tags[tag] for tag in self.tags]
@property
def description_escaped(self):
import html
return html.escape(self.description or "")
class InstanceProfile(CloudFormationModel):
def __init__(self, instance_profile_id, name, path, roles):
def __init__(self, instance_profile_id, name, path, roles, tags=None):
self.id = instance_profile_id
self.name = name
self.path = path or "/"
self.roles = roles if roles else []
self.create_date = datetime.utcnow()
self.tags = {tag["Key"]: tag["Value"] for tag in tags or []}
@property
def created_iso_8601(self):
@ -1410,7 +1420,7 @@ class IAMBackend(BaseBackend):
self.account_aliases = []
self.saml_providers = {}
self.open_id_providers = {}
self.policy_arn_regex = re.compile(r"^arn:aws:iam::[0-9]*:policy/.*$")
self.policy_arn_regex = re.compile(r"^arn:aws:iam::(aws|[0-9]*):policy/.*$")
self.virtual_mfa_devices = {}
self.account_password_policy = None
self.account_summary = AccountSummary(self)
@ -1496,12 +1506,16 @@ class IAMBackend(BaseBackend):
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.detach_from(self.get_user(user_name))
def create_policy(self, description, path, policy_document, policy_name):
def create_policy(self, description, path, policy_document, policy_name, tags=None):
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document)
iam_policy_document_validator.validate()
policy = ManagedPolicy(
policy_name, description=description, document=policy_document, path=path
policy_name,
description=description,
document=policy_document,
path=path,
tags=tags,
)
if policy.arn in self.managed_policies:
raise EntityAlreadyExists(
@ -1551,9 +1565,9 @@ class IAMBackend(BaseBackend):
def set_default_policy_version(self, policy_arn, version_id):
import re
if re.match("v[1-9][0-9]*(\.[A-Za-z0-9-]*)?", version_id) is None:
if re.match(r"v[1-9][0-9]*(\.[A-Za-z0-9-]*)?", version_id) is None:
raise ValidationError(
"Value '{0}' at 'versionId' failed to satisfy constraint: Member must satisfy regular expression pattern: v[1-9][0-9]*(\.[A-Za-z0-9-]*)?".format(
"Value '{0}' at 'versionId' failed to satisfy constraint: Member must satisfy regular expression pattern: v[1-9][0-9]*(\\.[A-Za-z0-9-]*)?".format(
version_id
)
)
@ -1823,7 +1837,7 @@ class IAMBackend(BaseBackend):
return
raise IAMNotFoundException("Policy not found")
def create_instance_profile(self, name, path, role_ids):
def create_instance_profile(self, name, path, role_ids, tags=None):
if self.instance_profiles.get(name):
raise IAMConflictException(
code="EntityAlreadyExists",
@ -1833,7 +1847,7 @@ class IAMBackend(BaseBackend):
instance_profile_id = random_resource_id()
roles = [iam_backend.get_role_by_id(role_id) for role_id in role_ids]
instance_profile = InstanceProfile(instance_profile_id, name, path, roles)
instance_profile = InstanceProfile(instance_profile_id, name, path, roles, tags)
self.instance_profiles[name] = instance_profile
return instance_profile

View File

@ -347,7 +347,7 @@ class IAMPolicyDocumentValidator:
return
resource_partitions = resource_partitions[2].partition(":")
if resource_partitions[0] != "aws":
if resource_partitions[0] not in ["aws", "*"]:
remaining_resource_parts = resource_partitions[2].split(":")
arn1 = (

View File

@ -53,8 +53,9 @@ class IamResponse(BaseResponse):
path = self._get_param("Path")
policy_document = self._get_param("PolicyDocument")
policy_name = self._get_param("PolicyName")
tags = self._get_multi_param("Tags.member")
policy = iam_backend.create_policy(
description, path, policy_document, policy_name
description, path, policy_document, policy_name, tags
)
template = self.response_template(CREATE_POLICY_TEMPLATE)
return template.render(policy=policy)
@ -320,8 +321,11 @@ class IamResponse(BaseResponse):
def create_instance_profile(self):
profile_name = self._get_param("InstanceProfileName")
path = self._get_param("Path", "/")
tags = self._get_multi_param("Tags.member")
profile = iam_backend.create_instance_profile(profile_name, path, role_ids=[])
profile = iam_backend.create_instance_profile(
profile_name, path, role_ids=[], tags=tags
)
template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE)
return template.render(profile=profile)
@ -461,6 +465,12 @@ class IamResponse(BaseResponse):
template = self.response_template(GET_GROUP_POLICY_TEMPLATE)
return template.render(name="GetGroupPolicyResponse", **policy_result)
def delete_group_policy(self):
group_name = self._get_param("GroupName")
policy_name = self._get_param("PolicyName")
iam_backend.delete_group_policy(group_name, policy_name)
return ""
def delete_group(self):
group_name = self._get_param("GroupName")
iam_backend.delete_group(group_name)
@ -1093,6 +1103,14 @@ CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse>
<PolicyId>{{ policy.id }}</PolicyId>
<PolicyName>{{ policy.name }}</PolicyName>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
<Tags>
{% for tag_key, tag_value in policy.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</Policy>
</CreatePolicyResult>
<ResponseMetadata>
@ -1112,6 +1130,14 @@ GET_POLICY_TEMPLATE = """<GetPolicyResponse>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
<Tags>
{% for tag_key, tag_value in policy.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</Policy>
</GetPolicyResult>
<ResponseMetadata>
@ -1223,11 +1249,30 @@ CREATE_INSTANCE_PROFILE_TEMPLATE = """<CreateInstanceProfileResponse xmlns="http
<CreateInstanceProfileResult>
<InstanceProfile>
<InstanceProfileId>{{ profile.id }}</InstanceProfileId>
<Roles/>
<Roles>
{% for role in profile.roles %}
<member>
<Path>{{ role.path }}</Path>
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
</member>
{% endfor %}
</Roles>
<InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
<Tags>
{% for tag_key, tag_value in profile.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</InstanceProfile>
</CreateInstanceProfileResult>
<ResponseMetadata>
@ -1261,6 +1306,14 @@ GET_INSTANCE_PROFILE_TEMPLATE = """<GetInstanceProfileResponse xmlns="https://ia
<Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
<Tags>
{% for tag_key, tag_value in profile.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</InstanceProfile>
</GetInstanceProfileResult>
<ResponseMetadata>
@ -1276,7 +1329,7 @@ CREATE_ROLE_TEMPLATE = """<CreateRoleResponse xmlns="https://iam.amazonaws.com/d
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
{% if role.description is not none %}
<Description>{{role.description}}</Description>
<Description>{{ role.description_escaped }}</Description>
{% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
@ -1330,7 +1383,9 @@ UPDATE_ROLE_DESCRIPTION_TEMPLATE = """<UpdateRoleDescriptionResponse xmlns="http
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<Description>{{role.description}}</Description>
{% if role.description is not none %}
<Description>{{ role.description_escaped }}</Description>
{% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
<MaxSessionDuration>{{ role.max_session_duration }}</MaxSessionDuration>
@ -1358,8 +1413,8 @@ GET_ROLE_TEMPLATE = """<GetRoleResponse xmlns="https://iam.amazonaws.com/doc/201
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
{% if role.description %}
<Description>{{role.description}}</Description>
{% if role.description is not none %}
<Description>{{ role.description_escaped }}</Description>
{% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
@ -1422,7 +1477,7 @@ LIST_ROLES_TEMPLATE = """<ListRolesResponse xmlns="https://iam.amazonaws.com/doc
</PermissionsBoundary>
{% endif %}
{% if role.description is not none %}
<Description>{{ role.description }}</Description>
<Description>{{ role.description_escaped }}</Description>
{% endif %}
</member>
{% endfor %}
@ -2243,7 +2298,9 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<Description>{{role.description}}</Description>
{% if role.description is not none %}
<Description>{{ role.description_escaped }}</Description>
{% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %}

View File

@ -1,5 +1,6 @@
from boto3 import Session
from moto import core as moto_core
from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time_millis
from .exceptions import (
@ -53,7 +54,7 @@ class LogStream(BaseModel):
self.region = region
self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format(
region=region,
id=self.__class__._log_ids,
id=moto_core.ACCOUNT_ID,
log_group=log_group,
log_stream=name,
)
@ -262,6 +263,11 @@ class LogGroup(BaseModel):
) # AWS defaults to Never Expire for log group retention
self.subscription_filters = []
# The Amazon Resource Name (ARN) of the CMK to use when encrypting log data. It is optional.
# Docs:
# https://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_CreateLogGroup.html
self.kms_key_id = kwargs.get("kmsKeyId")
def create_log_stream(self, log_stream_name):
if log_stream_name in self.streams:
raise ResourceAlreadyExistsException()
@ -442,6 +448,8 @@ class LogGroup(BaseModel):
# AWS only returns retentionInDays if a value is set for the log group (ie. not Never Expire)
if self.retention_in_days:
log_group["retentionInDays"] = self.retention_in_days
if self.kms_key_id:
log_group["kmsKeyId"] = self.kms_key_id
return log_group
def set_retention_policy(self, retention_in_days):
@ -510,6 +518,7 @@ class LogsBackend(BaseBackend):
self.region_name = region_name
self.groups = dict() # { logGroupName: LogGroup}
self.queries = dict()
self.resource_policies = dict()
def reset(self):
region_name = self.region_name
@ -677,6 +686,10 @@ class LogsBackend(BaseBackend):
log_group = self.groups[log_group_name]
return log_group.set_retention_policy(None)
def put_resource_policy(self, policy_name, policy_doc):
policy = {"policyName": policy_name, "policyDocument": policy_doc}
self.resource_policies[policy_name] = policy
def list_tags_log_group(self, log_group_name):
if log_group_name not in self.groups:
raise ResourceNotFoundException()

View File

@ -25,9 +25,10 @@ class LogsResponse(BaseResponse):
def create_log_group(self):
log_group_name = self._get_param("logGroupName")
tags = self._get_param("tags")
kms_key_id = self._get_param("kmsKeyId")
assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern
self.logs_backend.create_log_group(log_group_name, tags)
self.logs_backend.create_log_group(log_group_name, tags, kmsKeyId=kms_key_id)
return ""
def delete_log_group(self):
@ -166,6 +167,12 @@ class LogsResponse(BaseResponse):
self.logs_backend.delete_retention_policy(log_group_name)
return ""
def put_resource_policy(self):
policy_name = self._get_param("policyName")
policy_doc = self._get_param("policyDocument")
self.logs_backend.put_resource_policy(policy_name, policy_doc)
return ""
def list_tags_log_group(self):
log_group_name = self._get_param("logGroupName")
tags = self.logs_backend.list_tags_log_group(log_group_name)

View File

@ -32,6 +32,7 @@ class FakeOrganization(BaseModel):
self.master_account_id = utils.MASTER_ACCOUNT_ID
self.master_account_email = utils.MASTER_ACCOUNT_EMAIL
self.available_policy_types = [
# TODO: verify if this should be enabled by default (breaks TF tests for CloudTrail)
{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}
]
@ -141,7 +142,10 @@ class FakeRoot(FakeOrganizationalUnit):
self.type = "ROOT"
self.id = organization.root_id
self.name = "Root"
self.policy_types = [{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}]
self.policy_types = [
# TODO: verify if this should be enabled by default (breaks TF tests for CloudTrail)
{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}
]
self._arn_format = utils.ROOT_ARN_FORMAT
self.attached_policies = []
self.tags = {tag["Key"]: tag["Value"] for tag in kwargs.get("Tags", [])}
@ -328,6 +332,9 @@ class FakeDelegatedAdministrator(BaseModel):
class OrganizationsBackend(BaseBackend):
def __init__(self):
self._reset()
def _reset(self):
self.org = None
self.accounts = []
self.ou = []
@ -375,6 +382,10 @@ class OrganizationsBackend(BaseBackend):
raise AWSOrganizationsNotInUseException
return self.org.describe()
def delete_organization(self, **kwargs):
self._reset()
return {}
def list_roots(self):
return dict(Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)])

View File

@ -28,6 +28,11 @@ class OrganizationsResponse(BaseResponse):
def describe_organization(self):
return json.dumps(self.organizations_backend.describe_organization())
def delete_organization(self):
return json.dumps(
self.organizations_backend.delete_organization(**self.request_params)
)
def list_roots(self):
return json.dumps(self.organizations_backend.list_roots())

View File

@ -43,6 +43,19 @@ class Database(CloudFormationModel):
"engine": FilterDef(["engine"], "Engine Names"),
}
default_engine_versions = {
"MySQL": "5.6.21",
"mysql": "5.6.21",
"oracle-se1": "11.2.0.4.v3",
"oracle-se": "11.2.0.4.v3",
"oracle-ee": "11.2.0.4.v3",
"sqlserver-ee": "11.00.2100.60.v1",
"sqlserver-se": "11.00.2100.60.v1",
"sqlserver-ex": "11.00.2100.60.v1",
"sqlserver-web": "11.00.2100.60.v1",
"postgres": "9.3.3",
}
def __init__(self, **kwargs):
self.status = "available"
self.is_replica = False
@ -50,18 +63,6 @@ class Database(CloudFormationModel):
self.region = kwargs.get("region")
self.engine = kwargs.get("engine")
self.engine_version = kwargs.get("engine_version", None)
self.default_engine_versions = {
"MySQL": "5.6.21",
"mysql": "5.6.21",
"oracle-se1": "11.2.0.4.v3",
"oracle-se": "11.2.0.4.v3",
"oracle-ee": "11.2.0.4.v3",
"sqlserver-ee": "11.00.2100.60.v1",
"sqlserver-se": "11.00.2100.60.v1",
"sqlserver-ex": "11.00.2100.60.v1",
"sqlserver-web": "11.00.2100.60.v1",
"postgres": "9.3.3",
}
if not self.engine_version and self.engine in self.default_engine_versions:
self.engine_version = self.default_engine_versions[self.engine]
self.iops = kwargs.get("iops")
@ -120,6 +121,7 @@ class Database(CloudFormationModel):
self.db_parameter_group_name = kwargs.get("db_parameter_group_name")
if (
self.db_parameter_group_name
and not self.is_default_parameter_group(self.db_parameter_group_name)
and self.db_parameter_group_name
not in rds2_backends[self.region].db_parameter_groups
):
@ -160,7 +162,9 @@ class Database(CloudFormationModel):
return self.db_instance_identifier
def db_parameter_groups(self):
if not self.db_parameter_group_name:
if not self.db_parameter_group_name or self.is_default_parameter_group(
self.db_parameter_group_name
):
(
db_family,
db_parameter_group_name,
@ -182,6 +186,9 @@ class Database(CloudFormationModel):
]
]
def is_default_parameter_group(self, param_group_name):
return param_group_name.startswith("default.%s" % self.engine.lower())
def default_db_parameter_group_details(self):
if not self.engine_version:
return (None, None)

View File

@ -97,6 +97,8 @@ class FakeResourceGroup(BaseModel):
return True
def _validate_resource_query(self, value):
if not value:
return True
errors = []
if value["Type"] not in {"CLOUDFORMATION_STACK_1_0", "TAG_FILTERS_1_0"}:
errors.append(
@ -229,6 +231,8 @@ class ResourceGroupsBackend(BaseBackend):
@staticmethod
def _validate_resource_query(resource_query):
if not resource_query:
return
type = resource_query["Type"]
query = json.loads(resource_query["Query"])
query_keys = set(query.keys())

View File

@ -44,7 +44,7 @@ class ResourceGroupsResponse(BaseResponse):
)
def delete_group(self):
group_name = self._get_param("GroupName")
group_name = self._get_param("GroupName") or self._get_param("Group")
group = self.resourcegroups_backend.delete_group(group_name=group_name)
return json.dumps(
{

View File

@ -6,6 +6,8 @@ from moto.core.responses import BaseResponse
from .models import route53_backend
import xmltodict
XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/"
class Route53(BaseResponse):
def list_or_create_hostzone_response(self, request, full_url, headers):
@ -83,7 +85,7 @@ class Route53(BaseResponse):
zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1]
the_zone = route53_backend.get_hosted_zone(zoneid)
if not the_zone:
return 404, headers, "Zone %s not Found" % zoneid
return no_such_hosted_zone_error(zoneid, headers)
if request.method == "GET":
template = Template(GET_HOSTED_ZONE_RESPONSE)
@ -102,7 +104,7 @@ class Route53(BaseResponse):
zoneid = parsed_url.path.rstrip("/").rsplit("/", 2)[1]
the_zone = route53_backend.get_hosted_zone(zoneid)
if not the_zone:
return 404, headers, "Zone %s Not Found" % zoneid
return no_such_hosted_zone_error(zoneid, headers)
if method == "POST":
elements = xmltodict.parse(self.body)
@ -256,6 +258,20 @@ class Route53(BaseResponse):
return 200, headers, template.render(change_id=change_id)
def no_such_hosted_zone_error(zoneid, headers={}):
headers["X-Amzn-ErrorType"] = "NoSuchHostedZone"
headers["Content-Type"] = "text/xml"
message = "Zone %s Not Found" % zoneid
error_response = (
"<Error><Code>NoSuchHostedZone</Code><Message>%s</Message></Error>" % message
)
error_response = '<ErrorResponse xmlns="%s">%s</ErrorResponse>' % (
XMLNS,
error_response,
)
return 404, headers, error_response
LIST_TAGS_FOR_RESOURCE_RESPONSE = """
<ListTagsForResourceResponse xmlns="https://route53.amazonaws.com/doc/2015-01-01/">
<ResourceTagSet>

View File

@ -30,12 +30,28 @@ ERROR_WITH_RANGE = """{% extends 'single_error' %}
class S3ClientError(RESTError):
# S3 API uses <RequestID> as the XML tag in response messages
request_id_tag_name = "RequestID"
def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "single_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super(S3ClientError, self).__init__(*args, **kwargs)
class InvalidArgumentError(S3ClientError):
code = 400
def __init__(self, message, name, value, *args, **kwargs):
kwargs.setdefault("template", "argument_error")
kwargs["name"] = name
kwargs["value"] = value
self.templates["argument_error"] = ERROR_WITH_ARGUMENT
super(InvalidArgumentError, self).__init__(
"InvalidArgument", message, *args, **kwargs
)
class BucketError(S3ClientError):
def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "bucket_error")
@ -473,3 +489,16 @@ class InvalidContinuationToken(S3ClientError):
*args,
**kwargs
)
class InvalidFilterRuleName(InvalidArgumentError):
code = 400
def __init__(self, value, *args, **kwargs):
super(InvalidFilterRuleName, self).__init__(
"filter rule name must be either prefix or suffix",
"FilterRule.Name",
value,
*args,
**kwargs
)

View File

@ -534,7 +534,7 @@ class LifecycleAndFilter(BaseModel):
for key, value in self.tags.items():
data.append(
{"type": "LifecycleTagPredicate", "tag": {"key": key, "value": value},}
{"type": "LifecycleTagPredicate", "tag": {"key": key, "value": value}}
)
return data
@ -1058,9 +1058,6 @@ class FakeBucket(CloudFormationModel):
self.accelerate_configuration = accelerate_config
def set_website_configuration(self, website_configuration):
self.website_configuration = website_configuration
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -1382,12 +1379,16 @@ class S3Backend(BaseBackend):
def set_bucket_website_configuration(self, bucket_name, website_configuration):
bucket = self.get_bucket(bucket_name)
bucket.set_website_configuration(website_configuration)
bucket.website_configuration = website_configuration
def get_bucket_website_configuration(self, bucket_name):
bucket = self.get_bucket(bucket_name)
return bucket.website_configuration
def delete_bucket_website(self, bucket_name):
bucket = self.get_bucket(bucket_name)
bucket.website_configuration = None
def get_bucket_public_access_block(self, bucket_name):
bucket = self.get_bucket(bucket_name)

View File

@ -60,6 +60,7 @@ from .models import (
FakeGrant,
FakeAcl,
FakeKey,
FakeMultipart,
)
from .utils import (
bucket_name_from_url,
@ -109,6 +110,7 @@ ACTION_MAP = {
"DELETE": {
"lifecycle": "PutLifecycleConfiguration",
"policy": "DeleteBucketPolicy",
"website": "DeleteBucketWebsite",
"tagging": "PutBucketTagging",
"cors": "PutBucketCORS",
"public_access_block": "DeletePublicAccessBlock",
@ -815,6 +817,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
elif "tagging" in querystring:
self.backend.delete_bucket_tagging(bucket_name)
return 204, {}, ""
elif "website" in querystring:
self.backend.delete_bucket_website(bucket_name)
return 204, {}, ""
elif "cors" in querystring:
self.backend.delete_bucket_cors(bucket_name)
return 204, {}, ""
@ -1212,7 +1217,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if_unmodified_since = str_to_rfc_1123_datetime(if_unmodified_since)
if key.last_modified > if_unmodified_since:
raise PreconditionFailed("If-Unmodified-Since")
if if_match and key.etag != if_match:
if if_match and key.etag not in [if_match, '"{0}"'.format(if_match)]:
raise PreconditionFailed("If-Match")
if if_modified_since:
@ -1509,6 +1514,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
grants = []
for header, value in headers.items():
header = header.lower()
if not header.startswith("x-amz-grant-"):
continue
@ -1523,7 +1529,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
grantees = []
for key_and_value in value.split(","):
key, value = re.match(
'([^=]+)="([^"]+)"', key_and_value.strip()
'([^=]+)="?([^"]+)"?', key_and_value.strip()
).groups()
if key.lower() == "id":
grantees.append(FakeGrantee(id=value))
@ -1765,7 +1771,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if body == b"" and "uploads" in query:
metadata = metadata_from_headers(request.headers)
multipart = self.backend.initiate_multipart(bucket_name, key_name, metadata)
multipart = FakeMultipart(key_name, metadata)
multipart.storage = request.headers.get("x-amz-storage-class", "STANDARD")
bucket = self.backend.get_bucket(bucket_name)
bucket.multiparts[multipart.id] = multipart
template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE)
response = template.render(
@ -1775,8 +1785,26 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if query.get("uploadId"):
body = self._complete_multipart_body(body)
upload_id = query["uploadId"][0]
key = self.backend.complete_multipart(bucket_name, upload_id, body)
multipart_id = query["uploadId"][0]
bucket = self.backend.get_bucket(bucket_name)
multipart = bucket.multiparts[multipart_id]
value, etag = multipart.complete(body)
if value is None:
return 400, {}, ""
del bucket.multiparts[multipart_id]
key = self.backend.set_object(
bucket_name,
multipart.key_name,
value,
storage=multipart.storage,
etag=etag,
multipart=multipart,
)
key.set_metadata(multipart.metadata)
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
headers = {}
if key.version_id:
@ -1788,6 +1816,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
bucket_name=bucket_name, key_name=key.name, etag=key.etag
),
)
elif "restore" in query:
es = minidom.parseString(body).getElementsByTagName("Days")
days = es[0].childNodes[0].wholeText
@ -1797,6 +1826,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
r = 200
key.restore(int(days))
return r, {}, ""
else:
raise NotImplementedError(
"Method POST had only been implemented for multipart uploads and restore operations, so far"
@ -2237,7 +2267,7 @@ S3_ALL_MULTIPARTS = (
<KeyMarker></KeyMarker>
<UploadIdMarker></UploadIdMarker>
<MaxUploads>1000</MaxUploads>
<IsTruncated>False</IsTruncated>
<IsTruncated>false</IsTruncated>
{% for upload in uploads %}
<Upload>
<Key>{{ upload.key_name }}</Key>
@ -2355,7 +2385,7 @@ S3_NO_LOGGING_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
"""
S3_ENCRYPTION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
<BucketEncryptionStatus xmlns="http://doc.s3.amazonaws.com/2006-03-01">
<ServerSideEncryptionConfiguration xmlns="http://doc.s3.amazonaws.com/2006-03-01">
{% for entry in encryption %}
<Rule>
<ApplyServerSideEncryptionByDefault>
@ -2364,9 +2394,10 @@ S3_ENCRYPTION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
<KMSMasterKeyID>{{ entry["Rule"]["ApplyServerSideEncryptionByDefault"]["KMSMasterKeyID"] }}</KMSMasterKeyID>
{% endif %}
</ApplyServerSideEncryptionByDefault>
<BucketKeyEnabled>{{ 'true' if entry["Rule"].get("BucketKeyEnabled") == 'true' else 'false' }}</BucketKeyEnabled>
</Rule>
{% endfor %}
</BucketEncryptionStatus>
</ServerSideEncryptionConfiguration>
"""
S3_INVALID_PRESIGNED_PARAMETERS = """<?xml version="1.0" encoding="UTF-8"?>

View File

@ -144,23 +144,28 @@ class _VersionedKeyStore(dict):
super(_VersionedKeyStore, self).__setitem__(key, list_)
def _iteritems(self):
for key in self:
for key in self._self_iterable():
yield key, self[key]
def _itervalues(self):
for key in self:
for key in self._self_iterable():
yield self[key]
def _iterlists(self):
for key in self:
for key in self._self_iterable():
yield key, self.getlist(key)
def item_size(self):
size = 0
for val in self.values():
for val in self._self_iterable().values():
size += sys.getsizeof(val)
return size
def _self_iterable(self):
# to enable concurrency, return a copy, to avoid "dictionary changed size during iteration"
# TODO: look into replacing with a locking mechanism, potentially
return dict(self)
items = iteritems = _iteritems
lists = iterlists = _iterlists
values = itervalues = _itervalues

View File

@ -7,6 +7,7 @@ import uuid
import datetime
from boto3 import Session
from typing import List, Tuple
from moto.core import BaseBackend, BaseModel
from .exceptions import (
@ -566,15 +567,46 @@ class SecretsManagerBackend(BaseBackend):
return response
def list_secrets(self, filters, max_results, next_token):
# TODO implement pagination and limits
def list_secrets(
self, filters: List, max_results: int = 100, next_token: str = None
) -> Tuple[List, str]:
"""
Returns secrets from secretsmanager.
The result is paginated and page items depends on the token value, because token contains start element
number of secret list.
Response example:
{
SecretList: [
{
ARN: 'arn:aws:secretsmanager:us-east-1:1234567890:secret:test1-gEcah',
Name: 'test1',
...
},
{
ARN: 'arn:aws:secretsmanager:us-east-1:1234567890:secret:test2-KZwml',
Name: 'test2',
...
}
],
NextToken: '2'
}
:param filters: (List) Filter parameters.
:param max_results: (int) Max number of results per page.
:param next_token: (str) Page token.
:return: (Tuple[List,str]) Returns result list and next token.
"""
secret_list = []
for secret in self.secrets.values():
if _matches(secret, filters):
secret_list.append(secret.to_dict())
return secret_list, None
starting_point = int(next_token or 0)
ending_point = starting_point + int(max_results or 100)
secret_page = secret_list[starting_point:ending_point]
new_next_token = str(ending_point) if ending_point < len(secret_list) else None
return secret_page, new_next_token
def delete_secret(
self, secret_id, recovery_window_in_days, force_delete_without_recovery

View File

@ -441,7 +441,7 @@ class SNSBackend(BaseBackend):
return self._get_values_nexttoken(self.topics, next_token)
def delete_topic_subscriptions(self, topic):
for key, value in self.subscriptions.items():
for key, value in dict(self.subscriptions).items():
if value.topic == topic:
self.subscriptions.pop(key)
@ -585,7 +585,10 @@ class SNSBackend(BaseBackend):
):
for endpoint in self.platform_endpoints.values():
if token == endpoint.token:
if attributes["Enabled"].lower() == endpoint.attributes["Enabled"]:
if (
attributes.get("Enabled", "").lower()
== endpoint.attributes["Enabled"]
):
return endpoint
raise DuplicateSnsEndpointError(
"Duplicate endpoint token with different attributes: %s" % token

View File

@ -477,6 +477,7 @@ class Queue(CloudFormationModel):
@property
def messages(self):
# TODO: This can become very inefficient if a large number of messages are in-flight
return [
message
for message in self._messages
@ -832,6 +833,7 @@ class SQSBackend(BaseBackend):
if (
queue.dead_letter_queue is not None
and queue.redrive_policy
and message.approximate_receive_count
>= queue.redrive_policy["maxReceiveCount"]
):

View File

@ -3,6 +3,13 @@ import random
import string
def str2bool(v):
if v in ("yes", True, "true", "True", "TRUE", "t", "1"):
return True
elif v in ("no", False, "false", "False", "FALSE", "f", "0"):
return False
def random_string(length=None):
n = length or 20
random_str = "".join(
@ -20,3 +27,10 @@ def load_resource(filename, as_json=True):
"""
with open(filename, "r", encoding="utf-8") as f:
return json.load(f) if as_json else f.read()
def merge_multiple_dicts(*args):
result = {}
for d in args:
result.update(d)
return result

View File

@ -1,7 +0,0 @@
FROM python:3.7-buster
ADD . /moto/
ENV PYTHONUNBUFFERED 1
WORKDIR /moto/
RUN make init
RUN make test

View File

@ -0,0 +1,11 @@
TestAccAWSEc2TransitGatewayDxGatewayAttachmentDataSource
TestAccAWSEc2TransitGatewayPeeringAttachment
TestAccAWSEc2TransitGatewayPeeringAttachmentAccepter
TestAccAWSEc2TransitGatewayPeeringAttachmentDataSource
TestAccAWSEc2TransitGatewayRoute
TestAccAWSEc2TransitGatewayRouteTableAssociation
TestAccAWSEc2TransitGatewayRouteTablePropagation
TestAccAWSEc2TransitGatewayVpcAttachment
TestAccAWSEc2TransitGatewayVpcAttachmentDataSource
TestAccAWSFms
TestAccAWSIAMRolePolicy

View File

@ -4,29 +4,51 @@ TestAccAWSBillingServiceAccount
TestAccAWSCallerIdentity
TestAccAWSCloudTrailServiceAccount
TestAccAWSCloudWatchDashboard
TestAccAWSCloudWatchEventApiDestination
TestAccAWSCloudWatchEventArchive
TestAccAWSCloudWatchEventBus
TestAccAWSCloudwatchLogGroupDataSource
TestAccAWSDataSourceCloudwatch
TestAccAWSDataSourceElasticBeanstalkHostedZone
TestAccAWSDataSourceIAMGroup
TestAccAWSDataSourceIAMInstanceProfile
TestAccAWSDataSourceIAMPolicy
TestAccAWSDataSourceIAMPolicyDocument
TestAccAWSDataSourceIAMRole
TestAccAWSDataSourceIAMSessionContext
TestAccAWSDataSourceIAMUser
TestAccAWSDefaultSecurityGroup
TestAccAWSDefaultSubnet
TestAccAWSDefaultTagsDataSource
TestAccAWSDynamoDbTableItem
TestAccAWSEc2InstanceTypeOfferingDataSource
TestAccAWSEc2InstanceTypeOfferingsDataSource
TestAccAWSEc2Tag
TestAccAWSEc2TransitGateway
TestAccAWSEc2TransitGatewayDataSource
TestAccAWSEc2TransitGatewayRouteTable
TestAccAWSEc2TransitGatewayRouteTableDataSource
TestAccAWSEc2TransitGatewayVpcAttachmentAccepter
TestAccAWSEc2TransitGatewayVpnAttachmentDataSource
TestAccAWSElasticBeanstalkSolutionStackDataSource
TestAccAWSElbHostedZoneId
TestAccAWSElbServiceAccount
TestAccAWSFms
TestAccAWSGroupMembership
TestAccAWSIAMAccountAlias
TestAccAWSIAMGroupPolicy
TestAccAWSIAMGroupPolicyAttachment
TestAccAWSIAMRole
TestAccAWSIAMUserPolicy
TestAccAWSIPRanges
TestAccAWSKmsSecretDataSource
TestAccAWSPartition
TestAccAWSProvider
TestAccAWSRedshiftServiceAccount
TestAccAWSRolePolicyAttachment
TestAccAWSSNSSMSPreferences
TestAccAWSSageMakerPrebuiltECRImage
TestAccAWSSsmParameterDataSource
TestAccAWSUserPolicyAttachment
TestAccAWSUserGroupMembership
TestAccAWSUserPolicyAttachment
TestAccAWSUserSSHKey

View File

@ -9,6 +9,7 @@ import sure # noqa
from botocore.exceptions import ClientError
from moto import mock_apigateway, mock_cognitoidp, settings
from moto.apigateway.exceptions import NoIntegrationDefined
from moto.core import ACCOUNT_ID
from moto.core.models import responses_mock
import pytest
@ -474,13 +475,7 @@ def test_integrations():
response.should.equal(
{
"ResponseMetadata": {"HTTPStatusCode": 200},
"httpMethod": "GET",
"integrationResponses": {
"200": {
"responseTemplates": {"application/json": None},
"statusCode": 200,
}
},
"httpMethod": "POST",
"type": "HTTP",
"uri": "http://httpbin.org/robots.txt",
}
@ -495,13 +490,7 @@ def test_integrations():
response.should.equal(
{
"ResponseMetadata": {"HTTPStatusCode": 200},
"httpMethod": "GET",
"integrationResponses": {
"200": {
"responseTemplates": {"application/json": None},
"statusCode": 200,
}
},
"httpMethod": "POST",
"type": "HTTP",
"uri": "http://httpbin.org/robots.txt",
}
@ -511,18 +500,10 @@ def test_integrations():
# this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None)
response["resourceMethods"]["GET"]["httpMethod"].should.equal("GET")
response["resourceMethods"]["GET"]["authorizationType"].should.equal("none")
response["resourceMethods"]["GET"]["methodIntegration"].should.equal(
{
"httpMethod": "GET",
"integrationResponses": {
"200": {
"responseTemplates": {"application/json": None},
"statusCode": 200,
}
},
"type": "HTTP",
"uri": "http://httpbin.org/robots.txt",
}
{"httpMethod": "POST", "type": "HTTP", "uri": "http://httpbin.org/robots.txt",}
)
client.delete_integration(restApiId=api_id, resourceId=root_id, httpMethod="GET")
@ -611,7 +592,7 @@ def test_integration_response():
"statusCode": "200",
"selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None},
"responseTemplates": {}, # Note: TF compatibility
}
)
@ -626,7 +607,7 @@ def test_integration_response():
"statusCode": "200",
"selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None},
"responseTemplates": {}, # Note: TF compatibility
}
)
@ -637,7 +618,7 @@ def test_integration_response():
response["methodIntegration"]["integrationResponses"].should.equal(
{
"200": {
"responseTemplates": {"application/json": None},
"responseTemplates": {}, # Note: TF compatibility
"selectionPattern": "foobar",
"statusCode": "200",
}
@ -687,7 +668,7 @@ def test_integration_response():
"statusCode": "200",
"selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None},
"responseTemplates": {}, # Note: TF compatibility
"contentHandling": "CONVERT_TO_BINARY",
}
)
@ -703,7 +684,7 @@ def test_integration_response():
"statusCode": "200",
"selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None},
"responseTemplates": {}, # Note: TF compatibility
"contentHandling": "CONVERT_TO_BINARY",
}
)
@ -1277,7 +1258,7 @@ def test_create_deployment_requires_REST_method_integrations():
with pytest.raises(ClientError) as ex:
client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
ex.value.response["Error"]["Code"].should.equal("BadRequestException")
ex.value.response["Error"]["Code"].should.equal("NotFoundException")
ex.value.response["Error"]["Message"].should.equal(
"No integration defined for method"
)
@ -1873,7 +1854,7 @@ def test_http_proxying_integration():
httpMethod="GET",
type="HTTP",
uri="http://httpbin.org/robots.txt",
integrationHttpMethod="POST",
integrationHttpMethod="GET",
)
stage_name = "staging"

View File

@ -2542,7 +2542,8 @@ def test_stack_elbv2_resources_integration():
]
},
}
]
],
[{"Type": "forward", "TargetGroupArn": target_groups[1]["TargetGroupArn"]}],
)
listener_rule[0]["Conditions"].should.equal(
[{"Field": "path-pattern", "Values": ["/*"]}]

View File

@ -349,6 +349,75 @@ def test_get_metric_statistics():
datapoint["Sum"].should.equal(1.5)
@mock_cloudwatch
def test_get_metric_statistics_dimensions():
conn = boto3.client("cloudwatch", region_name="us-east-1")
utc_now = datetime.now(tz=pytz.utc)
# put metric data with different dimensions
dimensions1 = [{"Name": "dim1", "Value": "v1"}]
dimensions2 = dimensions1 + [{"Name": "dim2", "Value": "v2"}]
metric_name = "metr-stats-dims"
conn.put_metric_data(
Namespace="tester",
MetricData=[
dict(
MetricName=metric_name,
Value=1,
Timestamp=utc_now,
Dimensions=dimensions1,
)
],
)
conn.put_metric_data(
Namespace="tester",
MetricData=[
dict(
MetricName=metric_name,
Value=2,
Timestamp=utc_now,
Dimensions=dimensions1,
)
],
)
conn.put_metric_data(
Namespace="tester",
MetricData=[
dict(
MetricName=metric_name,
Value=6,
Timestamp=utc_now,
Dimensions=dimensions2,
)
],
)
# list of (<kwargs>, <expectedSum>, <expectedAverage>)
params_list = (
# get metric stats with no restriction on dimensions
({}, 9, 3),
# get metric stats for dimensions1 (should also cover dimensions2)
({"Dimensions": dimensions1}, 9, 3),
# get metric stats for dimensions2 only
({"Dimensions": dimensions2}, 6, 6),
)
for params in params_list:
stats = conn.get_metric_statistics(
Namespace="tester",
MetricName=metric_name,
StartTime=utc_now - timedelta(seconds=60),
EndTime=utc_now + timedelta(seconds=60),
Period=60,
Statistics=["Average", "Sum"],
**params[0],
)
stats["Datapoints"].should.have.length_of(1)
datapoint = stats["Datapoints"][0]
datapoint["Sum"].should.equal(params[1])
datapoint["Average"].should.equal(params[2])
@mock_cloudwatch
def test_duplicate_put_metric_data():
conn = boto3.client("cloudwatch", region_name="us-east-1")
@ -501,16 +570,8 @@ def test_list_metrics():
# Verify format
res.should.equal(
[
{
u"Namespace": "list_test_1/",
u"Dimensions": [],
u"MetricName": "metric1",
},
{
u"Namespace": "list_test_1/",
u"Dimensions": [],
u"MetricName": "metric1",
},
{"Namespace": "list_test_1/", "Dimensions": [], "MetricName": "metric1",},
{"Namespace": "list_test_1/", "Dimensions": [], "MetricName": "metric1",},
]
)
# Verify unknown namespace still has no results

View File

@ -156,7 +156,7 @@ def test_create_task():
@mock_datasync
def test_create_task_fail():
""" Test that Locations must exist before a Task can be created """
"""Test that Locations must exist before a Task can be created"""
client = boto3.client("datasync", region_name="us-east-1")
locations = create_locations(client, create_smb=True, create_s3=True)
with pytest.raises(ClientError) as e:

View File

@ -5970,7 +5970,7 @@ def test_dynamodb_update_item_fails_on_string_sets():
BillingMode="PAY_PER_REQUEST",
)
table.meta.client.get_waiter("table_exists").wait(TableName="test")
attribute = {"test_field": {"Value": {"SS": ["test1", "test2"],}, "Action": "PUT"}}
attribute = {"test_field": {"Value": {"SS": ["test1", "test2"],}, "Action": "PUT",}}
client.update_item(
TableName="test",

View File

@ -88,7 +88,7 @@ def test_validation_of_update_expression_with_keyword(table):
@pytest.mark.parametrize(
"update_expression", ["SET a = #b + :val2", "SET a = :val2 + #b",]
"update_expression", ["SET a = #b + :val2", "SET a = :val2 + #b",],
)
def test_validation_of_a_set_statement_with_incorrect_passed_value(
update_expression, table
@ -150,7 +150,9 @@ def test_validation_of_update_expression_with_attribute_that_does_not_exist_in_i
assert True
@pytest.mark.parametrize("update_expression", ["SET a = #c", "SET a = #c + #d",])
@pytest.mark.parametrize(
"update_expression", ["SET a = #c", "SET a = #c + #d",],
)
def test_validation_of_update_expression_with_attribute_name_that_is_not_defined(
update_expression, table,
):

View File

@ -290,7 +290,7 @@ def test_ami_filters():
amis_by_architecture = conn.get_all_images(filters={"architecture": "x86_64"})
set([ami.id for ami in amis_by_architecture]).should.contain(imageB.id)
len(amis_by_architecture).should.equal(35)
len(amis_by_architecture).should.equal(37)
amis_by_kernel = conn.get_all_images(filters={"kernel-id": "k-abcd1234"})
set([ami.id for ami in amis_by_kernel]).should.equal(set([imageB.id]))
@ -303,7 +303,7 @@ def test_ami_filters():
amis_by_platform = conn.get_all_images(filters={"platform": "windows"})
set([ami.id for ami in amis_by_platform]).should.contain(imageA.id)
len(amis_by_platform).should.equal(24)
len(amis_by_platform).should.equal(25)
amis_by_id = conn.get_all_images(filters={"image-id": imageA.id})
set([ami.id for ami in amis_by_id]).should.equal(set([imageA.id]))
@ -312,14 +312,14 @@ def test_ami_filters():
ami_ids_by_state = [ami.id for ami in amis_by_state]
ami_ids_by_state.should.contain(imageA.id)
ami_ids_by_state.should.contain(imageB.id)
len(amis_by_state).should.equal(36)
len(amis_by_state).should.equal(38)
amis_by_name = conn.get_all_images(filters={"name": imageA.name})
set([ami.id for ami in amis_by_name]).should.equal(set([imageA.id]))
amis_by_public = conn.get_all_images(filters={"is-public": "true"})
set([ami.id for ami in amis_by_public]).should.contain(imageB.id)
len(amis_by_public).should.equal(35)
len(amis_by_public).should.equal(37)
amis_by_nonpublic = conn.get_all_images(filters={"is-public": "false"})
set([ami.id for ami in amis_by_nonpublic]).should.contain(imageA.id)

View File

@ -38,7 +38,8 @@ def test_delete_customer_gateways():
cgws[0].id.should.match(customer_gateway.id)
deleted = conn.delete_customer_gateway(customer_gateway.id)
cgws = conn.get_all_customer_gateways()
cgws.should.have.length_of(0)
cgws[0].state.should.equal("deleted")
cgws.should.have.length_of(1)
@mock_ec2_deprecated

View File

@ -16,7 +16,7 @@ SAMPLE_NAME_SERVERS = ["10.0.0.6", "10.0.0.7"]
@mock_ec2_deprecated
def test_dhcp_options_associate():
""" associate dhcp option """
"""associate dhcp option"""
conn = boto.connect_vpc("the_key", "the_secret")
dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS)
vpc = conn.create_vpc("10.0.0.0/16")
@ -27,7 +27,7 @@ def test_dhcp_options_associate():
@mock_ec2_deprecated
def test_dhcp_options_associate_invalid_dhcp_id():
""" associate dhcp option bad dhcp options id """
"""associate dhcp option bad dhcp options id"""
conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc("10.0.0.0/16")
@ -40,7 +40,7 @@ def test_dhcp_options_associate_invalid_dhcp_id():
@mock_ec2_deprecated
def test_dhcp_options_associate_invalid_vpc_id():
""" associate dhcp option invalid vpc id """
"""associate dhcp option invalid vpc id"""
conn = boto.connect_vpc("the_key", "the_secret")
dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS)

View File

@ -205,7 +205,7 @@ def test_eip_boto3_vpc_association():
address.association_id.should.be.none
address.instance_id.should.be.empty
address.network_interface_id.should.be.empty
association_id = client.associate_address(
client.associate_address(
InstanceId=instance.id, AllocationId=allocation_id, AllowReassociation=False
)
instance.load()
@ -287,7 +287,6 @@ def test_eip_reassociate():
).should_not.throw(EC2ResponseError)
eip.release()
eip = None
instance1.terminate()
instance2.terminate()
@ -326,7 +325,7 @@ def test_eip_reassociate_nic():
@mock_ec2_deprecated
def test_eip_associate_invalid_args():
"""Associate EIP, invalid args """
"""Associate EIP, invalid args"""
conn = boto.connect_ec2("the_key", "the_secret")
reservation = conn.run_instances(EXAMPLE_AMI_ID)

View File

@ -27,7 +27,7 @@ def test_elastic_network_interfaces():
"An error occurred (DryRunOperation) when calling the CreateNetworkInterface operation: Request would have succeeded, but DryRun flag is set"
)
eni = conn.create_network_interface(subnet.id)
conn.create_network_interface(subnet.id)
all_enis = conn.get_all_network_interfaces()
all_enis.should.have.length_of(1)

View File

@ -6,6 +6,7 @@ import boto
import boto3
from boto.exception import EC2ResponseError
import sure # noqa
from botocore.exceptions import ClientError
from moto import mock_ec2_deprecated, mock_ec2
from tests import EXAMPLE_AMI_ID
@ -26,7 +27,7 @@ def test_console_output_without_instance():
with pytest.raises(EC2ResponseError) as cm:
conn.get_console_output("i-1234abcd")
cm.value.code.should.equal("InvalidInstanceID.NotFound")
cm.value.error_code.should.equal("InvalidInstanceID.NotFound")
cm.value.status.should.equal(400)
cm.value.request_id.should_not.be.none

View File

@ -21,7 +21,7 @@ BAD_IGW = "igw-deadbeef"
@mock_ec2_deprecated
def test_igw_create():
""" internet gateway create """
"""internet gateway create"""
conn = boto.connect_vpc("the_key", "the_secret")
conn.get_all_internet_gateways().should.have.length_of(0)
@ -44,7 +44,7 @@ def test_igw_create():
@mock_ec2_deprecated
def test_igw_attach():
""" internet gateway attach """
"""internet gateway attach"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR)
@ -65,7 +65,7 @@ def test_igw_attach():
@mock_ec2_deprecated
def test_igw_attach_bad_vpc():
""" internet gateway fail to attach w/ bad vpc """
"""internet gateway fail to attach w/ bad vpc"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
@ -78,7 +78,7 @@ def test_igw_attach_bad_vpc():
@mock_ec2_deprecated
def test_igw_attach_twice():
""" internet gateway fail to attach twice """
"""internet gateway fail to attach twice"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
vpc1 = conn.create_vpc(VPC_CIDR)
@ -94,7 +94,7 @@ def test_igw_attach_twice():
@mock_ec2_deprecated
def test_igw_detach():
""" internet gateway detach"""
"""internet gateway detach"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR)
@ -115,7 +115,7 @@ def test_igw_detach():
@mock_ec2_deprecated
def test_igw_detach_wrong_vpc():
""" internet gateway fail to detach w/ wrong vpc """
"""internet gateway fail to detach w/ wrong vpc"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
vpc1 = conn.create_vpc(VPC_CIDR)
@ -131,7 +131,7 @@ def test_igw_detach_wrong_vpc():
@mock_ec2_deprecated
def test_igw_detach_invalid_vpc():
""" internet gateway fail to detach w/ invalid vpc """
"""internet gateway fail to detach w/ invalid vpc"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR)
@ -146,7 +146,7 @@ def test_igw_detach_invalid_vpc():
@mock_ec2_deprecated
def test_igw_detach_unattached():
""" internet gateway fail to detach unattached """
"""internet gateway fail to detach unattached"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR)
@ -160,7 +160,7 @@ def test_igw_detach_unattached():
@mock_ec2_deprecated
def test_igw_delete():
""" internet gateway delete"""
"""internet gateway delete"""
conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc(VPC_CIDR)
conn.get_all_internet_gateways().should.have.length_of(0)
@ -181,7 +181,7 @@ def test_igw_delete():
@mock_ec2_deprecated
def test_igw_delete_attached():
""" internet gateway fail to delete attached """
"""internet gateway fail to delete attached"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR)
@ -196,7 +196,7 @@ def test_igw_delete_attached():
@mock_ec2_deprecated
def test_igw_desribe():
""" internet gateway fetch by id """
"""internet gateway fetch by id"""
conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway()
igw_by_search = conn.get_all_internet_gateways([igw.id])[0]
@ -205,7 +205,7 @@ def test_igw_desribe():
@mock_ec2_deprecated
def test_igw_describe_bad_id():
""" internet gateway fail to fetch by bad id """
"""internet gateway fail to fetch by bad id"""
conn = boto.connect_vpc("the_key", "the_secret")
with pytest.raises(EC2ResponseError) as cm:
conn.get_all_internet_gateways([BAD_IGW])
@ -216,7 +216,7 @@ def test_igw_describe_bad_id():
@mock_ec2_deprecated
def test_igw_filter_by_vpc_id():
""" internet gateway filter by vpc id """
"""internet gateway filter by vpc id"""
conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway()
@ -231,7 +231,7 @@ def test_igw_filter_by_vpc_id():
@mock_ec2_deprecated
def test_igw_filter_by_tags():
""" internet gateway filter by vpc id """
"""internet gateway filter by vpc id"""
conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway()
@ -245,7 +245,7 @@ def test_igw_filter_by_tags():
@mock_ec2_deprecated
def test_igw_filter_by_internet_gateway_id():
""" internet gateway filter by internet gateway id """
"""internet gateway filter by internet gateway id"""
conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway()
@ -258,7 +258,7 @@ def test_igw_filter_by_internet_gateway_id():
@mock_ec2_deprecated
def test_igw_filter_by_attachment_state():
""" internet gateway filter by attachment state """
"""internet gateway filter by attachment state"""
conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway()

View File

@ -175,6 +175,7 @@ def test_modify_subnet_attribute_assign_ipv6_address_on_creation():
# 'map_public_ip_on_launch' is set when calling 'DescribeSubnets' action
subnet.reload()
subnets = client.describe_subnets()
# For non default subnet, attribute value should be 'False'
subnet.assign_ipv6_address_on_creation.shouldnt.be.ok

View File

@ -36,7 +36,7 @@ def test_describe_vpn_gateway():
@mock_ec2
def test_describe_vpn_connections_attachment_vpc_id_filter():
""" describe_vpn_gateways attachment.vpc-id filter """
"""describe_vpn_gateways attachment.vpc-id filter"""
ec2 = boto3.client("ec2", region_name="us-east-1")
@ -60,7 +60,7 @@ def test_describe_vpn_connections_attachment_vpc_id_filter():
@mock_ec2
def test_describe_vpn_connections_state_filter_attached():
""" describe_vpn_gateways attachment.state filter - match attached """
"""describe_vpn_gateways attachment.state filter - match attached"""
ec2 = boto3.client("ec2", region_name="us-east-1")
@ -84,7 +84,7 @@ def test_describe_vpn_connections_state_filter_attached():
@mock_ec2
def test_describe_vpn_connections_state_filter_deatched():
""" describe_vpn_gateways attachment.state filter - don't match detatched """
"""describe_vpn_gateways attachment.state filter - don't match detatched"""
ec2 = boto3.client("ec2", region_name="us-east-1")
@ -104,7 +104,7 @@ def test_describe_vpn_connections_state_filter_deatched():
@mock_ec2
def test_describe_vpn_connections_id_filter_match():
""" describe_vpn_gateways vpn-gateway-id filter - match correct id """
"""describe_vpn_gateways vpn-gateway-id filter - match correct id"""
ec2 = boto3.client("ec2", region_name="us-east-1")
@ -121,7 +121,7 @@ def test_describe_vpn_connections_id_filter_match():
@mock_ec2
def test_describe_vpn_connections_id_filter_miss():
""" describe_vpn_gateways vpn-gateway-id filter - don't match """
"""describe_vpn_gateways vpn-gateway-id filter - don't match"""
ec2 = boto3.client("ec2", region_name="us-east-1")
@ -136,7 +136,7 @@ def test_describe_vpn_connections_id_filter_miss():
@mock_ec2
def test_describe_vpn_connections_type_filter_match():
""" describe_vpn_gateways type filter - match """
"""describe_vpn_gateways type filter - match"""
ec2 = boto3.client("ec2", region_name="us-east-1")
@ -153,7 +153,7 @@ def test_describe_vpn_connections_type_filter_match():
@mock_ec2
def test_describe_vpn_connections_type_filter_miss():
""" describe_vpn_gateways type filter - don't match """
"""describe_vpn_gateways type filter - don't match"""
ec2 = boto3.client("ec2", region_name="us-east-1")

View File

@ -8,7 +8,7 @@ import boto3
import boto
from boto.exception import EC2ResponseError
# import sure # noqa
import sure # noqa
from moto import mock_ec2, mock_ec2_deprecated
@ -919,4 +919,4 @@ def test_describe_vpc_end_points():
VpcEndpointIds=[route_table.get("RouteTable").get("RouteTableId")]
)
except ClientError as err:
assert err.response["Error"]["Code"] == "InvalidVpcEndPointId.NotFound"
assert err.response["Error"]["Code"] == "InvalidVpcEndpointId.NotFound"

View File

@ -29,7 +29,7 @@ def test_delete_vpn_connections():
list_of_vpn_connections.should.have.length_of(1)
conn.delete_vpn_connection(vpn_connection.id)
list_of_vpn_connections = conn.get_all_vpn_connections()
list_of_vpn_connections.should.have.length_of(0)
list_of_vpn_connections[0].state.should.equal("deleted")
@mock_ec2_deprecated

View File

@ -88,3 +88,12 @@ def test_event_pattern_with_multi_numeric_event_filter():
assert two_or_three.matches_event(events[2])
assert two_or_three.matches_event(events[3])
assert not two_or_three.matches_event(events[4])
@pytest.mark.parametrize(
"pattern, expected_str",
[('{"source": ["foo", "bar"]}', '{"source": ["foo", "bar"]}'), (None, ""),],
)
def test_event_pattern_str(pattern, expected_str):
event_pattern = EventPattern(pattern)
assert str(event_pattern) == expected_str

View File

@ -4,11 +4,10 @@ import unittest
from datetime import datetime
import boto3
import pytest
import pytz
import sure # noqa
from botocore.exceptions import ClientError
import pytest
from moto import mock_logs
from moto.core import ACCOUNT_ID
@ -2157,15 +2156,15 @@ def test_create_and_list_connections():
},
)
assert response.get(
"ConnectionArn"
) == "arn:aws:events:eu-central-1:{0}:connection/test".format(ACCOUNT_ID)
response.get("ConnectionArn").should.contain(
"arn:aws:events:eu-central-1:{0}:connection/test/".format(ACCOUNT_ID)
)
response = client.list_connections()
assert response.get("Connections")[0].get(
"ConnectionArn"
) == "arn:aws:events:eu-central-1:{0}:connection/test".format(ACCOUNT_ID)
response.get("Connections")[0].get("ConnectionArn").should.contain(
"arn:aws:events:eu-central-1:{0}:connection/test/".format(ACCOUNT_ID)
)
@mock_events
@ -2189,27 +2188,116 @@ def test_create_and_list_api_destinations():
HttpMethod="GET",
)
assert destination_response.get(
"ApiDestinationArn"
) == "arn:aws:events:eu-central-1:{0}:destination/test".format(ACCOUNT_ID)
arn_without_uuid = f"arn:aws:events:eu-central-1:{ACCOUNT_ID}:api-destination/test/"
assert destination_response.get("ApiDestinationArn").startswith(arn_without_uuid)
assert destination_response.get("ApiDestinationState") == "ACTIVE"
destination_response = client.describe_api_destination(Name="test")
assert destination_response.get(
"ApiDestinationArn"
) == "arn:aws:events:eu-central-1:{0}:destination/test".format(ACCOUNT_ID)
assert destination_response.get("ApiDestinationArn").startswith(arn_without_uuid)
assert destination_response.get("Name") == "test"
assert destination_response.get("ApiDestinationState") == "ACTIVE"
destination_response = client.list_api_destinations()
assert destination_response.get("ApiDestinations")[0].get(
"ApiDestinationArn"
) == "arn:aws:events:eu-central-1:{0}:destination/test".format(ACCOUNT_ID)
assert (
destination_response.get("ApiDestinations")[0]
.get("ApiDestinationArn")
.startswith(arn_without_uuid)
)
assert destination_response.get("ApiDestinations")[0].get("Name") == "test"
assert (
destination_response.get("ApiDestinations")[0].get("ApiDestinationState")
== "ACTIVE"
)
# Scenarios for describe_connection
# Scenario 01: Success
# Scenario 02: Failure - Connection not present
@mock_events
def test_describe_connection_success():
# Given
conn_name = "test_conn_name"
conn_description = "test_conn_description"
auth_type = "API_KEY"
auth_params = {
"ApiKeyAuthParameters": {"ApiKeyName": "test", "ApiKeyValue": "test"}
}
client = boto3.client("events", "eu-central-1")
_ = client.create_connection(
Name=conn_name,
Description=conn_description,
AuthorizationType=auth_type,
AuthParameters=auth_params,
)
# When
response = client.describe_connection(Name=conn_name)
# Then
assert response["Name"] == conn_name
assert response["Description"] == conn_description
assert response["AuthorizationType"] == auth_type
expected_auth_param = {"ApiKeyAuthParameters": {"ApiKeyName": "test"}}
assert response["AuthParameters"] == expected_auth_param
@mock_events
def test_describe_connection_not_present():
conn_name = "test_conn_name"
client = boto3.client("events", "eu-central-1")
# When/Then
with pytest.raises(ClientError):
_ = client.describe_connection(Name=conn_name)
# Scenarios for delete_connection
# Scenario 01: Success
# Scenario 02: Failure - Connection not present
@mock_events
def test_delete_connection_success():
# Given
conn_name = "test_conn_name"
conn_description = "test_conn_description"
auth_type = "API_KEY"
auth_params = {
"ApiKeyAuthParameters": {"ApiKeyName": "test", "ApiKeyValue": "test"}
}
client = boto3.client("events", "eu-central-1")
created_connection = client.create_connection(
Name=conn_name,
Description=conn_description,
AuthorizationType=auth_type,
AuthParameters=auth_params,
)
# When
response = client.delete_connection(Name=conn_name)
# Then
expected_arn = f"arn:aws:events:eu-central-1:{ACCOUNT_ID}:connection/{conn_name}/"
assert response["ConnectionArn"] == created_connection["ConnectionArn"]
assert response["ConnectionState"] == created_connection["ConnectionState"]
assert response["CreationTime"] == created_connection["CreationTime"]
with pytest.raises(ClientError):
_ = client.describe_connection(Name=conn_name)
@mock_events
def test_delete_connection_not_present():
conn_name = "test_conn_name"
client = boto3.client("events", "eu-central-1")
# When/Then
with pytest.raises(ClientError):
_ = client.delete_connection(Name=conn_name)

View File

@ -1,10 +1,11 @@
import os
import time
from unittest import SkipTest
import boto3
from botocore.exceptions import ClientError
import pytest
import sure # noqa
from botocore.exceptions import ClientError
from moto import mock_logs, settings
@ -12,14 +13,34 @@ _logs_region = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2"
@mock_logs
def test_create_log_group():
@pytest.mark.parametrize(
"kms_key_id",
[
"arn:aws:kms:us-east-1:000000000000:key/51d81fab-b138-4bd2-8a09-07fd6d37224d",
None,
],
)
def test_create_log_group(kms_key_id):
# Given
conn = boto3.client("logs", "us-west-2")
response = conn.create_log_group(logGroupName="dummy")
create_logs_params = dict(logGroupName="dummy")
if kms_key_id:
create_logs_params["kmsKeyId"] = kms_key_id
# When
response = conn.create_log_group(**create_logs_params)
response = conn.describe_log_groups()
# Then
response["logGroups"].should.have.length_of(1)
response["logGroups"][0].should_not.have.key("retentionInDays")
log_group = response["logGroups"][0]
log_group.should_not.have.key("retentionInDays")
if kms_key_id:
log_group.should.have.key("kmsKeyId")
log_group["kmsKeyId"].should.equal(kms_key_id)
@mock_logs

View File

@ -0,0 +1,25 @@
import sure # noqa
from moto.logs.models import LogGroup
def test_log_group_to_describe_dict():
# Given
region = "us-east-1"
name = "test-log-group"
tags = {"TestTag": "TestValue"}
kms_key_id = (
"arn:aws:kms:us-east-1:000000000000:key/51d81fab-b138-4bd2-8a09-07fd6d37224d"
)
kwargs = dict(kmsKeyId=kms_key_id,)
# When
log_group = LogGroup(region, name, tags, **kwargs)
describe_dict = log_group.to_describe_dict()
# Then
expected_dict = dict(logGroupName=name, kmsKeyId=kms_key_id)
for attr, value in expected_dict.items():
describe_dict.should.have.key(attr)
describe_dict[attr].should.equal(value)

View File

@ -4972,7 +4972,9 @@ def test_encryption():
resp = conn.get_bucket_encryption(Bucket="mybucket")
assert "ServerSideEncryptionConfiguration" in resp
assert resp["ServerSideEncryptionConfiguration"] == sse_config
return_config = sse_config.copy()
return_config["Rules"][0]["BucketKeyEnabled"] = False
assert resp["ServerSideEncryptionConfiguration"].should.equal(return_config)
conn.delete_bucket_encryption(Bucket="mybucket")
with pytest.raises(ClientError) as exc:

View File

@ -9,7 +9,6 @@ from freezegun import freeze_time
import pytest
import sure # noqa
from moto import mock_sts, mock_sts_deprecated, mock_iam, settings
from moto.core import ACCOUNT_ID
from moto.sts.responses import MAX_FEDERATION_TOKEN_POLICY_LENGTH
@ -467,17 +466,17 @@ def test_assume_role_with_saml_should_retrieve_attribute_value_from_text_when_xm
<saml:AttributeStatement>
<saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/RoleSessionName">
<saml:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:type="xs:string">{fed_name}</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/Role">
<saml:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:type="xs:string">arn:aws:iam::{account_id}:saml-provider/{provider_name},arn:aws:iam::{account_id}:role/{role_name}</saml:AttributeValue>
</saml:Attribute>
<saml:Attribute Name="https://aws.amazon.com/SAML/Attributes/SessionDuration">
<saml:AttributeValue xmlns:xs="http://www.w3.org/2001/XMLSchema"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:type="xs:string">900</saml:AttributeValue>
</saml:Attribute>
</saml:AttributeStatement>
@ -749,3 +748,61 @@ def test_sts_regions(region):
client = boto3.client("sts", region_name=region)
resp = client.get_caller_identity()
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
@mock_sts
@mock_iam
def test_get_caller_identity_with_iam_user_credentials():
iam_client = boto3.client("iam", region_name="us-east-1")
iam_user_name = "new-user"
iam_user = iam_client.create_user(UserName=iam_user_name)["User"]
access_key = iam_client.create_access_key(UserName=iam_user_name)["AccessKey"]
identity = boto3.client(
"sts",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
).get_caller_identity()
identity["Arn"].should.equal(iam_user["Arn"])
identity["UserId"].should.equal(iam_user["UserId"])
identity["Account"].should.equal(str(ACCOUNT_ID))
@mock_sts
@mock_iam
def test_get_caller_identity_with_assumed_role_credentials():
iam_client = boto3.client("iam", region_name="us-east-1")
sts_client = boto3.client("sts", region_name="us-east-1")
iam_role_name = "new-user"
trust_policy_document = {
"Version": "2012-10-17",
"Statement": {
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)
},
"Action": "sts:AssumeRole",
},
}
iam_role_arn = iam_client.role_arn = iam_client.create_role(
RoleName=iam_role_name,
AssumeRolePolicyDocument=json.dumps(trust_policy_document),
)["Role"]["Arn"]
session_name = "new-session"
assumed_role = sts_client.assume_role(
RoleArn=iam_role_arn, RoleSessionName=session_name
)
access_key = assumed_role["Credentials"]
identity = boto3.client(
"sts",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
).get_caller_identity()
identity["Arn"].should.equal(assumed_role["AssumedRoleUser"]["Arn"])
identity["UserId"].should.equal(assumed_role["AssumedRoleUser"]["AssumedRoleId"])
identity["Account"].should.equal(str(ACCOUNT_ID))

View File

@ -9,7 +9,7 @@ from moto import mock_support
@mock_support
def test_describe_trusted_advisor_checks_returns_amount_of_checks():
"""
test that the 104 checks that are listed under trusted advisor currently
test that the 104 checks that are listed under trusted advisor currently
are returned
"""
client = boto3.client("support", "us-east-1")
@ -49,7 +49,7 @@ def test_describe_trusted_advisor_checks_returns_an_expected_check_name():
@mock_support
def test_refresh_trusted_advisor_check_returns_expected_check():
"""
A refresh of a trusted advisor check returns the check id
A refresh of a trusted advisor check returns the check id
in the response
"""
client = boto3.client("support", "us-east-1")
@ -627,7 +627,7 @@ def test_support_created_case_can_be_described_and_contains_communications_when_
@mock_support
def test_support_created_case_can_be_described_and_does_not_contain_communications_when_false():
"""
On creating a support request it does not include
On creating a support request it does not include
comms when includeCommunications=False
"""
@ -665,7 +665,7 @@ def test_support_created_case_can_be_described_and_does_not_contain_communicatio
@mock_support
def test_support_created_case_can_be_described_and_contains_resolved_cases_when_true():
"""
On creating a support request it does contain resolved cases when
On creating a support request it does contain resolved cases when
includeResolvedCases=true
"""
@ -705,7 +705,7 @@ def test_support_created_case_can_be_described_and_contains_resolved_cases_when_
@mock_support
def test_support_created_case_can_be_described_and_does_not_contain_resolved_cases_when_false():
"""
On creating a support request it does not contain resolved cases when
On creating a support request it does not contain resolved cases when
includeResolvedCases=false
"""