Merge LocalStack changes into upstream moto (#4082)

* fix OPTIONS requests on non-existing API GW integrations

* add cloudformation models for API Gateway deployments

* bump version

* add backdoor to return CloudWatch metrics

* Updating implementation coverage

* Updating implementation coverage

* add cloudformation models for API Gateway deployments

* Updating implementation coverage

* Updating implementation coverage

* Implemented get-caller-identity returning real data depending on the access key used.

* bump version

* minor fixes

* fix Number data_type for SQS message attribute

* fix handling of encoding errors

* bump version

* make CF stack queryable before starting to initialize its resources

* bump version

* fix integration_method for API GW method integrations

* fix undefined status in CF FakeStack

* Fix apigateway issues with terraform v0.12.21
* resource_methods -> add handle for "DELETE" method
* integrations -> fix issue that "httpMethod" wasn't included in body request (this value was set as the value from refer method resource)

* bump version

* Fix setting http method for API gateway integrations (#6)

* bump version

* remove duplicate methods

* add storage class to S3 Key when completing multipart upload (#7)

* fix SQS performance issues; bump version

* add pagination to SecretsManager list-secrets (#9)

* fix default parameter groups in RDS

* fix adding S3 metadata headers with names containing dots (#13)

* Updating implementation coverage

* Updating implementation coverage

* add cloudformation models for API Gateway deployments

* Updating implementation coverage

* Updating implementation coverage

* Implemented get-caller-identity returning real data depending on the access key used.

* make CF stack queryable before starting to initialize its resources

* bump version

* remove duplicate methods

* fix adding S3 metadata headers with names containing dots (#13)

* Update amis.json to support EKS AMI mocks (#15)

* fix PascalCase for boolean value in ListMultipartUploads response (#17); fix _get_multi_param to parse nested list/dict query params

* determine non-zero container exit code in Batch API

* support filtering by dimensions in CW get_metric_statistics

* fix storing attributes for ELBv2 Route entities; API GW refactorings for TF tests

* add missing fields for API GW resources

* fix error messages for Route53 (TF-compat)

* various fixes for IAM resources (tf-compat)

* minor fixes for API GW models (tf-compat)

* minor fixes for API GW responses (tf-compat)

* add s3 exception for bucket notification filter rule validation

* change the way RESTErrors generate the response body and content-type header

* fix lint errors and disable "black" syntax enforcement

* remove return type hint in RESTError.get_body

* add RESTError XML template for IAM exceptions

* add support for API GW minimumCompressionSize

* fix casing getting PrivateDnsEnabled API GW attribute

* minor fixes for error responses

* fix escaping special chars for IAM role descriptions (tf-compat)

* minor fixes and tagging support for API GW and ELB v2 (tf-compat)

* Merge branch 'master' into localstack

* add "AlarmRule" attribute to enable support for composite CloudWatch metrics

* fix recursive parsing of complex/nested query params

* bump version

* add API to delete S3 website configurations (#18)

* use dict copy to allow parallelism and avoid concurrent modification exceptions in S3

* fix precondition check for etags in S3 (#19)

* minor fix for user filtering in Cognito

* fix API Gateway error response; avoid returning empty response templates (tf-compat)

* support tags and tracingEnabled attribute for API GW stages

* fix boolean value in S3 encryption response (#20)

* fix connection arn structure

* fix api destination arn structure

* black format

* release 2.0.3.37

* fix s3 exception tests

see botocore/parsers.py:1002 where RequestId is removed from parsed

* remove python 2 from build action

* add test failure annotations in build action

* fix events test arn comparisons

* fix s3 encryption response test

* return default value "0" if EC2 availableIpAddressCount is empty

* fix extracting SecurityGroupIds for EC2 VPC endpoints

* support deleting/updating API Gateway DomainNames

* fix(events): Return empty string instead of null when no pattern is specified in EventPattern (tf-compat) (#22)

* fix logic and revert CF changes to get tests running again (#21)

* add support for EC2 customer gateway API (#25)

* add support for EC2 Transit Gateway APIs (#24)

* feat(logs): add `kmsKeyId` into `LogGroup` entity (#23)

* minor change in ELBv2 logic to fix tests

* feat(events): add APIs to describe and delete CloudWatch Events connections (#26)

* add support for EC2 transit gateway route tables (#27)

* pass transit gateway route table ID in Describe API, minor refactoring (#29)

* add support for EC2 Transit Gateway Routes (#28)

* fix region on ACM certificate import (#31)

* add support for EC2 transit gateway attachments (#30)

* add support for EC2 Transit Gateway VPN attachments (#32)

* fix account ID for logs API

* add support for DeleteOrganization API

* feat(events): store raw filter representation for CloudWatch events patterns (tf-compat) (#36)

* feat(events): add support to describe/update/delete CloudWatch API destinations (#35)

* add Cognito UpdateIdentityPool, CW Logs PutResourcePolicy

* feat(events): add support for tags in EventBus API (#38)

* fix parameter validation for Batch compute environments (tf-compat)

* revert merge conflicts in IMPLEMENTATION_COVERAGE.md

* format code using black

* restore original README; re-enable and fix CloudFormation tests

* restore tests and old logic for CF stack parameters from SSM

* parameterize RequestId/RequestID in response messages and revert related test changes

* undo LocalStack-specific adaptations

* minor fix

* Update CodeCov config to reflect removal of Py2

* undo change related to CW metric filtering; add additional test for CW metric statistics with dimensions

* Terraform - Extend whitelist of running tests

Co-authored-by: acsbendi <acsbendi28@gmail.com>
Co-authored-by: Phan Duong <duongpv@outlook.com>
Co-authored-by: Thomas Rausch <thomas@thrau.at>
Co-authored-by: Macwan Nevil <macnev2013@gmail.com>
Co-authored-by: Dominik Schubert <dominik.schubert91@gmail.com>
Co-authored-by: Gonzalo Saad <saad.gonzalo.ale@gmail.com>
Co-authored-by: Mohit Alonja <monty16597@users.noreply.github.com>
Co-authored-by: Miguel Gagliardo <migag9@gmail.com>
Co-authored-by: Bert Blommers <info@bertblommers.nl>
This commit is contained in:
Waldemar Hummer 2021-07-26 16:21:17 +02:00 committed by GitHub
parent 7693d77333
commit f4f8527955
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
93 changed files with 2947 additions and 631 deletions

View File

@ -90,7 +90,7 @@ jobs:
with: with:
path: ${{ steps.pip-cache.outputs.dir }} path: ${{ steps.pip-cache.outputs.dir }}
key: pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}-4 key: pip-${{ matrix.python-version }}-${{ hashFiles('**/setup.py') }}-4
# Update PIP - recent version does not support PY2 though # Update PIP
- name: Update pip - name: Update pip
run: | run: |
# https://github.com/pypa/pip/issues/10201 # https://github.com/pypa/pip/issues/10201
@ -136,6 +136,7 @@ jobs:
run: | run: |
pip install -r requirements-dev.txt pip install -r requirements-dev.txt
pip install pytest-cov pip install pytest-cov
pip install pytest-github-actions-annotate-failures
- name: Test with pytest - name: Test with pytest
run: | run: |
make test-coverage make test-coverage
@ -229,7 +230,7 @@ jobs:
- name: Run Terraform Tests - name: Run Terraform Tests
run: | run: |
cd moto-terraform-tests cd moto-terraform-tests
bin/run-tests -i ../tests/terraform-tests.success.txt bin/run-tests -i ../tests/terraform-tests.success.txt -e ../tests/terraform-tests.failures.txt
cd .. cd ..
- name: "Create report" - name: "Create report"
run: | run: |

2
.gitignore vendored
View File

@ -1,4 +1,4 @@
moto.egg-info/* moto*.egg-info/*
dist/* dist/*
.cache .cache
.tox .tox

View File

@ -3,7 +3,6 @@ SHELL := /bin/bash
ifeq ($(TEST_SERVER_MODE), true) ifeq ($(TEST_SERVER_MODE), true)
# exclude test_kinesisvideoarchivedmedia # exclude test_kinesisvideoarchivedmedia
# because testing with moto_server is difficult with data-endpoint # because testing with moto_server is difficult with data-endpoint
TEST_EXCLUDE := -k 'not test_kinesisvideoarchivedmedia' TEST_EXCLUDE := -k 'not test_kinesisvideoarchivedmedia'
else else
TEST_EXCLUDE := TEST_EXCLUDE :=

View File

@ -448,11 +448,13 @@ class AWSCertificateManagerBackend(BaseBackend):
else: else:
# Will reuse provided ARN # Will reuse provided ARN
bundle = CertBundle( bundle = CertBundle(
certificate, private_key, chain=chain, region=region, arn=arn certificate, private_key, chain=chain, region=self.region, arn=arn
) )
else: else:
# Will generate a random ARN # Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region) bundle = CertBundle(
certificate, private_key, chain=chain, region=self.region
)
self._certificates[bundle.arn] = bundle self._certificates[bundle.arn] = bundle

View File

@ -1,8 +1,16 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.exceptions import RESTError from moto.core.exceptions import JsonRESTError
class BadRequestException(RESTError): class BadRequestException(JsonRESTError):
pass
class NotFoundException(JsonRESTError):
pass
class AccessDeniedException(JsonRESTError):
pass pass
@ -14,7 +22,7 @@ class AwsProxyNotAllowed(BadRequestException):
) )
class CrossAccountNotAllowed(RESTError): class CrossAccountNotAllowed(AccessDeniedException):
def __init__(self): def __init__(self):
super(CrossAccountNotAllowed, self).__init__( super(CrossAccountNotAllowed, self).__init__(
"AccessDeniedException", "Cross-account pass role is not allowed." "AccessDeniedException", "Cross-account pass role is not allowed."
@ -71,10 +79,19 @@ class InvalidRequestInput(BadRequestException):
) )
class NoIntegrationDefined(BadRequestException): class NoIntegrationDefined(NotFoundException):
def __init__(self): def __init__(self):
super(NoIntegrationDefined, self).__init__( super(NoIntegrationDefined, self).__init__(
"BadRequestException", "No integration defined for method" "NotFoundException", "No integration defined for method"
)
class NoIntegrationResponseDefined(NotFoundException):
code = 404
def __init__(self, code=None):
super(NoIntegrationResponseDefined, self).__init__(
"NotFoundException", "No integration defined for method, code '%s'" % code
) )
@ -85,7 +102,7 @@ class NoMethodDefined(BadRequestException):
) )
class AuthorizerNotFoundException(RESTError): class AuthorizerNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
@ -94,7 +111,7 @@ class AuthorizerNotFoundException(RESTError):
) )
class StageNotFoundException(RESTError): class StageNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
@ -103,7 +120,7 @@ class StageNotFoundException(RESTError):
) )
class ApiKeyNotFoundException(RESTError): class ApiKeyNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
@ -112,7 +129,7 @@ class ApiKeyNotFoundException(RESTError):
) )
class UsagePlanNotFoundException(RESTError): class UsagePlanNotFoundException(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
@ -121,7 +138,7 @@ class UsagePlanNotFoundException(RESTError):
) )
class ApiKeyAlreadyExists(RESTError): class ApiKeyAlreadyExists(JsonRESTError):
code = 409 code = 409
def __init__(self): def __init__(self):
@ -139,7 +156,7 @@ class InvalidDomainName(BadRequestException):
) )
class DomainNameNotFound(RESTError): class DomainNameNotFound(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
@ -166,7 +183,7 @@ class InvalidModelName(BadRequestException):
) )
class RestAPINotFound(RESTError): class RestAPINotFound(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
@ -175,7 +192,7 @@ class RestAPINotFound(RESTError):
) )
class ModelNotFound(RESTError): class ModelNotFound(NotFoundException):
code = 404 code = 404
def __init__(self): def __init__(self):
@ -184,10 +201,19 @@ class ModelNotFound(RESTError):
) )
class ApiKeyValueMinLength(RESTError): class ApiKeyValueMinLength(BadRequestException):
code = 400 code = 400
def __init__(self): def __init__(self):
super(ApiKeyValueMinLength, self).__init__( super(ApiKeyValueMinLength, self).__init__(
"BadRequestException", "API Key value should be at least 20 characters" "BadRequestException", "API Key value should be at least 20 characters"
) )
class MethodNotFoundException(NotFoundException):
code = 404
def __init__(self):
super(MethodNotFoundException, self).__init__(
"NotFoundException", "Invalid method properties specified"
)

View File

@ -33,6 +33,7 @@ from .exceptions import (
StageNotFoundException, StageNotFoundException,
RoleNotSpecified, RoleNotSpecified,
NoIntegrationDefined, NoIntegrationDefined,
NoIntegrationResponseDefined,
NoMethodDefined, NoMethodDefined,
ApiKeyAlreadyExists, ApiKeyAlreadyExists,
DomainNameNotFound, DomainNameNotFound,
@ -44,6 +45,7 @@ from .exceptions import (
ApiKeyValueMinLength, ApiKeyValueMinLength,
) )
from ..core.models import responses_mock from ..core.models import responses_mock
from moto.apigateway.exceptions import MethodNotFoundException
STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}" STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}"
@ -87,7 +89,12 @@ class IntegrationResponse(BaseModel, dict):
content_handling=None, content_handling=None,
): ):
if response_templates is None: if response_templates is None:
response_templates = {"application/json": None} # response_templates = {"application/json": None} # Note: removed for compatibility with TF
response_templates = {}
for key in response_templates.keys():
response_templates[key] = (
response_templates[key] or None
) # required for compatibility with TF
self["responseTemplates"] = response_templates self["responseTemplates"] = response_templates
self["statusCode"] = status_code self["statusCode"] = status_code
if selection_pattern: if selection_pattern:
@ -97,13 +104,26 @@ class IntegrationResponse(BaseModel, dict):
class Integration(BaseModel, dict): class Integration(BaseModel, dict):
def __init__(self, integration_type, uri, http_method, request_templates=None): def __init__(
self,
integration_type,
uri,
http_method,
request_templates=None,
tls_config=None,
cache_namespace=None,
):
super(Integration, self).__init__() super(Integration, self).__init__()
self["type"] = integration_type self["type"] = integration_type
self["uri"] = uri self["uri"] = uri
self["httpMethod"] = http_method self["httpMethod"] = http_method
self["requestTemplates"] = request_templates self["requestTemplates"] = request_templates
self["integrationResponses"] = {"200": IntegrationResponse(200)} # self["integrationResponses"] = {"200": IntegrationResponse(200)} # commented out (tf-compat)
self[
"integrationResponses"
] = None # prevent json serialization from including them if none provided
self["tlsConfig"] = tls_config
self["cacheNamespace"] = cache_namespace
def create_integration_response( def create_integration_response(
self, status_code, selection_pattern, response_templates, content_handling self, status_code, selection_pattern, response_templates, content_handling
@ -113,20 +133,27 @@ class Integration(BaseModel, dict):
integration_response = IntegrationResponse( integration_response = IntegrationResponse(
status_code, selection_pattern, response_templates, content_handling status_code, selection_pattern, response_templates, content_handling
) )
if self.get("integrationResponses") is None:
self["integrationResponses"] = {}
self["integrationResponses"][status_code] = integration_response self["integrationResponses"][status_code] = integration_response
return integration_response return integration_response
def get_integration_response(self, status_code): def get_integration_response(self, status_code):
return self["integrationResponses"][status_code] result = self.get("integrationResponses", {}).get(status_code)
if not result:
raise NoIntegrationResponseDefined(status_code)
return result
def delete_integration_response(self, status_code): def delete_integration_response(self, status_code):
return self["integrationResponses"].pop(status_code) return self.get("integrationResponses", {}).pop(status_code, None)
class MethodResponse(BaseModel, dict): class MethodResponse(BaseModel, dict):
def __init__(self, status_code): def __init__(self, status_code, response_models=None, response_parameters=None):
super(MethodResponse, self).__init__() super(MethodResponse, self).__init__()
self["statusCode"] = status_code self["statusCode"] = status_code
self["responseModels"] = response_models
self["responseParameters"] = response_parameters
class Method(CloudFormationModel, dict): class Method(CloudFormationModel, dict):
@ -136,11 +163,14 @@ class Method(CloudFormationModel, dict):
dict( dict(
httpMethod=method_type, httpMethod=method_type,
authorizationType=authorization_type, authorizationType=authorization_type,
authorizerId=None, authorizerId=kwargs.get("authorizer_id"),
authorizationScopes=kwargs.get("authorization_scopes"),
apiKeyRequired=kwargs.get("api_key_required") or False, apiKeyRequired=kwargs.get("api_key_required") or False,
requestParameters=None, requestParameters=None,
requestModels=None, requestModels=kwargs.get("request_models"),
methodIntegration=None, methodIntegration=None,
operationName=kwargs.get("operation_name"),
requestValidatorId=kwargs.get("request_validator_id"),
) )
) )
self.method_responses = {} self.method_responses = {}
@ -184,16 +214,18 @@ class Method(CloudFormationModel, dict):
) )
return m return m
def create_response(self, response_code): def create_response(self, response_code, response_models, response_parameters):
method_response = MethodResponse(response_code) method_response = MethodResponse(
response_code, response_models, response_parameters
)
self.method_responses[response_code] = method_response self.method_responses[response_code] = method_response
return method_response return method_response
def get_response(self, response_code): def get_response(self, response_code):
return self.method_responses[response_code] return self.method_responses.get(response_code)
def delete_response(self, response_code): def delete_response(self, response_code):
return self.method_responses.pop(response_code) return self.method_responses.pop(response_code, None)
class Resource(CloudFormationModel): class Resource(CloudFormationModel):
@ -279,29 +311,62 @@ class Resource(CloudFormationModel):
) )
return response.status_code, response.text return response.status_code, response.text
def add_method(self, method_type, authorization_type, api_key_required): def add_method(
self,
method_type,
authorization_type,
api_key_required,
request_models=None,
operation_name=None,
authorizer_id=None,
authorization_scopes=None,
request_validator_id=None,
):
if authorization_scopes and not isinstance(authorization_scopes, list):
authorization_scopes = [authorization_scopes]
method = Method( method = Method(
method_type=method_type, method_type=method_type,
authorization_type=authorization_type, authorization_type=authorization_type,
api_key_required=api_key_required, api_key_required=api_key_required,
request_models=request_models,
operation_name=operation_name,
authorizer_id=authorizer_id,
authorization_scopes=authorization_scopes,
request_validator_id=request_validator_id,
) )
self.resource_methods[method_type] = method self.resource_methods[method_type] = method
return method return method
def get_method(self, method_type): def get_method(self, method_type):
return self.resource_methods[method_type] method = self.resource_methods.get(method_type)
if not method:
raise MethodNotFoundException()
return method
def add_integration( def add_integration(
self, method_type, integration_type, uri, request_templates=None self,
method_type,
integration_type,
uri,
request_templates=None,
integration_method=None,
tls_config=None,
cache_namespace=None,
): ):
integration_method = integration_method or method_type
integration = Integration( integration = Integration(
integration_type, uri, method_type, request_templates=request_templates integration_type,
uri,
integration_method,
request_templates=request_templates,
tls_config=tls_config,
cache_namespace=cache_namespace,
) )
self.resource_methods[method_type]["methodIntegration"] = integration self.resource_methods[method_type]["methodIntegration"] = integration
return integration return integration
def get_integration(self, method_type): def get_integration(self, method_type):
return self.resource_methods[method_type]["methodIntegration"] return self.resource_methods.get(method_type, {}).get("methodIntegration", {})
def delete_integration(self, method_type): def delete_integration(self, method_type):
return self.resource_methods[method_type].pop("methodIntegration") return self.resource_methods[method_type].pop("methodIntegration")
@ -364,6 +429,8 @@ class Stage(BaseModel, dict):
description="", description="",
cacheClusterEnabled=False, cacheClusterEnabled=False,
cacheClusterSize=None, cacheClusterSize=None,
tags=None,
tracing_enabled=None,
): ):
super(Stage, self).__init__() super(Stage, self).__init__()
if variables is None: if variables is None:
@ -376,9 +443,12 @@ class Stage(BaseModel, dict):
self["cacheClusterEnabled"] = cacheClusterEnabled self["cacheClusterEnabled"] = cacheClusterEnabled
if self["cacheClusterEnabled"]: if self["cacheClusterEnabled"]:
self["cacheClusterSize"] = str(0.5) self["cacheClusterSize"] = str(0.5)
if cacheClusterSize is not None: if cacheClusterSize is not None:
self["cacheClusterSize"] = str(cacheClusterSize) self["cacheClusterSize"] = str(cacheClusterSize)
if tags is not None:
self["tags"] = tags
if tracing_enabled is not None:
self["tracingEnabled"] = tracing_enabled
def apply_operations(self, patch_operations): def apply_operations(self, patch_operations):
for op in patch_operations: for op in patch_operations:
@ -607,6 +677,7 @@ class RestAPI(CloudFormationModel):
self.disableExecuteApiEndpoint = ( self.disableExecuteApiEndpoint = (
kwargs.get("disableExecuteApiEndpoint") or False kwargs.get("disableExecuteApiEndpoint") or False
) )
self.minimum_compression_size = kwargs.get("minimum_compression_size")
self.deployments = {} self.deployments = {}
self.authorizers = {} self.authorizers = {}
self.stages = {} self.stages = {}
@ -624,12 +695,13 @@ class RestAPI(CloudFormationModel):
"description": self.description, "description": self.description,
"version": self.version, "version": self.version,
"binaryMediaTypes": self.binaryMediaTypes, "binaryMediaTypes": self.binaryMediaTypes,
"createdDate": int(time.time()), "createdDate": self.create_date,
"apiKeySource": self.api_key_source, "apiKeySource": self.api_key_source,
"endpointConfiguration": self.endpoint_configuration, "endpointConfiguration": self.endpoint_configuration,
"tags": self.tags, "tags": self.tags,
"policy": self.policy, "policy": self.policy,
"disableExecuteApiEndpoint": self.disableExecuteApiEndpoint, "disableExecuteApiEndpoint": self.disableExecuteApiEndpoint,
"minimumCompressionSize": self.minimum_compression_size,
} }
def apply_patch_operations(self, patch_operations): def apply_patch_operations(self, patch_operations):
@ -652,7 +724,10 @@ class RestAPI(CloudFormationModel):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "RootResourceId": if attribute_name == "RootResourceId":
return self.id for res_id, res_obj in self.resources.items():
if res_obj.path_part == "/" and not res_obj.parent_id:
return res_id
raise Exception("Unable to find root resource for API %s" % self)
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@property @property
@ -787,6 +862,8 @@ class RestAPI(CloudFormationModel):
description="", description="",
cacheClusterEnabled=None, cacheClusterEnabled=None,
cacheClusterSize=None, cacheClusterSize=None,
tags=None,
tracing_enabled=None,
): ):
if variables is None: if variables is None:
variables = {} variables = {}
@ -797,6 +874,8 @@ class RestAPI(CloudFormationModel):
description=description, description=description,
cacheClusterSize=cacheClusterSize, cacheClusterSize=cacheClusterSize,
cacheClusterEnabled=cacheClusterEnabled, cacheClusterEnabled=cacheClusterEnabled,
tags=tags,
tracing_enabled=tracing_enabled,
) )
self.stages[name] = stage self.stages[name] = stage
self.update_integration_mocks(name) self.update_integration_mocks(name)
@ -835,8 +914,11 @@ class DomainName(BaseModel, dict):
def __init__(self, domain_name, **kwargs): def __init__(self, domain_name, **kwargs):
super(DomainName, self).__init__() super(DomainName, self).__init__()
self["domainName"] = domain_name self["domainName"] = domain_name
self["regionalDomainName"] = domain_name self["regionalDomainName"] = "d-%s.execute-api.%s.amazonaws.com" % (
self["distributionDomainName"] = domain_name create_id(),
kwargs.get("region_name") or "us-east-1",
)
self["distributionDomainName"] = "d%s.cloudfront.net" % create_id()
self["domainNameStatus"] = "AVAILABLE" self["domainNameStatus"] = "AVAILABLE"
self["domainNameStatusMessage"] = "Domain Name Available" self["domainNameStatusMessage"] = "Domain Name Available"
self["regionalHostedZoneId"] = "Z2FDTNDATAQYW2" self["regionalHostedZoneId"] = "Z2FDTNDATAQYW2"
@ -907,6 +989,7 @@ class APIGatewayBackend(BaseBackend):
endpoint_configuration=None, endpoint_configuration=None,
tags=None, tags=None,
policy=None, policy=None,
minimum_compression_size=None,
): ):
api_id = create_id() api_id = create_id()
rest_api = RestAPI( rest_api = RestAPI(
@ -918,6 +1001,7 @@ class APIGatewayBackend(BaseBackend):
endpoint_configuration=endpoint_configuration, endpoint_configuration=endpoint_configuration,
tags=tags, tags=tags,
policy=policy, policy=policy,
minimum_compression_size=minimum_compression_size,
) )
self.apis[api_id] = rest_api self.apis[api_id] = rest_api
return rest_api return rest_api
@ -974,13 +1058,30 @@ class APIGatewayBackend(BaseBackend):
method_type, method_type,
authorization_type, authorization_type,
api_key_required=None, api_key_required=None,
request_models=None,
operation_name=None,
authorizer_id=None,
authorization_scopes=None,
request_validator_id=None,
): ):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
method = resource.add_method( method = resource.add_method(
method_type, authorization_type, api_key_required=api_key_required method_type,
authorization_type,
api_key_required=api_key_required,
request_models=request_models,
operation_name=operation_name,
authorizer_id=authorizer_id,
authorization_scopes=authorization_scopes,
request_validator_id=request_validator_id,
) )
return method return method
def update_method(self, function_id, resource_id, method_type, patch_operations):
resource = self.get_resource(function_id, resource_id)
method = resource.get_method(method_type)
return method.apply_operations(patch_operations)
def get_authorizer(self, restapi_id, authorizer_id): def get_authorizer(self, restapi_id, authorizer_id):
api = self.get_rest_api(restapi_id) api = self.get_rest_api(restapi_id)
authorizer = api.authorizers.get(authorizer_id) authorizer = api.authorizers.get(authorizer_id)
@ -1026,7 +1127,6 @@ class APIGatewayBackend(BaseBackend):
stage = api.stages.get(stage_name) stage = api.stages.get(stage_name)
if stage is None: if stage is None:
raise StageNotFoundException() raise StageNotFoundException()
else:
return stage return stage
def get_stages(self, function_id): def get_stages(self, function_id):
@ -1042,6 +1142,8 @@ class APIGatewayBackend(BaseBackend):
description="", description="",
cacheClusterEnabled=None, cacheClusterEnabled=None,
cacheClusterSize=None, cacheClusterSize=None,
tags=None,
tracing_enabled=None,
): ):
if variables is None: if variables is None:
variables = {} variables = {}
@ -1053,6 +1155,8 @@ class APIGatewayBackend(BaseBackend):
description=description, description=description,
cacheClusterEnabled=cacheClusterEnabled, cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize, cacheClusterSize=cacheClusterSize,
tags=tags,
tracing_enabled=tracing_enabled,
) )
return api.stages.get(stage_name) return api.stages.get(stage_name)
@ -1065,7 +1169,9 @@ class APIGatewayBackend(BaseBackend):
def delete_stage(self, function_id, stage_name): def delete_stage(self, function_id, stage_name):
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
del api.stages[stage_name] deleted = api.stages.pop(stage_name, None)
if not deleted:
raise StageNotFoundException()
def get_method_response(self, function_id, resource_id, method_type, response_code): def get_method_response(self, function_id, resource_id, method_type, response_code):
method = self.get_method(function_id, resource_id, method_type) method = self.get_method(function_id, resource_id, method_type)
@ -1073,10 +1179,26 @@ class APIGatewayBackend(BaseBackend):
return method_response return method_response
def create_method_response( def create_method_response(
self, function_id, resource_id, method_type, response_code self,
function_id,
resource_id,
method_type,
response_code,
response_models,
response_parameters,
): ):
method = self.get_method(function_id, resource_id, method_type) method = self.get_method(function_id, resource_id, method_type)
method_response = method.create_response(response_code) method_response = method.create_response(
response_code, response_models, response_parameters
)
return method_response
def update_method_response(
self, function_id, resource_id, method_type, response_code, patch_operations
):
method = self.get_method(function_id, resource_id, method_type)
method_response = method.get_response(response_code)
method_response.apply_operations(patch_operations)
return method_response return method_response
def delete_method_response( def delete_method_response(
@ -1096,6 +1218,8 @@ class APIGatewayBackend(BaseBackend):
integration_method=None, integration_method=None,
credentials=None, credentials=None,
request_templates=None, request_templates=None,
tls_config=None,
cache_namespace=None,
): ):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
if credentials and not re.match( if credentials and not re.match(
@ -1128,7 +1252,13 @@ class APIGatewayBackend(BaseBackend):
): ):
raise InvalidIntegrationArn() raise InvalidIntegrationArn()
integration = resource.add_integration( integration = resource.add_integration(
method_type, integration_type, uri, request_templates=request_templates method_type,
integration_type,
uri,
integration_method=integration_method,
request_templates=request_templates,
tls_config=tls_config,
cache_namespace=cache_namespace,
) )
return integration return integration
@ -1205,7 +1335,7 @@ class APIGatewayBackend(BaseBackend):
return api.delete_deployment(deployment_id) return api.delete_deployment(deployment_id)
def create_api_key(self, payload): def create_api_key(self, payload):
if payload.get("value") is not None: if payload.get("value"):
if len(payload.get("value", [])) < 20: if len(payload.get("value", [])) < 20:
raise ApiKeyValueMinLength() raise ApiKeyValueMinLength()
for api_key in self.get_api_keys(include_values=True): for api_key in self.get_api_keys(include_values=True):
@ -1229,7 +1359,9 @@ class APIGatewayBackend(BaseBackend):
return api_keys return api_keys
def get_api_key(self, api_key_id, include_value=False): def get_api_key(self, api_key_id, include_value=False):
api_key = self.keys[api_key_id] api_key = self.keys.get(api_key_id)
if not api_key:
raise ApiKeyNotFoundException()
if not include_value: if not include_value:
new_key = copy(api_key) new_key = copy(api_key)
@ -1322,7 +1454,7 @@ class APIGatewayBackend(BaseBackend):
def _uri_validator(self, uri): def _uri_validator(self, uri):
try: try:
result = urlparse(uri) result = urlparse(uri)
return all([result.scheme, result.netloc, result.path]) return all([result.scheme, result.netloc, result.path or "/"])
except Exception: except Exception:
return False return False
@ -1358,6 +1490,7 @@ class APIGatewayBackend(BaseBackend):
tags=tags, tags=tags,
security_policy=security_policy, security_policy=security_policy,
generate_cli_skeleton=generate_cli_skeleton, generate_cli_skeleton=generate_cli_skeleton,
region_name=self.region_name,
) )
self.domain_names[domain_name] = new_domain_name self.domain_names[domain_name] = new_domain_name
@ -1369,10 +1502,22 @@ class APIGatewayBackend(BaseBackend):
def get_domain_name(self, domain_name): def get_domain_name(self, domain_name):
domain_info = self.domain_names.get(domain_name) domain_info = self.domain_names.get(domain_name)
if domain_info is None: if domain_info is None:
raise DomainNameNotFound raise DomainNameNotFound()
else: else:
return self.domain_names[domain_name] return self.domain_names[domain_name]
def delete_domain_name(self, domain_name):
domain_info = self.domain_names.pop(domain_name, None)
if domain_info is None:
raise DomainNameNotFound()
def update_domain_name(self, domain_name, patch_operations):
domain_info = self.domain_names.get(domain_name)
if not domain_info:
raise DomainNameNotFound()
domain_info.apply_patch_operations(patch_operations)
return domain_info
def create_model( def create_model(
self, self,
rest_api_id, rest_api_id,

View File

@ -20,6 +20,8 @@ from .exceptions import (
ModelNotFound, ModelNotFound,
ApiKeyValueMinLength, ApiKeyValueMinLength,
InvalidRequestInput, InvalidRequestInput,
NoIntegrationDefined,
NotFoundException,
) )
API_KEY_SOURCES = ["AUTHORIZER", "HEADER"] API_KEY_SOURCES = ["AUTHORIZER", "HEADER"]
@ -29,9 +31,11 @@ ENDPOINT_CONFIGURATION_TYPES = ["PRIVATE", "EDGE", "REGIONAL"]
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def error(self, type_, message, status=400): def error(self, type_, message, status=400):
headers = self.response_headers or {}
headers["X-Amzn-Errortype"] = type_
return ( return (
status, status,
self.response_headers, headers,
json.dumps({"__type": type_, "message": message}), json.dumps({"__type": type_, "message": message}),
) )
@ -80,6 +84,7 @@ class APIGatewayResponse(BaseResponse):
endpoint_configuration = self._get_param("endpointConfiguration") endpoint_configuration = self._get_param("endpointConfiguration")
tags = self._get_param("tags") tags = self._get_param("tags")
policy = self._get_param("policy") policy = self._get_param("policy")
minimum_compression_size = self._get_param("minimumCompressionSize")
# Param validation # Param validation
response = self.__validate_api_key_source(api_key_source) response = self.__validate_api_key_source(api_key_source)
@ -97,6 +102,7 @@ class APIGatewayResponse(BaseResponse):
endpoint_configuration=endpoint_configuration, endpoint_configuration=endpoint_configuration,
tags=tags, tags=tags,
policy=policy, policy=policy,
minimum_compression_size=minimum_compression_size,
) )
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
@ -162,9 +168,7 @@ class APIGatewayResponse(BaseResponse):
resource = self.backend.delete_resource(function_id, resource_id) resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict()) return 200, {}, json.dumps(resource.to_dict())
except BadRequestException as e: except BadRequestException as e:
return self.error( return self.error("BadRequestException", e.message)
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
def resource_methods(self, request, full_url, headers): def resource_methods(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -179,15 +183,37 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "PUT": elif self.method == "PUT":
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
api_key_required = self._get_param("apiKeyRequired") api_key_required = self._get_param("apiKeyRequired")
request_models = self._get_param("requestModels")
operation_name = self._get_param("operationName")
authorizer_id = self._get_param("authorizerId")
authorization_scopes = self._get_param("authorizationScopes")
request_validator_id = self._get_param("requestValidatorId")
method = self.backend.create_method( method = self.backend.create_method(
function_id, function_id,
resource_id, resource_id,
method_type, method_type,
authorization_type, authorization_type,
api_key_required, api_key_required,
request_models=request_models,
operation_name=operation_name,
authorizer_id=authorizer_id,
authorization_scopes=authorization_scopes,
request_validator_id=request_validator_id,
) )
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
elif self.method == "DELETE":
self.backend.delete_method(function_id, resource_id, method_type)
return 200, {}, ""
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
self.backend.update_method(
function_id, resource_id, method_type, patch_operations
)
return 200, {}, ""
def resource_method_responses(self, request, full_url, headers): def resource_method_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
@ -201,13 +227,27 @@ class APIGatewayResponse(BaseResponse):
function_id, resource_id, method_type, response_code function_id, resource_id, method_type, response_code
) )
elif self.method == "PUT": elif self.method == "PUT":
response_models = self._get_param("responseModels")
response_parameters = self._get_param("responseParameters")
method_response = self.backend.create_method_response( method_response = self.backend.create_method_response(
function_id, resource_id, method_type, response_code function_id,
resource_id,
method_type,
response_code,
response_models,
response_parameters,
) )
elif self.method == "DELETE": elif self.method == "DELETE":
method_response = self.backend.delete_method_response( method_response = self.backend.delete_method_response(
function_id, resource_id, method_type, response_code function_id, resource_id, method_type, response_code
) )
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
method_response = self.backend.update_method_response(
function_id, resource_id, method_type, response_code, patch_operations
)
else:
raise Exception('Unexpected HTTP method "%s"' % self.method)
return 200, {}, json.dumps(method_response) return 200, {}, json.dumps(method_response)
def restapis_authorizers(self, request, full_url, headers): def restapis_authorizers(self, request, full_url, headers):
@ -302,6 +342,8 @@ class APIGatewayResponse(BaseResponse):
description = self._get_param("description", if_none="") description = self._get_param("description", if_none="")
cacheClusterEnabled = self._get_param("cacheClusterEnabled", if_none=False) cacheClusterEnabled = self._get_param("cacheClusterEnabled", if_none=False)
cacheClusterSize = self._get_param("cacheClusterSize") cacheClusterSize = self._get_param("cacheClusterSize")
tags = self._get_param("tags")
tracing_enabled = self._get_param("tracingEnabled")
stage_response = self.backend.create_stage( stage_response = self.backend.create_stage(
function_id, function_id,
@ -311,6 +353,8 @@ class APIGatewayResponse(BaseResponse):
description=description, description=description,
cacheClusterEnabled=cacheClusterEnabled, cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize, cacheClusterSize=cacheClusterSize,
tags=tags,
tracing_enabled=tracing_enabled,
) )
elif self.method == "GET": elif self.method == "GET":
stages = self.backend.get_stages(function_id) stages = self.backend.get_stages(function_id)
@ -353,6 +397,8 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
try: try:
integration_response = {}
if self.method == "GET": if self.method == "GET":
integration_response = self.backend.get_integration( integration_response = self.backend.get_integration(
function_id, resource_id, method_type function_id, resource_id, method_type
@ -360,32 +406,39 @@ class APIGatewayResponse(BaseResponse):
elif self.method == "PUT": elif self.method == "PUT":
integration_type = self._get_param("type") integration_type = self._get_param("type")
uri = self._get_param("uri") uri = self._get_param("uri")
integration_http_method = self._get_param("httpMethod") credentials = self._get_param("credentials")
creds = self._get_param("credentials")
request_templates = self._get_param("requestTemplates") request_templates = self._get_param("requestTemplates")
tls_config = self._get_param("tlsConfig")
cache_namespace = self._get_param("cacheNamespace")
self.backend.get_method(function_id, resource_id, method_type)
integration_http_method = self._get_param(
"httpMethod"
) # default removed because it's a required parameter
integration_response = self.backend.create_integration( integration_response = self.backend.create_integration(
function_id, function_id,
resource_id, resource_id,
method_type, method_type,
integration_type, integration_type,
uri, uri,
credentials=creds, credentials=credentials,
integration_method=integration_http_method, integration_method=integration_http_method,
request_templates=request_templates, request_templates=request_templates,
tls_config=tls_config,
cache_namespace=cache_namespace,
) )
elif self.method == "DELETE": elif self.method == "DELETE":
integration_response = self.backend.delete_integration( integration_response = self.backend.delete_integration(
function_id, resource_id, method_type function_id, resource_id, method_type
) )
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
except BadRequestException as e: except BadRequestException as e:
return self.error( return self.error("BadRequestException", e.message)
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message
)
except CrossAccountNotAllowed as e: except CrossAccountNotAllowed as e:
return self.error( return self.error("AccessDeniedException", e.message)
"com.amazonaws.dynamodb.v20111205#AccessDeniedException", e.message
)
def integration_responses(self, request, full_url, headers): def integration_responses(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -422,9 +475,9 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
except BadRequestException as e: except BadRequestException as e:
return self.error( return self.error("BadRequestException", e.message)
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message except NoIntegrationDefined as e:
) return self.error("NotFoundException", e.message)
def deployments(self, request, full_url, headers): def deployments(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -443,9 +496,9 @@ class APIGatewayResponse(BaseResponse):
) )
return 200, {}, json.dumps(deployment) return 200, {}, json.dumps(deployment)
except BadRequestException as e: except BadRequestException as e:
return self.error( return self.error("BadRequestException", e.message)
"com.amazonaws.dynamodb.v20111205#BadRequestException", e.message except NotFoundException as e:
) return self.error("NotFoundException", e.message)
def individual_deployment(self, request, full_url, headers): def individual_deployment(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -453,6 +506,7 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2] function_id = url_path_parts[2]
deployment_id = url_path_parts[4] deployment_id = url_path_parts[4]
deployment = None
if self.method == "GET": if self.method == "GET":
deployment = self.backend.get_deployment(function_id, deployment_id) deployment = self.backend.get_deployment(function_id, deployment_id)
elif self.method == "DELETE": elif self.method == "DELETE":
@ -652,6 +706,18 @@ class APIGatewayResponse(BaseResponse):
if self.method == "GET": if self.method == "GET":
if domain_name is not None: if domain_name is not None:
domain_names = self.backend.get_domain_name(domain_name) domain_names = self.backend.get_domain_name(domain_name)
elif self.method == "DELETE":
if domain_name is not None:
self.backend.delete_domain_name(domain_name)
elif self.method == "PATCH":
if domain_name is not None:
patch_operations = self._get_param("patchOperations")
self.backend.update_domain_name(domain_name, patch_operations)
else:
msg = (
'Method "%s" for API GW domain names not implemented' % self.method
)
return 404, {}, json.dumps({"error": msg})
return 200, {}, json.dumps(domain_names) return 200, {}, json.dumps(domain_names)
except DomainNameNotFound as error: except DomainNameNotFound as error:
return ( return (

View File

@ -155,7 +155,7 @@ class ApplicationAutoscalingBackend(BaseBackend):
service_namespace, resource_id, scalable_dimension, policy_name service_namespace, resource_id, scalable_dimension, policy_name
) )
if policy_key in self.policies: if policy_key in self.policies:
old_policy = self.policies[policy_name] old_policy = self.policies[policy_key]
policy = FakeApplicationAutoscalingPolicy( policy = FakeApplicationAutoscalingPolicy(
region_name=self.region, region_name=self.region,
policy_name=policy_name, policy_name=policy_name,

View File

@ -82,7 +82,7 @@ class AthenaResponse(BaseResponse):
def error(self, msg, status): def error(self, msg, status):
return ( return (
json.dumps({"__type": "InvalidRequestException", "Message": msg,}), json.dumps({"__type": "InvalidRequestException", "Message": msg}),
dict(status=status), dict(status=status),
) )

View File

@ -26,7 +26,7 @@ import requests.exceptions
from boto3 import Session from boto3 import Session
from moto.awslambda.policy import Policy from moto.awslambda.policy import Policy
from moto.core import BaseBackend, CloudFormationModel from moto.core import BaseBackend, BaseModel, CloudFormationModel
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
from moto.iam.models import iam_backend from moto.iam.models import iam_backend
from moto.iam.exceptions import IAMNotFoundException from moto.iam.exceptions import IAMNotFoundException
@ -1072,6 +1072,32 @@ class LayerStorage(object):
return None return None
class LambdaPermission(BaseModel):
def __init__(self, spec):
self.action = spec["Action"]
self.function_name = spec["FunctionName"]
self.principal = spec["Principal"]
# optional
self.source_account = spec.get("SourceAccount")
self.source_arn = spec.get("SourceArn")
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
spec = {
"Action": properties["Action"],
"FunctionName": properties["FunctionName"],
"Principal": properties["Principal"],
}
optional_properties = "SourceAccount SourceArn".split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
return LambdaPermission(spec)
class LambdaBackend(BaseBackend): class LambdaBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
self._lambdas = LambdaStorage() self._lambdas = LambdaStorage()

View File

@ -16,7 +16,7 @@ from moto.ec2 import ec2_backends
from moto.ecs import ecs_backends from moto.ecs import ecs_backends
from moto.logs import logs_backends from moto.logs import logs_backends
from .exceptions import InvalidParameterValueException, InternalFailure, ClientException from .exceptions import InvalidParameterValueException, ClientException
from .utils import ( from .utils import (
make_arn_for_compute_env, make_arn_for_compute_env,
make_arn_for_job_queue, make_arn_for_job_queue,
@ -463,7 +463,8 @@ class Job(threading.Thread, BaseModel, DockerModel):
self.job_state = "STARTING" self.job_state = "STARTING"
log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON) log_config = docker.types.LogConfig(type=docker.types.LogConfig.types.JSON)
image_repository, image_tag = parse_image_ref(image) image_repository, image_tag = parse_image_ref(image)
self.docker_client.images.pull(image_repository, image_tag) # avoid explicit pulling here, to allow using cached images
# self.docker_client.images.pull(image_repository, image_tag)
container = self.docker_client.containers.run( container = self.docker_client.containers.run(
image, image,
cmd, cmd,
@ -479,6 +480,7 @@ class Job(threading.Thread, BaseModel, DockerModel):
container.reload() container.reload()
while container.status == "running" and not self.stop: while container.status == "running" and not self.stop:
container.reload() container.reload()
time.sleep(0.5)
# Container should be stopped by this point... unless asked to stop # Container should be stopped by this point... unless asked to stop
if container.status == "running": if container.status == "running":
@ -531,11 +533,9 @@ class Job(threading.Thread, BaseModel, DockerModel):
self._log_backend.create_log_stream(log_group, stream_name) self._log_backend.create_log_stream(log_group, stream_name)
self._log_backend.put_log_events(log_group, stream_name, logs, None) self._log_backend.put_log_events(log_group, stream_name, logs, None)
result = container.wait() result = container.wait() or {}
if self.stop or result["StatusCode"] != 0: job_failed = self.stop or result.get("StatusCode", 0) > 0
self._mark_stopped(success=False) self._mark_stopped(success=not job_failed)
else:
self._mark_stopped(success=True)
except Exception as err: except Exception as err:
logger.error( logger.error(
@ -782,6 +782,7 @@ class BatchBackend(BaseBackend):
"state": environment.state, "state": environment.state,
"type": environment.env_type, "type": environment.env_type,
"status": "VALID", "status": "VALID",
"statusReason": "Compute environment is available",
} }
if environment.env_type == "MANAGED": if environment.env_type == "MANAGED":
json_part["computeResources"] = environment.compute_resources json_part["computeResources"] = environment.compute_resources
@ -898,9 +899,10 @@ class BatchBackend(BaseBackend):
"type", "type",
): ):
if param not in cr: if param not in cr:
raise InvalidParameterValueException( pass # commenting out invalid check below - values may be missing (tf-compat)
"computeResources must contain {0}".format(param) # raise InvalidParameterValueException(
) # "computeResources must contain {0}".format(param)
# )
for profile in self.iam_backend.get_instance_profiles(): for profile in self.iam_backend.get_instance_profiles():
if profile.arn == cr["instanceRole"]: if profile.arn == cr["instanceRole"]:
break break
@ -955,9 +957,6 @@ class BatchBackend(BaseBackend):
"computeResources.type must be either EC2 | SPOT" "computeResources.type must be either EC2 | SPOT"
) )
if cr["type"] == "SPOT":
raise InternalFailure("SPOT NOT SUPPORTED YET")
@staticmethod @staticmethod
def find_min_instances_to_meet_vcpus(instance_types, target): def find_min_instances_to_meet_vcpus(instance_types, target):
""" """
@ -1027,6 +1026,7 @@ class BatchBackend(BaseBackend):
if compute_env.env_type == "MANAGED": if compute_env.env_type == "MANAGED":
# Delete compute environment # Delete compute environment
instance_ids = [instance.id for instance in compute_env.instances] instance_ids = [instance.id for instance in compute_env.instances]
if instance_ids:
self.ec2_backend.terminate_instances(instance_ids) self.ec2_backend.terminate_instances(instance_ids)
def update_compute_environment( def update_compute_environment(

View File

@ -505,9 +505,8 @@ class ResourceMap(collections_abc.Mapping):
self._parsed_resources.update(json.loads(key.value)) self._parsed_resources.update(json.loads(key.value))
def parse_ssm_parameter(self, value, value_type): def parse_ssm_parameter(self, value, value_type):
# The Value in SSM parameters is the SSM parameter path # The Value in SSM parameters is the SSM parameter path
# we need to use ssm_backend to retreive the # we need to use ssm_backend to retrieve the
# actual value from parameter store # actual value from parameter store
parameter = ssm_backends[self._region_name].get_parameter(value, False) parameter = ssm_backends[self._region_name].get_parameter(value, False)
actual_value = parameter.value actual_value = parameter.value

View File

@ -44,6 +44,8 @@ def yaml_tag_constructor(loader, tag, node):
def _f(loader, tag, node): def _f(loader, tag, node):
if tag == "!GetAtt": if tag == "!GetAtt":
if isinstance(node.value, list):
return node.value
return node.value.split(".") return node.value.split(".")
elif type(node) == yaml.SequenceNode: elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node) return loader.construct_sequence(node)

View File

@ -107,6 +107,7 @@ class FakeAlarm(BaseModel):
unit, unit,
actions_enabled, actions_enabled,
region="us-east-1", region="us-east-1",
rule=None,
): ):
self.name = name self.name = name
self.alarm_arn = make_arn_for_alarm(region, DEFAULT_ACCOUNT_ID, name) self.alarm_arn = make_arn_for_alarm(region, DEFAULT_ACCOUNT_ID, name)
@ -123,7 +124,7 @@ class FakeAlarm(BaseModel):
self.dimensions = [ self.dimensions = [
Dimension(dimension["name"], dimension["value"]) for dimension in dimensions Dimension(dimension["name"], dimension["value"]) for dimension in dimensions
] ]
self.actions_enabled = actions_enabled self.actions_enabled = True if actions_enabled is None else actions_enabled
self.alarm_actions = alarm_actions self.alarm_actions = alarm_actions
self.ok_actions = ok_actions self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions self.insufficient_data_actions = insufficient_data_actions
@ -137,6 +138,9 @@ class FakeAlarm(BaseModel):
self.state_value = "OK" self.state_value = "OK"
self.state_updated_timestamp = datetime.utcnow() self.state_updated_timestamp = datetime.utcnow()
# only used for composite alarms
self.rule = rule
def update_state(self, reason, reason_data, state_value): def update_state(self, reason, reason_data, state_value):
# History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action # History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action
self.history.append( self.history.append(
@ -156,6 +160,8 @@ class FakeAlarm(BaseModel):
def are_dimensions_same(metric_dimensions, dimensions): def are_dimensions_same(metric_dimensions, dimensions):
if len(metric_dimensions) != len(dimensions):
return False
for dimension in metric_dimensions: for dimension in metric_dimensions:
for new_dimension in dimensions: for new_dimension in dimensions:
if ( if (
@ -163,7 +169,6 @@ def are_dimensions_same(metric_dimensions, dimensions):
or dimension.value != new_dimension.value or dimension.value != new_dimension.value
): ):
return False return False
return True return True
@ -178,11 +183,12 @@ class MetricDatum(BaseModel):
] ]
self.unit = unit self.unit = unit
def filter(self, namespace, name, dimensions, already_present_metrics): def filter(self, namespace, name, dimensions, already_present_metrics=[]):
if namespace and namespace != self.namespace: if namespace and namespace != self.namespace:
return False return False
if name and name != self.name: if name and name != self.name:
return False return False
for metric in already_present_metrics: for metric in already_present_metrics:
if self.dimensions and are_dimensions_same( if self.dimensions and are_dimensions_same(
metric.dimensions, self.dimensions metric.dimensions, self.dimensions
@ -302,6 +308,7 @@ class CloudWatchBackend(BaseBackend):
unit, unit,
actions_enabled, actions_enabled,
region="us-east-1", region="us-east-1",
rule=None,
): ):
alarm = FakeAlarm( alarm = FakeAlarm(
name, name,
@ -322,6 +329,7 @@ class CloudWatchBackend(BaseBackend):
unit, unit,
actions_enabled, actions_enabled,
region, region,
rule=rule,
) )
self.alarms[name] = alarm self.alarms[name] = alarm
@ -451,7 +459,15 @@ class CloudWatchBackend(BaseBackend):
return results return results
def get_metric_statistics( def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats, unit=None self,
namespace,
metric_name,
start_time,
end_time,
period,
stats,
unit=None,
dimensions=None,
): ):
period_delta = timedelta(seconds=period) period_delta = timedelta(seconds=period)
filtered_data = [ filtered_data = [
@ -464,6 +480,10 @@ class CloudWatchBackend(BaseBackend):
if unit: if unit:
filtered_data = [md for md in filtered_data if md.unit == unit] filtered_data = [md for md in filtered_data if md.unit == unit]
if dimensions:
filtered_data = [
md for md in filtered_data if md.filter(None, None, dimensions)
]
# earliest to oldest # earliest to oldest
filtered_data = sorted(filtered_data, key=lambda x: x.timestamp) filtered_data = sorted(filtered_data, key=lambda x: x.timestamp)

View File

@ -1,9 +1,11 @@
import json import json
from moto.core.utils import amzn_request_id
from moto.core.responses import BaseResponse
from .models import cloudwatch_backends, MetricDataQuery, MetricStat, Metric, Dimension
from dateutil.parser import parse as dtparse from dateutil.parser import parse as dtparse
from moto.core.responses import BaseResponse
from moto.core.utils import amzn_request_id
from .models import cloudwatch_backends, MetricDataQuery, MetricStat, Metric, Dimension
class CloudWatchResponse(BaseResponse): class CloudWatchResponse(BaseResponse):
@property @property
@ -19,34 +21,49 @@ class CloudWatchResponse(BaseResponse):
name = self._get_param("AlarmName") name = self._get_param("AlarmName")
namespace = self._get_param("Namespace") namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName") metric_name = self._get_param("MetricName")
metrics = self._get_multi_param("Metrics.member") metrics = self._get_multi_param("Metrics.member", skip_result_conversion=True)
metric_data_queries = None metric_data_queries = None
if metrics: if metrics:
metric_data_queries = [ metric_data_queries = []
for metric in metrics:
dimensions = []
dims = (
metric.get("MetricStat", {})
.get("Metric", {})
.get("Dimensions.member", [])
)
for dim in dims:
dimensions.append(
Dimension(name=dim.get("Name"), value=dim.get("Value"))
)
metric_stat = None
stat_metric_name = (
metric.get("MetricStat", {}).get("Metric", {}).get("MetricName")
)
if stat_metric_name:
stat_details = metric.get("MetricStat", {})
stat_metric_ns = stat_details.get("Metric", {}).get("Namespace")
metric_stat = MetricStat(
metric=Metric(
metric_name=stat_metric_name,
namespace=stat_metric_ns,
dimensions=dimensions,
),
period=stat_details.get("Period"),
stat=stat_details.get("Stat"),
unit=stat_details.get("Unit"),
)
metric_data_queries.append(
MetricDataQuery( MetricDataQuery(
id=metric.get("Id"), id=metric.get("Id"),
label=metric.get("Label"), label=metric.get("Label"),
period=metric.get("Period"), period=metric.get("Period"),
return_data=metric.get("ReturnData"), return_data=metric.get("ReturnData"),
expression=metric.get("Expression"), expression=metric.get("Expression"),
metric_stat=MetricStat( metric_stat=metric_stat,
metric=Metric(
metric_name=metric.get("MetricStat.Metric.MetricName"),
namespace=metric.get("MetricStat.Metric.Namespace"),
dimensions=[
Dimension(name=dim["Name"], value=dim["Value"])
for dim in metric["MetricStat.Metric.Dimensions.member"]
],
),
period=metric.get("MetricStat.Period"),
stat=metric.get("MetricStat.Stat"),
unit=metric.get("MetricStat.Unit"),
) )
if "MetricStat.Metric.MetricName" in metric
else None,
) )
for metric in metrics
]
comparison_operator = self._get_param("ComparisonOperator") comparison_operator = self._get_param("ComparisonOperator")
evaluation_periods = self._get_param("EvaluationPeriods") evaluation_periods = self._get_param("EvaluationPeriods")
datapoints_to_alarm = self._get_param("DatapointsToAlarm") datapoints_to_alarm = self._get_param("DatapointsToAlarm")
@ -62,6 +79,8 @@ class CloudWatchResponse(BaseResponse):
"InsufficientDataActions.member" "InsufficientDataActions.member"
) )
unit = self._get_param("Unit") unit = self._get_param("Unit")
# fetch AlarmRule to re-use this method for composite alarms as well
rule = self._get_param("AlarmRule")
alarm = self.cloudwatch_backend.put_metric_alarm( alarm = self.cloudwatch_backend.put_metric_alarm(
name, name,
namespace, namespace,
@ -81,6 +100,7 @@ class CloudWatchResponse(BaseResponse):
unit, unit,
actions_enabled, actions_enabled,
self.region, self.region,
rule=rule,
) )
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE) template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm) return template.render(alarm=alarm)
@ -105,8 +125,13 @@ class CloudWatchResponse(BaseResponse):
else: else:
alarms = self.cloudwatch_backend.get_all_alarms() alarms = self.cloudwatch_backend.get_all_alarms()
metric_alarms = [a for a in alarms if a.rule is None]
composite_alarms = [a for a in alarms if a.rule is not None]
template = self.response_template(DESCRIBE_ALARMS_TEMPLATE) template = self.response_template(DESCRIBE_ALARMS_TEMPLATE)
return template.render(alarms=alarms) return template.render(
metric_alarms=metric_alarms, composite_alarms=composite_alarms
)
@amzn_request_id @amzn_request_id
def delete_alarms(self): def delete_alarms(self):
@ -145,12 +170,12 @@ class CloudWatchResponse(BaseResponse):
end_time = dtparse(self._get_param("EndTime")) end_time = dtparse(self._get_param("EndTime"))
period = int(self._get_param("Period")) period = int(self._get_param("Period"))
statistics = self._get_multi_param("Statistics.member") statistics = self._get_multi_param("Statistics.member")
dimensions = self._get_multi_param("Dimensions.member")
# Unsupported Parameters (To Be Implemented) # Unsupported Parameters (To Be Implemented)
unit = self._get_param("Unit") unit = self._get_param("Unit")
extended_statistics = self._get_param("ExtendedStatistics") extended_statistics = self._get_param("ExtendedStatistics")
dimensions = self._get_param("Dimensions") if extended_statistics:
if extended_statistics or dimensions:
raise NotImplementedError() raise NotImplementedError()
# TODO: this should instead throw InvalidParameterCombination # TODO: this should instead throw InvalidParameterCombination
@ -160,7 +185,14 @@ class CloudWatchResponse(BaseResponse):
) )
datapoints = self.cloudwatch_backend.get_metric_statistics( datapoints = self.cloudwatch_backend.get_metric_statistics(
namespace, metric_name, start_time, end_time, period, statistics, unit namespace,
metric_name,
start_time,
end_time,
period,
statistics,
unit,
dimensions=dimensions,
) )
template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE) template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE)
return template.render(label=metric_name, datapoints=datapoints) return template.render(label=metric_name, datapoints=datapoints)
@ -280,7 +312,8 @@ PUT_METRIC_ALARM_TEMPLATE = """<PutMetricAlarmResponse xmlns="http://monitoring.
DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/"> DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.amazonaws.com/doc/2010-08-01/">
<DescribeAlarmsResult> <DescribeAlarmsResult>
<MetricAlarms> {% for tag_name, alarms in (('MetricAlarms', metric_alarms), ('CompositeAlarms', composite_alarms)) %}
<{{tag_name}}>
{% for alarm in alarms %} {% for alarm in alarms %}
<member> <member>
<ActionsEnabled>{{ alarm.actions_enabled }}</ActionsEnabled> <ActionsEnabled>{{ alarm.actions_enabled }}</ActionsEnabled>
@ -291,7 +324,7 @@ DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.a
</AlarmActions> </AlarmActions>
<AlarmArn>{{ alarm.alarm_arn }}</AlarmArn> <AlarmArn>{{ alarm.alarm_arn }}</AlarmArn>
<AlarmConfigurationUpdatedTimestamp>{{ alarm.configuration_updated_timestamp }}</AlarmConfigurationUpdatedTimestamp> <AlarmConfigurationUpdatedTimestamp>{{ alarm.configuration_updated_timestamp }}</AlarmConfigurationUpdatedTimestamp>
<AlarmDescription>{{ alarm.description }}</AlarmDescription> <AlarmDescription>{{ alarm.description or '' }}</AlarmDescription>
<AlarmName>{{ alarm.name }}</AlarmName> <AlarmName>{{ alarm.name }}</AlarmName>
<ComparisonOperator>{{ alarm.comparison_operator }}</ComparisonOperator> <ComparisonOperator>{{ alarm.comparison_operator }}</ComparisonOperator>
{% if alarm.dimensions is not none %} {% if alarm.dimensions is not none %}
@ -376,13 +409,19 @@ DESCRIBE_ALARMS_TEMPLATE = """<DescribeAlarmsResponse xmlns="http://monitoring.a
{% if alarm.statistic is not none %} {% if alarm.statistic is not none %}
<Statistic>{{ alarm.statistic }}</Statistic> <Statistic>{{ alarm.statistic }}</Statistic>
{% endif %} {% endif %}
{% if alarm.threshold is not none %}
<Threshold>{{ alarm.threshold }}</Threshold> <Threshold>{{ alarm.threshold }}</Threshold>
{% endif %}
{% if alarm.unit is not none %} {% if alarm.unit is not none %}
<Unit>{{ alarm.unit }}</Unit> <Unit>{{ alarm.unit }}</Unit>
{% endif %} {% endif %}
{% if alarm.rule is not none %}
<AlarmRule>{{ alarm.rule }}</AlarmRule>
{% endif %}
</member> </member>
{% endfor %} {% endfor %}
</MetricAlarms> </{{tag_name}}>
{% endfor %}
</DescribeAlarmsResult> </DescribeAlarmsResult>
</DescribeAlarmsResponse>""" </DescribeAlarmsResponse>"""
@ -429,7 +468,9 @@ DESCRIBE_METRIC_ALARMS_TEMPLATE = """<DescribeAlarmsForMetricResponse xmlns="htt
<StateUpdatedTimestamp>{{ alarm.state_updated_timestamp }}</StateUpdatedTimestamp> <StateUpdatedTimestamp>{{ alarm.state_updated_timestamp }}</StateUpdatedTimestamp>
<StateValue>{{ alarm.state_value }}</StateValue> <StateValue>{{ alarm.state_value }}</StateValue>
<Statistic>{{ alarm.statistic }}</Statistic> <Statistic>{{ alarm.statistic }}</Statistic>
{% if alarm.threshold is not none %}
<Threshold>{{ alarm.threshold }}</Threshold> <Threshold>{{ alarm.threshold }}</Threshold>
{% endif %}
<Unit>{{ alarm.unit }}</Unit> <Unit>{{ alarm.unit }}</Unit>
</member> </member>
{% endfor %} {% endfor %}

View File

@ -29,6 +29,20 @@ class CognitoIdentity(BaseModel):
self.identity_pool_id = get_random_identity_id(region) self.identity_pool_id = get_random_identity_id(region)
self.creation_time = datetime.datetime.utcnow() self.creation_time = datetime.datetime.utcnow()
def to_json(self):
return json.dumps(
{
"IdentityPoolId": self.identity_pool_id,
"IdentityPoolName": self.identity_pool_name,
"AllowUnauthenticatedIdentities": self.allow_unauthenticated_identities,
"SupportedLoginProviders": self.supported_login_providers,
"DeveloperProviderName": self.developer_provider_name,
"OpenIdConnectProviderARNs": self.open_id_connect_provider_arns,
"CognitoIdentityProviders": self.cognito_identity_providers,
"SamlProviderARNs": self.saml_provider_arns,
}
)
class CognitoIdentityBackend(BaseBackend): class CognitoIdentityBackend(BaseBackend):
def __init__(self, region): def __init__(self, region):
@ -54,7 +68,7 @@ class CognitoIdentityBackend(BaseBackend):
"DeveloperProviderName": identity_pool.developer_provider_name, "DeveloperProviderName": identity_pool.developer_provider_name,
"IdentityPoolId": identity_pool.identity_pool_id, "IdentityPoolId": identity_pool.identity_pool_id,
"IdentityPoolName": identity_pool.identity_pool_name, "IdentityPoolName": identity_pool.identity_pool_name,
"IdentityPoolTags": {}, "IdentityPoolTags": {}, # TODO: add tags
"OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns, "OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns,
"SamlProviderARNs": identity_pool.saml_provider_arns, "SamlProviderARNs": identity_pool.saml_provider_arns,
"SupportedLoginProviders": identity_pool.supported_login_providers, "SupportedLoginProviders": identity_pool.supported_login_providers,
@ -85,19 +99,38 @@ class CognitoIdentityBackend(BaseBackend):
) )
self.identity_pools[new_identity.identity_pool_id] = new_identity self.identity_pools[new_identity.identity_pool_id] = new_identity
response = json.dumps( response = new_identity.to_json()
{ return response
"IdentityPoolId": new_identity.identity_pool_id,
"IdentityPoolName": new_identity.identity_pool_name,
"AllowUnauthenticatedIdentities": new_identity.allow_unauthenticated_identities,
"SupportedLoginProviders": new_identity.supported_login_providers,
"DeveloperProviderName": new_identity.developer_provider_name,
"OpenIdConnectProviderARNs": new_identity.open_id_connect_provider_arns,
"CognitoIdentityProviders": new_identity.cognito_identity_providers,
"SamlProviderARNs": new_identity.saml_provider_arns,
}
)
def update_identity_pool(
self,
identity_pool_id,
identity_pool_name,
allow_unauthenticated,
allow_classic,
login_providers,
provider_name,
provider_arns,
identity_providers,
saml_providers,
pool_tags,
):
pool = self.identity_pools[identity_pool_id]
pool.identity_pool_name = pool.identity_pool_name or identity_pool_name
if allow_unauthenticated is not None:
pool.allow_unauthenticated_identities = allow_unauthenticated
if login_providers is not None:
pool.supported_login_providers = login_providers
if provider_name:
pool.developer_provider_name = provider_name
if provider_arns is not None:
pool.open_id_connect_provider_arns = provider_arns
if identity_providers is not None:
pool.cognito_identity_providers = identity_providers
if saml_providers is not None:
pool.saml_provider_arns = saml_providers
response = pool.to_json()
return response return response
def get_id(self): def get_id(self):

View File

@ -27,6 +27,31 @@ class CognitoIdentityResponse(BaseResponse):
saml_provider_arns=saml_provider_arns, saml_provider_arns=saml_provider_arns,
) )
def update_identity_pool(self):
pool_id = self._get_param("IdentityPoolId")
pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated = self._get_bool_param("AllowUnauthenticatedIdentities")
allow_classic = self._get_bool_param("AllowClassicFlow")
login_providers = self._get_multi_param_dict("SupportedLoginProviders")
provider_name = self._get_param("DeveloperProviderName")
provider_arns = self._get_multi_param("OpenIdConnectProviderARNs")
identity_providers = self._get_multi_param_dict("CognitoIdentityProviders")
saml_providers = self._get_multi_param("SamlProviderARNs")
pool_tags = self._get_multi_param_dict("IdentityPoolTags")
return cognitoidentity_backends[self.region].update_identity_pool(
identity_pool_id=pool_id,
identity_pool_name=pool_name,
allow_unauthenticated=allow_unauthenticated,
allow_classic=allow_classic,
login_providers=login_providers,
provider_name=provider_name,
provider_arns=provider_arns,
identity_providers=identity_providers,
saml_providers=saml_providers,
pool_tags=pool_tags,
)
def get_id(self): def get_id(self):
return cognitoidentity_backends[self.region].get_id() return cognitoidentity_backends[self.region].get_id()

View File

@ -8,10 +8,8 @@ import json
import os import os
import time import time
import uuid import uuid
from boto3 import Session from boto3 import Session
from jose import jws from jose import jws
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID from moto.core import ACCOUNT_ID as DEFAULT_ACCOUNT_ID

View File

@ -331,8 +331,11 @@ class CognitoIdpResponse(BaseResponse):
users = [ users = [
user user
for user in users for user in users
for attribute in user.attributes if [
if attribute["Name"] == name and attribute["Value"] == value attr
for attr in user.attributes
if attr["Name"] == name and attr["Value"] == value
]
] ]
response = {"Users": [user.to_json(extended=True) for user in users]} response = {"Users": [user.to_json(extended=True) for user in users]}
if token: if token:

View File

@ -4,16 +4,28 @@ from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment from jinja2 import DictLoader, Environment
import json import json
# TODO: add "<Type>Sender</Type>" to error responses below?
SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?> SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error> <Error>
<Code>{{error_type}}</Code> <Code>{{error_type}}</Code>
<Message>{{message}}</Message> <Message>{{message}}</Message>
{% block extra %}{% endblock %} {% block extra %}{% endblock %}
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID> <{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</Error> </Error>
""" """
WRAPPED_SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse{% if xmlns is defined %} xmlns="{{xmlns}}"{% endif %}>
<Error>
<Code>{{error_type}}</Code>
<Message>{{message}}</Message>
{% block extra %}{% endblock %}
<{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</Error>
</ErrorResponse>"""
ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?> ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse> <ErrorResponse>
<Errors> <Errors>
@ -23,7 +35,7 @@ ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
{% block extra %}{% endblock %} {% block extra %}{% endblock %}
</Error> </Error>
</Errors> </Errors>
<RequestID>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</RequestID> <{{request_id_tag}}>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</{{request_id_tag}}>
</ErrorResponse> </ErrorResponse>
""" """
@ -36,9 +48,12 @@ ERROR_JSON_RESPONSE = """{
class RESTError(HTTPException): class RESTError(HTTPException):
code = 400 code = 400
# most APIs use <RequestId>, but some APIs (including EC2, S3) use <RequestID>
request_id_tag_name = "RequestId"
templates = { templates = {
"single_error": SINGLE_ERROR_RESPONSE, "single_error": SINGLE_ERROR_RESPONSE,
"wrapped_single_error": WRAPPED_SINGLE_ERROR_RESPONSE,
"error": ERROR_RESPONSE, "error": ERROR_RESPONSE,
"error_json": ERROR_JSON_RESPONSE, "error_json": ERROR_JSON_RESPONSE,
} }
@ -49,9 +64,23 @@ class RESTError(HTTPException):
self.error_type = error_type self.error_type = error_type
self.message = message self.message = message
self.description = env.get_template(template).render( self.description = env.get_template(template).render(
error_type=error_type, message=message, **kwargs error_type=error_type,
message=message,
request_id_tag=self.request_id_tag_name,
**kwargs
) )
self.content_type = "application/xml"
def get_headers(self, *args, **kwargs):
return [
("X-Amzn-ErrorType", self.error_type or "UnknownError"),
("Content-Type", self.content_type),
]
def get_body(self, *args, **kwargs):
return self.description
class DryRunClientError(RESTError): class DryRunClientError(RESTError):
code = 400 code = 400
@ -60,9 +89,7 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError): class JsonRESTError(RESTError):
def __init__(self, error_type, message, template="error_json", **kwargs): def __init__(self, error_type, message, template="error_json", **kwargs):
super(JsonRESTError, self).__init__(error_type, message, template, **kwargs) super(JsonRESTError, self).__init__(error_type, message, template, **kwargs)
self.content_type = "application/json"
def get_headers(self, *args, **kwargs):
return [("Content-Type", "application/json")]
def get_body(self, *args, **kwargs): def get_body(self, *args, **kwargs):
return self.description return self.description

View File

@ -562,6 +562,7 @@ class CloudFormationModel(BaseModel):
# See for example https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html # See for example https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-dynamodb-table.html
return "AWS::SERVICE::RESOURCE" return "AWS::SERVICE::RESOURCE"
@classmethod
@abstractmethod @abstractmethod
def create_from_cloudformation_json( def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name
@ -572,6 +573,7 @@ class CloudFormationModel(BaseModel):
# and return an instance of the resource class # and return an instance of the resource class
pass pass
@classmethod
@abstractmethod @abstractmethod
def update_from_cloudformation_json( def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name cls, original_resource, new_resource_name, cloudformation_json, region_name
@ -583,6 +585,7 @@ class CloudFormationModel(BaseModel):
# the change in parameters and no-op when nothing has changed. # the change in parameters and no-op when nothing has changed.
pass pass
@classmethod
@abstractmethod @abstractmethod
def delete_from_cloudformation_json( def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name cls, resource_name, cloudformation_json, region_name

View File

@ -192,7 +192,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
region_from_useragent_regex = re.compile( region_from_useragent_regex = re.compile(
r"region/(?P<region>[a-z]{2}-[a-z]+-\d{1})" r"region/(?P<region>[a-z]{2}-[a-z]+-\d{1})"
) )
param_list_regex = re.compile(r"(.*)\.(\d+)\.") param_list_regex = re.compile(r"^(\.?[^.]*(\.member)?)\.(\d+)\.")
param_regex = re.compile(r"([^\.]*)\.(\w+)(\..+)?")
access_key_regex = re.compile( access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]" r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
) )
@ -252,7 +253,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
) )
) )
) )
except UnicodeEncodeError: except (UnicodeEncodeError, UnicodeDecodeError):
pass # ignore encoding errors, as the body may not contain a legitimate querystring pass # ignore encoding errors, as the body may not contain a legitimate querystring
if not querystring: if not querystring:
querystring.update(headers) querystring.update(headers)
@ -402,7 +403,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
try: try:
response = method() response = method()
except HTTPException as http_error: except HTTPException as http_error:
response = http_error.description, dict(status=http_error.code) response_headers = dict(http_error.get_headers() or [])
response_headers["status"] = http_error.code
response = http_error.description, response_headers
if isinstance(response, str): if isinstance(response, str):
return 200, headers, response return 200, headers, response
@ -460,15 +463,23 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _get_bool_param(self, param_name, if_none=None): def _get_bool_param(self, param_name, if_none=None):
val = self._get_param(param_name) val = self._get_param(param_name)
if val is not None: if val is not None:
val = str(val)
if val.lower() == "true": if val.lower() == "true":
return True return True
elif val.lower() == "false": elif val.lower() == "false":
return False return False
return if_none return if_none
def _get_multi_param_helper(self, param_prefix): def _get_multi_param_dict(self, param_prefix):
return self._get_multi_param_helper(param_prefix, skip_result_conversion=True)
def _get_multi_param_helper(
self, param_prefix, skip_result_conversion=False, tracked_prefixes=None
):
value_dict = dict() value_dict = dict()
tracked_prefixes = set() # prefixes which have already been processed tracked_prefixes = (
tracked_prefixes or set()
) # prefixes which have already been processed
def is_tracked(name_param): def is_tracked(name_param):
for prefix_loop in tracked_prefixes: for prefix_loop in tracked_prefixes:
@ -496,24 +507,47 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
tracked_prefixes.add(prefix) tracked_prefixes.add(prefix)
name = prefix name = prefix
value_dict[name] = value value_dict[name] = value
else:
match = self.param_regex.search(name[len(param_prefix) :])
if match:
# enable access to params that are lists of dicts, e.g., "TagSpecification.1.ResourceType=.."
sub_attr = "%s%s.%s" % (
name[: len(param_prefix)],
match.group(1),
match.group(2),
)
if match.group(3):
value = self._get_multi_param_helper(
sub_attr,
tracked_prefixes=tracked_prefixes,
skip_result_conversion=skip_result_conversion,
)
else:
value = self._get_param(sub_attr)
tracked_prefixes.add(sub_attr)
value_dict[name] = value
else: else:
value_dict[name] = value[0] value_dict[name] = value[0]
if not value_dict: if not value_dict:
return None return None
if len(value_dict) > 1: if skip_result_conversion or len(value_dict) > 1:
# strip off period prefix # strip off period prefix
value_dict = { value_dict = {
name[len(param_prefix) + 1 :]: value name[len(param_prefix) + 1 :]: value
for name, value in value_dict.items() for name, value in value_dict.items()
} }
for k in list(value_dict.keys()):
parts = k.split(".")
if len(parts) != 2 or parts[1] != "member":
value_dict[parts[0]] = value_dict.pop(k)
else: else:
value_dict = list(value_dict.values())[0] value_dict = list(value_dict.values())[0]
return value_dict return value_dict
def _get_multi_param(self, param_prefix): def _get_multi_param(self, param_prefix, skip_result_conversion=False):
""" """
Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2 Given a querystring of ?LaunchConfigurationNames.member.1=my-test-1&LaunchConfigurationNames.member.2=my-test-2
this will return ['my-test-1', 'my-test-2'] this will return ['my-test-1', 'my-test-2']
@ -525,7 +559,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
values = [] values = []
index = 1 index = 1
while True: while True:
value_dict = self._get_multi_param_helper(prefix + str(index)) value_dict = self._get_multi_param_helper(
prefix + str(index), skip_result_conversion=skip_result_conversion
)
if not value_dict and value_dict != "": if not value_dict and value_dict != "":
break break

View File

@ -4,6 +4,8 @@ from moto.core.exceptions import RESTError
class EC2ClientError(RESTError): class EC2ClientError(RESTError):
code = 400 code = 400
# EC2 uses <RequestID> as tag name in the XML response
request_id_tag_name = "RequestID"
class DependencyViolationError(EC2ClientError): class DependencyViolationError(EC2ClientError):
@ -612,7 +614,7 @@ class InvalidAssociationIDIamProfileAssociationError(EC2ClientError):
class InvalidVpcEndPointIdError(EC2ClientError): class InvalidVpcEndPointIdError(EC2ClientError):
def __init__(self, vpc_end_point_id): def __init__(self, vpc_end_point_id):
super(InvalidVpcEndPointIdError, self).__init__( super(InvalidVpcEndPointIdError, self).__init__(
"InvalidVpcEndPointId.NotFound", "InvalidVpcEndpointId.NotFound",
"The VpcEndPoint ID '{0}' does not exist".format(vpc_end_point_id), "The VpcEndPoint ID '{0}' does not exist".format(vpc_end_point_id),
) )

View File

@ -33,7 +33,7 @@ from moto.core.utils import (
) )
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from moto.kms import kms_backends from moto.kms import kms_backends
from moto.utilities.utils import load_resource from moto.utilities.utils import load_resource, merge_multiple_dicts
from os import listdir from os import listdir
from .exceptions import ( from .exceptions import (
@ -124,9 +124,12 @@ from .utils import (
random_internet_gateway_id, random_internet_gateway_id,
random_ip, random_ip,
random_ipv6_cidr, random_ipv6_cidr,
random_transit_gateway_attachment_id,
random_transit_gateway_route_table_id,
randor_ipv4_cidr, randor_ipv4_cidr,
random_launch_template_id, random_launch_template_id,
random_nat_gateway_id, random_nat_gateway_id,
random_transit_gateway_id,
random_key_pair, random_key_pair,
random_private_ip, random_private_ip,
random_public_ip, random_public_ip,
@ -249,7 +252,11 @@ class TaggedEC2Resource(BaseModel):
return [tag["key"] for tag in tags] return [tag["key"] for tag in tags]
elif filter_name == "tag-value": elif filter_name == "tag-value":
return [tag["value"] for tag in tags] return [tag["value"] for tag in tags]
else:
value = getattr(self, filter_name.lower().replace("-", "_"), None)
if value is not None:
return [value]
raise FilterNotImplementedError(filter_name, method_name) raise FilterNotImplementedError(filter_name, method_name)
@ -2015,7 +2022,9 @@ class SecurityRule(object):
class SecurityGroup(TaggedEC2Resource, CloudFormationModel): class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
def __init__(self, ec2_backend, group_id, name, description, vpc_id=None): def __init__(
self, ec2_backend, group_id, name, description, vpc_id=None, tags=None
):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = group_id self.id = group_id
self.name = name self.name = name
@ -2027,6 +2036,7 @@ class SecurityGroup(TaggedEC2Resource, CloudFormationModel):
self.enis = {} self.enis = {}
self.vpc_id = vpc_id self.vpc_id = vpc_id
self.owner_id = OWNER_ID self.owner_id = OWNER_ID
self.add_tags(tags or {})
# Append default IPv6 egress rule for VPCs with IPv6 support # Append default IPv6 egress rule for VPCs with IPv6 support
if vpc_id: if vpc_id:
@ -2188,7 +2198,9 @@ class SecurityGroupBackend(object):
super(SecurityGroupBackend, self).__init__() super(SecurityGroupBackend, self).__init__()
def create_security_group(self, name, description, vpc_id=None, force=False): def create_security_group(
self, name, description, vpc_id=None, tags=None, force=False
):
if not description: if not description:
raise MissingParameterError("GroupDescription") raise MissingParameterError("GroupDescription")
@ -2197,7 +2209,9 @@ class SecurityGroupBackend(object):
existing_group = self.get_security_group_from_name(name, vpc_id) existing_group = self.get_security_group_from_name(name, vpc_id)
if existing_group: if existing_group:
raise InvalidSecurityGroupDuplicateError(name) raise InvalidSecurityGroupDuplicateError(name)
group = SecurityGroup(self, group_id, name, description, vpc_id=vpc_id) group = SecurityGroup(
self, group_id, name, description, vpc_id=vpc_id, tags=tags
)
self.groups[vpc_id][group_id] = group self.groups[vpc_id][group_id] = group
return group return group
@ -3051,10 +3065,12 @@ class VPC(TaggedEC2Resource, CloudFormationModel):
).get("cidr_block"): ).get("cidr_block"):
raise OperationNotPermitted(association_id) raise OperationNotPermitted(association_id)
response = self.cidr_block_association_set.pop(association_id, {}) entry = response = self.cidr_block_association_set.get(association_id, {})
if response: if entry:
response = json.loads(json.dumps(entry))
response["vpc_id"] = self.id response["vpc_id"] = self.id
response["cidr_block_state"]["state"] = "disassociating" response["cidr_block_state"]["state"] = "disassociating"
entry["cidr_block_state"]["state"] = "disassociated"
return response return response
def get_cidr_block_association_set(self, ipv6=False): def get_cidr_block_association_set(self, ipv6=False):
@ -3229,7 +3245,7 @@ class VPCBackend(object):
network_interface_ids=[], network_interface_ids=[],
dns_entries=None, dns_entries=None,
client_token=None, client_token=None,
security_group=None, security_group_ids=None,
tag_specifications=None, tag_specifications=None,
private_dns_enabled=None, private_dns_enabled=None,
): ):
@ -3259,6 +3275,7 @@ class VPCBackend(object):
dns_entries = [dns_entries] dns_entries = [dns_entries]
vpc_end_point = VPCEndPoint( vpc_end_point = VPCEndPoint(
self,
vpc_endpoint_id, vpc_endpoint_id,
vpc_id, vpc_id,
service_name, service_name,
@ -3269,7 +3286,7 @@ class VPCBackend(object):
network_interface_ids, network_interface_ids,
dns_entries, dns_entries,
client_token, client_token,
security_group, security_group_ids,
tag_specifications, tag_specifications,
private_dns_enabled, private_dns_enabled,
) )
@ -4308,6 +4325,7 @@ class Route(CloudFormationModel):
class VPCEndPoint(TaggedEC2Resource): class VPCEndPoint(TaggedEC2Resource):
def __init__( def __init__(
self, self,
ec2_backend,
id, id,
vpc_id, vpc_id,
service_name, service_name,
@ -4318,10 +4336,11 @@ class VPCEndPoint(TaggedEC2Resource):
network_interface_ids=None, network_interface_ids=None,
dns_entries=None, dns_entries=None,
client_token=None, client_token=None,
security_group=None, security_group_ids=None,
tag_specifications=None, tag_specifications=None,
private_dns_enabled=None, private_dns_enabled=None,
): ):
self.ec2_backend = ec2_backend
self.id = id self.id = id
self.vpc_id = vpc_id self.vpc_id = vpc_id
self.service_name = service_name self.service_name = service_name
@ -4331,7 +4350,7 @@ class VPCEndPoint(TaggedEC2Resource):
self.network_interface_ids = network_interface_ids self.network_interface_ids = network_interface_ids
self.subnet_ids = subnet_ids self.subnet_ids = subnet_ids
self.client_token = client_token self.client_token = client_token
self.security_group = security_group self.security_group_ids = security_group_ids
self.tag_specifications = tag_specifications self.tag_specifications = tag_specifications
self.private_dns_enabled = private_dns_enabled self.private_dns_enabled = private_dns_enabled
self.created_at = datetime.utcnow() self.created_at = datetime.utcnow()
@ -5395,7 +5414,16 @@ class DHCPOptionsSetBackend(object):
class VPNConnection(TaggedEC2Resource): class VPNConnection(TaggedEC2Resource):
def __init__(self, ec2_backend, id, type, customer_gateway_id, vpn_gateway_id): def __init__(
self,
ec2_backend,
id,
type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
tags={},
):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = id self.id = id
self.state = "available" self.state = "available"
@ -5403,9 +5431,11 @@ class VPNConnection(TaggedEC2Resource):
self.type = type self.type = type
self.customer_gateway_id = customer_gateway_id self.customer_gateway_id = customer_gateway_id
self.vpn_gateway_id = vpn_gateway_id self.vpn_gateway_id = vpn_gateway_id
self.transit_gateway_id = transit_gateway_id
self.tunnels = None self.tunnels = None
self.options = None self.options = None
self.static_routes = None self.static_routes = None
self.add_tags(tags or {})
def get_filter_value(self, filter_name): def get_filter_value(self, filter_name):
return super(VPNConnection, self).get_filter_value( return super(VPNConnection, self).get_filter_value(
@ -5419,7 +5449,13 @@ class VPNConnectionBackend(object):
super(VPNConnectionBackend, self).__init__() super(VPNConnectionBackend, self).__init__()
def create_vpn_connection( def create_vpn_connection(
self, type, customer_gateway_id, vpn_gateway_id, static_routes_only=None self,
type,
customer_gateway_id,
vpn_gateway_id=None,
transit_gateway_id=None,
static_routes_only=None,
tags={},
): ):
vpn_connection_id = random_vpn_connection_id() vpn_connection_id = random_vpn_connection_id()
if static_routes_only: if static_routes_only:
@ -5430,6 +5466,8 @@ class VPNConnectionBackend(object):
type=type, type=type,
customer_gateway_id=customer_gateway_id, customer_gateway_id=customer_gateway_id,
vpn_gateway_id=vpn_gateway_id, vpn_gateway_id=vpn_gateway_id,
transit_gateway_id=transit_gateway_id,
tags=tags,
) )
self.vpn_connections[vpn_connection.id] = vpn_connection self.vpn_connections[vpn_connection.id] = vpn_connection
return vpn_connection return vpn_connection
@ -5437,10 +5475,10 @@ class VPNConnectionBackend(object):
def delete_vpn_connection(self, vpn_connection_id): def delete_vpn_connection(self, vpn_connection_id):
if vpn_connection_id in self.vpn_connections: if vpn_connection_id in self.vpn_connections:
self.vpn_connections.pop(vpn_connection_id) self.vpn_connections[vpn_connection_id].state = "deleted"
else: else:
raise InvalidVpnConnectionIdError(vpn_connection_id) raise InvalidVpnConnectionIdError(vpn_connection_id)
return True return self.vpn_connections[vpn_connection_id]
def describe_vpn_connections(self, vpn_connection_ids=None): def describe_vpn_connections(self, vpn_connection_ids=None):
vpn_connections = [] vpn_connections = []
@ -5723,10 +5761,23 @@ class NetworkAclEntry(TaggedEC2Resource):
class VpnGateway(TaggedEC2Resource): class VpnGateway(TaggedEC2Resource):
def __init__(self, ec2_backend, id, type): def __init__(
self,
ec2_backend,
id,
type,
amazon_side_asn,
availability_zone,
tags=None,
state="available",
):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = id self.id = id
self.type = type self.type = type
self.amazon_side_asn = amazon_side_asn
self.availability_zone = availability_zone
self.state = state
self.add_tags(tags or {})
self.attachments = {} self.attachments = {}
super(VpnGateway, self).__init__() super(VpnGateway, self).__init__()
@ -5756,9 +5807,13 @@ class VpnGatewayBackend(object):
self.vpn_gateways = {} self.vpn_gateways = {}
super(VpnGatewayBackend, self).__init__() super(VpnGatewayBackend, self).__init__()
def create_vpn_gateway(self, type="ipsec.1"): def create_vpn_gateway(
self, type="ipsec.1", amazon_side_asn=None, availability_zone=None, tags=None
):
vpn_gateway_id = random_vpn_gateway_id() vpn_gateway_id = random_vpn_gateway_id()
vpn_gateway = VpnGateway(self, vpn_gateway_id, type) vpn_gateway = VpnGateway(
self, vpn_gateway_id, type, amazon_side_asn, availability_zone, tags
)
self.vpn_gateways[vpn_gateway_id] = vpn_gateway self.vpn_gateways[vpn_gateway_id] = vpn_gateway
return vpn_gateway return vpn_gateway
@ -5795,13 +5850,17 @@ class VpnGatewayBackend(object):
class CustomerGateway(TaggedEC2Resource): class CustomerGateway(TaggedEC2Resource):
def __init__(self, ec2_backend, id, type, ip_address, bgp_asn): def __init__(
self, ec2_backend, id, type, ip_address, bgp_asn, state="available", tags=None
):
self.ec2_backend = ec2_backend self.ec2_backend = ec2_backend
self.id = id self.id = id
self.type = type self.type = type
self.ip_address = ip_address self.ip_address = ip_address
self.bgp_asn = bgp_asn self.bgp_asn = bgp_asn
self.attachments = {} self.attachments = {}
self.state = state
self.add_tags(tags or {})
super(CustomerGateway, self).__init__() super(CustomerGateway, self).__init__()
def get_filter_value(self, filter_name): def get_filter_value(self, filter_name):
@ -5815,17 +5874,44 @@ class CustomerGatewayBackend(object):
self.customer_gateways = {} self.customer_gateways = {}
super(CustomerGatewayBackend, self).__init__() super(CustomerGatewayBackend, self).__init__()
def create_customer_gateway(self, type="ipsec.1", ip_address=None, bgp_asn=None): def create_customer_gateway(
self, type="ipsec.1", ip_address=None, bgp_asn=None, tags=None
):
customer_gateway_id = random_customer_gateway_id() customer_gateway_id = random_customer_gateway_id()
customer_gateway = CustomerGateway( customer_gateway = CustomerGateway(
self, customer_gateway_id, type, ip_address, bgp_asn self, customer_gateway_id, type, ip_address, bgp_asn, tags=tags
) )
self.customer_gateways[customer_gateway_id] = customer_gateway self.customer_gateways[customer_gateway_id] = customer_gateway
return customer_gateway return customer_gateway
def get_all_customer_gateways(self, filters=None): def get_all_customer_gateways(self, filters=None):
customer_gateways = self.customer_gateways.values() customer_gateways = self.customer_gateways.values()
return generic_filter(filters, customer_gateways) if filters is not None:
if filters.get("customer-gateway-id") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.id in filters["customer-gateway-id"]
]
if filters.get("type") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.type in filters["type"]
]
if filters.get("bgp-asn") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.bgp_asn in filters["bgp-asn"]
]
if filters.get("ip-address") is not None:
customer_gateways = [
customer_gateway
for customer_gateway in customer_gateways
if customer_gateway.ip_address in filters["ip-address"]
]
return customer_gateways
def get_customer_gateway(self, customer_gateway_id): def get_customer_gateway(self, customer_gateway_id):
customer_gateway = self.customer_gateways.get(customer_gateway_id, None) customer_gateway = self.customer_gateways.get(customer_gateway_id, None)
@ -5834,12 +5920,425 @@ class CustomerGatewayBackend(object):
return customer_gateway return customer_gateway
def delete_customer_gateway(self, customer_gateway_id): def delete_customer_gateway(self, customer_gateway_id):
deleted = self.customer_gateways.pop(customer_gateway_id, None) customer_gateway = self.get_customer_gateway(customer_gateway_id)
customer_gateway.state = "deleted"
# deleted = self.customer_gateways.pop(customer_gateway_id, None)
deleted = True
if not deleted: if not deleted:
raise InvalidCustomerGatewayIdError(customer_gateway_id) raise InvalidCustomerGatewayIdError(customer_gateway_id)
return deleted return deleted
class TransitGateway(TaggedEC2Resource, CloudFormationModel):
DEFAULT_OPTIONS = {
"AmazonSideAsn": "64512",
"AssociationDefaultRouteTableId": "tgw-rtb-0d571391e50cf8514",
"AutoAcceptSharedAttachments": "disable",
"DefaultRouteTableAssociation": "enable",
"DefaultRouteTablePropagation": "enable",
"DnsSupport": "enable",
"MulticastSupport": "disable",
"PropagationDefaultRouteTableId": "tgw-rtb-0d571391e50cf8514",
"TransitGatewayCidrBlocks": None,
"VpnEcmpSupport": "enable",
}
def __init__(self, backend, description=None, options=None, tags=None):
self.ec2_backend = backend
self.id = random_transit_gateway_id()
self.description = description
self.state = "available"
self.add_tags(tags or {})
self.options = merge_multiple_dicts(self.DEFAULT_OPTIONS, options or {})
self._created_at = datetime.utcnow()
@property
def physical_resource_id(self):
return self.id
@property
def create_time(self):
return iso_8601_datetime_with_milliseconds(self._created_at)
@property
def owner_id(self):
return ACCOUNT_ID
@staticmethod
def cloudformation_name_type():
return None
@staticmethod
def cloudformation_type():
# https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/aws-resource-ec2-natgateway.html
return "AWS::EC2::TransitGateway"
@classmethod
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
ec2_backend = ec2_backends[region_name]
transit_gateway = ec2_backend.create_transit_gateway(
cloudformation_json["Properties"]["Description"],
cloudformation_json["Properties"]["Options"],
)
return transit_gateway
class TransitGatewayBackend(object):
def __init__(self):
self.transit_gateways = {}
super(TransitGatewayBackend, self).__init__()
def create_transit_gateway(self, description=None, options=None, tags=None):
transit_gateway = TransitGateway(self, description, options, tags)
self.transit_gateways[transit_gateway.id] = transit_gateway
return transit_gateway
def get_all_transit_gateways(self, filters):
transit_gateways = self.transit_gateways.values()
if filters is not None:
if filters.get("transit-gateway-id") is not None:
transit_gateways = [
transit_gateway
for transit_gateway in transit_gateways
if transit_gateway.id in filters["transit-gateway-id"]
]
if filters.get("state") is not None:
transit_gateways = [
transit_gateway
for transit_gateway in transit_gateways
if transit_gateway.state in filters["state"]
]
if filters.get("owner-id") is not None:
transit_gateways = [
transit_gateway
for transit_gateway in transit_gateways
if transit_gateway.owner_id in filters["owner-id"]
]
return transit_gateways
def delete_transit_gateway(self, transit_gateway_id):
return self.transit_gateways.pop(transit_gateway_id)
def modify_transit_gateway(
self, transit_gateway_id, description=None, options=None
):
transit_gateway = self.transit_gateways.get(transit_gateway_id)
if description:
transit_gateway.description = description
if options:
transit_gateway.options.update(options)
return transit_gateway
class TransitGatewayRouteTable(TaggedEC2Resource):
def __init__(
self,
backend,
transit_gateway_id,
tags=None,
default_association_route_table=False,
default_propagation_route_table=False,
):
self.ec2_backend = backend
self.id = random_transit_gateway_route_table_id()
self.transit_gateway_id = transit_gateway_id
self._created_at = datetime.utcnow()
self.default_association_route_table = default_association_route_table
self.default_propagation_route_table = default_propagation_route_table
self.state = "available"
self.routes = {}
self.add_tags(tags or {})
@property
def physical_resource_id(self):
return self.id
@property
def create_time(self):
return iso_8601_datetime_with_milliseconds(self._created_at)
class TransitGatewayRouteTableBackend(object):
def __init__(self):
self.transit_gateways_route_tables = {}
super(TransitGatewayRouteTableBackend, self).__init__()
def create_transit_gateway_route_table(
self,
transit_gateway_id,
tags=None,
default_association_route_table=False,
default_propagation_route_table=False,
):
transit_gateways_route_table = TransitGatewayRouteTable(
self,
transit_gateway_id=transit_gateway_id,
tags=tags,
default_association_route_table=default_association_route_table,
default_propagation_route_table=default_propagation_route_table,
)
self.transit_gateways_route_tables[
transit_gateways_route_table.id
] = transit_gateways_route_table
return transit_gateways_route_table
def get_all_transit_gateway_route_tables(
self, transit_gateway_ids=None, filters=None
):
transit_gateway_route_tables = self.transit_gateways_route_tables.values()
attr_pairs = (
("default-association-route-table", "default_association_route_table"),
("default-propagation-route-table", "default_propagation_route_table"),
("state", "state"),
("transit-gateway-id", "transit_gateway_id"),
("transit-gateway-route-table-id", "id"),
)
if transit_gateway_ids:
transit_gateway_route_tables = [
transit_gateway_route_table
for transit_gateway_route_table in transit_gateway_route_tables
if transit_gateway_route_table.id in transit_gateway_ids
]
if filters:
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values is not None:
transit_gateway_route_tables = [
transit_gateway_route_table
for transit_gateway_route_table in transit_gateway_route_tables
if not values
or getattr(transit_gateway_route_table, attrs[1]) in values
]
return transit_gateway_route_tables
def delete_transit_gateway_route_table(self, transit_gateway_route_table_id):
return self.transit_gateways_route_tables.pop(transit_gateway_route_table_id)
def create_transit_gateway_route(
self,
transit_gateway_route_table_id,
destination_cidr_block,
transit_gateway_attachment_id=None,
blackhole=False,
):
transit_gateways_route_table = self.transit_gateways_route_tables[
transit_gateway_route_table_id
]
transit_gateways_route_table.routes[destination_cidr_block] = {
"destinationCidrBlock": destination_cidr_block,
"prefixListId": "",
"state": "blackhole" if blackhole else "active",
# TODO: needs to be fixed once we have support for transit gateway attachments
"transitGatewayAttachments": {
"resourceId": "TODO",
"resourceType": "TODO",
"transitGatewayAttachmentId": transit_gateway_attachment_id,
},
"type": "TODO",
}
return transit_gateways_route_table
def delete_transit_gateway_route(
self, transit_gateway_route_table_id, destination_cidr_block,
):
transit_gateways_route_table = self.transit_gateways_route_tables[
transit_gateway_route_table_id
]
transit_gateways_route_table.routes[destination_cidr_block]["state"] = "deleted"
return transit_gateways_route_table
def search_transit_gateway_routes(
self, transit_gateway_route_table_id, filters, max_results=None
):
transit_gateway_route_table = self.transit_gateways_route_tables[
transit_gateway_route_table_id
]
attr_pairs = (
("type", "type"),
("state", "state"),
)
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values:
routes = [
transit_gateway_route_table.routes[key]
for key in transit_gateway_route_table.routes
if transit_gateway_route_table.routes[key][attrs[1]] in values
]
if max_results:
routes = routes[: int(max_results)]
return routes
class TransitGatewayAttachment(TaggedEC2Resource):
def __init__(
self, backend, resource_id, resource_type, transit_gateway_id, tags=None
):
self.ec2_backend = backend
self.association = {}
self.resource_id = resource_id
self.resource_type = resource_type
self.id = random_transit_gateway_attachment_id()
self.transit_gateway_id = transit_gateway_id
self.state = "available"
self.add_tags(tags or {})
self._created_at = datetime.utcnow()
@property
def create_time(self):
return iso_8601_datetime_with_milliseconds(self._created_at)
@property
def resource_owner_id(self):
return ACCOUNT_ID
@property
def transit_gateway_owner_id(self):
return ACCOUNT_ID
class TransitGatewayVpcAttachment(TransitGatewayAttachment):
DEFAULT_OPTIONS = {
"ApplianceModeSupport": "disable",
"DnsSupport": "enable",
"Ipv6Support": "disable",
}
def __init__(
self, backend, transit_gateway_id, vpc_id, subnet_ids, tags=None, options=None
):
super().__init__(
backend=backend,
transit_gateway_id=transit_gateway_id,
resource_id=vpc_id,
resource_type="vpc",
tags=tags,
)
self.vpc_id = vpc_id
self.subnet_ids = subnet_ids
self.options = merge_multiple_dicts(self.DEFAULT_OPTIONS, options or {})
class TransitGatewayAttachmentBackend(object):
def __init__(self):
self.transit_gateways_attachments = {}
super(TransitGatewayAttachmentBackend, self).__init__()
def create_transit_gateway_vpn_attachment(
self, vpn_id, transit_gateway_id, tags=[]
):
transit_gateway_vpn_attachment = TransitGatewayAttachment(
self,
resource_id=vpn_id,
resource_type="vpn",
transit_gateway_id=transit_gateway_id,
tags=tags,
)
self.transit_gateways_attachments[
transit_gateway_vpn_attachment.id
] = transit_gateway_vpn_attachment
return transit_gateway_vpn_attachment
def create_transit_gateway_vpc_attachment(
self, transit_gateway_id, vpc_id, subnet_ids, tags=None, options=None
):
transit_gateway_vpc_attachment = TransitGatewayVpcAttachment(
self,
transit_gateway_id=transit_gateway_id,
tags=tags,
vpc_id=vpc_id,
subnet_ids=subnet_ids,
options=options,
)
self.transit_gateways_attachments[
transit_gateway_vpc_attachment.id
] = transit_gateway_vpc_attachment
return transit_gateway_vpc_attachment
def describe_transit_gateway_attachments(
self, transit_gateways_attachment_ids=None, filters=None, max_results=0
):
transit_gateways_attachments = self.transit_gateways_attachments.values()
attr_pairs = (
("resource-id", "resource_id"),
("resource-type", "resource_type"),
("transit-gateway-id", "transit_gateway_id"),
)
if transit_gateways_attachment_ids:
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if transit_gateways_attachment.id in transit_gateways_attachment_ids
]
if filters:
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values is not None:
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if getattr(transit_gateways_attachment, attrs[1]) in values
]
return transit_gateways_attachments
def describe_transit_gateway_vpc_attachments(
self, transit_gateways_attachment_ids=None, filters=None, max_results=0
):
transit_gateways_attachments = self.transit_gateways_attachments.values()
attr_pairs = (
("state", "state"),
("transit-gateway-attachment-id", "id"),
("transit-gateway-id", "transit_gateway_id"),
("vpc-id", "resource_id"),
)
if (
not transit_gateways_attachment_ids == []
and transit_gateways_attachment_ids is not None
):
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if transit_gateways_attachment.id in transit_gateways_attachment_ids
]
if filters:
for attrs in attr_pairs:
values = filters.get(attrs[0]) or None
if values is not None:
transit_gateways_attachments = [
transit_gateways_attachment
for transit_gateways_attachment in transit_gateways_attachments
if getattr(transit_gateways_attachment, attrs[1]) in values
]
return transit_gateways_attachments
class NatGateway(CloudFormationModel): class NatGateway(CloudFormationModel):
def __init__(self, backend, subnet_id, allocation_id, tags=[]): def __init__(self, backend, subnet_id, allocation_id, tags=[]):
# public properties # public properties
@ -6199,6 +6698,9 @@ class EC2Backend(
VpnGatewayBackend, VpnGatewayBackend,
CustomerGatewayBackend, CustomerGatewayBackend,
NatGatewayBackend, NatGatewayBackend,
TransitGatewayBackend,
TransitGatewayRouteTableBackend,
TransitGatewayAttachmentBackend,
LaunchTemplateBackend, LaunchTemplateBackend,
IamInstanceProfileAssociationBackend, IamInstanceProfileAssociationBackend,
): ):

View File

@ -576,5 +576,39 @@
"name": "suse-sles-11-sp4-v20151207-pv-ssd-x86_64", "name": "suse-sles-11-sp4-v20151207-pv-ssd-x86_64",
"virtualization_type": "paravirtual", "virtualization_type": "paravirtual",
"hypervisor": "xen" "hypervisor": "xen"
},
{
"ami_id": "ami-ekswin",
"state": "available",
"public": true,
"owner_id": "801119661308",
"image_location": "amazon/amazon-eks",
"sriov": "simple",
"root_device_type": "ebs",
"root_device_name": "/dev/sda1",
"description": "Microsoft Windows Server 2019 Core optimized for EKS and provided by Amazon",
"image_type": "machine",
"platform": "windows",
"architecture": "x86_64",
"name": "Windows_Server-2019-English-Core-EKS_Optimized",
"virtualization_type": "hvm",
"hypervisor": "xen"
},
{
"ami_id": "ami-ekslinux",
"state": "available",
"public": true,
"owner_id": "801119661308",
"image_location": "amazon/amazon-eks",
"sriov": "simple",
"root_device_type": "ebs",
"root_device_name": "/dev/sda1",
"description": "EKS Kubernetes Worker AMI with AmazonLinux2 image",
"image_type": "machine",
"platform": "Linux/UNIX",
"architecture": "x86_64",
"name": "amazon-eks-node-linux",
"virtualization_type": "hvm",
"hypervisor": "xen"
} }
] ]

View File

@ -34,6 +34,9 @@ from .vpc_peering_connections import VPCPeeringConnections
from .vpn_connections import VPNConnections from .vpn_connections import VPNConnections
from .windows import Windows from .windows import Windows
from .nat_gateways import NatGateways from .nat_gateways import NatGateways
from .transit_gateways import TransitGateways
from .transit_gateway_route_tables import TransitGatewayRouteTable
from .transit_gateway_attachments import TransitGatewayAttachment
from .iam_instance_profiles import IamInstanceProfiles from .iam_instance_profiles import IamInstanceProfiles
@ -72,6 +75,9 @@ class EC2Response(
VPNConnections, VPNConnections,
Windows, Windows,
NatGateways, NatGateways,
TransitGateways,
TransitGatewayRouteTable,
TransitGatewayAttachment,
IamInstanceProfiles, IamInstanceProfiles,
): ):
@property @property

0
moto/ec2/responses/amis.py Executable file → Normal file
View File

View File

@ -9,8 +9,12 @@ class CustomerGateways(BaseResponse):
type = self._get_param("Type") type = self._get_param("Type")
ip_address = self._get_param("IpAddress") ip_address = self._get_param("IpAddress")
bgp_asn = self._get_param("BgpAsn") bgp_asn = self._get_param("BgpAsn")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
customer_gateway = self.ec2_backend.create_customer_gateway( customer_gateway = self.ec2_backend.create_customer_gateway(
type, ip_address=ip_address, bgp_asn=bgp_asn type, ip_address=ip_address, bgp_asn=bgp_asn, tags=tags
) )
template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=customer_gateway) return template.render(customer_gateway=customer_gateway)
@ -19,7 +23,7 @@ class CustomerGateways(BaseResponse):
customer_gateway_id = self._get_param("CustomerGatewayId") customer_gateway_id = self._get_param("CustomerGatewayId")
delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id) delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id)
template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=delete_status) return template.render(delete_status=delete_status)
def describe_customer_gateways(self): def describe_customer_gateways(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
@ -33,15 +37,13 @@ CREATE_CUSTOMER_GATEWAY_RESPONSE = """
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<customerGateway> <customerGateway>
<customerGatewayId>{{ customer_gateway.id }}</customerGatewayId> <customerGatewayId>{{ customer_gateway.id }}</customerGatewayId>
<state>pending</state> <state>{{ customer_gateway.state }}</state>
<type>{{ customer_gateway.type }}</type> <type>{{ customer_gateway.type }}</type>
<ipAddress>{{ customer_gateway.ip_address }}</ipAddress> <ipAddress>{{ customer_gateway.ip_address }}</ipAddress>
<bgpAsn>{{ customer_gateway.bgp_asn }}</bgpAsn> <bgpAsn>{{ customer_gateway.bgp_asn }}</bgpAsn>
<tagSet> <tagSet>
{% for tag in customer_gateway.get_tags() %} {% for tag in customer_gateway.get_tags() %}
<item> <item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key> <key>{{ tag.key }}</key>
<value>{{ tag.value }}</value> <value>{{ tag.value }}</value>
</item> </item>
@ -64,14 +66,12 @@ DESCRIBE_CUSTOMER_GATEWAYS_RESPONSE = """
<item> <item>
<customerGatewayId>{{ customer_gateway.id }}</customerGatewayId> <customerGatewayId>{{ customer_gateway.id }}</customerGatewayId>
<state>{{ customer_gateway.state }}</state> <state>{{ customer_gateway.state }}</state>
<type>available</type> <type>{{ customer_gateway.type }}</type>
<ipAddress>{{ customer_gateway.ip_address }}</ipAddress> <ipAddress>{{ customer_gateway.ip_address }}</ipAddress>
<bgpAsn>{{ customer_gateway.bgp_asn }}</bgpAsn> <bgpAsn>{{ customer_gateway.bgp_asn }}</bgpAsn>
<tagSet> <tagSet>
{% for tag in customer_gateway.get_tags() %} {% for tag in customer_gateway.get_tags() %}
<item> <item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key> <key>{{ tag.key }}</key>
<value>{{ tag.value }}</value> <value>{{ tag.value }}</value>
</item> </item>

View File

@ -14,9 +14,11 @@ class InternetGateways(BaseResponse):
def create_internet_gateway(self): def create_internet_gateway(self):
if self.is_not_dryrun("CreateInternetGateway"): if self.is_not_dryrun("CreateInternetGateway"):
tags = self._get_multi_param("TagSpecification") tags = self._get_multi_param(
"TagSpecification", skip_result_conversion=True
)
if tags: if tags:
tags = tags[0].get("Tag") tags = tags[0].get("Tag") or []
igw = self.ec2_backend.create_internet_gateway(tags=tags) igw = self.ec2_backend.create_internet_gateway(tags=tags)
template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE) template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE)
return template.render(internet_gateway=igw) return template.render(internet_gateway=igw)

View File

@ -40,9 +40,9 @@ class RouteTables(BaseResponse):
def create_route_table(self): def create_route_table(self):
vpc_id = self._get_param("VpcId") vpc_id = self._get_param("VpcId")
tags = self._get_multi_param("TagSpecification") tags = self._get_multi_param("TagSpecification", skip_result_conversion=True)
if tags: if tags:
tags = tags[0].get("Tag") tags = tags[0].get("Tag") or []
route_table = self.ec2_backend.create_route_table(vpc_id, tags) route_table = self.ec2_backend.create_route_table(vpc_id, tags)
template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE) template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE)
return template.render(route_table=route_table) return template.render(route_table=route_table)

View File

@ -115,10 +115,14 @@ class SecurityGroups(BaseResponse):
name = self._get_param("GroupName") name = self._get_param("GroupName")
description = self._get_param("GroupDescription") description = self._get_param("GroupDescription")
vpc_id = self._get_param("VpcId") vpc_id = self._get_param("VpcId")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
if self.is_not_dryrun("CreateSecurityGroup"): if self.is_not_dryrun("CreateSecurityGroup"):
group = self.ec2_backend.create_security_group( group = self.ec2_backend.create_security_group(
name, description, vpc_id=vpc_id name, description, vpc_id=vpc_id, tags=tags
) )
template = self.response_template(CREATE_SECURITY_GROUP_RESPONSE) template = self.response_template(CREATE_SECURITY_GROUP_RESPONSE)
return template.render(group=group) return template.render(group=group)

View File

@ -64,13 +64,13 @@ CREATE_SUBNET_RESPONSE = """
<state>pending</state> <state>pending</state>
<vpcId>{{ subnet.vpc_id }}</vpcId> <vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock> <cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>{{ subnet.available_ip_addresses }}</availableIpAddressCount> <availableIpAddressCount>{{ subnet.available_ip_addresses or '0' }}</availableIpAddressCount>
<availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone> <availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId> <availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
<defaultForAz>{{ subnet.default_for_az }}</defaultForAz> <defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch> <mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<ownerId>{{ subnet.owner_id }}</ownerId> <ownerId>{{ subnet.owner_id }}</ownerId>
<assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation> <assignIpv6AddressOnCreation>{{ 'false' if not subnet.assign_ipv6_address_on_creation or subnet.assign_ipv6_address_on_creation == 'false' else 'true'}}</assignIpv6AddressOnCreation>
<ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet> <ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn> <subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
<tagSet> <tagSet>
@ -102,13 +102,13 @@ DESCRIBE_SUBNETS_RESPONSE = """
<state>{{ subnet.state }}</state> <state>{{ subnet.state }}</state>
<vpcId>{{ subnet.vpc_id }}</vpcId> <vpcId>{{ subnet.vpc_id }}</vpcId>
<cidrBlock>{{ subnet.cidr_block }}</cidrBlock> <cidrBlock>{{ subnet.cidr_block }}</cidrBlock>
<availableIpAddressCount>{{ subnet.available_ip_addresses }}</availableIpAddressCount> <availableIpAddressCount>{{ subnet.available_ip_addresses or '0' }}</availableIpAddressCount>
<availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone> <availabilityZone>{{ subnet._availability_zone.name }}</availabilityZone>
<availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId> <availabilityZoneId>{{ subnet._availability_zone.zone_id }}</availabilityZoneId>
<defaultForAz>{{ subnet.default_for_az }}</defaultForAz> <defaultForAz>{{ subnet.default_for_az }}</defaultForAz>
<mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch> <mapPublicIpOnLaunch>{{ subnet.map_public_ip_on_launch }}</mapPublicIpOnLaunch>
<ownerId>{{ subnet.owner_id }}</ownerId> <ownerId>{{ subnet.owner_id }}</ownerId>
<assignIpv6AddressOnCreation>{{ subnet.assign_ipv6_address_on_creation }}</assignIpv6AddressOnCreation> <assignIpv6AddressOnCreation>{{ 'false' if not subnet.assign_ipv6_address_on_creation or subnet.assign_ipv6_address_on_creation == 'false' else 'true'}}</assignIpv6AddressOnCreation>
<ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet> <ipv6CidrBlockAssociationSet>{{ subnet.ipv6_cidr_block_associations }}</ipv6CidrBlockAssociationSet>
<subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn> <subnetArn>arn:aws:ec2:{{ subnet._availability_zone.name[0:-1] }}:{{ subnet.owner_id }}:subnet/{{ subnet.id }}</subnetArn>
{% if subnet.get_tags() %} {% if subnet.get_tags() %}

View File

@ -0,0 +1,155 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
class TransitGatewayAttachment(BaseResponse):
def create_transit_gateway_vpc_attachment(self):
options = self._get_multi_param_dict("Options")
subnet_ids = self._get_multi_param("SubnetIds")
transit_gateway_id = self._get_param("TransitGatewayId")
vpc_id = self._get_param("VpcId")
tags = self._get_multi_param("TagSpecifications")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
transit_gateway_attachment = self.ec2_backend.create_transit_gateway_vpc_attachment(
transit_gateway_id=transit_gateway_id,
tags=tags,
vpc_id=vpc_id,
subnet_ids=subnet_ids,
options=options,
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_VPC_ATTACHMENT)
return template.render(transit_gateway_attachment=transit_gateway_attachment)
def describe_transit_gateway_vpc_attachments(self):
transit_gateways_attachment_ids = self._get_multi_param(
"TransitGatewayAttachmentIds"
)
filters = filters_from_querystring(self.querystring)
max_results = self._get_param("MaxResults")
transit_gateway_vpc_attachments = self.ec2_backend.describe_transit_gateway_vpc_attachments(
transit_gateways_attachment_ids=transit_gateways_attachment_ids,
filters=filters,
max_results=max_results,
)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_VPC_ATTACHMENTS)
return template.render(
transit_gateway_vpc_attachments=transit_gateway_vpc_attachments
)
def describe_transit_gateway_attachments(self):
transit_gateways_attachment_ids = self._get_multi_param(
"TransitGatewayAttachmentIds"
)
filters = filters_from_querystring(self.querystring)
max_results = self._get_param("MaxResults")
transit_gateway_attachments = self.ec2_backend.describe_transit_gateway_attachments(
transit_gateways_attachment_ids=transit_gateways_attachment_ids,
filters=filters,
max_results=max_results,
)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_ATTACHMENTS)
return template.render(transit_gateway_attachments=transit_gateway_attachments)
CREATE_TRANSIT_GATEWAY_VPC_ATTACHMENT = """<CreateTransitGatewayVpcAttachmentResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>9b5766ac-2af6-4b92-9a8a-4d74ae46ae79</requestId>
<transitGatewayVpcAttachment>
<createTime>{{ transit_gateway_attachment.create_time }}</createTime>
<options>
<applianceModeSupport>{{ transit_gateway_attachment.options.ApplianceModeSupport }}</applianceModeSupport>
<dnsSupport>{{ transit_gateway_attachment.options.DnsSupport }}</dnsSupport>
<ipv6Support>{{ transit_gateway_attachment.options.Ipv6Support }}</ipv6Support>
</options>
<state>{{ transit_gateway_attachment.state }}</state>
<subnetIds>
{% for subnet_id in transit_gateway_attachment.subnet_ids %}
<item>{{ subnet_id }}</item>
{% endfor %}
</subnetIds>
<tagSet>
{% for tag in transit_gateway_attachment.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayAttachmentId>{{ transit_gateway_attachment.id }}</transitGatewayAttachmentId>
<transitGatewayId>{{ transit_gateway_attachment.transit_gateway_id }}</transitGatewayId>
<vpcId>{{ transit_gateway_attachment.vpc_id }}</vpcId>
<vpcOwnerId>{{ transit_gateway_attachment.resource_owner_id }}</vpcOwnerId>
</transitGatewayVpcAttachment>
</CreateTransitGatewayVpcAttachmentResponse>"""
DESCRIBE_TRANSIT_GATEWAY_ATTACHMENTS = """<DescribeTransitGatewayAttachmentsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>92aa7885-74c0-42d1-a846-e59bd07488a7</requestId>
<transitGatewayAttachments>
{% for transit_gateway_attachment in transit_gateway_attachments %}
<item>
<association>
<state>associated</state>
<transitGatewayRouteTableId>tgw-rtb-0b36edb9b88f0d5e3</transitGatewayRouteTableId>
</association>
<creationTime>2021-07-18T08:57:21.000Z</creationTime>
<resourceId>{{ transit_gateway_attachment.resource_id }}</resourceId>
<resourceOwnerId>{{ transit_gateway_attachment.resource_owner_id }}</resourceOwnerId>
<resourceType>{{ transit_gateway_attachment.resource_type }}</resourceType>
<state>{{ transit_gateway_attachment.state }}</state>
<tagSet>
{% for tag in transit_gateway_attachment.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayAttachmentId>{{ transit_gateway_attachment.id }}</transitGatewayAttachmentId>
<transitGatewayId>{{ transit_gateway_attachment.transit_gateway_id }}</transitGatewayId>
<transitGatewayOwnerId>074255357339</transitGatewayOwnerId>
</item>
{% endfor %}
</transitGatewayAttachments>
</DescribeTransitGatewayAttachmentsResponse>
"""
DESCRIBE_TRANSIT_GATEWAY_VPC_ATTACHMENTS = """<DescribeTransitGatewayVpcAttachmentsResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>bebc9670-0205-4f28-ad89-049c97e46633</requestId>
<transitGatewayVpcAttachments>
{% for transit_gateway_vpc_attachment in transit_gateway_vpc_attachments %}
<item>
<creationTime>2021-07-18T08:57:21.000Z</creationTime>
<options>
<applianceModeSupport>{{ transit_gateway_vpc_attachment.options.ApplianceModeSupport }}</applianceModeSupport>
<dnsSupport>{{ transit_gateway_vpc_attachment.options.DnsSupport }}</dnsSupport>
<ipv6Support>{{ transit_gateway_vpc_attachment.options.Ipv6Support }}</ipv6Support>
</options>
<state>{{ transit_gateway_vpc_attachment.state }}</state>
<subnetIds>
{% for id in transit_gateway_vpc_attachment.subnet_ids %}
<item>id</item>
{% endfor %}
</subnetIds>
<tagSet>
{% for tag in transit_gateway_vpc_attachment.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayAttachmentId>{{ transit_gateway_vpc_attachment.id }}</transitGatewayAttachmentId>
<transitGatewayId>{{ transit_gateway_vpc_attachment.transit_gateway_id }}</transitGatewayId>
<vpcId>{{ transit_gateway_vpc_attachment.vpc_id }}</vpcId>
<vpcOwnerId>074255357339</vpcOwnerId>
</item>
{% endfor %}
</transitGatewayVpcAttachments>
</DescribeTransitGatewayVpcAttachmentsResponse>
"""

View File

@ -0,0 +1,187 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
from moto.utilities.utils import str2bool
class TransitGatewayRouteTable(BaseResponse):
def create_transit_gateway_route_table(self):
transit_gateway_id = self._get_param("TransitGatewayId")
tags = self._get_multi_param("TagSpecifications")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
transit_gateway_route_table = self.ec2_backend.create_transit_gateway_route_table(
transit_gateway_id=transit_gateway_id, tags=tags
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE)
return template.render(transit_gateway_route_table=transit_gateway_route_table)
def describe_transit_gateway_route_tables(self):
filters = filters_from_querystring(self.querystring)
transit_gateway_ids = (
self._get_multi_param("TransitGatewayRouteTableIds") or None
)
transit_gateway_route_tables = self.ec2_backend.get_all_transit_gateway_route_tables(
transit_gateway_ids, filters
)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE)
return template.render(
transit_gateway_route_tables=transit_gateway_route_tables
)
def delete_transit_gateway_route_table(self):
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
transit_gateway_route_table = self.ec2_backend.delete_transit_gateway_route_table(
transit_gateway_route_table_id
)
template = self.response_template(DELETE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE)
return template.render(transit_gateway_route_table=transit_gateway_route_table)
def create_transit_gateway_route(self):
transit_gateway_attachment_id = self._get_param("TransitGatewayAttachmentId")
destination_cidr_block = self._get_param("DestinationCidrBlock")
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
blackhole = str2bool(self._get_param("Blackhole"))
transit_gateways_route_table = self.ec2_backend.create_transit_gateway_route(
destination_cidr_block=destination_cidr_block,
transit_gateway_route_table_id=transit_gateway_route_table_id,
transit_gateway_attachment_id=transit_gateway_attachment_id,
blackhole=blackhole,
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_ROUTE_RESPONSE)
return template.render(
transit_gateway_route_table=transit_gateways_route_table,
destination_cidr_block=destination_cidr_block,
)
def delete_transit_gateway_route(self):
destination_cidr_block = self._get_param("DestinationCidrBlock")
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
transit_gateway_route_table = self.ec2_backend.delete_transit_gateway_route(
destination_cidr_block=destination_cidr_block,
transit_gateway_route_table_id=transit_gateway_route_table_id,
)
template = self.response_template(DELETE_TRANSIT_GATEWAY_ROUTE_RESPONSE)
rendered_template = template.render(
transit_gateway_route_table=transit_gateway_route_table,
destination_cidr_block=destination_cidr_block,
)
del transit_gateway_route_table.routes[destination_cidr_block]
return rendered_template
def search_transit_gateway_routes(self):
transit_gateway_route_table_id = self._get_param("TransitGatewayRouteTableId")
filters = filters_from_querystring(self.querystring)
max_results = self._get_param("MaxResults")
transit_gateway_routes = self.ec2_backend.search_transit_gateway_routes(
transit_gateway_route_table_id=transit_gateway_route_table_id,
filters=filters,
max_results=max_results,
)
template = self.response_template(SEARCH_TRANSIT_GATEWAY_ROUTES_RESPONSE)
return template.render(transit_gateway_routes=transit_gateway_routes)
CREATE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE = """<CreateTransitGatewayRouteTableResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>3a495d25-08d4-466d-822e-477c9b1fc606</requestId>
<transitGatewayRouteTable>
<creationTime>{{ transit_gateway_route_table.create_time }}</creationTime>
<defaultAssociationRouteTable>{{ transit_gateway_route_table.default_association_route_table }}</defaultAssociationRouteTable>
<defaultPropagationRouteTable>{{ transit_gateway_route_table.default_propagation_route_table }}</defaultPropagationRouteTable>
<state>{{ transit_gateway_route_table.state }}</state>
<transitGatewayId>{{ transit_gateway_route_table.transit_gateway_id }}</transitGatewayId>
<transitGatewayRouteTableId>{{ transit_gateway_route_table.id }}</transitGatewayRouteTableId>
<tagSet>
{% for tag in transit_gateway_route_table.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</transitGatewayRouteTable>
</CreateTransitGatewayRouteTableResponse>
"""
DESCRIBE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE = """<DescribeTransitGatewayRouteTablesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>f9dea58a-7bb3-458b-a40d-0b7ae32eefdb</requestId>
<transitGatewayRouteTables>
{% for transit_gateway_route_table in transit_gateway_route_tables %}
<item>
<creationTime>{{ transit_gateway_route_table.create_time }}</creationTime>
<defaultAssociationRouteTable>{{ transit_gateway_route_table.default_association_route_table }}</defaultAssociationRouteTable>
<defaultPropagationRouteTable>{{ transit_gateway_route_table.default_propagation_route_table }}</defaultPropagationRouteTable>
<state>{{ transit_gateway_route_table.state }}</state>
<tagSet>
{% for tag in transit_gateway_route_table.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayId>{{ transit_gateway_route_table.transit_gateway_id }}</transitGatewayId>
<transitGatewayRouteTableId>{{ transit_gateway_route_table.id }}</transitGatewayRouteTableId>
</item>
{% endfor %}
</transitGatewayRouteTables>
</DescribeTransitGatewayRouteTablesResponse>
"""
DELETE_TRANSIT_GATEWAY_ROUTE_TABLE_RESPONSE = """<DeleteTransitGatewayRouteTableResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>a9a07226-c7b1-4305-9934-0bcfc3ef1c5e</requestId>
<transitGatewayRouteTable>
{% for transit_gateway_route_table in transit_gateway_route_tables %}
<item>
<creationTime>{{ transit_gateway_route_table.create_time }}</creationTime>
<defaultAssociationRouteTable>{{ transit_gateway_route_table.default_association_route_table }}</defaultAssociationRouteTable>
<defaultPropagationRouteTable>{{ transit_gateway_route_table.default_propagation_route_table }}</defaultPropagationRouteTable>
<state>{{ transit_gateway_route_table.state }}</state>
<transitGatewayId>{{ transit_gateway_route_table.transit_gateway_id }}</transitGatewayId>
<transitGatewayRouteTableId>{{ transit_gateway_route_table.id }}</transitGatewayRouteTableId>
</item>
{% endfor %}
</transitGatewayRouteTable>
</DeleteTransitGatewayRouteTableResponse>
"""
CREATE_TRANSIT_GATEWAY_ROUTE_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<CreateTransitGatewayRouteResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>072b02ce-df3a-4de6-a20b-6653ae4b91a4</requestId>
<route>
<destinationCidrBlock>{{ transit_gateway_route_table.routes[destination_cidr_block]['destinationCidrBlock'] }}</destinationCidrBlock>
<state>{{ transit_gateway_route_table.routes[destination_cidr_block]['state'] }}</state>
<type>{{ transit_gateway_route_table.routes[destination_cidr_block]['type'] }}</type>
</route>
</CreateTransitGatewayRouteResponse>
"""
DELETE_TRANSIT_GATEWAY_ROUTE_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<DeleteTransitGatewayRouteResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>2109d5bb-f874-4f35-b419-4723792a638f</requestId>
<route>
<destinationCidrBlock>{{ transit_gateway_route_table.routes[destination_cidr_block]['destinationCidrBlock'] }}</destinationCidrBlock>
<state>{{ transit_gateway_route_table.routes[destination_cidr_block]['state'] }}</state>
<type>{{ transit_gateway_route_table.routes[destination_cidr_block]['type'] }}</type>
</route>
</DeleteTransitGatewayRouteResponse>
"""
SEARCH_TRANSIT_GATEWAY_ROUTES_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<SearchTransitGatewayRoutesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>04b46ad2-5a0e-46db-afe4-68679a193b48</requestId>
<routeSet>
{% for route in transit_gateway_routes %}
<item>
<destinationCidrBlock>{{ route['destinationCidrBlock'] }}</destinationCidrBlock>
<state>{{ route['state'] }}</state>
<type>{{ route['type'] }}</type>
</item>
{% endfor %}
</routeSet>
<additionalRoutesAvailable>false</additionalRoutesAvailable>
</SearchTransitGatewayRoutesResponse>
"""

View File

@ -0,0 +1,157 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring
class TransitGateways(BaseResponse):
def create_transit_gateway(self):
description = self._get_param("Description") or None
options = self._get_multi_param_dict("Options")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
transit_gateway = self.ec2_backend.create_transit_gateway(
description=description, options=options, tags=tags
)
template = self.response_template(CREATE_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateway=transit_gateway)
def delete_transit_gateway(self):
transit_gateway_id = self._get_param("TransitGatewayId")
transit_gateway = self.ec2_backend.delete_transit_gateway(transit_gateway_id)
template = self.response_template(DELETE_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateway=transit_gateway)
def describe_transit_gateways(self):
filters = filters_from_querystring(self.querystring)
transit_gateways = self.ec2_backend.get_all_transit_gateways(filters)
template = self.response_template(DESCRIBE_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateways=transit_gateways)
def modify_transit_gateway(self):
transit_gateway_id = self._get_param("TransitGatewayId")
description = self._get_param("Description") or None
options = self._get_multi_param_dict("Options")
transit_gateway = self.ec2_backend.modify_transit_gateway(
transit_gateway_id=transit_gateway_id,
description=description,
options=options,
)
template = self.response_template(MODIFY_TRANSIT_GATEWAY_RESPONSE)
return template.render(transit_gateway=transit_gateway)
CREATE_TRANSIT_GATEWAY_RESPONSE = """<CreateTransitGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGateway>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
<ownerId>{{ transit_gateway.owner_id }}</ownerId>
<description>{{ transit_gateway.description or '' }}</description>
<createTime>{{ transit_gateway.create_time }}</createTime>
<state>{{ transit_gateway.state }}</state>
{% if transit_gateway.options %}
<options>
<amazonSideAsn>{{ transit_gateway.options.AmazonSideAsn }}</amazonSideAsn>
<autoAcceptSharedAttachments>{{ transit_gateway.options.AutoAcceptSharedAttachments }}</autoAcceptSharedAttachments>
<defaultRouteTableAssociation>{{ transit_gateway.options.DefaultRouteTableAssociation }}</defaultRouteTableAssociation>
<defaultRouteTablePropagation>{{ transit_gateway.options.DefaultRouteTablePropagation }}</defaultRouteTablePropagation>
<dnsSupport>{{ transit_gateway.options.DnsSupport }}</dnsSupport>
<propagationDefaultRouteTableId>{{ transit_gateway.options.PropagationDefaultRouteTableId }}</propagationDefaultRouteTableId>
<vpnEcmpSupport>{{ transit_gateway.options.VpnEcmpSupport }}</vpnEcmpSupport>
<transitGatewayCidrBlocks>{{ transit_gateway.options.TransitGatewayCidrBlocks }}</transitGatewayCidrBlocks>
</options>
{% endif %}
<tagSet>
{% for tag in transit_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
</transitGateway>
</CreateTransitGatewayResponse>
"""
DESCRIBE_TRANSIT_GATEWAY_RESPONSE = """<DescribeTransitGatewaysResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGatewaySet>
{% for transit_gateway in transit_gateways %}
<item>
<creationTime>{{ transit_gateway.create_time }}</creationTime>
<description>{{ transit_gateway.description or '' }}</description>
{% if transit_gateway.options %}
<options>
<amazonSideAsn>{{ transit_gateway.options.AmazonSideAsn }}</amazonSideAsn>
<associationDefaultRouteTableId>{{ transit_gateway.options.AssociationDefaultRouteTableId }}</associationDefaultRouteTableId>
<autoAcceptSharedAttachments>{{ transit_gateway.options.AutoAcceptSharedAttachments }}</autoAcceptSharedAttachments>
<defaultRouteTableAssociation>{{ transit_gateway.options.DefaultRouteTableAssociation }}</defaultRouteTableAssociation>
<defaultRouteTablePropagation>{{ transit_gateway.options.DefaultRouteTablePropagation }}</defaultRouteTablePropagation>
<dnsSupport>{{ transit_gateway.options.DnsSupport }}</dnsSupport>
<propagationDefaultRouteTableId>{{ transit_gateway.options.PropagationDefaultRouteTableId }}</propagationDefaultRouteTableId>
<vpnEcmpSupport>{{ transit_gateway.options.VpnEcmpSupport }}</vpnEcmpSupport>
<transitGatewayCidrBlocks>{{ transit_gateway.options.TransitGatewayCidrBlocks }}</transitGatewayCidrBlocks>
</options>
{% endif %}
<ownerId>{{ transit_gateway.owner_id }}</ownerId>
<state>{{ transit_gateway.state }}</state>
<tagSet>
{% for tag in transit_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayArn>arn:aws:ec2:us-east-1:{{ transit_gateway.owner_id }}:transit-gateway/{{ transit_gateway.id }}</transitGatewayArn>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
</item>
{% endfor %}
</transitGatewaySet>
</DescribeTransitGatewaysResponse>
"""
DELETE_TRANSIT_GATEWAY_RESPONSE = """<DeleteTransitGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
</DeleteTransitGatewayResponse>
"""
MODIFY_TRANSIT_GATEWAY_RESPONSE = """<ModifyTransitGatewaysResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>151283df-f7dc-4317-89b4-01c9888b1d45</requestId>
<transitGatewaySet>
<item>
<creationTime>{{ transit_gateway.create_time }}</creationTime>
<description>{{ transit_gateway.description or '' }}</description>
{% if transit_gateway.options %}
<options>
<amazonSideAsn>{{ transit_gateway.options.AmazonSideAsn }}</amazonSideAsn>
<associationDefaultRouteTableId>{{ transit_gateway.options.AssociationDefaultRouteTableId }}</associationDefaultRouteTableId>
<autoAcceptSharedAttachments>{{ transit_gateway.options.AutoAcceptSharedAttachments }}</autoAcceptSharedAttachments>
<defaultRouteTableAssociation>{{ transit_gateway.options.DefaultRouteTableAssociation }}</defaultRouteTableAssociation>
<defaultRouteTablePropagation>{{ transit_gateway.options.DefaultRouteTablePropagation }}</defaultRouteTablePropagation>
<dnsSupport>{{ transit_gateway.options.DnsSupport }}</dnsSupport>
<propagationDefaultRouteTableId>{{ transit_gateway.options.PropagationDefaultRouteTableId }}</propagationDefaultRouteTableId>
<vpnEcmpSupport>{{ transit_gateway.options.VpnEcmpSupport }}</vpnEcmpSupport>
<transitGatewayCidrBlocks>{{ transit_gateway.options.TransitGatewayCidrBlocks }}</transitGatewayCidrBlocks>
</options>
{% endif %}
<ownerId>{{ transit_gateway.owner_id }}</ownerId>
<state>{{ transit_gateway.state }}</state>
<tagSet>
{% for tag in transit_gateway.get_tags() %}
<item>
<key>{{ tag.key }}</key>
<value>{{ tag.value }}</value>
</item>
{% endfor %}
</tagSet>
<transitGatewayArn>arn:aws:ec2:us-east-1:{{ transit_gateway.owner_id }}:transit-gateway/{{ transit_gateway.id }}</transitGatewayArn>
<transitGatewayId>{{ transit_gateway.id }}</transitGatewayId>
</item>
</transitGatewaySet>
</ModifyTransitGatewaysResponse>
"""

View File

@ -13,7 +13,18 @@ class VirtualPrivateGateways(BaseResponse):
def create_vpn_gateway(self): def create_vpn_gateway(self):
type = self._get_param("Type") type = self._get_param("Type")
vpn_gateway = self.ec2_backend.create_vpn_gateway(type) amazon_side_asn = self._get_param("AmazonSideAsn")
availability_zone = self._get_param("AvailabilityZone")
tags = self._get_multi_param("TagSpecification")
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
vpn_gateway = self.ec2_backend.create_vpn_gateway(
type=type,
amazon_side_asn=amazon_side_asn,
availability_zone=availability_zone,
tags=tags,
)
template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE) template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE)
return template.render(vpn_gateway=vpn_gateway) return template.render(vpn_gateway=vpn_gateway)
@ -44,7 +55,7 @@ CREATE_VPN_GATEWAY_RESPONSE = """
<vpnGatewayId>{{ vpn_gateway.id }}</vpnGatewayId> <vpnGatewayId>{{ vpn_gateway.id }}</vpnGatewayId>
<state>available</state> <state>available</state>
<type>{{ vpn_gateway.type }}</type> <type>{{ vpn_gateway.type }}</type>
<availabilityZone>us-east-1a</availabilityZone> <availabilityZone>{{ vpn_gateway.availability_zone }}</availabilityZone>
<attachments/> <attachments/>
<tagSet> <tagSet>
{% for tag in vpn_gateway.get_tags() %} {% for tag in vpn_gateway.get_tags() %}

View File

@ -17,16 +17,16 @@ class VPCs(BaseResponse):
cidr_block = self._get_param("CidrBlock") cidr_block = self._get_param("CidrBlock")
tags = self._get_multi_param("TagSpecification") tags = self._get_multi_param("TagSpecification")
instance_tenancy = self._get_param("InstanceTenancy", if_none="default") instance_tenancy = self._get_param("InstanceTenancy", if_none="default")
amazon_provided_ipv6_cidr_blocks = self._get_param( amazon_provided_ipv6_cidr_block = self._get_param(
"AmazonProvidedIpv6CidrBlock" "AmazonProvidedIpv6CidrBlock"
) ) in ["true", "True"]
if tags: if tags:
tags = tags[0].get("Tag") tags = tags[0].get("Tag")
vpc = self.ec2_backend.create_vpc( vpc = self.ec2_backend.create_vpc(
cidr_block, cidr_block,
instance_tenancy, instance_tenancy,
amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_blocks, amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_block,
tags=tags, tags=tags,
) )
doc_date = self._get_doc_date() doc_date = self._get_doc_date()
@ -178,8 +178,8 @@ class VPCs(BaseResponse):
policy_document = self._get_param("PolicyDocument") policy_document = self._get_param("PolicyDocument")
client_token = self._get_param("ClientToken") client_token = self._get_param("ClientToken")
tag_specifications = self._get_param("TagSpecifications") tag_specifications = self._get_param("TagSpecifications")
private_dns_enabled = self._get_bool_param("PrivateDNSEnabled", if_none=True) private_dns_enabled = self._get_bool_param("PrivateDnsEnabled", if_none=True)
security_group = self._get_param("SecurityGroup") security_group_ids = self._get_multi_param("SecurityGroupId")
vpc_end_point = self.ec2_backend.create_vpc_endpoint( vpc_end_point = self.ec2_backend.create_vpc_endpoint(
vpc_id=vpc_id, vpc_id=vpc_id,
@ -189,7 +189,7 @@ class VPCs(BaseResponse):
route_table_ids=route_table_ids, route_table_ids=route_table_ids,
subnet_ids=subnet_ids, subnet_ids=subnet_ids,
client_token=client_token, client_token=client_token,
security_group=security_group, security_group_ids=security_group_ids,
tag_specifications=tag_specifications, tag_specifications=tag_specifications,
private_dns_enabled=private_dns_enabled, private_dns_enabled=private_dns_enabled,
) )
@ -479,8 +479,8 @@ DESCRIBE_VPC_ENDPOINT_SERVICES_RESPONSE = """<DescribeVpcEndpointServicesRespons
{% endfor %} {% endfor %}
</serviceNameSet> </serviceNameSet>
<serviceDetailSet> <serviceDetailSet>
<item>
{% for service in vpc_end_points.servicesDetails %} {% for service in vpc_end_points.servicesDetails %}
<item>
<owner>amazon</owner> <owner>amazon</owner>
<serviceType> <serviceType>
<item> <item>
@ -498,8 +498,8 @@ DESCRIBE_VPC_ENDPOINT_SERVICES_RESPONSE = """<DescribeVpcEndpointServicesRespons
</availabilityZoneSet> </availabilityZoneSet>
<serviceName>{{ service.service_name }}</serviceName> <serviceName>{{ service.service_name }}</serviceName>
<vpcEndpointPolicySupported>true</vpcEndpointPolicySupported> <vpcEndpointPolicySupported>true</vpcEndpointPolicySupported>
{% endfor %}
</item> </item>
{% endfor %}
</serviceDetailSet> </serviceDetailSet>
</DescribeVpcEndpointServicesResponse>""" </DescribeVpcEndpointServicesResponse>"""
@ -545,12 +545,15 @@ DESCRIBE_VPC_ENDPOINT_RESPONSE = """<DescribeVpcEndpointsResponse xmlns="http://
{% endfor %} {% endfor %}
</dnsEntries> </dnsEntries>
{% endif %} {% endif %}
{% if vpc_end_point.groups %} {% if vpc_end_point.security_group_ids %}
<groups> <groupSet>
{% for group in vpc_end_point.groups %} {% for group_id in vpc_end_point.security_group_ids %}
<item>{{ group }}</item> <item>
<groupId>{{ group_id }}</groupId>
<groupName>TODO</groupName>
</item>
{% endfor %} {% endfor %}
</groups> </groupSet>
{% endif %} {% endif %}
{% if vpc_end_point.tag_specifications %} {% if vpc_end_point.tag_specifications %}
<tagSet> <tagSet>

View File

@ -1,6 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import filters_from_querystring from moto.ec2.utils import filters_from_querystring, add_tag_specification
from xml.sax.saxutils import escape
class VPNConnections(BaseResponse): class VPNConnections(BaseResponse):
@ -8,9 +9,20 @@ class VPNConnections(BaseResponse):
type = self._get_param("Type") type = self._get_param("Type")
cgw_id = self._get_param("CustomerGatewayId") cgw_id = self._get_param("CustomerGatewayId")
vgw_id = self._get_param("VpnGatewayId") vgw_id = self._get_param("VpnGatewayId")
tgw_id = self._get_param("TransitGatewayId")
static_routes = self._get_param("StaticRoutesOnly") static_routes = self._get_param("StaticRoutesOnly")
tags = add_tag_specification(self._get_multi_param("TagSpecification"))
vpn_connection = self.ec2_backend.create_vpn_connection( vpn_connection = self.ec2_backend.create_vpn_connection(
type, cgw_id, vgw_id, static_routes_only=static_routes type,
cgw_id,
vpn_gateway_id=vgw_id,
transit_gateway_id=tgw_id,
static_routes_only=static_routes,
tags=tags,
)
if vpn_connection.transit_gateway_id:
self.ec2_backend.create_transit_gateway_vpn_attachment(
vpn_id=vpn_connection.id, transit_gateway_id=tgw_id
) )
template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE) template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection) return template.render(vpn_connection=vpn_connection)
@ -18,6 +30,13 @@ class VPNConnections(BaseResponse):
def delete_vpn_connection(self): def delete_vpn_connection(self):
vpn_connection_id = self._get_param("VpnConnectionId") vpn_connection_id = self._get_param("VpnConnectionId")
vpn_connection = self.ec2_backend.delete_vpn_connection(vpn_connection_id) vpn_connection = self.ec2_backend.delete_vpn_connection(vpn_connection_id)
if vpn_connection.transit_gateway_id:
transit_gateway_attachments = (
self.ec2_backend.describe_transit_gateway_attachments()
)
for attachment in transit_gateway_attachments:
if attachment.resource_id == vpn_connection.id:
attachment.state = "deleted"
template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE) template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE)
return template.render(vpn_connection=vpn_connection) return template.render(vpn_connection=vpn_connection)
@ -31,17 +50,11 @@ class VPNConnections(BaseResponse):
return template.render(vpn_connections=vpn_connections) return template.render(vpn_connections=vpn_connections)
CREATE_VPN_CONNECTION_RESPONSE = """ CUSTOMER_GATEWAY_CONFIGURATION_TEMPLATE = """
<CreateVpnConnectionResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<vpnConnection>
<vpnConnectionId>{{ vpn_connection.id }}</vpnConnectionId>
<state>pending</state>
<customerGatewayConfiguration>
<vpn_connection id="{{ vpn_connection.id }}"> <vpn_connection id="{{ vpn_connection.id }}">
<customer_gateway_id>{{ vpn_connection.customer_gateway_id }}</customer_gateway_id> <customer_gateway_id>{{ vpn_connection.customer_gateway_id }}</customer_gateway_id>
<vpn_gateway_id>{{ vpn_connection.vpn_gateway_id }}</vpn_gateway_id> <vpn_gateway_id> {{ vpn_connection.vpn_gateway_id if vpn_connection.vpn_gateway_id is not none }} </vpn_gateway_id>
<vpn_connection_type>ipsec.1</vpn_connection_type> <vpn_connection_type>{{ vpn_connection.type }}</vpn_connection_type>
<ipsec_tunnel> <ipsec_tunnel>
<customer_gateway> <customer_gateway>
<tunnel_outside_address> <tunnel_outside_address>
@ -149,15 +162,29 @@ CREATE_VPN_CONNECTION_RESPONSE = """
</ipsec> </ipsec>
</ipsec_tunnel> </ipsec_tunnel>
</vpn_connection> </vpn_connection>
"""
CREATE_VPN_CONNECTION_RESPONSE = (
"""
<CreateVpnConnectionResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<vpnConnection>
<vpnConnectionId>{{ vpn_connection.id }}</vpnConnectionId>
<state>{{ vpn_connection.state }}</state>
<customerGatewayConfiguration>
"""
+ escape(CUSTOMER_GATEWAY_CONFIGURATION_TEMPLATE)
+ """
</customerGatewayConfiguration> </customerGatewayConfiguration>
<type>ipsec.1</type> <type>ipsec.1</type>
<customerGatewayId>{{ vpn_connection.customer_gateway_id }}</customerGatewayId> <customerGatewayId>{{ vpn_connection.customer_gateway_id }}</customerGatewayId>
<vpnGatewayId>{{ vpn_connection.vpn_gateway_id }}</vpnGatewayId> <vpnGatewayId>{{ vpn_connection.vpn_gateway_id or '' }}</vpnGatewayId>
{% if vpn_connection.transit_gateway_id %}
<transitGatewayId>{{ vpn_connection.transit_gateway_id }}</transitGatewayId>
{% endif %}
<tagSet> <tagSet>
{% for tag in vpn_connection.get_tags() %} {% for tag in vpn_connection.get_tags() %}
<item> <item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key> <key>{{ tag.key }}</key>
<value>{{ tag.value }}</value> <value>{{ tag.value }}</value>
</item> </item>
@ -165,6 +192,8 @@ CREATE_VPN_CONNECTION_RESPONSE = """
</tagSet> </tagSet>
</vpnConnection> </vpnConnection>
</CreateVpnConnectionResponse>""" </CreateVpnConnectionResponse>"""
)
CREATE_VPN_CONNECTION_ROUTE_RESPONSE = """ CREATE_VPN_CONNECTION_ROUTE_RESPONSE = """
<CreateVpnConnectionRouteResponse xmlns="http://ec2.amazonaws.com/doc/2013-10- 15/"> <CreateVpnConnectionRouteResponse xmlns="http://ec2.amazonaws.com/doc/2013-10- 15/">
@ -184,135 +213,29 @@ DELETE_VPN_CONNECTION_ROUTE_RESPONSE = """
<return>true</return> <return>true</return>
</DeleteVpnConnectionRouteResponse>""" </DeleteVpnConnectionRouteResponse>"""
DESCRIBE_VPN_CONNECTION_RESPONSE = """ DESCRIBE_VPN_CONNECTION_RESPONSE = (
"""
<DescribeVpnConnectionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <DescribeVpnConnectionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<vpnConnectionSet> <vpnConnectionSet>
{% for vpn_connection in vpn_connections %} {% for vpn_connection in vpn_connections %}
<item> <item>
<vpnConnectionId>{{ vpn_connection.id }}</vpnConnectionId> <vpnConnectionId>{{ vpn_connection.id }}</vpnConnectionId>
<state>available</state> <state>{{ vpn_connection.state }}</state>
<customerGatewayConfiguration> <customerGatewayConfiguration>
<vpn_connection id="{{ vpn_connection.id }}"> """
<customer_gateway_id>{{ vpn_connection.customer_gateway_id }}</customer_gateway_id> + escape(CUSTOMER_GATEWAY_CONFIGURATION_TEMPLATE)
<vpn_gateway_id>{{ vpn_connection.vpn_gateway_id }}</vpn_gateway_id> + """
<vpn_connection_type>ipsec.1</vpn_connection_type>
<ipsec_tunnel>
<customer_gateway>
<tunnel_outside_address>
<ip_address>12.1.2.3</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.42</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>65000</asn>
<hold_time>30</hold_time>
</bgp>
</customer_gateway>
<vpn_gateway>
<tunnel_outside_address>
<ip_address>52.2.144.13</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.41</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>7224</asn>
<hold_time>30</hold_time>
</bgp>
</vpn_gateway>
<ike>
<authentication_protocol>sha1</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>28800</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>main</mode>
<pre_shared_key>Iw2IAN9XUsQeYUrkMGP3kP59ugFDkfHg</pre_shared_key>
</ike>
<ipsec>
<protocol>esp</protocol>
<authentication_protocol>hmac-sha1-96</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>3600</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>tunnel</mode>
<clear_df_bit>true</clear_df_bit>
<fragmentation_before_encryption>true</fragmentation_before_encryption>
<tcp_mss_adjustment>1387</tcp_mss_adjustment>
<dead_peer_detection>
<interval>10</interval>
<retries>3</retries>
</dead_peer_detection>
</ipsec>
</ipsec_tunnel>
<ipsec_tunnel>
<customer_gateway>
<tunnel_outside_address>
<ip_address>12.1.2.3</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.42</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>65000</asn>
<hold_time>30</hold_time>
</bgp>
</customer_gateway>
<vpn_gateway>
<tunnel_outside_address>
<ip_address>52.2.144.13</ip_address>
</tunnel_outside_address>
<tunnel_inside_address>
<ip_address>169.254.44.41</ip_address>
<network_mask>255.255.255.252</network_mask>
<network_cidr>30</network_cidr>
</tunnel_inside_address>
<bgp>
<asn>7224</asn>
<hold_time>30</hold_time>
</bgp>
</vpn_gateway>
<ike>
<authentication_protocol>sha1</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>28800</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>main</mode>
<pre_shared_key>Iw2IAN9XUsQeYUrkMGP3kP59ugFDkfHg</pre_shared_key>
</ike>
<ipsec>
<protocol>esp</protocol>
<authentication_protocol>hmac-sha1-96</authentication_protocol>
<encryption_protocol>aes-128-cbc</encryption_protocol>
<lifetime>3600</lifetime>
<perfect_forward_secrecy>group2</perfect_forward_secrecy>
<mode>tunnel</mode>
<clear_df_bit>true</clear_df_bit>
<fragmentation_before_encryption>true</fragmentation_before_encryption>
<tcp_mss_adjustment>1387</tcp_mss_adjustment>
<dead_peer_detection>
<interval>10</interval>
<retries>3</retries>
</dead_peer_detection>
</ipsec>
</ipsec_tunnel>
</vpn_connection>
</customerGatewayConfiguration> </customerGatewayConfiguration>
<type>ipsec.1</type> <type>ipsec.1</type>
<customerGatewayId>{{ vpn_connection.customer_gateway_id }}</customerGatewayId> <customerGatewayId>{{ vpn_connection.customer_gateway_id }}</customerGatewayId>
<vpnGatewayId>{{ vpn_connection.vpn_gateway_id }}</vpnGatewayId> <vpnGatewayId>{{ vpn_connection.vpn_gateway_id or '' }}</vpnGatewayId>
{% if vpn_connection.transit_gateway_id %}
<transitGatewayId>{{ vpn_connection.transit_gateway_id }}</transitGatewayId>
{% endif %}
<tagSet> <tagSet>
{% for tag in vpn_connection.get_tags() %} {% for tag in vpn_connection.get_tags() %}
<item> <item>
<resourceId>{{ tag.resource_id }}</resourceId>
<resourceType>{{ tag.resource_type }}</resourceType>
<key>{{ tag.key }}</key> <key>{{ tag.key }}</key>
<value>{{ tag.value }}</value> <value>{{ tag.value }}</value>
</item> </item>
@ -322,3 +245,4 @@ DESCRIBE_VPN_CONNECTION_RESPONSE = """
{% endfor %} {% endfor %}
</vpnConnectionSet> </vpnConnectionSet>
</DescribeVpnConnectionsResponse>""" </DescribeVpnConnectionsResponse>"""
)

View File

@ -15,6 +15,9 @@ from moto.iam import iam_backends
EC2_RESOURCE_TO_PREFIX = { EC2_RESOURCE_TO_PREFIX = {
"customer-gateway": "cgw", "customer-gateway": "cgw",
"transit-gateway": "tgw",
"transit-gateway-route-table": "tgw-rtb",
"transit-gateway-attachment": "tgw-attach",
"dhcp-options": "dopt", "dhcp-options": "dopt",
"flow-logs": "fl", "flow-logs": "fl",
"image": "ami", "image": "ami",
@ -168,6 +171,22 @@ def random_nat_gateway_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["nat-gateway"], size=17) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["nat-gateway"], size=17)
def random_transit_gateway_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["transit-gateway"], size=17)
def random_transit_gateway_route_table_id():
return random_id(
prefix=EC2_RESOURCE_TO_PREFIX["transit-gateway-route-table"], size=17
)
def random_transit_gateway_attachment_id():
return random_id(
prefix=EC2_RESOURCE_TO_PREFIX["transit-gateway-attachment"], size=17
)
def random_launch_template_id(): def random_launch_template_id():
return random_id(prefix=EC2_RESOURCE_TO_PREFIX["launch-template"], size=17) return random_id(prefix=EC2_RESOURCE_TO_PREFIX["launch-template"], size=17)
@ -207,7 +226,7 @@ def generate_route_id(route_table_id, cidr_block, ipv6_cidr_block=None):
def generate_vpc_end_point_id(vpc_id): def generate_vpc_end_point_id(vpc_id):
return "%s-%s" % ("vpce", vpc_id[4:]) return "%s-%s%s" % ("vpce", vpc_id[4:], random_resource_id(4))
def create_dns_entries(service_name, vpc_endpoint_id): def create_dns_entries(service_name, vpc_endpoint_id):
@ -342,6 +361,13 @@ def get_obj_tag_values(obj):
return tags return tags
def add_tag_specification(tags):
tags = tags[0] if isinstance(tags, list) and len(tags) == 1 else tags
tags = (tags or {}).get("Tag", [])
tags = {t["Key"]: t["Value"] for t in tags}
return tags
def tag_filter_matches(obj, filter_name, filter_values): def tag_filter_matches(obj, filter_name, filter_values):
regex_filters = [re.compile(simple_aws_filter_to_re(f)) for f in filter_values] regex_filters = [re.compile(simple_aws_filter_to_re(f)) for f in filter_values]
if filter_name == "tag-key": if filter_name == "tag-key":
@ -515,6 +541,11 @@ def random_key_pair():
def get_prefix(resource_id): def get_prefix(resource_id):
resource_id_prefix, separator, after = resource_id.partition("-") resource_id_prefix, separator, after = resource_id.partition("-")
if resource_id_prefix == EC2_RESOURCE_TO_PREFIX["transit-gateway"]:
if after.startswith("rtb"):
resource_id_prefix = EC2_RESOURCE_TO_PREFIX["transit-gateway-route-table"]
if after.startswith("attach"):
resource_id_prefix = EC2_RESOURCE_TO_PREFIX["transit-gateway-attachment"]
if resource_id_prefix == EC2_RESOURCE_TO_PREFIX["network-interface"]: if resource_id_prefix == EC2_RESOURCE_TO_PREFIX["network-interface"]:
if after.startswith("attach"): if after.startswith("attach"):
resource_id_prefix = EC2_RESOURCE_TO_PREFIX["network-interface-attachment"] resource_id_prefix = EC2_RESOURCE_TO_PREFIX["network-interface-attachment"]

View File

@ -307,6 +307,7 @@ class Service(BaseObject, CloudFormationModel):
tags=None, tags=None,
deployment_controller=None, deployment_controller=None,
launch_type=None, launch_type=None,
service_registries=None,
): ):
self.cluster_arn = cluster.arn self.cluster_arn = cluster.arn
self.arn = "arn:aws:ecs:{0}:{1}:service/{2}".format( self.arn = "arn:aws:ecs:{0}:{1}:service/{2}".format(
@ -324,6 +325,7 @@ class Service(BaseObject, CloudFormationModel):
self.deployment_controller = deployment_controller or {"type": "ECS"} self.deployment_controller = deployment_controller or {"type": "ECS"}
self.events = [] self.events = []
self.launch_type = launch_type self.launch_type = launch_type
self.service_registries = service_registries or []
if self.deployment_controller["type"] == "ECS": if self.deployment_controller["type"] == "ECS":
self.deployments = [ self.deployments = [
{ {
@ -1076,6 +1078,7 @@ class EC2ContainerServiceBackend(BaseBackend):
tags=None, tags=None,
deployment_controller=None, deployment_controller=None,
launch_type=None, launch_type=None,
service_registries=None,
): ):
cluster = self._get_cluster(cluster_str) cluster = self._get_cluster(cluster_str)
@ -1101,6 +1104,7 @@ class EC2ContainerServiceBackend(BaseBackend):
tags, tags,
deployment_controller, deployment_controller,
launch_type, launch_type,
service_registries=service_registries,
) )
cluster_service_pair = "{0}:{1}".format(cluster.name, service_name) cluster_service_pair = "{0}:{1}".format(cluster.name, service_name)
self.services[cluster_service_pair] = service self.services[cluster_service_pair] = service

View File

@ -176,6 +176,7 @@ class EC2ContainerServiceResponse(BaseResponse):
desired_count = self._get_int_param("desiredCount") desired_count = self._get_int_param("desiredCount")
load_balancers = self._get_param("loadBalancers") load_balancers = self._get_param("loadBalancers")
scheduling_strategy = self._get_param("schedulingStrategy") scheduling_strategy = self._get_param("schedulingStrategy")
service_registries = self._get_param("serviceRegistries")
tags = self._get_param("tags") tags = self._get_param("tags")
deployment_controller = self._get_param("deploymentController") deployment_controller = self._get_param("deploymentController")
launch_type = self._get_param("launchType") launch_type = self._get_param("launchType")
@ -189,6 +190,7 @@ class EC2ContainerServiceResponse(BaseResponse):
tags, tags,
deployment_controller, deployment_controller,
launch_type, launch_type,
service_registries=service_registries,
) )
return json.dumps({"service": service.response_object}) return json.dumps({"service": service.response_object})

View File

@ -103,11 +103,20 @@ class PriorityInUseError(ELBClientError):
class InvalidConditionFieldError(ELBClientError): class InvalidConditionFieldError(ELBClientError):
VALID_FIELDS = [
"path-pattern",
"host-header",
"http-header",
"http-request-method",
"query-string",
"source-ip",
]
def __init__(self, invalid_name): def __init__(self, invalid_name):
super(InvalidConditionFieldError, self).__init__( super(InvalidConditionFieldError, self).__init__(
"ValidationError", "ValidationError",
"Condition field '%s' must be one of '[path-pattern, host-header]" "Condition field '%s' must be one of '[%s]'"
% (invalid_name), % (invalid_name, ",".join(self.VALID_FIELDS)),
) )
@ -132,6 +141,14 @@ class ActionTargetGroupNotFoundError(ELBClientError):
) )
class ListenerOrBalancerMissingError(ELBClientError):
def __init__(self, arn):
super(ListenerOrBalancerMissingError, self).__init__(
"ValidationError",
"You must specify either listener ARNs or a load balancer ARN",
)
class InvalidDescribeRulesRequest(ELBClientError): class InvalidDescribeRulesRequest(ELBClientError):
def __init__(self, msg): def __init__(self, msg):
super(InvalidDescribeRulesRequest, self).__init__("ValidationError", msg) super(InvalidDescribeRulesRequest, self).__init__("ValidationError", msg)

View File

@ -70,6 +70,7 @@ class FakeTargetGroup(CloudFormationModel):
healthcheck_path=None, healthcheck_path=None,
healthcheck_interval_seconds=None, healthcheck_interval_seconds=None,
healthcheck_timeout_seconds=None, healthcheck_timeout_seconds=None,
healthcheck_enabled=None,
healthy_threshold_count=None, healthy_threshold_count=None,
unhealthy_threshold_count=None, unhealthy_threshold_count=None,
matcher=None, matcher=None,
@ -82,19 +83,21 @@ class FakeTargetGroup(CloudFormationModel):
self.vpc_id = vpc_id self.vpc_id = vpc_id
self.protocol = protocol self.protocol = protocol
self.port = port self.port = port
self.healthcheck_protocol = healthcheck_protocol or "HTTP" self.healthcheck_protocol = healthcheck_protocol or self.protocol
self.healthcheck_port = healthcheck_port or str(self.port) self.healthcheck_port = healthcheck_port
self.healthcheck_path = healthcheck_path or "/" self.healthcheck_path = healthcheck_path
self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30 self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30
self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5 self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5
self.healthcheck_enabled = healthcheck_enabled
self.healthy_threshold_count = healthy_threshold_count or 5 self.healthy_threshold_count = healthy_threshold_count or 5
self.unhealthy_threshold_count = unhealthy_threshold_count or 2 self.unhealthy_threshold_count = unhealthy_threshold_count or 2
self.load_balancer_arns = [] self.load_balancer_arns = []
self.tags = {} self.tags = {}
if matcher is None:
self.matcher = {"HttpCode": "200"}
else:
self.matcher = matcher self.matcher = matcher
if self.protocol != "TCP":
self.matcher = self.matcher or {"HttpCode": "200"}
self.healthcheck_path = self.healthcheck_path or "/"
self.healthcheck_port = self.healthcheck_port or str(self.port)
self.target_type = target_type self.target_type = target_type
self.attributes = { self.attributes = {
@ -209,7 +212,7 @@ class FakeListener(CloudFormationModel):
): ):
self.load_balancer_arn = load_balancer_arn self.load_balancer_arn = load_balancer_arn
self.arn = arn self.arn = arn
self.protocol = protocol.upper() self.protocol = (protocol or "").upper()
self.port = port self.port = port
self.ssl_policy = ssl_policy self.ssl_policy = ssl_policy
self.certificate = certificate self.certificate = certificate
@ -224,6 +227,7 @@ class FakeListener(CloudFormationModel):
actions=default_actions, actions=default_actions,
is_default=True, is_default=True,
) )
self.tags = {}
@property @property
def physical_resource_id(self): def physical_resource_id(self):
@ -437,6 +441,7 @@ class FakeLoadBalancer(CloudFormationModel):
dns_name, dns_name,
state, state,
scheme="internet-facing", scheme="internet-facing",
loadbalancer_type=None,
): ):
self.name = name self.name = name
self.created_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.created_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now())
@ -449,14 +454,15 @@ class FakeLoadBalancer(CloudFormationModel):
self.arn = arn self.arn = arn
self.dns_name = dns_name self.dns_name = dns_name
self.state = state self.state = state
self.loadbalancer_type = loadbalancer_type or "application"
self.stack = "ipv4" self.stack = "ipv4"
self.attrs = { self.attrs = {
"access_logs.s3.enabled": "false", # "access_logs.s3.enabled": "false", # commented out for TF compatibility
"access_logs.s3.bucket": None, "access_logs.s3.bucket": None,
"access_logs.s3.prefix": None, "access_logs.s3.prefix": None,
"deletion_protection.enabled": "false", "deletion_protection.enabled": "false",
"idle_timeout.timeout_seconds": "60", # "idle_timeout.timeout_seconds": "60", # commented out for TF compatibility
} }
@property @property
@ -573,7 +579,12 @@ class ELBv2Backend(BaseBackend):
self.__init__(region_name) self.__init__(region_name)
def create_load_balancer( def create_load_balancer(
self, name, security_groups, subnet_ids, scheme="internet-facing" self,
name,
security_groups,
subnet_ids,
scheme="internet-facing",
loadbalancer_type=None,
): ):
vpc_id = None vpc_id = None
subnets = [] subnets = []
@ -605,6 +616,7 @@ class ELBv2Backend(BaseBackend):
vpc_id=vpc_id, vpc_id=vpc_id,
dns_name=dns_name, dns_name=dns_name,
state=state, state=state,
loadbalancer_type=loadbalancer_type,
) )
self.load_balancers[arn] = new_load_balancer self.load_balancers[arn] = new_load_balancer
return new_load_balancer return new_load_balancer
@ -634,7 +646,7 @@ class ELBv2Backend(BaseBackend):
) )
elif action_type == "forward" and "ForwardConfig" not in action: elif action_type == "forward" and "ForwardConfig" not in action:
default_actions.append( default_actions.append(
{"type": action_type, "target_group_arn": action["TargetGroupArn"],} {"type": action_type, "target_group_arn": action["TargetGroupArn"]}
) )
elif action_type in [ elif action_type in [
"redirect", "redirect",
@ -666,8 +678,12 @@ class ELBv2Backend(BaseBackend):
listener = listeners[0] listener = listeners[0]
# validate conditions # validate conditions
# see: https://docs.aws.amazon.com/cli/latest/reference/elbv2/create-rule.html
self._validate_conditions(conditions) self._validate_conditions(conditions)
# TODO: check QueryStringConfig condition
# TODO: check HttpRequestMethodConfig condition
# TODO: check SourceIpConfig condition
# TODO: check pattern of value for 'host-header' # TODO: check pattern of value for 'host-header'
# TODO: check pattern of value for 'path-pattern' # TODO: check pattern of value for 'path-pattern'
@ -778,6 +794,10 @@ class ELBv2Backend(BaseBackend):
) )
if values is None or len(values) == 0: if values is None or len(values) == 0:
raise InvalidConditionValueError("A condition value must be specified") raise InvalidConditionValueError("A condition value must be specified")
if condition.get("Values") and condition.get("PathPatternConfig"):
raise InvalidConditionValueError(
"You cannot provide both Values and 'PathPatternConfig' for a condition of type 'path-pattern'"
)
for value in values: for value in values:
if len(value) > 128: if len(value) > 128:
raise InvalidConditionValueError( raise InvalidConditionValueError(
@ -955,7 +975,7 @@ Member must satisfy regular expression pattern: {}".format(
action_type = action["Type"] action_type = action["Type"]
if action_type == "forward": if action_type == "forward":
default_actions.append( default_actions.append(
{"type": action_type, "target_group_arn": action["TargetGroupArn"],} {"type": action_type, "target_group_arn": action["TargetGroupArn"]}
) )
elif action_type in [ elif action_type in [
"redirect", "redirect",
@ -1340,6 +1360,7 @@ Member must satisfy regular expression pattern: {}".format(
healthy_threshold_count=None, healthy_threshold_count=None,
unhealthy_threshold_count=None, unhealthy_threshold_count=None,
http_codes=None, http_codes=None,
health_check_enabled=None,
): ):
target_group = self.target_groups.get(arn) target_group = self.target_groups.get(arn)
if target_group is None: if target_group is None:
@ -1366,6 +1387,8 @@ Member must satisfy regular expression pattern: {}".format(
target_group.healthcheck_protocol = health_check_proto target_group.healthcheck_protocol = health_check_proto
if health_check_timeout is not None: if health_check_timeout is not None:
target_group.healthcheck_timeout_seconds = health_check_timeout target_group.healthcheck_timeout_seconds = health_check_timeout
if health_check_enabled is not None:
target_group.healthcheck_enabled = health_check_enabled
if healthy_threshold_count is not None: if healthy_threshold_count is not None:
target_group.healthy_threshold_count = healthy_threshold_count target_group.healthy_threshold_count = healthy_threshold_count
if unhealthy_threshold_count is not None: if unhealthy_threshold_count is not None:

View File

@ -6,7 +6,8 @@ from .models import elbv2_backends
from .exceptions import DuplicateTagKeysError from .exceptions import DuplicateTagKeysError
from .exceptions import LoadBalancerNotFoundError from .exceptions import LoadBalancerNotFoundError
from .exceptions import TargetGroupNotFoundError from .exceptions import TargetGroupNotFoundError
from .exceptions import ListenerNotFoundError
from .exceptions import ListenerOrBalancerMissingError
SSL_POLICIES = [ SSL_POLICIES = [
{ {
@ -138,12 +139,14 @@ class ELBV2Response(BaseResponse):
subnet_ids = self._get_multi_param("Subnets.member") subnet_ids = self._get_multi_param("Subnets.member")
security_groups = self._get_multi_param("SecurityGroups.member") security_groups = self._get_multi_param("SecurityGroups.member")
scheme = self._get_param("Scheme") scheme = self._get_param("Scheme")
loadbalancer_type = self._get_param("Type")
load_balancer = self.elbv2_backend.create_load_balancer( load_balancer = self.elbv2_backend.create_load_balancer(
name=load_balancer_name, name=load_balancer_name,
security_groups=security_groups, security_groups=security_groups,
subnet_ids=subnet_ids, subnet_ids=subnet_ids,
scheme=scheme, scheme=scheme,
loadbalancer_type=loadbalancer_type,
) )
self._add_tags(load_balancer) self._add_tags(load_balancer)
template = self.response_template(CREATE_LOAD_BALANCER_TEMPLATE) template = self.response_template(CREATE_LOAD_BALANCER_TEMPLATE)
@ -173,9 +176,11 @@ class ELBV2Response(BaseResponse):
healthcheck_path = self._get_param("HealthCheckPath") healthcheck_path = self._get_param("HealthCheckPath")
healthcheck_interval_seconds = self._get_param("HealthCheckIntervalSeconds") healthcheck_interval_seconds = self._get_param("HealthCheckIntervalSeconds")
healthcheck_timeout_seconds = self._get_param("HealthCheckTimeoutSeconds") healthcheck_timeout_seconds = self._get_param("HealthCheckTimeoutSeconds")
healthcheck_enabled = self._get_param("HealthCheckEnabled")
healthy_threshold_count = self._get_param("HealthyThresholdCount") healthy_threshold_count = self._get_param("HealthyThresholdCount")
unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount") unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount")
matcher = self._get_param("Matcher") matcher = self._get_param("Matcher")
target_type = self._get_param("TargetType")
target_group = self.elbv2_backend.create_target_group( target_group = self.elbv2_backend.create_target_group(
name, name,
@ -187,9 +192,11 @@ class ELBV2Response(BaseResponse):
healthcheck_path=healthcheck_path, healthcheck_path=healthcheck_path,
healthcheck_interval_seconds=healthcheck_interval_seconds, healthcheck_interval_seconds=healthcheck_interval_seconds,
healthcheck_timeout_seconds=healthcheck_timeout_seconds, healthcheck_timeout_seconds=healthcheck_timeout_seconds,
healthcheck_enabled=healthcheck_enabled,
healthy_threshold_count=healthy_threshold_count, healthy_threshold_count=healthy_threshold_count,
unhealthy_threshold_count=unhealthy_threshold_count, unhealthy_threshold_count=unhealthy_threshold_count,
matcher=matcher, matcher=matcher,
target_type=target_type,
) )
template = self.response_template(CREATE_TARGET_GROUP_TEMPLATE) template = self.response_template(CREATE_TARGET_GROUP_TEMPLATE)
@ -299,7 +306,7 @@ class ELBV2Response(BaseResponse):
load_balancer_arn = self._get_param("LoadBalancerArn") load_balancer_arn = self._get_param("LoadBalancerArn")
listener_arns = self._get_multi_param("ListenerArns.member") listener_arns = self._get_multi_param("ListenerArns.member")
if not load_balancer_arn and not listener_arns: if not load_balancer_arn and not listener_arns:
raise LoadBalancerNotFoundError() raise ListenerOrBalancerMissingError()
listeners = self.elbv2_backend.describe_listeners( listeners = self.elbv2_backend.describe_listeners(
load_balancer_arn, listener_arns load_balancer_arn, listener_arns
@ -453,6 +460,14 @@ class ELBV2Response(BaseResponse):
resource = self.elbv2_backend.load_balancers.get(arn) resource = self.elbv2_backend.load_balancers.get(arn)
if not resource: if not resource:
raise LoadBalancerNotFoundError() raise LoadBalancerNotFoundError()
elif ":listener" in arn:
lb_arn, _, _ = arn.replace(":listener", ":loadbalancer").rpartition("/")
balancer = self.elbv2_backend.load_balancers.get(lb_arn)
if not balancer:
raise LoadBalancerNotFoundError()
resource = balancer.listeners.get(arn)
if not resource:
raise ListenerNotFoundError()
else: else:
raise LoadBalancerNotFoundError() raise LoadBalancerNotFoundError()
resources.append(resource) resources.append(resource)
@ -555,6 +570,7 @@ class ELBV2Response(BaseResponse):
health_check_path = self._get_param("HealthCheckPath") health_check_path = self._get_param("HealthCheckPath")
health_check_interval = self._get_param("HealthCheckIntervalSeconds") health_check_interval = self._get_param("HealthCheckIntervalSeconds")
health_check_timeout = self._get_param("HealthCheckTimeoutSeconds") health_check_timeout = self._get_param("HealthCheckTimeoutSeconds")
health_check_enabled = self._get_param("HealthCheckEnabled")
healthy_threshold_count = self._get_param("HealthyThresholdCount") healthy_threshold_count = self._get_param("HealthyThresholdCount")
unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount") unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount")
http_codes = self._get_param("Matcher.HttpCode") http_codes = self._get_param("Matcher.HttpCode")
@ -569,6 +585,7 @@ class ELBV2Response(BaseResponse):
healthy_threshold_count, healthy_threshold_count,
unhealthy_threshold_count, unhealthy_threshold_count,
http_codes, http_codes,
health_check_enabled=health_check_enabled,
) )
template = self.response_template(MODIFY_TARGET_GROUP_TEMPLATE) template = self.response_template(MODIFY_TARGET_GROUP_TEMPLATE)
@ -687,7 +704,7 @@ CREATE_LOAD_BALANCER_TEMPLATE = """<CreateLoadBalancerResponse xmlns="http://ela
<State> <State>
<Code>{{ load_balancer.state }}</Code> <Code>{{ load_balancer.state }}</Code>
</State> </State>
<Type>application</Type> <Type>{{ load_balancer.loadbalancer_type }}</Type>
</member> </member>
</LoadBalancers> </LoadBalancers>
</CreateLoadBalancerResult> </CreateLoadBalancerResult>
@ -817,10 +834,11 @@ CREATE_TARGET_GROUP_TEMPLATE = """<CreateTargetGroupResponse xmlns="http://elast
<Port>{{ target_group.port }}</Port> <Port>{{ target_group.port }}</Port>
<VpcId>{{ target_group.vpc_id }}</VpcId> <VpcId>{{ target_group.vpc_id }}</VpcId>
<HealthCheckProtocol>{{ target_group.health_check_protocol }}</HealthCheckProtocol> <HealthCheckProtocol>{{ target_group.health_check_protocol }}</HealthCheckProtocol>
<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort> <HealthCheckPort>{{ target_group.healthcheck_port or '' }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path }}</HealthCheckPath> <HealthCheckPath>{{ target_group.healthcheck_path or '' }}</HealthCheckPath>
<HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds> <HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds>
<HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds> <HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds>
<HealthCheckEnabled>{{ target_group.healthcheck_enabled and 'true' or 'false' }}</HealthCheckEnabled>
<HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount> <HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount>
<UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount> <UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount>
{% if target_group.matcher %} {% if target_group.matcher %}
@ -928,7 +946,7 @@ DESCRIBE_LOAD_BALANCERS_TEMPLATE = """<DescribeLoadBalancersResponse xmlns="http
<State> <State>
<Code>{{ load_balancer.state }}</Code> <Code>{{ load_balancer.state }}</Code>
</State> </State>
<Type>application</Type> <Type>{{ load_balancer.loadbalancer_type }}</Type>
<IpAddressType>ipv4</IpAddressType> <IpAddressType>ipv4</IpAddressType>
</member> </member>
{% endfor %} {% endfor %}
@ -1052,10 +1070,11 @@ DESCRIBE_TARGET_GROUPS_TEMPLATE = """<DescribeTargetGroupsResponse xmlns="http:/
<Port>{{ target_group.port }}</Port> <Port>{{ target_group.port }}</Port>
<VpcId>{{ target_group.vpc_id }}</VpcId> <VpcId>{{ target_group.vpc_id }}</VpcId>
<HealthCheckProtocol>{{ target_group.healthcheck_protocol }}</HealthCheckProtocol> <HealthCheckProtocol>{{ target_group.healthcheck_protocol }}</HealthCheckProtocol>
<HealthCheckPort>{{ target_group.healthcheck_port }}</HealthCheckPort> <HealthCheckPort>{{ target_group.healthcheck_port or '' }}</HealthCheckPort>
<HealthCheckPath>{{ target_group.healthcheck_path }}</HealthCheckPath> <HealthCheckPath>{{ target_group.healthcheck_path or '' }}</HealthCheckPath>
<HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds> <HealthCheckIntervalSeconds>{{ target_group.healthcheck_interval_seconds }}</HealthCheckIntervalSeconds>
<HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds> <HealthCheckTimeoutSeconds>{{ target_group.healthcheck_timeout_seconds }}</HealthCheckTimeoutSeconds>
<HealthCheckEnabled>{{ target_group.healthcheck_enabled and 'true' or 'false' }}</HealthCheckEnabled>
<HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount> <HealthyThresholdCount>{{ target_group.healthy_threshold_count }}</HealthyThresholdCount>
<UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount> <UnhealthyThresholdCount>{{ target_group.unhealthy_threshold_count }}</UnhealthyThresholdCount>
{% if target_group.matcher %} {% if target_group.matcher %}

View File

@ -246,9 +246,10 @@ class Rule(CloudFormationModel):
class EventBus(CloudFormationModel): class EventBus(CloudFormationModel):
def __init__(self, region_name, name): def __init__(self, region_name, name, tags=None):
self.region = region_name self.region = region_name
self.name = name self.name = name
self.tags = tags or []
self._permissions = {} self._permissions = {}
@ -545,6 +546,7 @@ class Connection(BaseModel):
def __init__( def __init__(
self, name, region_name, description, authorization_type, auth_parameters, self, name, region_name, description, authorization_type, auth_parameters,
): ):
self.uuid = uuid4()
self.name = name self.name = name
self.region = region_name self.region = region_name
self.description = description self.description = description
@ -555,10 +557,62 @@ class Connection(BaseModel):
@property @property
def arn(self): def arn(self):
return "arn:aws:events:{0}:{1}:connection/{2}".format( return "arn:aws:events:{0}:{1}:connection/{2}/{3}".format(
self.region, ACCOUNT_ID, self.name self.region, ACCOUNT_ID, self.name, self.uuid
) )
def describe_short(self):
"""
Create the short description for the Connection object.
Taken our from the Response Syntax of this API doc:
- https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DeleteConnection.html
Something to consider:
- The original response also has
- LastAuthorizedTime (number)
- LastModifiedTime (number)
- At the time of implemeting this, there was no place where to set/get
those attributes. That is why they are not in the response.
Returns:
dict
"""
return {
"ConnectionArn": self.arn,
"ConnectionState": self.state,
"CreationTime": self.creation_time,
}
def describe(self):
"""
Create a complete description for the Connection object.
Taken our from the Response Syntax of this API doc:
- https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeConnection.html
Something to consider:
- The original response also has:
- LastAuthorizedTime (number)
- LastModifiedTime (number)
- SecretArn (string)
- StateReason (string)
- At the time of implemeting this, there was no place where to set/get
those attributes. That is why they are not in the response.
Returns:
dict
"""
return {
"AuthorizationType": self.authorization_type,
"AuthParameters": self.auth_parameters,
"ConnectionArn": self.arn,
"ConnectionState": self.state,
"CreationTime": self.creation_time,
"Description": self.description,
"Name": self.name,
}
class Destination(BaseModel): class Destination(BaseModel):
def __init__( def __init__(
@ -568,32 +622,71 @@ class Destination(BaseModel):
description, description,
connection_arn, connection_arn,
invocation_endpoint, invocation_endpoint,
invocation_rate_limit_per_second,
http_method, http_method,
): ):
self.uuid = uuid4()
self.name = name self.name = name
self.region = region_name self.region = region_name
self.description = description self.description = description
self.connection_arn = connection_arn self.connection_arn = connection_arn
self.invocation_endpoint = invocation_endpoint self.invocation_endpoint = invocation_endpoint
self.invocation_rate_limit_per_second = invocation_rate_limit_per_second
self.creation_time = unix_time(datetime.utcnow()) self.creation_time = unix_time(datetime.utcnow())
self.http_method = http_method self.http_method = http_method
self.state = "ACTIVE" self.state = "ACTIVE"
@property @property
def arn(self): def arn(self):
return "arn:aws:events:{0}:{1}:destination/{2}".format( return "arn:aws:events:{0}:{1}:api-destination/{2}/{3}".format(
self.region, ACCOUNT_ID, self.name self.region, ACCOUNT_ID, self.name, self.uuid
) )
def describe(self):
"""
Describes the Destination object as a dict
Docs:
Response Syntax in
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeApiDestination.html
Something to consider:
- The response also has [InvocationRateLimitPerSecond] which was not
available when implementing this method
Returns:
dict
"""
return {
"ApiDestinationArn": self.arn,
"ApiDestinationState": self.state,
"ConnectionArn": self.connection_arn,
"CreationTime": self.creation_time,
"Description": self.description,
"HttpMethod": self.http_method,
"InvocationEndpoint": self.invocation_endpoint,
"InvocationRateLimitPerSecond": self.invocation_rate_limit_per_second,
"LastModifiedTime": self.creation_time,
"Name": self.name,
}
def describe_short(self):
return {
"ApiDestinationArn": self.arn,
"ApiDestinationState": self.state,
"CreationTime": self.creation_time,
"LastModifiedTime": self.creation_time,
}
class EventPattern: class EventPattern:
def __init__(self, filter): def __init__(self, filter):
self._filter = self._load_event_pattern(filter) self._filter = self._load_event_pattern(filter)
self._filter_raw = filter
if not self._validate_event_pattern(self._filter): if not self._validate_event_pattern(self._filter):
raise InvalidEventPatternException raise InvalidEventPatternException
def __str__(self): def __str__(self):
return json.dumps(self._filter) return self._filter_raw or str()
def _load_event_pattern(self, pattern): def _load_event_pattern(self, pattern):
try: try:
@ -1032,7 +1125,7 @@ class EventsBackend(BaseBackend):
return event_bus return event_bus
def create_event_bus(self, name, event_source_name=None): def create_event_bus(self, name, event_source_name=None, tags=None):
if name in self.event_buses: if name in self.event_buses:
raise JsonRESTError( raise JsonRESTError(
"ResourceAlreadyExistsException", "ResourceAlreadyExistsException",
@ -1050,7 +1143,10 @@ class EventsBackend(BaseBackend):
"Event source {} does not exist.".format(event_source_name), "Event source {} does not exist.".format(event_source_name),
) )
self.event_buses[name] = EventBus(self.region_name, name) event_bus = EventBus(self.region_name, name, tags=tags)
self.event_buses[name] = event_bus
if tags:
self.tagger.tag_resource(event_bus.arn, tags)
return self.event_buses[name] return self.event_buses[name]
@ -1069,20 +1165,26 @@ class EventsBackend(BaseBackend):
raise JsonRESTError( raise JsonRESTError(
"ValidationException", "Cannot delete event bus default." "ValidationException", "Cannot delete event bus default."
) )
self.event_buses.pop(name, None) event_bus = self.event_buses.pop(name, None)
if event_bus:
self.tagger.delete_all_tags_for_resource(event_bus.arn)
def list_tags_for_resource(self, arn): def list_tags_for_resource(self, arn):
name = arn.split("/")[-1] name = arn.split("/")[-1]
if name in self.rules: registries = [self.rules, self.event_buses]
return self.tagger.list_tags_for_resource(self.rules[name].arn) for registry in registries:
if name in registry:
return self.tagger.list_tags_for_resource(registry[name].arn)
raise ResourceNotFoundException( raise ResourceNotFoundException(
"Rule {0} does not exist on EventBus default.".format(name) "Rule {0} does not exist on EventBus default.".format(name)
) )
def tag_resource(self, arn, tags): def tag_resource(self, arn, tags):
name = arn.split("/")[-1] name = arn.split("/")[-1]
if name in self.rules: registries = [self.rules, self.event_buses]
self.tagger.tag_resource(self.rules[name].arn, tags) for registry in registries:
if name in registry:
self.tagger.tag_resource(registry[name].arn, tags)
return {} return {}
raise ResourceNotFoundException( raise ResourceNotFoundException(
"Rule {0} does not exist on EventBus default.".format(name) "Rule {0} does not exist on EventBus default.".format(name)
@ -1090,8 +1192,10 @@ class EventsBackend(BaseBackend):
def untag_resource(self, arn, tag_names): def untag_resource(self, arn, tag_names):
name = arn.split("/")[-1] name = arn.split("/")[-1]
if name in self.rules: registries = [self.rules, self.event_buses]
self.tagger.untag_resource_using_names(self.rules[name].arn, tag_names) for registry in registries:
if name in registry:
self.tagger.untag_resource_using_names(registry[name].arn, tag_names)
return {} return {}
raise ResourceNotFoundException( raise ResourceNotFoundException(
"Rule {0} does not exist on EventBus default.".format(name) "Rule {0} does not exist on EventBus default.".format(name)
@ -1337,27 +1441,145 @@ class EventsBackend(BaseBackend):
def list_connections(self): def list_connections(self):
return self.connections.values() return self.connections.values()
def create_api_destination( def describe_connection(self, name):
self, name, description, connection_arn, invocation_endpoint, http_method """
): Retrieves details about a connection.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeConnection.html
Args:
name: The name of the connection to retrieve.
Raises:
ResourceNotFoundException: When the connection is not present.
Returns:
dict
"""
connection = self.connections.get(name)
if not connection:
raise ResourceNotFoundException(
"Connection '{}' does not exist.".format(name)
)
return connection.describe()
def delete_connection(self, name):
"""
Deletes a connection.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DeleteConnection.html
Args:
name: The name of the connection to delete.
Raises:
ResourceNotFoundException: When the connection is not present.
Returns:
dict
"""
connection = self.connections.pop(name, None)
if not connection:
raise ResourceNotFoundException(
"Connection '{}' does not exist.".format(name)
)
return connection.describe_short()
def create_api_destination(
self,
name,
description,
connection_arn,
invocation_endpoint,
invocation_rate_limit_per_second,
http_method,
):
"""
Creates an API destination, which is an HTTP invocation endpoint configured as a target for events.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_CreateApiDestination.html
Returns:
dict
"""
destination = Destination( destination = Destination(
name=name, name=name,
region_name=self.region_name, region_name=self.region_name,
description=description, description=description,
connection_arn=connection_arn, connection_arn=connection_arn,
invocation_endpoint=invocation_endpoint, invocation_endpoint=invocation_endpoint,
invocation_rate_limit_per_second=invocation_rate_limit_per_second,
http_method=http_method, http_method=http_method,
) )
self.destinations[name] = destination self.destinations[name] = destination
return destination return destination.describe_short()
def list_api_destinations(self): def list_api_destinations(self):
return self.destinations.values() return self.destinations.values()
def describe_api_destination(self, name): def describe_api_destination(self, name):
return self.destinations.get(name) """
Retrieves details about an API destination.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DescribeApiDestination.html
Args:
name: The name of the API destination to retrieve.
Returns:
dict
"""
destination = self.destinations.get(name)
if not destination:
raise ResourceNotFoundException(
"An api-destination '{}' does not exist.".format(name)
)
return destination.describe()
def update_api_destination(self, *, name, **kwargs):
"""
Creates an API destination, which is an HTTP invocation endpoint configured as a target for events.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_UpdateApiDestination.html
Returns:
dict
"""
destination = self.destinations.get(name)
if not destination:
raise ResourceNotFoundException(
"An api-destination '{}' does not exist.".format(name)
)
for attr, value in kwargs.items():
if value is not None and hasattr(destination, attr):
setattr(destination, attr, value)
return destination.describe_short()
def delete_api_destination(self, name):
"""
Deletes the specified API destination.
Docs:
https://docs.aws.amazon.com/eventbridge/latest/APIReference/API_DeleteApiDestination.html
Args:
name: The name of the destination to delete.
Raises:
ResourceNotFoundException: When the destination is not present.
Returns:
dict
"""
destination = self.destinations.pop(name, None)
if not destination:
raise ResourceNotFoundException(
"An api-destination '{}' does not exist.".format(name)
)
return {}
events_backends = {} events_backends = {}

View File

@ -42,6 +42,20 @@ class EventsHandler(BaseResponse):
def _get_param(self, param, if_none=None): def _get_param(self, param, if_none=None):
return self.request_params.get(param, if_none) return self.request_params.get(param, if_none)
def _create_response(self, result):
"""
Creates a proper response for the API.
It basically transforms a dict-like result from the backend
into a tuple (str, dict) properly formatted.
Args:
result (dict): result from backend
Returns:
(str, dict): dumped result and headers
"""
return json.dumps(result), self.headers
def error(self, type_, message="", status=400): def error(self, type_, message="", status=400):
headers = self.response_headers headers = self.response_headers
headers["status"] = status headers["status"] = status
@ -266,9 +280,9 @@ class EventsHandler(BaseResponse):
def create_event_bus(self): def create_event_bus(self):
name = self._get_param("Name") name = self._get_param("Name")
event_source_name = self._get_param("EventSourceName") event_source_name = self._get_param("EventSourceName")
tags = self._get_param("Tags")
event_bus = self.events_backend.create_event_bus(name, event_source_name) event_bus = self.events_backend.create_event_bus(name, event_source_name, tags)
return json.dumps({"EventBusArn": event_bus.arn}), self.response_headers return json.dumps({"EventBusArn": event_bus.arn}), self.response_headers
def list_event_buses(self): def list_event_buses(self):
@ -448,27 +462,35 @@ class EventsHandler(BaseResponse):
return json.dumps({"Connections": result}), self.response_headers return json.dumps({"Connections": result}), self.response_headers
def describe_connection(self):
name = self._get_param("Name")
result = self.events_backend.describe_connection(name)
return json.dumps(result), self.response_headers
def delete_connection(self):
name = self._get_param("Name")
result = self.events_backend.delete_connection(name)
return json.dumps(result), self.response_headers
def create_api_destination(self): def create_api_destination(self):
name = self._get_param("Name") name = self._get_param("Name")
description = self._get_param("Description") description = self._get_param("Description")
connection_arn = self._get_param("ConnectionArn") connection_arn = self._get_param("ConnectionArn")
invocation_endpoint = self._get_param("InvocationEndpoint") invocation_endpoint = self._get_param("InvocationEndpoint")
invocation_rate_limit_per_second = self._get_param(
"InvocationRateLimitPerSecond"
)
http_method = self._get_param("HttpMethod") http_method = self._get_param("HttpMethod")
destination = self.events_backend.create_api_destination( result = self.events_backend.create_api_destination(
name, description, connection_arn, invocation_endpoint, http_method name,
) description,
return ( connection_arn,
json.dumps( invocation_endpoint,
{ invocation_rate_limit_per_second,
"ApiDestinationArn": destination.arn, http_method,
"ApiDestinationState": "ACTIVE",
"CreationTime": destination.creation_time,
"LastModifiedTime": destination.creation_time,
}
),
self.response_headers,
) )
return self._create_response(result)
def list_api_destinations(self): def list_api_destinations(self):
destinations = self.events_backend.list_api_destinations() destinations = self.events_backend.list_api_destinations()
@ -491,20 +513,25 @@ class EventsHandler(BaseResponse):
def describe_api_destination(self): def describe_api_destination(self):
name = self._get_param("Name") name = self._get_param("Name")
destination = self.events_backend.describe_api_destination(name) result = self.events_backend.describe_api_destination(name)
return self._create_response(result)
return ( def update_api_destination(self):
json.dumps( updates = dict(
{ connection_arn=self._get_param("ConnectionArn"),
"ApiDestinationArn": destination.arn, description=self._get_param("Description"),
"Name": destination.name, http_method=self._get_param("HttpMethod"),
"ApiDestinationState": destination.state, invocation_endpoint=self._get_param("InvocationEndpoint"),
"ConnectionArn": destination.connection_arn, invocation_rate_limit_per_second=self._get_param(
"InvocationEndpoint": destination.invocation_endpoint, "InvocationRateLimitPerSecond"
"HttpMethod": destination.http_method,
"CreationTime": destination.creation_time,
"LastModifiedTime": destination.creation_time,
}
), ),
self.response_headers, name=self._get_param("Name"),
) )
result = self.events_backend.update_api_destination(**updates)
return self._create_response(result)
def delete_api_destination(self):
name = self._get_param("Name")
result = self.events_backend.delete_api_destination(name)
return self._create_response(result)

View File

@ -1,12 +1,16 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.exceptions import RESTError from moto.core.exceptions import RESTError
XMLNS_IAM = "https://iam.amazonaws.com/doc/2010-05-08/"
class IAMNotFoundException(RESTError): class IAMNotFoundException(RESTError):
code = 404 code = 404
def __init__(self, message): def __init__(self, message):
super(IAMNotFoundException, self).__init__("NoSuchEntity", message) super(IAMNotFoundException, self).__init__(
"NoSuchEntity", message, xmlns=XMLNS_IAM, template="wrapped_single_error"
)
class IAMConflictException(RESTError): class IAMConflictException(RESTError):
@ -134,4 +138,6 @@ class NoSuchEntity(RESTError):
code = 404 code = 404
def __init__(self, message): def __init__(self, message):
super(NoSuchEntity, self).__init__("NoSuchEntity", message) super(NoSuchEntity, self).__init__(
"NoSuchEntity", message, xmlns=XMLNS_IAM, template="wrapped_single_error"
)

32
moto/iam/models.py Executable file → Normal file
View File

@ -101,6 +101,7 @@ class Policy(CloudFormationModel):
path=None, path=None,
create_date=None, create_date=None,
update_date=None, update_date=None,
tags=None,
): ):
self.name = name self.name = name
@ -108,6 +109,7 @@ class Policy(CloudFormationModel):
self.description = description or "" self.description = description or ""
self.id = random_policy_id() self.id = random_policy_id()
self.path = path or "/" self.path = path or "/"
self.tags = {tag["Key"]: tag["Value"] for tag in tags or []}
if default_version_id: if default_version_id:
self.default_version_id = default_version_id self.default_version_id = default_version_id
@ -337,9 +339,10 @@ class AWSManagedPolicy(ManagedPolicy):
# AWS defines some of its own managed policies and we periodically # AWS defines some of its own managed policies and we periodically
# import them via `make aws_managed_policies` # import them via `make aws_managed_policies`
# FIXME: Takes about 40ms at import time # FIXME: Takes about 40ms at import time
aws_managed_policies_data_parsed = json.loads(aws_managed_policies_data)
aws_managed_policies = [ aws_managed_policies = [
AWSManagedPolicy.from_data(name, d) AWSManagedPolicy.from_data(name, d)
for name, d in json.loads(aws_managed_policies_data).items() for name, d in aws_managed_policies_data_parsed.items()
] ]
@ -646,14 +649,21 @@ class Role(CloudFormationModel):
def get_tags(self): def get_tags(self):
return [self.tags[tag] for tag in self.tags] return [self.tags[tag] for tag in self.tags]
@property
def description_escaped(self):
import html
return html.escape(self.description or "")
class InstanceProfile(CloudFormationModel): class InstanceProfile(CloudFormationModel):
def __init__(self, instance_profile_id, name, path, roles): def __init__(self, instance_profile_id, name, path, roles, tags=None):
self.id = instance_profile_id self.id = instance_profile_id
self.name = name self.name = name
self.path = path or "/" self.path = path or "/"
self.roles = roles if roles else [] self.roles = roles if roles else []
self.create_date = datetime.utcnow() self.create_date = datetime.utcnow()
self.tags = {tag["Key"]: tag["Value"] for tag in tags or []}
@property @property
def created_iso_8601(self): def created_iso_8601(self):
@ -1410,7 +1420,7 @@ class IAMBackend(BaseBackend):
self.account_aliases = [] self.account_aliases = []
self.saml_providers = {} self.saml_providers = {}
self.open_id_providers = {} self.open_id_providers = {}
self.policy_arn_regex = re.compile(r"^arn:aws:iam::[0-9]*:policy/.*$") self.policy_arn_regex = re.compile(r"^arn:aws:iam::(aws|[0-9]*):policy/.*$")
self.virtual_mfa_devices = {} self.virtual_mfa_devices = {}
self.account_password_policy = None self.account_password_policy = None
self.account_summary = AccountSummary(self) self.account_summary = AccountSummary(self)
@ -1496,12 +1506,16 @@ class IAMBackend(BaseBackend):
raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn)) raise IAMNotFoundException("Policy {0} was not found.".format(policy_arn))
policy.detach_from(self.get_user(user_name)) policy.detach_from(self.get_user(user_name))
def create_policy(self, description, path, policy_document, policy_name): def create_policy(self, description, path, policy_document, policy_name, tags=None):
iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document) iam_policy_document_validator = IAMPolicyDocumentValidator(policy_document)
iam_policy_document_validator.validate() iam_policy_document_validator.validate()
policy = ManagedPolicy( policy = ManagedPolicy(
policy_name, description=description, document=policy_document, path=path policy_name,
description=description,
document=policy_document,
path=path,
tags=tags,
) )
if policy.arn in self.managed_policies: if policy.arn in self.managed_policies:
raise EntityAlreadyExists( raise EntityAlreadyExists(
@ -1551,9 +1565,9 @@ class IAMBackend(BaseBackend):
def set_default_policy_version(self, policy_arn, version_id): def set_default_policy_version(self, policy_arn, version_id):
import re import re
if re.match("v[1-9][0-9]*(\.[A-Za-z0-9-]*)?", version_id) is None: if re.match(r"v[1-9][0-9]*(\.[A-Za-z0-9-]*)?", version_id) is None:
raise ValidationError( raise ValidationError(
"Value '{0}' at 'versionId' failed to satisfy constraint: Member must satisfy regular expression pattern: v[1-9][0-9]*(\.[A-Za-z0-9-]*)?".format( "Value '{0}' at 'versionId' failed to satisfy constraint: Member must satisfy regular expression pattern: v[1-9][0-9]*(\\.[A-Za-z0-9-]*)?".format(
version_id version_id
) )
) )
@ -1823,7 +1837,7 @@ class IAMBackend(BaseBackend):
return return
raise IAMNotFoundException("Policy not found") raise IAMNotFoundException("Policy not found")
def create_instance_profile(self, name, path, role_ids): def create_instance_profile(self, name, path, role_ids, tags=None):
if self.instance_profiles.get(name): if self.instance_profiles.get(name):
raise IAMConflictException( raise IAMConflictException(
code="EntityAlreadyExists", code="EntityAlreadyExists",
@ -1833,7 +1847,7 @@ class IAMBackend(BaseBackend):
instance_profile_id = random_resource_id() instance_profile_id = random_resource_id()
roles = [iam_backend.get_role_by_id(role_id) for role_id in role_ids] roles = [iam_backend.get_role_by_id(role_id) for role_id in role_ids]
instance_profile = InstanceProfile(instance_profile_id, name, path, roles) instance_profile = InstanceProfile(instance_profile_id, name, path, roles, tags)
self.instance_profiles[name] = instance_profile self.instance_profiles[name] = instance_profile
return instance_profile return instance_profile

View File

@ -347,7 +347,7 @@ class IAMPolicyDocumentValidator:
return return
resource_partitions = resource_partitions[2].partition(":") resource_partitions = resource_partitions[2].partition(":")
if resource_partitions[0] != "aws": if resource_partitions[0] not in ["aws", "*"]:
remaining_resource_parts = resource_partitions[2].split(":") remaining_resource_parts = resource_partitions[2].split(":")
arn1 = ( arn1 = (

View File

@ -53,8 +53,9 @@ class IamResponse(BaseResponse):
path = self._get_param("Path") path = self._get_param("Path")
policy_document = self._get_param("PolicyDocument") policy_document = self._get_param("PolicyDocument")
policy_name = self._get_param("PolicyName") policy_name = self._get_param("PolicyName")
tags = self._get_multi_param("Tags.member")
policy = iam_backend.create_policy( policy = iam_backend.create_policy(
description, path, policy_document, policy_name description, path, policy_document, policy_name, tags
) )
template = self.response_template(CREATE_POLICY_TEMPLATE) template = self.response_template(CREATE_POLICY_TEMPLATE)
return template.render(policy=policy) return template.render(policy=policy)
@ -320,8 +321,11 @@ class IamResponse(BaseResponse):
def create_instance_profile(self): def create_instance_profile(self):
profile_name = self._get_param("InstanceProfileName") profile_name = self._get_param("InstanceProfileName")
path = self._get_param("Path", "/") path = self._get_param("Path", "/")
tags = self._get_multi_param("Tags.member")
profile = iam_backend.create_instance_profile(profile_name, path, role_ids=[]) profile = iam_backend.create_instance_profile(
profile_name, path, role_ids=[], tags=tags
)
template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE) template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE)
return template.render(profile=profile) return template.render(profile=profile)
@ -461,6 +465,12 @@ class IamResponse(BaseResponse):
template = self.response_template(GET_GROUP_POLICY_TEMPLATE) template = self.response_template(GET_GROUP_POLICY_TEMPLATE)
return template.render(name="GetGroupPolicyResponse", **policy_result) return template.render(name="GetGroupPolicyResponse", **policy_result)
def delete_group_policy(self):
group_name = self._get_param("GroupName")
policy_name = self._get_param("PolicyName")
iam_backend.delete_group_policy(group_name, policy_name)
return ""
def delete_group(self): def delete_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
iam_backend.delete_group(group_name) iam_backend.delete_group(group_name)
@ -1093,6 +1103,14 @@ CREATE_POLICY_TEMPLATE = """<CreatePolicyResponse>
<PolicyId>{{ policy.id }}</PolicyId> <PolicyId>{{ policy.id }}</PolicyId>
<PolicyName>{{ policy.name }}</PolicyName> <PolicyName>{{ policy.name }}</PolicyName>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate> <UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
<Tags>
{% for tag_key, tag_value in policy.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</Policy> </Policy>
</CreatePolicyResult> </CreatePolicyResult>
<ResponseMetadata> <ResponseMetadata>
@ -1112,6 +1130,14 @@ GET_POLICY_TEMPLATE = """<GetPolicyResponse>
<AttachmentCount>{{ policy.attachment_count }}</AttachmentCount> <AttachmentCount>{{ policy.attachment_count }}</AttachmentCount>
<CreateDate>{{ policy.created_iso_8601 }}</CreateDate> <CreateDate>{{ policy.created_iso_8601 }}</CreateDate>
<UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate> <UpdateDate>{{ policy.updated_iso_8601 }}</UpdateDate>
<Tags>
{% for tag_key, tag_value in policy.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</Policy> </Policy>
</GetPolicyResult> </GetPolicyResult>
<ResponseMetadata> <ResponseMetadata>
@ -1223,11 +1249,30 @@ CREATE_INSTANCE_PROFILE_TEMPLATE = """<CreateInstanceProfileResponse xmlns="http
<CreateInstanceProfileResult> <CreateInstanceProfileResult>
<InstanceProfile> <InstanceProfile>
<InstanceProfileId>{{ profile.id }}</InstanceProfileId> <InstanceProfileId>{{ profile.id }}</InstanceProfileId>
<Roles/> <Roles>
{% for role in profile.roles %}
<member>
<Path>{{ role.path }}</Path>
<Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId>
</member>
{% endfor %}
</Roles>
<InstanceProfileName>{{ profile.name }}</InstanceProfileName> <InstanceProfileName>{{ profile.name }}</InstanceProfileName>
<Path>{{ profile.path }}</Path> <Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn> <Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate> <CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
<Tags>
{% for tag_key, tag_value in profile.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</InstanceProfile> </InstanceProfile>
</CreateInstanceProfileResult> </CreateInstanceProfileResult>
<ResponseMetadata> <ResponseMetadata>
@ -1261,6 +1306,14 @@ GET_INSTANCE_PROFILE_TEMPLATE = """<GetInstanceProfileResponse xmlns="https://ia
<Path>{{ profile.path }}</Path> <Path>{{ profile.path }}</Path>
<Arn>{{ profile.arn }}</Arn> <Arn>{{ profile.arn }}</Arn>
<CreateDate>{{ profile.created_iso_8601 }}</CreateDate> <CreateDate>{{ profile.created_iso_8601 }}</CreateDate>
<Tags>
{% for tag_key, tag_value in profile.tags.items() %}
<member>
<Key>{{ tag_key }}</Key>
<Value>{{ tag_value }}</Value>
</member>
{% endfor %}
</Tags>
</InstanceProfile> </InstanceProfile>
</GetInstanceProfileResult> </GetInstanceProfileResult>
<ResponseMetadata> <ResponseMetadata>
@ -1276,7 +1329,7 @@ CREATE_ROLE_TEMPLATE = """<CreateRoleResponse xmlns="https://iam.amazonaws.com/d
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
{% if role.description is not none %} {% if role.description is not none %}
<Description>{{role.description}}</Description> <Description>{{ role.description_escaped }}</Description>
{% endif %} {% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
@ -1330,7 +1383,9 @@ UPDATE_ROLE_DESCRIPTION_TEMPLATE = """<UpdateRoleDescriptionResponse xmlns="http
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<Description>{{role.description}}</Description> {% if role.description is not none %}
<Description>{{ role.description_escaped }}</Description>
{% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
<MaxSessionDuration>{{ role.max_session_duration }}</MaxSessionDuration> <MaxSessionDuration>{{ role.max_session_duration }}</MaxSessionDuration>
@ -1358,8 +1413,8 @@ GET_ROLE_TEMPLATE = """<GetRoleResponse xmlns="https://iam.amazonaws.com/doc/201
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
{% if role.description %} {% if role.description is not none %}
<Description>{{role.description}}</Description> <Description>{{ role.description_escaped }}</Description>
{% endif %} {% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
@ -1422,7 +1477,7 @@ LIST_ROLES_TEMPLATE = """<ListRolesResponse xmlns="https://iam.amazonaws.com/doc
</PermissionsBoundary> </PermissionsBoundary>
{% endif %} {% endif %}
{% if role.description is not none %} {% if role.description is not none %}
<Description>{{ role.description }}</Description> <Description>{{ role.description_escaped }}</Description>
{% endif %} {% endif %}
</member> </member>
{% endfor %} {% endfor %}
@ -2243,7 +2298,9 @@ GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE = """<GetAccountAuthorizationDetailsR
<Arn>{{ role.arn }}</Arn> <Arn>{{ role.arn }}</Arn>
<RoleName>{{ role.name }}</RoleName> <RoleName>{{ role.name }}</RoleName>
<AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument> <AssumeRolePolicyDocument>{{ role.assume_role_policy_document }}</AssumeRolePolicyDocument>
<Description>{{role.description}}</Description> {% if role.description is not none %}
<Description>{{ role.description_escaped }}</Description>
{% endif %}
<CreateDate>{{ role.created_iso_8601 }}</CreateDate> <CreateDate>{{ role.created_iso_8601 }}</CreateDate>
<RoleId>{{ role.id }}</RoleId> <RoleId>{{ role.id }}</RoleId>
{% if role.permissions_boundary %} {% if role.permissions_boundary %}

View File

@ -1,5 +1,6 @@
from boto3 import Session from boto3 import Session
from moto import core as moto_core
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.core.utils import unix_time_millis from moto.core.utils import unix_time_millis
from .exceptions import ( from .exceptions import (
@ -53,7 +54,7 @@ class LogStream(BaseModel):
self.region = region self.region = region
self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format( self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format(
region=region, region=region,
id=self.__class__._log_ids, id=moto_core.ACCOUNT_ID,
log_group=log_group, log_group=log_group,
log_stream=name, log_stream=name,
) )
@ -262,6 +263,11 @@ class LogGroup(BaseModel):
) # AWS defaults to Never Expire for log group retention ) # AWS defaults to Never Expire for log group retention
self.subscription_filters = [] self.subscription_filters = []
# The Amazon Resource Name (ARN) of the CMK to use when encrypting log data. It is optional.
# Docs:
# https://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_CreateLogGroup.html
self.kms_key_id = kwargs.get("kmsKeyId")
def create_log_stream(self, log_stream_name): def create_log_stream(self, log_stream_name):
if log_stream_name in self.streams: if log_stream_name in self.streams:
raise ResourceAlreadyExistsException() raise ResourceAlreadyExistsException()
@ -442,6 +448,8 @@ class LogGroup(BaseModel):
# AWS only returns retentionInDays if a value is set for the log group (ie. not Never Expire) # AWS only returns retentionInDays if a value is set for the log group (ie. not Never Expire)
if self.retention_in_days: if self.retention_in_days:
log_group["retentionInDays"] = self.retention_in_days log_group["retentionInDays"] = self.retention_in_days
if self.kms_key_id:
log_group["kmsKeyId"] = self.kms_key_id
return log_group return log_group
def set_retention_policy(self, retention_in_days): def set_retention_policy(self, retention_in_days):
@ -510,6 +518,7 @@ class LogsBackend(BaseBackend):
self.region_name = region_name self.region_name = region_name
self.groups = dict() # { logGroupName: LogGroup} self.groups = dict() # { logGroupName: LogGroup}
self.queries = dict() self.queries = dict()
self.resource_policies = dict()
def reset(self): def reset(self):
region_name = self.region_name region_name = self.region_name
@ -677,6 +686,10 @@ class LogsBackend(BaseBackend):
log_group = self.groups[log_group_name] log_group = self.groups[log_group_name]
return log_group.set_retention_policy(None) return log_group.set_retention_policy(None)
def put_resource_policy(self, policy_name, policy_doc):
policy = {"policyName": policy_name, "policyDocument": policy_doc}
self.resource_policies[policy_name] = policy
def list_tags_log_group(self, log_group_name): def list_tags_log_group(self, log_group_name):
if log_group_name not in self.groups: if log_group_name not in self.groups:
raise ResourceNotFoundException() raise ResourceNotFoundException()

View File

@ -25,9 +25,10 @@ class LogsResponse(BaseResponse):
def create_log_group(self): def create_log_group(self):
log_group_name = self._get_param("logGroupName") log_group_name = self._get_param("logGroupName")
tags = self._get_param("tags") tags = self._get_param("tags")
kms_key_id = self._get_param("kmsKeyId")
assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern
self.logs_backend.create_log_group(log_group_name, tags) self.logs_backend.create_log_group(log_group_name, tags, kmsKeyId=kms_key_id)
return "" return ""
def delete_log_group(self): def delete_log_group(self):
@ -166,6 +167,12 @@ class LogsResponse(BaseResponse):
self.logs_backend.delete_retention_policy(log_group_name) self.logs_backend.delete_retention_policy(log_group_name)
return "" return ""
def put_resource_policy(self):
policy_name = self._get_param("policyName")
policy_doc = self._get_param("policyDocument")
self.logs_backend.put_resource_policy(policy_name, policy_doc)
return ""
def list_tags_log_group(self): def list_tags_log_group(self):
log_group_name = self._get_param("logGroupName") log_group_name = self._get_param("logGroupName")
tags = self.logs_backend.list_tags_log_group(log_group_name) tags = self.logs_backend.list_tags_log_group(log_group_name)

View File

@ -32,6 +32,7 @@ class FakeOrganization(BaseModel):
self.master_account_id = utils.MASTER_ACCOUNT_ID self.master_account_id = utils.MASTER_ACCOUNT_ID
self.master_account_email = utils.MASTER_ACCOUNT_EMAIL self.master_account_email = utils.MASTER_ACCOUNT_EMAIL
self.available_policy_types = [ self.available_policy_types = [
# TODO: verify if this should be enabled by default (breaks TF tests for CloudTrail)
{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"} {"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}
] ]
@ -141,7 +142,10 @@ class FakeRoot(FakeOrganizationalUnit):
self.type = "ROOT" self.type = "ROOT"
self.id = organization.root_id self.id = organization.root_id
self.name = "Root" self.name = "Root"
self.policy_types = [{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}] self.policy_types = [
# TODO: verify if this should be enabled by default (breaks TF tests for CloudTrail)
{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}
]
self._arn_format = utils.ROOT_ARN_FORMAT self._arn_format = utils.ROOT_ARN_FORMAT
self.attached_policies = [] self.attached_policies = []
self.tags = {tag["Key"]: tag["Value"] for tag in kwargs.get("Tags", [])} self.tags = {tag["Key"]: tag["Value"] for tag in kwargs.get("Tags", [])}
@ -328,6 +332,9 @@ class FakeDelegatedAdministrator(BaseModel):
class OrganizationsBackend(BaseBackend): class OrganizationsBackend(BaseBackend):
def __init__(self): def __init__(self):
self._reset()
def _reset(self):
self.org = None self.org = None
self.accounts = [] self.accounts = []
self.ou = [] self.ou = []
@ -375,6 +382,10 @@ class OrganizationsBackend(BaseBackend):
raise AWSOrganizationsNotInUseException raise AWSOrganizationsNotInUseException
return self.org.describe() return self.org.describe()
def delete_organization(self, **kwargs):
self._reset()
return {}
def list_roots(self): def list_roots(self):
return dict(Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)]) return dict(Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)])

View File

@ -28,6 +28,11 @@ class OrganizationsResponse(BaseResponse):
def describe_organization(self): def describe_organization(self):
return json.dumps(self.organizations_backend.describe_organization()) return json.dumps(self.organizations_backend.describe_organization())
def delete_organization(self):
return json.dumps(
self.organizations_backend.delete_organization(**self.request_params)
)
def list_roots(self): def list_roots(self):
return json.dumps(self.organizations_backend.list_roots()) return json.dumps(self.organizations_backend.list_roots())

View File

@ -43,14 +43,7 @@ class Database(CloudFormationModel):
"engine": FilterDef(["engine"], "Engine Names"), "engine": FilterDef(["engine"], "Engine Names"),
} }
def __init__(self, **kwargs): default_engine_versions = {
self.status = "available"
self.is_replica = False
self.replicas = []
self.region = kwargs.get("region")
self.engine = kwargs.get("engine")
self.engine_version = kwargs.get("engine_version", None)
self.default_engine_versions = {
"MySQL": "5.6.21", "MySQL": "5.6.21",
"mysql": "5.6.21", "mysql": "5.6.21",
"oracle-se1": "11.2.0.4.v3", "oracle-se1": "11.2.0.4.v3",
@ -62,6 +55,14 @@ class Database(CloudFormationModel):
"sqlserver-web": "11.00.2100.60.v1", "sqlserver-web": "11.00.2100.60.v1",
"postgres": "9.3.3", "postgres": "9.3.3",
} }
def __init__(self, **kwargs):
self.status = "available"
self.is_replica = False
self.replicas = []
self.region = kwargs.get("region")
self.engine = kwargs.get("engine")
self.engine_version = kwargs.get("engine_version", None)
if not self.engine_version and self.engine in self.default_engine_versions: if not self.engine_version and self.engine in self.default_engine_versions:
self.engine_version = self.default_engine_versions[self.engine] self.engine_version = self.default_engine_versions[self.engine]
self.iops = kwargs.get("iops") self.iops = kwargs.get("iops")
@ -120,6 +121,7 @@ class Database(CloudFormationModel):
self.db_parameter_group_name = kwargs.get("db_parameter_group_name") self.db_parameter_group_name = kwargs.get("db_parameter_group_name")
if ( if (
self.db_parameter_group_name self.db_parameter_group_name
and not self.is_default_parameter_group(self.db_parameter_group_name)
and self.db_parameter_group_name and self.db_parameter_group_name
not in rds2_backends[self.region].db_parameter_groups not in rds2_backends[self.region].db_parameter_groups
): ):
@ -160,7 +162,9 @@ class Database(CloudFormationModel):
return self.db_instance_identifier return self.db_instance_identifier
def db_parameter_groups(self): def db_parameter_groups(self):
if not self.db_parameter_group_name: if not self.db_parameter_group_name or self.is_default_parameter_group(
self.db_parameter_group_name
):
( (
db_family, db_family,
db_parameter_group_name, db_parameter_group_name,
@ -182,6 +186,9 @@ class Database(CloudFormationModel):
] ]
] ]
def is_default_parameter_group(self, param_group_name):
return param_group_name.startswith("default.%s" % self.engine.lower())
def default_db_parameter_group_details(self): def default_db_parameter_group_details(self):
if not self.engine_version: if not self.engine_version:
return (None, None) return (None, None)

View File

@ -97,6 +97,8 @@ class FakeResourceGroup(BaseModel):
return True return True
def _validate_resource_query(self, value): def _validate_resource_query(self, value):
if not value:
return True
errors = [] errors = []
if value["Type"] not in {"CLOUDFORMATION_STACK_1_0", "TAG_FILTERS_1_0"}: if value["Type"] not in {"CLOUDFORMATION_STACK_1_0", "TAG_FILTERS_1_0"}:
errors.append( errors.append(
@ -229,6 +231,8 @@ class ResourceGroupsBackend(BaseBackend):
@staticmethod @staticmethod
def _validate_resource_query(resource_query): def _validate_resource_query(resource_query):
if not resource_query:
return
type = resource_query["Type"] type = resource_query["Type"]
query = json.loads(resource_query["Query"]) query = json.loads(resource_query["Query"])
query_keys = set(query.keys()) query_keys = set(query.keys())

View File

@ -44,7 +44,7 @@ class ResourceGroupsResponse(BaseResponse):
) )
def delete_group(self): def delete_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName") or self._get_param("Group")
group = self.resourcegroups_backend.delete_group(group_name=group_name) group = self.resourcegroups_backend.delete_group(group_name=group_name)
return json.dumps( return json.dumps(
{ {

View File

@ -6,6 +6,8 @@ from moto.core.responses import BaseResponse
from .models import route53_backend from .models import route53_backend
import xmltodict import xmltodict
XMLNS = "https://route53.amazonaws.com/doc/2013-04-01/"
class Route53(BaseResponse): class Route53(BaseResponse):
def list_or_create_hostzone_response(self, request, full_url, headers): def list_or_create_hostzone_response(self, request, full_url, headers):
@ -83,7 +85,7 @@ class Route53(BaseResponse):
zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1] zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1]
the_zone = route53_backend.get_hosted_zone(zoneid) the_zone = route53_backend.get_hosted_zone(zoneid)
if not the_zone: if not the_zone:
return 404, headers, "Zone %s not Found" % zoneid return no_such_hosted_zone_error(zoneid, headers)
if request.method == "GET": if request.method == "GET":
template = Template(GET_HOSTED_ZONE_RESPONSE) template = Template(GET_HOSTED_ZONE_RESPONSE)
@ -102,7 +104,7 @@ class Route53(BaseResponse):
zoneid = parsed_url.path.rstrip("/").rsplit("/", 2)[1] zoneid = parsed_url.path.rstrip("/").rsplit("/", 2)[1]
the_zone = route53_backend.get_hosted_zone(zoneid) the_zone = route53_backend.get_hosted_zone(zoneid)
if not the_zone: if not the_zone:
return 404, headers, "Zone %s Not Found" % zoneid return no_such_hosted_zone_error(zoneid, headers)
if method == "POST": if method == "POST":
elements = xmltodict.parse(self.body) elements = xmltodict.parse(self.body)
@ -256,6 +258,20 @@ class Route53(BaseResponse):
return 200, headers, template.render(change_id=change_id) return 200, headers, template.render(change_id=change_id)
def no_such_hosted_zone_error(zoneid, headers={}):
headers["X-Amzn-ErrorType"] = "NoSuchHostedZone"
headers["Content-Type"] = "text/xml"
message = "Zone %s Not Found" % zoneid
error_response = (
"<Error><Code>NoSuchHostedZone</Code><Message>%s</Message></Error>" % message
)
error_response = '<ErrorResponse xmlns="%s">%s</ErrorResponse>' % (
XMLNS,
error_response,
)
return 404, headers, error_response
LIST_TAGS_FOR_RESOURCE_RESPONSE = """ LIST_TAGS_FOR_RESOURCE_RESPONSE = """
<ListTagsForResourceResponse xmlns="https://route53.amazonaws.com/doc/2015-01-01/"> <ListTagsForResourceResponse xmlns="https://route53.amazonaws.com/doc/2015-01-01/">
<ResourceTagSet> <ResourceTagSet>

View File

@ -30,12 +30,28 @@ ERROR_WITH_RANGE = """{% extends 'single_error' %}
class S3ClientError(RESTError): class S3ClientError(RESTError):
# S3 API uses <RequestID> as the XML tag in response messages
request_id_tag_name = "RequestID"
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "single_error") kwargs.setdefault("template", "single_error")
self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME
super(S3ClientError, self).__init__(*args, **kwargs) super(S3ClientError, self).__init__(*args, **kwargs)
class InvalidArgumentError(S3ClientError):
code = 400
def __init__(self, message, name, value, *args, **kwargs):
kwargs.setdefault("template", "argument_error")
kwargs["name"] = name
kwargs["value"] = value
self.templates["argument_error"] = ERROR_WITH_ARGUMENT
super(InvalidArgumentError, self).__init__(
"InvalidArgument", message, *args, **kwargs
)
class BucketError(S3ClientError): class BucketError(S3ClientError):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs.setdefault("template", "bucket_error") kwargs.setdefault("template", "bucket_error")
@ -473,3 +489,16 @@ class InvalidContinuationToken(S3ClientError):
*args, *args,
**kwargs **kwargs
) )
class InvalidFilterRuleName(InvalidArgumentError):
code = 400
def __init__(self, value, *args, **kwargs):
super(InvalidFilterRuleName, self).__init__(
"filter rule name must be either prefix or suffix",
"FilterRule.Name",
value,
*args,
**kwargs
)

View File

@ -534,7 +534,7 @@ class LifecycleAndFilter(BaseModel):
for key, value in self.tags.items(): for key, value in self.tags.items():
data.append( data.append(
{"type": "LifecycleTagPredicate", "tag": {"key": key, "value": value},} {"type": "LifecycleTagPredicate", "tag": {"key": key, "value": value}}
) )
return data return data
@ -1058,9 +1058,6 @@ class FakeBucket(CloudFormationModel):
self.accelerate_configuration = accelerate_config self.accelerate_configuration = accelerate_config
def set_website_configuration(self, website_configuration):
self.website_configuration = website_configuration
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
@ -1382,12 +1379,16 @@ class S3Backend(BaseBackend):
def set_bucket_website_configuration(self, bucket_name, website_configuration): def set_bucket_website_configuration(self, bucket_name, website_configuration):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
bucket.set_website_configuration(website_configuration) bucket.website_configuration = website_configuration
def get_bucket_website_configuration(self, bucket_name): def get_bucket_website_configuration(self, bucket_name):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)
return bucket.website_configuration return bucket.website_configuration
def delete_bucket_website(self, bucket_name):
bucket = self.get_bucket(bucket_name)
bucket.website_configuration = None
def get_bucket_public_access_block(self, bucket_name): def get_bucket_public_access_block(self, bucket_name):
bucket = self.get_bucket(bucket_name) bucket = self.get_bucket(bucket_name)

View File

@ -60,6 +60,7 @@ from .models import (
FakeGrant, FakeGrant,
FakeAcl, FakeAcl,
FakeKey, FakeKey,
FakeMultipart,
) )
from .utils import ( from .utils import (
bucket_name_from_url, bucket_name_from_url,
@ -109,6 +110,7 @@ ACTION_MAP = {
"DELETE": { "DELETE": {
"lifecycle": "PutLifecycleConfiguration", "lifecycle": "PutLifecycleConfiguration",
"policy": "DeleteBucketPolicy", "policy": "DeleteBucketPolicy",
"website": "DeleteBucketWebsite",
"tagging": "PutBucketTagging", "tagging": "PutBucketTagging",
"cors": "PutBucketCORS", "cors": "PutBucketCORS",
"public_access_block": "DeletePublicAccessBlock", "public_access_block": "DeletePublicAccessBlock",
@ -815,6 +817,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
elif "tagging" in querystring: elif "tagging" in querystring:
self.backend.delete_bucket_tagging(bucket_name) self.backend.delete_bucket_tagging(bucket_name)
return 204, {}, "" return 204, {}, ""
elif "website" in querystring:
self.backend.delete_bucket_website(bucket_name)
return 204, {}, ""
elif "cors" in querystring: elif "cors" in querystring:
self.backend.delete_bucket_cors(bucket_name) self.backend.delete_bucket_cors(bucket_name)
return 204, {}, "" return 204, {}, ""
@ -1212,7 +1217,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if_unmodified_since = str_to_rfc_1123_datetime(if_unmodified_since) if_unmodified_since = str_to_rfc_1123_datetime(if_unmodified_since)
if key.last_modified > if_unmodified_since: if key.last_modified > if_unmodified_since:
raise PreconditionFailed("If-Unmodified-Since") raise PreconditionFailed("If-Unmodified-Since")
if if_match and key.etag != if_match: if if_match and key.etag not in [if_match, '"{0}"'.format(if_match)]:
raise PreconditionFailed("If-Match") raise PreconditionFailed("If-Match")
if if_modified_since: if if_modified_since:
@ -1509,6 +1514,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
grants = [] grants = []
for header, value in headers.items(): for header, value in headers.items():
header = header.lower()
if not header.startswith("x-amz-grant-"): if not header.startswith("x-amz-grant-"):
continue continue
@ -1523,7 +1529,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
grantees = [] grantees = []
for key_and_value in value.split(","): for key_and_value in value.split(","):
key, value = re.match( key, value = re.match(
'([^=]+)="([^"]+)"', key_and_value.strip() '([^=]+)="?([^"]+)"?', key_and_value.strip()
).groups() ).groups()
if key.lower() == "id": if key.lower() == "id":
grantees.append(FakeGrantee(id=value)) grantees.append(FakeGrantee(id=value))
@ -1765,7 +1771,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if body == b"" and "uploads" in query: if body == b"" and "uploads" in query:
metadata = metadata_from_headers(request.headers) metadata = metadata_from_headers(request.headers)
multipart = self.backend.initiate_multipart(bucket_name, key_name, metadata) multipart = FakeMultipart(key_name, metadata)
multipart.storage = request.headers.get("x-amz-storage-class", "STANDARD")
bucket = self.backend.get_bucket(bucket_name)
bucket.multiparts[multipart.id] = multipart
template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE) template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE)
response = template.render( response = template.render(
@ -1775,8 +1785,26 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if query.get("uploadId"): if query.get("uploadId"):
body = self._complete_multipart_body(body) body = self._complete_multipart_body(body)
upload_id = query["uploadId"][0] multipart_id = query["uploadId"][0]
key = self.backend.complete_multipart(bucket_name, upload_id, body)
bucket = self.backend.get_bucket(bucket_name)
multipart = bucket.multiparts[multipart_id]
value, etag = multipart.complete(body)
if value is None:
return 400, {}, ""
del bucket.multiparts[multipart_id]
key = self.backend.set_object(
bucket_name,
multipart.key_name,
value,
storage=multipart.storage,
etag=etag,
multipart=multipart,
)
key.set_metadata(multipart.metadata)
template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE)
headers = {} headers = {}
if key.version_id: if key.version_id:
@ -1788,6 +1816,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
bucket_name=bucket_name, key_name=key.name, etag=key.etag bucket_name=bucket_name, key_name=key.name, etag=key.etag
), ),
) )
elif "restore" in query: elif "restore" in query:
es = minidom.parseString(body).getElementsByTagName("Days") es = minidom.parseString(body).getElementsByTagName("Days")
days = es[0].childNodes[0].wholeText days = es[0].childNodes[0].wholeText
@ -1797,6 +1826,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
r = 200 r = 200
key.restore(int(days)) key.restore(int(days))
return r, {}, "" return r, {}, ""
else: else:
raise NotImplementedError( raise NotImplementedError(
"Method POST had only been implemented for multipart uploads and restore operations, so far" "Method POST had only been implemented for multipart uploads and restore operations, so far"
@ -2237,7 +2267,7 @@ S3_ALL_MULTIPARTS = (
<KeyMarker></KeyMarker> <KeyMarker></KeyMarker>
<UploadIdMarker></UploadIdMarker> <UploadIdMarker></UploadIdMarker>
<MaxUploads>1000</MaxUploads> <MaxUploads>1000</MaxUploads>
<IsTruncated>False</IsTruncated> <IsTruncated>false</IsTruncated>
{% for upload in uploads %} {% for upload in uploads %}
<Upload> <Upload>
<Key>{{ upload.key_name }}</Key> <Key>{{ upload.key_name }}</Key>
@ -2355,7 +2385,7 @@ S3_NO_LOGGING_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
""" """
S3_ENCRYPTION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?> S3_ENCRYPTION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
<BucketEncryptionStatus xmlns="http://doc.s3.amazonaws.com/2006-03-01"> <ServerSideEncryptionConfiguration xmlns="http://doc.s3.amazonaws.com/2006-03-01">
{% for entry in encryption %} {% for entry in encryption %}
<Rule> <Rule>
<ApplyServerSideEncryptionByDefault> <ApplyServerSideEncryptionByDefault>
@ -2364,9 +2394,10 @@ S3_ENCRYPTION_CONFIG = """<?xml version="1.0" encoding="UTF-8"?>
<KMSMasterKeyID>{{ entry["Rule"]["ApplyServerSideEncryptionByDefault"]["KMSMasterKeyID"] }}</KMSMasterKeyID> <KMSMasterKeyID>{{ entry["Rule"]["ApplyServerSideEncryptionByDefault"]["KMSMasterKeyID"] }}</KMSMasterKeyID>
{% endif %} {% endif %}
</ApplyServerSideEncryptionByDefault> </ApplyServerSideEncryptionByDefault>
<BucketKeyEnabled>{{ 'true' if entry["Rule"].get("BucketKeyEnabled") == 'true' else 'false' }}</BucketKeyEnabled>
</Rule> </Rule>
{% endfor %} {% endfor %}
</BucketEncryptionStatus> </ServerSideEncryptionConfiguration>
""" """
S3_INVALID_PRESIGNED_PARAMETERS = """<?xml version="1.0" encoding="UTF-8"?> S3_INVALID_PRESIGNED_PARAMETERS = """<?xml version="1.0" encoding="UTF-8"?>

View File

@ -144,23 +144,28 @@ class _VersionedKeyStore(dict):
super(_VersionedKeyStore, self).__setitem__(key, list_) super(_VersionedKeyStore, self).__setitem__(key, list_)
def _iteritems(self): def _iteritems(self):
for key in self: for key in self._self_iterable():
yield key, self[key] yield key, self[key]
def _itervalues(self): def _itervalues(self):
for key in self: for key in self._self_iterable():
yield self[key] yield self[key]
def _iterlists(self): def _iterlists(self):
for key in self: for key in self._self_iterable():
yield key, self.getlist(key) yield key, self.getlist(key)
def item_size(self): def item_size(self):
size = 0 size = 0
for val in self.values(): for val in self._self_iterable().values():
size += sys.getsizeof(val) size += sys.getsizeof(val)
return size return size
def _self_iterable(self):
# to enable concurrency, return a copy, to avoid "dictionary changed size during iteration"
# TODO: look into replacing with a locking mechanism, potentially
return dict(self)
items = iteritems = _iteritems items = iteritems = _iteritems
lists = iterlists = _iterlists lists = iterlists = _iterlists
values = itervalues = _itervalues values = itervalues = _itervalues

View File

@ -7,6 +7,7 @@ import uuid
import datetime import datetime
from boto3 import Session from boto3 import Session
from typing import List, Tuple
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from .exceptions import ( from .exceptions import (
@ -566,15 +567,46 @@ class SecretsManagerBackend(BaseBackend):
return response return response
def list_secrets(self, filters, max_results, next_token): def list_secrets(
# TODO implement pagination and limits self, filters: List, max_results: int = 100, next_token: str = None
) -> Tuple[List, str]:
"""
Returns secrets from secretsmanager.
The result is paginated and page items depends on the token value, because token contains start element
number of secret list.
Response example:
{
SecretList: [
{
ARN: 'arn:aws:secretsmanager:us-east-1:1234567890:secret:test1-gEcah',
Name: 'test1',
...
},
{
ARN: 'arn:aws:secretsmanager:us-east-1:1234567890:secret:test2-KZwml',
Name: 'test2',
...
}
],
NextToken: '2'
}
:param filters: (List) Filter parameters.
:param max_results: (int) Max number of results per page.
:param next_token: (str) Page token.
:return: (Tuple[List,str]) Returns result list and next token.
"""
secret_list = [] secret_list = []
for secret in self.secrets.values(): for secret in self.secrets.values():
if _matches(secret, filters): if _matches(secret, filters):
secret_list.append(secret.to_dict()) secret_list.append(secret.to_dict())
return secret_list, None starting_point = int(next_token or 0)
ending_point = starting_point + int(max_results or 100)
secret_page = secret_list[starting_point:ending_point]
new_next_token = str(ending_point) if ending_point < len(secret_list) else None
return secret_page, new_next_token
def delete_secret( def delete_secret(
self, secret_id, recovery_window_in_days, force_delete_without_recovery self, secret_id, recovery_window_in_days, force_delete_without_recovery

View File

@ -441,7 +441,7 @@ class SNSBackend(BaseBackend):
return self._get_values_nexttoken(self.topics, next_token) return self._get_values_nexttoken(self.topics, next_token)
def delete_topic_subscriptions(self, topic): def delete_topic_subscriptions(self, topic):
for key, value in self.subscriptions.items(): for key, value in dict(self.subscriptions).items():
if value.topic == topic: if value.topic == topic:
self.subscriptions.pop(key) self.subscriptions.pop(key)
@ -585,7 +585,10 @@ class SNSBackend(BaseBackend):
): ):
for endpoint in self.platform_endpoints.values(): for endpoint in self.platform_endpoints.values():
if token == endpoint.token: if token == endpoint.token:
if attributes["Enabled"].lower() == endpoint.attributes["Enabled"]: if (
attributes.get("Enabled", "").lower()
== endpoint.attributes["Enabled"]
):
return endpoint return endpoint
raise DuplicateSnsEndpointError( raise DuplicateSnsEndpointError(
"Duplicate endpoint token with different attributes: %s" % token "Duplicate endpoint token with different attributes: %s" % token

View File

@ -477,6 +477,7 @@ class Queue(CloudFormationModel):
@property @property
def messages(self): def messages(self):
# TODO: This can become very inefficient if a large number of messages are in-flight
return [ return [
message message
for message in self._messages for message in self._messages
@ -832,6 +833,7 @@ class SQSBackend(BaseBackend):
if ( if (
queue.dead_letter_queue is not None queue.dead_letter_queue is not None
and queue.redrive_policy
and message.approximate_receive_count and message.approximate_receive_count
>= queue.redrive_policy["maxReceiveCount"] >= queue.redrive_policy["maxReceiveCount"]
): ):

View File

@ -3,6 +3,13 @@ import random
import string import string
def str2bool(v):
if v in ("yes", True, "true", "True", "TRUE", "t", "1"):
return True
elif v in ("no", False, "false", "False", "FALSE", "f", "0"):
return False
def random_string(length=None): def random_string(length=None):
n = length or 20 n = length or 20
random_str = "".join( random_str = "".join(
@ -20,3 +27,10 @@ def load_resource(filename, as_json=True):
""" """
with open(filename, "r", encoding="utf-8") as f: with open(filename, "r", encoding="utf-8") as f:
return json.load(f) if as_json else f.read() return json.load(f) if as_json else f.read()
def merge_multiple_dicts(*args):
result = {}
for d in args:
result.update(d)
return result

View File

@ -1,7 +0,0 @@
FROM python:3.7-buster
ADD . /moto/
ENV PYTHONUNBUFFERED 1
WORKDIR /moto/
RUN make init
RUN make test

View File

@ -0,0 +1,11 @@
TestAccAWSEc2TransitGatewayDxGatewayAttachmentDataSource
TestAccAWSEc2TransitGatewayPeeringAttachment
TestAccAWSEc2TransitGatewayPeeringAttachmentAccepter
TestAccAWSEc2TransitGatewayPeeringAttachmentDataSource
TestAccAWSEc2TransitGatewayRoute
TestAccAWSEc2TransitGatewayRouteTableAssociation
TestAccAWSEc2TransitGatewayRouteTablePropagation
TestAccAWSEc2TransitGatewayVpcAttachment
TestAccAWSEc2TransitGatewayVpcAttachmentDataSource
TestAccAWSFms
TestAccAWSIAMRolePolicy

View File

@ -4,29 +4,51 @@ TestAccAWSBillingServiceAccount
TestAccAWSCallerIdentity TestAccAWSCallerIdentity
TestAccAWSCloudTrailServiceAccount TestAccAWSCloudTrailServiceAccount
TestAccAWSCloudWatchDashboard TestAccAWSCloudWatchDashboard
TestAccAWSCloudWatchEventApiDestination
TestAccAWSCloudWatchEventArchive
TestAccAWSCloudWatchEventBus
TestAccAWSCloudwatchLogGroupDataSource
TestAccAWSDataSourceCloudwatch
TestAccAWSDataSourceElasticBeanstalkHostedZone TestAccAWSDataSourceElasticBeanstalkHostedZone
TestAccAWSDataSourceIAMGroup TestAccAWSDataSourceIAMGroup
TestAccAWSDataSourceIAMInstanceProfile TestAccAWSDataSourceIAMInstanceProfile
TestAccAWSDataSourceIAMPolicy TestAccAWSDataSourceIAMPolicy
TestAccAWSDataSourceIAMPolicyDocument
TestAccAWSDataSourceIAMRole TestAccAWSDataSourceIAMRole
TestAccAWSDataSourceIAMSessionContext TestAccAWSDataSourceIAMSessionContext
TestAccAWSDataSourceIAMUser TestAccAWSDataSourceIAMUser
TestAccAWSDefaultSecurityGroup
TestAccAWSDefaultSubnet TestAccAWSDefaultSubnet
TestAccAWSDefaultTagsDataSource TestAccAWSDefaultTagsDataSource
TestAccAWSDynamoDbTableItem TestAccAWSDynamoDbTableItem
TestAccAWSEc2InstanceTypeOfferingDataSource TestAccAWSEc2InstanceTypeOfferingDataSource
TestAccAWSEc2InstanceTypeOfferingsDataSource TestAccAWSEc2InstanceTypeOfferingsDataSource
TestAccAWSEc2Tag
TestAccAWSEc2TransitGateway
TestAccAWSEc2TransitGatewayDataSource
TestAccAWSEc2TransitGatewayRouteTable
TestAccAWSEc2TransitGatewayRouteTableDataSource
TestAccAWSEc2TransitGatewayVpcAttachmentAccepter
TestAccAWSEc2TransitGatewayVpnAttachmentDataSource
TestAccAWSElasticBeanstalkSolutionStackDataSource TestAccAWSElasticBeanstalkSolutionStackDataSource
TestAccAWSElbHostedZoneId TestAccAWSElbHostedZoneId
TestAccAWSElbServiceAccount TestAccAWSElbServiceAccount
TestAccAWSFms
TestAccAWSGroupMembership
TestAccAWSIAMAccountAlias TestAccAWSIAMAccountAlias
TestAccAWSIAMGroupPolicy
TestAccAWSIAMGroupPolicyAttachment TestAccAWSIAMGroupPolicyAttachment
TestAccAWSIAMRole
TestAccAWSIAMUserPolicy
TestAccAWSIPRanges TestAccAWSIPRanges
TestAccAWSKmsSecretDataSource TestAccAWSKmsSecretDataSource
TestAccAWSPartition TestAccAWSPartition
TestAccAWSProvider TestAccAWSProvider
TestAccAWSRedshiftServiceAccount TestAccAWSRedshiftServiceAccount
TestAccAWSRolePolicyAttachment
TestAccAWSSNSSMSPreferences TestAccAWSSNSSMSPreferences
TestAccAWSSageMakerPrebuiltECRImage TestAccAWSSageMakerPrebuiltECRImage
TestAccAWSSsmParameterDataSource TestAccAWSSsmParameterDataSource
TestAccAWSUserGroupMembership
TestAccAWSUserPolicyAttachment TestAccAWSUserPolicyAttachment
TestAccAWSUserSSHKey

View File

@ -9,6 +9,7 @@ import sure # noqa
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
from moto import mock_apigateway, mock_cognitoidp, settings from moto import mock_apigateway, mock_cognitoidp, settings
from moto.apigateway.exceptions import NoIntegrationDefined
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from moto.core.models import responses_mock from moto.core.models import responses_mock
import pytest import pytest
@ -474,13 +475,7 @@ def test_integrations():
response.should.equal( response.should.equal(
{ {
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"httpMethod": "GET", "httpMethod": "POST",
"integrationResponses": {
"200": {
"responseTemplates": {"application/json": None},
"statusCode": 200,
}
},
"type": "HTTP", "type": "HTTP",
"uri": "http://httpbin.org/robots.txt", "uri": "http://httpbin.org/robots.txt",
} }
@ -495,13 +490,7 @@ def test_integrations():
response.should.equal( response.should.equal(
{ {
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"httpMethod": "GET", "httpMethod": "POST",
"integrationResponses": {
"200": {
"responseTemplates": {"application/json": None},
"statusCode": 200,
}
},
"type": "HTTP", "type": "HTTP",
"uri": "http://httpbin.org/robots.txt", "uri": "http://httpbin.org/robots.txt",
} }
@ -511,18 +500,10 @@ def test_integrations():
# this is hard to match against, so remove it # this is hard to match against, so remove it
response["ResponseMetadata"].pop("HTTPHeaders", None) response["ResponseMetadata"].pop("HTTPHeaders", None)
response["ResponseMetadata"].pop("RetryAttempts", None) response["ResponseMetadata"].pop("RetryAttempts", None)
response["resourceMethods"]["GET"]["httpMethod"].should.equal("GET")
response["resourceMethods"]["GET"]["authorizationType"].should.equal("none")
response["resourceMethods"]["GET"]["methodIntegration"].should.equal( response["resourceMethods"]["GET"]["methodIntegration"].should.equal(
{ {"httpMethod": "POST", "type": "HTTP", "uri": "http://httpbin.org/robots.txt",}
"httpMethod": "GET",
"integrationResponses": {
"200": {
"responseTemplates": {"application/json": None},
"statusCode": 200,
}
},
"type": "HTTP",
"uri": "http://httpbin.org/robots.txt",
}
) )
client.delete_integration(restApiId=api_id, resourceId=root_id, httpMethod="GET") client.delete_integration(restApiId=api_id, resourceId=root_id, httpMethod="GET")
@ -611,7 +592,7 @@ def test_integration_response():
"statusCode": "200", "statusCode": "200",
"selectionPattern": "foobar", "selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None}, "responseTemplates": {}, # Note: TF compatibility
} }
) )
@ -626,7 +607,7 @@ def test_integration_response():
"statusCode": "200", "statusCode": "200",
"selectionPattern": "foobar", "selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None}, "responseTemplates": {}, # Note: TF compatibility
} }
) )
@ -637,7 +618,7 @@ def test_integration_response():
response["methodIntegration"]["integrationResponses"].should.equal( response["methodIntegration"]["integrationResponses"].should.equal(
{ {
"200": { "200": {
"responseTemplates": {"application/json": None}, "responseTemplates": {}, # Note: TF compatibility
"selectionPattern": "foobar", "selectionPattern": "foobar",
"statusCode": "200", "statusCode": "200",
} }
@ -687,7 +668,7 @@ def test_integration_response():
"statusCode": "200", "statusCode": "200",
"selectionPattern": "foobar", "selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None}, "responseTemplates": {}, # Note: TF compatibility
"contentHandling": "CONVERT_TO_BINARY", "contentHandling": "CONVERT_TO_BINARY",
} }
) )
@ -703,7 +684,7 @@ def test_integration_response():
"statusCode": "200", "statusCode": "200",
"selectionPattern": "foobar", "selectionPattern": "foobar",
"ResponseMetadata": {"HTTPStatusCode": 200}, "ResponseMetadata": {"HTTPStatusCode": 200},
"responseTemplates": {"application/json": None}, "responseTemplates": {}, # Note: TF compatibility
"contentHandling": "CONVERT_TO_BINARY", "contentHandling": "CONVERT_TO_BINARY",
} }
) )
@ -1277,7 +1258,7 @@ def test_create_deployment_requires_REST_method_integrations():
with pytest.raises(ClientError) as ex: with pytest.raises(ClientError) as ex:
client.create_deployment(restApiId=api_id, stageName=stage_name)["id"] client.create_deployment(restApiId=api_id, stageName=stage_name)["id"]
ex.value.response["Error"]["Code"].should.equal("BadRequestException") ex.value.response["Error"]["Code"].should.equal("NotFoundException")
ex.value.response["Error"]["Message"].should.equal( ex.value.response["Error"]["Message"].should.equal(
"No integration defined for method" "No integration defined for method"
) )
@ -1873,7 +1854,7 @@ def test_http_proxying_integration():
httpMethod="GET", httpMethod="GET",
type="HTTP", type="HTTP",
uri="http://httpbin.org/robots.txt", uri="http://httpbin.org/robots.txt",
integrationHttpMethod="POST", integrationHttpMethod="GET",
) )
stage_name = "staging" stage_name = "staging"

View File

@ -2542,7 +2542,8 @@ def test_stack_elbv2_resources_integration():
] ]
}, },
} }
] ],
[{"Type": "forward", "TargetGroupArn": target_groups[1]["TargetGroupArn"]}],
) )
listener_rule[0]["Conditions"].should.equal( listener_rule[0]["Conditions"].should.equal(
[{"Field": "path-pattern", "Values": ["/*"]}] [{"Field": "path-pattern", "Values": ["/*"]}]

View File

@ -349,6 +349,75 @@ def test_get_metric_statistics():
datapoint["Sum"].should.equal(1.5) datapoint["Sum"].should.equal(1.5)
@mock_cloudwatch
def test_get_metric_statistics_dimensions():
conn = boto3.client("cloudwatch", region_name="us-east-1")
utc_now = datetime.now(tz=pytz.utc)
# put metric data with different dimensions
dimensions1 = [{"Name": "dim1", "Value": "v1"}]
dimensions2 = dimensions1 + [{"Name": "dim2", "Value": "v2"}]
metric_name = "metr-stats-dims"
conn.put_metric_data(
Namespace="tester",
MetricData=[
dict(
MetricName=metric_name,
Value=1,
Timestamp=utc_now,
Dimensions=dimensions1,
)
],
)
conn.put_metric_data(
Namespace="tester",
MetricData=[
dict(
MetricName=metric_name,
Value=2,
Timestamp=utc_now,
Dimensions=dimensions1,
)
],
)
conn.put_metric_data(
Namespace="tester",
MetricData=[
dict(
MetricName=metric_name,
Value=6,
Timestamp=utc_now,
Dimensions=dimensions2,
)
],
)
# list of (<kwargs>, <expectedSum>, <expectedAverage>)
params_list = (
# get metric stats with no restriction on dimensions
({}, 9, 3),
# get metric stats for dimensions1 (should also cover dimensions2)
({"Dimensions": dimensions1}, 9, 3),
# get metric stats for dimensions2 only
({"Dimensions": dimensions2}, 6, 6),
)
for params in params_list:
stats = conn.get_metric_statistics(
Namespace="tester",
MetricName=metric_name,
StartTime=utc_now - timedelta(seconds=60),
EndTime=utc_now + timedelta(seconds=60),
Period=60,
Statistics=["Average", "Sum"],
**params[0],
)
stats["Datapoints"].should.have.length_of(1)
datapoint = stats["Datapoints"][0]
datapoint["Sum"].should.equal(params[1])
datapoint["Average"].should.equal(params[2])
@mock_cloudwatch @mock_cloudwatch
def test_duplicate_put_metric_data(): def test_duplicate_put_metric_data():
conn = boto3.client("cloudwatch", region_name="us-east-1") conn = boto3.client("cloudwatch", region_name="us-east-1")
@ -501,16 +570,8 @@ def test_list_metrics():
# Verify format # Verify format
res.should.equal( res.should.equal(
[ [
{ {"Namespace": "list_test_1/", "Dimensions": [], "MetricName": "metric1",},
u"Namespace": "list_test_1/", {"Namespace": "list_test_1/", "Dimensions": [], "MetricName": "metric1",},
u"Dimensions": [],
u"MetricName": "metric1",
},
{
u"Namespace": "list_test_1/",
u"Dimensions": [],
u"MetricName": "metric1",
},
] ]
) )
# Verify unknown namespace still has no results # Verify unknown namespace still has no results

View File

@ -156,7 +156,7 @@ def test_create_task():
@mock_datasync @mock_datasync
def test_create_task_fail(): def test_create_task_fail():
""" Test that Locations must exist before a Task can be created """ """Test that Locations must exist before a Task can be created"""
client = boto3.client("datasync", region_name="us-east-1") client = boto3.client("datasync", region_name="us-east-1")
locations = create_locations(client, create_smb=True, create_s3=True) locations = create_locations(client, create_smb=True, create_s3=True)
with pytest.raises(ClientError) as e: with pytest.raises(ClientError) as e:

View File

@ -5970,7 +5970,7 @@ def test_dynamodb_update_item_fails_on_string_sets():
BillingMode="PAY_PER_REQUEST", BillingMode="PAY_PER_REQUEST",
) )
table.meta.client.get_waiter("table_exists").wait(TableName="test") table.meta.client.get_waiter("table_exists").wait(TableName="test")
attribute = {"test_field": {"Value": {"SS": ["test1", "test2"],}, "Action": "PUT"}} attribute = {"test_field": {"Value": {"SS": ["test1", "test2"],}, "Action": "PUT",}}
client.update_item( client.update_item(
TableName="test", TableName="test",

View File

@ -88,7 +88,7 @@ def test_validation_of_update_expression_with_keyword(table):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"update_expression", ["SET a = #b + :val2", "SET a = :val2 + #b",] "update_expression", ["SET a = #b + :val2", "SET a = :val2 + #b",],
) )
def test_validation_of_a_set_statement_with_incorrect_passed_value( def test_validation_of_a_set_statement_with_incorrect_passed_value(
update_expression, table update_expression, table
@ -150,7 +150,9 @@ def test_validation_of_update_expression_with_attribute_that_does_not_exist_in_i
assert True assert True
@pytest.mark.parametrize("update_expression", ["SET a = #c", "SET a = #c + #d",]) @pytest.mark.parametrize(
"update_expression", ["SET a = #c", "SET a = #c + #d",],
)
def test_validation_of_update_expression_with_attribute_name_that_is_not_defined( def test_validation_of_update_expression_with_attribute_name_that_is_not_defined(
update_expression, table, update_expression, table,
): ):

View File

@ -290,7 +290,7 @@ def test_ami_filters():
amis_by_architecture = conn.get_all_images(filters={"architecture": "x86_64"}) amis_by_architecture = conn.get_all_images(filters={"architecture": "x86_64"})
set([ami.id for ami in amis_by_architecture]).should.contain(imageB.id) set([ami.id for ami in amis_by_architecture]).should.contain(imageB.id)
len(amis_by_architecture).should.equal(35) len(amis_by_architecture).should.equal(37)
amis_by_kernel = conn.get_all_images(filters={"kernel-id": "k-abcd1234"}) amis_by_kernel = conn.get_all_images(filters={"kernel-id": "k-abcd1234"})
set([ami.id for ami in amis_by_kernel]).should.equal(set([imageB.id])) set([ami.id for ami in amis_by_kernel]).should.equal(set([imageB.id]))
@ -303,7 +303,7 @@ def test_ami_filters():
amis_by_platform = conn.get_all_images(filters={"platform": "windows"}) amis_by_platform = conn.get_all_images(filters={"platform": "windows"})
set([ami.id for ami in amis_by_platform]).should.contain(imageA.id) set([ami.id for ami in amis_by_platform]).should.contain(imageA.id)
len(amis_by_platform).should.equal(24) len(amis_by_platform).should.equal(25)
amis_by_id = conn.get_all_images(filters={"image-id": imageA.id}) amis_by_id = conn.get_all_images(filters={"image-id": imageA.id})
set([ami.id for ami in amis_by_id]).should.equal(set([imageA.id])) set([ami.id for ami in amis_by_id]).should.equal(set([imageA.id]))
@ -312,14 +312,14 @@ def test_ami_filters():
ami_ids_by_state = [ami.id for ami in amis_by_state] ami_ids_by_state = [ami.id for ami in amis_by_state]
ami_ids_by_state.should.contain(imageA.id) ami_ids_by_state.should.contain(imageA.id)
ami_ids_by_state.should.contain(imageB.id) ami_ids_by_state.should.contain(imageB.id)
len(amis_by_state).should.equal(36) len(amis_by_state).should.equal(38)
amis_by_name = conn.get_all_images(filters={"name": imageA.name}) amis_by_name = conn.get_all_images(filters={"name": imageA.name})
set([ami.id for ami in amis_by_name]).should.equal(set([imageA.id])) set([ami.id for ami in amis_by_name]).should.equal(set([imageA.id]))
amis_by_public = conn.get_all_images(filters={"is-public": "true"}) amis_by_public = conn.get_all_images(filters={"is-public": "true"})
set([ami.id for ami in amis_by_public]).should.contain(imageB.id) set([ami.id for ami in amis_by_public]).should.contain(imageB.id)
len(amis_by_public).should.equal(35) len(amis_by_public).should.equal(37)
amis_by_nonpublic = conn.get_all_images(filters={"is-public": "false"}) amis_by_nonpublic = conn.get_all_images(filters={"is-public": "false"})
set([ami.id for ami in amis_by_nonpublic]).should.contain(imageA.id) set([ami.id for ami in amis_by_nonpublic]).should.contain(imageA.id)

View File

@ -38,7 +38,8 @@ def test_delete_customer_gateways():
cgws[0].id.should.match(customer_gateway.id) cgws[0].id.should.match(customer_gateway.id)
deleted = conn.delete_customer_gateway(customer_gateway.id) deleted = conn.delete_customer_gateway(customer_gateway.id)
cgws = conn.get_all_customer_gateways() cgws = conn.get_all_customer_gateways()
cgws.should.have.length_of(0) cgws[0].state.should.equal("deleted")
cgws.should.have.length_of(1)
@mock_ec2_deprecated @mock_ec2_deprecated

View File

@ -16,7 +16,7 @@ SAMPLE_NAME_SERVERS = ["10.0.0.6", "10.0.0.7"]
@mock_ec2_deprecated @mock_ec2_deprecated
def test_dhcp_options_associate(): def test_dhcp_options_associate():
""" associate dhcp option """ """associate dhcp option"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS)
vpc = conn.create_vpc("10.0.0.0/16") vpc = conn.create_vpc("10.0.0.0/16")
@ -27,7 +27,7 @@ def test_dhcp_options_associate():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_dhcp_options_associate_invalid_dhcp_id(): def test_dhcp_options_associate_invalid_dhcp_id():
""" associate dhcp option bad dhcp options id """ """associate dhcp option bad dhcp options id"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc("10.0.0.0/16") vpc = conn.create_vpc("10.0.0.0/16")
@ -40,7 +40,7 @@ def test_dhcp_options_associate_invalid_dhcp_id():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_dhcp_options_associate_invalid_vpc_id(): def test_dhcp_options_associate_invalid_vpc_id():
""" associate dhcp option invalid vpc id """ """associate dhcp option invalid vpc id"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS)

View File

@ -205,7 +205,7 @@ def test_eip_boto3_vpc_association():
address.association_id.should.be.none address.association_id.should.be.none
address.instance_id.should.be.empty address.instance_id.should.be.empty
address.network_interface_id.should.be.empty address.network_interface_id.should.be.empty
association_id = client.associate_address( client.associate_address(
InstanceId=instance.id, AllocationId=allocation_id, AllowReassociation=False InstanceId=instance.id, AllocationId=allocation_id, AllowReassociation=False
) )
instance.load() instance.load()
@ -287,7 +287,6 @@ def test_eip_reassociate():
).should_not.throw(EC2ResponseError) ).should_not.throw(EC2ResponseError)
eip.release() eip.release()
eip = None
instance1.terminate() instance1.terminate()
instance2.terminate() instance2.terminate()
@ -326,7 +325,7 @@ def test_eip_reassociate_nic():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_eip_associate_invalid_args(): def test_eip_associate_invalid_args():
"""Associate EIP, invalid args """ """Associate EIP, invalid args"""
conn = boto.connect_ec2("the_key", "the_secret") conn = boto.connect_ec2("the_key", "the_secret")
reservation = conn.run_instances(EXAMPLE_AMI_ID) reservation = conn.run_instances(EXAMPLE_AMI_ID)

View File

@ -27,7 +27,7 @@ def test_elastic_network_interfaces():
"An error occurred (DryRunOperation) when calling the CreateNetworkInterface operation: Request would have succeeded, but DryRun flag is set" "An error occurred (DryRunOperation) when calling the CreateNetworkInterface operation: Request would have succeeded, but DryRun flag is set"
) )
eni = conn.create_network_interface(subnet.id) conn.create_network_interface(subnet.id)
all_enis = conn.get_all_network_interfaces() all_enis = conn.get_all_network_interfaces()
all_enis.should.have.length_of(1) all_enis.should.have.length_of(1)

View File

@ -6,6 +6,7 @@ import boto
import boto3 import boto3
from boto.exception import EC2ResponseError from boto.exception import EC2ResponseError
import sure # noqa import sure # noqa
from botocore.exceptions import ClientError
from moto import mock_ec2_deprecated, mock_ec2 from moto import mock_ec2_deprecated, mock_ec2
from tests import EXAMPLE_AMI_ID from tests import EXAMPLE_AMI_ID
@ -26,7 +27,7 @@ def test_console_output_without_instance():
with pytest.raises(EC2ResponseError) as cm: with pytest.raises(EC2ResponseError) as cm:
conn.get_console_output("i-1234abcd") conn.get_console_output("i-1234abcd")
cm.value.code.should.equal("InvalidInstanceID.NotFound") cm.value.error_code.should.equal("InvalidInstanceID.NotFound")
cm.value.status.should.equal(400) cm.value.status.should.equal(400)
cm.value.request_id.should_not.be.none cm.value.request_id.should_not.be.none

View File

@ -21,7 +21,7 @@ BAD_IGW = "igw-deadbeef"
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_create(): def test_igw_create():
""" internet gateway create """ """internet gateway create"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
conn.get_all_internet_gateways().should.have.length_of(0) conn.get_all_internet_gateways().should.have.length_of(0)
@ -44,7 +44,7 @@ def test_igw_create():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_attach(): def test_igw_attach():
""" internet gateway attach """ """internet gateway attach"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR) vpc = conn.create_vpc(VPC_CIDR)
@ -65,7 +65,7 @@ def test_igw_attach():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_attach_bad_vpc(): def test_igw_attach_bad_vpc():
""" internet gateway fail to attach w/ bad vpc """ """internet gateway fail to attach w/ bad vpc"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
@ -78,7 +78,7 @@ def test_igw_attach_bad_vpc():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_attach_twice(): def test_igw_attach_twice():
""" internet gateway fail to attach twice """ """internet gateway fail to attach twice"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
vpc1 = conn.create_vpc(VPC_CIDR) vpc1 = conn.create_vpc(VPC_CIDR)
@ -94,7 +94,7 @@ def test_igw_attach_twice():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_detach(): def test_igw_detach():
""" internet gateway detach""" """internet gateway detach"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR) vpc = conn.create_vpc(VPC_CIDR)
@ -115,7 +115,7 @@ def test_igw_detach():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_detach_wrong_vpc(): def test_igw_detach_wrong_vpc():
""" internet gateway fail to detach w/ wrong vpc """ """internet gateway fail to detach w/ wrong vpc"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
vpc1 = conn.create_vpc(VPC_CIDR) vpc1 = conn.create_vpc(VPC_CIDR)
@ -131,7 +131,7 @@ def test_igw_detach_wrong_vpc():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_detach_invalid_vpc(): def test_igw_detach_invalid_vpc():
""" internet gateway fail to detach w/ invalid vpc """ """internet gateway fail to detach w/ invalid vpc"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR) vpc = conn.create_vpc(VPC_CIDR)
@ -146,7 +146,7 @@ def test_igw_detach_invalid_vpc():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_detach_unattached(): def test_igw_detach_unattached():
""" internet gateway fail to detach unattached """ """internet gateway fail to detach unattached"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR) vpc = conn.create_vpc(VPC_CIDR)
@ -160,7 +160,7 @@ def test_igw_detach_unattached():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_delete(): def test_igw_delete():
""" internet gateway delete""" """internet gateway delete"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
vpc = conn.create_vpc(VPC_CIDR) vpc = conn.create_vpc(VPC_CIDR)
conn.get_all_internet_gateways().should.have.length_of(0) conn.get_all_internet_gateways().should.have.length_of(0)
@ -181,7 +181,7 @@ def test_igw_delete():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_delete_attached(): def test_igw_delete_attached():
""" internet gateway fail to delete attached """ """internet gateway fail to delete attached"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
vpc = conn.create_vpc(VPC_CIDR) vpc = conn.create_vpc(VPC_CIDR)
@ -196,7 +196,7 @@ def test_igw_delete_attached():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_desribe(): def test_igw_desribe():
""" internet gateway fetch by id """ """internet gateway fetch by id"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw = conn.create_internet_gateway() igw = conn.create_internet_gateway()
igw_by_search = conn.get_all_internet_gateways([igw.id])[0] igw_by_search = conn.get_all_internet_gateways([igw.id])[0]
@ -205,7 +205,7 @@ def test_igw_desribe():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_describe_bad_id(): def test_igw_describe_bad_id():
""" internet gateway fail to fetch by bad id """ """internet gateway fail to fetch by bad id"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
with pytest.raises(EC2ResponseError) as cm: with pytest.raises(EC2ResponseError) as cm:
conn.get_all_internet_gateways([BAD_IGW]) conn.get_all_internet_gateways([BAD_IGW])
@ -216,7 +216,7 @@ def test_igw_describe_bad_id():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_filter_by_vpc_id(): def test_igw_filter_by_vpc_id():
""" internet gateway filter by vpc id """ """internet gateway filter by vpc id"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway() igw1 = conn.create_internet_gateway()
@ -231,7 +231,7 @@ def test_igw_filter_by_vpc_id():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_filter_by_tags(): def test_igw_filter_by_tags():
""" internet gateway filter by vpc id """ """internet gateway filter by vpc id"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway() igw1 = conn.create_internet_gateway()
@ -245,7 +245,7 @@ def test_igw_filter_by_tags():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_filter_by_internet_gateway_id(): def test_igw_filter_by_internet_gateway_id():
""" internet gateway filter by internet gateway id """ """internet gateway filter by internet gateway id"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway() igw1 = conn.create_internet_gateway()
@ -258,7 +258,7 @@ def test_igw_filter_by_internet_gateway_id():
@mock_ec2_deprecated @mock_ec2_deprecated
def test_igw_filter_by_attachment_state(): def test_igw_filter_by_attachment_state():
""" internet gateway filter by attachment state """ """internet gateway filter by attachment state"""
conn = boto.connect_vpc("the_key", "the_secret") conn = boto.connect_vpc("the_key", "the_secret")
igw1 = conn.create_internet_gateway() igw1 = conn.create_internet_gateway()

View File

@ -175,6 +175,7 @@ def test_modify_subnet_attribute_assign_ipv6_address_on_creation():
# 'map_public_ip_on_launch' is set when calling 'DescribeSubnets' action # 'map_public_ip_on_launch' is set when calling 'DescribeSubnets' action
subnet.reload() subnet.reload()
subnets = client.describe_subnets()
# For non default subnet, attribute value should be 'False' # For non default subnet, attribute value should be 'False'
subnet.assign_ipv6_address_on_creation.shouldnt.be.ok subnet.assign_ipv6_address_on_creation.shouldnt.be.ok

View File

@ -36,7 +36,7 @@ def test_describe_vpn_gateway():
@mock_ec2 @mock_ec2
def test_describe_vpn_connections_attachment_vpc_id_filter(): def test_describe_vpn_connections_attachment_vpc_id_filter():
""" describe_vpn_gateways attachment.vpc-id filter """ """describe_vpn_gateways attachment.vpc-id filter"""
ec2 = boto3.client("ec2", region_name="us-east-1") ec2 = boto3.client("ec2", region_name="us-east-1")
@ -60,7 +60,7 @@ def test_describe_vpn_connections_attachment_vpc_id_filter():
@mock_ec2 @mock_ec2
def test_describe_vpn_connections_state_filter_attached(): def test_describe_vpn_connections_state_filter_attached():
""" describe_vpn_gateways attachment.state filter - match attached """ """describe_vpn_gateways attachment.state filter - match attached"""
ec2 = boto3.client("ec2", region_name="us-east-1") ec2 = boto3.client("ec2", region_name="us-east-1")
@ -84,7 +84,7 @@ def test_describe_vpn_connections_state_filter_attached():
@mock_ec2 @mock_ec2
def test_describe_vpn_connections_state_filter_deatched(): def test_describe_vpn_connections_state_filter_deatched():
""" describe_vpn_gateways attachment.state filter - don't match detatched """ """describe_vpn_gateways attachment.state filter - don't match detatched"""
ec2 = boto3.client("ec2", region_name="us-east-1") ec2 = boto3.client("ec2", region_name="us-east-1")
@ -104,7 +104,7 @@ def test_describe_vpn_connections_state_filter_deatched():
@mock_ec2 @mock_ec2
def test_describe_vpn_connections_id_filter_match(): def test_describe_vpn_connections_id_filter_match():
""" describe_vpn_gateways vpn-gateway-id filter - match correct id """ """describe_vpn_gateways vpn-gateway-id filter - match correct id"""
ec2 = boto3.client("ec2", region_name="us-east-1") ec2 = boto3.client("ec2", region_name="us-east-1")
@ -121,7 +121,7 @@ def test_describe_vpn_connections_id_filter_match():
@mock_ec2 @mock_ec2
def test_describe_vpn_connections_id_filter_miss(): def test_describe_vpn_connections_id_filter_miss():
""" describe_vpn_gateways vpn-gateway-id filter - don't match """ """describe_vpn_gateways vpn-gateway-id filter - don't match"""
ec2 = boto3.client("ec2", region_name="us-east-1") ec2 = boto3.client("ec2", region_name="us-east-1")
@ -136,7 +136,7 @@ def test_describe_vpn_connections_id_filter_miss():
@mock_ec2 @mock_ec2
def test_describe_vpn_connections_type_filter_match(): def test_describe_vpn_connections_type_filter_match():
""" describe_vpn_gateways type filter - match """ """describe_vpn_gateways type filter - match"""
ec2 = boto3.client("ec2", region_name="us-east-1") ec2 = boto3.client("ec2", region_name="us-east-1")
@ -153,7 +153,7 @@ def test_describe_vpn_connections_type_filter_match():
@mock_ec2 @mock_ec2
def test_describe_vpn_connections_type_filter_miss(): def test_describe_vpn_connections_type_filter_miss():
""" describe_vpn_gateways type filter - don't match """ """describe_vpn_gateways type filter - don't match"""
ec2 = boto3.client("ec2", region_name="us-east-1") ec2 = boto3.client("ec2", region_name="us-east-1")

View File

@ -8,7 +8,7 @@ import boto3
import boto import boto
from boto.exception import EC2ResponseError from boto.exception import EC2ResponseError
# import sure # noqa import sure # noqa
from moto import mock_ec2, mock_ec2_deprecated from moto import mock_ec2, mock_ec2_deprecated
@ -919,4 +919,4 @@ def test_describe_vpc_end_points():
VpcEndpointIds=[route_table.get("RouteTable").get("RouteTableId")] VpcEndpointIds=[route_table.get("RouteTable").get("RouteTableId")]
) )
except ClientError as err: except ClientError as err:
assert err.response["Error"]["Code"] == "InvalidVpcEndPointId.NotFound" assert err.response["Error"]["Code"] == "InvalidVpcEndpointId.NotFound"

View File

@ -29,7 +29,7 @@ def test_delete_vpn_connections():
list_of_vpn_connections.should.have.length_of(1) list_of_vpn_connections.should.have.length_of(1)
conn.delete_vpn_connection(vpn_connection.id) conn.delete_vpn_connection(vpn_connection.id)
list_of_vpn_connections = conn.get_all_vpn_connections() list_of_vpn_connections = conn.get_all_vpn_connections()
list_of_vpn_connections.should.have.length_of(0) list_of_vpn_connections[0].state.should.equal("deleted")
@mock_ec2_deprecated @mock_ec2_deprecated

View File

@ -88,3 +88,12 @@ def test_event_pattern_with_multi_numeric_event_filter():
assert two_or_three.matches_event(events[2]) assert two_or_three.matches_event(events[2])
assert two_or_three.matches_event(events[3]) assert two_or_three.matches_event(events[3])
assert not two_or_three.matches_event(events[4]) assert not two_or_three.matches_event(events[4])
@pytest.mark.parametrize(
"pattern, expected_str",
[('{"source": ["foo", "bar"]}', '{"source": ["foo", "bar"]}'), (None, ""),],
)
def test_event_pattern_str(pattern, expected_str):
event_pattern = EventPattern(pattern)
assert str(event_pattern) == expected_str

View File

@ -4,11 +4,10 @@ import unittest
from datetime import datetime from datetime import datetime
import boto3 import boto3
import pytest
import pytz import pytz
import sure # noqa import sure # noqa
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
import pytest
from moto import mock_logs from moto import mock_logs
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
@ -2157,15 +2156,15 @@ def test_create_and_list_connections():
}, },
) )
assert response.get( response.get("ConnectionArn").should.contain(
"ConnectionArn" "arn:aws:events:eu-central-1:{0}:connection/test/".format(ACCOUNT_ID)
) == "arn:aws:events:eu-central-1:{0}:connection/test".format(ACCOUNT_ID) )
response = client.list_connections() response = client.list_connections()
assert response.get("Connections")[0].get( response.get("Connections")[0].get("ConnectionArn").should.contain(
"ConnectionArn" "arn:aws:events:eu-central-1:{0}:connection/test/".format(ACCOUNT_ID)
) == "arn:aws:events:eu-central-1:{0}:connection/test".format(ACCOUNT_ID) )
@mock_events @mock_events
@ -2189,27 +2188,116 @@ def test_create_and_list_api_destinations():
HttpMethod="GET", HttpMethod="GET",
) )
assert destination_response.get( arn_without_uuid = f"arn:aws:events:eu-central-1:{ACCOUNT_ID}:api-destination/test/"
"ApiDestinationArn" assert destination_response.get("ApiDestinationArn").startswith(arn_without_uuid)
) == "arn:aws:events:eu-central-1:{0}:destination/test".format(ACCOUNT_ID)
assert destination_response.get("ApiDestinationState") == "ACTIVE" assert destination_response.get("ApiDestinationState") == "ACTIVE"
destination_response = client.describe_api_destination(Name="test") destination_response = client.describe_api_destination(Name="test")
assert destination_response.get( assert destination_response.get("ApiDestinationArn").startswith(arn_without_uuid)
"ApiDestinationArn"
) == "arn:aws:events:eu-central-1:{0}:destination/test".format(ACCOUNT_ID)
assert destination_response.get("Name") == "test" assert destination_response.get("Name") == "test"
assert destination_response.get("ApiDestinationState") == "ACTIVE" assert destination_response.get("ApiDestinationState") == "ACTIVE"
destination_response = client.list_api_destinations() destination_response = client.list_api_destinations()
assert destination_response.get("ApiDestinations")[0].get( assert (
"ApiDestinationArn" destination_response.get("ApiDestinations")[0]
) == "arn:aws:events:eu-central-1:{0}:destination/test".format(ACCOUNT_ID) .get("ApiDestinationArn")
.startswith(arn_without_uuid)
)
assert destination_response.get("ApiDestinations")[0].get("Name") == "test" assert destination_response.get("ApiDestinations")[0].get("Name") == "test"
assert ( assert (
destination_response.get("ApiDestinations")[0].get("ApiDestinationState") destination_response.get("ApiDestinations")[0].get("ApiDestinationState")
== "ACTIVE" == "ACTIVE"
) )
# Scenarios for describe_connection
# Scenario 01: Success
# Scenario 02: Failure - Connection not present
@mock_events
def test_describe_connection_success():
# Given
conn_name = "test_conn_name"
conn_description = "test_conn_description"
auth_type = "API_KEY"
auth_params = {
"ApiKeyAuthParameters": {"ApiKeyName": "test", "ApiKeyValue": "test"}
}
client = boto3.client("events", "eu-central-1")
_ = client.create_connection(
Name=conn_name,
Description=conn_description,
AuthorizationType=auth_type,
AuthParameters=auth_params,
)
# When
response = client.describe_connection(Name=conn_name)
# Then
assert response["Name"] == conn_name
assert response["Description"] == conn_description
assert response["AuthorizationType"] == auth_type
expected_auth_param = {"ApiKeyAuthParameters": {"ApiKeyName": "test"}}
assert response["AuthParameters"] == expected_auth_param
@mock_events
def test_describe_connection_not_present():
conn_name = "test_conn_name"
client = boto3.client("events", "eu-central-1")
# When/Then
with pytest.raises(ClientError):
_ = client.describe_connection(Name=conn_name)
# Scenarios for delete_connection
# Scenario 01: Success
# Scenario 02: Failure - Connection not present
@mock_events
def test_delete_connection_success():
# Given
conn_name = "test_conn_name"
conn_description = "test_conn_description"
auth_type = "API_KEY"
auth_params = {
"ApiKeyAuthParameters": {"ApiKeyName": "test", "ApiKeyValue": "test"}
}
client = boto3.client("events", "eu-central-1")
created_connection = client.create_connection(
Name=conn_name,
Description=conn_description,
AuthorizationType=auth_type,
AuthParameters=auth_params,
)
# When
response = client.delete_connection(Name=conn_name)
# Then
expected_arn = f"arn:aws:events:eu-central-1:{ACCOUNT_ID}:connection/{conn_name}/"
assert response["ConnectionArn"] == created_connection["ConnectionArn"]
assert response["ConnectionState"] == created_connection["ConnectionState"]
assert response["CreationTime"] == created_connection["CreationTime"]
with pytest.raises(ClientError):
_ = client.describe_connection(Name=conn_name)
@mock_events
def test_delete_connection_not_present():
conn_name = "test_conn_name"
client = boto3.client("events", "eu-central-1")
# When/Then
with pytest.raises(ClientError):
_ = client.delete_connection(Name=conn_name)

View File

@ -1,10 +1,11 @@
import os import os
import time import time
from unittest import SkipTest from unittest import SkipTest
import boto3 import boto3
from botocore.exceptions import ClientError
import pytest import pytest
import sure # noqa import sure # noqa
from botocore.exceptions import ClientError
from moto import mock_logs, settings from moto import mock_logs, settings
@ -12,14 +13,34 @@ _logs_region = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2"
@mock_logs @mock_logs
def test_create_log_group(): @pytest.mark.parametrize(
"kms_key_id",
[
"arn:aws:kms:us-east-1:000000000000:key/51d81fab-b138-4bd2-8a09-07fd6d37224d",
None,
],
)
def test_create_log_group(kms_key_id):
# Given
conn = boto3.client("logs", "us-west-2") conn = boto3.client("logs", "us-west-2")
response = conn.create_log_group(logGroupName="dummy") create_logs_params = dict(logGroupName="dummy")
if kms_key_id:
create_logs_params["kmsKeyId"] = kms_key_id
# When
response = conn.create_log_group(**create_logs_params)
response = conn.describe_log_groups() response = conn.describe_log_groups()
# Then
response["logGroups"].should.have.length_of(1) response["logGroups"].should.have.length_of(1)
response["logGroups"][0].should_not.have.key("retentionInDays")
log_group = response["logGroups"][0]
log_group.should_not.have.key("retentionInDays")
if kms_key_id:
log_group.should.have.key("kmsKeyId")
log_group["kmsKeyId"].should.equal(kms_key_id)
@mock_logs @mock_logs

View File

@ -0,0 +1,25 @@
import sure # noqa
from moto.logs.models import LogGroup
def test_log_group_to_describe_dict():
# Given
region = "us-east-1"
name = "test-log-group"
tags = {"TestTag": "TestValue"}
kms_key_id = (
"arn:aws:kms:us-east-1:000000000000:key/51d81fab-b138-4bd2-8a09-07fd6d37224d"
)
kwargs = dict(kmsKeyId=kms_key_id,)
# When
log_group = LogGroup(region, name, tags, **kwargs)
describe_dict = log_group.to_describe_dict()
# Then
expected_dict = dict(logGroupName=name, kmsKeyId=kms_key_id)
for attr, value in expected_dict.items():
describe_dict.should.have.key(attr)
describe_dict[attr].should.equal(value)

View File

@ -4972,7 +4972,9 @@ def test_encryption():
resp = conn.get_bucket_encryption(Bucket="mybucket") resp = conn.get_bucket_encryption(Bucket="mybucket")
assert "ServerSideEncryptionConfiguration" in resp assert "ServerSideEncryptionConfiguration" in resp
assert resp["ServerSideEncryptionConfiguration"] == sse_config return_config = sse_config.copy()
return_config["Rules"][0]["BucketKeyEnabled"] = False
assert resp["ServerSideEncryptionConfiguration"].should.equal(return_config)
conn.delete_bucket_encryption(Bucket="mybucket") conn.delete_bucket_encryption(Bucket="mybucket")
with pytest.raises(ClientError) as exc: with pytest.raises(ClientError) as exc:

View File

@ -9,7 +9,6 @@ from freezegun import freeze_time
import pytest import pytest
import sure # noqa import sure # noqa
from moto import mock_sts, mock_sts_deprecated, mock_iam, settings from moto import mock_sts, mock_sts_deprecated, mock_iam, settings
from moto.core import ACCOUNT_ID from moto.core import ACCOUNT_ID
from moto.sts.responses import MAX_FEDERATION_TOKEN_POLICY_LENGTH from moto.sts.responses import MAX_FEDERATION_TOKEN_POLICY_LENGTH
@ -749,3 +748,61 @@ def test_sts_regions(region):
client = boto3.client("sts", region_name=region) client = boto3.client("sts", region_name=region)
resp = client.get_caller_identity() resp = client.get_caller_identity()
resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200)
@mock_sts
@mock_iam
def test_get_caller_identity_with_iam_user_credentials():
iam_client = boto3.client("iam", region_name="us-east-1")
iam_user_name = "new-user"
iam_user = iam_client.create_user(UserName=iam_user_name)["User"]
access_key = iam_client.create_access_key(UserName=iam_user_name)["AccessKey"]
identity = boto3.client(
"sts",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
).get_caller_identity()
identity["Arn"].should.equal(iam_user["Arn"])
identity["UserId"].should.equal(iam_user["UserId"])
identity["Account"].should.equal(str(ACCOUNT_ID))
@mock_sts
@mock_iam
def test_get_caller_identity_with_assumed_role_credentials():
iam_client = boto3.client("iam", region_name="us-east-1")
sts_client = boto3.client("sts", region_name="us-east-1")
iam_role_name = "new-user"
trust_policy_document = {
"Version": "2012-10-17",
"Statement": {
"Effect": "Allow",
"Principal": {
"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)
},
"Action": "sts:AssumeRole",
},
}
iam_role_arn = iam_client.role_arn = iam_client.create_role(
RoleName=iam_role_name,
AssumeRolePolicyDocument=json.dumps(trust_policy_document),
)["Role"]["Arn"]
session_name = "new-session"
assumed_role = sts_client.assume_role(
RoleArn=iam_role_arn, RoleSessionName=session_name
)
access_key = assumed_role["Credentials"]
identity = boto3.client(
"sts",
region_name="us-east-1",
aws_access_key_id=access_key["AccessKeyId"],
aws_secret_access_key=access_key["SecretAccessKey"],
).get_caller_identity()
identity["Arn"].should.equal(assumed_role["AssumedRoleUser"]["Arn"])
identity["UserId"].should.equal(assumed_role["AssumedRoleUser"]["AssumedRoleId"])
identity["Account"].should.equal(str(ACCOUNT_ID))