Merge from master to current branch

This commit is contained in:
Ian Yon 2019-11-04 12:14:43 -03:00
commit 577419798d
514 changed files with 52633 additions and 47892 deletions

View File

@ -7,6 +7,7 @@ python:
- 2.7
- 3.6
- 3.7
- 3.8
env:
- TEST_SERVER_MODE=false
- TEST_SERVER_MODE=true
@ -17,7 +18,14 @@ install:
python setup.py sdist
if [ "$TEST_SERVER_MODE" = "true" ]; then
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh &
if [ "$TRAVIS_PYTHON_VERSION" = "3.8" ]; then
# Python 3.8 does not provide Stretch images yet [1]
# [1] https://github.com/docker-library/python/issues/428
PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-buster
else
PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-stretch
fi
docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh &
fi
travis_retry pip install boto==2.45.0
travis_retry pip install boto3
@ -29,7 +37,8 @@ install:
python wait_for.py
fi
script:
- make test
- make test-only
- if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then make lint; fi
after_success:
- coveralls
before_deploy:

View File

@ -14,12 +14,16 @@ init:
lint:
flake8 moto
black --check moto/ tests/
test: lint
test-only:
rm -f .coverage
rm -rf cover
@nosetests -sv --with-coverage --cover-html ./tests/ $(TEST_EXCLUDE)
test: lint test-only
test_server:
@TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/

View File

@ -7,9 +7,9 @@
[![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org)
![PyPI](https://img.shields.io/pypi/v/moto.svg)
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg)
![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
# In a nutshell
## In a nutshell
Moto is a library that allows your tests to easily mock out AWS Services.

View File

@ -1,64 +1,73 @@
from __future__ import unicode_literals
import logging
# import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto'
__version__ = '1.3.14.dev'
__title__ = "moto"
__version__ = "1.3.14.dev"
from .acm import mock_acm # flake8: noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa
from .athena import mock_athena # flake8: noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa
from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # flake8: noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # flake8: noqa
from .cognitoidentity import mock_cognitoidentity, mock_cognitoidentity_deprecated # flake8: noqa
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # flake8: noqa
from .config import mock_config # flake8: noqa
from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # flake8: noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # flake8: noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # flake8: noqa
from .dynamodbstreams import mock_dynamodbstreams # flake8: noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # flake8: noqa
from .ecr import mock_ecr, mock_ecr_deprecated # flake8: noqa
from .ecs import mock_ecs, mock_ecs_deprecated # flake8: noqa
from .elb import mock_elb, mock_elb_deprecated # flake8: noqa
from .elbv2 import mock_elbv2 # flake8: noqa
from .emr import mock_emr, mock_emr_deprecated # flake8: noqa
from .events import mock_events # flake8: noqa
from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa
from .glue import mock_glue # flake8: noqa
from .iam import mock_iam, mock_iam_deprecated # flake8: noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa
from .kms import mock_kms, mock_kms_deprecated # flake8: noqa
from .organizations import mock_organizations # flake8: noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa
from .polly import mock_polly # flake8: noqa
from .rds import mock_rds, mock_rds_deprecated # flake8: noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa
from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa
from .resourcegroups import mock_resourcegroups # flake8: noqa
from .s3 import mock_s3, mock_s3_deprecated # flake8: noqa
from .ses import mock_ses, mock_ses_deprecated # flake8: noqa
from .secretsmanager import mock_secretsmanager # flake8: noqa
from .sns import mock_sns, mock_sns_deprecated # flake8: noqa
from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa
from .stepfunctions import mock_stepfunctions # flake8: noqa
from .sts import mock_sts, mock_sts_deprecated # flake8: noqa
from .ssm import mock_ssm # flake8: noqa
from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa
from .swf import mock_swf, mock_swf_deprecated # flake8: noqa
from .xray import mock_xray, mock_xray_client, XRaySegment # flake8: noqa
from .logs import mock_logs, mock_logs_deprecated # flake8: noqa
from .batch import mock_batch # flake8: noqa
from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # flake8: noqa
from .iot import mock_iot # flake8: noqa
from .iotdata import mock_iotdata # flake8: noqa
from .acm import mock_acm # noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa
from .athena import mock_athena # noqa
from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # noqa
from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa
from .cognitoidentity import ( # noqa
mock_cognitoidentity,
mock_cognitoidentity_deprecated,
)
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa
from .config import mock_config # noqa
from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # noqa
from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa
from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa
from .dynamodbstreams import mock_dynamodbstreams # noqa
from .ec2 import mock_ec2, mock_ec2_deprecated # noqa
from .ecr import mock_ecr, mock_ecr_deprecated # noqa
from .ecs import mock_ecs, mock_ecs_deprecated # noqa
from .elb import mock_elb, mock_elb_deprecated # noqa
from .elbv2 import mock_elbv2 # noqa
from .emr import mock_emr, mock_emr_deprecated # noqa
from .events import mock_events # noqa
from .glacier import mock_glacier, mock_glacier_deprecated # noqa
from .glue import mock_glue # noqa
from .iam import mock_iam, mock_iam_deprecated # noqa
from .kinesis import mock_kinesis, mock_kinesis_deprecated # noqa
from .kms import mock_kms, mock_kms_deprecated # noqa
from .organizations import mock_organizations # noqa
from .opsworks import mock_opsworks, mock_opsworks_deprecated # noqa
from .polly import mock_polly # noqa
from .rds import mock_rds, mock_rds_deprecated # noqa
from .rds2 import mock_rds2, mock_rds2_deprecated # noqa
from .redshift import mock_redshift, mock_redshift_deprecated # noqa
from .resourcegroups import mock_resourcegroups # noqa
from .s3 import mock_s3, mock_s3_deprecated # noqa
from .ses import mock_ses, mock_ses_deprecated # noqa
from .secretsmanager import mock_secretsmanager # noqa
from .sns import mock_sns, mock_sns_deprecated # noqa
from .sqs import mock_sqs, mock_sqs_deprecated # noqa
from .stepfunctions import mock_stepfunctions # noqa
from .sts import mock_sts, mock_sts_deprecated # noqa
from .ssm import mock_ssm # noqa
from .route53 import mock_route53, mock_route53_deprecated # noqa
from .swf import mock_swf, mock_swf_deprecated # noqa
from .xray import mock_xray, mock_xray_client, XRaySegment # noqa
from .logs import mock_logs, mock_logs_deprecated # noqa
from .batch import mock_batch # noqa
from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # noqa
from .iot import mock_iot # noqa
from .iotdata import mock_iotdata # noqa
try:
# Need to monkey-patch botocore requests back to underlying urllib3 classes
from botocore.awsrequest import HTTPSConnectionPool, HTTPConnectionPool, HTTPConnection, VerifiedHTTPSConnection
from botocore.awsrequest import (
HTTPSConnectionPool,
HTTPConnectionPool,
HTTPConnection,
VerifiedHTTPSConnection,
)
except ImportError:
pass
else:

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import acm_backends
from ..core.models import base_decorator
acm_backend = acm_backends['us-east-1']
acm_backend = acm_backends["us-east-1"]
mock_acm = base_decorator(acm_backends)

View File

@ -57,20 +57,29 @@ class AWSError(Exception):
self.message = message
def response(self):
resp = {'__type': self.TYPE, 'message': self.message}
resp = {"__type": self.TYPE, "message": self.message}
return json.dumps(resp), dict(status=self.STATUS)
class AWSValidationException(AWSError):
TYPE = 'ValidationException'
TYPE = "ValidationException"
class AWSResourceNotFoundException(AWSError):
TYPE = 'ResourceNotFoundException'
TYPE = "ResourceNotFoundException"
class CertBundle(BaseModel):
def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'):
def __init__(
self,
certificate,
private_key,
chain=None,
region="us-east-1",
arn=None,
cert_type="IMPORTED",
cert_status="ISSUED",
):
self.created_at = datetime.datetime.now()
self.cert = certificate
self._cert = None
@ -87,7 +96,7 @@ class CertBundle(BaseModel):
if self.chain is None:
self.chain = GOOGLE_ROOT_CA
else:
self.chain += b'\n' + GOOGLE_ROOT_CA
self.chain += b"\n" + GOOGLE_ROOT_CA
# Takes care of PEM checking
self.validate_pk()
@ -114,149 +123,209 @@ class CertBundle(BaseModel):
sans.add(domain_name)
sans = [cryptography.x509.DNSName(item) for item in sans]
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend())
subject = cryptography.x509.Name([
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name),
])
issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"),
])
cert = cryptography.x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
key.public_key()
).serial_number(
cryptography.x509.random_serial_number()
).not_valid_before(
datetime.datetime.utcnow()
).not_valid_after(
datetime.datetime.utcnow() + datetime.timedelta(days=365)
).add_extension(
cryptography.x509.SubjectAlternativeName(sans),
critical=False,
).sign(key, hashes.SHA512(), default_backend())
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
public_exponent=65537, key_size=2048, backend=default_backend()
)
subject = cryptography.x509.Name(
[
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COUNTRY_NAME, "US"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, domain_name
),
]
)
issuer = cryptography.x509.Name(
[ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COUNTRY_NAME, "US"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, "Amazon"
),
]
)
cert = (
cryptography.x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(cryptography.x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension(
cryptography.x509.SubjectAlternativeName(sans), critical=False
)
.sign(key, hashes.SHA512(), default_backend())
)
cert_armored = cert.public_bytes(serialization.Encoding.PEM)
private_key = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption()
encryption_algorithm=serialization.NoEncryption(),
)
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION', region=region)
return cls(
cert_armored,
private_key,
cert_type="AMAZON_ISSUED",
cert_status="PENDING_VALIDATION",
region=region,
)
def validate_pk(self):
try:
self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend())
self._key = serialization.load_pem_private_key(
self.key, password=None, backend=default_backend()
)
if self._key.key_size > 2048:
AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.')
AWSValidationException(
"The private key length is not supported. Only 1024-bit and 2048-bit are allowed."
)
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The private key is not PEM-encoded or is not valid.')
raise AWSValidationException(
"The private key is not PEM-encoded or is not valid."
)
def validate_certificate(self):
try:
self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend())
self._cert = cryptography.x509.load_pem_x509_certificate(
self.cert, default_backend()
)
now = datetime.datetime.utcnow()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate has expired, is not valid.')
raise AWSValidationException(
"The certificate has expired, is not valid."
)
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate is not in effect yet, is not valid.')
raise AWSValidationException(
"The certificate is not in effect yet, is not valid."
)
# Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs
self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value
self.common_name = self._cert.subject.get_attributes_for_oid(
cryptography.x509.OID_COMMON_NAME
)[0].value
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def validate_chain(self):
try:
self._chain = []
for cert_armored in self.chain.split(b'-\n-'):
for cert_armored in self.chain.split(b"-\n-"):
# Would leave encoded but Py2 does not have raw binary strings
cert_armored = cert_armored.decode()
# Fix missing -'s on split
cert_armored = re.sub(r'^----B', '-----B', cert_armored)
cert_armored = re.sub(r'E----$', 'E-----', cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend())
cert_armored = re.sub(r"^----B", "-----B", cert_armored)
cert_armored = re.sub(r"E----$", "E-----", cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(
cert_armored.encode(), default_backend()
)
self._chain.append(cert)
now = datetime.datetime.now()
if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate chain has expired, is not valid.')
raise AWSValidationException(
"The certificate chain has expired, is not valid."
)
if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate chain is not in effect yet, is not valid.')
raise AWSValidationException(
"The certificate chain is not in effect yet, is not valid."
)
except Exception as err:
if isinstance(err, AWSValidationException):
raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.')
raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def check(self):
# Basically, if the certificate is pending, and then checked again after 1 min
# It will appear as if its been validated
if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \
(datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min
self.status = 'ISSUED'
if (
self.type == "AMAZON_ISSUED"
and self.status == "PENDING_VALIDATION"
and (datetime.datetime.now() - self.created_at).total_seconds() > 60
): # 1min
self.status = "ISSUED"
def describe(self):
# 'RenewalSummary': {}, # Only when cert is amazon issued
if self._key.key_size == 1024:
key_algo = 'RSA_1024'
key_algo = "RSA_1024"
elif self._key.key_size == 2048:
key_algo = 'RSA_2048'
key_algo = "RSA_2048"
else:
key_algo = 'EC_prime256v1'
key_algo = "EC_prime256v1"
# Look for SANs
san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME)
san_obj = self._cert.extensions.get_extension_for_oid(
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
)
sans = []
if san_obj is not None:
sans = [item.value for item in san_obj.value]
result = {
'Certificate': {
'CertificateArn': self.arn,
'DomainName': self.common_name,
'InUseBy': [],
'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value,
'KeyAlgorithm': key_algo,
'NotAfter': datetime_to_epoch(self._cert.not_valid_after),
'NotBefore': datetime_to_epoch(self._cert.not_valid_before),
'Serial': self._cert.serial_number,
'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''),
'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
'Subject': 'CN={0}'.format(self.common_name),
'SubjectAlternativeNames': sans,
'Type': self.type # One of IMPORTED, AMAZON_ISSUED
"Certificate": {
"CertificateArn": self.arn,
"DomainName": self.common_name,
"InUseBy": [],
"Issuer": self._cert.issuer.get_attributes_for_oid(
cryptography.x509.OID_COMMON_NAME
)[0].value,
"KeyAlgorithm": key_algo,
"NotAfter": datetime_to_epoch(self._cert.not_valid_after),
"NotBefore": datetime_to_epoch(self._cert.not_valid_before),
"Serial": self._cert.serial_number,
"SignatureAlgorithm": self._cert.signature_algorithm_oid._name.upper().replace(
"ENCRYPTION", ""
),
"Status": self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
"Subject": "CN={0}".format(self.common_name),
"SubjectAlternativeNames": sans,
"Type": self.type, # One of IMPORTED, AMAZON_ISSUED
}
}
if self.type == 'IMPORTED':
result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at)
if self.type == "IMPORTED":
result["Certificate"]["ImportedAt"] = datetime_to_epoch(self.created_at)
else:
result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at)
result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at)
result["Certificate"]["CreatedAt"] = datetime_to_epoch(self.created_at)
result["Certificate"]["IssuedAt"] = datetime_to_epoch(self.created_at)
return result
@ -264,7 +333,7 @@ class CertBundle(BaseModel):
return self.arn
def __repr__(self):
return '<Certificate>'
return "<Certificate>"
class AWSCertificateManagerBackend(BaseBackend):
@ -281,7 +350,9 @@ class AWSCertificateManagerBackend(BaseBackend):
@staticmethod
def _arn_not_found(arn):
msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID)
msg = "Certificate with arn {0} not found in account {1}".format(
arn, DEFAULT_ACCOUNT_ID
)
return AWSResourceNotFoundException(msg)
def _get_arn_from_idempotency_token(self, token):
@ -298,17 +369,20 @@ class AWSCertificateManagerBackend(BaseBackend):
"""
now = datetime.datetime.now()
if token in self._idempotency_tokens:
if self._idempotency_tokens[token]['expires'] < now:
if self._idempotency_tokens[token]["expires"] < now:
# Token has expired, new request
del self._idempotency_tokens[token]
return None
else:
return self._idempotency_tokens[token]['arn']
return self._idempotency_tokens[token]["arn"]
return None
def _set_idempotency_token_arn(self, token, arn):
self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)}
self._idempotency_tokens[token] = {
"arn": arn,
"expires": datetime.datetime.now() + datetime.timedelta(hours=1),
}
def import_cert(self, certificate, private_key, chain=None, arn=None):
if arn is not None:
@ -316,7 +390,9 @@ class AWSCertificateManagerBackend(BaseBackend):
raise self._arn_not_found(arn)
else:
# Will reuse provided ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn)
bundle = CertBundle(
certificate, private_key, chain=chain, region=region, arn=arn
)
else:
# Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region)
@ -351,13 +427,21 @@ class AWSCertificateManagerBackend(BaseBackend):
del self._certificates[arn]
def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names):
def request_certificate(
self,
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
):
if idempotency_token is not None:
arn = self._get_arn_from_idempotency_token(idempotency_token)
if arn is not None:
return arn
cert = CertBundle.generate_cert(domain_name, region=self.region, sans=subject_alt_names)
cert = CertBundle.generate_cert(
domain_name, region=self.region, sans=subject_alt_names
)
if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert
@ -369,8 +453,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
key = tag["Key"]
value = tag.get("Value", None)
cert_bundle.tags[key] = value
def remove_tags_from_certificate(self, arn, tags):
@ -378,8 +462,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn)
for tag in tags:
key = tag['Key']
value = tag.get('Value', None)
key = tag["Key"]
value = tag.get("Value", None)
try:
# If value isnt provided, just delete key

View File

@ -7,7 +7,6 @@ from .models import acm_backends, AWSError, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse):
@property
def acm_backend(self):
"""
@ -29,40 +28,49 @@ class AWSCertificateManagerResponse(BaseResponse):
return self.request_params.get(param, default)
def add_tags_to_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
arn = self._get_param("CertificateArn")
tags = self._get_param("Tags")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
return ""
def delete_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
self.acm_backend.delete_certificate(arn)
except AWSError as err:
return err.response()
return ''
return ""
def describe_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
@ -72,11 +80,14 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps(cert_bundle.describe())
def get_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
@ -84,8 +95,8 @@ class AWSCertificateManagerResponse(BaseResponse):
return err.response()
result = {
'Certificate': cert_bundle.cert.decode(),
'CertificateChain': cert_bundle.chain.decode()
"Certificate": cert_bundle.cert.decode(),
"CertificateChain": cert_bundle.chain.decode(),
}
return json.dumps(result)
@ -102,104 +113,129 @@ class AWSCertificateManagerResponse(BaseResponse):
:return: str(JSON) for response
"""
certificate = self._get_param('Certificate')
private_key = self._get_param('PrivateKey')
chain = self._get_param('CertificateChain') # Optional
current_arn = self._get_param('CertificateArn') # Optional
certificate = self._get_param("Certificate")
private_key = self._get_param("PrivateKey")
chain = self._get_param("CertificateChain") # Optional
current_arn = self._get_param("CertificateArn") # Optional
# Simple parameter decoding. Rather do it here as its a data transport decision not part of the
# actual data
try:
certificate = base64.standard_b64decode(certificate)
except Exception:
return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response()
return AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
).response()
try:
private_key = base64.standard_b64decode(private_key)
except Exception:
return AWSValidationException('The private key is not PEM-encoded or is not valid.').response()
return AWSValidationException(
"The private key is not PEM-encoded or is not valid."
).response()
if chain is not None:
try:
chain = base64.standard_b64decode(chain)
except Exception:
return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response()
return AWSValidationException(
"The certificate chain is not PEM-encoded or is not valid."
).response()
try:
arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn)
arn = self.acm_backend.import_cert(
certificate, private_key, chain=chain, arn=current_arn
)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
return json.dumps({"CertificateArn": arn})
def list_certificates(self):
certs = []
statuses = self._get_param('CertificateStatuses')
statuses = self._get_param("CertificateStatuses")
for cert_bundle in self.acm_backend.get_certificates_list(statuses):
certs.append({
'CertificateArn': cert_bundle.arn,
'DomainName': cert_bundle.common_name
})
certs.append(
{
"CertificateArn": cert_bundle.arn,
"DomainName": cert_bundle.common_name,
}
)
result = {'CertificateSummaryList': certs}
result = {"CertificateSummaryList": certs}
return json.dumps(result)
def list_tags_for_certificate(self):
arn = self._get_param('CertificateArn')
arn = self._get_param("CertificateArn")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return {'__type': 'MissingParameter', 'message': msg}, dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return {"__type": "MissingParameter", "message": msg}, dict(status=400)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err:
return err.response()
result = {'Tags': []}
result = {"Tags": []}
# Tag "objects" can not contain the Value part
for key, value in cert_bundle.tags.items():
tag_dict = {'Key': key}
tag_dict = {"Key": key}
if value is not None:
tag_dict['Value'] = value
result['Tags'].append(tag_dict)
tag_dict["Value"] = value
result["Tags"].append(tag_dict)
return json.dumps(result)
def remove_tags_from_certificate(self):
arn = self._get_param('CertificateArn')
tags = self._get_param('Tags')
arn = self._get_param("CertificateArn")
tags = self._get_param("Tags")
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err:
return err.response()
return ''
return ""
def request_certificate(self):
domain_name = self._get_param('DomainName')
domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm
idempotency_token = self._get_param('IdempotencyToken')
subject_alt_names = self._get_param('SubjectAlternativeNames')
domain_name = self._get_param("DomainName")
domain_validation_options = self._get_param(
"DomainValidationOptions"
) # is ignored atm
idempotency_token = self._get_param("IdempotencyToken")
subject_alt_names = self._get_param("SubjectAlternativeNames")
if subject_alt_names is not None and len(subject_alt_names) > 10:
# There is initial AWS limit of 10
msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised'
return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400)
msg = (
"An ACM limit has been exceeded. Need to request SAN limit to be raised"
)
return (
json.dumps({"__type": "LimitExceededException", "message": msg}),
dict(status=400),
)
try:
arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names)
arn = self.acm_backend.request_certificate(
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
)
except AWSError as err:
return err.response()
return json.dumps({'CertificateArn': arn})
return json.dumps({"CertificateArn": arn})
def resend_validation_email(self):
arn = self._get_param('CertificateArn')
domain = self._get_param('Domain')
arn = self._get_param("CertificateArn")
domain = self._get_param("Domain")
# ValidationDomain not used yet.
# Contains domain which is equal to or a subset of Domain
# that AWS will send validation emails to
@ -207,18 +243,21 @@ class AWSCertificateManagerResponse(BaseResponse):
# validation_domain = self._get_param('ValidationDomain')
if arn is None:
msg = 'A required parameter for the specified action is not supplied.'
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400)
msg = "A required parameter for the specified action is not supplied."
return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try:
cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain:
msg = 'Parameter Domain does not match certificate domain'
_type = 'InvalidDomainValidationOptionsException'
return json.dumps({'__type': _type, 'message': msg}), dict(status=400)
msg = "Parameter Domain does not match certificate domain"
_type = "InvalidDomainValidationOptionsException"
return json.dumps({"__type": _type, "message": msg}), dict(status=400)
except AWSError as err:
return err.response()
return ''
return ""

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import AWSCertificateManagerResponse
url_bases = [
"https?://acm.(.+).amazonaws.com",
]
url_bases = ["https?://acm.(.+).amazonaws.com"]
url_paths = {
'{0}/$': AWSCertificateManagerResponse.dispatch,
}
url_paths = {"{0}/$": AWSCertificateManagerResponse.dispatch}

View File

@ -4,4 +4,6 @@ import uuid
def make_arn_for_certificate(account_id, region_name):
# Example
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4())
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(
region_name, account_id, uuid.uuid4()
)

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import apigateway_backends
from ..core.models import base_decorator, deprecated_base_decorator
apigateway_backend = apigateway_backends['us-east-1']
apigateway_backend = apigateway_backends["us-east-1"]
mock_apigateway = base_decorator(apigateway_backends)
mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends)

View File

@ -7,7 +7,8 @@ class StageNotFoundException(RESTError):
def __init__(self):
super(StageNotFoundException, self).__init__(
"NotFoundException", "Invalid stage identifier specified")
"NotFoundException", "Invalid stage identifier specified"
)
class ApiKeyNotFoundException(RESTError):
@ -15,4 +16,5 @@ class ApiKeyNotFoundException(RESTError):
def __init__(self):
super(ApiKeyNotFoundException, self).__init__(
"NotFoundException", "Invalid API Key identifier specified")
"NotFoundException", "Invalid API Key identifier specified"
)

View File

@ -17,39 +17,33 @@ STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_nam
class Deployment(BaseModel, dict):
def __init__(self, deployment_id, name, description=""):
super(Deployment, self).__init__()
self['id'] = deployment_id
self['stageName'] = name
self['description'] = description
self['createdDate'] = int(time.time())
self["id"] = deployment_id
self["stageName"] = name
self["description"] = description
self["createdDate"] = int(time.time())
class IntegrationResponse(BaseModel, dict):
def __init__(self, status_code, selection_pattern=None):
self['responseTemplates'] = {"application/json": None}
self['statusCode'] = status_code
self["responseTemplates"] = {"application/json": None}
self["statusCode"] = status_code
if selection_pattern:
self['selectionPattern'] = selection_pattern
self["selectionPattern"] = selection_pattern
class Integration(BaseModel, dict):
def __init__(self, integration_type, uri, http_method, request_templates=None):
super(Integration, self).__init__()
self['type'] = integration_type
self['uri'] = uri
self['httpMethod'] = http_method
self['requestTemplates'] = request_templates
self["integrationResponses"] = {
"200": IntegrationResponse(200)
}
self["type"] = integration_type
self["uri"] = uri
self["httpMethod"] = http_method
self["requestTemplates"] = request_templates
self["integrationResponses"] = {"200": IntegrationResponse(200)}
def create_integration_response(self, status_code, selection_pattern):
integration_response = IntegrationResponse(
status_code, selection_pattern)
integration_response = IntegrationResponse(status_code, selection_pattern)
self["integrationResponses"][status_code] = integration_response
return integration_response
@ -61,25 +55,25 @@ class Integration(BaseModel, dict):
class MethodResponse(BaseModel, dict):
def __init__(self, status_code):
super(MethodResponse, self).__init__()
self['statusCode'] = status_code
self["statusCode"] = status_code
class Method(BaseModel, dict):
def __init__(self, method_type, authorization_type):
super(Method, self).__init__()
self.update(dict(
httpMethod=method_type,
authorizationType=authorization_type,
authorizerId=None,
apiKeyRequired=None,
requestParameters=None,
requestModels=None,
methodIntegration=None,
))
self.update(
dict(
httpMethod=method_type,
authorizationType=authorization_type,
authorizerId=None,
apiKeyRequired=None,
requestParameters=None,
requestModels=None,
methodIntegration=None,
)
)
self.method_responses = {}
def create_response(self, response_code):
@ -95,16 +89,13 @@ class Method(BaseModel, dict):
class Resource(BaseModel):
def __init__(self, id, region_name, api_id, path_part, parent_id):
self.id = id
self.region_name = region_name
self.api_id = api_id
self.path_part = path_part
self.parent_id = parent_id
self.resource_methods = {
'GET': {}
}
self.resource_methods = {"GET": {}}
def to_dict(self):
response = {
@ -113,8 +104,8 @@ class Resource(BaseModel):
"resourceMethods": self.resource_methods,
}
if self.parent_id:
response['parentId'] = self.parent_id
response['pathPart'] = self.path_part
response["parentId"] = self.parent_id
response["pathPart"] = self.path_part
return response
def get_path(self):
@ -125,102 +116,112 @@ class Resource(BaseModel):
backend = apigateway_backends[self.region_name]
parent = backend.get_resource(self.api_id, self.parent_id)
parent_path = parent.get_path()
if parent_path != '/': # Root parent
parent_path += '/'
if parent_path != "/": # Root parent
parent_path += "/"
return parent_path
else:
return ''
return ""
def get_response(self, request):
integration = self.get_integration(request.method)
integration_type = integration['type']
integration_type = integration["type"]
if integration_type == 'HTTP':
uri = integration['uri']
requests_func = getattr(requests, integration[
'httpMethod'].lower())
if integration_type == "HTTP":
uri = integration["uri"]
requests_func = getattr(requests, integration["httpMethod"].lower())
response = requests_func(uri)
else:
raise NotImplementedError(
"The {0} type has not been implemented".format(integration_type))
"The {0} type has not been implemented".format(integration_type)
)
return response.status_code, response.text
def add_method(self, method_type, authorization_type):
method = Method(method_type=method_type,
authorization_type=authorization_type)
method = Method(method_type=method_type, authorization_type=authorization_type)
self.resource_methods[method_type] = method
return method
def get_method(self, method_type):
return self.resource_methods[method_type]
def add_integration(self, method_type, integration_type, uri, request_templates=None):
def add_integration(
self, method_type, integration_type, uri, request_templates=None
):
integration = Integration(
integration_type, uri, method_type, request_templates=request_templates)
self.resource_methods[method_type]['methodIntegration'] = integration
integration_type, uri, method_type, request_templates=request_templates
)
self.resource_methods[method_type]["methodIntegration"] = integration
return integration
def get_integration(self, method_type):
return self.resource_methods[method_type]['methodIntegration']
return self.resource_methods[method_type]["methodIntegration"]
def delete_integration(self, method_type):
return self.resource_methods[method_type].pop('methodIntegration')
return self.resource_methods[method_type].pop("methodIntegration")
class Stage(BaseModel, dict):
def __init__(self, name=None, deployment_id=None, variables=None,
description='', cacheClusterEnabled=False, cacheClusterSize=None):
def __init__(
self,
name=None,
deployment_id=None,
variables=None,
description="",
cacheClusterEnabled=False,
cacheClusterSize=None,
):
super(Stage, self).__init__()
if variables is None:
variables = {}
self['stageName'] = name
self['deploymentId'] = deployment_id
self['methodSettings'] = {}
self['variables'] = variables
self['description'] = description
self['cacheClusterEnabled'] = cacheClusterEnabled
if self['cacheClusterEnabled']:
self['cacheClusterSize'] = str(0.5)
self["stageName"] = name
self["deploymentId"] = deployment_id
self["methodSettings"] = {}
self["variables"] = variables
self["description"] = description
self["cacheClusterEnabled"] = cacheClusterEnabled
if self["cacheClusterEnabled"]:
self["cacheClusterSize"] = str(0.5)
if cacheClusterSize is not None:
self['cacheClusterSize'] = str(cacheClusterSize)
self["cacheClusterSize"] = str(cacheClusterSize)
def apply_operations(self, patch_operations):
for op in patch_operations:
if 'variables/' in op['path']:
if "variables/" in op["path"]:
self._apply_operation_to_variables(op)
elif '/cacheClusterEnabled' in op['path']:
self['cacheClusterEnabled'] = self._str2bool(op['value'])
if 'cacheClusterSize' not in self and self['cacheClusterEnabled']:
self['cacheClusterSize'] = str(0.5)
elif '/cacheClusterSize' in op['path']:
self['cacheClusterSize'] = str(float(op['value']))
elif '/description' in op['path']:
self['description'] = op['value']
elif '/deploymentId' in op['path']:
self['deploymentId'] = op['value']
elif op['op'] == 'replace':
elif "/cacheClusterEnabled" in op["path"]:
self["cacheClusterEnabled"] = self._str2bool(op["value"])
if "cacheClusterSize" not in self and self["cacheClusterEnabled"]:
self["cacheClusterSize"] = str(0.5)
elif "/cacheClusterSize" in op["path"]:
self["cacheClusterSize"] = str(float(op["value"]))
elif "/description" in op["path"]:
self["description"] = op["value"]
elif "/deploymentId" in op["path"]:
self["deploymentId"] = op["value"]
elif op["op"] == "replace":
# Method Settings drop into here
# (e.g., path could be '/*/*/logging/loglevel')
split_path = op['path'].split('/', 3)
split_path = op["path"].split("/", 3)
if len(split_path) != 4:
continue
self._patch_method_setting(
'/'.join(split_path[1:3]), split_path[3], op['value'])
"/".join(split_path[1:3]), split_path[3], op["value"]
)
else:
raise Exception(
'Patch operation "%s" not implemented' % op['op'])
raise Exception('Patch operation "%s" not implemented' % op["op"])
return self
def _patch_method_setting(self, resource_path_and_method, key, value):
updated_key = self._method_settings_translations(key)
if updated_key is not None:
if resource_path_and_method not in self['methodSettings']:
self['methodSettings'][
resource_path_and_method] = self._get_default_method_settings()
self['methodSettings'][resource_path_and_method][
updated_key] = self._convert_to_type(updated_key, value)
if resource_path_and_method not in self["methodSettings"]:
self["methodSettings"][
resource_path_and_method
] = self._get_default_method_settings()
self["methodSettings"][resource_path_and_method][
updated_key
] = self._convert_to_type(updated_key, value)
def _get_default_method_settings(self):
return {
@ -232,21 +233,21 @@ class Stage(BaseModel, dict):
"cacheDataEncrypted": True,
"cachingEnabled": False,
"throttlingBurstLimit": 2000,
"requireAuthorizationForCacheControl": True
"requireAuthorizationForCacheControl": True,
}
def _method_settings_translations(self, key):
mappings = {
'metrics/enabled': 'metricsEnabled',
'logging/loglevel': 'loggingLevel',
'logging/dataTrace': 'dataTraceEnabled',
'throttling/burstLimit': 'throttlingBurstLimit',
'throttling/rateLimit': 'throttlingRateLimit',
'caching/enabled': 'cachingEnabled',
'caching/ttlInSeconds': 'cacheTtlInSeconds',
'caching/dataEncrypted': 'cacheDataEncrypted',
'caching/requireAuthorizationForCacheControl': 'requireAuthorizationForCacheControl',
'caching/unauthorizedCacheControlHeaderStrategy': 'unauthorizedCacheControlHeaderStrategy'
"metrics/enabled": "metricsEnabled",
"logging/loglevel": "loggingLevel",
"logging/dataTrace": "dataTraceEnabled",
"throttling/burstLimit": "throttlingBurstLimit",
"throttling/rateLimit": "throttlingRateLimit",
"caching/enabled": "cachingEnabled",
"caching/ttlInSeconds": "cacheTtlInSeconds",
"caching/dataEncrypted": "cacheDataEncrypted",
"caching/requireAuthorizationForCacheControl": "requireAuthorizationForCacheControl",
"caching/unauthorizedCacheControlHeaderStrategy": "unauthorizedCacheControlHeaderStrategy",
}
if key in mappings:
@ -259,26 +260,26 @@ class Stage(BaseModel, dict):
def _convert_to_type(self, key, val):
type_mappings = {
'metricsEnabled': 'bool',
'loggingLevel': 'str',
'dataTraceEnabled': 'bool',
'throttlingBurstLimit': 'int',
'throttlingRateLimit': 'float',
'cachingEnabled': 'bool',
'cacheTtlInSeconds': 'int',
'cacheDataEncrypted': 'bool',
'requireAuthorizationForCacheControl': 'bool',
'unauthorizedCacheControlHeaderStrategy': 'str'
"metricsEnabled": "bool",
"loggingLevel": "str",
"dataTraceEnabled": "bool",
"throttlingBurstLimit": "int",
"throttlingRateLimit": "float",
"cachingEnabled": "bool",
"cacheTtlInSeconds": "int",
"cacheDataEncrypted": "bool",
"requireAuthorizationForCacheControl": "bool",
"unauthorizedCacheControlHeaderStrategy": "str",
}
if key in type_mappings:
type_value = type_mappings[key]
if type_value == 'bool':
if type_value == "bool":
return self._str2bool(val)
elif type_value == 'int':
elif type_value == "int":
return int(val)
elif type_value == 'float':
elif type_value == "float":
return float(val)
else:
return str(val)
@ -286,44 +287,55 @@ class Stage(BaseModel, dict):
return str(val)
def _apply_operation_to_variables(self, op):
key = op['path'][op['path'].rindex("variables/") + 10:]
if op['op'] == 'remove':
self['variables'].pop(key, None)
elif op['op'] == 'replace':
self['variables'][key] = op['value']
key = op["path"][op["path"].rindex("variables/") + 10 :]
if op["op"] == "remove":
self["variables"].pop(key, None)
elif op["op"] == "replace":
self["variables"][key] = op["value"]
else:
raise Exception('Patch operation "%s" not implemented' % op['op'])
raise Exception('Patch operation "%s" not implemented' % op["op"])
class ApiKey(BaseModel, dict):
def __init__(self, name=None, description=None, enabled=True,
generateDistinctId=False, value=None, stageKeys=None, tags=None, customerId=None):
def __init__(
self,
name=None,
description=None,
enabled=True,
generateDistinctId=False,
value=None,
stageKeys=None,
tags=None,
customerId=None,
):
super(ApiKey, self).__init__()
self['id'] = create_id()
self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40))
self['name'] = name
self['customerId'] = customerId
self['description'] = description
self['enabled'] = enabled
self['createdDate'] = self['lastUpdatedDate'] = int(time.time())
self['stageKeys'] = stageKeys
self['tags'] = tags
self["id"] = create_id()
self["value"] = (
value
if value
else "".join(random.sample(string.ascii_letters + string.digits, 40))
)
self["name"] = name
self["customerId"] = customerId
self["description"] = description
self["enabled"] = enabled
self["createdDate"] = self["lastUpdatedDate"] = int(time.time())
self["stageKeys"] = stageKeys
self["tags"] = tags
def update_operations(self, patch_operations):
for op in patch_operations:
if op['op'] == 'replace':
if '/name' in op['path']:
self['name'] = op['value']
elif '/customerId' in op['path']:
self['customerId'] = op['value']
elif '/description' in op['path']:
self['description'] = op['value']
elif '/enabled' in op['path']:
self['enabled'] = self._str2bool(op['value'])
if op["op"] == "replace":
if "/name" in op["path"]:
self["name"] = op["value"]
elif "/customerId" in op["path"]:
self["customerId"] = op["value"]
elif "/description" in op["path"]:
self["description"] = op["value"]
elif "/enabled" in op["path"]:
self["enabled"] = self._str2bool(op["value"])
else:
raise Exception(
'Patch operation "%s" not implemented' % op['op'])
raise Exception('Patch operation "%s" not implemented' % op["op"])
return self
def _str2bool(self, v):
@ -331,31 +343,35 @@ class ApiKey(BaseModel, dict):
class UsagePlan(BaseModel, dict):
def __init__(self, name=None, description=None, apiStages=None,
throttle=None, quota=None, tags=None):
def __init__(
self,
name=None,
description=None,
apiStages=None,
throttle=None,
quota=None,
tags=None,
):
super(UsagePlan, self).__init__()
self['id'] = create_id()
self['name'] = name
self['description'] = description
self['apiStages'] = apiStages if apiStages else []
self['throttle'] = throttle
self['quota'] = quota
self['tags'] = tags
self["id"] = create_id()
self["name"] = name
self["description"] = description
self["apiStages"] = apiStages if apiStages else []
self["throttle"] = throttle
self["quota"] = quota
self["tags"] = tags
class UsagePlanKey(BaseModel, dict):
def __init__(self, id, type, name, value):
super(UsagePlanKey, self).__init__()
self['id'] = id
self['name'] = name
self['type'] = type
self['value'] = value
self["id"] = id
self["name"] = name
self["type"] = type
self["value"] = value
class RestAPI(BaseModel):
def __init__(self, id, region_name, name, description):
self.id = id
self.region_name = region_name
@ -367,7 +383,7 @@ class RestAPI(BaseModel):
self.stages = {}
self.resources = {}
self.add_child('/') # Add default child
self.add_child("/") # Add default child
def __repr__(self):
return str(self.id)
@ -382,8 +398,13 @@ class RestAPI(BaseModel):
def add_child(self, path, parent_id=None):
child_id = create_id()
child = Resource(id=child_id, region_name=self.region_name,
api_id=self.id, path_part=path, parent_id=parent_id)
child = Resource(
id=child_id,
region_name=self.region_name,
api_id=self.id,
path_part=path,
parent_id=parent_id,
)
self.resources[child_id] = child
return child
@ -395,36 +416,53 @@ class RestAPI(BaseModel):
def resource_callback(self, request):
path = path_url(request.url)
path_after_stage_name = '/'.join(path.split("/")[2:])
path_after_stage_name = "/".join(path.split("/")[2:])
if not path_after_stage_name:
path_after_stage_name = '/'
path_after_stage_name = "/"
resource = self.get_resource_for_path(path_after_stage_name)
status_code, response = resource.get_response(request)
return status_code, {}, response
def update_integration_mocks(self, stage_name):
stage_url_lower = STAGE_URL.format(api_id=self.id.lower(),
region_name=self.region_name, stage_name=stage_name)
stage_url_upper = STAGE_URL.format(api_id=self.id.upper(),
region_name=self.region_name, stage_name=stage_name)
stage_url_lower = STAGE_URL.format(
api_id=self.id.lower(), region_name=self.region_name, stage_name=stage_name
)
stage_url_upper = STAGE_URL.format(
api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name
)
for url in [stage_url_lower, stage_url_upper]:
responses._default_mock._matches.insert(0,
responses._default_mock._matches.insert(
0,
responses.CallbackResponse(
url=url,
method=responses.GET,
callback=self.resource_callback,
content_type="text/plain",
match_querystring=False,
)
),
)
def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None):
def create_stage(
self,
name,
deployment_id,
variables=None,
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
):
if variables is None:
variables = {}
stage = Stage(name=name, deployment_id=deployment_id, variables=variables,
description=description, cacheClusterSize=cacheClusterSize, cacheClusterEnabled=cacheClusterEnabled)
stage = Stage(
name=name,
deployment_id=deployment_id,
variables=variables,
description=description,
cacheClusterSize=cacheClusterSize,
cacheClusterEnabled=cacheClusterEnabled,
)
self.stages[name] = stage
self.update_integration_mocks(name)
return stage
@ -436,7 +474,8 @@ class RestAPI(BaseModel):
deployment = Deployment(deployment_id, name, description)
self.deployments[deployment_id] = deployment
self.stages[name] = Stage(
name=name, deployment_id=deployment_id, variables=stage_variables)
name=name, deployment_id=deployment_id, variables=stage_variables
)
self.update_integration_mocks(name)
return deployment
@ -455,7 +494,6 @@ class RestAPI(BaseModel):
class APIGatewayBackend(BaseBackend):
def __init__(self, region_name):
super(APIGatewayBackend, self).__init__()
self.apis = {}
@ -497,10 +535,7 @@ class APIGatewayBackend(BaseBackend):
def create_resource(self, function_id, parent_resource_id, path_part):
api = self.get_rest_api(function_id)
child = api.add_child(
path=path_part,
parent_id=parent_resource_id,
)
child = api.add_child(path=path_part, parent_id=parent_resource_id)
return child
def delete_resource(self, function_id, resource_id):
@ -529,13 +564,27 @@ class APIGatewayBackend(BaseBackend):
api = self.get_rest_api(function_id)
return api.get_stages()
def create_stage(self, function_id, stage_name, deploymentId,
variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None):
def create_stage(
self,
function_id,
stage_name,
deploymentId,
variables=None,
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
):
if variables is None:
variables = {}
api = self.get_rest_api(function_id)
api.create_stage(stage_name, deploymentId, variables=variables,
description=description, cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize)
api.create_stage(
stage_name,
deploymentId,
variables=variables,
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
)
return api.stages.get(stage_name)
def update_stage(self, function_id, stage_name, patch_operations):
@ -550,21 +599,33 @@ class APIGatewayBackend(BaseBackend):
method_response = method.get_response(response_code)
return method_response
def create_method_response(self, function_id, resource_id, method_type, response_code):
def create_method_response(
self, function_id, resource_id, method_type, response_code
):
method = self.get_method(function_id, resource_id, method_type)
method_response = method.create_response(response_code)
return method_response
def delete_method_response(self, function_id, resource_id, method_type, response_code):
def delete_method_response(
self, function_id, resource_id, method_type, response_code
):
method = self.get_method(function_id, resource_id, method_type)
method_response = method.delete_response(response_code)
return method_response
def create_integration(self, function_id, resource_id, method_type, integration_type, uri,
request_templates=None):
def create_integration(
self,
function_id,
resource_id,
method_type,
integration_type,
uri,
request_templates=None,
):
resource = self.get_resource(function_id, resource_id)
integration = resource.add_integration(method_type, integration_type, uri,
request_templates=request_templates)
integration = resource.add_integration(
method_type, integration_type, uri, request_templates=request_templates
)
return integration
def get_integration(self, function_id, resource_id, method_type):
@ -575,28 +636,32 @@ class APIGatewayBackend(BaseBackend):
resource = self.get_resource(function_id, resource_id)
return resource.delete_integration(method_type)
def create_integration_response(self, function_id, resource_id, method_type, status_code, selection_pattern):
integration = self.get_integration(
function_id, resource_id, method_type)
def create_integration_response(
self, function_id, resource_id, method_type, status_code, selection_pattern
):
integration = self.get_integration(function_id, resource_id, method_type)
integration_response = integration.create_integration_response(
status_code, selection_pattern)
status_code, selection_pattern
)
return integration_response
def get_integration_response(self, function_id, resource_id, method_type, status_code):
integration = self.get_integration(
function_id, resource_id, method_type)
integration_response = integration.get_integration_response(
status_code)
def get_integration_response(
self, function_id, resource_id, method_type, status_code
):
integration = self.get_integration(function_id, resource_id, method_type)
integration_response = integration.get_integration_response(status_code)
return integration_response
def delete_integration_response(self, function_id, resource_id, method_type, status_code):
integration = self.get_integration(
function_id, resource_id, method_type)
integration_response = integration.delete_integration_response(
status_code)
def delete_integration_response(
self, function_id, resource_id, method_type, status_code
):
integration = self.get_integration(function_id, resource_id, method_type)
integration_response = integration.delete_integration_response(status_code)
return integration_response
def create_deployment(self, function_id, name, description="", stage_variables=None):
def create_deployment(
self, function_id, name, description="", stage_variables=None
):
if stage_variables is None:
stage_variables = {}
api = self.get_rest_api(function_id)
@ -617,7 +682,7 @@ class APIGatewayBackend(BaseBackend):
def create_apikey(self, payload):
key = ApiKey(**payload)
self.keys[key['id']] = key
self.keys[key["id"]] = key
return key
def get_apikeys(self):
@ -636,7 +701,7 @@ class APIGatewayBackend(BaseBackend):
def create_usage_plan(self, payload):
plan = UsagePlan(**payload)
self.usage_plans[plan['id']] = plan
self.usage_plans[plan["id"]] = plan
return plan
def get_usage_plans(self, api_key_id=None):
@ -645,7 +710,7 @@ class APIGatewayBackend(BaseBackend):
plans = [
plan
for plan in plans
if self.usage_plan_keys.get(plan['id'], {}).get(api_key_id, False)
if self.usage_plan_keys.get(plan["id"], {}).get(api_key_id, False)
]
return plans
@ -666,8 +731,13 @@ class APIGatewayBackend(BaseBackend):
api_key = self.keys[key_id]
usage_plan_key = UsagePlanKey(id=key_id, type=payload["keyType"], name=api_key["name"], value=api_key["value"])
self.usage_plan_keys[usage_plan_id][usage_plan_key['id']] = usage_plan_key
usage_plan_key = UsagePlanKey(
id=key_id,
type=payload["keyType"],
name=api_key["name"],
value=api_key["value"],
)
self.usage_plan_keys[usage_plan_id][usage_plan_key["id"]] = usage_plan_key
return usage_plan_key
def get_usage_plan_keys(self, usage_plan_id):
@ -685,5 +755,5 @@ class APIGatewayBackend(BaseBackend):
apigateway_backends = {}
for region_name in Session().get_available_regions('apigateway'):
for region_name in Session().get_available_regions("apigateway"):
apigateway_backends[region_name] = APIGatewayBackend(region_name)

View File

@ -8,7 +8,6 @@ from .exceptions import StageNotFoundException, ApiKeyNotFoundException
class APIGatewayResponse(BaseResponse):
def _get_param(self, key):
return json.loads(self.body).get(key)
@ -27,14 +26,12 @@ class APIGatewayResponse(BaseResponse):
def restapis(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if self.method == 'GET':
if self.method == "GET":
apis = self.backend.list_apis()
return 200, {}, json.dumps({"item": [
api.to_dict() for api in apis
]})
elif self.method == 'POST':
name = self._get_param('name')
description = self._get_param('description')
return 200, {}, json.dumps({"item": [api.to_dict() for api in apis]})
elif self.method == "POST":
name = self._get_param("name")
description = self._get_param("description")
rest_api = self.backend.create_rest_api(name, description)
return 200, {}, json.dumps(rest_api.to_dict())
@ -42,10 +39,10 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET':
if self.method == "GET":
rest_api = self.backend.get_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict())
elif self.method == 'DELETE':
elif self.method == "DELETE":
rest_api = self.backend.delete_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict())
@ -53,24 +50,25 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET':
if self.method == "GET":
resources = self.backend.list_resources(function_id)
return 200, {}, json.dumps({"item": [
resource.to_dict() for resource in resources
]})
return (
200,
{},
json.dumps({"item": [resource.to_dict() for resource in resources]}),
)
def resource_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
resource_id = self.path.split("/")[-1]
if self.method == 'GET':
if self.method == "GET":
resource = self.backend.get_resource(function_id, resource_id)
elif self.method == 'POST':
elif self.method == "POST":
path_part = self._get_param("pathPart")
resource = self.backend.create_resource(
function_id, resource_id, path_part)
elif self.method == 'DELETE':
resource = self.backend.create_resource(function_id, resource_id, path_part)
elif self.method == "DELETE":
resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict())
@ -81,14 +79,14 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4]
method_type = url_path_parts[6]
if self.method == 'GET':
method = self.backend.get_method(
function_id, resource_id, method_type)
if self.method == "GET":
method = self.backend.get_method(function_id, resource_id, method_type)
return 200, {}, json.dumps(method)
elif self.method == 'PUT':
elif self.method == "PUT":
authorization_type = self._get_param("authorizationType")
method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type)
function_id, resource_id, method_type, authorization_type
)
return 200, {}, json.dumps(method)
def resource_method_responses(self, request, full_url, headers):
@ -99,15 +97,18 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6]
response_code = url_path_parts[8]
if self.method == 'GET':
if self.method == "GET":
method_response = self.backend.get_method_response(
function_id, resource_id, method_type, response_code)
elif self.method == 'PUT':
function_id, resource_id, method_type, response_code
)
elif self.method == "PUT":
method_response = self.backend.create_method_response(
function_id, resource_id, method_type, response_code)
elif self.method == 'DELETE':
function_id, resource_id, method_type, response_code
)
elif self.method == "DELETE":
method_response = self.backend.delete_method_response(
function_id, resource_id, method_type, response_code)
function_id, resource_id, method_type, response_code
)
return 200, {}, json.dumps(method_response)
def restapis_stages(self, request, full_url, headers):
@ -115,21 +116,28 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
function_id = url_path_parts[2]
if self.method == 'POST':
if self.method == "POST":
stage_name = self._get_param("stageName")
deployment_id = self._get_param("deploymentId")
stage_variables = self._get_param_with_default_value(
'variables', {})
description = self._get_param_with_default_value('description', '')
stage_variables = self._get_param_with_default_value("variables", {})
description = self._get_param_with_default_value("description", "")
cacheClusterEnabled = self._get_param_with_default_value(
'cacheClusterEnabled', False)
"cacheClusterEnabled", False
)
cacheClusterSize = self._get_param_with_default_value(
'cacheClusterSize', None)
"cacheClusterSize", None
)
stage_response = self.backend.create_stage(function_id, stage_name, deployment_id,
variables=stage_variables, description=description,
cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize)
elif self.method == 'GET':
stage_response = self.backend.create_stage(
function_id,
stage_name,
deployment_id,
variables=stage_variables,
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
)
elif self.method == "GET":
stages = self.backend.get_stages(function_id)
return 200, {}, json.dumps({"item": stages})
@ -141,16 +149,22 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2]
stage_name = url_path_parts[4]
if self.method == 'GET':
if self.method == "GET":
try:
stage_response = self.backend.get_stage(
function_id, stage_name)
stage_response = self.backend.get_stage(function_id, stage_name)
except StageNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type)
elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations')
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
stage_response = self.backend.update_stage(
function_id, stage_name, patch_operations)
function_id, stage_name, patch_operations
)
return 200, {}, json.dumps(stage_response)
def integrations(self, request, full_url, headers):
@ -160,18 +174,26 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4]
method_type = url_path_parts[6]
if self.method == 'GET':
if self.method == "GET":
integration_response = self.backend.get_integration(
function_id, resource_id, method_type)
elif self.method == 'PUT':
integration_type = self._get_param('type')
uri = self._get_param('uri')
request_templates = self._get_param('requestTemplates')
function_id, resource_id, method_type
)
elif self.method == "PUT":
integration_type = self._get_param("type")
uri = self._get_param("uri")
request_templates = self._get_param("requestTemplates")
integration_response = self.backend.create_integration(
function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates)
elif self.method == 'DELETE':
function_id,
resource_id,
method_type,
integration_type,
uri,
request_templates=request_templates,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration(
function_id, resource_id, method_type)
function_id, resource_id, method_type
)
return 200, {}, json.dumps(integration_response)
def integration_responses(self, request, full_url, headers):
@ -182,16 +204,16 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6]
status_code = url_path_parts[9]
if self.method == 'GET':
if self.method == "GET":
integration_response = self.backend.get_integration_response(
function_id, resource_id, method_type, status_code
)
elif self.method == 'PUT':
elif self.method == "PUT":
selection_pattern = self._get_param("selectionPattern")
integration_response = self.backend.create_integration_response(
function_id, resource_id, method_type, status_code, selection_pattern
)
elif self.method == 'DELETE':
elif self.method == "DELETE":
integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code
)
@ -201,16 +223,16 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET':
if self.method == "GET":
deployments = self.backend.get_deployments(function_id)
return 200, {}, json.dumps({"item": deployments})
elif self.method == 'POST':
elif self.method == "POST":
name = self._get_param("stageName")
description = self._get_param_with_default_value("description", "")
stage_variables = self._get_param_with_default_value(
'variables', {})
stage_variables = self._get_param_with_default_value("variables", {})
deployment = self.backend.create_deployment(
function_id, name, description, stage_variables)
function_id, name, description, stage_variables
)
return 200, {}, json.dumps(deployment)
def individual_deployment(self, request, full_url, headers):
@ -219,20 +241,18 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2]
deployment_id = url_path_parts[4]
if self.method == 'GET':
deployment = self.backend.get_deployment(
function_id, deployment_id)
elif self.method == 'DELETE':
deployment = self.backend.delete_deployment(
function_id, deployment_id)
if self.method == "GET":
deployment = self.backend.get_deployment(function_id, deployment_id)
elif self.method == "DELETE":
deployment = self.backend.delete_deployment(function_id, deployment_id)
return 200, {}, json.dumps(deployment)
def apikeys(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if self.method == 'POST':
if self.method == "POST":
apikey_response = self.backend.create_apikey(json.loads(self.body))
elif self.method == 'GET':
elif self.method == "GET":
apikeys_response = self.backend.get_apikeys()
return 200, {}, json.dumps({"item": apikeys_response})
return 200, {}, json.dumps(apikey_response)
@ -243,21 +263,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
apikey = url_path_parts[2]
if self.method == 'GET':
if self.method == "GET":
apikey_response = self.backend.get_apikey(apikey)
elif self.method == 'PATCH':
patch_operations = self._get_param('patchOperations')
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == 'DELETE':
elif self.method == "DELETE":
apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response)
def usage_plans(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if self.method == 'POST':
if self.method == "POST":
usage_plan_response = self.backend.create_usage_plan(json.loads(self.body))
elif self.method == 'GET':
elif self.method == "GET":
api_key_id = self.querystring.get("keyId", [None])[0]
usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id)
return 200, {}, json.dumps({"item": usage_plans_response})
@ -269,9 +289,9 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
usage_plan = url_path_parts[2]
if self.method == 'GET':
if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan(usage_plan)
elif self.method == 'DELETE':
elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan(usage_plan)
return 200, {}, json.dumps(usage_plan_response)
@ -281,13 +301,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/")
usage_plan_id = url_path_parts[2]
if self.method == 'POST':
if self.method == "POST":
try:
usage_plan_response = self.backend.create_usage_plan_key(usage_plan_id, json.loads(self.body))
usage_plan_response = self.backend.create_usage_plan_key(
usage_plan_id, json.loads(self.body)
)
except ApiKeyNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type)
return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == 'GET':
elif self.method == "GET":
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response})
@ -300,8 +328,10 @@ class APIGatewayResponse(BaseResponse):
usage_plan_id = url_path_parts[2]
key_id = url_path_parts[4]
if self.method == 'GET':
if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id)
elif self.method == 'DELETE':
usage_plan_response = self.backend.delete_usage_plan_key(usage_plan_id, key_id)
elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan_key(
usage_plan_id, key_id
)
return 200, {}, json.dumps(usage_plan_response)

View File

@ -1,27 +1,25 @@
from __future__ import unicode_literals
from .responses import APIGatewayResponse
url_bases = [
"https?://apigateway.(.+).amazonaws.com"
]
url_bases = ["https?://apigateway.(.+).amazonaws.com"]
url_paths = {
'{0}/restapis$': APIGatewayResponse().restapis,
'{0}/restapis/(?P<function_id>[^/]+)/?$': APIGatewayResponse().restapis_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources$': APIGatewayResponse().resources,
'{0}/restapis/(?P<function_id>[^/]+)/stages$': APIGatewayResponse().restapis_stages,
'{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$': APIGatewayResponse().stages,
'{0}/restapis/(?P<function_id>[^/]+)/deployments$': APIGatewayResponse().deployments,
'{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$': APIGatewayResponse().individual_deployment,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$': APIGatewayResponse().resource_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$': APIGatewayResponse().resource_methods,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$': APIGatewayResponse().resource_method_responses,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$': APIGatewayResponse().integrations,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$': APIGatewayResponse().integration_responses,
'{0}/apikeys$': APIGatewayResponse().apikeys,
'{0}/apikeys/(?P<apikey>[^/]+)': APIGatewayResponse().apikey_individual,
'{0}/usageplans$': APIGatewayResponse().usage_plans,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$': APIGatewayResponse().usage_plan_individual,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$': APIGatewayResponse().usage_plan_keys,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$': APIGatewayResponse().usage_plan_key_individual,
"{0}/restapis$": APIGatewayResponse().restapis,
"{0}/restapis/(?P<function_id>[^/]+)/?$": APIGatewayResponse().restapis_individual,
"{0}/restapis/(?P<function_id>[^/]+)/resources$": APIGatewayResponse().resources,
"{0}/restapis/(?P<function_id>[^/]+)/stages$": APIGatewayResponse().restapis_stages,
"{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$": APIGatewayResponse().stages,
"{0}/restapis/(?P<function_id>[^/]+)/deployments$": APIGatewayResponse().deployments,
"{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$": APIGatewayResponse().individual_deployment,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$": APIGatewayResponse().resource_individual,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$": APIGatewayResponse().resource_methods,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$": APIGatewayResponse().resource_method_responses,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$": APIGatewayResponse().integrations,
"{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$": APIGatewayResponse().integration_responses,
"{0}/apikeys$": APIGatewayResponse().apikeys,
"{0}/apikeys/(?P<apikey>[^/]+)": APIGatewayResponse().apikey_individual,
"{0}/usageplans$": APIGatewayResponse().usage_plans,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$": APIGatewayResponse().usage_plan_individual,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$": APIGatewayResponse().usage_plan_keys,
"{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual,
}

View File

@ -7,4 +7,4 @@ import string
def create_id():
size = 10
chars = list(range(10)) + list(string.ascii_lowercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size))
return "".join(six.text_type(random.choice(chars)) for x in range(size))

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import athena_backends
from ..core.models import base_decorator, deprecated_base_decorator
athena_backend = athena_backends['us-east-1']
athena_backend = athena_backends["us-east-1"]
mock_athena = base_decorator(athena_backends)
mock_athena_deprecated = deprecated_base_decorator(athena_backends)

View File

@ -5,14 +5,15 @@ from werkzeug.exceptions import BadRequest
class AthenaClientError(BadRequest):
def __init__(self, code, message):
super(AthenaClientError, self).__init__()
self.description = json.dumps({
"Error": {
"Code": code,
"Message": message,
'Type': "InvalidRequestException",
},
'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1',
})
self.description = json.dumps(
{
"Error": {
"Code": code,
"Message": message,
"Type": "InvalidRequestException",
},
"RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1",
}
)

View File

@ -19,31 +19,30 @@ class TaggableResourceMixin(object):
@property
def arn(self):
return "arn:aws:athena:{region}:{account_id}:{resource_name}".format(
region=self.region,
account_id=ACCOUNT_ID,
resource_name=self.resource_name)
region=self.region, account_id=ACCOUNT_ID, resource_name=self.resource_name
)
def create_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags
if tag_set['Key'] not in new_keys]
new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
self.tags.extend(tags)
return self.tags
def delete_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags
if tag_set['Key'] not in tag_keys]
self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
return self.tags
class WorkGroup(TaggableResourceMixin, BaseModel):
resource_type = 'workgroup'
state = 'ENABLED'
resource_type = "workgroup"
state = "ENABLED"
def __init__(self, athena_backend, name, configuration, description, tags):
self.region_name = athena_backend.region_name
super(WorkGroup, self).__init__(self.region_name, "workgroup/{}".format(name), tags)
super(WorkGroup, self).__init__(
self.region_name, "workgroup/{}".format(name), tags
)
self.athena_backend = athena_backend
self.name = name
self.description = description
@ -66,14 +65,17 @@ class AthenaBackend(BaseBackend):
return work_group
def list_work_groups(self):
return [{
'Name': wg.name,
'State': wg.state,
'Description': wg.description,
'CreationTime': time.time(),
} for wg in self.work_groups.values()]
return [
{
"Name": wg.name,
"State": wg.state,
"Description": wg.description,
"CreationTime": time.time(),
}
for wg in self.work_groups.values()
]
athena_backends = {}
for region in boto3.Session().get_available_regions('athena'):
for region in boto3.Session().get_available_regions("athena"):
athena_backends[region] = AthenaBackend(region)

View File

@ -5,31 +5,37 @@ from .models import athena_backends
class AthenaResponse(BaseResponse):
@property
def athena_backend(self):
return athena_backends[self.region]
def create_work_group(self):
name = self._get_param('Name')
description = self._get_param('Description')
configuration = self._get_param('Configuration')
tags = self._get_param('Tags')
work_group = self.athena_backend.create_work_group(name, configuration, description, tags)
name = self._get_param("Name")
description = self._get_param("Description")
configuration = self._get_param("Configuration")
tags = self._get_param("Tags")
work_group = self.athena_backend.create_work_group(
name, configuration, description, tags
)
if not work_group:
return json.dumps({
'__type': 'InvalidRequestException',
'Message': 'WorkGroup already exists',
}), dict(status=400)
return json.dumps({
"CreateWorkGroupResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a",
return (
json.dumps(
{
"__type": "InvalidRequestException",
"Message": "WorkGroup already exists",
}
),
dict(status=400),
)
return json.dumps(
{
"CreateWorkGroupResponse": {
"ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a"
}
}
}
})
)
def list_work_groups(self):
return json.dumps({
"WorkGroups": self.athena_backend.list_work_groups()
})
return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import AthenaResponse
url_bases = [
"https?://athena.(.+).amazonaws.com",
]
url_bases = ["https?://athena.(.+).amazonaws.com"]
url_paths = {
'{0}/$': AthenaResponse.dispatch,
}
url_paths = {"{0}/$": AthenaResponse.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import autoscaling_backends
from ..core.models import base_decorator, deprecated_base_decorator
autoscaling_backend = autoscaling_backends['us-east-1']
autoscaling_backend = autoscaling_backends["us-east-1"]
mock_autoscaling = base_decorator(autoscaling_backends)
mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends)

View File

@ -12,13 +12,12 @@ class ResourceContentionError(RESTError):
def __init__(self):
super(ResourceContentionError, self).__init__(
"ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).")
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).",
)
class InvalidInstanceError(AutoscalingClientError):
def __init__(self, instance_id):
super(InvalidInstanceError, self).__init__(
"ValidationError",
"Instance [{0}] is invalid."
.format(instance_id))
"ValidationError", "Instance [{0}] is invalid.".format(instance_id)
)

View File

@ -12,7 +12,9 @@ from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import (
AutoscalingClientError, ResourceContentionError, InvalidInstanceError
AutoscalingClientError,
ResourceContentionError,
InvalidInstanceError,
)
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -22,8 +24,13 @@ ASG_NAME_TAG = "aws:autoscaling:groupName"
class InstanceState(object):
def __init__(self, instance, lifecycle_state="InService",
health_status="Healthy", protected_from_scale_in=False):
def __init__(
self,
instance,
lifecycle_state="InService",
health_status="Healthy",
protected_from_scale_in=False,
):
self.instance = instance
self.lifecycle_state = lifecycle_state
self.health_status = health_status
@ -31,8 +38,16 @@ class InstanceState(object):
class FakeScalingPolicy(BaseModel):
def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment,
cooldown, autoscaling_backend):
def __init__(
self,
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
autoscaling_backend,
):
self.name = name
self.policy_type = policy_type
self.adjustment_type = adjustment_type
@ -45,21 +60,38 @@ class FakeScalingPolicy(BaseModel):
self.autoscaling_backend = autoscaling_backend
def execute(self):
if self.adjustment_type == 'ExactCapacity':
if self.adjustment_type == "ExactCapacity":
self.autoscaling_backend.set_desired_capacity(
self.as_name, self.scaling_adjustment)
elif self.adjustment_type == 'ChangeInCapacity':
self.as_name, self.scaling_adjustment
)
elif self.adjustment_type == "ChangeInCapacity":
self.autoscaling_backend.change_capacity(
self.as_name, self.scaling_adjustment)
elif self.adjustment_type == 'PercentChangeInCapacity':
self.as_name, self.scaling_adjustment
)
elif self.adjustment_type == "PercentChangeInCapacity":
self.autoscaling_backend.change_capacity_percent(
self.as_name, self.scaling_adjustment)
self.as_name, self.scaling_adjustment
)
class FakeLaunchConfiguration(BaseModel):
def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data,
instance_type, instance_monitoring, instance_profile_name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict):
def __init__(
self,
name,
image_id,
key_name,
ramdisk_id,
kernel_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mapping_dict,
):
self.name = name
self.image_id = image_id
self.key_name = key_name
@ -80,8 +112,8 @@ class FakeLaunchConfiguration(BaseModel):
config = backend.create_launch_configuration(
name=name,
image_id=instance.image_id,
kernel_id='',
ramdisk_id='',
kernel_id="",
ramdisk_id="",
key_name=instance.key_name,
security_groups=instance.security_groups,
user_data=instance.user_data,
@ -91,13 +123,15 @@ class FakeLaunchConfiguration(BaseModel):
spot_price=None,
ebs_optimized=instance.ebs_optimized,
associate_public_ip_address=instance.associate_public_ip,
block_device_mappings=instance.block_device_mapping
block_device_mappings=instance.block_device_mapping,
)
return config
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
instance_profile_name = properties.get("IamInstanceProfile")
@ -115,20 +149,26 @@ class FakeLaunchConfiguration(BaseModel):
instance_profile_name=instance_profile_name,
spot_price=properties.get("SpotPrice"),
ebs_optimized=properties.get("EbsOptimized"),
associate_public_ip_address=properties.get(
"AssociatePublicIpAddress"),
block_device_mappings=properties.get("BlockDeviceMapping.member")
associate_public_ip_address=properties.get("AssociatePublicIpAddress"),
block_device_mappings=properties.get("BlockDeviceMapping.member"),
)
return config
@classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name)
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name)
original_resource.name, cloudformation_json, region_name
)
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name]
try:
backend.delete_launch_configuration(resource_name)
@ -153,34 +193,49 @@ class FakeLaunchConfiguration(BaseModel):
@property
def instance_monitoring_enabled(self):
if self.instance_monitoring:
return 'true'
return 'false'
return "true"
return "false"
def _parse_block_device_mappings(self):
block_device_map = BlockDeviceMapping()
for mapping in self.block_device_mapping_dict:
block_type = BlockDeviceType()
mount_point = mapping.get('device_name')
if 'ephemeral' in mapping.get('virtual_name', ''):
block_type.ephemeral_name = mapping.get('virtual_name')
mount_point = mapping.get("device_name")
if "ephemeral" in mapping.get("virtual_name", ""):
block_type.ephemeral_name = mapping.get("virtual_name")
else:
block_type.volume_type = mapping.get('ebs._volume_type')
block_type.snapshot_id = mapping.get('ebs._snapshot_id')
block_type.volume_type = mapping.get("ebs._volume_type")
block_type.snapshot_id = mapping.get("ebs._snapshot_id")
block_type.delete_on_termination = mapping.get(
'ebs._delete_on_termination')
block_type.size = mapping.get('ebs._volume_size')
block_type.iops = mapping.get('ebs._iops')
"ebs._delete_on_termination"
)
block_type.size = mapping.get("ebs._volume_size")
block_type.iops = mapping.get("ebs._iops")
block_device_map[mount_point] = block_type
return block_device_map
class FakeAutoScalingGroup(BaseModel):
def __init__(self, name, availability_zones, desired_capacity, max_size,
min_size, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, health_check_type,
load_balancers, target_group_arns, placement_group, termination_policies,
autoscaling_backend, tags,
new_instances_protected_from_scale_in=False):
def __init__(
self,
name,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
autoscaling_backend,
tags,
new_instances_protected_from_scale_in=False,
):
self.autoscaling_backend = autoscaling_backend
self.name = name
@ -190,17 +245,22 @@ class FakeAutoScalingGroup(BaseModel):
self.min_size = min_size
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name]
launch_config_name
]
self.launch_config_name = launch_config_name
self.default_cooldown = default_cooldown if default_cooldown else DEFAULT_COOLDOWN
self.default_cooldown = (
default_cooldown if default_cooldown else DEFAULT_COOLDOWN
)
self.health_check_period = health_check_period
self.health_check_type = health_check_type if health_check_type else "EC2"
self.load_balancers = load_balancers
self.target_group_arns = target_group_arns
self.placement_group = placement_group
self.termination_policies = termination_policies
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in
self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
self.suspended_processes = []
self.instance_states = []
@ -215,8 +275,10 @@ class FakeAutoScalingGroup(BaseModel):
if vpc_zone_identifier:
# extract azs for vpcs
subnet_ids = vpc_zone_identifier.split(',')
subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(subnet_ids=subnet_ids)
subnet_ids = vpc_zone_identifier.split(",")
subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(
subnet_ids=subnet_ids
)
vpc_zones = [subnet.availability_zone for subnet in subnets]
if availability_zones and set(availability_zones) != set(vpc_zones):
@ -229,7 +291,7 @@ class FakeAutoScalingGroup(BaseModel):
if not update:
raise AutoscalingClientError(
"ValidationError",
"At least one Availability Zone or VPC Subnet is required."
"At least one Availability Zone or VPC Subnet is required.",
)
return
@ -237,8 +299,10 @@ class FakeAutoScalingGroup(BaseModel):
self.vpc_zone_identifier = vpc_zone_identifier
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
launch_config_name = properties.get("LaunchConfigurationName")
load_balancer_names = properties.get("LoadBalancerNames", [])
@ -253,7 +317,8 @@ class FakeAutoScalingGroup(BaseModel):
min_size=properties.get("MinSize"),
launch_config_name=launch_config_name,
vpc_zone_identifier=(
','.join(properties.get("VPCZoneIdentifier", [])) or None),
",".join(properties.get("VPCZoneIdentifier", [])) or None
),
default_cooldown=properties.get("Cooldown"),
health_check_period=properties.get("HealthCheckGracePeriod"),
health_check_type=properties.get("HealthCheckType"),
@ -263,18 +328,26 @@ class FakeAutoScalingGroup(BaseModel):
termination_policies=properties.get("TerminationPolicies", []),
tags=properties.get("Tags", []),
new_instances_protected_from_scale_in=properties.get(
"NewInstancesProtectedFromScaleIn", False)
"NewInstancesProtectedFromScaleIn", False
),
)
return group
@classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name):
def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name)
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name)
original_resource.name, cloudformation_json, region_name
)
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name]
try:
backend.delete_auto_scaling_group(resource_name)
@ -289,11 +362,21 @@ class FakeAutoScalingGroup(BaseModel):
def physical_resource_id(self):
return self.name
def update(self, availability_zones, desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier, default_cooldown,
health_check_period, health_check_type,
placement_group, termination_policies,
new_instances_protected_from_scale_in=None):
def update(
self,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=None,
):
self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier, update=True)
if max_size is not None:
@ -309,14 +392,17 @@ class FakeAutoScalingGroup(BaseModel):
if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name]
launch_config_name
]
self.launch_config_name = launch_config_name
if health_check_period is not None:
self.health_check_period = health_check_period
if health_check_type is not None:
self.health_check_type = health_check_type
if new_instances_protected_from_scale_in is not None:
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in
self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
if desired_capacity is not None:
self.set_desired_capacity(desired_capacity)
@ -342,25 +428,30 @@ class FakeAutoScalingGroup(BaseModel):
# Need to remove some instances
count_to_remove = curr_instance_count - self.desired_capacity
instances_to_remove = [ # only remove unprotected
state for state in self.instance_states
state
for state in self.instance_states
if not state.protected_from_scale_in
][:count_to_remove]
if instances_to_remove: # just in case not instances to remove
instance_ids_to_remove = [
instance.instance.id for instance in instances_to_remove]
instance.instance.id for instance in instances_to_remove
]
self.autoscaling_backend.ec2_backend.terminate_instances(
instance_ids_to_remove)
self.instance_states = list(set(self.instance_states) - set(instances_to_remove))
instance_ids_to_remove
)
self.instance_states = list(
set(self.instance_states) - set(instances_to_remove)
)
def get_propagated_tags(self):
propagated_tags = {}
for tag in self.tags:
# boto uses 'propagate_at_launch
# boto3 and cloudformation use PropagateAtLaunch
if 'propagate_at_launch' in tag and tag['propagate_at_launch'] == 'true':
propagated_tags[tag['key']] = tag['value']
if 'PropagateAtLaunch' in tag and tag['PropagateAtLaunch']:
propagated_tags[tag['Key']] = tag['Value']
if "propagate_at_launch" in tag and tag["propagate_at_launch"] == "true":
propagated_tags[tag["key"]] = tag["value"]
if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"]:
propagated_tags[tag["Key"]] = tag["Value"]
return propagated_tags
def replace_autoscaling_group_instances(self, count_needed, propagated_tags):
@ -371,15 +462,17 @@ class FakeAutoScalingGroup(BaseModel):
self.launch_config.user_data,
self.launch_config.security_groups,
instance_type=self.launch_config.instance_type,
tags={'instance': propagated_tags},
tags={"instance": propagated_tags},
placement=random.choice(self.availability_zones),
)
for instance in reservation.instances:
instance.autoscaling_group = self
self.instance_states.append(InstanceState(
instance,
protected_from_scale_in=self.new_instances_protected_from_scale_in,
))
self.instance_states.append(
InstanceState(
instance,
protected_from_scale_in=self.new_instances_protected_from_scale_in,
)
)
def append_target_groups(self, target_group_arns):
append = [x for x in target_group_arns if x not in self.target_group_arns]
@ -402,10 +495,23 @@ class AutoScalingBackend(BaseBackend):
self.__dict__ = {}
self.__init__(ec2_backend, elb_backend, elbv2_backend)
def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id,
security_groups, user_data, instance_type,
instance_monitoring, instance_profile_name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mappings):
def create_launch_configuration(
self,
name,
image_id,
key_name,
kernel_id,
ramdisk_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mappings,
):
launch_configuration = FakeLaunchConfiguration(
name=name,
image_id=image_id,
@ -428,23 +534,37 @@ class AutoScalingBackend(BaseBackend):
def describe_launch_configurations(self, names):
configurations = self.launch_configurations.values()
if names:
return [configuration for configuration in configurations if configuration.name in names]
return [
configuration
for configuration in configurations
if configuration.name in names
]
else:
return list(configurations)
def delete_launch_configuration(self, launch_configuration_name):
self.launch_configurations.pop(launch_configuration_name, None)
def create_auto_scaling_group(self, name, availability_zones,
desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period,
health_check_type, load_balancers,
target_group_arns, placement_group,
termination_policies, tags,
new_instances_protected_from_scale_in=False,
instance_id=None):
def create_auto_scaling_group(
self,
name,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
tags,
new_instances_protected_from_scale_in=False,
instance_id=None,
):
def make_int(value):
return int(value) if value is not None else value
@ -460,7 +580,9 @@ class AutoScalingBackend(BaseBackend):
try:
instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name
FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self)
FakeLaunchConfiguration.create_from_instance(
launch_config_name, instance, self
)
except InvalidInstanceIdError:
raise InvalidInstanceError(instance_id)
@ -489,19 +611,37 @@ class AutoScalingBackend(BaseBackend):
self.update_attached_target_groups(group.name)
return group
def update_auto_scaling_group(self, name, availability_zones,
desired_capacity, max_size, min_size,
launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period,
health_check_type, placement_group,
termination_policies,
new_instances_protected_from_scale_in=None):
def update_auto_scaling_group(
self,
name,
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=None,
):
group = self.autoscaling_groups[name]
group.update(availability_zones, desired_capacity, max_size,
min_size, launch_config_name, vpc_zone_identifier,
default_cooldown, health_check_period, health_check_type,
placement_group, termination_policies,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in)
group.update(
availability_zones,
desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in,
)
return group
def describe_auto_scaling_groups(self, names):
@ -537,32 +677,48 @@ class AutoScalingBackend(BaseBackend):
for x in instance_ids
]
for instance in new_instances:
self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name})
self.ec2_backend.create_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
group.instance_states.extend(new_instances)
self.update_attached_elbs(group.name)
def set_instance_health(self, instance_id, health_status, should_respect_grace_period):
def set_instance_health(
self, instance_id, health_status, should_respect_grace_period
):
instance = self.ec2_backend.get_instance(instance_id)
instance_state = next(instance_state for group in self.autoscaling_groups.values()
for instance_state in group.instance_states if instance_state.instance.id == instance.id)
instance_state = next(
instance_state
for group in self.autoscaling_groups.values()
for instance_state in group.instance_states
if instance_state.instance.id == instance.id
)
instance_state.health_status = health_status
def detach_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states)
detached_instances = [x for x in group.instance_states if x.instance.id in instance_ids]
detached_instances = [
x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in detached_instances:
self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name})
self.ec2_backend.delete_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
new_instance_state = [x for x in group.instance_states if x.instance.id not in instance_ids]
new_instance_state = [
x for x in group.instance_states if x.instance.id not in instance_ids
]
group.instance_states = new_instance_state
if should_decrement:
group.desired_capacity = original_size - len(instance_ids)
else:
count_needed = len(instance_ids)
group.replace_autoscaling_group_instances(count_needed, group.get_propagated_tags())
group.replace_autoscaling_group_instances(
count_needed, group.get_propagated_tags()
)
self.update_attached_elbs(group_name)
return detached_instances
@ -593,19 +749,32 @@ class AutoScalingBackend(BaseBackend):
desired_capacity = int(desired_capacity)
self.set_desired_capacity(group_name, desired_capacity)
def create_autoscaling_policy(self, name, policy_type, adjustment_type, as_name,
scaling_adjustment, cooldown):
policy = FakeScalingPolicy(name, policy_type, adjustment_type, as_name,
scaling_adjustment, cooldown, self)
def create_autoscaling_policy(
self, name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown
):
policy = FakeScalingPolicy(
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
self,
)
self.policies[name] = policy
return policy
def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None):
return [policy for policy in self.policies.values()
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and
(not policy_names or policy.name in policy_names) and
(not policy_types or policy.policy_type in policy_types)]
def describe_policies(
self, autoscaling_group_name=None, policy_names=None, policy_types=None
):
return [
policy
for policy in self.policies.values()
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name)
and (not policy_names or policy.name in policy_names)
and (not policy_types or policy.policy_type in policy_types)
]
def delete_policy(self, group_name):
self.policies.pop(group_name, None)
@ -616,16 +785,14 @@ class AutoScalingBackend(BaseBackend):
def update_attached_elbs(self, group_name):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
group_instance_ids = set(state.instance.id for state in group.instance_states)
# skip this if group.load_balancers is empty
# otherwise elb_backend.describe_load_balancers returns all available load balancers
if not group.load_balancers:
return
try:
elbs = self.elb_backend.describe_load_balancers(
names=group.load_balancers)
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
except LoadBalancerNotFoundError:
# ELBs can be deleted before their autoscaling group
return
@ -633,14 +800,15 @@ class AutoScalingBackend(BaseBackend):
for elb in elbs:
elb_instace_ids = set(elb.instance_ids)
self.elb_backend.register_instances(
elb.name, group_instance_ids - elb_instace_ids)
elb.name, group_instance_ids - elb_instace_ids
)
self.elb_backend.deregister_instances(
elb.name, elb_instace_ids - group_instance_ids)
elb.name, elb_instace_ids - group_instance_ids
)
def update_attached_target_groups(self, group_name):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
group_instance_ids = set(state.instance.id for state in group.instance_states)
# no action necessary if target_group_arns is empty
if not group.target_group_arns:
@ -649,10 +817,13 @@ class AutoScalingBackend(BaseBackend):
target_groups = self.elbv2_backend.describe_target_groups(
target_group_arns=group.target_group_arns,
load_balancer_arn=None,
names=None)
names=None,
)
for target_group in target_groups:
asg_targets = [{'id': x, 'port': target_group.port} for x in group_instance_ids]
asg_targets = [
{"id": x, "port": target_group.port} for x in group_instance_ids
]
self.elbv2_backend.register_targets(target_group.arn, (asg_targets))
def create_or_update_tags(self, tags):
@ -670,7 +841,7 @@ class AutoScalingBackend(BaseBackend):
new_tags.append(old_tag)
# if key was never in old_tag's add it (create tag)
if not any(new_tag['key'] == tag['key'] for new_tag in new_tags):
if not any(new_tag["key"] == tag["key"] for new_tag in new_tags):
new_tags.append(tag)
group.tags = new_tags
@ -678,7 +849,8 @@ class AutoScalingBackend(BaseBackend):
def attach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name]
group.load_balancers.extend(
[x for x in load_balancer_names if x not in group.load_balancers])
[x for x in load_balancer_names if x not in group.load_balancers]
)
self.update_attached_elbs(group_name)
def describe_load_balancers(self, group_name):
@ -686,13 +858,13 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name]
group_instance_ids = set(
state.instance.id for state in group.instance_states)
group_instance_ids = set(state.instance.id for state in group.instance_states)
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
for elb in elbs:
self.elb_backend.deregister_instances(
elb.name, group_instance_ids)
group.load_balancers = [x for x in group.load_balancers if x not in load_balancer_names]
self.elb_backend.deregister_instances(elb.name, group_instance_ids)
group.load_balancers = [
x for x in group.load_balancers if x not in load_balancer_names
]
def attach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name]
@ -704,36 +876,51 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name]
group.target_group_arns = [x for x in group.target_group_arns if x not in target_group_arns]
group.target_group_arns = [
x for x in group.target_group_arns if x not in target_group_arns
]
for target_group in target_group_arns:
asg_targets = [{'id': x.instance.id} for x in group.instance_states]
asg_targets = [{"id": x.instance.id} for x in group.instance_states]
self.elbv2_backend.deregister_targets(target_group, (asg_targets))
def suspend_processes(self, group_name, scaling_processes):
group = self.autoscaling_groups[group_name]
group.suspended_processes = scaling_processes or []
def set_instance_protection(self, group_name, instance_ids, protected_from_scale_in):
def set_instance_protection(
self, group_name, instance_ids, protected_from_scale_in
):
group = self.autoscaling_groups[group_name]
protected_instances = [
x for x in group.instance_states if x.instance.id in instance_ids]
x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in protected_instances:
instance.protected_from_scale_in = protected_from_scale_in
def notify_terminate_instances(self, instance_ids):
for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items():
for (
autoscaling_group_name,
autoscaling_group,
) in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states)
autoscaling_group.instance_states = list(filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states = list(
filter(
lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states,
)
)
difference = original_instance_count - len(
autoscaling_group.instance_states
))
difference = original_instance_count - len(autoscaling_group.instance_states)
)
if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags())
autoscaling_group.replace_autoscaling_group_instances(
difference, autoscaling_group.get_propagated_tags()
)
self.update_attached_elbs(autoscaling_group_name)
autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items():
autoscaling_backends[region] = AutoScalingBackend(
ec2_backend, elb_backends[region], elbv2_backends[region])
ec2_backend, elb_backends[region], elbv2_backends[region]
)

View File

@ -6,88 +6,88 @@ from .models import autoscaling_backends
class AutoScalingResponse(BaseResponse):
@property
def autoscaling_backend(self):
return autoscaling_backends[self.region]
def create_launch_configuration(self):
instance_monitoring_string = self._get_param(
'InstanceMonitoring.Enabled')
if instance_monitoring_string == 'true':
instance_monitoring_string = self._get_param("InstanceMonitoring.Enabled")
if instance_monitoring_string == "true":
instance_monitoring = True
else:
instance_monitoring = False
self.autoscaling_backend.create_launch_configuration(
name=self._get_param('LaunchConfigurationName'),
image_id=self._get_param('ImageId'),
key_name=self._get_param('KeyName'),
ramdisk_id=self._get_param('RamdiskId'),
kernel_id=self._get_param('KernelId'),
security_groups=self._get_multi_param('SecurityGroups.member'),
user_data=self._get_param('UserData'),
instance_type=self._get_param('InstanceType'),
name=self._get_param("LaunchConfigurationName"),
image_id=self._get_param("ImageId"),
key_name=self._get_param("KeyName"),
ramdisk_id=self._get_param("RamdiskId"),
kernel_id=self._get_param("KernelId"),
security_groups=self._get_multi_param("SecurityGroups.member"),
user_data=self._get_param("UserData"),
instance_type=self._get_param("InstanceType"),
instance_monitoring=instance_monitoring,
instance_profile_name=self._get_param('IamInstanceProfile'),
spot_price=self._get_param('SpotPrice'),
ebs_optimized=self._get_param('EbsOptimized'),
associate_public_ip_address=self._get_param(
"AssociatePublicIpAddress"),
block_device_mappings=self._get_list_prefix(
'BlockDeviceMappings.member')
instance_profile_name=self._get_param("IamInstanceProfile"),
spot_price=self._get_param("SpotPrice"),
ebs_optimized=self._get_param("EbsOptimized"),
associate_public_ip_address=self._get_param("AssociatePublicIpAddress"),
block_device_mappings=self._get_list_prefix("BlockDeviceMappings.member"),
)
template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render()
def describe_launch_configurations(self):
names = self._get_multi_param('LaunchConfigurationNames.member')
all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(names)
marker = self._get_param('NextToken')
names = self._get_multi_param("LaunchConfigurationNames.member")
all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(
names
)
marker = self._get_param("NextToken")
all_names = [lc.name for lc in all_launch_configurations]
if marker:
start = all_names.index(marker) + 1
else:
start = 0
max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[start:start + max_records]
max_records = self._get_int_param(
"MaxRecords", 50
) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[
start : start + max_records
]
next_token = None
if len(all_launch_configurations) > start + max_records:
next_token = launch_configurations_resp[-1].name
template = self.response_template(
DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE)
return template.render(launch_configurations=launch_configurations_resp, next_token=next_token)
template = self.response_template(DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE)
return template.render(
launch_configurations=launch_configurations_resp, next_token=next_token
)
def delete_launch_configuration(self):
launch_configurations_name = self.querystring.get(
'LaunchConfigurationName')[0]
self.autoscaling_backend.delete_launch_configuration(
launch_configurations_name)
launch_configurations_name = self.querystring.get("LaunchConfigurationName")[0]
self.autoscaling_backend.delete_launch_configuration(launch_configurations_name)
template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render()
def create_auto_scaling_group(self):
self.autoscaling_backend.create_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'),
availability_zones=self._get_multi_param(
'AvailabilityZones.member'),
desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'),
instance_id=self._get_param('InstanceId'),
launch_config_name=self._get_param('LaunchConfigurationName'),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'),
default_cooldown=self._get_int_param('DefaultCooldown'),
health_check_period=self._get_int_param('HealthCheckGracePeriod'),
health_check_type=self._get_param('HealthCheckType'),
load_balancers=self._get_multi_param('LoadBalancerNames.member'),
target_group_arns=self._get_multi_param('TargetGroupARNs.member'),
placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
tags=self._get_list_prefix('Tags.member'),
name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param("AvailabilityZones.member"),
desired_capacity=self._get_int_param("DesiredCapacity"),
max_size=self._get_int_param("MaxSize"),
min_size=self._get_int_param("MinSize"),
instance_id=self._get_param("InstanceId"),
launch_config_name=self._get_param("LaunchConfigurationName"),
vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
default_cooldown=self._get_int_param("DefaultCooldown"),
health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_type=self._get_param("HealthCheckType"),
load_balancers=self._get_multi_param("LoadBalancerNames.member"),
target_group_arns=self._get_multi_param("TargetGroupARNs.member"),
placement_group=self._get_param("PlacementGroup"),
termination_policies=self._get_multi_param("TerminationPolicies.member"),
tags=self._get_list_prefix("Tags.member"),
new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', False)
"NewInstancesProtectedFromScaleIn", False
),
)
template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render()
@ -95,68 +95,73 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def attach_instances(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
self.autoscaling_backend.attach_instances(
group_name, instance_ids)
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
self.autoscaling_backend.attach_instances(group_name, instance_ids)
template = self.response_template(ATTACH_INSTANCES_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def set_instance_health(self):
instance_id = self._get_param('InstanceId')
instance_id = self._get_param("InstanceId")
health_status = self._get_param("HealthStatus")
if health_status not in ['Healthy', 'Unhealthy']:
raise ValueError('Valid instance health states are: [Healthy, Unhealthy]')
if health_status not in ["Healthy", "Unhealthy"]:
raise ValueError("Valid instance health states are: [Healthy, Unhealthy]")
should_respect_grace_period = self._get_param("ShouldRespectGracePeriod")
self.autoscaling_backend.set_instance_health(instance_id, health_status, should_respect_grace_period)
self.autoscaling_backend.set_instance_health(
instance_id, health_status, should_respect_grace_period
)
template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def detach_instances(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
should_decrement_string = self._get_param('ShouldDecrementDesiredCapacity')
if should_decrement_string == 'true':
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == "true":
should_decrement = True
else:
should_decrement = False
detached_instances = self.autoscaling_backend.detach_instances(
group_name, instance_ids, should_decrement)
group_name, instance_ids, should_decrement
)
template = self.response_template(DETACH_INSTANCES_TEMPLATE)
return template.render(detached_instances=detached_instances)
@amz_crc32
@amzn_request_id
def attach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
target_group_arns = self._get_multi_param('TargetGroupARNs.member')
group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.attach_load_balancer_target_groups(
group_name, target_group_arns)
group_name, target_group_arns
)
template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def describe_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups(
group_name)
group_name
)
template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS)
return template.render(target_group_arns=target_group_arns)
@amz_crc32
@amzn_request_id
def detach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName')
target_group_arns = self._get_multi_param('TargetGroupARNs.member')
group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.detach_load_balancer_target_groups(
group_name, target_group_arns)
group_name, target_group_arns
)
template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render()
@ -172,7 +177,7 @@ class AutoScalingResponse(BaseResponse):
max_records = self._get_int_param("MaxRecords", 50)
if max_records > 100:
raise ValueError
groups = all_groups[start:start + max_records]
groups = all_groups[start : start + max_records]
next_token = None
if max_records and len(all_groups) > start + max_records:
next_token = groups[-1].name
@ -181,42 +186,40 @@ class AutoScalingResponse(BaseResponse):
def update_auto_scaling_group(self):
self.autoscaling_backend.update_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'),
availability_zones=self._get_multi_param(
'AvailabilityZones.member'),
desired_capacity=self._get_int_param('DesiredCapacity'),
max_size=self._get_int_param('MaxSize'),
min_size=self._get_int_param('MinSize'),
launch_config_name=self._get_param('LaunchConfigurationName'),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'),
default_cooldown=self._get_int_param('DefaultCooldown'),
health_check_period=self._get_int_param('HealthCheckGracePeriod'),
health_check_type=self._get_param('HealthCheckType'),
placement_group=self._get_param('PlacementGroup'),
termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param("AvailabilityZones.member"),
desired_capacity=self._get_int_param("DesiredCapacity"),
max_size=self._get_int_param("MaxSize"),
min_size=self._get_int_param("MinSize"),
launch_config_name=self._get_param("LaunchConfigurationName"),
vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
default_cooldown=self._get_int_param("DefaultCooldown"),
health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_type=self._get_param("HealthCheckType"),
placement_group=self._get_param("PlacementGroup"),
termination_policies=self._get_multi_param("TerminationPolicies.member"),
new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', None)
"NewInstancesProtectedFromScaleIn", None
),
)
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render()
def delete_auto_scaling_group(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
self.autoscaling_backend.delete_auto_scaling_group(group_name)
template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE)
return template.render()
def set_desired_capacity(self):
group_name = self._get_param('AutoScalingGroupName')
desired_capacity = self._get_int_param('DesiredCapacity')
self.autoscaling_backend.set_desired_capacity(
group_name, desired_capacity)
group_name = self._get_param("AutoScalingGroupName")
desired_capacity = self._get_int_param("DesiredCapacity")
self.autoscaling_backend.set_desired_capacity(group_name, desired_capacity)
template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE)
return template.render()
def create_or_update_tags(self):
tags = self._get_list_prefix('Tags.member')
tags = self._get_list_prefix("Tags.member")
self.autoscaling_backend.create_or_update_tags(tags)
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
@ -224,38 +227,38 @@ class AutoScalingResponse(BaseResponse):
def describe_auto_scaling_instances(self):
instance_states = self.autoscaling_backend.describe_auto_scaling_instances()
template = self.response_template(
DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
return template.render(instance_states=instance_states)
def put_scaling_policy(self):
policy = self.autoscaling_backend.create_autoscaling_policy(
name=self._get_param('PolicyName'),
policy_type=self._get_param('PolicyType'),
adjustment_type=self._get_param('AdjustmentType'),
as_name=self._get_param('AutoScalingGroupName'),
scaling_adjustment=self._get_int_param('ScalingAdjustment'),
cooldown=self._get_int_param('Cooldown'),
name=self._get_param("PolicyName"),
policy_type=self._get_param("PolicyType"),
adjustment_type=self._get_param("AdjustmentType"),
as_name=self._get_param("AutoScalingGroupName"),
scaling_adjustment=self._get_int_param("ScalingAdjustment"),
cooldown=self._get_int_param("Cooldown"),
)
template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE)
return template.render(policy=policy)
def describe_policies(self):
policies = self.autoscaling_backend.describe_policies(
autoscaling_group_name=self._get_param('AutoScalingGroupName'),
policy_names=self._get_multi_param('PolicyNames.member'),
policy_types=self._get_multi_param('PolicyTypes.member'))
autoscaling_group_name=self._get_param("AutoScalingGroupName"),
policy_names=self._get_multi_param("PolicyNames.member"),
policy_types=self._get_multi_param("PolicyTypes.member"),
)
template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE)
return template.render(policies=policies)
def delete_policy(self):
group_name = self._get_param('PolicyName')
group_name = self._get_param("PolicyName")
self.autoscaling_backend.delete_policy(group_name)
template = self.response_template(DELETE_POLICY_TEMPLATE)
return template.render()
def execute_policy(self):
group_name = self._get_param('PolicyName')
group_name = self._get_param("PolicyName")
self.autoscaling_backend.execute_policy(group_name)
template = self.response_template(EXECUTE_POLICY_TEMPLATE)
return template.render()
@ -263,17 +266,16 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def attach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.attach_load_balancers(
group_name, load_balancer_names)
self.autoscaling_backend.attach_load_balancers(group_name, load_balancer_names)
template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
@amz_crc32
@amzn_request_id
def describe_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
load_balancers = self.autoscaling_backend.describe_load_balancers(group_name)
template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE)
return template.render(load_balancers=load_balancers)
@ -281,26 +283,28 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32
@amzn_request_id
def detach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName')
group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.detach_load_balancers(
group_name, load_balancer_names)
self.autoscaling_backend.detach_load_balancers(group_name, load_balancer_names)
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
return template.render()
def suspend_processes(self):
autoscaling_group_name = self._get_param('AutoScalingGroupName')
scaling_processes = self._get_multi_param('ScalingProcesses.member')
self.autoscaling_backend.suspend_processes(autoscaling_group_name, scaling_processes)
autoscaling_group_name = self._get_param("AutoScalingGroupName")
scaling_processes = self._get_multi_param("ScalingProcesses.member")
self.autoscaling_backend.suspend_processes(
autoscaling_group_name, scaling_processes
)
template = self.response_template(SUSPEND_PROCESSES_TEMPLATE)
return template.render()
def set_instance_protection(self):
group_name = self._get_param('AutoScalingGroupName')
instance_ids = self._get_multi_param('InstanceIds.member')
protected_from_scale_in = self._get_bool_param('ProtectedFromScaleIn')
group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param("InstanceIds.member")
protected_from_scale_in = self._get_bool_param("ProtectedFromScaleIn")
self.autoscaling_backend.set_instance_protection(
group_name, instance_ids, protected_from_scale_in)
group_name, instance_ids, protected_from_scale_in
)
template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE)
return template.render()

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import AutoScalingResponse
url_bases = [
"https?://autoscaling.(.+).amazonaws.com",
]
url_bases = ["https?://autoscaling.(.+).amazonaws.com"]
url_paths = {
'{0}/$': AutoScalingResponse.dispatch,
}
url_paths = {"{0}/$": AutoScalingResponse.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import lambda_backends
from ..core.models import base_decorator, deprecated_base_decorator
lambda_backend = lambda_backends['us-east-1']
lambda_backend = lambda_backends["us-east-1"]
mock_lambda = base_decorator(lambda_backends)
mock_lambda_deprecated = deprecated_base_decorator(lambda_backends)

View File

@ -38,7 +38,7 @@ from moto.dynamodbstreams import dynamodbstreams_backends
logger = logging.getLogger(__name__)
ACCOUNT_ID = '123456789012'
ACCOUNT_ID = "123456789012"
try:
@ -47,20 +47,22 @@ except ImportError:
from backports.tempfile import TemporaryDirectory
_stderr_regex = re.compile(r'START|END|REPORT RequestId: .*')
_stderr_regex = re.compile(r"START|END|REPORT RequestId: .*")
_orig_adapter_send = requests.adapters.HTTPAdapter.send
docker_3 = docker.__version__[0] >= '3'
docker_3 = docker.__version__[0] >= "3"
def zip2tar(zip_bytes):
with TemporaryDirectory() as td:
tarname = os.path.join(td, 'data.tar')
timeshift = int((datetime.datetime.now() -
datetime.datetime.utcnow()).total_seconds())
with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zipf, \
tarfile.TarFile(tarname, 'w') as tarf:
tarname = os.path.join(td, "data.tar")
timeshift = int(
(datetime.datetime.now() - datetime.datetime.utcnow()).total_seconds()
)
with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zipf, tarfile.TarFile(
tarname, "w"
) as tarf:
for zipinfo in zipf.infolist():
if zipinfo.filename[-1] == '/': # is_dir() is py3.6+
if zipinfo.filename[-1] == "/": # is_dir() is py3.6+
continue
tarinfo = tarfile.TarInfo(name=zipinfo.filename)
@ -69,7 +71,7 @@ def zip2tar(zip_bytes):
infile = zipf.open(zipinfo.filename)
tarf.addfile(tarinfo, infile)
with open(tarname, 'rb') as f:
with open(tarname, "rb") as f:
tar_data = f.read()
return tar_data
@ -83,7 +85,9 @@ class _VolumeRefCount:
class _DockerDataVolumeContext:
_data_vol_map = defaultdict(lambda: _VolumeRefCount(0, None)) # {sha256: _VolumeRefCount}
_data_vol_map = defaultdict(
lambda: _VolumeRefCount(0, None)
) # {sha256: _VolumeRefCount}
_lock = threading.Lock()
def __init__(self, lambda_func):
@ -109,15 +113,19 @@ class _DockerDataVolumeContext:
return self
# It doesn't exist so we need to create it
self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(self._lambda_func.code_sha_256)
self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(
self._lambda_func.code_sha_256
)
if docker_3:
volumes = {self.name: {'bind': '/tmp/data', 'mode': 'rw'}}
volumes = {self.name: {"bind": "/tmp/data", "mode": "rw"}}
else:
volumes = {self.name: '/tmp/data'}
container = self._lambda_func.docker_client.containers.run('alpine', 'sleep 100', volumes=volumes, detach=True)
volumes = {self.name: "/tmp/data"}
container = self._lambda_func.docker_client.containers.run(
"alpine", "sleep 100", volumes=volumes, detach=True
)
try:
tar_bytes = zip2tar(self._lambda_func.code_bytes)
container.put_archive('/tmp/data', tar_bytes)
container.put_archive("/tmp/data", tar_bytes)
finally:
container.remove(force=True)
@ -140,13 +148,13 @@ class LambdaFunction(BaseModel):
def __init__(self, spec, region, validate_s3=True, version=1):
# required
self.region = region
self.code = spec['Code']
self.function_name = spec['FunctionName']
self.handler = spec['Handler']
self.role = spec['Role']
self.run_time = spec['Runtime']
self.code = spec["Code"]
self.function_name = spec["FunctionName"]
self.handler = spec["Handler"]
self.role = spec["Role"]
self.run_time = spec["Runtime"]
self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get('Environment', {}).get('Variables', {})
self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.docker_client = docker.from_env()
self.policy = ""
@ -161,77 +169,82 @@ class LambdaFunction(BaseModel):
if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter
self.docker_client.api.get_adapter = replace_adapter_send
# optional
self.description = spec.get('Description', '')
self.memory_size = spec.get('MemorySize', 128)
self.publish = spec.get('Publish', False) # this is ignored currently
self.timeout = spec.get('Timeout', 3)
self.description = spec.get("Description", "")
self.memory_size = spec.get("MemorySize", 128)
self.publish = spec.get("Publish", False) # this is ignored currently
self.timeout = spec.get("Timeout", 3)
self.logs_group_name = '/aws/lambda/{}'.format(self.function_name)
self.logs_group_name = "/aws/lambda/{}".format(self.function_name)
self.logs_backend.ensure_log_group(self.logs_group_name, [])
# this isn't finished yet. it needs to find out the VpcId value
self._vpc_config = spec.get(
'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []})
"VpcConfig", {"SubnetIds": [], "SecurityGroupIds": []}
)
# auto-generated
self.version = version
self.last_modified = datetime.datetime.utcnow().strftime(
'%Y-%m-%d %H:%M:%S')
self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
if 'ZipFile' in self.code:
if "ZipFile" in self.code:
# more hackery to handle unicode/bytes/str in python3 and python2 -
# argh!
try:
to_unzip_code = base64.b64decode(
bytes(self.code['ZipFile'], 'utf-8'))
to_unzip_code = base64.b64decode(bytes(self.code["ZipFile"], "utf-8"))
except Exception:
to_unzip_code = base64.b64decode(self.code['ZipFile'])
to_unzip_code = base64.b64decode(self.code["ZipFile"])
self.code_bytes = to_unzip_code
self.code_size = len(to_unzip_code)
self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest()
# TODO: we should be putting this in a lambda bucket
self.code['UUID'] = str(uuid.uuid4())
self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID'])
self.code["UUID"] = str(uuid.uuid4())
self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"])
else:
# validate s3 bucket and key
key = None
try:
# FIXME: does not validate bucket region
key = s3_backend.get_key(
self.code['S3Bucket'], self.code['S3Key'])
key = s3_backend.get_key(self.code["S3Bucket"], self.code["S3Key"])
except MissingBucket:
if do_validate_s3():
raise ValueError(
"InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist")
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist",
)
except MissingKey:
if do_validate_s3():
raise ValueError(
"InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.")
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.",
)
if key:
self.code_bytes = key.value
self.code_size = key.size
self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.function_arn = make_function_arn(self.region, ACCOUNT_ID, self.function_name)
self.function_arn = make_function_arn(
self.region, ACCOUNT_ID, self.function_name
)
self.tags = dict()
def set_version(self, version):
self.function_arn = make_function_ver_arn(self.region, ACCOUNT_ID, self.function_name, version)
self.function_arn = make_function_ver_arn(
self.region, ACCOUNT_ID, self.function_name, version
)
self.version = version
self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
@property
def vpc_config(self):
config = self._vpc_config.copy()
if config['SecurityGroupIds']:
if config["SecurityGroupIds"]:
config.update({"VpcId": "vpc-123abc"})
return config
@ -260,17 +273,17 @@ class LambdaFunction(BaseModel):
}
if self.environment_vars:
config['Environment'] = {
'Variables': self.environment_vars
}
config["Environment"] = {"Variables": self.environment_vars}
return config
def get_code(self):
return {
"Code": {
"Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(self.region, self.code['S3Key']),
"RepositoryType": "S3"
"Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(
self.region, self.code["S3Key"]
),
"RepositoryType": "S3",
},
"Configuration": self.get_configuration(),
}
@ -297,43 +310,48 @@ class LambdaFunction(BaseModel):
return self.get_configuration()
def update_function_code(self, updated_spec):
if 'DryRun' in updated_spec and updated_spec['DryRun']:
if "DryRun" in updated_spec and updated_spec["DryRun"]:
return self.get_configuration()
if 'ZipFile' in updated_spec:
self.code['ZipFile'] = updated_spec['ZipFile']
if "ZipFile" in updated_spec:
self.code["ZipFile"] = updated_spec["ZipFile"]
# using the "hackery" from __init__ because it seems to work
# TODOs and FIXMEs included, because they'll need to be fixed
# in both places now
try:
to_unzip_code = base64.b64decode(
bytes(updated_spec['ZipFile'], 'utf-8'))
bytes(updated_spec["ZipFile"], "utf-8")
)
except Exception:
to_unzip_code = base64.b64decode(updated_spec['ZipFile'])
to_unzip_code = base64.b64decode(updated_spec["ZipFile"])
self.code_bytes = to_unzip_code
self.code_size = len(to_unzip_code)
self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest()
# TODO: we should be putting this in a lambda bucket
self.code['UUID'] = str(uuid.uuid4())
self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID'])
elif 'S3Bucket' in updated_spec and 'S3Key' in updated_spec:
self.code["UUID"] = str(uuid.uuid4())
self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"])
elif "S3Bucket" in updated_spec and "S3Key" in updated_spec:
key = None
try:
# FIXME: does not validate bucket region
key = s3_backend.get_key(updated_spec['S3Bucket'], updated_spec['S3Key'])
key = s3_backend.get_key(
updated_spec["S3Bucket"], updated_spec["S3Key"]
)
except MissingBucket:
if do_validate_s3():
raise ValueError(
"InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist")
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist",
)
except MissingKey:
if do_validate_s3():
raise ValueError(
"InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.")
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.",
)
if key:
self.code_bytes = key.value
self.code_size = key.size
@ -344,7 +362,7 @@ class LambdaFunction(BaseModel):
@staticmethod
def convert(s):
try:
return str(s, encoding='utf-8')
return str(s, encoding="utf-8")
except Exception:
return s
@ -372,12 +390,21 @@ class LambdaFunction(BaseModel):
container = output = exit_code = None
with _DockerDataVolumeContext(self) as data_vol:
try:
run_kwargs = dict(links={'motoserver': 'motoserver'}) if settings.TEST_SERVER_MODE else {}
run_kwargs = (
dict(links={"motoserver": "motoserver"})
if settings.TEST_SERVER_MODE
else {}
)
container = self.docker_client.containers.run(
"lambci/lambda:{}".format(self.run_time),
[self.handler, json.dumps(event)], remove=False,
[self.handler, json.dumps(event)],
remove=False,
mem_limit="{}m".format(self.memory_size),
volumes=["{}:/var/task".format(data_vol.name)], environment=env_vars, detach=True, **run_kwargs)
volumes=["{}:/var/task".format(data_vol.name)],
environment=env_vars,
detach=True,
**run_kwargs
)
finally:
if container:
try:
@ -388,32 +415,43 @@ class LambdaFunction(BaseModel):
container.kill()
else:
if docker_3:
exit_code = exit_code['StatusCode']
exit_code = exit_code["StatusCode"]
output = container.logs(stdout=False, stderr=True)
output += container.logs(stdout=True, stderr=False)
container.remove()
output = output.decode('utf-8')
output = output.decode("utf-8")
# Send output to "logs" backend
invoke_id = uuid.uuid4().hex
log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format(
date=datetime.datetime.utcnow(), version=self.version, invoke_id=invoke_id
date=datetime.datetime.utcnow(),
version=self.version,
invoke_id=invoke_id,
)
self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name)
log_events = [{'timestamp': unix_time_millis(), "message": line}
for line in output.splitlines()]
self.logs_backend.put_log_events(self.logs_group_name, log_stream_name, log_events, None)
log_events = [
{"timestamp": unix_time_millis(), "message": line}
for line in output.splitlines()
]
self.logs_backend.put_log_events(
self.logs_group_name, log_stream_name, log_events, None
)
if exit_code != 0:
raise Exception(
'lambda invoke failed output: {}'.format(output))
raise Exception("lambda invoke failed output: {}".format(output))
# strip out RequestId lines
output = os.linesep.join([line for line in self.convert(output).splitlines() if not _stderr_regex.match(line)])
output = os.linesep.join(
[
line
for line in self.convert(output).splitlines()
if not _stderr_regex.match(line)
]
)
return output, False
except BaseException as e:
traceback.print_exc()
@ -428,31 +466,34 @@ class LambdaFunction(BaseModel):
# Get the invocation type:
res, errored = self._invoke_lambda(code=self.code, event=body)
if request_headers.get("x-amz-invocation-type") == "RequestResponse":
encoded = base64.b64encode(res.encode('utf-8'))
response_headers["x-amz-log-result"] = encoded.decode('utf-8')
payload['result'] = response_headers["x-amz-log-result"]
result = res.encode('utf-8')
encoded = base64.b64encode(res.encode("utf-8"))
response_headers["x-amz-log-result"] = encoded.decode("utf-8")
payload["result"] = response_headers["x-amz-log-result"]
result = res.encode("utf-8")
else:
result = json.dumps(payload)
if errored:
response_headers['x-amz-function-error'] = "Handled"
response_headers["x-amz-function-error"] = "Handled"
return result
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
# required
spec = {
'Code': properties['Code'],
'FunctionName': resource_name,
'Handler': properties['Handler'],
'Role': properties['Role'],
'Runtime': properties['Runtime'],
"Code": properties["Code"],
"FunctionName": resource_name,
"Handler": properties["Handler"],
"Role": properties["Role"],
"Runtime": properties["Runtime"],
}
optional_properties = 'Description MemorySize Publish Timeout VpcConfig Environment'.split()
optional_properties = (
"Description MemorySize Publish Timeout VpcConfig Environment".split()
)
# NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the
# default logic
for prop in optional_properties:
@ -462,27 +503,27 @@ class LambdaFunction(BaseModel):
# when ZipFile is present in CloudFormation, per the official docs,
# the code it's a plaintext code snippet up to 4096 bytes.
# this snippet converts this plaintext code to a proper base64-encoded ZIP file.
if 'ZipFile' in properties['Code']:
spec['Code']['ZipFile'] = base64.b64encode(
cls._create_zipfile_from_plaintext_code(
spec['Code']['ZipFile']))
if "ZipFile" in properties["Code"]:
spec["Code"]["ZipFile"] = base64.b64encode(
cls._create_zipfile_from_plaintext_code(spec["Code"]["ZipFile"])
)
backend = lambda_backends[region_name]
fn = backend.create_function(spec)
return fn
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import \
UnformattedGetAttTemplateException
if attribute_name == 'Arn':
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == "Arn":
return make_function_arn(self.region, ACCOUNT_ID, self.function_name)
raise UnformattedGetAttTemplateException()
@staticmethod
def _create_zipfile_from_plaintext_code(code):
zip_output = io.BytesIO()
zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED)
zip_file.writestr('lambda_function.zip', code)
zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED)
zip_file.writestr("lambda_function.zip", code)
zip_file.close()
zip_output.seek(0)
return zip_output.read()
@ -491,61 +532,66 @@ class LambdaFunction(BaseModel):
class EventSourceMapping(BaseModel):
def __init__(self, spec):
# required
self.function_arn = spec['FunctionArn']
self.event_source_arn = spec['EventSourceArn']
self.function_arn = spec["FunctionArn"]
self.event_source_arn = spec["EventSourceArn"]
self.uuid = str(uuid.uuid4())
self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
# BatchSize service default/max mapping
batch_size_map = {
'kinesis': (100, 10000),
'dynamodb': (100, 1000),
'sqs': (10, 10),
"kinesis": (100, 10000),
"dynamodb": (100, 1000),
"sqs": (10, 10),
}
source_type = self.event_source_arn.split(":")[2].lower()
batch_size_entry = batch_size_map.get(source_type)
if batch_size_entry:
# Use service default if not provided
batch_size = int(spec.get('BatchSize', batch_size_entry[0]))
batch_size = int(spec.get("BatchSize", batch_size_entry[0]))
if batch_size > batch_size_entry[1]:
raise ValueError("InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(batch_size, batch_size_entry[1]))
raise ValueError(
"InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(
batch_size, batch_size_entry[1]
),
)
else:
self.batch_size = batch_size
else:
raise ValueError("InvalidParameterValueException",
"Unsupported event source type")
raise ValueError(
"InvalidParameterValueException", "Unsupported event source type"
)
# optional
self.starting_position = spec.get('StartingPosition', 'TRIM_HORIZON')
self.enabled = spec.get('Enabled', True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp',
None)
self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON")
self.enabled = spec.get("Enabled", True)
self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None)
def get_configuration(self):
return {
'UUID': self.uuid,
'BatchSize': self.batch_size,
'EventSourceArn': self.event_source_arn,
'FunctionArn': self.function_arn,
'LastModified': self.last_modified,
'LastProcessingResult': '',
'State': 'Enabled' if self.enabled else 'Disabled',
'StateTransitionReason': 'User initiated'
"UUID": self.uuid,
"BatchSize": self.batch_size,
"EventSourceArn": self.event_source_arn,
"FunctionArn": self.function_arn,
"LastModified": self.last_modified,
"LastProcessingResult": "",
"State": "Enabled" if self.enabled else "Disabled",
"StateTransitionReason": "User initiated",
}
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
func = lambda_backends[region_name].get_function(properties['FunctionName'])
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
func = lambda_backends[region_name].get_function(properties["FunctionName"])
spec = {
'FunctionArn': func.function_arn,
'EventSourceArn': properties['EventSourceArn'],
'StartingPosition': properties['StartingPosition'],
'BatchSize': properties.get('BatchSize', 100)
"FunctionArn": func.function_arn,
"EventSourceArn": properties["EventSourceArn"],
"StartingPosition": properties["StartingPosition"],
"BatchSize": properties.get("BatchSize", 100),
}
optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split()
optional_properties = "BatchSize Enabled StartingPositionTimestamp".split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]
@ -554,20 +600,19 @@ class EventSourceMapping(BaseModel):
class LambdaVersion(BaseModel):
def __init__(self, spec):
self.version = spec['Version']
self.version = spec["Version"]
def __repr__(self):
return str(self.logical_resource_id)
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json,
region_name):
properties = cloudformation_json['Properties']
function_name = properties['FunctionName']
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
function_name = properties["FunctionName"]
func = lambda_backends[region_name].publish_function(function_name)
spec = {
'Version': func.version
}
spec = {"Version": func.version}
return LambdaVersion(spec)
@ -578,18 +623,18 @@ class LambdaStorage(object):
self._arns = weakref.WeakValueDictionary()
def _get_latest(self, name):
return self._functions[name]['latest']
return self._functions[name]["latest"]
def _get_version(self, name, version):
index = version - 1
try:
return self._functions[name]['versions'][index]
return self._functions[name]["versions"][index]
except IndexError:
return None
def _get_alias(self, name, alias):
return self._functions[name]['alias'].get(alias, None)
return self._functions[name]["alias"].get(alias, None)
def get_function_by_name(self, name, qualifier=None):
if name not in self._functions:
@ -601,15 +646,15 @@ class LambdaStorage(object):
try:
return self._get_version(name, int(qualifier))
except ValueError:
return self._functions[name]['latest']
return self._functions[name]["latest"]
def list_versions_by_function(self, name):
if name not in self._functions:
return None
latest = copy.copy(self._functions[name]['latest'])
latest.function_arn += ':$LATEST'
return [latest] + self._functions[name]['versions']
latest = copy.copy(self._functions[name]["latest"])
latest.function_arn += ":$LATEST"
return [latest] + self._functions[name]["versions"]
def get_arn(self, arn):
return self._arns.get(arn, None)
@ -623,12 +668,12 @@ class LambdaStorage(object):
:type fn: LambdaFunction
"""
if fn.function_name in self._functions:
self._functions[fn.function_name]['latest'] = fn
self._functions[fn.function_name]["latest"] = fn
else:
self._functions[fn.function_name] = {
'latest': fn,
'versions': [],
'alias': weakref.WeakValueDictionary()
"latest": fn,
"versions": [],
"alias": weakref.WeakValueDictionary(),
}
self._arns[fn.function_arn] = fn
@ -636,14 +681,14 @@ class LambdaStorage(object):
def publish_function(self, name):
if name not in self._functions:
return None
if not self._functions[name]['latest']:
if not self._functions[name]["latest"]:
return None
new_version = len(self._functions[name]['versions']) + 1
fn = copy.copy(self._functions[name]['latest'])
new_version = len(self._functions[name]["versions"]) + 1
fn = copy.copy(self._functions[name]["latest"])
fn.set_version(new_version)
self._functions[name]['versions'].append(fn)
self._functions[name]["versions"].append(fn)
self._arns[fn.function_arn] = fn
return fn
@ -653,21 +698,24 @@ class LambdaStorage(object):
name = function.function_name
if not qualifier:
# Something is still reffing this so delete all arns
latest = self._functions[name]['latest'].function_arn
latest = self._functions[name]["latest"].function_arn
del self._arns[latest]
for fn in self._functions[name]['versions']:
for fn in self._functions[name]["versions"]:
del self._arns[fn.function_arn]
del self._functions[name]
return True
elif qualifier == '$LATEST':
self._functions[name]['latest'] = None
elif qualifier == "$LATEST":
self._functions[name]["latest"] = None
# If theres no functions left
if not self._functions[name]['versions'] and not self._functions[name]['latest']:
if (
not self._functions[name]["versions"]
and not self._functions[name]["latest"]
):
del self._functions[name]
return True
@ -675,10 +723,13 @@ class LambdaStorage(object):
else:
fn = self.get_function_by_name(name, qualifier)
if fn:
self._functions[name]['versions'].remove(fn)
self._functions[name]["versions"].remove(fn)
# If theres no functions left
if not self._functions[name]['versions'] and not self._functions[name]['latest']:
if (
not self._functions[name]["versions"]
and not self._functions[name]["latest"]
):
del self._functions[name]
return True
@ -689,10 +740,10 @@ class LambdaStorage(object):
result = []
for function_group in self._functions.values():
if function_group['latest'] is not None:
result.append(function_group['latest'])
if function_group["latest"] is not None:
result.append(function_group["latest"])
result.extend(function_group['versions'])
result.extend(function_group["versions"])
return result
@ -709,44 +760,47 @@ class LambdaBackend(BaseBackend):
self.__init__(region_name)
def create_function(self, spec):
function_name = spec.get('FunctionName', None)
function_name = spec.get("FunctionName", None)
if function_name is None:
raise RESTError('InvalidParameterValueException', 'Missing FunctionName')
raise RESTError("InvalidParameterValueException", "Missing FunctionName")
fn = LambdaFunction(spec, self.region_name, version='$LATEST')
fn = LambdaFunction(spec, self.region_name, version="$LATEST")
self._lambdas.put_function(fn)
if spec.get('Publish'):
if spec.get("Publish"):
ver = self.publish_function(function_name)
fn.version = ver.version
return fn
def create_event_source_mapping(self, spec):
required = [
'EventSourceArn',
'FunctionName',
]
required = ["EventSourceArn", "FunctionName"]
for param in required:
if not spec.get(param):
raise RESTError('InvalidParameterValueException', 'Missing {}'.format(param))
raise RESTError(
"InvalidParameterValueException", "Missing {}".format(param)
)
# Validate function name
func = self._lambdas.get_function_by_name_or_arn(spec.pop('FunctionName', ''))
func = self._lambdas.get_function_by_name_or_arn(spec.pop("FunctionName", ""))
if not func:
raise RESTError('ResourceNotFoundException', 'Invalid FunctionName')
raise RESTError("ResourceNotFoundException", "Invalid FunctionName")
# Validate queue
for queue in sqs_backends[self.region_name].queues.values():
if queue.queue_arn == spec['EventSourceArn']:
if queue.lambda_event_source_mappings.get('func.function_arn'):
if queue.queue_arn == spec["EventSourceArn"]:
if queue.lambda_event_source_mappings.get("func.function_arn"):
# TODO: Correct exception?
raise RESTError('ResourceConflictException', 'The resource already exists.')
raise RESTError(
"ResourceConflictException", "The resource already exists."
)
if queue.fifo_queue:
raise RESTError('InvalidParameterValueException',
'{} is FIFO'.format(queue.queue_arn))
raise RESTError(
"InvalidParameterValueException",
"{} is FIFO".format(queue.queue_arn),
)
else:
spec.update({'FunctionArn': func.function_arn})
spec.update({"FunctionArn": func.function_arn})
esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm
@ -754,16 +808,18 @@ class LambdaBackend(BaseBackend):
queue.lambda_event_source_mappings[esm.function_arn] = esm
return esm
for stream in json.loads(dynamodbstreams_backends[self.region_name].list_streams())['Streams']:
if stream['StreamArn'] == spec['EventSourceArn']:
spec.update({'FunctionArn': func.function_arn})
for stream in json.loads(
dynamodbstreams_backends[self.region_name].list_streams()
)["Streams"]:
if stream["StreamArn"] == spec["EventSourceArn"]:
spec.update({"FunctionArn": func.function_arn})
esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm
table_name = stream['TableName']
table_name = stream["TableName"]
table = dynamodb_backends2[self.region_name].get_table(table_name)
table.lambda_event_source_mappings[esm.function_arn] = esm
return esm
raise RESTError('ResourceNotFoundException', 'Invalid EventSourceArn')
raise RESTError("ResourceNotFoundException", "Invalid EventSourceArn")
def publish_function(self, function_name):
return self._lambdas.publish_function(function_name)
@ -783,13 +839,15 @@ class LambdaBackend(BaseBackend):
def update_event_source_mapping(self, uuid, spec):
esm = self.get_event_source_mapping(uuid)
if esm:
if spec.get('FunctionName'):
func = self._lambdas.get_function_by_name_or_arn(spec.get('FunctionName'))
if spec.get("FunctionName"):
func = self._lambdas.get_function_by_name_or_arn(
spec.get("FunctionName")
)
esm.function_arn = func.function_arn
if 'BatchSize' in spec:
esm.batch_size = spec['BatchSize']
if 'Enabled' in spec:
esm.enabled = spec['Enabled']
if "BatchSize" in spec:
esm.batch_size = spec["BatchSize"]
if "Enabled" in spec:
esm.enabled = spec["Enabled"]
return esm
return False
@ -830,13 +888,13 @@ class LambdaBackend(BaseBackend):
"ApproximateReceiveCount": "1",
"SentTimestamp": "1545082649183",
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
"ApproximateFirstReceiveTimestamp": "1545082649185"
"ApproximateFirstReceiveTimestamp": "1545082649185",
},
"messageAttributes": {},
"md5OfBody": "098f6bcd4621d373cade4e832627b4f6",
"eventSource": "aws:sqs",
"eventSourceARN": queue_arn,
"awsRegion": self.region_name
"awsRegion": self.region_name,
}
]
}
@ -844,7 +902,7 @@ class LambdaBackend(BaseBackend):
request_headers = {}
response_headers = {}
func.invoke(json.dumps(event), request_headers, response_headers)
return 'x-amz-function-error' not in response_headers
return "x-amz-function-error" not in response_headers
def send_sns_message(self, function_name, message, subject=None, qualifier=None):
event = {
@ -861,37 +919,35 @@ class LambdaBackend(BaseBackend):
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": message,
"MessageAttributes": {
"Test": {
"Type": "String",
"Value": "TestString"
},
"TestBinary": {
"Type": "Binary",
"Value": "TestBinary"
}
"Test": {"Type": "String", "Value": "TestString"},
"TestBinary": {"Type": "Binary", "Value": "TestBinary"},
},
"Type": "Notification",
"UnsubscribeUrl": "EXAMPLE",
"TopicArn": "arn:aws:sns:EXAMPLE",
"Subject": subject or "TestInvoke"
}
"Subject": subject or "TestInvoke",
},
}
]
}
func = self._lambdas.get_function_by_name_or_arn(function_name, qualifier)
func.invoke(json.dumps(event), {}, {})
def send_dynamodb_items(self, function_arn, items, source):
event = {'Records': [
{
'eventID': item.to_json()['eventID'],
'eventName': 'INSERT',
'eventVersion': item.to_json()['eventVersion'],
'eventSource': item.to_json()['eventSource'],
'awsRegion': self.region_name,
'dynamodb': item.to_json()['dynamodb'],
'eventSourceARN': source} for item in items]}
event = {
"Records": [
{
"eventID": item.to_json()["eventID"],
"eventName": "INSERT",
"eventVersion": item.to_json()["eventVersion"],
"eventSource": item.to_json()["eventSource"],
"awsRegion": self.region_name,
"dynamodb": item.to_json()["dynamodb"],
"eventSourceARN": source,
}
for item in items
]
}
func = self._lambdas.get_arn(function_arn)
func.invoke(json.dumps(event), {}, {})
@ -923,12 +979,13 @@ class LambdaBackend(BaseBackend):
def do_validate_s3():
return os.environ.get('VALIDATE_LAMBDA_S3', '') in ['', '1', 'true']
return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"]
# Handle us forgotten regions, unless Lambda truly only runs out of US and
lambda_backends = {_region.name: LambdaBackend(_region.name)
for _region in boto.awslambda.regions()}
lambda_backends = {
_region.name: LambdaBackend(_region.name) for _region in boto.awslambda.regions()
}
lambda_backends['ap-southeast-2'] = LambdaBackend('ap-southeast-2')
lambda_backends['us-gov-west-1'] = LambdaBackend('us-gov-west-1')
lambda_backends["ap-southeast-2"] = LambdaBackend("ap-southeast-2")
lambda_backends["us-gov-west-1"] = LambdaBackend("us-gov-west-1")

View File

@ -32,57 +32,57 @@ class LambdaResponse(BaseResponse):
def root(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
return self._list_functions(request, full_url, headers)
elif request.method == 'POST':
elif request.method == "POST":
return self._create_function(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def event_source_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
querystring = self.querystring
event_source_arn = querystring.get('EventSourceArn', [None])[0]
function_name = querystring.get('FunctionName', [None])[0]
event_source_arn = querystring.get("EventSourceArn", [None])[0]
function_name = querystring.get("FunctionName", [None])[0]
return self._list_event_source_mappings(event_source_arn, function_name)
elif request.method == 'POST':
elif request.method == "POST":
return self._create_event_source_mapping(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def event_source_mapping(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, 'path') else path_url(request.url)
uuid = path.split('/')[-1]
if request.method == 'GET':
path = request.path if hasattr(request, "path") else path_url(request.url)
uuid = path.split("/")[-1]
if request.method == "GET":
return self._get_event_source_mapping(uuid)
elif request.method == 'PUT':
elif request.method == "PUT":
return self._update_event_source_mapping(uuid)
elif request.method == 'DELETE':
elif request.method == "DELETE":
return self._delete_event_source_mapping(uuid)
else:
raise ValueError("Cannot handle request")
def function(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
return self._get_function(request, full_url, headers)
elif request.method == 'DELETE':
elif request.method == "DELETE":
return self._delete_function(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
def versions(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
# This is ListVersionByFunction
path = request.path if hasattr(request, 'path') else path_url(request.url)
function_name = path.split('/')[-2]
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2]
return self._list_versions_by_function(function_name)
elif request.method == 'POST':
elif request.method == "POST":
return self._publish_function(request, full_url, headers)
else:
raise ValueError("Cannot handle request")
@ -91,7 +91,7 @@ class LambdaResponse(BaseResponse):
@amzn_request_id
def invoke(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'POST':
if request.method == "POST":
return self._invoke(request, full_url)
else:
raise ValueError("Cannot handle request")
@ -100,46 +100,46 @@ class LambdaResponse(BaseResponse):
@amzn_request_id
def invoke_async(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'POST':
if request.method == "POST":
return self._invoke_async(request, full_url)
else:
raise ValueError("Cannot handle request")
def tag(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
return self._list_tags(request, full_url)
elif request.method == 'POST':
elif request.method == "POST":
return self._tag_resource(request, full_url)
elif request.method == 'DELETE':
elif request.method == "DELETE":
return self._untag_resource(request, full_url)
else:
raise ValueError("Cannot handle {0} request".format(request.method))
def policy(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'GET':
if request.method == "GET":
return self._get_policy(request, full_url, headers)
if request.method == 'POST':
if request.method == "POST":
return self._add_policy(request, full_url, headers)
def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'PUT':
if request.method == "PUT":
return self._put_configuration(request)
else:
raise ValueError("Cannot handle request")
def code(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
if request.method == 'PUT':
if request.method == "PUT":
return self._put_code()
else:
raise ValueError("Cannot handle request")
def _add_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url)
function_name = path.split('/')[-2]
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name):
policy = self.body
self.lambda_backend.add_policy(function_name, policy)
@ -148,24 +148,30 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _get_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url)
function_name = path.split('/')[-2]
path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name)
return 200, {}, json.dumps(dict(Policy="{\"Statement\":[" + lambda_function.policy + "]}"))
return (
200,
{},
json.dumps(
dict(Policy='{"Statement":[' + lambda_function.policy + "]}")
),
)
else:
return 404, {}, "{}"
def _invoke(self, request, full_url):
response_headers = {}
function_name = self.path.rsplit('/', 2)[-2]
qualifier = self._get_param('qualifier')
function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param("qualifier")
fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
payload = fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(len(payload))
response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload
else:
return 404, response_headers, "{}"
@ -173,38 +179,34 @@ class LambdaResponse(BaseResponse):
def _invoke_async(self, request, full_url):
response_headers = {}
function_name = self.path.rsplit('/', 3)[-3]
function_name = self.path.rsplit("/", 3)[-3]
fn = self.lambda_backend.get_function(function_name, None)
if fn:
payload = fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(len(payload))
response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload
else:
return 404, response_headers, "{}"
def _list_functions(self, request, full_url, headers):
result = {
'Functions': []
}
result = {"Functions": []}
for fn in self.lambda_backend.list_functions():
json_data = fn.get_configuration()
json_data['Version'] = '$LATEST'
result['Functions'].append(json_data)
json_data["Version"] = "$LATEST"
result["Functions"].append(json_data)
return 200, {}, json.dumps(result)
def _list_versions_by_function(self, function_name):
result = {
'Versions': []
}
result = {"Versions": []}
functions = self.lambda_backend.list_versions_by_function(function_name)
if functions:
for fn in functions:
json_data = fn.get_configuration()
result['Versions'].append(json_data)
result["Versions"].append(json_data)
return 200, {}, json.dumps(result)
@ -212,7 +214,11 @@ class LambdaResponse(BaseResponse):
try:
fn = self.lambda_backend.create_function(self.json_body)
except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
return (
400,
{},
json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}),
)
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
@ -221,16 +227,20 @@ class LambdaResponse(BaseResponse):
try:
fn = self.lambda_backend.create_event_source_mapping(self.json_body)
except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}})
return (
400,
{},
json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}),
)
else:
config = fn.get_configuration()
return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(event_source_arn, function_name)
result = {
'EventSourceMappings': [esm.get_configuration() for esm in esms]
}
esms = self.lambda_backend.list_event_source_mappings(
event_source_arn, function_name
)
result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]}
return 200, {}, json.dumps(result)
def _get_event_source_mapping(self, uuid):
@ -251,13 +261,13 @@ class LambdaResponse(BaseResponse):
esm = self.lambda_backend.delete_event_source_mapping(uuid)
if esm:
json_result = esm.get_configuration()
json_result.update({'State': 'Deleting'})
json_result.update({"State": "Deleting"})
return 202, {}, json.dumps(json_result)
else:
return 404, {}, "{}"
def _publish_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 2)[-2]
function_name = self.path.rsplit("/", 2)[-2]
fn = self.lambda_backend.publish_function(function_name)
if fn:
@ -267,8 +277,8 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _delete_function(self, request, full_url, headers):
function_name = unquote(self.path.rsplit('/', 1)[-1])
qualifier = self._get_param('Qualifier', None)
function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None)
if self.lambda_backend.delete_function(function_name, qualifier):
return 204, {}, ""
@ -276,17 +286,17 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _get_function(self, request, full_url, headers):
function_name = unquote(self.path.rsplit('/', 1)[-1])
qualifier = self._get_param('Qualifier', None)
function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
code = fn.get_code()
if qualifier is None or qualifier == '$LATEST':
code['Configuration']['Version'] = '$LATEST'
if qualifier == '$LATEST':
code['Configuration']['FunctionArn'] += ':$LATEST'
if qualifier is None or qualifier == "$LATEST":
code["Configuration"]["Version"] = "$LATEST"
if qualifier == "$LATEST":
code["Configuration"]["FunctionArn"] += ":$LATEST"
return 200, {}, json.dumps(code)
else:
return 404, {}, "{}"
@ -299,25 +309,25 @@ class LambdaResponse(BaseResponse):
return self.default_region
def _list_tags(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1])
function_arn = unquote(self.path.rsplit("/", 1)[-1])
fn = self.lambda_backend.get_function_by_arn(function_arn)
if fn:
return 200, {}, json.dumps({'Tags': fn.tags})
return 200, {}, json.dumps({"Tags": fn.tags})
else:
return 404, {}, "{}"
def _tag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1])
function_arn = unquote(self.path.rsplit("/", 1)[-1])
if self.lambda_backend.tag_resource(function_arn, self.json_body['Tags']):
if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]):
return 200, {}, "{}"
else:
return 404, {}, "{}"
def _untag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1])
tag_keys = self.querystring['tagKeys']
function_arn = unquote(self.path.rsplit("/", 1)[-1])
tag_keys = self.querystring["tagKeys"]
if self.lambda_backend.untag_resource(function_arn, tag_keys):
return 204, {}, "{}"
@ -325,8 +335,8 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _put_configuration(self, request):
function_name = self.path.rsplit('/', 2)[-2]
qualifier = self._get_param('Qualifier', None)
function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier)
@ -337,13 +347,13 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}"
def _put_code(self):
function_name = self.path.rsplit('/', 2)[-2]
qualifier = self._get_param('Qualifier', None)
function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier)
if fn:
if self.json_body.get('Publish', False):
if self.json_body.get("Publish", False):
fn = self.lambda_backend.publish_function(function_name)
config = fn.update_function_code(self.json_body)

View File

@ -1,22 +1,20 @@
from __future__ import unicode_literals
from .responses import LambdaResponse
url_bases = [
"https?://lambda.(.+).amazonaws.com",
]
url_bases = ["https?://lambda.(.+).amazonaws.com"]
response = LambdaResponse()
url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$': response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/?$': response.event_source_mappings,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$': response.event_source_mapping,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$': response.policy,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$': response.configuration,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$': response.code
"{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,
}

View File

@ -1,20 +1,20 @@
from collections import namedtuple
ARN = namedtuple('ARN', ['region', 'account', 'function_name', 'version'])
ARN = namedtuple("ARN", ["region", "account", "function_name", "version"])
def make_function_arn(region, account, name):
return 'arn:aws:lambda:{0}:{1}:function:{2}'.format(region, account, name)
return "arn:aws:lambda:{0}:{1}:function:{2}".format(region, account, name)
def make_function_ver_arn(region, account, name, version='1'):
def make_function_ver_arn(region, account, name, version="1"):
arn = make_function_arn(region, account, name)
return '{0}:{1}'.format(arn, version)
return "{0}:{1}".format(arn, version)
def split_function_arn(arn):
arn = arn.replace('arn:aws:lambda:')
arn = arn.replace("arn:aws:lambda:")
region, account, _, name, version = arn.split(':')
region, account, _, name, version = arn.split(":")
return ARN(region, account, name, version)

View File

@ -52,57 +52,57 @@ from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends
from moto.config import config_backends
BACKENDS = {
'acm': acm_backends,
'apigateway': apigateway_backends,
'athena': athena_backends,
'autoscaling': autoscaling_backends,
'batch': batch_backends,
'cloudformation': cloudformation_backends,
'cloudwatch': cloudwatch_backends,
'cognito-identity': cognitoidentity_backends,
'cognito-idp': cognitoidp_backends,
'config': config_backends,
'datapipeline': datapipeline_backends,
'dynamodb': dynamodb_backends,
'dynamodb2': dynamodb_backends2,
'dynamodbstreams': dynamodbstreams_backends,
'ec2': ec2_backends,
'ecr': ecr_backends,
'ecs': ecs_backends,
'elb': elb_backends,
'elbv2': elbv2_backends,
'events': events_backends,
'emr': emr_backends,
'glacier': glacier_backends,
'glue': glue_backends,
'iam': iam_backends,
'moto_api': moto_api_backends,
'instance_metadata': instance_metadata_backends,
'logs': logs_backends,
'kinesis': kinesis_backends,
'kms': kms_backends,
'opsworks': opsworks_backends,
'organizations': organizations_backends,
'polly': polly_backends,
'redshift': redshift_backends,
'resource-groups': resourcegroups_backends,
'rds': rds2_backends,
's3': s3_backends,
's3bucket_path': s3_backends,
'ses': ses_backends,
'secretsmanager': secretsmanager_backends,
'sns': sns_backends,
'sqs': sqs_backends,
'ssm': ssm_backends,
'stepfunctions': stepfunction_backends,
'sts': sts_backends,
'swf': swf_backends,
'route53': route53_backends,
'lambda': lambda_backends,
'xray': xray_backends,
'resourcegroupstaggingapi': resourcegroupstaggingapi_backends,
'iot': iot_backends,
'iot-data': iotdata_backends,
"acm": acm_backends,
"apigateway": apigateway_backends,
"athena": athena_backends,
"autoscaling": autoscaling_backends,
"batch": batch_backends,
"cloudformation": cloudformation_backends,
"cloudwatch": cloudwatch_backends,
"cognito-identity": cognitoidentity_backends,
"cognito-idp": cognitoidp_backends,
"config": config_backends,
"datapipeline": datapipeline_backends,
"dynamodb": dynamodb_backends,
"dynamodb2": dynamodb_backends2,
"dynamodbstreams": dynamodbstreams_backends,
"ec2": ec2_backends,
"ecr": ecr_backends,
"ecs": ecs_backends,
"elb": elb_backends,
"elbv2": elbv2_backends,
"events": events_backends,
"emr": emr_backends,
"glacier": glacier_backends,
"glue": glue_backends,
"iam": iam_backends,
"moto_api": moto_api_backends,
"instance_metadata": instance_metadata_backends,
"logs": logs_backends,
"kinesis": kinesis_backends,
"kms": kms_backends,
"opsworks": opsworks_backends,
"organizations": organizations_backends,
"polly": polly_backends,
"redshift": redshift_backends,
"resource-groups": resourcegroups_backends,
"rds": rds2_backends,
"s3": s3_backends,
"s3bucket_path": s3_backends,
"ses": ses_backends,
"secretsmanager": secretsmanager_backends,
"sns": sns_backends,
"sqs": sqs_backends,
"ssm": ssm_backends,
"stepfunctions": stepfunction_backends,
"sts": sts_backends,
"swf": swf_backends,
"route53": route53_backends,
"lambda": lambda_backends,
"xray": xray_backends,
"resourcegroupstaggingapi": resourcegroupstaggingapi_backends,
"iot": iot_backends,
"iot-data": iotdata_backends,
}
@ -110,6 +110,6 @@ def get_model(name, region_name):
for backends in BACKENDS.values():
for region, backend in backends.items():
if region == region_name:
models = getattr(backend.__class__, '__models__', {})
models = getattr(backend.__class__, "__models__", {})
if name in models:
return list(getattr(backend, models[name])())

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import batch_backends
from ..core.models import base_decorator
batch_backend = batch_backends['us-east-1']
batch_backend = batch_backends["us-east-1"]
mock_batch = base_decorator(batch_backends)

View File

@ -12,26 +12,29 @@ class AWSError(Exception):
self.status = status if status is not None else self.STATUS
def response(self):
return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status)
return (
json.dumps({"__type": self.code, "message": self.message}),
dict(status=self.status),
)
class InvalidRequestException(AWSError):
CODE = 'InvalidRequestException'
CODE = "InvalidRequestException"
class InvalidParameterValueException(AWSError):
CODE = 'InvalidParameterValue'
CODE = "InvalidParameterValue"
class ValidationError(AWSError):
CODE = 'ValidationError'
CODE = "ValidationError"
class InternalFailure(AWSError):
CODE = 'InternalFailure'
CODE = "InternalFailure"
STATUS = 500
class ClientException(AWSError):
CODE = 'ClientException'
CODE = "ClientException"
STATUS = 400

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ import json
class BatchResponse(BaseResponse):
def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400)
return json.dumps({"__type": code, "message": message}), dict(status=400)
@property
def batch_backend(self):
@ -22,9 +22,9 @@ class BatchResponse(BaseResponse):
@property
def json(self):
if self.body is None or self.body == '':
if self.body is None or self.body == "":
self._json = {}
elif not hasattr(self, '_json'):
elif not hasattr(self, "_json"):
try:
self._json = json.loads(self.body)
except ValueError:
@ -39,153 +39,146 @@ class BatchResponse(BaseResponse):
def _get_action(self):
# Return element after the /v1/*
return urlsplit(self.uri).path.lstrip('/').split('/')[1]
return urlsplit(self.uri).path.lstrip("/").split("/")[1]
# CreateComputeEnvironment
def createcomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironmentName')
compute_resource = self._get_param('computeResources')
service_role = self._get_param('serviceRole')
state = self._get_param('state')
_type = self._get_param('type')
compute_env_name = self._get_param("computeEnvironmentName")
compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole")
state = self._get_param("state")
_type = self._get_param("type")
try:
name, arn = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name,
_type=_type, state=state,
_type=_type,
state=state,
compute_resources=compute_resource,
service_role=service_role
service_role=service_role,
)
except AWSError as err:
return err.response()
result = {
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
return json.dumps(result)
# DescribeComputeEnvironments
def describecomputeenvironments(self):
compute_environments = self._get_param('computeEnvironments')
max_results = self._get_param('maxResults') # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored
compute_environments = self._get_param("computeEnvironments")
max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param("nextToken") # Ignored
envs = self.batch_backend.describe_compute_environments(compute_environments, max_results=max_results, next_token=next_token)
envs = self.batch_backend.describe_compute_environments(
compute_environments, max_results=max_results, next_token=next_token
)
result = {'computeEnvironments': envs}
result = {"computeEnvironments": envs}
return json.dumps(result)
# DeleteComputeEnvironment
def deletecomputeenvironment(self):
compute_environment = self._get_param('computeEnvironment')
compute_environment = self._get_param("computeEnvironment")
try:
self.batch_backend.delete_compute_environment(compute_environment)
except AWSError as err:
return err.response()
return ''
return ""
# UpdateComputeEnvironment
def updatecomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironment')
compute_resource = self._get_param('computeResources')
service_role = self._get_param('serviceRole')
state = self._get_param('state')
compute_env_name = self._get_param("computeEnvironment")
compute_resource = self._get_param("computeResources")
service_role = self._get_param("serviceRole")
state = self._get_param("state")
try:
name, arn = self.batch_backend.update_compute_environment(
compute_environment_name=compute_env_name,
compute_resources=compute_resource,
service_role=service_role,
state=state
state=state,
)
except AWSError as err:
return err.response()
result = {
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
return json.dumps(result)
# CreateJobQueue
def createjobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder')
queue_name = self._get_param('jobQueueName')
priority = self._get_param('priority')
state = self._get_param('state')
compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueueName")
priority = self._get_param("priority")
state = self._get_param("state")
try:
name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order
compute_env_order=compute_env_order,
)
except AWSError as err:
return err.response()
result = {
'jobQueueArn': arn,
'jobQueueName': name
}
result = {"jobQueueArn": arn, "jobQueueName": name}
return json.dumps(result)
# DescribeJobQueues
def describejobqueues(self):
job_queues = self._get_param('jobQueues')
max_results = self._get_param('maxResults') # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored
job_queues = self._get_param("jobQueues")
max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param("nextToken") # Ignored
queues = self.batch_backend.describe_job_queues(job_queues, max_results=max_results, next_token=next_token)
queues = self.batch_backend.describe_job_queues(
job_queues, max_results=max_results, next_token=next_token
)
result = {'jobQueues': queues}
result = {"jobQueues": queues}
return json.dumps(result)
# UpdateJobQueue
def updatejobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder')
queue_name = self._get_param('jobQueue')
priority = self._get_param('priority')
state = self._get_param('state')
compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param("jobQueue")
priority = self._get_param("priority")
state = self._get_param("state")
try:
name, arn = self.batch_backend.update_job_queue(
queue_name=queue_name,
priority=priority,
state=state,
compute_env_order=compute_env_order
compute_env_order=compute_env_order,
)
except AWSError as err:
return err.response()
result = {
'jobQueueArn': arn,
'jobQueueName': name
}
result = {"jobQueueArn": arn, "jobQueueName": name}
return json.dumps(result)
# DeleteJobQueue
def deletejobqueue(self):
queue_name = self._get_param('jobQueue')
queue_name = self._get_param("jobQueue")
self.batch_backend.delete_job_queue(queue_name)
return ''
return ""
# RegisterJobDefinition
def registerjobdefinition(self):
container_properties = self._get_param('containerProperties')
def_name = self._get_param('jobDefinitionName')
parameters = self._get_param('parameters')
retry_strategy = self._get_param('retryStrategy')
_type = self._get_param('type')
container_properties = self._get_param("containerProperties")
def_name = self._get_param("jobDefinitionName")
parameters = self._get_param("parameters")
retry_strategy = self._get_param("retryStrategy")
_type = self._get_param("type")
try:
name, arn, revision = self.batch_backend.register_job_definition(
@ -193,104 +186,113 @@ class BatchResponse(BaseResponse):
parameters=parameters,
_type=_type,
retry_strategy=retry_strategy,
container_properties=container_properties
container_properties=container_properties,
)
except AWSError as err:
return err.response()
result = {
'jobDefinitionArn': arn,
'jobDefinitionName': name,
'revision': revision
"jobDefinitionArn": arn,
"jobDefinitionName": name,
"revision": revision,
}
return json.dumps(result)
# DeregisterJobDefinition
def deregisterjobdefinition(self):
queue_name = self._get_param('jobDefinition')
queue_name = self._get_param("jobDefinition")
self.batch_backend.deregister_job_definition(queue_name)
return ''
return ""
# DescribeJobDefinitions
def describejobdefinitions(self):
job_def_name = self._get_param('jobDefinitionName')
job_def_list = self._get_param('jobDefinitions')
max_results = self._get_param('maxResults')
next_token = self._get_param('nextToken')
status = self._get_param('status')
job_def_name = self._get_param("jobDefinitionName")
job_def_list = self._get_param("jobDefinitions")
max_results = self._get_param("maxResults")
next_token = self._get_param("nextToken")
status = self._get_param("status")
job_defs = self.batch_backend.describe_job_definitions(job_def_name, job_def_list, status, max_results, next_token)
job_defs = self.batch_backend.describe_job_definitions(
job_def_name, job_def_list, status, max_results, next_token
)
result = {'jobDefinitions': [job.describe() for job in job_defs]}
result = {"jobDefinitions": [job.describe() for job in job_defs]}
return json.dumps(result)
# SubmitJob
def submitjob(self):
container_overrides = self._get_param('containerOverrides')
depends_on = self._get_param('dependsOn')
job_def = self._get_param('jobDefinition')
job_name = self._get_param('jobName')
job_queue = self._get_param('jobQueue')
parameters = self._get_param('parameters')
retries = self._get_param('retryStrategy')
container_overrides = self._get_param("containerOverrides")
depends_on = self._get_param("dependsOn")
job_def = self._get_param("jobDefinition")
job_name = self._get_param("jobName")
job_queue = self._get_param("jobQueue")
parameters = self._get_param("parameters")
retries = self._get_param("retryStrategy")
try:
name, job_id = self.batch_backend.submit_job(
job_name, job_def, job_queue,
job_name,
job_def,
job_queue,
parameters=parameters,
retries=retries,
depends_on=depends_on,
container_overrides=container_overrides
container_overrides=container_overrides,
)
except AWSError as err:
return err.response()
result = {
'jobId': job_id,
'jobName': name,
}
result = {"jobId": job_id, "jobName": name}
return json.dumps(result)
# DescribeJobs
def describejobs(self):
jobs = self._get_param('jobs')
jobs = self._get_param("jobs")
try:
return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)})
return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
except AWSError as err:
return err.response()
# ListJobs
def listjobs(self):
job_queue = self._get_param('jobQueue')
job_status = self._get_param('jobStatus')
max_results = self._get_param('maxResults')
next_token = self._get_param('nextToken')
job_queue = self._get_param("jobQueue")
job_status = self._get_param("jobStatus")
max_results = self._get_param("maxResults")
next_token = self._get_param("nextToken")
try:
jobs = self.batch_backend.list_jobs(job_queue, job_status, max_results, next_token)
jobs = self.batch_backend.list_jobs(
job_queue, job_status, max_results, next_token
)
except AWSError as err:
return err.response()
result = {'jobSummaryList': [{'jobId': job.job_id, 'jobName': job.job_name} for job in jobs]}
result = {
"jobSummaryList": [
{"jobId": job.job_id, "jobName": job.job_name} for job in jobs
]
}
return json.dumps(result)
# TerminateJob
def terminatejob(self):
job_id = self._get_param('jobId')
reason = self._get_param('reason')
job_id = self._get_param("jobId")
reason = self._get_param("reason")
try:
self.batch_backend.terminate_job(job_id, reason)
except AWSError as err:
return err.response()
return ''
return ""
# CancelJob
def canceljob(self): # Theres some AWS semantics on the differences but for us they're identical ;-)
def canceljob(
self,
): # Theres some AWS semantics on the differences but for us they're identical ;-)
return self.terminatejob()

View File

@ -1,25 +1,23 @@
from __future__ import unicode_literals
from .responses import BatchResponse
url_bases = [
"https?://batch.(.+).amazonaws.com",
]
url_bases = ["https?://batch.(.+).amazonaws.com"]
url_paths = {
'{0}/v1/createcomputeenvironment$': BatchResponse.dispatch,
'{0}/v1/describecomputeenvironments$': BatchResponse.dispatch,
'{0}/v1/deletecomputeenvironment': BatchResponse.dispatch,
'{0}/v1/updatecomputeenvironment': BatchResponse.dispatch,
'{0}/v1/createjobqueue': BatchResponse.dispatch,
'{0}/v1/describejobqueues': BatchResponse.dispatch,
'{0}/v1/updatejobqueue': BatchResponse.dispatch,
'{0}/v1/deletejobqueue': BatchResponse.dispatch,
'{0}/v1/registerjobdefinition': BatchResponse.dispatch,
'{0}/v1/deregisterjobdefinition': BatchResponse.dispatch,
'{0}/v1/describejobdefinitions': BatchResponse.dispatch,
'{0}/v1/submitjob': BatchResponse.dispatch,
'{0}/v1/describejobs': BatchResponse.dispatch,
'{0}/v1/listjobs': BatchResponse.dispatch,
'{0}/v1/terminatejob': BatchResponse.dispatch,
'{0}/v1/canceljob': BatchResponse.dispatch,
"{0}/v1/createcomputeenvironment$": BatchResponse.dispatch,
"{0}/v1/describecomputeenvironments$": BatchResponse.dispatch,
"{0}/v1/deletecomputeenvironment": BatchResponse.dispatch,
"{0}/v1/updatecomputeenvironment": BatchResponse.dispatch,
"{0}/v1/createjobqueue": BatchResponse.dispatch,
"{0}/v1/describejobqueues": BatchResponse.dispatch,
"{0}/v1/updatejobqueue": BatchResponse.dispatch,
"{0}/v1/deletejobqueue": BatchResponse.dispatch,
"{0}/v1/registerjobdefinition": BatchResponse.dispatch,
"{0}/v1/deregisterjobdefinition": BatchResponse.dispatch,
"{0}/v1/describejobdefinitions": BatchResponse.dispatch,
"{0}/v1/submitjob": BatchResponse.dispatch,
"{0}/v1/describejobs": BatchResponse.dispatch,
"{0}/v1/listjobs": BatchResponse.dispatch,
"{0}/v1/terminatejob": BatchResponse.dispatch,
"{0}/v1/canceljob": BatchResponse.dispatch,
}

View File

@ -2,7 +2,9 @@ from __future__ import unicode_literals
def make_arn_for_compute_env(account_id, name, region_name):
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name)
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(
region_name, account_id, name
)
def make_arn_for_job_queue(account_id, name, region_name):
@ -10,7 +12,9 @@ def make_arn_for_job_queue(account_id, name, region_name):
def make_arn_for_task_def(account_id, name, revision, region_name):
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision)
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(
region_name, account_id, name, revision
)
def lowercase_first_key(some_dict):

View File

@ -2,7 +2,6 @@ from __future__ import unicode_literals
from .models import cloudformation_backends
from ..core.models import base_decorator, deprecated_base_decorator
cloudformation_backend = cloudformation_backends['us-east-1']
cloudformation_backend = cloudformation_backends["us-east-1"]
mock_cloudformation = base_decorator(cloudformation_backends)
mock_cloudformation_deprecated = deprecated_base_decorator(
cloudformation_backends)
mock_cloudformation_deprecated = deprecated_base_decorator(cloudformation_backends)

View File

@ -4,26 +4,23 @@ from jinja2 import Template
class UnformattedGetAttTemplateException(Exception):
description = 'Template error: resource {0} does not support attribute type {1} in Fn::GetAtt'
description = (
"Template error: resource {0} does not support attribute type {1} in Fn::GetAtt"
)
status_code = 400
class ValidationError(BadRequest):
def __init__(self, name_or_id, message=None):
if message is None:
message = "Stack with id {0} does not exist".format(name_or_id)
template = Template(ERROR_RESPONSE)
super(ValidationError, self).__init__()
self.description = template.render(
code="ValidationError",
message=message,
)
self.description = template.render(code="ValidationError", message=message)
class MissingParameterError(BadRequest):
def __init__(self, parameter_name):
template = Template(ERROR_RESPONSE)
super(MissingParameterError, self).__init__()
@ -40,8 +37,8 @@ class ExportNotFound(BadRequest):
template = Template(ERROR_RESPONSE)
super(ExportNotFound, self).__init__()
self.description = template.render(
code='ExportNotFound',
message="No export named {0} found.".format(export_name)
code="ExportNotFound",
message="No export named {0} found.".format(export_name),
)

View File

@ -21,11 +21,19 @@ from .exceptions import ValidationError
class FakeStackSet(BaseModel):
def __init__(self, stackset_id, name, template, region='us-east-1',
status='ACTIVE', description=None, parameters=None, tags=None,
admin_role='AWSCloudFormationStackSetAdministrationRole',
execution_role='AWSCloudFormationStackSetExecutionRole'):
def __init__(
self,
stackset_id,
name,
template,
region="us-east-1",
status="ACTIVE",
description=None,
parameters=None,
tags=None,
admin_role="AWSCloudFormationStackSetAdministrationRole",
execution_role="AWSCloudFormationStackSetExecutionRole",
):
self.id = stackset_id
self.arn = generate_stackset_arn(stackset_id, region)
self.name = name
@ -42,12 +50,14 @@ class FakeStackSet(BaseModel):
def _create_operation(self, operation_id, action, status, accounts=[], regions=[]):
operation = {
'OperationId': str(operation_id),
'Action': action,
'Status': status,
'CreationTimestamp': datetime.now(),
'EndTimestamp': datetime.now() + timedelta(minutes=2),
'Instances': [{account: region} for account in accounts for region in regions],
"OperationId": str(operation_id),
"Action": action,
"Status": status,
"CreationTimestamp": datetime.now(),
"EndTimestamp": datetime.now() + timedelta(minutes=2),
"Instances": [
{account: region} for account in accounts for region in regions
],
}
self.operations += [operation]
@ -55,20 +65,30 @@ class FakeStackSet(BaseModel):
def get_operation(self, operation_id):
for operation in self.operations:
if operation_id == operation['OperationId']:
if operation_id == operation["OperationId"]:
return operation
raise ValidationError(operation_id)
def update_operation(self, operation_id, status):
operation = self.get_operation(operation_id)
operation['Status'] = status
operation["Status"] = status
return operation_id
def delete(self):
self.status = 'DELETED'
self.status = "DELETED"
def update(self, template, description, parameters, tags, admin_role,
execution_role, accounts, regions, operation_id=None):
def update(
self,
template,
description,
parameters,
tags,
admin_role,
execution_role,
accounts,
regions,
operation_id=None,
):
if not operation_id:
operation_id = uuid.uuid4()
@ -82,9 +102,13 @@ class FakeStackSet(BaseModel):
if accounts and regions:
self.update_instances(accounts, regions, self.parameters)
operation = self._create_operation(operation_id=operation_id,
action='UPDATE', status='SUCCEEDED', accounts=accounts,
regions=regions)
operation = self._create_operation(
operation_id=operation_id,
action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation
def create_stack_instances(self, accounts, regions, parameters, operation_id=None):
@ -94,8 +118,13 @@ class FakeStackSet(BaseModel):
parameters = self.parameters
self.instances.create_instances(accounts, regions, parameters, operation_id)
self._create_operation(operation_id=operation_id, action='CREATE',
status='SUCCEEDED', accounts=accounts, regions=regions)
self._create_operation(
operation_id=operation_id,
action="CREATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
def delete_stack_instances(self, accounts, regions, operation_id=None):
if not operation_id:
@ -103,8 +132,13 @@ class FakeStackSet(BaseModel):
self.instances.delete(accounts, regions)
operation = self._create_operation(operation_id=operation_id, action='DELETE',
status='SUCCEEDED', accounts=accounts, regions=regions)
operation = self._create_operation(
operation_id=operation_id,
action="DELETE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation
def update_instances(self, accounts, regions, parameters, operation_id=None):
@ -112,9 +146,13 @@ class FakeStackSet(BaseModel):
operation_id = uuid.uuid4()
self.instances.update(accounts, regions, parameters)
operation = self._create_operation(operation_id=operation_id,
action='UPDATE', status='SUCCEEDED', accounts=accounts,
regions=regions)
operation = self._create_operation(
operation_id=operation_id,
action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation
@ -131,12 +169,12 @@ class FakeStackInstances(BaseModel):
for region in regions:
for account in accounts:
instance = {
'StackId': generate_stack_id(self.stack_name, region, account),
'StackSetId': self.stackset_id,
'Region': region,
'Account': account,
'Status': "CURRENT",
'ParameterOverrides': parameters if parameters else [],
"StackId": generate_stack_id(self.stack_name, region, account),
"StackSetId": self.stackset_id,
"Region": region,
"Account": account,
"Status": "CURRENT",
"ParameterOverrides": parameters if parameters else [],
}
new_instances.append(instance)
self.stack_instances += new_instances
@ -147,24 +185,35 @@ class FakeStackInstances(BaseModel):
for region in regions:
instance = self.get_instance(account, region)
if parameters:
instance['ParameterOverrides'] = parameters
instance["ParameterOverrides"] = parameters
else:
instance['ParameterOverrides'] = []
instance["ParameterOverrides"] = []
def delete(self, accounts, regions):
for i, instance in enumerate(self.stack_instances):
if instance['Region'] in regions and instance['Account'] in accounts:
if instance["Region"] in regions and instance["Account"] in accounts:
self.stack_instances.pop(i)
def get_instance(self, account, region):
for i, instance in enumerate(self.stack_instances):
if instance['Region'] == region and instance['Account'] == account:
if instance["Region"] == region and instance["Account"] == account:
return self.stack_instances[i]
class FakeStack(BaseModel):
def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None, create_change_set=False):
def __init__(
self,
stack_id,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
create_change_set=False,
):
self.stack_id = stack_id
self.name = name
self.template = template
@ -176,22 +225,31 @@ class FakeStack(BaseModel):
self.tags = tags if tags else {}
self.events = []
if create_change_set:
self._add_stack_event("REVIEW_IN_PROGRESS",
resource_status_reason="User Initiated")
self._add_stack_event(
"REVIEW_IN_PROGRESS", resource_status_reason="User Initiated"
)
else:
self._add_stack_event("CREATE_IN_PROGRESS",
resource_status_reason="User Initiated")
self._add_stack_event(
"CREATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.description = self.template_dict.get('Description')
self.description = self.template_dict.get("Description")
self.cross_stack_resources = cross_stack_resources or {}
self.resource_map = self._create_resource_map()
self.output_map = self._create_output_map()
self._add_stack_event("CREATE_COMPLETE")
self.status = 'CREATE_COMPLETE'
self.status = "CREATE_COMPLETE"
def _create_resource_map(self):
resource_map = ResourceMap(
self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict, self.cross_stack_resources)
self.stack_id,
self.name,
self.parameters,
self.tags,
self.region_name,
self.template_dict,
self.cross_stack_resources,
)
resource_map.create()
return resource_map
@ -200,34 +258,46 @@ class FakeStack(BaseModel):
output_map.create()
return output_map
def _add_stack_event(self, resource_status, resource_status_reason=None, resource_properties=None):
self.events.append(FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=self.name,
physical_resource_id=self.stack_id,
resource_type="AWS::CloudFormation::Stack",
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
))
def _add_stack_event(
self, resource_status, resource_status_reason=None, resource_properties=None
):
self.events.append(
FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=self.name,
physical_resource_id=self.stack_id,
resource_type="AWS::CloudFormation::Stack",
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
)
)
def _add_resource_event(self, logical_resource_id, resource_status, resource_status_reason=None, resource_properties=None):
def _add_resource_event(
self,
logical_resource_id,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
# not used yet... feel free to help yourself
resource = self.resource_map[logical_resource_id]
self.events.append(FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=logical_resource_id,
physical_resource_id=resource.physical_resource_id,
resource_type=resource.type,
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
))
self.events.append(
FakeEvent(
stack_id=self.stack_id,
stack_name=self.name,
logical_resource_id=logical_resource_id,
physical_resource_id=resource.physical_resource_id,
resource_type=resource.type,
resource_status=resource_status,
resource_status_reason=resource_status_reason,
resource_properties=resource_properties,
)
)
def _parse_template(self):
yaml.add_multi_constructor('', yaml_tag_constructor)
yaml.add_multi_constructor("", yaml_tag_constructor)
try:
self.template_dict = yaml.load(self.template, Loader=yaml.Loader)
except yaml.parser.ParserError:
@ -250,7 +320,9 @@ class FakeStack(BaseModel):
return self.output_map.exports
def update(self, template, role_arn=None, parameters=None, tags=None):
self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated")
self._add_stack_event(
"UPDATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.template = template
self._parse_template()
self.resource_map.update(self.template_dict, parameters)
@ -264,15 +336,15 @@ class FakeStack(BaseModel):
# TODO: update tags in the resource map
def delete(self):
self._add_stack_event("DELETE_IN_PROGRESS",
resource_status_reason="User Initiated")
self._add_stack_event(
"DELETE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.resource_map.delete()
self._add_stack_event("DELETE_COMPLETE")
self.status = "DELETE_COMPLETE"
class FakeChange(BaseModel):
def __init__(self, action, logical_resource_id, resource_type):
self.action = action
self.logical_resource_id = logical_resource_id
@ -280,8 +352,21 @@ class FakeChange(BaseModel):
class FakeChangeSet(FakeStack):
def __init__(self, stack_id, stack_name, stack_template, change_set_id, change_set_name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None):
def __init__(
self,
stack_id,
stack_name,
stack_template,
change_set_id,
change_set_name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
):
super(FakeChangeSet, self).__init__(
stack_id,
stack_name,
@ -306,17 +391,28 @@ class FakeChangeSet(FakeStack):
resources_by_action = self.resource_map.diff(self.template_dict, parameters)
for action, resources in resources_by_action.items():
for resource_name, resource in resources.items():
changes.append(FakeChange(
action=action,
logical_resource_id=resource_name,
resource_type=resource['ResourceType'],
))
changes.append(
FakeChange(
action=action,
logical_resource_id=resource_name,
resource_type=resource["ResourceType"],
)
)
return changes
class FakeEvent(BaseModel):
def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None):
def __init__(
self,
stack_id,
stack_name,
logical_resource_id,
physical_resource_id,
resource_type,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
self.stack_id = stack_id
self.stack_name = stack_name
self.logical_resource_id = logical_resource_id
@ -330,7 +426,6 @@ class FakeEvent(BaseModel):
class CloudFormationBackend(BaseBackend):
def __init__(self):
self.stacks = OrderedDict()
self.stacksets = OrderedDict()
@ -338,7 +433,17 @@ class CloudFormationBackend(BaseBackend):
self.exports = OrderedDict()
self.change_sets = OrderedDict()
def create_stack_set(self, name, template, parameters, tags=None, description=None, region='us-east-1', admin_role=None, execution_role=None):
def create_stack_set(
self,
name,
template,
parameters,
tags=None,
description=None,
region="us-east-1",
admin_role=None,
execution_role=None,
):
stackset_id = generate_stackset_id(name)
new_stackset = FakeStackSet(
stackset_id=stackset_id,
@ -366,7 +471,9 @@ class CloudFormationBackend(BaseBackend):
if self.stacksets[stackset].name == name:
self.stacksets[stackset].delete()
def create_stack_instances(self, stackset_name, accounts, regions, parameters, operation_id=None):
def create_stack_instances(
self, stackset_name, accounts, regions, parameters, operation_id=None
):
stackset = self.get_stack_set(stackset_name)
stackset.create_stack_instances(
@ -377,9 +484,19 @@ class CloudFormationBackend(BaseBackend):
)
return stackset
def update_stack_set(self, stackset_name, template=None, description=None,
parameters=None, tags=None, admin_role=None, execution_role=None,
accounts=None, regions=None, operation_id=None):
def update_stack_set(
self,
stackset_name,
template=None,
description=None,
parameters=None,
tags=None,
admin_role=None,
execution_role=None,
accounts=None,
regions=None,
operation_id=None,
):
stackset = self.get_stack_set(stackset_name)
update = stackset.update(
template=template,
@ -390,16 +507,28 @@ class CloudFormationBackend(BaseBackend):
execution_role=execution_role,
accounts=accounts,
regions=regions,
operation_id=operation_id
operation_id=operation_id,
)
return update
def delete_stack_instances(self, stackset_name, accounts, regions, operation_id=None):
def delete_stack_instances(
self, stackset_name, accounts, regions, operation_id=None
):
stackset = self.get_stack_set(stackset_name)
stackset.delete_stack_instances(accounts, regions, operation_id)
return stackset
def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False):
def create_stack(
self,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
create_change_set=False,
):
stack_id = generate_stack_id(name)
new_stack = FakeStack(
stack_id=stack_id,
@ -419,10 +548,21 @@ class CloudFormationBackend(BaseBackend):
self.exports[export.name] = export
return new_stack
def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None):
def create_change_set(
self,
stack_name,
change_set_name,
template,
parameters,
region_name,
change_set_type,
notification_arns=None,
tags=None,
role_arn=None,
):
stack_id = None
stack_template = None
if change_set_type == 'UPDATE':
if change_set_type == "UPDATE":
stacks = self.stacks.values()
stack = None
for s in stacks:
@ -449,7 +589,7 @@ class CloudFormationBackend(BaseBackend):
notification_arns=notification_arns,
tags=tags,
role_arn=role_arn,
cross_stack_resources=self.exports
cross_stack_resources=self.exports,
)
self.change_sets[change_set_id] = new_change_set
self.stacks[stack_id] = new_change_set
@ -488,11 +628,11 @@ class CloudFormationBackend(BaseBackend):
stack = self.change_sets[cs]
if stack is None:
raise ValidationError(stack_name)
if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS':
stack._add_stack_event('CREATE_COMPLETE')
if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS":
stack._add_stack_event("CREATE_COMPLETE")
else:
stack._add_stack_event('UPDATE_IN_PROGRESS')
stack._add_stack_event('UPDATE_COMPLETE')
stack._add_stack_event("UPDATE_IN_PROGRESS")
stack._add_stack_event("UPDATE_COMPLETE")
return True
def describe_stacks(self, name_or_stack_id):
@ -514,9 +654,7 @@ class CloudFormationBackend(BaseBackend):
return self.change_sets.values()
def list_stacks(self):
return [
v for v in self.stacks.values()
] + [
return [v for v in self.stacks.values()] + [
v for v in self.deleted_stacks.values()
]
@ -558,10 +696,10 @@ class CloudFormationBackend(BaseBackend):
all_exports = list(self.exports.values())
if token is None:
exports = all_exports[0:100]
next_token = '100' if len(all_exports) > 100 else None
next_token = "100" if len(all_exports) > 100 else None
else:
token = int(token)
exports = all_exports[token:token + 100]
exports = all_exports[token : token + 100]
next_token = str(token + 100) if len(all_exports) > token + 100 else None
return exports, next_token
@ -572,7 +710,10 @@ class CloudFormationBackend(BaseBackend):
new_stack_export_names = [x.name for x in stack.exports]
export_names = self.exports.keys()
if not set(export_names).isdisjoint(new_stack_export_names):
raise ValidationError(stack.stack_id, message='Export names must be unique across a given region')
raise ValidationError(
stack.stack_id,
message="Export names must be unique across a given region",
)
cloudformation_backends = {}

View File

@ -28,7 +28,12 @@ from moto.s3 import models as s3_models
from moto.sns import models as sns_models
from moto.sqs import models as sqs_models
from .utils import random_suffix
from .exceptions import ExportNotFound, MissingParameterError, UnformattedGetAttTemplateException, ValidationError
from .exceptions import (
ExportNotFound,
MissingParameterError,
UnformattedGetAttTemplateException,
ValidationError,
)
from boto.cloudformation.stack import Output
MODEL_MAP = {
@ -100,7 +105,7 @@ NAME_TYPE_MAP = {
"AWS::RDS::DBInstance": "DBInstanceIdentifier",
"AWS::S3::Bucket": "BucketName",
"AWS::SNS::Topic": "TopicName",
"AWS::SQS::Queue": "QueueName"
"AWS::SQS::Queue": "QueueName",
}
# Just ignore these models types for now
@ -109,13 +114,12 @@ NULL_MODELS = [
"AWS::CloudFormation::WaitConditionHandle",
]
DEFAULT_REGION = 'us-east-1'
DEFAULT_REGION = "us-east-1"
logger = logging.getLogger("moto")
class LazyDict(dict):
def __getitem__(self, key):
val = dict.__getitem__(self, key)
if callable(val):
@ -132,10 +136,10 @@ def clean_json(resource_json, resources_map):
Eventually, this is where we would add things like function parsing (fn::)
"""
if isinstance(resource_json, dict):
if 'Ref' in resource_json:
if "Ref" in resource_json:
# Parse resource reference
resource = resources_map[resource_json['Ref']]
if hasattr(resource, 'physical_resource_id'):
resource = resources_map[resource_json["Ref"]]
if hasattr(resource, "physical_resource_id"):
return resource.physical_resource_id
else:
return resource
@ -148,74 +152,92 @@ def clean_json(resource_json, resources_map):
result = result[clean_json(path, resources_map)]
return result
if 'Fn::GetAtt' in resource_json:
resource = resources_map.get(resource_json['Fn::GetAtt'][0])
if "Fn::GetAtt" in resource_json:
resource = resources_map.get(resource_json["Fn::GetAtt"][0])
if resource is None:
return resource_json
try:
return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1])
return resource.get_cfn_attribute(resource_json["Fn::GetAtt"][1])
except NotImplementedError as n:
logger.warning(str(n).format(
resource_json['Fn::GetAtt'][0]))
logger.warning(str(n).format(resource_json["Fn::GetAtt"][0]))
except UnformattedGetAttTemplateException:
raise ValidationError(
'Bad Request',
"Bad Request",
UnformattedGetAttTemplateException.description.format(
resource_json['Fn::GetAtt'][0], resource_json['Fn::GetAtt'][1]))
resource_json["Fn::GetAtt"][0], resource_json["Fn::GetAtt"][1]
),
)
if 'Fn::If' in resource_json:
condition_name, true_value, false_value = resource_json['Fn::If']
if "Fn::If" in resource_json:
condition_name, true_value, false_value = resource_json["Fn::If"]
if resources_map.lazy_condition_map[condition_name]:
return clean_json(true_value, resources_map)
else:
return clean_json(false_value, resources_map)
if 'Fn::Join' in resource_json:
join_list = clean_json(resource_json['Fn::Join'][1], resources_map)
return resource_json['Fn::Join'][0].join([str(x) for x in join_list])
if "Fn::Join" in resource_json:
join_list = clean_json(resource_json["Fn::Join"][1], resources_map)
return resource_json["Fn::Join"][0].join([str(x) for x in join_list])
if 'Fn::Split' in resource_json:
to_split = clean_json(resource_json['Fn::Split'][1], resources_map)
return to_split.split(resource_json['Fn::Split'][0])
if "Fn::Split" in resource_json:
to_split = clean_json(resource_json["Fn::Split"][1], resources_map)
return to_split.split(resource_json["Fn::Split"][0])
if 'Fn::Select' in resource_json:
select_index = int(resource_json['Fn::Select'][0])
select_list = clean_json(resource_json['Fn::Select'][1], resources_map)
if "Fn::Select" in resource_json:
select_index = int(resource_json["Fn::Select"][0])
select_list = clean_json(resource_json["Fn::Select"][1], resources_map)
return select_list[select_index]
if 'Fn::Sub' in resource_json:
if isinstance(resource_json['Fn::Sub'], list):
if "Fn::Sub" in resource_json:
if isinstance(resource_json["Fn::Sub"], list):
warnings.warn(
"Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation")
"Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation"
)
else:
fn_sub_value = clean_json(resource_json['Fn::Sub'], resources_map)
fn_sub_value = clean_json(resource_json["Fn::Sub"], resources_map)
to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value)
for sub in to_sub:
if '.' in sub:
cleaned_ref = clean_json({'Fn::GetAtt': re.findall('(?<=\${)[^"]*?(?=})', sub)[0].split('.')}, resources_map)
if "." in sub:
cleaned_ref = clean_json(
{
"Fn::GetAtt": re.findall('(?<=\${)[^"]*?(?=})', sub)[
0
].split(".")
},
resources_map,
)
else:
cleaned_ref = clean_json({'Ref': re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, resources_map)
cleaned_ref = clean_json(
{"Ref": re.findall('(?<=\${)[^"]*?(?=})', sub)[0]},
resources_map,
)
fn_sub_value = fn_sub_value.replace(sub, cleaned_ref)
for literal in literals:
fn_sub_value = fn_sub_value.replace(literal, literal.replace('!', ''))
fn_sub_value = fn_sub_value.replace(
literal, literal.replace("!", "")
)
return fn_sub_value
pass
if 'Fn::ImportValue' in resource_json:
cleaned_val = clean_json(resource_json['Fn::ImportValue'], resources_map)
values = [x.value for x in resources_map.cross_stack_resources.values() if x.name == cleaned_val]
if "Fn::ImportValue" in resource_json:
cleaned_val = clean_json(resource_json["Fn::ImportValue"], resources_map)
values = [
x.value
for x in resources_map.cross_stack_resources.values()
if x.name == cleaned_val
]
if any(values):
return values[0]
else:
raise ExportNotFound(cleaned_val)
if 'Fn::GetAZs' in resource_json:
region = resource_json.get('Fn::GetAZs') or DEFAULT_REGION
if "Fn::GetAZs" in resource_json:
region = resource_json.get("Fn::GetAZs") or DEFAULT_REGION
result = []
# TODO: make this configurable, to reflect the real AWS AZs
for az in ('a', 'b', 'c', 'd'):
result.append('%s%s' % (region, az))
for az in ("a", "b", "c", "d"):
result.append("%s%s" % (region, az))
return result
cleaned_json = {}
@ -246,58 +268,69 @@ def resource_name_property_from_type(resource_type):
def generate_resource_name(resource_type, stack_name, logical_id):
if resource_type in ["AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer"]:
if resource_type in [
"AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer",
]:
# Target group names need to be less than 32 characters, so when cloudformation creates a name for you
# it makes sure to stay under that limit
name_prefix = '{0}-{1}'.format(stack_name, logical_id)
name_prefix = "{0}-{1}".format(stack_name, logical_id)
my_random_suffix = random_suffix()
truncated_name_prefix = name_prefix[0:32 - (len(my_random_suffix) + 1)]
truncated_name_prefix = name_prefix[0 : 32 - (len(my_random_suffix) + 1)]
# if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is
# not allowed
if truncated_name_prefix.endswith('-'):
if truncated_name_prefix.endswith("-"):
truncated_name_prefix = truncated_name_prefix[:-1]
return '{0}-{1}'.format(truncated_name_prefix, my_random_suffix)
return "{0}-{1}".format(truncated_name_prefix, my_random_suffix)
else:
return '{0}-{1}-{2}'.format(stack_name, logical_id, random_suffix())
return "{0}-{1}-{2}".format(stack_name, logical_id, random_suffix())
def parse_resource(logical_id, resource_json, resources_map):
resource_type = resource_json['Type']
resource_type = resource_json["Type"]
resource_class = resource_class_from_type(resource_type)
if not resource_class:
warnings.warn(
"Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(resource_type))
"Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(
resource_type
)
)
return None
resource_json = clean_json(resource_json, resources_map)
resource_name_property = resource_name_property_from_type(resource_type)
if resource_name_property:
if 'Properties' not in resource_json:
resource_json['Properties'] = dict()
if resource_name_property not in resource_json['Properties']:
resource_json['Properties'][resource_name_property] = generate_resource_name(
resource_type, resources_map.get('AWS::StackName'), logical_id)
resource_name = resource_json['Properties'][resource_name_property]
if "Properties" not in resource_json:
resource_json["Properties"] = dict()
if resource_name_property not in resource_json["Properties"]:
resource_json["Properties"][
resource_name_property
] = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
resource_name = resource_json["Properties"][resource_name_property]
else:
resource_name = generate_resource_name(resource_type, resources_map.get('AWS::StackName'), logical_id)
resource_name = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
return resource_class, resource_json, resource_name
def parse_and_create_resource(logical_id, resource_json, resources_map, region_name):
condition = resource_json.get('Condition')
condition = resource_json.get("Condition")
if condition and not resources_map.lazy_condition_map[condition]:
# If this has a False condition, don't create the resource
return None
resource_type = resource_json['Type']
resource_type = resource_json["Type"]
resource_tuple = parse_resource(logical_id, resource_json, resources_map)
if not resource_tuple:
return None
resource_class, resource_json, resource_name = resource_tuple
resource = resource_class.create_from_cloudformation_json(
resource_name, resource_json, region_name)
resource_name, resource_json, region_name
)
resource.type = resource_type
resource.logical_resource_id = logical_id
return resource
@ -305,24 +338,27 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n
def parse_and_update_resource(logical_id, resource_json, resources_map, region_name):
resource_class, new_resource_json, new_resource_name = parse_resource(
logical_id, resource_json, resources_map)
logical_id, resource_json, resources_map
)
original_resource = resources_map[logical_id]
new_resource = resource_class.update_from_cloudformation_json(
original_resource=original_resource,
new_resource_name=new_resource_name,
cloudformation_json=new_resource_json,
region_name=region_name
region_name=region_name,
)
new_resource.type = resource_json['Type']
new_resource.type = resource_json["Type"]
new_resource.logical_resource_id = logical_id
return new_resource
def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name):
resource_class, resource_json, resource_name = parse_resource(
logical_id, resource_json, resources_map)
logical_id, resource_json, resources_map
)
resource_class.delete_from_cloudformation_json(
resource_name, resource_json, region_name)
resource_name, resource_json, region_name
)
def parse_condition(condition, resources_map, condition_map):
@ -334,8 +370,8 @@ def parse_condition(condition, resources_map, condition_map):
condition_values = []
for value in list(condition.values())[0]:
# Check if we are referencing another Condition
if 'Condition' in value:
condition_values.append(condition_map[value['Condition']])
if "Condition" in value:
condition_values.append(condition_map[value["Condition"]])
else:
condition_values.append(clean_json(value, resources_map))
@ -344,23 +380,27 @@ def parse_condition(condition, resources_map, condition_map):
elif condition_operator == "Fn::Not":
return not parse_condition(condition_values[0], resources_map, condition_map)
elif condition_operator == "Fn::And":
return all([
parse_condition(condition_value, resources_map, condition_map)
for condition_value
in condition_values])
return all(
[
parse_condition(condition_value, resources_map, condition_map)
for condition_value in condition_values
]
)
elif condition_operator == "Fn::Or":
return any([
parse_condition(condition_value, resources_map, condition_map)
for condition_value
in condition_values])
return any(
[
parse_condition(condition_value, resources_map, condition_map)
for condition_value in condition_values
]
)
def parse_output(output_logical_id, output_json, resources_map):
output_json = clean_json(output_json, resources_map)
output = Output()
output.key = output_logical_id
output.value = clean_json(output_json['Value'], resources_map)
output.description = output_json.get('Description')
output.value = clean_json(output_json["Value"], resources_map)
output.description = output_json.get("Description")
return output
@ -371,9 +411,18 @@ class ResourceMap(collections.Mapping):
each resources is passed this lazy map that it can grab dependencies from.
"""
def __init__(self, stack_id, stack_name, parameters, tags, region_name, template, cross_stack_resources):
def __init__(
self,
stack_id,
stack_name,
parameters,
tags,
region_name,
template,
cross_stack_resources,
):
self._template = template
self._resource_json_map = template['Resources']
self._resource_json_map = template["Resources"]
self._region_name = region_name
self.input_parameters = parameters
self.tags = copy.deepcopy(tags)
@ -401,7 +450,8 @@ class ResourceMap(collections.Mapping):
if not resource_json:
raise KeyError(resource_logical_id)
new_resource = parse_and_create_resource(
resource_logical_id, resource_json, self, self._region_name)
resource_logical_id, resource_json, self, self._region_name
)
if new_resource is not None:
self._parsed_resources[resource_logical_id] = new_resource
return new_resource
@ -417,13 +467,13 @@ class ResourceMap(collections.Mapping):
return self._resource_json_map.keys()
def load_mapping(self):
self._parsed_resources.update(self._template.get('Mappings', {}))
self._parsed_resources.update(self._template.get("Mappings", {}))
def load_parameters(self):
parameter_slots = self._template.get('Parameters', {})
parameter_slots = self._template.get("Parameters", {})
for parameter_name, parameter in parameter_slots.items():
# Set the default values.
self.resolved_parameters[parameter_name] = parameter.get('Default')
self.resolved_parameters[parameter_name] = parameter.get("Default")
# Set any input parameters that were passed
self.no_echo_parameter_keys = []
@ -431,11 +481,11 @@ class ResourceMap(collections.Mapping):
if key in self.resolved_parameters:
parameter_slot = parameter_slots[key]
value_type = parameter_slot.get('Type', 'String')
if value_type == 'CommaDelimitedList' or value_type.startswith("List"):
value = value.split(',')
value_type = parameter_slot.get("Type", "String")
if value_type == "CommaDelimitedList" or value_type.startswith("List"):
value = value.split(",")
if parameter_slot.get('NoEcho'):
if parameter_slot.get("NoEcho"):
self.no_echo_parameter_keys.append(key)
self.resolved_parameters[key] = value
@ -449,11 +499,15 @@ class ResourceMap(collections.Mapping):
self._parsed_resources.update(self.resolved_parameters)
def load_conditions(self):
conditions = self._template.get('Conditions', {})
conditions = self._template.get("Conditions", {})
self.lazy_condition_map = LazyDict()
for condition_name, condition in conditions.items():
self.lazy_condition_map[condition_name] = functools.partial(parse_condition,
condition, self._parsed_resources, self.lazy_condition_map)
self.lazy_condition_map[condition_name] = functools.partial(
parse_condition,
condition,
self._parsed_resources,
self.lazy_condition_map,
)
for condition_name in self.lazy_condition_map:
self.lazy_condition_map[condition_name]
@ -465,13 +519,18 @@ class ResourceMap(collections.Mapping):
# Since this is a lazy map, to create every object we just need to
# iterate through self.
self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'),
'aws:cloudformation:stack-id': self.get('AWS::StackId')})
self.tags.update(
{
"aws:cloudformation:stack-name": self.get("AWS::StackName"),
"aws:cloudformation:stack-id": self.get("AWS::StackId"),
}
)
for resource in self.resources:
if isinstance(self[resource], ec2_models.TaggedEC2Resource):
self.tags['aws:cloudformation:logical-id'] = resource
self.tags["aws:cloudformation:logical-id"] = resource
ec2_models.ec2_backends[self._region_name].create_tags(
[self[resource].physical_resource_id], self.tags)
[self[resource].physical_resource_id], self.tags
)
def diff(self, template, parameters=None):
if parameters:
@ -481,36 +540,35 @@ class ResourceMap(collections.Mapping):
self.load_conditions()
old_template = self._resource_json_map
new_template = template['Resources']
new_template = template["Resources"]
resource_names_by_action = {
'Add': set(new_template) - set(old_template),
'Modify': set(name for name in new_template if name in old_template and new_template[
name] != old_template[name]),
'Remove': set(old_template) - set(new_template)
}
resources_by_action = {
'Add': {},
'Modify': {},
'Remove': {},
"Add": set(new_template) - set(old_template),
"Modify": set(
name
for name in new_template
if name in old_template and new_template[name] != old_template[name]
),
"Remove": set(old_template) - set(new_template),
}
resources_by_action = {"Add": {}, "Modify": {}, "Remove": {}}
for resource_name in resource_names_by_action['Add']:
resources_by_action['Add'][resource_name] = {
'LogicalResourceId': resource_name,
'ResourceType': new_template[resource_name]['Type']
for resource_name in resource_names_by_action["Add"]:
resources_by_action["Add"][resource_name] = {
"LogicalResourceId": resource_name,
"ResourceType": new_template[resource_name]["Type"],
}
for resource_name in resource_names_by_action['Modify']:
resources_by_action['Modify'][resource_name] = {
'LogicalResourceId': resource_name,
'ResourceType': new_template[resource_name]['Type']
for resource_name in resource_names_by_action["Modify"]:
resources_by_action["Modify"][resource_name] = {
"LogicalResourceId": resource_name,
"ResourceType": new_template[resource_name]["Type"],
}
for resource_name in resource_names_by_action['Remove']:
resources_by_action['Remove'][resource_name] = {
'LogicalResourceId': resource_name,
'ResourceType': old_template[resource_name]['Type']
for resource_name in resource_names_by_action["Remove"]:
resources_by_action["Remove"][resource_name] = {
"LogicalResourceId": resource_name,
"ResourceType": old_template[resource_name]["Type"],
}
return resources_by_action
@ -519,35 +577,38 @@ class ResourceMap(collections.Mapping):
resources_by_action = self.diff(template, parameters)
old_template = self._resource_json_map
new_template = template['Resources']
new_template = template["Resources"]
self._resource_json_map = new_template
for resource_name, resource in resources_by_action['Add'].items():
for resource_name, resource in resources_by_action["Add"].items():
resource_json = new_template[resource_name]
new_resource = parse_and_create_resource(
resource_name, resource_json, self, self._region_name)
resource_name, resource_json, self, self._region_name
)
self._parsed_resources[resource_name] = new_resource
for resource_name, resource in resources_by_action['Remove'].items():
for resource_name, resource in resources_by_action["Remove"].items():
resource_json = old_template[resource_name]
parse_and_delete_resource(
resource_name, resource_json, self, self._region_name)
resource_name, resource_json, self, self._region_name
)
self._parsed_resources.pop(resource_name)
tries = 1
while resources_by_action['Modify'] and tries < 5:
for resource_name, resource in resources_by_action['Modify'].copy().items():
while resources_by_action["Modify"] and tries < 5:
for resource_name, resource in resources_by_action["Modify"].copy().items():
resource_json = new_template[resource_name]
try:
changed_resource = parse_and_update_resource(
resource_name, resource_json, self, self._region_name)
resource_name, resource_json, self, self._region_name
)
except Exception as e:
# skip over dependency violations, and try again in a
# second pass
last_exception = e
else:
self._parsed_resources[resource_name] = changed_resource
del resources_by_action['Modify'][resource_name]
del resources_by_action["Modify"][resource_name]
tries += 1
if tries == 5:
raise last_exception
@ -559,7 +620,7 @@ class ResourceMap(collections.Mapping):
for resource in remaining_resources.copy():
parsed_resource = self._parsed_resources.get(resource)
try:
if parsed_resource and hasattr(parsed_resource, 'delete'):
if parsed_resource and hasattr(parsed_resource, "delete"):
parsed_resource.delete(self._region_name)
except Exception as e:
# skip over dependency violations, and try again in a
@ -573,11 +634,10 @@ class ResourceMap(collections.Mapping):
class OutputMap(collections.Mapping):
def __init__(self, resources, template, stack_id):
self._template = template
self._stack_id = stack_id
self._output_json_map = template.get('Outputs')
self._output_json_map = template.get("Outputs")
# Create the default resources
self._resource_map = resources
@ -591,7 +651,8 @@ class OutputMap(collections.Mapping):
else:
output_json = self._output_json_map.get(output_logical_id)
new_output = parse_output(
output_logical_id, output_json, self._resource_map)
output_logical_id, output_json, self._resource_map
)
self._parsed_outputs[output_logical_id] = new_output
return new_output
@ -610,9 +671,11 @@ class OutputMap(collections.Mapping):
exports = []
if self.outputs:
for key, value in self._output_json_map.items():
if value.get('Export'):
cleaned_name = clean_json(value['Export'].get('Name'), self._resource_map)
cleaned_value = clean_json(value.get('Value'), self._resource_map)
if value.get("Export"):
cleaned_name = clean_json(
value["Export"].get("Name"), self._resource_map
)
cleaned_value = clean_json(value.get("Value"), self._resource_map)
exports.append(Export(self._stack_id, cleaned_name, cleaned_value))
return exports
@ -622,7 +685,6 @@ class OutputMap(collections.Mapping):
class Export(object):
def __init__(self, exporting_stack_id, name, value):
self._exporting_stack_id = exporting_stack_id
self._name = name

View File

@ -12,7 +12,6 @@ from .exceptions import ValidationError
class CloudFormationResponse(BaseResponse):
@property
def cloudformation_backend(self):
return cloudformation_backends[self.region]
@ -20,17 +19,18 @@ class CloudFormationResponse(BaseResponse):
def _get_stack_from_s3_url(self, template_url):
template_url_parts = urlparse(template_url)
if "localhost" in template_url:
bucket_name, key_name = template_url_parts.path.lstrip(
"/").split("/", 1)
bucket_name, key_name = template_url_parts.path.lstrip("/").split("/", 1)
else:
if template_url_parts.netloc.endswith('amazonaws.com') \
and template_url_parts.netloc.startswith('s3'):
if template_url_parts.netloc.endswith(
"amazonaws.com"
) and template_url_parts.netloc.startswith("s3"):
# Handle when S3 url uses amazon url with bucket in path
# Also handles getting region as technically s3 is region'd
# region = template_url.netloc.split('.')[1]
bucket_name, key_name = template_url_parts.path.lstrip(
"/").split("/", 1)
bucket_name, key_name = template_url_parts.path.lstrip("/").split(
"/", 1
)
else:
bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/")
@ -39,24 +39,26 @@ class CloudFormationResponse(BaseResponse):
return key.value.decode("utf-8")
def create_stack(self):
stack_name = self._get_param('StackName')
stack_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
role_arn = self._get_param('RoleARN')
stack_name = self._get_param("StackName")
stack_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
role_arn = self._get_param("RoleARN")
parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Hack dict-comprehension
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in parameters_list
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in parameters_list
]
)
if template_url:
stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param(
'NotificationARNs.member')
stack_notification_arns = self._get_multi_param("NotificationARNs.member")
stack = self.cloudformation_backend.create_stack(
name=stack_name,
@ -68,34 +70,37 @@ class CloudFormationResponse(BaseResponse):
role_arn=role_arn,
)
if self.request_json:
return json.dumps({
'CreateStackResponse': {
'CreateStackResult': {
'StackId': stack.stack_id,
return json.dumps(
{
"CreateStackResponse": {
"CreateStackResult": {"StackId": stack.stack_id}
}
}
})
)
else:
template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE)
return template.render(stack=stack)
@amzn_request_id
def create_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
role_arn = self._get_param('RoleARN')
update_or_create = self._get_param('ChangeSetType', 'CREATE')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
stack_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
role_arn = self._get_param("RoleARN")
update_or_create = self._get_param("ChangeSetType", "CREATE")
parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
parameters = {param['parameter_key']: param['parameter_value']
for param in parameters_list}
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
parameters = {
param["parameter_key"]: param["parameter_value"]
for param in parameters_list
}
if template_url:
stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param(
'NotificationARNs.member')
stack_notification_arns = self._get_multi_param("NotificationARNs.member")
change_set_id, stack_id = self.cloudformation_backend.create_change_set(
stack_name=stack_name,
change_set_name=change_set_name,
@ -108,66 +113,64 @@ class CloudFormationResponse(BaseResponse):
change_set_type=update_or_create,
)
if self.request_json:
return json.dumps({
'CreateChangeSetResponse': {
'CreateChangeSetResult': {
'Id': change_set_id,
'StackId': stack_id,
return json.dumps(
{
"CreateChangeSetResponse": {
"CreateChangeSetResult": {
"Id": change_set_id,
"StackId": stack_id,
}
}
}
})
)
else:
template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(stack_id=stack_id, change_set_id=change_set_id)
def delete_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.delete_change_set(change_set_name=change_set_name, stack_name=stack_name)
self.cloudformation_backend.delete_change_set(
change_set_name=change_set_name, stack_name=stack_name
)
if self.request_json:
return json.dumps({
'DeleteChangeSetResponse': {
'DeleteChangeSetResult': {},
}
})
return json.dumps(
{"DeleteChangeSetResponse": {"DeleteChangeSetResult": {}}}
)
else:
template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render()
def describe_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
change_set = self.cloudformation_backend.describe_change_set(
change_set_name=change_set_name,
stack_name=stack_name,
change_set_name=change_set_name, stack_name=stack_name
)
template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(change_set=change_set)
@amzn_request_id
def execute_change_set(self):
stack_name = self._get_param('StackName')
change_set_name = self._get_param('ChangeSetName')
stack_name = self._get_param("StackName")
change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.execute_change_set(
stack_name=stack_name,
change_set_name=change_set_name,
stack_name=stack_name, change_set_name=change_set_name
)
if self.request_json:
return json.dumps({
'ExecuteChangeSetResponse': {
'ExecuteChangeSetResult': {},
}
})
return json.dumps(
{"ExecuteChangeSetResponse": {"ExecuteChangeSetResult": {}}}
)
else:
template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render()
def describe_stacks(self):
stack_name_or_id = None
if self._get_param('StackName'):
stack_name_or_id = self.querystring.get('StackName')[0]
token = self._get_param('NextToken')
if self._get_param("StackName"):
stack_name_or_id = self.querystring.get("StackName")[0]
token = self._get_param("NextToken")
stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id)
stack_ids = [stack.stack_id for stack in stacks]
if token:
@ -175,7 +178,7 @@ class CloudFormationResponse(BaseResponse):
else:
start = 0
max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB
stacks_resp = stacks[start:start + max_results]
stacks_resp = stacks[start : start + max_results]
next_token = None
if len(stacks) > (start + max_results):
next_token = stacks_resp[-1].stack_id
@ -183,9 +186,9 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks_resp, next_token=next_token)
def describe_stack_resource(self):
stack_name = self._get_param('StackName')
stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name)
logical_resource_id = self._get_param('LogicalResourceId')
logical_resource_id = self._get_param("LogicalResourceId")
for stack_resource in stack.stack_resources:
if stack_resource.logical_resource_id == logical_resource_id:
@ -194,19 +197,18 @@ class CloudFormationResponse(BaseResponse):
else:
raise ValidationError(logical_resource_id)
template = self.response_template(
DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
template = self.response_template(DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
return template.render(stack=stack, resource=resource)
def describe_stack_resources(self):
stack_name = self._get_param('StackName')
stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE)
return template.render(stack=stack)
def describe_stack_events(self):
stack_name = self._get_param('StackName')
stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE)
@ -223,68 +225,82 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks)
def list_stack_resources(self):
stack_name_or_id = self._get_param('StackName')
resources = self.cloudformation_backend.list_stack_resources(
stack_name_or_id)
stack_name_or_id = self._get_param("StackName")
resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id)
template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE)
return template.render(resources=resources)
def get_template(self):
name_or_stack_id = self.querystring.get('StackName')[0]
name_or_stack_id = self.querystring.get("StackName")[0]
stack = self.cloudformation_backend.get_stack(name_or_stack_id)
if self.request_json:
return json.dumps({
"GetTemplateResponse": {
"GetTemplateResult": {
"TemplateBody": stack.template,
"ResponseMetadata": {
"RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE"
return json.dumps(
{
"GetTemplateResponse": {
"GetTemplateResult": {
"TemplateBody": stack.template,
"ResponseMetadata": {
"RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE"
},
}
}
}
})
)
else:
template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE)
return template.render(stack=stack)
def update_stack(self):
stack_name = self._get_param('StackName')
role_arn = self._get_param('RoleARN')
template_url = self._get_param('TemplateURL')
stack_body = self._get_param('TemplateBody')
stack_name = self._get_param("StackName")
role_arn = self._get_param("RoleARN")
template_url = self._get_param("TemplateURL")
stack_body = self._get_param("TemplateBody")
stack = self.cloudformation_backend.get_stack(stack_name)
if self._get_param('UsePreviousTemplate') == "true":
if self._get_param("UsePreviousTemplate") == "true":
stack_body = stack.template
elif not stack_body and template_url:
stack_body = self._get_stack_from_s3_url(template_url)
incoming_params = self._get_list_prefix("Parameters.member")
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in incoming_params if 'parameter_value' in parameter
])
previous = dict([
(parameter['parameter_key'], stack.parameters[parameter['parameter_key']])
for parameter
in incoming_params if 'use_previous_value' in parameter
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in incoming_params
if "parameter_value" in parameter
]
)
previous = dict(
[
(
parameter["parameter_key"],
stack.parameters[parameter["parameter_key"]],
)
for parameter in incoming_params
if "use_previous_value" in parameter
]
)
parameters.update(previous)
# boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't
# end up containing anything we can use to differentiate between passing an empty value versus not
# passing anything. so until that changes, moto won't be able to clear tags, only update them.
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# so that if we don't pass the parameter, we don't clear all the tags accidentally
if not tags:
tags = None
stack = self.cloudformation_backend.get_stack(stack_name)
if stack.status == 'ROLLBACK_COMPLETE':
if stack.status == "ROLLBACK_COMPLETE":
raise ValidationError(
stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id))
stack.stack_id,
message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(
stack.stack_id
),
)
stack = self.cloudformation_backend.update_stack(
name=stack_name,
@ -295,11 +311,7 @@ class CloudFormationResponse(BaseResponse):
)
if self.request_json:
stack_body = {
'UpdateStackResponse': {
'UpdateStackResult': {
'StackId': stack.name,
}
}
"UpdateStackResponse": {"UpdateStackResult": {"StackId": stack.name}}
}
return json.dumps(stack_body)
else:
@ -307,56 +319,57 @@ class CloudFormationResponse(BaseResponse):
return template.render(stack=stack)
def delete_stack(self):
name_or_stack_id = self.querystring.get('StackName')[0]
name_or_stack_id = self.querystring.get("StackName")[0]
self.cloudformation_backend.delete_stack(name_or_stack_id)
if self.request_json:
return json.dumps({
'DeleteStackResponse': {
'DeleteStackResult': {},
}
})
return json.dumps({"DeleteStackResponse": {"DeleteStackResult": {}}})
else:
template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE)
return template.render()
def list_exports(self):
token = self._get_param('NextToken')
token = self._get_param("NextToken")
exports, next_token = self.cloudformation_backend.list_exports(token=token)
template = self.response_template(LIST_EXPORTS_RESPONSE)
return template.render(exports=exports, next_token=next_token)
def validate_template(self):
cfn_lint = self.cloudformation_backend.validate_template(self._get_param('TemplateBody'))
cfn_lint = self.cloudformation_backend.validate_template(
self._get_param("TemplateBody")
)
if cfn_lint:
raise ValidationError(cfn_lint[0].message)
description = ""
try:
description = json.loads(self._get_param('TemplateBody'))['Description']
description = json.loads(self._get_param("TemplateBody"))["Description"]
except (ValueError, KeyError):
pass
try:
description = yaml.load(self._get_param('TemplateBody'))['Description']
description = yaml.load(self._get_param("TemplateBody"))["Description"]
except (yaml.ParserError, KeyError):
pass
template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE)
return template.render(description=description)
def create_stack_set(self):
stackset_name = self._get_param('StackSetName')
stack_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
stackset_name = self._get_param("StackSetName")
stack_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
# role_arn = self._get_param('RoleARN')
parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Copy-Pasta - Hack dict-comprehension
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in parameters_list
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in parameters_list
]
)
if template_url:
stack_body = self._get_stack_from_s3_url(template_url)
@ -368,59 +381,65 @@ class CloudFormationResponse(BaseResponse):
# role_arn=role_arn,
)
if self.request_json:
return json.dumps({
'CreateStackSetResponse': {
'CreateStackSetResult': {
'StackSetId': stackset.stackset_id,
return json.dumps(
{
"CreateStackSetResponse": {
"CreateStackSetResult": {"StackSetId": stackset.stackset_id}
}
}
})
)
else:
template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset)
def create_stack_instances(self):
stackset_name = self._get_param('StackSetName')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
parameters = self._get_multi_param('ParameterOverrides.member')
self.cloudformation_backend.create_stack_instances(stackset_name, accounts, regions, parameters)
stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param("ParameterOverrides.member")
self.cloudformation_backend.create_stack_instances(
stackset_name, accounts, regions, parameters
)
template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE)
return template.render()
def delete_stack_set(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
self.cloudformation_backend.delete_stack_set(stackset_name)
template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE)
return template.render()
def delete_stack_instances(self):
stackset_name = self._get_param('StackSetName')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
operation = self.cloudformation_backend.delete_stack_instances(stackset_name, accounts, regions)
stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
operation = self.cloudformation_backend.delete_stack_instances(
stackset_name, accounts, regions
)
template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE)
return template.render(operation=operation)
def describe_stack_set(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
if not stackset.admin_role:
stackset.admin_role = 'arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole'
stackset.admin_role = "arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole"
if not stackset.execution_role:
stackset.execution_role = 'AWSCloudFormationStackSetExecutionRole'
stackset.execution_role = "AWSCloudFormationStackSetExecutionRole"
template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset)
def describe_stack_instance(self):
stackset_name = self._get_param('StackSetName')
account = self._get_param('StackInstanceAccount')
region = self._get_param('StackInstanceRegion')
stackset_name = self._get_param("StackSetName")
account = self._get_param("StackInstanceAccount")
region = self._get_param("StackInstanceRegion")
instance = self.cloudformation_backend.get_stack_set(stackset_name).instances.get_instance(account, region)
instance = self.cloudformation_backend.get_stack_set(
stackset_name
).instances.get_instance(account, region)
template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE)
rendered = template.render(instance=instance)
return rendered
@ -431,61 +450,66 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacksets=stacksets)
def list_stack_instances(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE)
return template.render(stackset=stackset)
def list_stack_set_operations(self):
stackset_name = self._get_param('StackSetName')
stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE)
return template.render(stackset=stackset)
def stop_stack_set_operation(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
stackset.update_operation(operation_id, 'STOPPED')
stackset.update_operation(operation_id, "STOPPED")
template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE)
return template.render()
def describe_stack_set_operation(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id)
template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE)
return template.render(stackset=stackset, operation=operation)
def list_stack_set_operation_results(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id)
template = self.response_template(LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE)
template = self.response_template(
LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE
)
return template.render(operation=operation)
def update_stack_set(self):
stackset_name = self._get_param('StackSetName')
operation_id = self._get_param('OperationId')
description = self._get_param('Description')
execution_role = self._get_param('ExecutionRoleName')
admin_role = self._get_param('AdministrationRoleARN')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
template_body = self._get_param('TemplateBody')
template_url = self._get_param('TemplateURL')
stackset_name = self._get_param("StackSetName")
operation_id = self._get_param("OperationId")
description = self._get_param("Description")
execution_role = self._get_param("ExecutionRoleName")
admin_role = self._get_param("AdministrationRoleARN")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
template_body = self._get_param("TemplateBody")
template_url = self._get_param("TemplateURL")
if template_url:
template_body = self._get_stack_from_s3_url(template_url)
tags = dict((item['key'], item['value'])
for item in self._get_list_prefix("Tags.member"))
tags = dict(
(item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
parameters_list = self._get_list_prefix("Parameters.member")
parameters = dict([
(parameter['parameter_key'], parameter['parameter_value'])
for parameter
in parameters_list
])
parameters = dict(
[
(parameter["parameter_key"], parameter["parameter_value"])
for parameter in parameters_list
]
)
operation = self.cloudformation_backend.update_stack_set(
stackset_name=stackset_name,
template=template_body,
@ -496,18 +520,20 @@ class CloudFormationResponse(BaseResponse):
execution_role=execution_role,
accounts=accounts,
regions=regions,
operation_id=operation_id
operation_id=operation_id,
)
template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(operation=operation)
def update_stack_instances(self):
stackset_name = self._get_param('StackSetName')
accounts = self._get_multi_param('Accounts.member')
regions = self._get_multi_param('Regions.member')
parameters = self._get_multi_param('ParameterOverrides.member')
operation = self.cloudformation_backend.get_stack_set(stackset_name).update_instances(accounts, regions, parameters)
stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param("ParameterOverrides.member")
operation = self.cloudformation_backend.get_stack_set(
stackset_name
).update_instances(accounts, regions, parameters)
template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE)
return template.render(operation=operation)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import CloudFormationResponse
url_bases = [
"https?://cloudformation.(.+).amazonaws.com",
]
url_bases = ["https?://cloudformation.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CloudFormationResponse.dispatch,
}
url_paths = {"{0}/$": CloudFormationResponse.dispatch}

View File

@ -11,44 +11,51 @@ from cfnlint import decode, core
def generate_stack_id(stack_name, region="us-east-1", account="123456789"):
random_id = uuid.uuid4()
return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(region, account, stack_name, random_id)
return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(
region, account, stack_name, random_id
)
def generate_changeset_id(changeset_name, region_name):
random_id = uuid.uuid4()
return 'arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}'.format(region_name, changeset_name, random_id)
return "arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}".format(
region_name, changeset_name, random_id
)
def generate_stackset_id(stackset_name):
random_id = uuid.uuid4()
return '{}:{}'.format(stackset_name, random_id)
return "{}:{}".format(stackset_name, random_id)
def generate_stackset_arn(stackset_id, region_name):
return 'arn:aws:cloudformation:{}:123456789012:stackset/{}'.format(region_name, stackset_id)
return "arn:aws:cloudformation:{}:123456789012:stackset/{}".format(
region_name, stackset_id
)
def random_suffix():
size = 12
chars = list(range(10)) + list(string.ascii_uppercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size))
return "".join(six.text_type(random.choice(chars)) for x in range(size))
def yaml_tag_constructor(loader, tag, node):
"""convert shorthand intrinsic function to full name
"""
def _f(loader, tag, node):
if tag == '!GetAtt':
return node.value.split('.')
if tag == "!GetAtt":
return node.value.split(".")
elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node)
else:
return node.value
if tag == '!Ref':
key = 'Ref'
if tag == "!Ref":
key = "Ref"
else:
key = 'Fn::{}'.format(tag[1:])
key = "Fn::{}".format(tag[1:])
return {key: _f(loader, tag, node)}
@ -71,13 +78,9 @@ def validate_template_cfn_lint(template):
rules = core.get_rules([], [], [])
# Use us-east-1 region (spec file) for validation
regions = ['us-east-1']
regions = ["us-east-1"]
# Process all the rules and gather the errors
matches = core.run_checks(
abs_filename,
template,
rules,
regions)
matches = core.run_checks(abs_filename, template, rules, regions)
return matches

View File

@ -1,6 +1,6 @@
from .models import cloudwatch_backends
from ..core.models import base_decorator, deprecated_base_decorator
cloudwatch_backend = cloudwatch_backends['us-east-1']
cloudwatch_backend = cloudwatch_backends["us-east-1"]
mock_cloudwatch = base_decorator(cloudwatch_backends)
mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends)

View File

@ -1,4 +1,3 @@
import json
from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core import BaseBackend, BaseModel
@ -14,7 +13,6 @@ _EMPTY_LIST = tuple()
class Dimension(object):
def __init__(self, name, value):
self.name = name
self.value = value
@ -49,10 +47,23 @@ def daterange(start, stop, step=timedelta(days=1), inclusive=False):
class FakeAlarm(BaseModel):
def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
period, threshold, statistic, description, dimensions, alarm_actions,
ok_actions, insufficient_data_actions, unit):
def __init__(
self,
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
):
self.name = name
self.namespace = namespace
self.metric_name = metric_name
@ -62,8 +73,9 @@ class FakeAlarm(BaseModel):
self.threshold = threshold
self.statistic = statistic
self.description = description
self.dimensions = [Dimension(dimension['name'], dimension[
'value']) for dimension in dimensions]
self.dimensions = [
Dimension(dimension["name"], dimension["value"]) for dimension in dimensions
]
self.alarm_actions = alarm_actions
self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions
@ -72,15 +84,21 @@ class FakeAlarm(BaseModel):
self.history = []
self.state_reason = ''
self.state_reason_data = '{}'
self.state_value = 'OK'
self.state_reason = ""
self.state_reason_data = "{}"
self.state_value = "OK"
self.state_updated_timestamp = datetime.utcnow()
def update_state(self, reason, reason_data, state_value):
# History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action
self.history.append(
('StateUpdate', self.state_reason, self.state_reason_data, self.state_value, self.state_updated_timestamp)
(
"StateUpdate",
self.state_reason,
self.state_reason_data,
self.state_value,
self.state_updated_timestamp,
)
)
self.state_reason = reason
@ -90,14 +108,14 @@ class FakeAlarm(BaseModel):
class MetricDatum(BaseModel):
def __init__(self, namespace, name, value, dimensions, timestamp):
self.namespace = namespace
self.name = name
self.value = value
self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc())
self.dimensions = [Dimension(dimension['Name'], dimension[
'Value']) for dimension in dimensions]
self.dimensions = [
Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
]
class Dashboard(BaseModel):
@ -120,7 +138,7 @@ class Dashboard(BaseModel):
return len(self.body)
def __repr__(self):
return '<CloudWatchDashboard {0}>'.format(self.name)
return "<CloudWatchDashboard {0}>".format(self.name)
class Statistics:
@ -131,7 +149,7 @@ class Statistics:
@property
def sample_count(self):
if 'SampleCount' not in self.stats:
if "SampleCount" not in self.stats:
return None
return len(self.values)
@ -142,28 +160,28 @@ class Statistics:
@property
def sum(self):
if 'Sum' not in self.stats:
if "Sum" not in self.stats:
return None
return sum(self.values)
@property
def minimum(self):
if 'Minimum' not in self.stats:
if "Minimum" not in self.stats:
return None
return min(self.values)
@property
def maximum(self):
if 'Maximum' not in self.stats:
if "Maximum" not in self.stats:
return None
return max(self.values)
@property
def average(self):
if 'Average' not in self.stats:
if "Average" not in self.stats:
return None
# when moto is 3.4+ we can switch to the statistics module
@ -171,18 +189,44 @@ class Statistics:
class CloudWatchBackend(BaseBackend):
def __init__(self):
self.alarms = {}
self.dashboards = {}
self.metric_data = []
def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods,
period, threshold, statistic, description, dimensions,
alarm_actions, ok_actions, insufficient_data_actions, unit):
alarm = FakeAlarm(name, namespace, metric_name, comparison_operator, evaluation_periods, period,
threshold, statistic, description, dimensions, alarm_actions,
ok_actions, insufficient_data_actions, unit)
def put_metric_alarm(
self,
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
):
alarm = FakeAlarm(
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
)
self.alarms[name] = alarm
return alarm
@ -214,14 +258,12 @@ class CloudWatchBackend(BaseBackend):
]
def get_alarms_by_alarm_names(self, alarm_names):
return [
alarm
for alarm in self.alarms.values()
if alarm.name in alarm_names
]
return [alarm for alarm in self.alarms.values() if alarm.name in alarm_names]
def get_alarms_by_state_value(self, target_state):
return filter(lambda alarm: alarm.state_value == target_state, self.alarms.values())
return filter(
lambda alarm: alarm.state_value == target_state, self.alarms.values()
)
def delete_alarms(self, alarm_names):
for alarm_name in alarm_names:
@ -230,17 +272,31 @@ class CloudWatchBackend(BaseBackend):
def put_metric_data(self, namespace, metric_data):
for metric_member in metric_data:
# Preserve "datetime" for get_metric_statistics comparisons
timestamp = metric_member.get('Timestamp')
timestamp = metric_member.get("Timestamp")
if timestamp is not None and type(timestamp) != datetime:
timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ')
timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
timestamp = timestamp.replace(tzinfo=tzutc())
self.metric_data.append(MetricDatum(
namespace, metric_member['MetricName'], float(metric_member.get('Value', 0)), metric_member.get('Dimensions.member', _EMPTY_LIST), timestamp))
self.metric_data.append(
MetricDatum(
namespace,
metric_member["MetricName"],
float(metric_member.get("Value", 0)),
metric_member.get("Dimensions.member", _EMPTY_LIST),
timestamp,
)
)
def get_metric_statistics(self, namespace, metric_name, start_time, end_time, period, stats):
def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats
):
period_delta = timedelta(seconds=period)
filtered_data = [md for md in self.metric_data if
md.namespace == namespace and md.name == metric_name and start_time <= md.timestamp <= end_time]
filtered_data = [
md
for md in self.metric_data
if md.namespace == namespace
and md.name == metric_name
and start_time <= md.timestamp <= end_time
]
# earliest to oldest
filtered_data = sorted(filtered_data, key=lambda x: x.timestamp)
@ -249,9 +305,15 @@ class CloudWatchBackend(BaseBackend):
idx = 0
data = list()
for dt in daterange(filtered_data[0].timestamp, filtered_data[-1].timestamp + period_delta, period_delta):
for dt in daterange(
filtered_data[0].timestamp,
filtered_data[-1].timestamp + period_delta,
period_delta,
):
s = Statistics(stats, dt)
while idx < len(filtered_data) and filtered_data[idx].timestamp < (dt + period_delta):
while idx < len(filtered_data) and filtered_data[idx].timestamp < (
dt + period_delta
):
s.values.append(filtered_data[idx].value)
idx += 1
@ -268,7 +330,7 @@ class CloudWatchBackend(BaseBackend):
def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body)
def list_dashboards(self, prefix=''):
def list_dashboards(self, prefix=""):
for key, value in self.dashboards.items():
if key.startswith(prefix):
yield value
@ -280,7 +342,12 @@ class CloudWatchBackend(BaseBackend):
left_over = to_delete - all_dashboards
if len(left_over) > 0:
# Some dashboards are not found
return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over))
return (
False,
"The specified dashboard does not exist. [{0}]".format(
", ".join(left_over)
),
)
for dashboard in to_delete:
del self.dashboards[dashboard]
@ -295,32 +362,36 @@ class CloudWatchBackend(BaseBackend):
if reason_data is not None:
json.loads(reason_data)
except ValueError:
raise RESTError('InvalidFormat', 'StateReasonData is invalid JSON')
raise RESTError("InvalidFormat", "StateReasonData is invalid JSON")
if alarm_name not in self.alarms:
raise RESTError('ResourceNotFound', 'Alarm {0} not found'.format(alarm_name), status=404)
raise RESTError(
"ResourceNotFound", "Alarm {0} not found".format(alarm_name), status=404
)
if state_value not in ('OK', 'ALARM', 'INSUFFICIENT_DATA'):
raise RESTError('InvalidParameterValue', 'StateValue is not one of OK | ALARM | INSUFFICIENT_DATA')
if state_value not in ("OK", "ALARM", "INSUFFICIENT_DATA"):
raise RESTError(
"InvalidParameterValue",
"StateValue is not one of OK | ALARM | INSUFFICIENT_DATA",
)
self.alarms[alarm_name].update_state(reason, reason_data, state_value)
class LogGroup(BaseModel):
def __init__(self, spec):
# required
self.name = spec['LogGroupName']
self.name = spec["LogGroupName"]
# optional
self.tags = spec.get('Tags', [])
self.tags = spec.get("Tags", [])
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
spec = {
'LogGroupName': properties['LogGroupName']
}
optional_properties = 'Tags'.split()
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
spec = {"LogGroupName": properties["LogGroupName"]}
optional_properties = "Tags".split()
for prop in optional_properties:
if prop in properties:
spec[prop] = properties[prop]

View File

@ -6,7 +6,6 @@ from dateutil.parser import parse as dtparse
class CloudWatchResponse(BaseResponse):
@property
def cloudwatch_backend(self):
return cloudwatch_backends[self.region]
@ -17,45 +16,54 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def put_metric_alarm(self):
name = self._get_param('AlarmName')
namespace = self._get_param('Namespace')
metric_name = self._get_param('MetricName')
comparison_operator = self._get_param('ComparisonOperator')
evaluation_periods = self._get_param('EvaluationPeriods')
period = self._get_param('Period')
threshold = self._get_param('Threshold')
statistic = self._get_param('Statistic')
description = self._get_param('AlarmDescription')
dimensions = self._get_list_prefix('Dimensions.member')
alarm_actions = self._get_multi_param('AlarmActions.member')
ok_actions = self._get_multi_param('OKActions.member')
name = self._get_param("AlarmName")
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
comparison_operator = self._get_param("ComparisonOperator")
evaluation_periods = self._get_param("EvaluationPeriods")
period = self._get_param("Period")
threshold = self._get_param("Threshold")
statistic = self._get_param("Statistic")
description = self._get_param("AlarmDescription")
dimensions = self._get_list_prefix("Dimensions.member")
alarm_actions = self._get_multi_param("AlarmActions.member")
ok_actions = self._get_multi_param("OKActions.member")
insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member")
unit = self._get_param('Unit')
alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name,
comparison_operator,
evaluation_periods, period,
threshold, statistic,
description, dimensions,
alarm_actions, ok_actions,
insufficient_data_actions,
unit)
"InsufficientDataActions.member"
)
unit = self._get_param("Unit")
alarm = self.cloudwatch_backend.put_metric_alarm(
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm)
@amzn_request_id
def describe_alarms(self):
action_prefix = self._get_param('ActionPrefix')
alarm_name_prefix = self._get_param('AlarmNamePrefix')
alarm_names = self._get_multi_param('AlarmNames.member')
state_value = self._get_param('StateValue')
action_prefix = self._get_param("ActionPrefix")
alarm_name_prefix = self._get_param("AlarmNamePrefix")
alarm_names = self._get_multi_param("AlarmNames.member")
state_value = self._get_param("StateValue")
if action_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(
action_prefix)
alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(action_prefix)
elif alarm_name_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarm_name_prefix)
alarm_name_prefix
)
elif alarm_names:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
elif state_value:
@ -68,15 +76,15 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def delete_alarms(self):
alarm_names = self._get_multi_param('AlarmNames.member')
alarm_names = self._get_multi_param("AlarmNames.member")
self.cloudwatch_backend.delete_alarms(alarm_names)
template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE)
return template.render()
@amzn_request_id
def put_metric_data(self):
namespace = self._get_param('Namespace')
metric_data = self._get_multi_param('MetricData.member')
namespace = self._get_param("Namespace")
metric_data = self._get_multi_param("MetricData.member")
self.cloudwatch_backend.put_metric_data(namespace, metric_data)
template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
@ -84,25 +92,29 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def get_metric_statistics(self):
namespace = self._get_param('Namespace')
metric_name = self._get_param('MetricName')
start_time = dtparse(self._get_param('StartTime'))
end_time = dtparse(self._get_param('EndTime'))
period = int(self._get_param('Period'))
namespace = self._get_param("Namespace")
metric_name = self._get_param("MetricName")
start_time = dtparse(self._get_param("StartTime"))
end_time = dtparse(self._get_param("EndTime"))
period = int(self._get_param("Period"))
statistics = self._get_multi_param("Statistics.member")
# Unsupported Parameters (To Be Implemented)
unit = self._get_param('Unit')
extended_statistics = self._get_param('ExtendedStatistics')
dimensions = self._get_param('Dimensions')
unit = self._get_param("Unit")
extended_statistics = self._get_param("ExtendedStatistics")
dimensions = self._get_param("Dimensions")
if unit or extended_statistics or dimensions:
raise NotImplemented()
raise NotImplementedError()
# TODO: this should instead throw InvalidParameterCombination
if not statistics:
raise NotImplemented("Must specify either Statistics or ExtendedStatistics")
raise NotImplementedError(
"Must specify either Statistics or ExtendedStatistics"
)
datapoints = self.cloudwatch_backend.get_metric_statistics(namespace, metric_name, start_time, end_time, period, statistics)
datapoints = self.cloudwatch_backend.get_metric_statistics(
namespace, metric_name, start_time, end_time, period, statistics
)
template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE)
return template.render(label=metric_name, datapoints=datapoints)
@ -114,13 +126,13 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def delete_dashboards(self):
dashboards = self._get_multi_param('DashboardNames.member')
dashboards = self._get_multi_param("DashboardNames.member")
if dashboards is None:
return self._error('InvalidParameterValue', 'Need at least 1 dashboard')
return self._error("InvalidParameterValue", "Need at least 1 dashboard")
status, error = self.cloudwatch_backend.delete_dashboards(dashboards)
if not status:
return self._error('ResourceNotFound', error)
return self._error("ResourceNotFound", error)
template = self.response_template(DELETE_DASHBOARD_TEMPLATE)
return template.render()
@ -143,18 +155,18 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def get_dashboard(self):
dashboard_name = self._get_param('DashboardName')
dashboard_name = self._get_param("DashboardName")
dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name)
if dashboard is None:
return self._error('ResourceNotFound', 'Dashboard does not exist')
return self._error("ResourceNotFound", "Dashboard does not exist")
template = self.response_template(GET_DASHBOARD_TEMPLATE)
return template.render(dashboard=dashboard)
@amzn_request_id
def list_dashboards(self):
prefix = self._get_param('DashboardNamePrefix', '')
prefix = self._get_param("DashboardNamePrefix", "")
dashboards = self.cloudwatch_backend.list_dashboards(prefix)
@ -163,13 +175,13 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def put_dashboard(self):
name = self._get_param('DashboardName')
body = self._get_param('DashboardBody')
name = self._get_param("DashboardName")
body = self._get_param("DashboardBody")
try:
json.loads(body)
except ValueError:
return self._error('InvalidParameterInput', 'Body is invalid JSON')
return self._error("InvalidParameterInput", "Body is invalid JSON")
self.cloudwatch_backend.put_dashboard(name, body)
@ -178,12 +190,14 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id
def set_alarm_state(self):
alarm_name = self._get_param('AlarmName')
reason = self._get_param('StateReason')
reason_data = self._get_param('StateReasonData')
state_value = self._get_param('StateValue')
alarm_name = self._get_param("AlarmName")
reason = self._get_param("StateReason")
reason_data = self._get_param("StateReasonData")
state_value = self._get_param("StateValue")
self.cloudwatch_backend.set_alarm_state(alarm_name, reason, reason_data, state_value)
self.cloudwatch_backend.set_alarm_state(
alarm_name, reason, reason_data, state_value
)
template = self.response_template(SET_ALARM_STATE_TEMPLATE)
return template.render()

View File

@ -1,9 +1,5 @@
from .responses import CloudWatchResponse
url_bases = [
"https?://monitoring.(.+).amazonaws.com",
]
url_bases = ["https?://monitoring.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CloudWatchResponse.dispatch,
}
url_paths = {"{0}/$": CloudWatchResponse.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import cognitoidentity_backends
from ..core.models import base_decorator, deprecated_base_decorator
cognitoidentity_backend = cognitoidentity_backends['us-east-1']
cognitoidentity_backend = cognitoidentity_backends["us-east-1"]
mock_cognitoidentity = base_decorator(cognitoidentity_backends)
mock_cognitoidentity_deprecated = deprecated_base_decorator(cognitoidentity_backends)

View File

@ -6,10 +6,8 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest):
def __init__(self, message):
super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'ResourceNotFoundException',
})
self.description = json.dumps(
{"message": message, "__type": "ResourceNotFoundException"}
)

View File

@ -13,22 +13,24 @@ from .utils import get_random_identity_id
class CognitoIdentity(BaseModel):
def __init__(self, region, identity_pool_name, **kwargs):
self.identity_pool_name = identity_pool_name
self.allow_unauthenticated_identities = kwargs.get('allow_unauthenticated_identities', '')
self.supported_login_providers = kwargs.get('supported_login_providers', {})
self.developer_provider_name = kwargs.get('developer_provider_name', '')
self.open_id_connect_provider_arns = kwargs.get('open_id_connect_provider_arns', [])
self.cognito_identity_providers = kwargs.get('cognito_identity_providers', [])
self.saml_provider_arns = kwargs.get('saml_provider_arns', [])
self.allow_unauthenticated_identities = kwargs.get(
"allow_unauthenticated_identities", ""
)
self.supported_login_providers = kwargs.get("supported_login_providers", {})
self.developer_provider_name = kwargs.get("developer_provider_name", "")
self.open_id_connect_provider_arns = kwargs.get(
"open_id_connect_provider_arns", []
)
self.cognito_identity_providers = kwargs.get("cognito_identity_providers", [])
self.saml_provider_arns = kwargs.get("saml_provider_arns", [])
self.identity_pool_id = get_random_identity_id(region)
self.creation_time = datetime.datetime.utcnow()
class CognitoIdentityBackend(BaseBackend):
def __init__(self, region):
super(CognitoIdentityBackend, self).__init__()
self.region = region
@ -45,47 +47,61 @@ class CognitoIdentityBackend(BaseBackend):
if not identity_pool:
raise ResourceNotFoundError(identity_pool)
response = json.dumps({
'AllowUnauthenticatedIdentities': identity_pool.allow_unauthenticated_identities,
'CognitoIdentityProviders': identity_pool.cognito_identity_providers,
'DeveloperProviderName': identity_pool.developer_provider_name,
'IdentityPoolId': identity_pool.identity_pool_id,
'IdentityPoolName': identity_pool.identity_pool_name,
'IdentityPoolTags': {},
'OpenIdConnectProviderARNs': identity_pool.open_id_connect_provider_arns,
'SamlProviderARNs': identity_pool.saml_provider_arns,
'SupportedLoginProviders': identity_pool.supported_login_providers
})
response = json.dumps(
{
"AllowUnauthenticatedIdentities": identity_pool.allow_unauthenticated_identities,
"CognitoIdentityProviders": identity_pool.cognito_identity_providers,
"DeveloperProviderName": identity_pool.developer_provider_name,
"IdentityPoolId": identity_pool.identity_pool_id,
"IdentityPoolName": identity_pool.identity_pool_name,
"IdentityPoolTags": {},
"OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns,
"SamlProviderARNs": identity_pool.saml_provider_arns,
"SupportedLoginProviders": identity_pool.supported_login_providers,
}
)
return response
def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities,
supported_login_providers, developer_provider_name, open_id_connect_provider_arns,
cognito_identity_providers, saml_provider_arns):
new_identity = CognitoIdentity(self.region, identity_pool_name,
def create_identity_pool(
self,
identity_pool_name,
allow_unauthenticated_identities,
supported_login_providers,
developer_provider_name,
open_id_connect_provider_arns,
cognito_identity_providers,
saml_provider_arns,
):
new_identity = CognitoIdentity(
self.region,
identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers,
developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns)
saml_provider_arns=saml_provider_arns,
)
self.identity_pools[new_identity.identity_pool_id] = new_identity
response = json.dumps({
'IdentityPoolId': new_identity.identity_pool_id,
'IdentityPoolName': new_identity.identity_pool_name,
'AllowUnauthenticatedIdentities': new_identity.allow_unauthenticated_identities,
'SupportedLoginProviders': new_identity.supported_login_providers,
'DeveloperProviderName': new_identity.developer_provider_name,
'OpenIdConnectProviderARNs': new_identity.open_id_connect_provider_arns,
'CognitoIdentityProviders': new_identity.cognito_identity_providers,
'SamlProviderARNs': new_identity.saml_provider_arns
})
response = json.dumps(
{
"IdentityPoolId": new_identity.identity_pool_id,
"IdentityPoolName": new_identity.identity_pool_name,
"AllowUnauthenticatedIdentities": new_identity.allow_unauthenticated_identities,
"SupportedLoginProviders": new_identity.supported_login_providers,
"DeveloperProviderName": new_identity.developer_provider_name,
"OpenIdConnectProviderARNs": new_identity.open_id_connect_provider_arns,
"CognitoIdentityProviders": new_identity.cognito_identity_providers,
"SamlProviderARNs": new_identity.saml_provider_arns,
}
)
return response
def get_id(self):
identity_id = {'IdentityId': get_random_identity_id(self.region)}
identity_id = {"IdentityId": get_random_identity_id(self.region)}
return json.dumps(identity_id)
def get_credentials_for_identity(self, identity_id):
@ -95,31 +111,26 @@ class CognitoIdentityBackend(BaseBackend):
expiration_str = str(iso_8601_datetime_with_milliseconds(expiration))
response = json.dumps(
{
"Credentials":
{
"AccessKeyId": "TESTACCESSKEY12345",
"Expiration": expiration_str,
"SecretKey": "ABCSECRETKEY",
"SessionToken": "ABC12345"
},
"IdentityId": identity_id
})
"Credentials": {
"AccessKeyId": "TESTACCESSKEY12345",
"Expiration": expiration_str,
"SecretKey": "ABCSECRETKEY",
"SessionToken": "ABC12345",
},
"IdentityId": identity_id,
}
)
return response
def get_open_id_token_for_developer_identity(self, identity_id):
response = json.dumps(
{
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
})
{"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
)
return response
def get_open_id_token(self, identity_id):
response = json.dumps(
{
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
}
{"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
)
return response

View File

@ -6,15 +6,16 @@ from .utils import get_random_identity_id
class CognitoIdentityResponse(BaseResponse):
def create_identity_pool(self):
identity_pool_name = self._get_param('IdentityPoolName')
allow_unauthenticated_identities = self._get_param('AllowUnauthenticatedIdentities')
supported_login_providers = self._get_param('SupportedLoginProviders')
developer_provider_name = self._get_param('DeveloperProviderName')
open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs')
cognito_identity_providers = self._get_param('CognitoIdentityProviders')
saml_provider_arns = self._get_param('SamlProviderARNs')
identity_pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated_identities = self._get_param(
"AllowUnauthenticatedIdentities"
)
supported_login_providers = self._get_param("SupportedLoginProviders")
developer_provider_name = self._get_param("DeveloperProviderName")
open_id_connect_provider_arns = self._get_param("OpenIdConnectProviderARNs")
cognito_identity_providers = self._get_param("CognitoIdentityProviders")
saml_provider_arns = self._get_param("SamlProviderARNs")
return cognitoidentity_backends[self.region].create_identity_pool(
identity_pool_name=identity_pool_name,
@ -23,20 +24,27 @@ class CognitoIdentityResponse(BaseResponse):
developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns)
saml_provider_arns=saml_provider_arns,
)
def get_id(self):
return cognitoidentity_backends[self.region].get_id()
def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool(self._get_param('IdentityPoolId'))
return cognitoidentity_backends[self.region].describe_identity_pool(
self._get_param("IdentityPoolId")
)
def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId'))
return cognitoidentity_backends[self.region].get_credentials_for_identity(
self._get_param("IdentityId")
)
def get_open_id_token_for_developer_identity(self):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity(
self._get_param('IdentityId') or get_random_identity_id(self.region)
return cognitoidentity_backends[
self.region
].get_open_id_token_for_developer_identity(
self._get_param("IdentityId") or get_random_identity_id(self.region)
)
def get_open_id_token(self):

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import CognitoIdentityResponse
url_bases = [
"https?://cognito-identity.(.+).amazonaws.com",
]
url_bases = ["https?://cognito-identity.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CognitoIdentityResponse.dispatch,
}
url_paths = {"{0}/$": CognitoIdentityResponse.dispatch}

View File

@ -5,50 +5,40 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest):
def __init__(self, message):
super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'ResourceNotFoundException',
})
self.description = json.dumps(
{"message": message, "__type": "ResourceNotFoundException"}
)
class UserNotFoundError(BadRequest):
def __init__(self, message):
super(UserNotFoundError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'UserNotFoundException',
})
self.description = json.dumps(
{"message": message, "__type": "UserNotFoundException"}
)
class UsernameExistsException(BadRequest):
def __init__(self, message):
super(UsernameExistsException, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'UsernameExistsException',
})
self.description = json.dumps(
{"message": message, "__type": "UsernameExistsException"}
)
class GroupExistsException(BadRequest):
def __init__(self, message):
super(GroupExistsException, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'GroupExistsException',
})
self.description = json.dumps(
{"message": message, "__type": "GroupExistsException"}
)
class NotAuthorizedError(BadRequest):
def __init__(self, message):
super(NotAuthorizedError, self).__init__()
self.description = json.dumps({
"message": message,
'__type': 'NotAuthorizedException',
})
self.description = json.dumps(
{"message": message, "__type": "NotAuthorizedException"}
)

View File

@ -14,8 +14,13 @@ from jose import jws
from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel
from .exceptions import GroupExistsException, NotAuthorizedError, ResourceNotFoundError, UserNotFoundError, \
UsernameExistsException
from .exceptions import (
GroupExistsException,
NotAuthorizedError,
ResourceNotFoundError,
UserNotFoundError,
UsernameExistsException,
)
UserStatus = {
"FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD",
@ -45,19 +50,22 @@ def paginate(limit, start_arg="next_token", limit_arg="max_results"):
def outer_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start = int(default_start if kwargs.get(start_arg) is None else kwargs[start_arg])
start = int(
default_start if kwargs.get(start_arg) is None else kwargs[start_arg]
)
lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg])
stop = start + lim
result = func(*args, **kwargs)
limited_results = list(itertools.islice(result, start, stop))
next_token = stop if stop < len(result) else None
return limited_results, next_token
return wrapper
return outer_wrapper
class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config):
self.region = region
self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
@ -75,7 +83,9 @@ class CognitoIdpUserPool(BaseModel):
self.access_tokens = {}
self.id_tokens = {}
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f:
with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")
) as f:
self.json_web_key = json.loads(f.read())
def _base_json(self):
@ -92,14 +102,18 @@ class CognitoIdpUserPool(BaseModel):
if extended:
user_pool_json.update(self.extended_config)
else:
user_pool_json["LambdaConfig"] = self.extended_config.get("LambdaConfig") or {}
user_pool_json["LambdaConfig"] = (
self.extended_config.get("LambdaConfig") or {}
)
return user_pool_json
def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}):
now = int(time.time())
payload = {
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format(self.region, self.id),
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format(
self.region, self.id
),
"sub": self.users[username].id,
"aud": client_id,
"token_use": "id",
@ -108,7 +122,7 @@ class CognitoIdpUserPool(BaseModel):
}
payload.update(extra_data)
return jws.sign(payload, self.json_web_key, algorithm='RS256'), expires_in
return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in
def create_id_token(self, client_id, username):
id_token, expires_in = self.create_jwt(client_id, username)
@ -121,11 +135,10 @@ class CognitoIdpUserPool(BaseModel):
return refresh_token
def create_access_token(self, client_id, username):
extra_data = self.get_user_extra_data_by_client_id(
client_id, username
extra_data = self.get_user_extra_data_by_client_id(client_id, username)
access_token, expires_in = self.create_jwt(
client_id, username, extra_data=extra_data
)
access_token, expires_in = self.create_jwt(client_id, username,
extra_data=extra_data)
self.access_tokens[access_token] = (client_id, username)
return access_token, expires_in
@ -143,29 +156,27 @@ class CognitoIdpUserPool(BaseModel):
current_client = self.clients.get(client_id, None)
if current_client:
for readable_field in current_client.get_readable_fields():
attribute = list(filter(
lambda f: f['Name'] == readable_field,
self.users.get(username).attributes
))
attribute = list(
filter(
lambda f: f["Name"] == readable_field,
self.users.get(username).attributes,
)
)
if len(attribute) > 0:
extra_data.update({
attribute[0]['Name']: attribute[0]['Value']
})
extra_data.update({attribute[0]["Name"]: attribute[0]["Value"]})
return extra_data
class CognitoIdpUserPoolDomain(BaseModel):
def __init__(self, user_pool_id, domain, custom_domain_config=None):
self.user_pool_id = user_pool_id
self.domain = domain
self.custom_domain_config = custom_domain_config or {}
def _distribution_name(self):
if self.custom_domain_config and \
'CertificateArn' in self.custom_domain_config:
if self.custom_domain_config and "CertificateArn" in self.custom_domain_config:
hash = hashlib.md5(
self.custom_domain_config['CertificateArn'].encode('utf-8')
self.custom_domain_config["CertificateArn"].encode("utf-8")
).hexdigest()
return "{hash}.cloudfront.net".format(hash=hash[:16])
return None
@ -183,14 +194,11 @@ class CognitoIdpUserPoolDomain(BaseModel):
"Version": None,
}
elif distribution:
return {
"CloudFrontDomain": distribution,
}
return {"CloudFrontDomain": distribution}
return None
class CognitoIdpUserPoolClient(BaseModel):
def __init__(self, user_pool_id, extended_config):
self.user_pool_id = user_pool_id
self.id = str(uuid.uuid4())
@ -212,11 +220,10 @@ class CognitoIdpUserPoolClient(BaseModel):
return user_pool_client_json
def get_readable_fields(self):
return self.extended_config.get('ReadAttributes', [])
return self.extended_config.get("ReadAttributes", [])
class CognitoIdpIdentityProvider(BaseModel):
def __init__(self, name, extended_config):
self.name = name
self.extended_config = extended_config or {}
@ -240,7 +247,6 @@ class CognitoIdpIdentityProvider(BaseModel):
class CognitoIdpGroup(BaseModel):
def __init__(self, user_pool_id, group_name, description, role_arn, precedence):
self.user_pool_id = user_pool_id
self.group_name = group_name
@ -267,7 +273,6 @@ class CognitoIdpGroup(BaseModel):
class CognitoIdpUser(BaseModel):
def __init__(self, user_pool_id, username, password, status, attributes):
self.id = str(uuid.uuid4())
self.user_pool_id = user_pool_id
@ -300,19 +305,18 @@ class CognitoIdpUser(BaseModel):
{
"Enabled": self.enabled,
attributes_key: self.attributes,
"MFAOptions": []
"MFAOptions": [],
}
)
return user_json
def update_attributes(self, new_attributes):
def flatten_attrs(attrs):
return {attr['Name']: attr['Value'] for attr in attrs}
return {attr["Name"]: attr["Value"] for attr in attrs}
def expand_attrs(attrs):
return [{'Name': k, 'Value': v} for k, v in attrs.items()]
return [{"Name": k, "Value": v} for k, v in attrs.items()]
flat_attributes = flatten_attrs(self.attributes)
flat_attributes.update(flatten_attrs(new_attributes))
@ -320,7 +324,6 @@ class CognitoIdpUser(BaseModel):
class CognitoIdpBackend(BaseBackend):
def __init__(self, region):
super(CognitoIdpBackend, self).__init__()
self.region = region
@ -496,7 +499,9 @@ class CognitoIdpBackend(BaseBackend):
if not user_pool:
raise ResourceNotFoundError(user_pool_id)
group = CognitoIdpGroup(user_pool_id, group_name, description, role_arn, precedence)
group = CognitoIdpGroup(
user_pool_id, group_name, description, role_arn, precedence
)
if group.group_name in user_pool.groups:
raise GroupExistsException("A group with the name already exists")
user_pool.groups[group.group_name] = group
@ -565,7 +570,13 @@ class CognitoIdpBackend(BaseBackend):
if username in user_pool.users:
raise UsernameExistsException(username)
user = CognitoIdpUser(user_pool_id, username, temporary_password, UserStatus["FORCE_CHANGE_PASSWORD"], attributes)
user = CognitoIdpUser(
user_pool_id,
username,
temporary_password,
UserStatus["FORCE_CHANGE_PASSWORD"],
attributes,
)
user_pool.users[user.username] = user
return user
@ -611,7 +622,9 @@ class CognitoIdpBackend(BaseBackend):
def _log_user_in(self, user_pool, client, username):
refresh_token = user_pool.create_refresh_token(client.id, username)
access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token)
access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(
refresh_token
)
return {
"AuthenticationResult": {
@ -654,7 +667,11 @@ class CognitoIdpBackend(BaseBackend):
return self._log_user_in(user_pool, client, username)
elif auth_flow == "REFRESH_TOKEN":
refresh_token = auth_parameters.get("REFRESH_TOKEN")
id_token, access_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token)
(
id_token,
access_token,
expires_in,
) = user_pool.create_tokens_from_refresh_token(refresh_token)
return {
"AuthenticationResult": {
@ -666,7 +683,9 @@ class CognitoIdpBackend(BaseBackend):
else:
return {}
def respond_to_auth_challenge(self, session, client_id, challenge_name, challenge_responses):
def respond_to_auth_challenge(
self, session, client_id, challenge_name, challenge_responses
):
user_pool = self.sessions.get(session)
if not user_pool:
raise ResourceNotFoundError(session)

View File

@ -8,7 +8,6 @@ from .models import cognitoidp_backends, find_region_by_value
class CognitoIdpResponse(BaseResponse):
@property
def parameters(self):
return json.loads(self.body)
@ -16,10 +15,10 @@ class CognitoIdpResponse(BaseResponse):
# User pool
def create_user_pool(self):
name = self.parameters.pop("PoolName")
user_pool = cognitoidp_backends[self.region].create_user_pool(name, self.parameters)
return json.dumps({
"UserPool": user_pool.to_json(extended=True)
})
user_pool = cognitoidp_backends[self.region].create_user_pool(
name, self.parameters
)
return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def list_user_pools(self):
max_results = self._get_param("MaxResults")
@ -27,9 +26,7 @@ class CognitoIdpResponse(BaseResponse):
user_pools, next_token = cognitoidp_backends[self.region].list_user_pools(
max_results=max_results, next_token=next_token
)
response = {
"UserPools": [user_pool.to_json() for user_pool in user_pools],
}
response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]}
if next_token:
response["NextToken"] = str(next_token)
return json.dumps(response)
@ -37,9 +34,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool(self):
user_pool_id = self._get_param("UserPoolId")
user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id)
return json.dumps({
"UserPool": user_pool.to_json(extended=True)
})
return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def delete_user_pool(self):
user_pool_id = self._get_param("UserPoolId")
@ -61,14 +56,14 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_domain(self):
domain = self._get_param("Domain")
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(domain)
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(
domain
)
domain_description = {}
if user_pool_domain:
domain_description = user_pool_domain.to_json()
return json.dumps({
"DomainDescription": domain_description
})
return json.dumps({"DomainDescription": domain_description})
def delete_user_pool_domain(self):
domain = self._get_param("Domain")
@ -89,19 +84,24 @@ class CognitoIdpResponse(BaseResponse):
# User pool client
def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(user_pool_id, self.parameters)
return json.dumps({
"UserPoolClient": user_pool_client.to_json(extended=True)
})
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(
user_pool_id, self.parameters
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def list_user_pool_clients(self):
user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0")
user_pool_clients, next_token = cognitoidp_backends[self.region].list_user_pool_clients(user_pool_id,
max_results=max_results, next_token=next_token)
user_pool_clients, next_token = cognitoidp_backends[
self.region
].list_user_pool_clients(
user_pool_id, max_results=max_results, next_token=next_token
)
response = {
"UserPoolClients": [user_pool_client.to_json() for user_pool_client in user_pool_clients]
"UserPoolClients": [
user_pool_client.to_json() for user_pool_client in user_pool_clients
]
}
if next_token:
response["NextToken"] = str(next_token)
@ -110,43 +110,51 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId")
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(user_pool_id, client_id)
return json.dumps({
"UserPoolClient": user_pool_client.to_json(extended=True)
})
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(
user_pool_id, client_id
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def update_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId")
client_id = self.parameters.pop("ClientId")
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(user_pool_id, client_id, self.parameters)
return json.dumps({
"UserPoolClient": user_pool_client.to_json(extended=True)
})
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(
user_pool_id, client_id, self.parameters
)
return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def delete_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId")
cognitoidp_backends[self.region].delete_user_pool_client(user_pool_id, client_id)
cognitoidp_backends[self.region].delete_user_pool_client(
user_pool_id, client_id
)
return ""
# Identity provider
def create_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self.parameters.pop("ProviderName")
identity_provider = cognitoidp_backends[self.region].create_identity_provider(user_pool_id, name, self.parameters)
return json.dumps({
"IdentityProvider": identity_provider.to_json(extended=True)
})
identity_provider = cognitoidp_backends[self.region].create_identity_provider(
user_pool_id, name, self.parameters
)
return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def list_identity_providers(self):
user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0")
identity_providers, next_token = cognitoidp_backends[self.region].list_identity_providers(
identity_providers, next_token = cognitoidp_backends[
self.region
].list_identity_providers(
user_pool_id, max_results=max_results, next_token=next_token
)
response = {
"Providers": [identity_provider.to_json() for identity_provider in identity_providers]
"Providers": [
identity_provider.to_json() for identity_provider in identity_providers
]
}
if next_token:
response["NextToken"] = str(next_token)
@ -155,18 +163,22 @@ class CognitoIdpResponse(BaseResponse):
def describe_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].describe_identity_provider(user_pool_id, name)
return json.dumps({
"IdentityProvider": identity_provider.to_json(extended=True)
})
identity_provider = cognitoidp_backends[self.region].describe_identity_provider(
user_pool_id, name
)
return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def update_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].update_identity_provider(user_pool_id, name, self.parameters)
return json.dumps({
"IdentityProvider": identity_provider.to_json(extended=True)
})
identity_provider = cognitoidp_backends[self.region].update_identity_provider(
user_pool_id, name, self.parameters
)
return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def delete_identity_provider(self):
user_pool_id = self._get_param("UserPoolId")
@ -183,31 +195,21 @@ class CognitoIdpResponse(BaseResponse):
precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].create_group(
user_pool_id,
group_name,
description,
role_arn,
precedence,
user_pool_id, group_name, description, role_arn, precedence
)
return json.dumps({
"Group": group.to_json(),
})
return json.dumps({"Group": group.to_json()})
def get_group(self):
group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId")
group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name)
return json.dumps({
"Group": group.to_json(),
})
return json.dumps({"Group": group.to_json()})
def list_groups(self):
user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].list_groups(user_pool_id)
return json.dumps({
"Groups": [group.to_json() for group in groups],
})
return json.dumps({"Groups": [group.to_json() for group in groups]})
def delete_group(self):
group_name = self._get_param("GroupName")
@ -221,9 +223,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_add_user_to_group(
user_pool_id,
group_name,
username,
user_pool_id, group_name, username
)
return ""
@ -231,18 +231,18 @@ class CognitoIdpResponse(BaseResponse):
def list_users_in_group(self):
user_pool_id = self._get_param("UserPoolId")
group_name = self._get_param("GroupName")
users = cognitoidp_backends[self.region].list_users_in_group(user_pool_id, group_name)
return json.dumps({
"Users": [user.to_json(extended=True) for user in users],
})
users = cognitoidp_backends[self.region].list_users_in_group(
user_pool_id, group_name
)
return json.dumps({"Users": [user.to_json(extended=True) for user in users]})
def admin_list_groups_for_user(self):
username = self._get_param("Username")
user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].admin_list_groups_for_user(user_pool_id, username)
return json.dumps({
"Groups": [group.to_json() for group in groups],
})
groups = cognitoidp_backends[self.region].admin_list_groups_for_user(
user_pool_id, username
)
return json.dumps({"Groups": [group.to_json() for group in groups]})
def admin_remove_user_from_group(self):
user_pool_id = self._get_param("UserPoolId")
@ -250,9 +250,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_remove_user_from_group(
user_pool_id,
group_name,
username,
user_pool_id, group_name, username
)
return ""
@ -266,28 +264,24 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id,
username,
temporary_password,
self._get_param("UserAttributes", [])
self._get_param("UserAttributes", []),
)
return json.dumps({
"User": user.to_json(extended=True)
})
return json.dumps({"User": user.to_json(extended=True)})
def admin_get_user(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username)
return json.dumps(
user.to_json(extended=True, attributes_key="UserAttributes")
)
return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes"))
def list_users(self):
user_pool_id = self._get_param("UserPoolId")
limit = self._get_param("Limit")
token = self._get_param("PaginationToken")
users, token = cognitoidp_backends[self.region].list_users(user_pool_id,
limit=limit,
pagination_token=token)
users, token = cognitoidp_backends[self.region].list_users(
user_pool_id, limit=limit, pagination_token=token
)
response = {"Users": [user.to_json(extended=True) for user in users]}
if token:
response["PaginationToken"] = str(token)
@ -318,10 +312,7 @@ class CognitoIdpResponse(BaseResponse):
auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends[self.region].admin_initiate_auth(
user_pool_id,
client_id,
auth_flow,
auth_parameters,
user_pool_id, client_id, auth_flow, auth_parameters
)
return json.dumps(auth_result)
@ -332,21 +323,15 @@ class CognitoIdpResponse(BaseResponse):
challenge_name = self._get_param("ChallengeName")
challenge_responses = self._get_param("ChallengeResponses")
auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge(
session,
client_id,
challenge_name,
challenge_responses,
session, client_id, challenge_name, challenge_responses
)
return json.dumps(auth_result)
def forgot_password(self):
return json.dumps({
"CodeDeliveryDetails": {
"DeliveryMedium": "EMAIL",
"Destination": "...",
}
})
return json.dumps(
{"CodeDeliveryDetails": {"DeliveryMedium": "EMAIL", "Destination": "..."}}
)
# This endpoint receives no authorization header, so if moto-server is listening
# on localhost (doesn't get a region in the host header), it doesn't know what
@ -357,7 +342,9 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username")
password = self._get_param("Password")
region = find_region_by_value("client_id", client_id)
cognitoidp_backends[region].confirm_forgot_password(client_id, username, password)
cognitoidp_backends[region].confirm_forgot_password(
client_id, username, password
)
return ""
# Ditto the comment on confirm_forgot_password.
@ -366,21 +353,26 @@ class CognitoIdpResponse(BaseResponse):
previous_password = self._get_param("PreviousPassword")
proposed_password = self._get_param("ProposedPassword")
region = find_region_by_value("access_token", access_token)
cognitoidp_backends[region].change_password(access_token, previous_password, proposed_password)
cognitoidp_backends[region].change_password(
access_token, previous_password, proposed_password
)
return ""
def admin_update_user_attributes(self):
user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username")
attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].admin_update_user_attributes(user_pool_id, username, attributes)
cognitoidp_backends[self.region].admin_update_user_attributes(
user_pool_id, username, attributes
)
return ""
class CognitoIdpJsonWebKeyResponse(BaseResponse):
def __init__(self):
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")) as f:
with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")
) as f:
self.json_web_key = f.read()
def serve_json_web_key(self, request, full_url, headers):

View File

@ -1,11 +1,9 @@
from __future__ import unicode_literals
from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse
url_bases = [
"https?://cognito-idp.(.+).amazonaws.com",
]
url_bases = ["https?://cognito-idp.(.+).amazonaws.com"]
url_paths = {
'{0}/$': CognitoIdpResponse.dispatch,
'{0}/<user_pool_id>/.well-known/jwks.json$': CognitoIdpJsonWebKeyResponse().serve_json_web_key,
"{0}/$": CognitoIdpResponse.dispatch,
"{0}/<user_pool_id>/.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key,
}

View File

@ -1,5 +1,5 @@
try:
from collections import OrderedDict # flake8: noqa
from collections import OrderedDict # noqa
except ImportError:
# python 2.6 or earlier, use backport
from ordereddict import OrderedDict # flake8: noqa
from ordereddict import OrderedDict # noqa

View File

@ -6,8 +6,12 @@ class NameTooLongException(JsonRESTError):
code = 400
def __init__(self, name, location):
message = '1 validation error detected: Value \'{name}\' at \'{location}\' failed to satisfy' \
' constraint: Member must have length less than or equal to 256'.format(name=name, location=location)
message = (
"1 validation error detected: Value '{name}' at '{location}' failed to satisfy"
" constraint: Member must have length less than or equal to 256".format(
name=name, location=location
)
)
super(NameTooLongException, self).__init__("ValidationException", message)
@ -15,41 +19,54 @@ class InvalidConfigurationRecorderNameException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'The configuration recorder name \'{name}\' is not valid, blank string.'.format(name=name)
super(InvalidConfigurationRecorderNameException, self).__init__("InvalidConfigurationRecorderNameException",
message)
message = "The configuration recorder name '{name}' is not valid, blank string.".format(
name=name
)
super(InvalidConfigurationRecorderNameException, self).__init__(
"InvalidConfigurationRecorderNameException", message
)
class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Failed to put configuration recorder \'{name}\' because the maximum number of ' \
'configuration recorders: 1 is reached.'.format(name=name)
message = (
"Failed to put configuration recorder '{name}' because the maximum number of "
"configuration recorders: 1 is reached.".format(name=name)
)
super(MaxNumberOfConfigurationRecordersExceededException, self).__init__(
"MaxNumberOfConfigurationRecordersExceededException", message)
"MaxNumberOfConfigurationRecordersExceededException", message
)
class InvalidRecordingGroupException(JsonRESTError):
code = 400
def __init__(self):
message = 'The recording group provided is not valid'
super(InvalidRecordingGroupException, self).__init__("InvalidRecordingGroupException", message)
message = "The recording group provided is not valid"
super(InvalidRecordingGroupException, self).__init__(
"InvalidRecordingGroupException", message
)
class InvalidResourceTypeException(JsonRESTError):
code = 400
def __init__(self, bad_list, good_list):
message = '{num} validation error detected: Value \'{bad_list}\' at ' \
'\'configurationRecorder.recordingGroup.resourceTypes\' failed to satisfy constraint: ' \
'Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]'.format(
num=len(bad_list), bad_list=bad_list, good_list=good_list)
message = (
"{num} validation error detected: Value '{bad_list}' at "
"'configurationRecorder.recordingGroup.resourceTypes' failed to satisfy constraint: "
"Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]".format(
num=len(bad_list), bad_list=bad_list, good_list=good_list
)
)
# For PY2:
message = str(message)
super(InvalidResourceTypeException, self).__init__("ValidationException", message)
super(InvalidResourceTypeException, self).__init__(
"ValidationException", message
)
class NoSuchConfigurationAggregatorException(JsonRESTError):
@ -57,36 +74,48 @@ class NoSuchConfigurationAggregatorException(JsonRESTError):
def __init__(self, number=1):
if number == 1:
message = 'The configuration aggregator does not exist. Check the configuration aggregator name and try again.'
message = "The configuration aggregator does not exist. Check the configuration aggregator name and try again."
else:
message = 'At least one of the configuration aggregators does not exist. Check the configuration aggregator' \
' names and try again.'
super(NoSuchConfigurationAggregatorException, self).__init__("NoSuchConfigurationAggregatorException", message)
message = (
"At least one of the configuration aggregators does not exist. Check the configuration aggregator"
" names and try again."
)
super(NoSuchConfigurationAggregatorException, self).__init__(
"NoSuchConfigurationAggregatorException", message
)
class NoSuchConfigurationRecorderException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Cannot find configuration recorder with the specified name \'{name}\'.'.format(name=name)
super(NoSuchConfigurationRecorderException, self).__init__("NoSuchConfigurationRecorderException", message)
message = "Cannot find configuration recorder with the specified name '{name}'.".format(
name=name
)
super(NoSuchConfigurationRecorderException, self).__init__(
"NoSuchConfigurationRecorderException", message
)
class InvalidDeliveryChannelNameException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'The delivery channel name \'{name}\' is not valid, blank string.'.format(name=name)
super(InvalidDeliveryChannelNameException, self).__init__("InvalidDeliveryChannelNameException",
message)
message = "The delivery channel name '{name}' is not valid, blank string.".format(
name=name
)
super(InvalidDeliveryChannelNameException, self).__init__(
"InvalidDeliveryChannelNameException", message
)
class NoSuchBucketException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here."""
code = 400
def __init__(self):
message = 'Cannot find a S3 bucket with an empty bucket name.'
message = "Cannot find a S3 bucket with an empty bucket name."
super(NoSuchBucketException, self).__init__("NoSuchBucketException", message)
@ -94,89 +123,120 @@ class InvalidNextTokenException(JsonRESTError):
code = 400
def __init__(self):
message = 'The nextToken provided is invalid'
super(InvalidNextTokenException, self).__init__("InvalidNextTokenException", message)
message = "The nextToken provided is invalid"
super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", message
)
class InvalidS3KeyPrefixException(JsonRESTError):
code = 400
def __init__(self):
message = 'The s3 key prefix \'\' is not valid, empty s3 key prefix.'
super(InvalidS3KeyPrefixException, self).__init__("InvalidS3KeyPrefixException", message)
message = "The s3 key prefix '' is not valid, empty s3 key prefix."
super(InvalidS3KeyPrefixException, self).__init__(
"InvalidS3KeyPrefixException", message
)
class InvalidSNSTopicARNException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here."""
code = 400
def __init__(self):
message = 'The sns topic arn \'\' is not valid.'
super(InvalidSNSTopicARNException, self).__init__("InvalidSNSTopicARNException", message)
message = "The sns topic arn '' is not valid."
super(InvalidSNSTopicARNException, self).__init__(
"InvalidSNSTopicARNException", message
)
class InvalidDeliveryFrequency(JsonRESTError):
code = 400
def __init__(self, value, good_list):
message = '1 validation error detected: Value \'{value}\' at ' \
'\'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency\' failed to satisfy ' \
'constraint: Member must satisfy enum value set: {good_list}'.format(value=value, good_list=good_list)
super(InvalidDeliveryFrequency, self).__init__("InvalidDeliveryFrequency", message)
message = (
"1 validation error detected: Value '{value}' at "
"'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency' failed to satisfy "
"constraint: Member must satisfy enum value set: {good_list}".format(
value=value, good_list=good_list
)
)
super(InvalidDeliveryFrequency, self).__init__(
"InvalidDeliveryFrequency", message
)
class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Failed to put delivery channel \'{name}\' because the maximum number of ' \
'delivery channels: 1 is reached.'.format(name=name)
message = (
"Failed to put delivery channel '{name}' because the maximum number of "
"delivery channels: 1 is reached.".format(name=name)
)
super(MaxNumberOfDeliveryChannelsExceededException, self).__init__(
"MaxNumberOfDeliveryChannelsExceededException", message)
"MaxNumberOfDeliveryChannelsExceededException", message
)
class NoSuchDeliveryChannelException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Cannot find delivery channel with specified name \'{name}\'.'.format(name=name)
super(NoSuchDeliveryChannelException, self).__init__("NoSuchDeliveryChannelException", message)
message = "Cannot find delivery channel with specified name '{name}'.".format(
name=name
)
super(NoSuchDeliveryChannelException, self).__init__(
"NoSuchDeliveryChannelException", message
)
class NoAvailableConfigurationRecorderException(JsonRESTError):
code = 400
def __init__(self):
message = 'Configuration recorder is not available to put delivery channel.'
super(NoAvailableConfigurationRecorderException, self).__init__("NoAvailableConfigurationRecorderException",
message)
message = "Configuration recorder is not available to put delivery channel."
super(NoAvailableConfigurationRecorderException, self).__init__(
"NoAvailableConfigurationRecorderException", message
)
class NoAvailableDeliveryChannelException(JsonRESTError):
code = 400
def __init__(self):
message = 'Delivery channel is not available to start configuration recorder.'
super(NoAvailableDeliveryChannelException, self).__init__("NoAvailableDeliveryChannelException", message)
message = "Delivery channel is not available to start configuration recorder."
super(NoAvailableDeliveryChannelException, self).__init__(
"NoAvailableDeliveryChannelException", message
)
class LastDeliveryChannelDeleteFailedException(JsonRESTError):
code = 400
def __init__(self, name):
message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \
'because there is a running configuration recorder.'.format(name=name)
super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message)
message = (
"Failed to delete last specified delivery channel with name '{name}', because there, "
"because there is a running configuration recorder.".format(name=name)
)
super(LastDeliveryChannelDeleteFailedException, self).__init__(
"LastDeliveryChannelDeleteFailedException", message
)
class TooManyAccountSources(JsonRESTError):
code = 400
def __init__(self, length):
locations = ['com.amazonaws.xyz'] * length
locations = ["com.amazonaws.xyz"] * length
message = 'Value \'[{locations}]\' at \'accountAggregationSources\' failed to satisfy constraint: ' \
'Member must have length less than or equal to 1'.format(locations=', '.join(locations))
message = (
"Value '[{locations}]' at 'accountAggregationSources' failed to satisfy constraint: "
"Member must have length less than or equal to 1".format(
locations=", ".join(locations)
)
)
super(TooManyAccountSources, self).__init__("ValidationException", message)
@ -185,16 +245,22 @@ class DuplicateTags(JsonRESTError):
def __init__(self):
super(DuplicateTags, self).__init__(
'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.')
"InvalidInput",
"Duplicate tag keys found. Please note that Tag keys are case insensitive.",
)
class TagKeyTooBig(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
def __init__(self, tag, param="tags.X.member.key"):
super(TagKeyTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(tag, param))
"ValidationException",
"1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(
tag, param
),
)
class TagValueTooBig(JsonRESTError):
@ -202,76 +268,100 @@ class TagValueTooBig(JsonRESTError):
def __init__(self, tag):
super(TagValueTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag))
"ValidationException",
"1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag),
)
class InvalidParameterValueException(JsonRESTError):
code = 400
def __init__(self, message):
super(InvalidParameterValueException, self).__init__('InvalidParameterValueException', message)
super(InvalidParameterValueException, self).__init__(
"InvalidParameterValueException", message
)
class InvalidTagCharacters(JsonRESTError):
code = 400
def __init__(self, tag, param='tags.X.member.key'):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param)
message += 'constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+'
def __init__(self, tag, param="tags.X.member.key"):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(
tag, param
)
message += "constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+"
super(InvalidTagCharacters, self).__init__('ValidationException', message)
super(InvalidTagCharacters, self).__init__("ValidationException", message)
class TooManyTags(JsonRESTError):
code = 400
def __init__(self, tags, param='tags'):
def __init__(self, tags, param="tags"):
super(TooManyTags, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(tags, param))
"ValidationException",
"1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(
tags, param
),
)
class InvalidResourceParameters(JsonRESTError):
code = 400
def __init__(self):
super(InvalidResourceParameters, self).__init__('ValidationException', 'Both Resource ID and Resource Name '
'cannot be specified in the request')
super(InvalidResourceParameters, self).__init__(
"ValidationException",
"Both Resource ID and Resource Name " "cannot be specified in the request",
)
class InvalidLimit(JsonRESTError):
code = 400
def __init__(self, value):
super(InvalidLimit, self).__init__('ValidationException', 'Value \'{value}\' at \'limit\' failed to satisify constraint: Member'
' must have value less than or equal to 100'.format(value=value))
super(InvalidLimit, self).__init__(
"ValidationException",
"Value '{value}' at 'limit' failed to satisify constraint: Member"
" must have value less than or equal to 100".format(value=value),
)
class TooManyResourceIds(JsonRESTError):
code = 400
def __init__(self):
super(TooManyResourceIds, self).__init__('ValidationException', "The specified list had more than 20 resource ID's. "
"It must have '20' or less items")
super(TooManyResourceIds, self).__init__(
"ValidationException",
"The specified list had more than 20 resource ID's. "
"It must have '20' or less items",
)
class ResourceNotDiscoveredException(JsonRESTError):
code = 400
def __init__(self, type, resource):
super(ResourceNotDiscoveredException, self).__init__('ResourceNotDiscoveredException',
'Resource {resource} of resourceType:{type} is unknown or has not been '
'discovered'.format(resource=resource, type=type))
super(ResourceNotDiscoveredException, self).__init__(
"ResourceNotDiscoveredException",
"Resource {resource} of resourceType:{type} is unknown or has not been "
"discovered".format(resource=resource, type=type),
)
class TooManyResourceKeys(JsonRESTError):
code = 400
def __init__(self, bad_list):
message = '1 validation error detected: Value \'{bad_list}\' at ' \
'\'resourceKeys\' failed to satisfy constraint: ' \
'Member must have length less than or equal to 100'.format(bad_list=bad_list)
message = (
"1 validation error detected: Value '{bad_list}' at "
"'resourceKeys' failed to satisfy constraint: "
"Member must have length less than or equal to 100".format(
bad_list=bad_list
)
)
# For PY2:
message = str(message)

File diff suppressed because it is too large Load Diff

View File

@ -4,116 +4,150 @@ from .models import config_backends
class ConfigResponse(BaseResponse):
@property
def config_backend(self):
return config_backends[self.region]
def put_configuration_recorder(self):
self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder'))
self.config_backend.put_configuration_recorder(
self._get_param("ConfigurationRecorder")
)
return ""
def put_configuration_aggregator(self):
aggregator = self.config_backend.put_configuration_aggregator(json.loads(self.body), self.region)
schema = {'ConfigurationAggregator': aggregator}
aggregator = self.config_backend.put_configuration_aggregator(
json.loads(self.body), self.region
)
schema = {"ConfigurationAggregator": aggregator}
return json.dumps(schema)
def describe_configuration_aggregators(self):
aggregators = self.config_backend.describe_configuration_aggregators(self._get_param('ConfigurationAggregatorNames'),
self._get_param('NextToken'),
self._get_param('Limit'))
aggregators = self.config_backend.describe_configuration_aggregators(
self._get_param("ConfigurationAggregatorNames"),
self._get_param("NextToken"),
self._get_param("Limit"),
)
return json.dumps(aggregators)
def delete_configuration_aggregator(self):
self.config_backend.delete_configuration_aggregator(self._get_param('ConfigurationAggregatorName'))
self.config_backend.delete_configuration_aggregator(
self._get_param("ConfigurationAggregatorName")
)
return ""
def put_aggregation_authorization(self):
agg_auth = self.config_backend.put_aggregation_authorization(self.region,
self._get_param('AuthorizedAccountId'),
self._get_param('AuthorizedAwsRegion'),
self._get_param('Tags'))
schema = {'AggregationAuthorization': agg_auth}
agg_auth = self.config_backend.put_aggregation_authorization(
self.region,
self._get_param("AuthorizedAccountId"),
self._get_param("AuthorizedAwsRegion"),
self._get_param("Tags"),
)
schema = {"AggregationAuthorization": agg_auth}
return json.dumps(schema)
def describe_aggregation_authorizations(self):
authorizations = self.config_backend.describe_aggregation_authorizations(self._get_param('NextToken'), self._get_param('Limit'))
authorizations = self.config_backend.describe_aggregation_authorizations(
self._get_param("NextToken"), self._get_param("Limit")
)
return json.dumps(authorizations)
def delete_aggregation_authorization(self):
self.config_backend.delete_aggregation_authorization(self._get_param('AuthorizedAccountId'), self._get_param('AuthorizedAwsRegion'))
self.config_backend.delete_aggregation_authorization(
self._get_param("AuthorizedAccountId"),
self._get_param("AuthorizedAwsRegion"),
)
return ""
def describe_configuration_recorders(self):
recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames'))
schema = {'ConfigurationRecorders': recorders}
recorders = self.config_backend.describe_configuration_recorders(
self._get_param("ConfigurationRecorderNames")
)
schema = {"ConfigurationRecorders": recorders}
return json.dumps(schema)
def describe_configuration_recorder_status(self):
recorder_statuses = self.config_backend.describe_configuration_recorder_status(
self._get_param('ConfigurationRecorderNames'))
schema = {'ConfigurationRecordersStatus': recorder_statuses}
self._get_param("ConfigurationRecorderNames")
)
schema = {"ConfigurationRecordersStatus": recorder_statuses}
return json.dumps(schema)
def put_delivery_channel(self):
self.config_backend.put_delivery_channel(self._get_param('DeliveryChannel'))
self.config_backend.put_delivery_channel(self._get_param("DeliveryChannel"))
return ""
def describe_delivery_channels(self):
delivery_channels = self.config_backend.describe_delivery_channels(self._get_param('DeliveryChannelNames'))
schema = {'DeliveryChannels': delivery_channels}
delivery_channels = self.config_backend.describe_delivery_channels(
self._get_param("DeliveryChannelNames")
)
schema = {"DeliveryChannels": delivery_channels}
return json.dumps(schema)
def describe_delivery_channel_status(self):
raise NotImplementedError()
def delete_delivery_channel(self):
self.config_backend.delete_delivery_channel(self._get_param('DeliveryChannelName'))
self.config_backend.delete_delivery_channel(
self._get_param("DeliveryChannelName")
)
return ""
def delete_configuration_recorder(self):
self.config_backend.delete_configuration_recorder(self._get_param('ConfigurationRecorderName'))
self.config_backend.delete_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return ""
def start_configuration_recorder(self):
self.config_backend.start_configuration_recorder(self._get_param('ConfigurationRecorderName'))
self.config_backend.start_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return ""
def stop_configuration_recorder(self):
self.config_backend.stop_configuration_recorder(self._get_param('ConfigurationRecorderName'))
self.config_backend.stop_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return ""
def list_discovered_resources(self):
schema = self.config_backend.list_discovered_resources(self._get_param('resourceType'),
self.region,
self._get_param('resourceIds'),
self._get_param('resourceName'),
self._get_param('limit'),
self._get_param('nextToken'))
schema = self.config_backend.list_discovered_resources(
self._get_param("resourceType"),
self.region,
self._get_param("resourceIds"),
self._get_param("resourceName"),
self._get_param("limit"),
self._get_param("nextToken"),
)
return json.dumps(schema)
def list_aggregate_discovered_resources(self):
schema = self.config_backend.list_aggregate_discovered_resources(self._get_param('ConfigurationAggregatorName'),
self._get_param('ResourceType'),
self._get_param('Filters'),
self._get_param('Limit'),
self._get_param('NextToken'))
schema = self.config_backend.list_aggregate_discovered_resources(
self._get_param("ConfigurationAggregatorName"),
self._get_param("ResourceType"),
self._get_param("Filters"),
self._get_param("Limit"),
self._get_param("NextToken"),
)
return json.dumps(schema)
def get_resource_config_history(self):
schema = self.config_backend.get_resource_config_history(self._get_param('resourceType'),
self._get_param('resourceId'),
self.region)
schema = self.config_backend.get_resource_config_history(
self._get_param("resourceType"), self._get_param("resourceId"), self.region
)
return json.dumps(schema)
def batch_get_resource_config(self):
schema = self.config_backend.batch_get_resource_config(self._get_param('resourceKeys'),
self.region)
schema = self.config_backend.batch_get_resource_config(
self._get_param("resourceKeys"), self.region
)
return json.dumps(schema)
def batch_get_aggregate_resource_config(self):
schema = self.config_backend.batch_get_aggregate_resource_config(self._get_param('ConfigurationAggregatorName'),
self._get_param('ResourceIdentifiers'))
schema = self.config_backend.batch_get_aggregate_resource_config(
self._get_param("ConfigurationAggregatorName"),
self._get_param("ResourceIdentifiers"),
)
return json.dumps(schema)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import ConfigResponse
url_bases = [
"https?://config.(.+).amazonaws.com",
]
url_bases = ["https?://config.(.+).amazonaws.com"]
url_paths = {
'{0}/$': ConfigResponse.dispatch,
}
url_paths = {"{0}/$": ConfigResponse.dispatch}

View File

@ -1,7 +1,9 @@
from __future__ import unicode_literals
from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa
from .models import BaseModel, BaseBackend, moto_api_backend # noqa
from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend}
set_initial_no_auth_action_count = ActionAuthenticatorMixin.set_initial_no_auth_action_count
set_initial_no_auth_action_count = (
ActionAuthenticatorMixin.set_initial_no_auth_action_count
)

View File

@ -26,7 +26,12 @@ from six import string_types
from moto.iam.models import ACCOUNT_ID, Policy
from moto.iam import iam_backend
from moto.core.exceptions import SignatureDoesNotMatchError, AccessDeniedError, InvalidClientTokenIdError, AuthFailureError
from moto.core.exceptions import (
SignatureDoesNotMatchError,
AccessDeniedError,
InvalidClientTokenIdError,
AuthFailureError,
)
from moto.s3.exceptions import (
BucketAccessDeniedError,
S3AccessDeniedError,
@ -35,7 +40,7 @@ from moto.s3.exceptions import (
S3InvalidAccessKeyIdError,
BucketInvalidAccessKeyIdError,
BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError
S3SignatureDoesNotMatchError,
)
from moto.sts import sts_backend
@ -50,9 +55,8 @@ def create_access_key(access_key_id, headers):
class IAMUserAccessKey(object):
def __init__(self, access_key_id, headers):
iam_users = iam_backend.list_users('/', None, None)
iam_users = iam_backend.list_users("/", None, None)
for iam_user in iam_users:
for access_key in iam_user.access_keys:
if access_key.access_key_id == access_key_id:
@ -67,8 +71,7 @@ class IAMUserAccessKey(object):
@property
def arn(self):
return "arn:aws:iam::{account_id}:user/{iam_user_name}".format(
account_id=ACCOUNT_ID,
iam_user_name=self._owner_user_name
account_id=ACCOUNT_ID, iam_user_name=self._owner_user_name
)
def create_credentials(self):
@ -79,27 +82,34 @@ class IAMUserAccessKey(object):
inline_policy_names = iam_backend.list_user_policies(self._owner_user_name)
for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy(self._owner_user_name, inline_policy_name)
inline_policy = iam_backend.get_user_policy(
self._owner_user_name, inline_policy_name
)
user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies(self._owner_user_name)
attached_policies, _ = iam_backend.list_attached_user_policies(
self._owner_user_name
)
user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(self._owner_user_name)
for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group.name)
for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy(user_group.name, inline_group_policy_name)
inline_user_group_policy = iam_backend.get_group_policy(
user_group.name, inline_group_policy_name
)
user_policies.append(inline_user_group_policy)
attached_group_policies, _ = iam_backend.list_attached_group_policies(user_group.name)
attached_group_policies, _ = iam_backend.list_attached_group_policies(
user_group.name
)
user_policies += attached_group_policies
return user_policies
class AssumedRoleAccessKey(object):
def __init__(self, access_key_id, headers):
for assumed_role in sts_backend.assumed_roles:
if assumed_role.access_key_id == access_key_id:
@ -118,28 +128,33 @@ class AssumedRoleAccessKey(object):
return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format(
account_id=ACCOUNT_ID,
role_name=self._owner_role_name,
session_name=self._session_name
session_name=self._session_name,
)
def create_credentials(self):
return Credentials(self._access_key_id, self._secret_access_key, self._session_token)
return Credentials(
self._access_key_id, self._secret_access_key, self._session_token
)
def collect_policies(self):
role_policies = []
inline_policy_names = iam_backend.list_role_policies(self._owner_role_name)
for inline_policy_name in inline_policy_names:
_, inline_policy = iam_backend.get_role_policy(self._owner_role_name, inline_policy_name)
_, inline_policy = iam_backend.get_role_policy(
self._owner_role_name, inline_policy_name
)
role_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_role_policies(self._owner_role_name)
attached_policies, _ = iam_backend.list_attached_role_policies(
self._owner_role_name
)
role_policies += attached_policies
return role_policies
class CreateAccessKeyFailure(Exception):
def __init__(self, reason, *args):
super(CreateAccessKeyFailure, self).__init__(*args)
self.reason = reason
@ -147,32 +162,54 @@ class CreateAccessKeyFailure(Exception):
@six.add_metaclass(ABCMeta)
class IAMRequestBase(object):
def __init__(self, method, path, data, headers):
log.debug("Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format(
class_name=self.__class__.__name__, method=method, path=path, data=data, headers=headers))
log.debug(
"Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format(
class_name=self.__class__.__name__,
method=method,
path=path,
data=data,
headers=headers,
)
)
self._method = method
self._path = path
self._data = data
self._headers = headers
credential_scope = self._get_string_between('Credential=', ',', self._headers['Authorization'])
credential_data = credential_scope.split('/')
credential_scope = self._get_string_between(
"Credential=", ",", self._headers["Authorization"]
)
credential_data = credential_scope.split("/")
self._region = credential_data[2]
self._service = credential_data[3]
self._action = self._service + ":" + (self._data["Action"][0] if isinstance(self._data["Action"], list) else self._data["Action"])
self._action = (
self._service
+ ":"
+ (
self._data["Action"][0]
if isinstance(self._data["Action"], list)
else self._data["Action"]
)
)
try:
self._access_key = create_access_key(access_key_id=credential_data[0], headers=headers)
self._access_key = create_access_key(
access_key_id=credential_data[0], headers=headers
)
except CreateAccessKeyFailure as e:
self._raise_invalid_access_key(e.reason)
def check_signature(self):
original_signature = self._get_string_between('Signature=', ',', self._headers['Authorization'])
original_signature = self._get_string_between(
"Signature=", ",", self._headers["Authorization"]
)
calculated_signature = self._calculate_signature()
if original_signature != calculated_signature:
self._raise_signature_does_not_match()
def check_action_permitted(self):
if self._action == 'sts:GetCallerIdentity': # always allowed, even if there's an explicit Deny for it
if (
self._action == "sts:GetCallerIdentity"
): # always allowed, even if there's an explicit Deny for it
return True
policies = self._access_key.collect_policies()
@ -213,10 +250,14 @@ class IAMRequestBase(object):
return headers
def _create_aws_request(self):
signed_headers = self._get_string_between('SignedHeaders=', ',', self._headers['Authorization']).split(';')
signed_headers = self._get_string_between(
"SignedHeaders=", ",", self._headers["Authorization"]
).split(";")
headers = self._create_headers_for_aws_request(signed_headers, self._headers)
request = AWSRequest(method=self._method, url=self._path, data=self._data, headers=headers)
request.context['timestamp'] = headers['X-Amz-Date']
request = AWSRequest(
method=self._method, url=self._path, data=self._data, headers=headers
)
request.context["timestamp"] = headers["X-Amz-Date"]
return request
@ -234,7 +275,6 @@ class IAMRequestBase(object):
class IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if self._service == "ec2":
raise AuthFailureError()
@ -251,14 +291,10 @@ class IAMRequest(IAMRequestBase):
return SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self):
raise AccessDeniedError(
user_arn=self._access_key.arn,
action=self._action
)
raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action)
class S3IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self):
if "BucketName" in self._data:
raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"])
@ -288,10 +324,13 @@ class S3IAMRequest(IAMRequestBase):
class IAMPolicy(object):
def __init__(self, policy):
if isinstance(policy, Policy):
default_version = next(policy_version for policy_version in policy.versions if policy_version.is_default)
default_version = next(
policy_version
for policy_version in policy.versions
if policy_version.is_default
)
policy_document = default_version.document
elif isinstance(policy, string_types):
policy_document = policy
@ -321,7 +360,6 @@ class IAMPolicy(object):
class IAMPolicyStatement(object):
def __init__(self, statement):
self._statement = statement

View File

@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment
SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error>
<Code>{{error_type}}</Code>
<Message>{{message}}</Message>
@ -13,7 +13,7 @@ SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error>
"""
ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse>
<Errors>
<Error>
@ -26,7 +26,7 @@ ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</ErrorResponse>
"""
ERROR_JSON_RESPONSE = u"""{
ERROR_JSON_RESPONSE = """{
"message": "{{message}}",
"__type": "{{error_type}}"
}
@ -37,18 +37,19 @@ class RESTError(HTTPException):
code = 400
templates = {
'single_error': SINGLE_ERROR_RESPONSE,
'error': ERROR_RESPONSE,
'error_json': ERROR_JSON_RESPONSE,
"single_error": SINGLE_ERROR_RESPONSE,
"error": ERROR_RESPONSE,
"error_json": ERROR_JSON_RESPONSE,
}
def __init__(self, error_type, message, template='error', **kwargs):
def __init__(self, error_type, message, template="error", **kwargs):
super(RESTError, self).__init__()
env = Environment(loader=DictLoader(self.templates))
self.error_type = error_type
self.message = message
self.description = env.get_template(template).render(
error_type=error_type, message=message, **kwargs)
error_type=error_type, message=message, **kwargs
)
class DryRunClientError(RESTError):
@ -56,12 +57,11 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError):
def __init__(self, error_type, message, template='error_json', **kwargs):
super(JsonRESTError, self).__init__(
error_type, message, template, **kwargs)
def __init__(self, error_type, message, template="error_json", **kwargs):
super(JsonRESTError, self).__init__(error_type, message, template, **kwargs)
def get_headers(self, *args, **kwargs):
return [('Content-Type', 'application/json')]
return [("Content-Type", "application/json")]
def get_body(self, *args, **kwargs):
return self.description
@ -72,8 +72,9 @@ class SignatureDoesNotMatchError(RESTError):
def __init__(self):
super(SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch',
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.")
"SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
)
class InvalidClientTokenIdError(RESTError):
@ -81,8 +82,9 @@ class InvalidClientTokenIdError(RESTError):
def __init__(self):
super(InvalidClientTokenIdError, self).__init__(
'InvalidClientTokenId',
"The security token included in the request is invalid.")
"InvalidClientTokenId",
"The security token included in the request is invalid.",
)
class AccessDeniedError(RESTError):
@ -90,11 +92,11 @@ class AccessDeniedError(RESTError):
def __init__(self, user_arn, action):
super(AccessDeniedError, self).__init__(
'AccessDenied',
"AccessDenied",
"User: {user_arn} is not authorized to perform: {operation}".format(
user_arn=user_arn,
operation=action
))
user_arn=user_arn, operation=action
),
)
class AuthFailureError(RESTError):
@ -102,13 +104,17 @@ class AuthFailureError(RESTError):
def __init__(self):
super(AuthFailureError, self).__init__(
'AuthFailure',
"AWS was not able to validate the provided access credentials")
"AuthFailure",
"AWS was not able to validate the provided access credentials",
)
class InvalidNextTokenException(JsonRESTError):
"""For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core."""
code = 400
def __init__(self):
super(InvalidNextTokenException, self).__init__('InvalidNextTokenException', 'The nextToken provided is invalid')
super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", "The nextToken provided is invalid"
)

View File

@ -31,15 +31,19 @@ class BaseMockAWS(object):
self.backends_for_urls = {}
from moto.backends import BACKENDS
default_backends = {
"instance_metadata": BACKENDS['instance_metadata']['global'],
"moto_api": BACKENDS['moto_api']['global'],
"instance_metadata": BACKENDS["instance_metadata"]["global"],
"moto_api": BACKENDS["moto_api"]["global"],
}
self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"}
FAKE_KEYS = {
"AWS_ACCESS_KEY_ID": "foobar_key",
"AWS_SECRET_ACCESS_KEY": "foobar_secret",
}
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0:
@ -72,7 +76,7 @@ class BaseMockAWS(object):
self.__class__.nested_count -= 1
if self.__class__.nested_count < 0:
raise RuntimeError('Called stop() before start().')
raise RuntimeError("Called stop() before start().")
if self.__class__.nested_count == 0:
self.disable_patching()
@ -85,6 +89,7 @@ class BaseMockAWS(object):
finally:
self.stop()
return result
functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func
return wrapper
@ -122,7 +127,6 @@ class BaseMockAWS(object):
class HttprettyMockAWS(BaseMockAWS):
def reset(self):
HTTPretty.reset()
@ -144,18 +148,26 @@ class HttprettyMockAWS(BaseMockAWS):
HTTPretty.reset()
RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD,
responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT]
RESPONSES_METHODS = [
responses.GET,
responses.DELETE,
responses.HEAD,
responses.OPTIONS,
responses.PATCH,
responses.POST,
responses.PUT,
]
class CallbackResponse(responses.CallbackResponse):
'''
"""
Need to subclass so we can change a couple things
'''
"""
def get_response(self, request):
'''
"""
Need to override this so we can pass decode_content=False
'''
"""
headers = self.get_headers()
result = self.callback(request)
@ -177,17 +189,17 @@ class CallbackResponse(responses.CallbackResponse):
)
def _url_matches(self, url, other, match_querystring=False):
'''
"""
Need to override this so we can fix querystrings breaking regex matching
'''
"""
if not match_querystring:
other = other.split('?', 1)[0]
other = other.split("?", 1)[0]
if responses._is_string(url):
if responses._has_unicode(url):
url = responses._clean_unicode(url)
if not isinstance(other, six.text_type):
other = other.encode('ascii').decode('utf8')
other = other.encode("ascii").decode("utf8")
return self._url_matches_strict(url, other)
elif isinstance(url, responses.Pattern) and url.match(other):
return True
@ -195,22 +207,23 @@ class CallbackResponse(responses.CallbackResponse):
return False
botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send')
botocore_mock = responses.RequestsMock(
assert_all_requests_are_fired=False,
target="botocore.vendored.requests.adapters.HTTPAdapter.send",
)
responses_mock = responses._default_mock
# Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests.
responses_mock.add_passthru("http")
BOTOCORE_HTTP_METHODS = [
'GET', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'
]
BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
class MockRawResponse(BytesIO):
def __init__(self, input):
if isinstance(input, six.text_type):
input = input.encode('utf-8')
input = input.encode("utf-8")
super(MockRawResponse, self).__init__(input)
def stream(self, **kwargs):
@ -241,7 +254,7 @@ class BotocoreStubber(object):
found_index = None
matchers = self.methods.get(request.method)
base_url = request.url.split('?', 1)[0]
base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers):
if pattern.match(base_url):
if found_index is None:
@ -254,8 +267,10 @@ class BotocoreStubber(object):
if response_callback is not None:
for header, value in request.headers.items():
if isinstance(value, six.binary_type):
request.headers[header] = value.decode('utf-8')
status, headers, body = response_callback(request, request.url, request.headers)
request.headers[header] = value.decode("utf-8")
status, headers, body = response_callback(
request, request.url, request.headers
)
body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body)
@ -263,7 +278,7 @@ class BotocoreStubber(object):
botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(('before-send', botocore_stubber))
BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
def not_implemented_callback(request):
@ -287,7 +302,9 @@ class BotocoreEventMockAWS(BaseMockAWS):
pattern = re.compile(key)
botocore_stubber.register_response(method, pattern, value)
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'):
if not hasattr(responses_mock, "_patcher") or not hasattr(
responses_mock._patcher, "target"
):
responses_mock.start()
for method in RESPONSES_METHODS:
@ -336,9 +353,9 @@ MockAWS = BotocoreEventMockAWS
class ServerModeMockAWS(BaseMockAWS):
def reset(self):
import requests
requests.post("http://localhost:5000/moto-api/reset")
def enable_patching(self):
@ -350,13 +367,13 @@ class ServerModeMockAWS(BaseMockAWS):
import mock
def fake_boto3_client(*args, **kwargs):
if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000"
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs):
if 'endpoint_url' not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000"
if "endpoint_url" not in kwargs:
kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_resource(*args, **kwargs)
def fake_httplib_send_output(self, message_body=None, *args, **kwargs):
@ -364,7 +381,7 @@ class ServerModeMockAWS(BaseMockAWS):
bytes_buffer = []
for chunk in mixed_buffer:
if isinstance(chunk, six.text_type):
bytes_buffer.append(chunk.encode('utf-8'))
bytes_buffer.append(chunk.encode("utf-8"))
else:
bytes_buffer.append(chunk)
msg = b"\r\n".join(bytes_buffer)
@ -385,10 +402,12 @@ class ServerModeMockAWS(BaseMockAWS):
if message_body is not None:
self.send(message_body)
self._client_patcher = mock.patch('boto3.client', fake_boto3_client)
self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource)
self._client_patcher = mock.patch("boto3.client", fake_boto3_client)
self._resource_patcher = mock.patch("boto3.resource", fake_boto3_resource)
if six.PY2:
self._httplib_patcher = mock.patch('httplib.HTTPConnection._send_output', fake_httplib_send_output)
self._httplib_patcher = mock.patch(
"httplib.HTTPConnection._send_output", fake_httplib_send_output
)
self._client_patcher.start()
self._resource_patcher.start()
@ -404,7 +423,6 @@ class ServerModeMockAWS(BaseMockAWS):
class Model(type):
def __new__(self, clsname, bases, namespace):
cls = super(Model, self).__new__(self, clsname, bases, namespace)
cls.__models__ = {}
@ -419,9 +437,11 @@ class Model(type):
@staticmethod
def prop(model_name):
""" decorator to mark a class method as returning model values """
def dec(f):
f.__returns_model__ = model_name
return f
return dec
@ -431,7 +451,7 @@ model_data = defaultdict(dict)
class InstanceTrackerMeta(type):
def __new__(meta, name, bases, dct):
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
if name == 'BaseModel':
if name == "BaseModel":
return cls
service = cls.__module__.split(".")[1]
@ -450,7 +470,6 @@ class BaseModel(object):
class BaseBackend(object):
def _reset_model_refs(self):
# Remove all references to the models stored
for service, models in model_data.items():
@ -466,8 +485,9 @@ class BaseBackend(object):
def _url_module(self):
backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(backend_urls_module_name, fromlist=[
'url_bases', 'url_paths'])
backend_urls_module = __import__(
backend_urls_module_name, fromlist=["url_bases", "url_paths"]
)
return backend_urls_module
@property
@ -523,9 +543,9 @@ class BaseBackend(object):
def decorator(self, func=None):
if settings.TEST_SERVER_MODE:
mocked_backend = ServerModeMockAWS({'global': self})
mocked_backend = ServerModeMockAWS({"global": self})
else:
mocked_backend = MockAWS({'global': self})
mocked_backend = MockAWS({"global": self})
if func:
return mocked_backend(func)
@ -534,9 +554,9 @@ class BaseBackend(object):
def deprecated_decorator(self, func=None):
if func:
return HttprettyMockAWS({'global': self})(func)
return HttprettyMockAWS({"global": self})(func)
else:
return HttprettyMockAWS({'global': self})
return HttprettyMockAWS({"global": self})
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
@ -544,12 +564,19 @@ class BaseBackend(object):
class ConfigQueryModel(object):
def __init__(self, backends):
"""Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends
def list_config_service_resources(self, resource_ids, resource_name, limit, next_token, backend_region=None, resource_region=None):
def list_config_service_resources(
self,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
):
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region.
This supports both aggregated and non-aggregated listing. The following notes the difference:
@ -593,7 +620,9 @@ class ConfigQueryModel(object):
"""
raise NotImplementedError()
def get_config_resource(self, resource_id, resource_name=None, backend_region=None, resource_region=None):
def get_config_resource(
self, resource_id, resource_name=None, backend_region=None, resource_region=None
):
"""For AWS Config. This will query the backend for the specific resource type configuration.
This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests
@ -644,9 +673,9 @@ class deprecated_base_decorator(base_decorator):
class MotoAPIBackend(BaseBackend):
def reset(self):
from moto.backends import BACKENDS
for name, backends in BACKENDS.items():
if name == "moto_api":
continue

View File

@ -40,7 +40,7 @@ def _decode_dict(d):
newkey = []
for k in key:
if isinstance(k, six.binary_type):
newkey.append(k.decode('utf-8'))
newkey.append(k.decode("utf-8"))
else:
newkey.append(k)
else:
@ -52,7 +52,7 @@ def _decode_dict(d):
newvalue = []
for v in value:
if isinstance(v, six.binary_type):
newvalue.append(v.decode('utf-8'))
newvalue.append(v.decode("utf-8"))
else:
newvalue.append(v)
else:
@ -83,12 +83,15 @@ class DynamicDictLoader(DictLoader):
class _TemplateEnvironmentMixin(object):
LEFT_PATTERN = re.compile(r"[\s\n]+<")
RIGHT_PATTERN = re.compile(r">[\s\n]+")
def __init__(self):
super(_TemplateEnvironmentMixin, self).__init__()
self.loader = DynamicDictLoader({})
self.environment = Environment(
loader=self.loader, autoescape=self.should_autoescape)
loader=self.loader, autoescape=self.should_autoescape
)
@property
def should_autoescape(self):
@ -101,9 +104,16 @@ class _TemplateEnvironmentMixin(object):
def response_template(self, source):
template_id = id(source)
if not self.contains_template(template_id):
self.loader.update({template_id: source})
self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape, trim_blocks=True,
lstrip_blocks=True)
collapsed = re.sub(
self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
)
self.loader.update({template_id: collapsed})
self.environment = Environment(
loader=self.loader,
autoescape=self.should_autoescape,
trim_blocks=True,
lstrip_blocks=True,
)
return self.environment.get_template(template_id)
@ -112,8 +122,13 @@ class ActionAuthenticatorMixin(object):
request_count = 0
def _authenticate_and_authorize_action(self, iam_request_cls):
if ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT:
iam_request = iam_request_cls(method=self.method, path=self.path, data=self.data, headers=self.headers)
if (
ActionAuthenticatorMixin.request_count
>= settings.INITIAL_NO_AUTH_ACTION_COUNT
):
iam_request = iam_request_cls(
method=self.method, path=self.path, data=self.data, headers=self.headers
)
iam_request.check_signature()
iam_request.check_action_permitted()
else:
@ -130,10 +145,17 @@ class ActionAuthenticatorMixin(object):
def decorator(function):
def wrapper(*args, **kwargs):
if settings.TEST_SERVER_MODE:
response = requests.post("http://localhost:5000/moto-api/reset-auth", data=str(initial_no_auth_action_count).encode())
original_initial_no_auth_action_count = response.json()['PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT']
response = requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(initial_no_auth_action_count).encode(),
)
original_initial_no_auth_action_count = response.json()[
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT"
]
else:
original_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
original_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
original_request_count = ActionAuthenticatorMixin.request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count
ActionAuthenticatorMixin.request_count = 0
@ -141,10 +163,15 @@ class ActionAuthenticatorMixin(object):
result = function(*args, **kwargs)
finally:
if settings.TEST_SERVER_MODE:
requests.post("http://localhost:5000/moto-api/reset-auth", data=str(original_initial_no_auth_action_count).encode())
requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(original_initial_no_auth_action_count).encode(),
)
else:
ActionAuthenticatorMixin.request_count = original_request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = original_initial_no_auth_action_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = (
original_initial_no_auth_action_count
)
return result
functools.update_wrapper(wrapper, function)
@ -156,11 +183,13 @@ class ActionAuthenticatorMixin(object):
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = 'us-east-1'
default_region = "us-east-1"
# to extract region, use [^.]
region_regex = re.compile(r'\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com')
param_list_regex = re.compile(r'(.*)\.(\d+)\.')
access_key_regex = re.compile(r'AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]')
region_regex = re.compile(r"\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com")
param_list_regex = re.compile(r"(.*)\.(\d+)\.")
access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
)
aws_service_spec = None
@classmethod
@ -169,7 +198,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def setup_class(self, request, full_url, headers):
querystring = {}
if hasattr(request, 'body'):
if hasattr(request, "body"):
# Boto
self.body = request.body
else:
@ -182,24 +211,29 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
querystring = {}
for key, value in request.form.items():
querystring[key] = [value, ]
querystring[key] = [value]
raw_body = self.body
if isinstance(self.body, six.binary_type):
self.body = self.body.decode('utf-8')
self.body = self.body.decode("utf-8")
if not querystring:
querystring.update(
parse_qs(urlparse(full_url).query, keep_blank_values=True))
parse_qs(urlparse(full_url).query, keep_blank_values=True)
)
if not querystring:
if 'json' in request.headers.get('content-type', []) and self.aws_service_spec:
if (
"json" in request.headers.get("content-type", [])
and self.aws_service_spec
):
decoded = json.loads(self.body)
target = request.headers.get(
'x-amz-target') or request.headers.get('X-Amz-Target')
service, method = target.split('.')
target = request.headers.get("x-amz-target") or request.headers.get(
"X-Amz-Target"
)
service, method = target.split(".")
input_spec = self.aws_service_spec.input_spec(method)
flat = flatten_json_request_body('', decoded, input_spec)
flat = flatten_json_request_body("", decoded, input_spec)
for key, value in flat.items():
querystring[key] = [value]
elif self.body:
@ -224,17 +258,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.uri_match = None
self.headers = request.headers
if 'host' not in self.headers:
self.headers['host'] = urlparse(full_url).netloc
if "host" not in self.headers:
self.headers["host"] = urlparse(full_url).netloc
self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, request, full_url):
match = self.region_regex.search(full_url)
if match:
region = match.group(1)
elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']:
region = request.headers['Authorization'].split(",")[
0].split("/")[2]
elif (
"Authorization" in request.headers
and "AWS4" in request.headers["Authorization"]
):
region = request.headers["Authorization"].split(",")[0].split("/")[2]
else:
region = self.default_region
return region
@ -243,16 +279,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
"""
Returns the access key id used in this request as the current user id
"""
if 'Authorization' in self.headers:
match = self.access_key_regex.search(self.headers['Authorization'])
if "Authorization" in self.headers:
match = self.access_key_regex.search(self.headers["Authorization"])
if match:
return match.group(1)
if self.querystring.get('AWSAccessKeyId'):
return self.querystring.get('AWSAccessKeyId')
if self.querystring.get("AWSAccessKeyId"):
return self.querystring.get("AWSAccessKeyId")
else:
# Should we raise an unauthorized exception instead?
return '111122223333'
return "111122223333"
def _dispatch(self, request, full_url, headers):
self.setup_class(request, full_url, headers)
@ -267,17 +303,22 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
-> '^/cars/.*/drivers/.*/drive$'
"""
def _convert(elem, is_last):
if not re.match('^{.*}$', elem):
return elem
name = elem.replace('{', '').replace('}', '')
if is_last:
return '(?P<%s>[^/]*)' % name
return '(?P<%s>.*)' % name
elems = uri.split('/')
def _convert(elem, is_last):
if not re.match("^{.*}$", elem):
return elem
name = elem.replace("{", "").replace("}", "")
if is_last:
return "(?P<%s>[^/]*)" % name
return "(?P<%s>.*)" % name
elems = uri.split("/")
num_elems = len(elems)
regexp = '^{}$'.format('/'.join([_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]))
regexp = "^{}$".format(
"/".join(
[_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]
)
)
return regexp
def _get_action_from_method_and_request_uri(self, method, request_uri):
@ -288,19 +329,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# service response class should have 'SERVICE_NAME' class member,
# if you want to get action from method and url
if not hasattr(self, 'SERVICE_NAME'):
if not hasattr(self, "SERVICE_NAME"):
return None
service = self.SERVICE_NAME
conn = boto3.client(service, region_name=self.region)
# make cache if it does not exist yet
if not hasattr(self, 'method_urls'):
if not hasattr(self, "method_urls"):
self.method_urls = defaultdict(lambda: defaultdict(str))
op_names = conn._service_model.operation_names
for op_name in op_names:
op_model = conn._service_model.operation_model(op_name)
_method = op_model.http['method']
uri_regexp = self.uri_to_regexp(op_model.http['requestUri'])
_method = op_model.http["method"]
uri_regexp = self.uri_to_regexp(op_model.http["requestUri"])
self.method_urls[_method][uri_regexp] = op_model.name
regexp_and_names = self.method_urls[method]
for regexp, name in regexp_and_names.items():
@ -311,11 +352,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return None
def _get_action(self):
action = self.querystring.get('Action', [""])[0]
action = self.querystring.get("Action", [""])[0]
if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this.
match = self.headers.get(
'x-amz-target') or self.headers.get('X-Amz-Target')
match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target")
if match:
action = match.split(".")[-1]
# get action from method and uri
@ -347,10 +387,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return self._send_response(headers, response)
if not action:
return 404, headers, ''
return 404, headers, ""
raise NotImplementedError(
"The {0} action has not been implemented".format(action))
"The {0} action has not been implemented".format(action)
)
@staticmethod
def _send_response(headers, response):
@ -358,11 +399,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
body, new_headers = response
else:
status, new_headers, body = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
headers["status"] = str(headers["status"])
return status, headers, body
def _get_param(self, param_name, if_none=None):
@ -396,9 +437,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _get_bool_param(self, param_name, if_none=None):
val = self._get_param(param_name)
if val is not None:
if val.lower() == 'true':
if val.lower() == "true":
return True
elif val.lower() == 'false':
elif val.lower() == "false":
return False
return if_none
@ -416,11 +457,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if is_tracked(name) or not name.startswith(param_prefix):
continue
if len(name) > len(param_prefix) and \
not name[len(param_prefix):].startswith('.'):
if len(name) > len(param_prefix) and not name[
len(param_prefix) :
].startswith("."):
continue
match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None
match = (
self.param_list_regex.search(name[len(param_prefix) :])
if len(name) > len(param_prefix)
else None
)
if match:
prefix = param_prefix + match.group(1)
value = self._get_multi_param(prefix)
@ -435,7 +481,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if len(value_dict) > 1:
# strip off period prefix
value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()}
value_dict = {
name[len(param_prefix) + 1 :]: value
for name, value in value_dict.items()
}
else:
value_dict = list(value_dict.values())[0]
@ -454,7 +503,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
index = 1
while True:
value_dict = self._get_multi_param_helper(prefix + str(index))
if not value_dict and value_dict != '':
if not value_dict and value_dict != "":
break
values.append(value_dict)
@ -479,8 +528,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
params = {}
for key, value in self.querystring.items():
if key.startswith(param_prefix):
params[camelcase_to_underscores(
key.replace(param_prefix, ""))] = value[0]
params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
0
]
return params
def _get_list_prefix(self, param_prefix):
@ -513,19 +563,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
new_items = {}
for key, value in self.querystring.items():
if key.startswith(index_prefix):
new_items[camelcase_to_underscores(
key.replace(index_prefix, ""))] = value[0]
new_items[
camelcase_to_underscores(key.replace(index_prefix, ""))
] = value[0]
if not new_items:
break
results.append(new_items)
param_index += 1
return results
def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'):
def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"):
results = {}
param_index = 1
while 1:
index_prefix = '{0}.{1}.'.format(param_prefix, param_index)
index_prefix = "{0}.{1}.".format(param_prefix, param_index)
k, v = None, None
for key, value in self.querystring.items():
@ -552,8 +603,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
param_index = 1
while True:
key_name = 'tag.{0}._key'.format(param_index)
value_name = 'tag.{0}._value'.format(param_index)
key_name = "tag.{0}._key".format(param_index)
value_name = "tag.{0}._value".format(param_index)
try:
results[resource_type][tag[key_name]] = tag[value_name]
@ -563,7 +614,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return results
def _get_object_map(self, prefix, name='Name', value='Value'):
def _get_object_map(self, prefix, name="Name", value="Value"):
"""
Given a query dict like
{
@ -591,15 +642,14 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
index = 1
while True:
# Loop through looking for keys representing object name
name_key = '{0}.{1}.{2}'.format(prefix, index, name)
name_key = "{0}.{1}.{2}".format(prefix, index, name)
obj_name = self.querystring.get(name_key)
if not obj_name:
# Found all keys
break
obj = {}
value_key_prefix = '{0}.{1}.{2}.'.format(
prefix, index, value)
value_key_prefix = "{0}.{1}.{2}.".format(prefix, index, value)
for k, v in self.querystring.items():
if k.startswith(value_key_prefix):
_, value_key = k.split(value_key_prefix, 1)
@ -613,31 +663,46 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
@property
def request_json(self):
return 'JSON' in self.querystring.get('ContentType', [])
return "JSON" in self.querystring.get("ContentType", [])
def is_not_dryrun(self, action):
if 'true' in self.querystring.get('DryRun', ['false']):
message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action
raise DryRunClientError(
error_type="DryRunOperation", message=message)
if "true" in self.querystring.get("DryRun", ["false"]):
message = (
"An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set"
% action
)
raise DryRunClientError(error_type="DryRunOperation", message=message)
return True
class MotoAPIResponse(BaseResponse):
def reset_response(self, request, full_url, headers):
if request.method == "POST":
from .models import moto_api_backend
moto_api_backend.reset()
return 200, {}, json.dumps({"status": "ok"})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"})
def reset_auth_response(self, request, full_url, headers):
if request.method == "POST":
previous_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT
previous_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode())
ActionAuthenticatorMixin.request_count = 0
return 200, {}, json.dumps({"status": "ok", "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(previous_initial_no_auth_action_count)})
return (
200,
{},
json.dumps(
{
"status": "ok",
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(
previous_initial_no_auth_action_count
),
}
),
)
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers):
@ -665,7 +730,8 @@ class MotoAPIResponse(BaseResponse):
def dashboard(self, request, full_url, headers):
from flask import render_template
return render_template('dashboard.html')
return render_template("dashboard.html")
class _RecursiveDictRef(object):
@ -676,7 +742,7 @@ class _RecursiveDictRef(object):
self.dic = {}
def __repr__(self):
return '{!r}'.format(self.dic)
return "{!r}".format(self.dic)
def __getattr__(self, key):
return self.dic.__getattr__(key)
@ -700,21 +766,21 @@ class AWSServiceSpec(object):
"""
def __init__(self, path):
self.path = resource_filename('botocore', path)
with io.open(self.path, 'r', encoding='utf-8') as f:
self.path = resource_filename("botocore", path)
with io.open(self.path, "r", encoding="utf-8") as f:
spec = json.load(f)
self.metadata = spec['metadata']
self.operations = spec['operations']
self.shapes = spec['shapes']
self.metadata = spec["metadata"]
self.operations = spec["operations"]
self.shapes = spec["shapes"]
def input_spec(self, operation):
try:
op = self.operations[operation]
except KeyError:
raise ValueError('Invalid operation: {}'.format(operation))
if 'input' not in op:
raise ValueError("Invalid operation: {}".format(operation))
if "input" not in op:
return {}
shape = self.shapes[op['input']['shape']]
shape = self.shapes[op["input"]["shape"]]
return self._expand(shape)
def output_spec(self, operation):
@ -728,129 +794,133 @@ class AWSServiceSpec(object):
try:
op = self.operations[operation]
except KeyError:
raise ValueError('Invalid operation: {}'.format(operation))
if 'output' not in op:
raise ValueError("Invalid operation: {}".format(operation))
if "output" not in op:
return {}
shape = self.shapes[op['output']['shape']]
shape = self.shapes[op["output"]["shape"]]
return self._expand(shape)
def _expand(self, shape):
def expand(dic, seen=None):
seen = seen or {}
if dic['type'] == 'structure':
if dic["type"] == "structure":
nodes = {}
for k, v in dic['members'].items():
for k, v in dic["members"].items():
seen_till_here = dict(seen)
if k in seen_till_here:
nodes[k] = seen_till_here[k]
continue
seen_till_here[k] = _RecursiveDictRef()
nodes[k] = expand(self.shapes[v['shape']], seen_till_here)
nodes[k] = expand(self.shapes[v["shape"]], seen_till_here)
seen_till_here[k].set_reference(k, nodes[k])
nodes['type'] = 'structure'
nodes["type"] = "structure"
return nodes
elif dic['type'] == 'list':
elif dic["type"] == "list":
seen_till_here = dict(seen)
shape = dic['member']['shape']
shape = dic["member"]["shape"]
if shape in seen_till_here:
return seen_till_here[shape]
seen_till_here[shape] = _RecursiveDictRef()
expanded = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, expanded)
return {'type': 'list', 'member': expanded}
return {"type": "list", "member": expanded}
elif dic['type'] == 'map':
elif dic["type"] == "map":
seen_till_here = dict(seen)
node = {'type': 'map'}
node = {"type": "map"}
if 'shape' in dic['key']:
shape = dic['key']['shape']
if "shape" in dic["key"]:
shape = dic["key"]["shape"]
seen_till_here[shape] = _RecursiveDictRef()
node['key'] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['key'])
node["key"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node["key"])
else:
node['key'] = dic['key']['type']
node["key"] = dic["key"]["type"]
if 'shape' in dic['value']:
shape = dic['value']['shape']
if "shape" in dic["value"]:
shape = dic["value"]["shape"]
seen_till_here[shape] = _RecursiveDictRef()
node['value'] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['value'])
node["value"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node["value"])
else:
node['value'] = dic['value']['type']
node["value"] = dic["value"]["type"]
return node
else:
return {'type': dic['type']}
return {"type": dic["type"]}
return expand(shape)
def to_str(value, spec):
vtype = spec['type']
if vtype == 'boolean':
return 'true' if value else 'false'
elif vtype == 'integer':
vtype = spec["type"]
if vtype == "boolean":
return "true" if value else "false"
elif vtype == "integer":
return str(value)
elif vtype == 'float':
elif vtype == "float":
return str(value)
elif vtype == 'double':
elif vtype == "double":
return str(value)
elif vtype == 'timestamp':
return datetime.datetime.utcfromtimestamp(
value).replace(tzinfo=pytz.utc).isoformat()
elif vtype == 'string':
elif vtype == "timestamp":
return (
datetime.datetime.utcfromtimestamp(value)
.replace(tzinfo=pytz.utc)
.isoformat()
)
elif vtype == "string":
return str(value)
elif value is None:
return 'null'
return "null"
else:
raise TypeError('Unknown type {}'.format(vtype))
raise TypeError("Unknown type {}".format(vtype))
def from_str(value, spec):
vtype = spec['type']
if vtype == 'boolean':
return True if value == 'true' else False
elif vtype == 'integer':
vtype = spec["type"]
if vtype == "boolean":
return True if value == "true" else False
elif vtype == "integer":
return int(value)
elif vtype == 'float':
elif vtype == "float":
return float(value)
elif vtype == 'double':
elif vtype == "double":
return float(value)
elif vtype == 'timestamp':
elif vtype == "timestamp":
return value
elif vtype == 'string':
elif vtype == "string":
return value
raise TypeError('Unknown type {}'.format(vtype))
raise TypeError("Unknown type {}".format(vtype))
def flatten_json_request_body(prefix, dict_body, spec):
"""Convert a JSON request body into query params."""
if len(spec) == 1 and 'type' in spec:
if len(spec) == 1 and "type" in spec:
return {prefix: to_str(dict_body, spec)}
flat = {}
for key, value in dict_body.items():
node_type = spec[key]['type']
if node_type == 'list':
node_type = spec[key]["type"]
if node_type == "list":
for idx, v in enumerate(value, 1):
pref = key + '.member.' + str(idx)
flat.update(flatten_json_request_body(
pref, v, spec[key]['member']))
elif node_type == 'map':
pref = key + ".member." + str(idx)
flat.update(flatten_json_request_body(pref, v, spec[key]["member"]))
elif node_type == "map":
for idx, (k, v) in enumerate(value.items(), 1):
pref = key + '.entry.' + str(idx)
flat.update(flatten_json_request_body(
pref + '.key', k, spec[key]['key']))
flat.update(flatten_json_request_body(
pref + '.value', v, spec[key]['value']))
pref = key + ".entry." + str(idx)
flat.update(
flatten_json_request_body(pref + ".key", k, spec[key]["key"])
)
flat.update(
flatten_json_request_body(pref + ".value", v, spec[key]["value"])
)
else:
flat.update(flatten_json_request_body(key, value, spec[key]))
if prefix:
prefix = prefix + '.'
prefix = prefix + "."
return dict((prefix + k, v) for k, v in flat.items())
@ -873,41 +943,40 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
od = OrderedDict()
for k, v in value.items():
if k.startswith('@'):
if k.startswith("@"):
continue
if k not in spec:
# this can happen when with an older version of
# botocore for which the node in XML template is not
# defined in service spec.
log.warning(
'Field %s is not defined by the botocore version in use', k)
log.warning("Field %s is not defined by the botocore version in use", k)
continue
if spec[k]['type'] == 'list':
if spec[k]["type"] == "list":
if v is None:
od[k] = []
elif len(spec[k]['member']) == 1:
if isinstance(v['member'], list):
od[k] = transform(v['member'], spec[k]['member'])
elif len(spec[k]["member"]) == 1:
if isinstance(v["member"], list):
od[k] = transform(v["member"], spec[k]["member"])
else:
od[k] = [transform(v['member'], spec[k]['member'])]
elif isinstance(v['member'], list):
od[k] = [transform(o, spec[k]['member'])
for o in v['member']]
elif isinstance(v['member'], OrderedDict):
od[k] = [transform(v['member'], spec[k]['member'])]
od[k] = [transform(v["member"], spec[k]["member"])]
elif isinstance(v["member"], list):
od[k] = [transform(o, spec[k]["member"]) for o in v["member"]]
elif isinstance(v["member"], OrderedDict):
od[k] = [transform(v["member"], spec[k]["member"])]
else:
raise ValueError('Malformatted input')
elif spec[k]['type'] == 'map':
raise ValueError("Malformatted input")
elif spec[k]["type"] == "map":
if v is None:
od[k] = {}
else:
items = ([v['entry']] if not isinstance(v['entry'], list) else
v['entry'])
items = (
[v["entry"]] if not isinstance(v["entry"], list) else v["entry"]
)
for item in items:
key = from_str(item['key'], spec[k]['key'])
val = from_str(item['value'], spec[k]['value'])
key = from_str(item["key"], spec[k]["key"])
val = from_str(item["value"], spec[k]["value"])
if k not in od:
od[k] = {}
od[k][key] = val
@ -921,7 +990,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
dic = xmltodict.parse(xml)
output_spec = service_spec.output_spec(operation)
try:
for k in (result_node or (operation + 'Response', operation + 'Result')):
for k in result_node or (operation + "Response", operation + "Result"):
dic = dic[k]
except KeyError:
return None

View File

@ -1,15 +1,13 @@
from __future__ import unicode_literals
from .responses import MotoAPIResponse
url_bases = [
"https?://motoapi.amazonaws.com"
]
url_bases = ["https?://motoapi.amazonaws.com"]
response_instance = MotoAPIResponse()
url_paths = {
'{0}/moto-api/$': response_instance.dashboard,
'{0}/moto-api/data.json': response_instance.model_data,
'{0}/moto-api/reset': response_instance.reset_response,
'{0}/moto-api/reset-auth': response_instance.reset_auth_response,
"{0}/moto-api/$": response_instance.dashboard,
"{0}/moto-api/data.json": response_instance.model_data,
"{0}/moto-api/reset": response_instance.reset_response,
"{0}/moto-api/reset-auth": response_instance.reset_auth_response,
}

View File

@ -15,9 +15,9 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase
def camelcase_to_underscores(argument):
''' Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute'''
result = ''
""" Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute"""
result = ""
prev_char_title = True
if not argument:
return argument
@ -41,18 +41,18 @@ def camelcase_to_underscores(argument):
def underscores_to_camelcase(argument):
''' Converts a camelcase param like the_new_attribute to the equivalent
""" Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function '''
result = ''
NOT capitalized by this function """
result = ""
previous_was_underscore = False
for char in argument:
if char != '_':
if char != "_":
if previous_was_underscore:
result += char.upper()
else:
result += char
previous_was_underscore = char == '_'
previous_was_underscore = char == "_"
return result
@ -69,12 +69,18 @@ def method_names_from_class(clazz):
def get_random_hex(length=8):
chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f']
return ''.join(six.text_type(random.choice(chars)) for x in range(length))
chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"]
return "".join(six.text_type(random.choice(chars)) for x in range(length))
def get_random_message_id():
return '{0}-{1}-{2}-{3}-{4}'.format(get_random_hex(8), get_random_hex(4), get_random_hex(4), get_random_hex(4), get_random_hex(12))
return "{0}-{1}-{2}-{3}-{4}".format(
get_random_hex(8),
get_random_hex(4),
get_random_hex(4),
get_random_hex(4),
get_random_hex(12),
)
def convert_regex_to_flask_path(url_path):
@ -97,7 +103,6 @@ def convert_regex_to_flask_path(url_path):
class convert_httpretty_response(object):
def __init__(self, callback):
self.callback = callback
@ -114,13 +119,12 @@ class convert_httpretty_response(object):
def __call__(self, request, url, headers, **kwargs):
result = self.callback(request, url, headers)
status, headers, response = result
if 'server' not in headers:
if "server" not in headers:
headers["server"] = "amazon.com"
return status, headers, response
class convert_flask_to_httpretty_response(object):
def __init__(self, callback):
self.callback = callback
@ -145,13 +149,12 @@ class convert_flask_to_httpretty_response(object):
status, headers, content = 200, {}, result
response = Response(response=content, status=status, headers=headers)
if request.method == "HEAD" and 'content-length' in headers:
response.headers['Content-Length'] = headers['content-length']
if request.method == "HEAD" and "content-length" in headers:
response.headers["Content-Length"] = headers["content-length"]
return response
class convert_flask_to_responses_response(object):
def __init__(self, callback):
self.callback = callback
@ -176,14 +179,14 @@ class convert_flask_to_responses_response(object):
def iso_8601_datetime_with_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + 'Z'
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def iso_8601_datetime_without_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z'
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
RFC1123 = '%a, %d %b %Y %H:%M:%S GMT'
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime):
@ -212,16 +215,16 @@ def gen_amz_crc32(response, headerdict=None):
crc = str(binascii.crc32(response))
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amz-crc32': crc})
headerdict.update({"x-amz-crc32": crc})
return crc
def gen_amzn_requestid_long(headerdict=None):
req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
req_id = "".join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amzn-requestid': req_id})
headerdict.update({"x-amzn-requestid": req_id})
return req_id
@ -239,13 +242,13 @@ def amz_crc32(f):
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
else:
status, new_headers, body = response
headers.update(new_headers)
# Cast status to string
if "status" in headers:
headers['status'] = str(headers['status'])
headers["status"] = str(headers["status"])
try:
# Doesnt work on python2 for some odd unicode strings
@ -271,7 +274,7 @@ def amzn_request_id(f):
else:
if len(response) == 2:
body, new_headers = response
status = new_headers.get('status', 200)
status = new_headers.get("status", 200)
else:
status, new_headers, body = response
headers.update(new_headers)
@ -280,7 +283,7 @@ def amzn_request_id(f):
# Update request ID in XML
try:
body = re.sub(r'(?<=<RequestId>).*(?=<\/RequestId>)', request_id, body)
body = re.sub(r"(?<=<RequestId>).*(?=<\/RequestId>)", request_id, body)
except Exception: # Will just ignore if it cant work on bytes (which are str's on python2)
pass
@ -293,7 +296,7 @@ def path_url(url):
parsed_url = urlparse(url)
path = parsed_url.path
if not path:
path = '/'
path = "/"
if parsed_url.query:
path = path + '?' + parsed_url.query
path = path + "?" + parsed_url.query
return path

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import datapipeline_backends
from ..core.models import base_decorator, deprecated_base_decorator
datapipeline_backend = datapipeline_backends['us-east-1']
datapipeline_backend = datapipeline_backends["us-east-1"]
mock_datapipeline = base_decorator(datapipeline_backends)
mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends)

View File

@ -8,85 +8,65 @@ from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys
class PipelineObject(BaseModel):
def __init__(self, object_id, name, fields):
self.object_id = object_id
self.name = name
self.fields = fields
def to_json(self):
return {
"fields": self.fields,
"id": self.object_id,
"name": self.name,
}
return {"fields": self.fields, "id": self.object_id, "name": self.name}
class Pipeline(BaseModel):
def __init__(self, name, unique_id, **kwargs):
self.name = name
self.unique_id = unique_id
self.description = kwargs.get('description', '')
self.description = kwargs.get("description", "")
self.pipeline_id = get_random_pipeline_id()
self.creation_time = datetime.datetime.utcnow()
self.objects = []
self.status = "PENDING"
self.tags = kwargs.get('tags', [])
self.tags = kwargs.get("tags", [])
@property
def physical_resource_id(self):
return self.pipeline_id
def to_meta_json(self):
return {
"id": self.pipeline_id,
"name": self.name,
}
return {"id": self.pipeline_id, "name": self.name}
def to_json(self):
return {
"description": self.description,
"fields": [{
"key": "@pipelineState",
"stringValue": self.status,
}, {
"key": "description",
"stringValue": self.description
}, {
"key": "name",
"stringValue": self.name
}, {
"key": "@creationTime",
"stringValue": datetime.datetime.strftime(self.creation_time, '%Y-%m-%dT%H-%M-%S'),
}, {
"key": "@id",
"stringValue": self.pipeline_id,
}, {
"key": "@sphere",
"stringValue": "PIPELINE"
}, {
"key": "@version",
"stringValue": "1"
}, {
"key": "@userId",
"stringValue": "924374875933"
}, {
"key": "@accountId",
"stringValue": "924374875933"
}, {
"key": "uniqueId",
"stringValue": self.unique_id
}],
"fields": [
{"key": "@pipelineState", "stringValue": self.status},
{"key": "description", "stringValue": self.description},
{"key": "name", "stringValue": self.name},
{
"key": "@creationTime",
"stringValue": datetime.datetime.strftime(
self.creation_time, "%Y-%m-%dT%H-%M-%S"
),
},
{"key": "@id", "stringValue": self.pipeline_id},
{"key": "@sphere", "stringValue": "PIPELINE"},
{"key": "@version", "stringValue": "1"},
{"key": "@userId", "stringValue": "924374875933"},
{"key": "@accountId", "stringValue": "924374875933"},
{"key": "uniqueId", "stringValue": self.unique_id},
],
"name": self.name,
"pipelineId": self.pipeline_id,
"tags": self.tags
"tags": self.tags,
}
def set_pipeline_objects(self, pipeline_objects):
self.objects = [
PipelineObject(pipeline_object['id'], pipeline_object[
'name'], pipeline_object['fields'])
PipelineObject(
pipeline_object["id"],
pipeline_object["name"],
pipeline_object["fields"],
)
for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects)
]
@ -94,15 +74,19 @@ class Pipeline(BaseModel):
self.status = "SCHEDULED"
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
datapipeline_backend = datapipeline_backends[region_name]
properties = cloudformation_json["Properties"]
cloudformation_unique_id = "cf-" + properties["Name"]
pipeline = datapipeline_backend.create_pipeline(
properties["Name"], cloudformation_unique_id)
properties["Name"], cloudformation_unique_id
)
datapipeline_backend.put_pipeline_definition(
pipeline.pipeline_id, properties["PipelineObjects"])
pipeline.pipeline_id, properties["PipelineObjects"]
)
if properties["Activate"]:
pipeline.activate()
@ -110,7 +94,6 @@ class Pipeline(BaseModel):
class DataPipelineBackend(BaseBackend):
def __init__(self):
self.pipelines = OrderedDict()
@ -123,8 +106,11 @@ class DataPipelineBackend(BaseBackend):
return self.pipelines.values()
def describe_pipelines(self, pipeline_ids):
pipelines = [pipeline for pipeline in self.pipelines.values(
) if pipeline.pipeline_id in pipeline_ids]
pipelines = [
pipeline
for pipeline in self.pipelines.values()
if pipeline.pipeline_id in pipeline_ids
]
return pipelines
def get_pipeline(self, pipeline_id):
@ -144,7 +130,8 @@ class DataPipelineBackend(BaseBackend):
def describe_objects(self, object_ids, pipeline_id):
pipeline = self.get_pipeline(pipeline_id)
pipeline_objects = [
pipeline_object for pipeline_object in pipeline.objects
pipeline_object
for pipeline_object in pipeline.objects
if pipeline_object.object_id in object_ids
]
return pipeline_objects

View File

@ -7,7 +7,6 @@ from .models import datapipeline_backends
class DataPipelineResponse(BaseResponse):
@property
def parameters(self):
# TODO this should really be moved to core/responses.py
@ -21,47 +20,47 @@ class DataPipelineResponse(BaseResponse):
return datapipeline_backends[self.region]
def create_pipeline(self):
name = self.parameters.get('name')
unique_id = self.parameters.get('uniqueId')
description = self.parameters.get('description', '')
tags = self.parameters.get('tags', [])
pipeline = self.datapipeline_backend.create_pipeline(name, unique_id, description=description, tags=tags)
return json.dumps({
"pipelineId": pipeline.pipeline_id,
})
name = self.parameters.get("name")
unique_id = self.parameters.get("uniqueId")
description = self.parameters.get("description", "")
tags = self.parameters.get("tags", [])
pipeline = self.datapipeline_backend.create_pipeline(
name, unique_id, description=description, tags=tags
)
return json.dumps({"pipelineId": pipeline.pipeline_id})
def list_pipelines(self):
pipelines = list(self.datapipeline_backend.list_pipelines())
pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines]
max_pipelines = 50
marker = self.parameters.get('marker')
marker = self.parameters.get("marker")
if marker:
start = pipeline_ids.index(marker) + 1
else:
start = 0
pipelines_resp = pipelines[start:start + max_pipelines]
pipelines_resp = pipelines[start : start + max_pipelines]
has_more_results = False
marker = None
if start + max_pipelines < len(pipeline_ids) - 1:
has_more_results = True
marker = pipelines_resp[-1].pipeline_id
return json.dumps({
"hasMoreResults": has_more_results,
"marker": marker,
"pipelineIdList": [
pipeline.to_meta_json() for pipeline in pipelines_resp
]
})
return json.dumps(
{
"hasMoreResults": has_more_results,
"marker": marker,
"pipelineIdList": [
pipeline.to_meta_json() for pipeline in pipelines_resp
],
}
)
def describe_pipelines(self):
pipeline_ids = self.parameters["pipelineIds"]
pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids)
return json.dumps({
"pipelineDescriptionList": [
pipeline.to_json() for pipeline in pipelines
]
})
return json.dumps(
{"pipelineDescriptionList": [pipeline.to_json() for pipeline in pipelines]}
)
def delete_pipeline(self):
pipeline_id = self.parameters["pipelineId"]
@ -72,31 +71,38 @@ class DataPipelineResponse(BaseResponse):
pipeline_id = self.parameters["pipelineId"]
pipeline_objects = self.parameters["pipelineObjects"]
self.datapipeline_backend.put_pipeline_definition(
pipeline_id, pipeline_objects)
self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects)
return json.dumps({"errored": False})
def get_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"]
pipeline_definition = self.datapipeline_backend.get_pipeline_definition(
pipeline_id)
return json.dumps({
"pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition]
})
pipeline_id
)
return json.dumps(
{
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_definition
]
}
)
def describe_objects(self):
pipeline_id = self.parameters["pipelineId"]
object_ids = self.parameters["objectIds"]
pipeline_objects = self.datapipeline_backend.describe_objects(
object_ids, pipeline_id)
return json.dumps({
"hasMoreResults": False,
"marker": None,
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_objects
]
})
object_ids, pipeline_id
)
return json.dumps(
{
"hasMoreResults": False,
"marker": None,
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_objects
],
}
)
def activate_pipeline(self):
pipeline_id = self.parameters["pipelineId"]

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import DataPipelineResponse
url_bases = [
"https?://datapipeline.(.+).amazonaws.com",
]
url_bases = ["https?://datapipeline.(.+).amazonaws.com"]
url_paths = {
'{0}/$': DataPipelineResponse.dispatch,
}
url_paths = {"{0}/$": DataPipelineResponse.dispatch}

View File

@ -14,7 +14,9 @@ def remove_capitalization_of_dict_keys(obj):
normalized_key = key[:1].lower() + key[1:]
result[normalized_key] = remove_capitalization_of_dict_keys(value)
return result
elif isinstance(obj, collections.Iterable) and not isinstance(obj, six.string_types):
elif isinstance(obj, collections.Iterable) and not isinstance(
obj, six.string_types
):
result = obj.__class__()
for item in obj:
result += (remove_capitalization_of_dict_keys(item),)

View File

@ -1,19 +1,22 @@
from __future__ import unicode_literals
# TODO add tests for all of these
COMPARISON_FUNCS = {
'EQ': lambda item_value, test_value: item_value == test_value,
'NE': lambda item_value, test_value: item_value != test_value,
'LE': lambda item_value, test_value: item_value <= test_value,
'LT': lambda item_value, test_value: item_value < test_value,
'GE': lambda item_value, test_value: item_value >= test_value,
'GT': lambda item_value, test_value: item_value > test_value,
'NULL': lambda item_value: item_value is None,
'NOT_NULL': lambda item_value: item_value is not None,
'CONTAINS': lambda item_value, test_value: test_value in item_value,
'NOT_CONTAINS': lambda item_value, test_value: test_value not in item_value,
'BEGINS_WITH': lambda item_value, test_value: item_value.startswith(test_value),
'IN': lambda item_value, *test_values: item_value in test_values,
'BETWEEN': lambda item_value, lower_test_value, upper_test_value: lower_test_value <= item_value <= upper_test_value,
"EQ": lambda item_value, test_value: item_value == test_value,
"NE": lambda item_value, test_value: item_value != test_value,
"LE": lambda item_value, test_value: item_value <= test_value,
"LT": lambda item_value, test_value: item_value < test_value,
"GE": lambda item_value, test_value: item_value >= test_value,
"GT": lambda item_value, test_value: item_value > test_value,
"NULL": lambda item_value: item_value is None,
"NOT_NULL": lambda item_value: item_value is not None,
"CONTAINS": lambda item_value, test_value: test_value in item_value,
"NOT_CONTAINS": lambda item_value, test_value: test_value not in item_value,
"BEGINS_WITH": lambda item_value, test_value: item_value.startswith(test_value),
"IN": lambda item_value, *test_values: item_value in test_values,
"BETWEEN": lambda item_value, lower_test_value, upper_test_value: lower_test_value
<= item_value
<= upper_test_value,
}

View File

@ -10,9 +10,8 @@ from .comparisons import get_comparison_func
class DynamoJsonEncoder(json.JSONEncoder):
def default(self, obj):
if hasattr(obj, 'to_json'):
if hasattr(obj, "to_json"):
return obj.to_json()
@ -33,10 +32,7 @@ class DynamoType(object):
return hash((self.type, self.value))
def __eq__(self, other):
return (
self.type == other.type and
self.value == other.value
)
return self.type == other.type and self.value == other.value
def __repr__(self):
return "DynamoType: {0}".format(self.to_json())
@ -54,7 +50,6 @@ class DynamoType(object):
class Item(BaseModel):
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
self.hash_key = hash_key
self.hash_key_type = hash_key_type
@ -73,9 +68,7 @@ class Item(BaseModel):
for attribute_key, attribute in self.attrs.items():
attributes[attribute_key] = attribute.value
return {
"Attributes": attributes
}
return {"Attributes": attributes}
def describe_attrs(self, attributes):
if attributes:
@ -85,16 +78,20 @@ class Item(BaseModel):
included[key] = value
else:
included = self.attrs
return {
"Item": included
}
return {"Item": included}
class Table(BaseModel):
def __init__(self, name, hash_key_attr, hash_key_type,
range_key_attr=None, range_key_type=None, read_capacity=None,
write_capacity=None):
def __init__(
self,
name,
hash_key_attr,
hash_key_type,
range_key_attr=None,
range_key_type=None,
read_capacity=None,
write_capacity=None,
):
self.name = name
self.hash_key_attr = hash_key_attr
self.hash_key_type = hash_key_type
@ -117,12 +114,12 @@ class Table(BaseModel):
"KeySchema": {
"HashKeyElement": {
"AttributeName": self.hash_key_attr,
"AttributeType": self.hash_key_type
},
"AttributeType": self.hash_key_type,
}
},
"ProvisionedThroughput": {
"ReadCapacityUnits": self.read_capacity,
"WriteCapacityUnits": self.write_capacity
"WriteCapacityUnits": self.write_capacity,
},
"TableName": self.name,
"TableStatus": "ACTIVE",
@ -133,19 +130,29 @@ class Table(BaseModel):
if self.has_range_key:
results["Table"]["KeySchema"]["RangeKeyElement"] = {
"AttributeName": self.range_key_attr,
"AttributeType": self.range_key_type
"AttributeType": self.range_key_type,
}
return results
@classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name):
properties = cloudformation_json['Properties']
key_attr = [i['AttributeName'] for i in properties['KeySchema'] if i['KeyType'] == 'HASH'][0]
key_type = [i['AttributeType'] for i in properties['AttributeDefinitions'] if i['AttributeName'] == key_attr][0]
def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
key_attr = [
i["AttributeName"]
for i in properties["KeySchema"]
if i["KeyType"] == "HASH"
][0]
key_type = [
i["AttributeType"]
for i in properties["AttributeDefinitions"]
if i["AttributeName"] == key_attr
][0]
spec = {
'name': properties['TableName'],
'hash_key_attr': key_attr,
'hash_key_type': key_type
"name": properties["TableName"],
"hash_key_attr": key_attr,
"hash_key_type": key_type,
}
# TODO: optional properties still missing:
# range_key_attr, range_key_type, read_capacity, write_capacity
@ -173,8 +180,9 @@ class Table(BaseModel):
else:
range_value = None
item = Item(hash_value, self.hash_key_type, range_value,
self.range_key_type, item_attrs)
item = Item(
hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs
)
if range_value:
self.items[hash_value][range_value] = item
@ -185,7 +193,8 @@ class Table(BaseModel):
def get_item(self, hash_key, range_key):
if self.has_range_key and not range_key:
raise ValueError(
"Table has a range key, but no range key was passed into get_item")
"Table has a range key, but no range key was passed into get_item"
)
try:
if range_key:
return self.items[hash_key][range_key]
@ -228,7 +237,10 @@ class Table(BaseModel):
for result in self.all_items():
scanned_count += 1
passes_all_conditions = True
for attribute_name, (comparison_operator, comparison_objs) in filters.items():
for (
attribute_name,
(comparison_operator, comparison_objs),
) in filters.items():
attribute = result.attrs.get(attribute_name)
if attribute:
@ -236,7 +248,7 @@ class Table(BaseModel):
if not attribute.compare(comparison_operator, comparison_objs):
passes_all_conditions = False
break
elif comparison_operator == 'NULL':
elif comparison_operator == "NULL":
# Comparison is NULL and we don't have the attribute
continue
else:
@ -261,15 +273,17 @@ class Table(BaseModel):
def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'StreamArn':
region = 'us-east-1'
time = '2000-01-01T00:00:00.000'
return 'arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}'.format(region, self.name, time)
if attribute_name == "StreamArn":
region = "us-east-1"
time = "2000-01-01T00:00:00.000"
return "arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}".format(
region, self.name, time
)
raise UnformattedGetAttTemplateException()
class DynamoDBBackend(BaseBackend):
def __init__(self):
self.tables = OrderedDict()
@ -310,8 +324,7 @@ class DynamoDBBackend(BaseBackend):
return None, None
hash_key = DynamoType(hash_key_dict)
range_values = [DynamoType(range_value)
for range_value in range_value_dicts]
range_values = [DynamoType(range_value) for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values)

View File

@ -8,7 +8,6 @@ from .models import dynamodb_backend, dynamo_json_dump
class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers):
"""Parses request headers and extracts part od the X-Amz-Target
that corresponds to a method of DynamoHandler
@ -16,15 +15,15 @@ class DynamoHandler(BaseResponse):
ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables
"""
# Headers are case-insensitive. Probably a better way to do this.
match = headers.get('x-amz-target') or headers.get('X-Amz-Target')
match = headers.get("x-amz-target") or headers.get("X-Amz-Target")
if match:
return match.split(".")[1]
def error(self, type_, status=400):
return status, self.response_headers, dynamo_json_dump({'__type': type_})
return status, self.response_headers, dynamo_json_dump({"__type": type_})
def call_action(self):
self.body = json.loads(self.body or '{}')
self.body = json.loads(self.body or "{}")
endpoint = self.get_endpoint_name(self.headers)
if endpoint:
endpoint = camelcase_to_underscores(endpoint)
@ -41,7 +40,7 @@ class DynamoHandler(BaseResponse):
def list_tables(self):
body = self.body
limit = body.get('Limit')
limit = body.get("Limit")
if body.get("ExclusiveStartTableName"):
last = body.get("ExclusiveStartTableName")
start = list(dynamodb_backend.tables.keys()).index(last) + 1
@ -49,7 +48,7 @@ class DynamoHandler(BaseResponse):
start = 0
all_tables = list(dynamodb_backend.tables.keys())
if limit:
tables = all_tables[start:start + limit]
tables = all_tables[start : start + limit]
else:
tables = all_tables[start:]
response = {"TableNames": tables}
@ -59,16 +58,16 @@ class DynamoHandler(BaseResponse):
def create_table(self):
body = self.body
name = body['TableName']
name = body["TableName"]
key_schema = body['KeySchema']
hash_key = key_schema['HashKeyElement']
hash_key_attr = hash_key['AttributeName']
hash_key_type = hash_key['AttributeType']
key_schema = body["KeySchema"]
hash_key = key_schema["HashKeyElement"]
hash_key_attr = hash_key["AttributeName"]
hash_key_type = hash_key["AttributeType"]
range_key = key_schema.get('RangeKeyElement', {})
range_key_attr = range_key.get('AttributeName')
range_key_type = range_key.get('AttributeType')
range_key = key_schema.get("RangeKeyElement", {})
range_key_attr = range_key.get("AttributeName")
range_key_type = range_key.get("AttributeType")
throughput = body["ProvisionedThroughput"]
read_units = throughput["ReadCapacityUnits"]
@ -86,137 +85,131 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(table.describe)
def delete_table(self):
name = self.body['TableName']
name = self.body["TableName"]
table = dynamodb_backend.delete_table(name)
if table:
return dynamo_json_dump(table.describe)
else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
def update_table(self):
name = self.body['TableName']
name = self.body["TableName"]
throughput = self.body["ProvisionedThroughput"]
new_read_units = throughput["ReadCapacityUnits"]
new_write_units = throughput["WriteCapacityUnits"]
table = dynamodb_backend.update_table_throughput(
name, new_read_units, new_write_units)
name, new_read_units, new_write_units
)
return dynamo_json_dump(table.describe)
def describe_table(self):
name = self.body['TableName']
name = self.body["TableName"]
try:
table = dynamodb_backend.tables[name]
except KeyError:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
return dynamo_json_dump(table.describe)
def put_item(self):
name = self.body['TableName']
item = self.body['Item']
name = self.body["TableName"]
item = self.body["Item"]
result = dynamodb_backend.put_item(name, item)
if result:
item_dict = result.to_json()
item_dict['ConsumedCapacityUnits'] = 1
item_dict["ConsumedCapacityUnits"] = 1
return dynamo_json_dump(item_dict)
else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
def batch_write_item(self):
table_batches = self.body['RequestItems']
table_batches = self.body["RequestItems"]
for table_name, table_requests in table_batches.items():
for table_request in table_requests:
request_type = list(table_request)[0]
request = list(table_request.values())[0]
if request_type == 'PutRequest':
item = request['Item']
if request_type == "PutRequest":
item = request["Item"]
dynamodb_backend.put_item(table_name, item)
elif request_type == 'DeleteRequest':
key = request['Key']
hash_key = key['HashKeyElement']
range_key = key.get('RangeKeyElement')
item = dynamodb_backend.delete_item(
table_name, hash_key, range_key)
elif request_type == "DeleteRequest":
key = request["Key"]
hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement")
item = dynamodb_backend.delete_item(table_name, hash_key, range_key)
response = {
"Responses": {
"Thread": {
"ConsumedCapacityUnits": 1.0
},
"Reply": {
"ConsumedCapacityUnits": 1.0
}
"Thread": {"ConsumedCapacityUnits": 1.0},
"Reply": {"ConsumedCapacityUnits": 1.0},
},
"UnprocessedItems": {}
"UnprocessedItems": {},
}
return dynamo_json_dump(response)
def get_item(self):
name = self.body['TableName']
key = self.body['Key']
hash_key = key['HashKeyElement']
range_key = key.get('RangeKeyElement')
attrs_to_get = self.body.get('AttributesToGet')
name = self.body["TableName"]
key = self.body["Key"]
hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement")
attrs_to_get = self.body.get("AttributesToGet")
try:
item = dynamodb_backend.get_item(name, hash_key, range_key)
except ValueError:
er = 'com.amazon.coral.validate#ValidationException'
er = "com.amazon.coral.validate#ValidationException"
return self.error(er, status=400)
if item:
item_dict = item.describe_attrs(attrs_to_get)
item_dict['ConsumedCapacityUnits'] = 0.5
item_dict["ConsumedCapacityUnits"] = 0.5
return dynamo_json_dump(item_dict)
else:
# Item not found
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er, status=404)
def batch_get_item(self):
table_batches = self.body['RequestItems']
table_batches = self.body["RequestItems"]
results = {
"Responses": {
"UnprocessedKeys": {}
}
}
results = {"Responses": {"UnprocessedKeys": {}}}
for table_name, table_request in table_batches.items():
items = []
keys = table_request['Keys']
attributes_to_get = table_request.get('AttributesToGet')
keys = table_request["Keys"]
attributes_to_get = table_request.get("AttributesToGet")
for key in keys:
hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement")
item = dynamodb_backend.get_item(
table_name, hash_key, range_key)
item = dynamodb_backend.get_item(table_name, hash_key, range_key)
if item:
item_describe = item.describe_attrs(attributes_to_get)
items.append(item_describe)
results["Responses"][table_name] = {
"Items": items, "ConsumedCapacityUnits": 1}
"Items": items,
"ConsumedCapacityUnits": 1,
}
return dynamo_json_dump(results)
def query(self):
name = self.body['TableName']
hash_key = self.body['HashKeyValue']
range_condition = self.body.get('RangeKeyCondition')
name = self.body["TableName"]
hash_key = self.body["HashKeyValue"]
range_condition = self.body.get("RangeKeyCondition")
if range_condition:
range_comparison = range_condition['ComparisonOperator']
range_values = range_condition['AttributeValueList']
range_comparison = range_condition["ComparisonOperator"]
range_values = range_condition["AttributeValueList"]
else:
range_comparison = None
range_values = []
items, last_page = dynamodb_backend.query(
name, hash_key, range_comparison, range_values)
name, hash_key, range_comparison, range_values
)
if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
result = {
@ -234,10 +227,10 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result)
def scan(self):
name = self.body['TableName']
name = self.body["TableName"]
filters = {}
scan_filters = self.body.get('ScanFilter', {})
scan_filters = self.body.get("ScanFilter", {})
for attribute_name, scan_filter in scan_filters.items():
# Keys are attribute names. Values are tuples of (comparison,
# comparison_value)
@ -248,14 +241,14 @@ class DynamoHandler(BaseResponse):
items, scanned_count, last_page = dynamodb_backend.scan(name, filters)
if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)
result = {
"Count": len(items),
"Items": [item.attrs for item in items if item],
"ConsumedCapacityUnits": 1,
"ScannedCount": scanned_count
"ScannedCount": scanned_count,
}
# Implement this when we do pagination
@ -267,19 +260,19 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result)
def delete_item(self):
name = self.body['TableName']
key = self.body['Key']
hash_key = key['HashKeyElement']
range_key = key.get('RangeKeyElement')
return_values = self.body.get('ReturnValues', '')
name = self.body["TableName"]
key = self.body["Key"]
hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement")
return_values = self.body.get("ReturnValues", "")
item = dynamodb_backend.delete_item(name, hash_key, range_key)
if item:
if return_values == 'ALL_OLD':
if return_values == "ALL_OLD":
item_dict = item.to_json()
else:
item_dict = {'Attributes': []}
item_dict['ConsumedCapacityUnits'] = 0.5
item_dict = {"Attributes": []}
item_dict["ConsumedCapacityUnits"] = 0.5
return dynamo_json_dump(item_dict)
else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException'
er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import DynamoHandler
url_bases = [
"https?://dynamodb.(.+).amazonaws.com"
]
url_bases = ["https?://dynamodb.(.+).amazonaws.com"]
url_paths = {
"{0}/": DynamoHandler.dispatch,
}
url_paths = {"{0}/": DynamoHandler.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import dynamodb_backends as dynamodb_backends2
from ..core.models import base_decorator, deprecated_base_decorator
dynamodb_backend2 = dynamodb_backends2['us-east-1']
dynamodb_backend2 = dynamodb_backends2["us-east-1"]
mock_dynamodb2 = base_decorator(dynamodb_backends2)
mock_dynamodb2_deprecated = deprecated_base_decorator(dynamodb_backends2)

File diff suppressed because it is too large Load Diff

View File

@ -7,4 +7,4 @@ class InvalidUpdateExpression(ValueError):
class ItemSizeTooLarge(Exception):
message = 'Item size has exceeded the maximum allowed size'
message = "Item size has exceeded the maximum allowed size"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import DynamoHandler
url_bases = [
"https?://dynamodb.(.+).amazonaws.com"
]
url_bases = ["https?://dynamodb.(.+).amazonaws.com"]
url_paths = {
"{0}/": DynamoHandler.dispatch,
}
url_paths = {"{0}/": DynamoHandler.dispatch}

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import dynamodbstreams_backends
from ..core.models import base_decorator
dynamodbstreams_backend = dynamodbstreams_backends['us-east-1']
dynamodbstreams_backend = dynamodbstreams_backends["us-east-1"]
mock_dynamodbstreams = base_decorator(dynamodbstreams_backends)

View File

@ -10,51 +10,59 @@ from moto.dynamodb2.models import dynamodb_backends
class ShardIterator(BaseModel):
def __init__(self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None):
self.id = base64.b64encode(os.urandom(472)).decode('utf-8')
def __init__(
self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None
):
self.id = base64.b64encode(os.urandom(472)).decode("utf-8")
self.streams_backend = streams_backend
self.stream_shard = stream_shard
self.shard_iterator_type = shard_iterator_type
if shard_iterator_type == 'TRIM_HORIZON':
if shard_iterator_type == "TRIM_HORIZON":
self.sequence_number = stream_shard.starting_sequence_number
elif shard_iterator_type == 'LATEST':
self.sequence_number = stream_shard.starting_sequence_number + len(stream_shard.items)
elif shard_iterator_type == 'AT_SEQUENCE_NUMBER':
elif shard_iterator_type == "LATEST":
self.sequence_number = stream_shard.starting_sequence_number + len(
stream_shard.items
)
elif shard_iterator_type == "AT_SEQUENCE_NUMBER":
self.sequence_number = sequence_number
elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER':
elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER":
self.sequence_number = sequence_number + 1
@property
def arn(self):
return '{}/stream/{}|1|{}'.format(
return "{}/stream/{}|1|{}".format(
self.stream_shard.table.table_arn,
self.stream_shard.table.latest_stream_label,
self.id)
self.id,
)
def to_json(self):
return {
'ShardIterator': self.arn
}
return {"ShardIterator": self.arn}
def get(self, limit=1000):
items = self.stream_shard.get(self.sequence_number, limit)
try:
last_sequence_number = max(int(i['dynamodb']['SequenceNumber']) for i in items)
new_shard_iterator = ShardIterator(self.streams_backend,
self.stream_shard,
'AFTER_SEQUENCE_NUMBER',
last_sequence_number)
last_sequence_number = max(
int(i["dynamodb"]["SequenceNumber"]) for i in items
)
new_shard_iterator = ShardIterator(
self.streams_backend,
self.stream_shard,
"AFTER_SEQUENCE_NUMBER",
last_sequence_number,
)
except ValueError:
new_shard_iterator = ShardIterator(self.streams_backend,
self.stream_shard,
'AT_SEQUENCE_NUMBER',
self.sequence_number)
new_shard_iterator = ShardIterator(
self.streams_backend,
self.stream_shard,
"AT_SEQUENCE_NUMBER",
self.sequence_number,
)
self.streams_backend.shard_iterators[new_shard_iterator.arn] = new_shard_iterator
return {
'NextShardIterator': new_shard_iterator.arn,
'Records': items
}
self.streams_backend.shard_iterators[
new_shard_iterator.arn
] = new_shard_iterator
return {"NextShardIterator": new_shard_iterator.arn, "Records": items}
class DynamoDBStreamsBackend(BaseBackend):
@ -72,23 +80,27 @@ class DynamoDBStreamsBackend(BaseBackend):
return dynamodb_backends[self.region]
def _get_table_from_arn(self, arn):
table_name = arn.split(':', 6)[5].split('/')[1]
table_name = arn.split(":", 6)[5].split("/")[1]
return self.dynamodb.get_table(table_name)
def describe_stream(self, arn):
table = self._get_table_from_arn(arn)
resp = {'StreamDescription': {
'StreamArn': arn,
'StreamLabel': table.latest_stream_label,
'StreamStatus': ('ENABLED' if table.latest_stream_label
else 'DISABLED'),
'StreamViewType': table.stream_specification['StreamViewType'],
'CreationRequestDateTime': table.stream_shard.created_on.isoformat(),
'TableName': table.name,
'KeySchema': table.schema,
'Shards': ([table.stream_shard.to_json()] if table.stream_shard
else [])
}}
resp = {
"StreamDescription": {
"StreamArn": arn,
"StreamLabel": table.latest_stream_label,
"StreamStatus": (
"ENABLED" if table.latest_stream_label else "DISABLED"
),
"StreamViewType": table.stream_specification["StreamViewType"],
"CreationRequestDateTime": table.stream_shard.created_on.isoformat(),
"TableName": table.name,
"KeySchema": table.schema,
"Shards": (
[table.stream_shard.to_json()] if table.stream_shard else []
),
}
}
return json.dumps(resp)
@ -98,22 +110,26 @@ class DynamoDBStreamsBackend(BaseBackend):
if table_name is not None and table.name != table_name:
continue
if table.latest_stream_label:
d = table.describe(base_key='Table')
streams.append({
'StreamArn': d['Table']['LatestStreamArn'],
'TableName': d['Table']['TableName'],
'StreamLabel': d['Table']['LatestStreamLabel']
})
d = table.describe(base_key="Table")
streams.append(
{
"StreamArn": d["Table"]["LatestStreamArn"],
"TableName": d["Table"]["TableName"],
"StreamLabel": d["Table"]["LatestStreamLabel"],
}
)
return json.dumps({'Streams': streams})
return json.dumps({"Streams": streams})
def get_shard_iterator(self, arn, shard_id, shard_iterator_type, sequence_number=None):
def get_shard_iterator(
self, arn, shard_id, shard_iterator_type, sequence_number=None
):
table = self._get_table_from_arn(arn)
assert table.stream_shard.id == shard_id
shard_iterator = ShardIterator(self, table.stream_shard,
shard_iterator_type,
sequence_number)
shard_iterator = ShardIterator(
self, table.stream_shard, shard_iterator_type, sequence_number
)
self.shard_iterators[shard_iterator.arn] = shard_iterator
return json.dumps(shard_iterator.to_json())
@ -123,7 +139,7 @@ class DynamoDBStreamsBackend(BaseBackend):
return json.dumps(shard_iterator.get(limit))
available_regions = boto3.session.Session().get_available_regions(
'dynamodbstreams')
dynamodbstreams_backends = {region: DynamoDBStreamsBackend(region=region)
for region in available_regions}
available_regions = boto3.session.Session().get_available_regions("dynamodbstreams")
dynamodbstreams_backends = {
region: DynamoDBStreamsBackend(region=region) for region in available_regions
}

View File

@ -7,34 +7,34 @@ from six import string_types
class DynamoDBStreamsHandler(BaseResponse):
@property
def backend(self):
return dynamodbstreams_backends[self.region]
def describe_stream(self):
arn = self._get_param('StreamArn')
arn = self._get_param("StreamArn")
return self.backend.describe_stream(arn)
def list_streams(self):
table_name = self._get_param('TableName')
table_name = self._get_param("TableName")
return self.backend.list_streams(table_name)
def get_shard_iterator(self):
arn = self._get_param('StreamArn')
shard_id = self._get_param('ShardId')
shard_iterator_type = self._get_param('ShardIteratorType')
sequence_number = self._get_param('SequenceNumber')
arn = self._get_param("StreamArn")
shard_id = self._get_param("ShardId")
shard_iterator_type = self._get_param("ShardIteratorType")
sequence_number = self._get_param("SequenceNumber")
# according to documentation sequence_number param should be string
if isinstance(sequence_number, string_types):
sequence_number = int(sequence_number)
return self.backend.get_shard_iterator(arn, shard_id,
shard_iterator_type, sequence_number)
return self.backend.get_shard_iterator(
arn, shard_id, shard_iterator_type, sequence_number
)
def get_records(self):
arn = self._get_param('ShardIterator')
limit = self._get_param('Limit')
arn = self._get_param("ShardIterator")
limit = self._get_param("Limit")
if limit is None:
limit = 1000
return self.backend.get_records(arn, limit)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals
from .responses import DynamoDBStreamsHandler
url_bases = [
"https?://streams.dynamodb.(.+).amazonaws.com"
]
url_bases = ["https?://streams.dynamodb.(.+).amazonaws.com"]
url_paths = {
"{0}/$": DynamoDBStreamsHandler.dispatch,
}
url_paths = {"{0}/$": DynamoDBStreamsHandler.dispatch}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import ec2_backends
from ..core.models import base_decorator, deprecated_base_decorator
ec2_backend = ec2_backends['us-east-1']
ec2_backend = ec2_backends["us-east-1"]
mock_ec2 = base_decorator(ec2_backends)
mock_ec2_deprecated = deprecated_base_decorator(ec2_backends)

View File

@ -7,499 +7,468 @@ class EC2ClientError(RESTError):
class DependencyViolationError(EC2ClientError):
def __init__(self, message):
super(DependencyViolationError, self).__init__(
"DependencyViolation", message)
super(DependencyViolationError, self).__init__("DependencyViolation", message)
class MissingParameterError(EC2ClientError):
def __init__(self, parameter):
super(MissingParameterError, self).__init__(
"MissingParameter",
"The request must contain the parameter {0}"
.format(parameter))
"The request must contain the parameter {0}".format(parameter),
)
class InvalidDHCPOptionsIdError(EC2ClientError):
def __init__(self, dhcp_options_id):
super(InvalidDHCPOptionsIdError, self).__init__(
"InvalidDhcpOptionID.NotFound",
"DhcpOptionID {0} does not exist."
.format(dhcp_options_id))
"DhcpOptionID {0} does not exist.".format(dhcp_options_id),
)
class MalformedDHCPOptionsIdError(EC2ClientError):
def __init__(self, dhcp_options_id):
super(MalformedDHCPOptionsIdError, self).__init__(
"InvalidDhcpOptionsId.Malformed",
"Invalid id: \"{0}\" (expecting \"dopt-...\")"
.format(dhcp_options_id))
'Invalid id: "{0}" (expecting "dopt-...")'.format(dhcp_options_id),
)
class InvalidKeyPairNameError(EC2ClientError):
def __init__(self, key):
super(InvalidKeyPairNameError, self).__init__(
"InvalidKeyPair.NotFound",
"The keypair '{0}' does not exist."
.format(key))
"InvalidKeyPair.NotFound", "The keypair '{0}' does not exist.".format(key)
)
class InvalidKeyPairDuplicateError(EC2ClientError):
def __init__(self, key):
super(InvalidKeyPairDuplicateError, self).__init__(
"InvalidKeyPair.Duplicate",
"The keypair '{0}' already exists."
.format(key))
"InvalidKeyPair.Duplicate", "The keypair '{0}' already exists.".format(key)
)
class InvalidKeyPairFormatError(EC2ClientError):
def __init__(self):
super(InvalidKeyPairFormatError, self).__init__(
"InvalidKeyPair.Format",
"Key is not in valid OpenSSH public key format")
"InvalidKeyPair.Format", "Key is not in valid OpenSSH public key format"
)
class InvalidVPCIdError(EC2ClientError):
def __init__(self, vpc_id):
super(InvalidVPCIdError, self).__init__(
"InvalidVpcID.NotFound",
"VpcID {0} does not exist."
.format(vpc_id))
"InvalidVpcID.NotFound", "VpcID {0} does not exist.".format(vpc_id)
)
class InvalidSubnetIdError(EC2ClientError):
def __init__(self, subnet_id):
super(InvalidSubnetIdError, self).__init__(
"InvalidSubnetID.NotFound",
"The subnet ID '{0}' does not exist"
.format(subnet_id))
"The subnet ID '{0}' does not exist".format(subnet_id),
)
class InvalidNetworkAclIdError(EC2ClientError):
def __init__(self, network_acl_id):
super(InvalidNetworkAclIdError, self).__init__(
"InvalidNetworkAclID.NotFound",
"The network acl ID '{0}' does not exist"
.format(network_acl_id))
"The network acl ID '{0}' does not exist".format(network_acl_id),
)
class InvalidVpnGatewayIdError(EC2ClientError):
def __init__(self, network_acl_id):
super(InvalidVpnGatewayIdError, self).__init__(
"InvalidVpnGatewayID.NotFound",
"The virtual private gateway ID '{0}' does not exist"
.format(network_acl_id))
"The virtual private gateway ID '{0}' does not exist".format(
network_acl_id
),
)
class InvalidVpnConnectionIdError(EC2ClientError):
def __init__(self, network_acl_id):
super(InvalidVpnConnectionIdError, self).__init__(
"InvalidVpnConnectionID.NotFound",
"The vpnConnection ID '{0}' does not exist"
.format(network_acl_id))
"The vpnConnection ID '{0}' does not exist".format(network_acl_id),
)
class InvalidCustomerGatewayIdError(EC2ClientError):
def __init__(self, customer_gateway_id):
super(InvalidCustomerGatewayIdError, self).__init__(
"InvalidCustomerGatewayID.NotFound",
"The customer gateway ID '{0}' does not exist"
.format(customer_gateway_id))
"The customer gateway ID '{0}' does not exist".format(customer_gateway_id),
)
class InvalidNetworkInterfaceIdError(EC2ClientError):
def __init__(self, eni_id):
super(InvalidNetworkInterfaceIdError, self).__init__(
"InvalidNetworkInterfaceID.NotFound",
"The network interface ID '{0}' does not exist"
.format(eni_id))
"The network interface ID '{0}' does not exist".format(eni_id),
)
class InvalidNetworkAttachmentIdError(EC2ClientError):
def __init__(self, attachment_id):
super(InvalidNetworkAttachmentIdError, self).__init__(
"InvalidAttachmentID.NotFound",
"The network interface attachment ID '{0}' does not exist"
.format(attachment_id))
"The network interface attachment ID '{0}' does not exist".format(
attachment_id
),
)
class InvalidSecurityGroupDuplicateError(EC2ClientError):
def __init__(self, name):
super(InvalidSecurityGroupDuplicateError, self).__init__(
"InvalidGroup.Duplicate",
"The security group '{0}' already exists"
.format(name))
"The security group '{0}' already exists".format(name),
)
class InvalidSecurityGroupNotFoundError(EC2ClientError):
def __init__(self, name):
super(InvalidSecurityGroupNotFoundError, self).__init__(
"InvalidGroup.NotFound",
"The security group '{0}' does not exist"
.format(name))
"The security group '{0}' does not exist".format(name),
)
class InvalidPermissionNotFoundError(EC2ClientError):
def __init__(self):
super(InvalidPermissionNotFoundError, self).__init__(
"InvalidPermission.NotFound",
"The specified rule does not exist in this security group")
"The specified rule does not exist in this security group",
)
class InvalidPermissionDuplicateError(EC2ClientError):
def __init__(self):
super(InvalidPermissionDuplicateError, self).__init__(
"InvalidPermission.Duplicate",
"The specified rule already exists")
"InvalidPermission.Duplicate", "The specified rule already exists"
)
class InvalidRouteTableIdError(EC2ClientError):
def __init__(self, route_table_id):
super(InvalidRouteTableIdError, self).__init__(
"InvalidRouteTableID.NotFound",
"The routeTable ID '{0}' does not exist"
.format(route_table_id))
"The routeTable ID '{0}' does not exist".format(route_table_id),
)
class InvalidRouteError(EC2ClientError):
def __init__(self, route_table_id, cidr):
super(InvalidRouteError, self).__init__(
"InvalidRoute.NotFound",
"no route with destination-cidr-block {0} in route table {1}"
.format(cidr, route_table_id))
"no route with destination-cidr-block {0} in route table {1}".format(
cidr, route_table_id
),
)
class InvalidInstanceIdError(EC2ClientError):
def __init__(self, instance_id):
super(InvalidInstanceIdError, self).__init__(
"InvalidInstanceID.NotFound",
"The instance ID '{0}' does not exist"
.format(instance_id))
"The instance ID '{0}' does not exist".format(instance_id),
)
class InvalidAMIIdError(EC2ClientError):
def __init__(self, ami_id):
super(InvalidAMIIdError, self).__init__(
"InvalidAMIID.NotFound",
"The image id '[{0}]' does not exist"
.format(ami_id))
"The image id '[{0}]' does not exist".format(ami_id),
)
class InvalidAMIAttributeItemValueError(EC2ClientError):
def __init__(self, attribute, value):
super(InvalidAMIAttributeItemValueError, self).__init__(
"InvalidAMIAttributeItemValue",
"Invalid attribute item value \"{0}\" for {1} item type."
.format(value, attribute))
'Invalid attribute item value "{0}" for {1} item type.'.format(
value, attribute
),
)
class MalformedAMIIdError(EC2ClientError):
def __init__(self, ami_id):
super(MalformedAMIIdError, self).__init__(
"InvalidAMIID.Malformed",
"Invalid id: \"{0}\" (expecting \"ami-...\")"
.format(ami_id))
'Invalid id: "{0}" (expecting "ami-...")'.format(ami_id),
)
class InvalidSnapshotIdError(EC2ClientError):
def __init__(self, snapshot_id):
super(InvalidSnapshotIdError, self).__init__(
"InvalidSnapshot.NotFound",
"") # Note: AWS returns empty message for this, as of 2014.08.22.
"InvalidSnapshot.NotFound", ""
) # Note: AWS returns empty message for this, as of 2014.08.22.
class InvalidVolumeIdError(EC2ClientError):
def __init__(self, volume_id):
super(InvalidVolumeIdError, self).__init__(
"InvalidVolume.NotFound",
"The volume '{0}' does not exist."
.format(volume_id))
"The volume '{0}' does not exist.".format(volume_id),
)
class InvalidVolumeAttachmentError(EC2ClientError):
def __init__(self, volume_id, instance_id):
super(InvalidVolumeAttachmentError, self).__init__(
"InvalidAttachment.NotFound",
"Volume {0} can not be detached from {1} because it is not attached"
.format(volume_id, instance_id))
"Volume {0} can not be detached from {1} because it is not attached".format(
volume_id, instance_id
),
)
class InvalidDomainError(EC2ClientError):
def __init__(self, domain):
super(InvalidDomainError, self).__init__(
"InvalidParameterValue",
"Invalid value '{0}' for domain."
.format(domain))
"InvalidParameterValue", "Invalid value '{0}' for domain.".format(domain)
)
class InvalidAddressError(EC2ClientError):
def __init__(self, ip):
super(InvalidAddressError, self).__init__(
"InvalidAddress.NotFound",
"Address '{0}' not found."
.format(ip))
"InvalidAddress.NotFound", "Address '{0}' not found.".format(ip)
)
class InvalidAllocationIdError(EC2ClientError):
def __init__(self, allocation_id):
super(InvalidAllocationIdError, self).__init__(
"InvalidAllocationID.NotFound",
"Allocation ID '{0}' not found."
.format(allocation_id))
"Allocation ID '{0}' not found.".format(allocation_id),
)
class InvalidAssociationIdError(EC2ClientError):
def __init__(self, association_id):
super(InvalidAssociationIdError, self).__init__(
"InvalidAssociationID.NotFound",
"Association ID '{0}' not found."
.format(association_id))
"Association ID '{0}' not found.".format(association_id),
)
class InvalidVpcCidrBlockAssociationIdError(EC2ClientError):
def __init__(self, association_id):
super(InvalidVpcCidrBlockAssociationIdError, self).__init__(
"InvalidVpcCidrBlockAssociationIdError.NotFound",
"The vpc CIDR block association ID '{0}' does not exist"
.format(association_id))
"The vpc CIDR block association ID '{0}' does not exist".format(
association_id
),
)
class InvalidVPCPeeringConnectionIdError(EC2ClientError):
def __init__(self, vpc_peering_connection_id):
super(InvalidVPCPeeringConnectionIdError, self).__init__(
"InvalidVpcPeeringConnectionId.NotFound",
"VpcPeeringConnectionID {0} does not exist."
.format(vpc_peering_connection_id))
"VpcPeeringConnectionID {0} does not exist.".format(
vpc_peering_connection_id
),
)
class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
def __init__(self, vpc_peering_connection_id):
super(InvalidVPCPeeringConnectionStateTransitionError, self).__init__(
"InvalidStateTransition",
"VpcPeeringConnectionID {0} is not in the correct state for the request."
.format(vpc_peering_connection_id))
"VpcPeeringConnectionID {0} is not in the correct state for the request.".format(
vpc_peering_connection_id
),
)
class InvalidParameterValueError(EC2ClientError):
def __init__(self, parameter_value):
super(InvalidParameterValueError, self).__init__(
"InvalidParameterValue",
"Value {0} is invalid for parameter."
.format(parameter_value))
"Value {0} is invalid for parameter.".format(parameter_value),
)
class InvalidParameterValueErrorTagNull(EC2ClientError):
def __init__(self):
super(InvalidParameterValueErrorTagNull, self).__init__(
"InvalidParameterValue",
"Tag value cannot be null. Use empty string instead.")
"Tag value cannot be null. Use empty string instead.",
)
class InvalidParameterValueErrorUnknownAttribute(EC2ClientError):
def __init__(self, parameter_value):
super(InvalidParameterValueErrorUnknownAttribute, self).__init__(
"InvalidParameterValue",
"Value ({0}) for parameter attribute is invalid. Unknown attribute."
.format(parameter_value))
"Value ({0}) for parameter attribute is invalid. Unknown attribute.".format(
parameter_value
),
)
class InvalidInternetGatewayIdError(EC2ClientError):
def __init__(self, internet_gateway_id):
super(InvalidInternetGatewayIdError, self).__init__(
"InvalidInternetGatewayID.NotFound",
"InternetGatewayID {0} does not exist."
.format(internet_gateway_id))
"InternetGatewayID {0} does not exist.".format(internet_gateway_id),
)
class GatewayNotAttachedError(EC2ClientError):
def __init__(self, internet_gateway_id, vpc_id):
super(GatewayNotAttachedError, self).__init__(
"Gateway.NotAttached",
"InternetGatewayID {0} is not attached to a VPC {1}."
.format(internet_gateway_id, vpc_id))
"InternetGatewayID {0} is not attached to a VPC {1}.".format(
internet_gateway_id, vpc_id
),
)
class ResourceAlreadyAssociatedError(EC2ClientError):
def __init__(self, resource_id):
super(ResourceAlreadyAssociatedError, self).__init__(
"Resource.AlreadyAssociated",
"Resource {0} is already associated."
.format(resource_id))
"Resource {0} is already associated.".format(resource_id),
)
class TagLimitExceeded(EC2ClientError):
def __init__(self):
super(TagLimitExceeded, self).__init__(
"TagLimitExceeded",
"The maximum number of Tags for a resource has been reached.")
"The maximum number of Tags for a resource has been reached.",
)
class InvalidID(EC2ClientError):
def __init__(self, resource_id):
super(InvalidID, self).__init__(
"InvalidID",
"The ID '{0}' is not valid"
.format(resource_id))
"InvalidID", "The ID '{0}' is not valid".format(resource_id)
)
class InvalidCIDRSubnetError(EC2ClientError):
def __init__(self, cidr):
super(InvalidCIDRSubnetError, self).__init__(
"InvalidParameterValue",
"invalid CIDR subnet specification: {0}"
.format(cidr))
"invalid CIDR subnet specification: {0}".format(cidr),
)
class RulesPerSecurityGroupLimitExceededError(EC2ClientError):
def __init__(self):
super(RulesPerSecurityGroupLimitExceededError, self).__init__(
"RulesPerSecurityGroupLimitExceeded",
'The maximum number of rules per security group '
'has been reached.')
"The maximum number of rules per security group " "has been reached.",
)
class MotoNotImplementedError(NotImplementedError):
def __init__(self, blurb):
super(MotoNotImplementedError, self).__init__(
"{0} has not been implemented in Moto yet."
" Feel free to open an issue at"
" https://github.com/spulec/moto/issues".format(blurb))
" https://github.com/spulec/moto/issues".format(blurb)
)
class FilterNotImplementedError(MotoNotImplementedError):
def __init__(self, filter_name, method_name):
super(FilterNotImplementedError, self).__init__(
"The filter '{0}' for {1}".format(
filter_name, method_name))
"The filter '{0}' for {1}".format(filter_name, method_name)
)
class CidrLimitExceeded(EC2ClientError):
def __init__(self, vpc_id, max_cidr_limit):
super(CidrLimitExceeded, self).__init__(
"CidrLimitExceeded",
"This network '{0}' has met its maximum number of allowed CIDRs: {1}".format(vpc_id, max_cidr_limit)
"This network '{0}' has met its maximum number of allowed CIDRs: {1}".format(
vpc_id, max_cidr_limit
),
)
class OperationNotPermitted(EC2ClientError):
def __init__(self, association_id):
super(OperationNotPermitted, self).__init__(
"OperationNotPermitted",
"The vpc CIDR block with association ID {} may not be disassociated. "
"It is the primary IPv4 CIDR block of the VPC".format(association_id)
"It is the primary IPv4 CIDR block of the VPC".format(association_id),
)
class InvalidAvailabilityZoneError(EC2ClientError):
def __init__(self, availability_zone_value, valid_availability_zones):
super(InvalidAvailabilityZoneError, self).__init__(
"InvalidParameterValue",
"Value ({0}) for parameter availabilityZone is invalid. "
"Subnets can currently only be created in the following availability zones: {1}.".format(availability_zone_value, valid_availability_zones)
"Subnets can currently only be created in the following availability zones: {1}.".format(
availability_zone_value, valid_availability_zones
),
)
class NetworkAclEntryAlreadyExistsError(EC2ClientError):
def __init__(self, rule_number):
super(NetworkAclEntryAlreadyExistsError, self).__init__(
"NetworkAclEntryAlreadyExists",
"The network acl entry identified by {} already exists.".format(rule_number)
"The network acl entry identified by {} already exists.".format(
rule_number
),
)
class InvalidSubnetRangeError(EC2ClientError):
def __init__(self, cidr_block):
super(InvalidSubnetRangeError, self).__init__(
"InvalidSubnet.Range",
"The CIDR '{}' is invalid.".format(cidr_block)
"InvalidSubnet.Range", "The CIDR '{}' is invalid.".format(cidr_block)
)
class InvalidCIDRBlockParameterError(EC2ClientError):
def __init__(self, cidr_block):
super(InvalidCIDRBlockParameterError, self).__init__(
"InvalidParameterValue",
"Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block)
"Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(
cidr_block
),
)
class InvalidDestinationCIDRBlockParameterError(EC2ClientError):
def __init__(self, cidr_block):
super(InvalidDestinationCIDRBlockParameterError, self).__init__(
"InvalidParameterValue",
"Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block)
"Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(
cidr_block
),
)
class InvalidSubnetConflictError(EC2ClientError):
def __init__(self, cidr_block):
super(InvalidSubnetConflictError, self).__init__(
"InvalidSubnet.Conflict",
"The CIDR '{}' conflicts with another subnet".format(cidr_block)
"The CIDR '{}' conflicts with another subnet".format(cidr_block),
)
class InvalidVPCRangeError(EC2ClientError):
def __init__(self, cidr_block):
super(InvalidVPCRangeError, self).__init__(
"InvalidVpc.Range",
"The CIDR '{}' is invalid.".format(cidr_block)
"InvalidVpc.Range", "The CIDR '{}' is invalid.".format(cidr_block)
)
@ -509,7 +478,9 @@ class OperationNotPermitted2(EC2ClientError):
super(OperationNotPermitted2, self).__init__(
"OperationNotPermitted",
"Incorrect region ({0}) specified for this request."
"VPC peering connection {1} must be accepted in region {2}".format(client_region, pcx_id, acceptor_region)
"VPC peering connection {1} must be accepted in region {2}".format(
client_region, pcx_id, acceptor_region
),
)
@ -519,9 +490,9 @@ class OperationNotPermitted3(EC2ClientError):
super(OperationNotPermitted3, self).__init__(
"OperationNotPermitted",
"Incorrect region ({0}) specified for this request."
"VPC peering connection {1} must be accepted or rejected in region {2}".format(client_region,
pcx_id,
acceptor_region)
"VPC peering connection {1} must be accepted or rejected in region {2}".format(
client_region, pcx_id, acceptor_region
),
)
@ -529,5 +500,5 @@ class InvalidLaunchTemplateNameError(EC2ClientError):
def __init__(self):
super(InvalidLaunchTemplateNameError, self).__init__(
"InvalidLaunchTemplateName.AlreadyExistsException",
"Launch template name already in use."
"Launch template name already in use.",
)

File diff suppressed because it is too large Load Diff

View File

@ -70,10 +70,10 @@ class EC2Response(
Windows,
NatGateways,
):
@property
def ec2_backend(self):
from moto.ec2.models import ec2_backends
return ec2_backends[self.region]
@property

View File

@ -3,13 +3,12 @@ from moto.core.responses import BaseResponse
class AccountAttributes(BaseResponse):
def describe_account_attributes(self):
template = self.response_template(DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT)
return template.render()
DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = u"""
DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = """
<DescribeAccountAttributesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<accountAttributeSet>

View File

@ -3,7 +3,7 @@ from moto.core.responses import BaseResponse
class AmazonDevPay(BaseResponse):
def confirm_product_instance(self):
raise NotImplementedError(
'AmazonDevPay.confirm_product_instance is not yet implemented')
"AmazonDevPay.confirm_product_instance is not yet implemented"
)

View File

@ -4,76 +4,83 @@ from moto.ec2.utils import filters_from_querystring
class AmisResponse(BaseResponse):
def create_image(self):
name = self.querystring.get('Name')[0]
description = self._get_param('Description', if_none='')
instance_id = self._get_param('InstanceId')
if self.is_not_dryrun('CreateImage'):
name = self.querystring.get("Name")[0]
description = self._get_param("Description", if_none="")
instance_id = self._get_param("InstanceId")
if self.is_not_dryrun("CreateImage"):
image = self.ec2_backend.create_image(
instance_id, name, description, context=self)
instance_id, name, description, context=self
)
template = self.response_template(CREATE_IMAGE_RESPONSE)
return template.render(image=image)
def copy_image(self):
source_image_id = self._get_param('SourceImageId')
source_region = self._get_param('SourceRegion')
name = self._get_param('Name')
description = self._get_param('Description')
if self.is_not_dryrun('CopyImage'):
source_image_id = self._get_param("SourceImageId")
source_region = self._get_param("SourceRegion")
name = self._get_param("Name")
description = self._get_param("Description")
if self.is_not_dryrun("CopyImage"):
image = self.ec2_backend.copy_image(
source_image_id, source_region, name, description)
source_image_id, source_region, name, description
)
template = self.response_template(COPY_IMAGE_RESPONSE)
return template.render(image=image)
def deregister_image(self):
ami_id = self._get_param('ImageId')
if self.is_not_dryrun('DeregisterImage'):
ami_id = self._get_param("ImageId")
if self.is_not_dryrun("DeregisterImage"):
success = self.ec2_backend.deregister_image(ami_id)
template = self.response_template(DEREGISTER_IMAGE_RESPONSE)
return template.render(success=str(success).lower())
def describe_images(self):
ami_ids = self._get_multi_param('ImageId')
ami_ids = self._get_multi_param("ImageId")
filters = filters_from_querystring(self.querystring)
owners = self._get_multi_param('Owner')
exec_users = self._get_multi_param('ExecutableBy')
owners = self._get_multi_param("Owner")
exec_users = self._get_multi_param("ExecutableBy")
images = self.ec2_backend.describe_images(
ami_ids=ami_ids, filters=filters, exec_users=exec_users,
owners=owners, context=self)
ami_ids=ami_ids,
filters=filters,
exec_users=exec_users,
owners=owners,
context=self,
)
template = self.response_template(DESCRIBE_IMAGES_RESPONSE)
return template.render(images=images)
def describe_image_attribute(self):
ami_id = self._get_param('ImageId')
ami_id = self._get_param("ImageId")
groups = self.ec2_backend.get_launch_permission_groups(ami_id)
users = self.ec2_backend.get_launch_permission_users(ami_id)
template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE)
return template.render(ami_id=ami_id, groups=groups, users=users)
def modify_image_attribute(self):
ami_id = self._get_param('ImageId')
operation_type = self._get_param('OperationType')
group = self._get_param('UserGroup.1')
user_ids = self._get_multi_param('UserId')
if self.is_not_dryrun('ModifyImageAttribute'):
if (operation_type == 'add'):
ami_id = self._get_param("ImageId")
operation_type = self._get_param("OperationType")
group = self._get_param("UserGroup.1")
user_ids = self._get_multi_param("UserId")
if self.is_not_dryrun("ModifyImageAttribute"):
if operation_type == "add":
self.ec2_backend.add_launch_permission(
ami_id, user_ids=user_ids, group=group)
elif (operation_type == 'remove'):
ami_id, user_ids=user_ids, group=group
)
elif operation_type == "remove":
self.ec2_backend.remove_launch_permission(
ami_id, user_ids=user_ids, group=group)
ami_id, user_ids=user_ids, group=group
)
return MODIFY_IMAGE_ATTRIBUTE_RESPONSE
def register_image(self):
if self.is_not_dryrun('RegisterImage'):
raise NotImplementedError(
'AMIs.register_image is not yet implemented')
if self.is_not_dryrun("RegisterImage"):
raise NotImplementedError("AMIs.register_image is not yet implemented")
def reset_image_attribute(self):
if self.is_not_dryrun('ResetImageAttribute'):
if self.is_not_dryrun("ResetImageAttribute"):
raise NotImplementedError(
'AMIs.reset_image_attribute is not yet implemented')
"AMIs.reset_image_attribute is not yet implemented"
)
CREATE_IMAGE_RESPONSE = """<CreateImageResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">

View File

@ -3,14 +3,13 @@ from moto.core.responses import BaseResponse
class AvailabilityZonesAndRegions(BaseResponse):
def describe_availability_zones(self):
zones = self.ec2_backend.describe_availability_zones()
template = self.response_template(DESCRIBE_ZONES_RESPONSE)
return template.render(zones=zones)
def describe_regions(self):
region_names = self._get_multi_param('RegionName')
region_names = self._get_multi_param("RegionName")
regions = self.ec2_backend.describe_regions(region_names)
template = self.response_template(DESCRIBE_REGIONS_RESPONSE)
return template.render(regions=regions)

View File

@ -4,21 +4,20 @@ from moto.ec2.utils import filters_from_querystring
class CustomerGateways(BaseResponse):
def create_customer_gateway(self):
# raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented')
type = self._get_param('Type')
ip_address = self._get_param('IpAddress')
bgp_asn = self._get_param('BgpAsn')
type = self._get_param("Type")
ip_address = self._get_param("IpAddress")
bgp_asn = self._get_param("BgpAsn")
customer_gateway = self.ec2_backend.create_customer_gateway(
type, ip_address=ip_address, bgp_asn=bgp_asn)
type, ip_address=ip_address, bgp_asn=bgp_asn
)
template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=customer_gateway)
def delete_customer_gateway(self):
customer_gateway_id = self._get_param('CustomerGatewayId')
delete_status = self.ec2_backend.delete_customer_gateway(
customer_gateway_id)
customer_gateway_id = self._get_param("CustomerGatewayId")
delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id)
template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=delete_status)

View File

@ -1,15 +1,12 @@
from __future__ import unicode_literals
from moto.core.responses import BaseResponse
from moto.ec2.utils import (
filters_from_querystring,
dhcp_configuration_from_querystring)
from moto.ec2.utils import filters_from_querystring, dhcp_configuration_from_querystring
class DHCPOptions(BaseResponse):
def associate_dhcp_options(self):
dhcp_opt_id = self._get_param('DhcpOptionsId')
vpc_id = self._get_param('VpcId')
dhcp_opt_id = self._get_param("DhcpOptionsId")
vpc_id = self._get_param("VpcId")
dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0]
vpc = self.ec2_backend.get_vpc(vpc_id)
@ -35,14 +32,14 @@ class DHCPOptions(BaseResponse):
domain_name=domain_name,
ntp_servers=ntp_servers,
netbios_name_servers=netbios_name_servers,
netbios_node_type=netbios_node_type
netbios_node_type=netbios_node_type,
)
template = self.response_template(CREATE_DHCP_OPTIONS_RESPONSE)
return template.render(dhcp_options_set=dhcp_options_set)
def delete_dhcp_options(self):
dhcp_opt_id = self._get_param('DhcpOptionsId')
dhcp_opt_id = self._get_param("DhcpOptionsId")
delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id)
template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE)
return template.render(delete_status=delete_status)
@ -50,13 +47,12 @@ class DHCPOptions(BaseResponse):
def describe_dhcp_options(self):
dhcp_opt_ids = self._get_multi_param("DhcpOptionsId")
filters = filters_from_querystring(self.querystring)
dhcp_opts = self.ec2_backend.get_all_dhcp_options(
dhcp_opt_ids, filters)
dhcp_opts = self.ec2_backend.get_all_dhcp_options(dhcp_opt_ids, filters)
template = self.response_template(DESCRIBE_DHCP_OPTIONS_RESPONSE)
return template.render(dhcp_options=dhcp_opts)
CREATE_DHCP_OPTIONS_RESPONSE = u"""
CREATE_DHCP_OPTIONS_RESPONSE = """
<CreateDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<dhcpOptions>
@ -92,14 +88,14 @@ CREATE_DHCP_OPTIONS_RESPONSE = u"""
</CreateDhcpOptionsResponse>
"""
DELETE_DHCP_OPTIONS_RESPONSE = u"""
DELETE_DHCP_OPTIONS_RESPONSE = """
<DeleteDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<return>{{delete_status}}</return>
</DeleteDhcpOptionsResponse>
"""
DESCRIBE_DHCP_OPTIONS_RESPONSE = u"""
DESCRIBE_DHCP_OPTIONS_RESPONSE = """
<DescribeDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<dhcpOptionsSet>
@ -139,7 +135,7 @@ DESCRIBE_DHCP_OPTIONS_RESPONSE = u"""
</DescribeDhcpOptionsResponse>
"""
ASSOCIATE_DHCP_OPTIONS_RESPONSE = u"""
ASSOCIATE_DHCP_OPTIONS_RESPONSE = """
<AssociateDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<return>true</return>

View File

@ -4,137 +4,148 @@ from moto.ec2.utils import filters_from_querystring
class ElasticBlockStore(BaseResponse):
def attach_volume(self):
volume_id = self._get_param('VolumeId')
instance_id = self._get_param('InstanceId')
device_path = self._get_param('Device')
if self.is_not_dryrun('AttachVolume'):
volume_id = self._get_param("VolumeId")
instance_id = self._get_param("InstanceId")
device_path = self._get_param("Device")
if self.is_not_dryrun("AttachVolume"):
attachment = self.ec2_backend.attach_volume(
volume_id, instance_id, device_path)
volume_id, instance_id, device_path
)
template = self.response_template(ATTACHED_VOLUME_RESPONSE)
return template.render(attachment=attachment)
def copy_snapshot(self):
source_snapshot_id = self._get_param('SourceSnapshotId')
source_region = self._get_param('SourceRegion')
description = self._get_param('Description')
if self.is_not_dryrun('CopySnapshot'):
source_snapshot_id = self._get_param("SourceSnapshotId")
source_region = self._get_param("SourceRegion")
description = self._get_param("Description")
if self.is_not_dryrun("CopySnapshot"):
snapshot = self.ec2_backend.copy_snapshot(
source_snapshot_id, source_region, description)
source_snapshot_id, source_region, description
)
template = self.response_template(COPY_SNAPSHOT_RESPONSE)
return template.render(snapshot=snapshot)
def create_snapshot(self):
volume_id = self._get_param('VolumeId')
description = self._get_param('Description')
volume_id = self._get_param("VolumeId")
description = self._get_param("Description")
tags = self._parse_tag_specification("TagSpecification")
snapshot_tags = tags.get('snapshot', {})
if self.is_not_dryrun('CreateSnapshot'):
snapshot_tags = tags.get("snapshot", {})
if self.is_not_dryrun("CreateSnapshot"):
snapshot = self.ec2_backend.create_snapshot(volume_id, description)
snapshot.add_tags(snapshot_tags)
template = self.response_template(CREATE_SNAPSHOT_RESPONSE)
return template.render(snapshot=snapshot)
def create_volume(self):
size = self._get_param('Size')
zone = self._get_param('AvailabilityZone')
snapshot_id = self._get_param('SnapshotId')
size = self._get_param("Size")
zone = self._get_param("AvailabilityZone")
snapshot_id = self._get_param("SnapshotId")
tags = self._parse_tag_specification("TagSpecification")
volume_tags = tags.get('volume', {})
encrypted = self._get_param('Encrypted', if_none=False)
if self.is_not_dryrun('CreateVolume'):
volume = self.ec2_backend.create_volume(
size, zone, snapshot_id, encrypted)
volume_tags = tags.get("volume", {})
encrypted = self._get_param("Encrypted", if_none=False)
if self.is_not_dryrun("CreateVolume"):
volume = self.ec2_backend.create_volume(size, zone, snapshot_id, encrypted)
volume.add_tags(volume_tags)
template = self.response_template(CREATE_VOLUME_RESPONSE)
return template.render(volume=volume)
def delete_snapshot(self):
snapshot_id = self._get_param('SnapshotId')
if self.is_not_dryrun('DeleteSnapshot'):
snapshot_id = self._get_param("SnapshotId")
if self.is_not_dryrun("DeleteSnapshot"):
self.ec2_backend.delete_snapshot(snapshot_id)
return DELETE_SNAPSHOT_RESPONSE
def delete_volume(self):
volume_id = self._get_param('VolumeId')
if self.is_not_dryrun('DeleteVolume'):
volume_id = self._get_param("VolumeId")
if self.is_not_dryrun("DeleteVolume"):
self.ec2_backend.delete_volume(volume_id)
return DELETE_VOLUME_RESPONSE
def describe_snapshots(self):
filters = filters_from_querystring(self.querystring)
snapshot_ids = self._get_multi_param('SnapshotId')
snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters)
snapshot_ids = self._get_multi_param("SnapshotId")
snapshots = self.ec2_backend.describe_snapshots(
snapshot_ids=snapshot_ids, filters=filters
)
template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE)
return template.render(snapshots=snapshots)
def describe_volumes(self):
filters = filters_from_querystring(self.querystring)
volume_ids = self._get_multi_param('VolumeId')
volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters)
volume_ids = self._get_multi_param("VolumeId")
volumes = self.ec2_backend.describe_volumes(
volume_ids=volume_ids, filters=filters
)
template = self.response_template(DESCRIBE_VOLUMES_RESPONSE)
return template.render(volumes=volumes)
def describe_volume_attribute(self):
raise NotImplementedError(
'ElasticBlockStore.describe_volume_attribute is not yet implemented')
"ElasticBlockStore.describe_volume_attribute is not yet implemented"
)
def describe_volume_status(self):
raise NotImplementedError(
'ElasticBlockStore.describe_volume_status is not yet implemented')
"ElasticBlockStore.describe_volume_status is not yet implemented"
)
def detach_volume(self):
volume_id = self._get_param('VolumeId')
instance_id = self._get_param('InstanceId')
device_path = self._get_param('Device')
if self.is_not_dryrun('DetachVolume'):
volume_id = self._get_param("VolumeId")
instance_id = self._get_param("InstanceId")
device_path = self._get_param("Device")
if self.is_not_dryrun("DetachVolume"):
attachment = self.ec2_backend.detach_volume(
volume_id, instance_id, device_path)
volume_id, instance_id, device_path
)
template = self.response_template(DETATCH_VOLUME_RESPONSE)
return template.render(attachment=attachment)
def enable_volume_io(self):
if self.is_not_dryrun('EnableVolumeIO'):
if self.is_not_dryrun("EnableVolumeIO"):
raise NotImplementedError(
'ElasticBlockStore.enable_volume_io is not yet implemented')
"ElasticBlockStore.enable_volume_io is not yet implemented"
)
def import_volume(self):
if self.is_not_dryrun('ImportVolume'):
if self.is_not_dryrun("ImportVolume"):
raise NotImplementedError(
'ElasticBlockStore.import_volume is not yet implemented')
"ElasticBlockStore.import_volume is not yet implemented"
)
def describe_snapshot_attribute(self):
snapshot_id = self._get_param('SnapshotId')
groups = self.ec2_backend.get_create_volume_permission_groups(
snapshot_id)
template = self.response_template(
DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE)
snapshot_id = self._get_param("SnapshotId")
groups = self.ec2_backend.get_create_volume_permission_groups(snapshot_id)
template = self.response_template(DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE)
return template.render(snapshot_id=snapshot_id, groups=groups)
def modify_snapshot_attribute(self):
snapshot_id = self._get_param('SnapshotId')
operation_type = self._get_param('OperationType')
group = self._get_param('UserGroup.1')
user_id = self._get_param('UserId.1')
if self.is_not_dryrun('ModifySnapshotAttribute'):
if (operation_type == 'add'):
snapshot_id = self._get_param("SnapshotId")
operation_type = self._get_param("OperationType")
group = self._get_param("UserGroup.1")
user_id = self._get_param("UserId.1")
if self.is_not_dryrun("ModifySnapshotAttribute"):
if operation_type == "add":
self.ec2_backend.add_create_volume_permission(
snapshot_id, user_id=user_id, group=group)
elif (operation_type == 'remove'):
snapshot_id, user_id=user_id, group=group
)
elif operation_type == "remove":
self.ec2_backend.remove_create_volume_permission(
snapshot_id, user_id=user_id, group=group)
snapshot_id, user_id=user_id, group=group
)
return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE
def modify_volume_attribute(self):
if self.is_not_dryrun('ModifyVolumeAttribute'):
if self.is_not_dryrun("ModifyVolumeAttribute"):
raise NotImplementedError(
'ElasticBlockStore.modify_volume_attribute is not yet implemented')
"ElasticBlockStore.modify_volume_attribute is not yet implemented"
)
def reset_snapshot_attribute(self):
if self.is_not_dryrun('ResetSnapshotAttribute'):
if self.is_not_dryrun("ResetSnapshotAttribute"):
raise NotImplementedError(
'ElasticBlockStore.reset_snapshot_attribute is not yet implemented')
"ElasticBlockStore.reset_snapshot_attribute is not yet implemented"
)
CREATE_VOLUME_RESPONSE = """<CreateVolumeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">

View File

@ -4,14 +4,14 @@ from moto.ec2.utils import filters_from_querystring
class ElasticIPAddresses(BaseResponse):
def allocate_address(self):
domain = self._get_param('Domain', if_none='standard')
reallocate_address = self._get_param('Address', if_none=None)
if self.is_not_dryrun('AllocateAddress'):
domain = self._get_param("Domain", if_none="standard")
reallocate_address = self._get_param("Address", if_none=None)
if self.is_not_dryrun("AllocateAddress"):
if reallocate_address:
address = self.ec2_backend.allocate_address(
domain, address=reallocate_address)
domain, address=reallocate_address
)
else:
address = self.ec2_backend.allocate_address(domain)
template = self.response_template(ALLOCATE_ADDRESS_RESPONSE)
@ -21,73 +21,92 @@ class ElasticIPAddresses(BaseResponse):
instance = eni = None
if "InstanceId" in self.querystring:
instance = self.ec2_backend.get_instance(
self._get_param('InstanceId'))
instance = self.ec2_backend.get_instance(self._get_param("InstanceId"))
elif "NetworkInterfaceId" in self.querystring:
eni = self.ec2_backend.get_network_interface(
self._get_param('NetworkInterfaceId'))
self._get_param("NetworkInterfaceId")
)
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.")
"MissingParameter",
"Invalid request, expect InstanceId/NetworkId parameter.",
)
reassociate = False
if "AllowReassociation" in self.querystring:
reassociate = self._get_param('AllowReassociation') == "true"
reassociate = self._get_param("AllowReassociation") == "true"
if self.is_not_dryrun('AssociateAddress'):
if self.is_not_dryrun("AssociateAddress"):
if instance or eni:
if "PublicIp" in self.querystring:
eip = self.ec2_backend.associate_address(
instance=instance, eni=eni,
address=self._get_param('PublicIp'), reassociate=reassociate)
instance=instance,
eni=eni,
address=self._get_param("PublicIp"),
reassociate=reassociate,
)
elif "AllocationId" in self.querystring:
eip = self.ec2_backend.associate_address(
instance=instance, eni=eni,
allocation_id=self._get_param('AllocationId'), reassociate=reassociate)
instance=instance,
eni=eni,
allocation_id=self._get_param("AllocationId"),
reassociate=reassociate,
)
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")
"MissingParameter",
"Invalid request, expect PublicIp/AllocationId parameter.",
)
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect either instance or ENI.")
"MissingParameter",
"Invalid request, expect either instance or ENI.",
)
template = self.response_template(ASSOCIATE_ADDRESS_RESPONSE)
return template.render(address=eip)
def describe_addresses(self):
allocation_ids = self._get_multi_param('AllocationId')
public_ips = self._get_multi_param('PublicIp')
allocation_ids = self._get_multi_param("AllocationId")
public_ips = self._get_multi_param("PublicIp")
filters = filters_from_querystring(self.querystring)
addresses = self.ec2_backend.describe_addresses(
allocation_ids, public_ips, filters)
allocation_ids, public_ips, filters
)
template = self.response_template(DESCRIBE_ADDRESS_RESPONSE)
return template.render(addresses=addresses)
def disassociate_address(self):
if self.is_not_dryrun('DisAssociateAddress'):
if self.is_not_dryrun("DisAssociateAddress"):
if "PublicIp" in self.querystring:
self.ec2_backend.disassociate_address(
address=self._get_param('PublicIp'))
address=self._get_param("PublicIp")
)
elif "AssociationId" in self.querystring:
self.ec2_backend.disassociate_address(
association_id=self._get_param('AssociationId'))
association_id=self._get_param("AssociationId")
)
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.")
"MissingParameter",
"Invalid request, expect PublicIp/AssociationId parameter.",
)
return self.response_template(DISASSOCIATE_ADDRESS_RESPONSE).render()
def release_address(self):
if self.is_not_dryrun('ReleaseAddress'):
if self.is_not_dryrun("ReleaseAddress"):
if "PublicIp" in self.querystring:
self.ec2_backend.release_address(
address=self._get_param('PublicIp'))
self.ec2_backend.release_address(address=self._get_param("PublicIp"))
elif "AllocationId" in self.querystring:
self.ec2_backend.release_address(
allocation_id=self._get_param('AllocationId'))
allocation_id=self._get_param("AllocationId")
)
else:
self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.")
"MissingParameter",
"Invalid request, expect PublicIp/AllocationId parameter.",
)
return self.response_template(RELEASE_ADDRESS_RESPONSE).render()

Some files were not shown because too many files have changed in this diff Show More