Run black on moto & test directories.

This commit is contained in:
Asher Foa 2019-10-31 08:44:26 -07:00
parent c820395dbf
commit 96e5b1993d
507 changed files with 52521 additions and 47794 deletions

View File

@ -1,9 +1,10 @@
from __future__ import unicode_literals from __future__ import unicode_literals
# import logging # import logging
# logging.getLogger('boto').setLevel(logging.CRITICAL) # logging.getLogger('boto').setLevel(logging.CRITICAL)
__title__ = 'moto' __title__ = "moto"
__version__ = '1.3.14.dev' __version__ = "1.3.14.dev"
from .acm import mock_acm # noqa from .acm import mock_acm # noqa
from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa
@ -12,7 +13,10 @@ from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # noqa
from .awslambda import mock_lambda, mock_lambda_deprecated # noqa from .awslambda import mock_lambda, mock_lambda_deprecated # noqa
from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # noqa from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # noqa
from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa
from .cognitoidentity import mock_cognitoidentity, mock_cognitoidentity_deprecated # noqa from .cognitoidentity import ( # noqa
mock_cognitoidentity,
mock_cognitoidentity_deprecated,
)
from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa
from .config import mock_config # noqa from .config import mock_config # noqa
from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # noqa from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # noqa
@ -58,7 +62,12 @@ from .iotdata import mock_iotdata # noqa
try: try:
# Need to monkey-patch botocore requests back to underlying urllib3 classes # Need to monkey-patch botocore requests back to underlying urllib3 classes
from botocore.awsrequest import HTTPSConnectionPool, HTTPConnectionPool, HTTPConnection, VerifiedHTTPSConnection from botocore.awsrequest import (
HTTPSConnectionPool,
HTTPConnectionPool,
HTTPConnection,
VerifiedHTTPSConnection,
)
except ImportError: except ImportError:
pass pass
else: else:

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import acm_backends from .models import acm_backends
from ..core.models import base_decorator from ..core.models import base_decorator
acm_backend = acm_backends['us-east-1'] acm_backend = acm_backends["us-east-1"]
mock_acm = base_decorator(acm_backends) mock_acm = base_decorator(acm_backends)

View File

@ -57,20 +57,29 @@ class AWSError(Exception):
self.message = message self.message = message
def response(self): def response(self):
resp = {'__type': self.TYPE, 'message': self.message} resp = {"__type": self.TYPE, "message": self.message}
return json.dumps(resp), dict(status=self.STATUS) return json.dumps(resp), dict(status=self.STATUS)
class AWSValidationException(AWSError): class AWSValidationException(AWSError):
TYPE = 'ValidationException' TYPE = "ValidationException"
class AWSResourceNotFoundException(AWSError): class AWSResourceNotFoundException(AWSError):
TYPE = 'ResourceNotFoundException' TYPE = "ResourceNotFoundException"
class CertBundle(BaseModel): class CertBundle(BaseModel):
def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'): def __init__(
self,
certificate,
private_key,
chain=None,
region="us-east-1",
arn=None,
cert_type="IMPORTED",
cert_status="ISSUED",
):
self.created_at = datetime.datetime.now() self.created_at = datetime.datetime.now()
self.cert = certificate self.cert = certificate
self._cert = None self._cert = None
@ -87,7 +96,7 @@ class CertBundle(BaseModel):
if self.chain is None: if self.chain is None:
self.chain = GOOGLE_ROOT_CA self.chain = GOOGLE_ROOT_CA
else: else:
self.chain += b'\n' + GOOGLE_ROOT_CA self.chain += b"\n" + GOOGLE_ROOT_CA
# Takes care of PEM checking # Takes care of PEM checking
self.validate_pk() self.validate_pk()
@ -114,149 +123,209 @@ class CertBundle(BaseModel):
sans.add(domain_name) sans.add(domain_name)
sans = [cryptography.x509.DNSName(item) for item in sans] sans = [cryptography.x509.DNSName(item) for item in sans]
key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(
subject = cryptography.x509.Name([ public_exponent=65537, key_size=2048, backend=default_backend()
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), )
cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"), subject = cryptography.x509.Name(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"), [
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"), cryptography.x509.NameAttribute(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name), cryptography.x509.NameOID.COUNTRY_NAME, "US"
]) ),
issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon cryptography.x509.NameAttribute(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA"
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"), ),
cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"), cryptography.x509.NameAttribute(
cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"), cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco"
]) ),
cert = cryptography.x509.CertificateBuilder().subject_name( cryptography.x509.NameAttribute(
subject cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company"
).issuer_name( ),
issuer cryptography.x509.NameAttribute(
).public_key( cryptography.x509.NameOID.COMMON_NAME, domain_name
key.public_key() ),
).serial_number( ]
cryptography.x509.random_serial_number() )
).not_valid_before( issuer = cryptography.x509.Name(
datetime.datetime.utcnow() [ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon
).not_valid_after( cryptography.x509.NameAttribute(
datetime.datetime.utcnow() + datetime.timedelta(days=365) cryptography.x509.NameOID.COUNTRY_NAME, "US"
).add_extension( ),
cryptography.x509.SubjectAlternativeName(sans), cryptography.x509.NameAttribute(
critical=False, cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon"
).sign(key, hashes.SHA512(), default_backend()) ),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B"
),
cryptography.x509.NameAttribute(
cryptography.x509.NameOID.COMMON_NAME, "Amazon"
),
]
)
cert = (
cryptography.x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(key.public_key())
.serial_number(cryptography.x509.random_serial_number())
.not_valid_before(datetime.datetime.utcnow())
.not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365))
.add_extension(
cryptography.x509.SubjectAlternativeName(sans), critical=False
)
.sign(key, hashes.SHA512(), default_backend())
)
cert_armored = cert.public_bytes(serialization.Encoding.PEM) cert_armored = cert.public_bytes(serialization.Encoding.PEM)
private_key = key.private_bytes( private_key = key.private_bytes(
encoding=serialization.Encoding.PEM, encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL, format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption() encryption_algorithm=serialization.NoEncryption(),
) )
return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION', region=region) return cls(
cert_armored,
private_key,
cert_type="AMAZON_ISSUED",
cert_status="PENDING_VALIDATION",
region=region,
)
def validate_pk(self): def validate_pk(self):
try: try:
self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend()) self._key = serialization.load_pem_private_key(
self.key, password=None, backend=default_backend()
)
if self._key.key_size > 2048: if self._key.key_size > 2048:
AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.') AWSValidationException(
"The private key length is not supported. Only 1024-bit and 2048-bit are allowed."
)
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
raise AWSValidationException('The private key is not PEM-encoded or is not valid.') raise AWSValidationException(
"The private key is not PEM-encoded or is not valid."
)
def validate_certificate(self): def validate_certificate(self):
try: try:
self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend()) self._cert = cryptography.x509.load_pem_x509_certificate(
self.cert, default_backend()
)
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
if self._cert.not_valid_after < now: if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate has expired, is not valid.') raise AWSValidationException(
"The certificate has expired, is not valid."
)
if self._cert.not_valid_before > now: if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate is not in effect yet, is not valid.') raise AWSValidationException(
"The certificate is not in effect yet, is not valid."
)
# Extracting some common fields for ease of use # Extracting some common fields for ease of use
# Have to search through cert.subject for OIDs # Have to search through cert.subject for OIDs
self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value self.common_name = self._cert.subject.get_attributes_for_oid(
cryptography.x509.OID_COMMON_NAME
)[0].value
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def validate_chain(self): def validate_chain(self):
try: try:
self._chain = [] self._chain = []
for cert_armored in self.chain.split(b'-\n-'): for cert_armored in self.chain.split(b"-\n-"):
# Would leave encoded but Py2 does not have raw binary strings # Would leave encoded but Py2 does not have raw binary strings
cert_armored = cert_armored.decode() cert_armored = cert_armored.decode()
# Fix missing -'s on split # Fix missing -'s on split
cert_armored = re.sub(r'^----B', '-----B', cert_armored) cert_armored = re.sub(r"^----B", "-----B", cert_armored)
cert_armored = re.sub(r'E----$', 'E-----', cert_armored) cert_armored = re.sub(r"E----$", "E-----", cert_armored)
cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend()) cert = cryptography.x509.load_pem_x509_certificate(
cert_armored.encode(), default_backend()
)
self._chain.append(cert) self._chain.append(cert)
now = datetime.datetime.now() now = datetime.datetime.now()
if self._cert.not_valid_after < now: if self._cert.not_valid_after < now:
raise AWSValidationException('The certificate chain has expired, is not valid.') raise AWSValidationException(
"The certificate chain has expired, is not valid."
)
if self._cert.not_valid_before > now: if self._cert.not_valid_before > now:
raise AWSValidationException('The certificate chain is not in effect yet, is not valid.') raise AWSValidationException(
"The certificate chain is not in effect yet, is not valid."
)
except Exception as err: except Exception as err:
if isinstance(err, AWSValidationException): if isinstance(err, AWSValidationException):
raise raise
raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') raise AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
)
def check(self): def check(self):
# Basically, if the certificate is pending, and then checked again after 1 min # Basically, if the certificate is pending, and then checked again after 1 min
# It will appear as if its been validated # It will appear as if its been validated
if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \ if (
(datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min self.type == "AMAZON_ISSUED"
self.status = 'ISSUED' and self.status == "PENDING_VALIDATION"
and (datetime.datetime.now() - self.created_at).total_seconds() > 60
): # 1min
self.status = "ISSUED"
def describe(self): def describe(self):
# 'RenewalSummary': {}, # Only when cert is amazon issued # 'RenewalSummary': {}, # Only when cert is amazon issued
if self._key.key_size == 1024: if self._key.key_size == 1024:
key_algo = 'RSA_1024' key_algo = "RSA_1024"
elif self._key.key_size == 2048: elif self._key.key_size == 2048:
key_algo = 'RSA_2048' key_algo = "RSA_2048"
else: else:
key_algo = 'EC_prime256v1' key_algo = "EC_prime256v1"
# Look for SANs # Look for SANs
san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME) san_obj = self._cert.extensions.get_extension_for_oid(
cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME
)
sans = [] sans = []
if san_obj is not None: if san_obj is not None:
sans = [item.value for item in san_obj.value] sans = [item.value for item in san_obj.value]
result = { result = {
'Certificate': { "Certificate": {
'CertificateArn': self.arn, "CertificateArn": self.arn,
'DomainName': self.common_name, "DomainName": self.common_name,
'InUseBy': [], "InUseBy": [],
'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value, "Issuer": self._cert.issuer.get_attributes_for_oid(
'KeyAlgorithm': key_algo, cryptography.x509.OID_COMMON_NAME
'NotAfter': datetime_to_epoch(self._cert.not_valid_after), )[0].value,
'NotBefore': datetime_to_epoch(self._cert.not_valid_before), "KeyAlgorithm": key_algo,
'Serial': self._cert.serial_number, "NotAfter": datetime_to_epoch(self._cert.not_valid_after),
'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''), "NotBefore": datetime_to_epoch(self._cert.not_valid_before),
'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED. "Serial": self._cert.serial_number,
'Subject': 'CN={0}'.format(self.common_name), "SignatureAlgorithm": self._cert.signature_algorithm_oid._name.upper().replace(
'SubjectAlternativeNames': sans, "ENCRYPTION", ""
'Type': self.type # One of IMPORTED, AMAZON_ISSUED ),
"Status": self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED.
"Subject": "CN={0}".format(self.common_name),
"SubjectAlternativeNames": sans,
"Type": self.type, # One of IMPORTED, AMAZON_ISSUED
} }
} }
if self.type == 'IMPORTED': if self.type == "IMPORTED":
result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at) result["Certificate"]["ImportedAt"] = datetime_to_epoch(self.created_at)
else: else:
result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at) result["Certificate"]["CreatedAt"] = datetime_to_epoch(self.created_at)
result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at) result["Certificate"]["IssuedAt"] = datetime_to_epoch(self.created_at)
return result return result
@ -264,7 +333,7 @@ class CertBundle(BaseModel):
return self.arn return self.arn
def __repr__(self): def __repr__(self):
return '<Certificate>' return "<Certificate>"
class AWSCertificateManagerBackend(BaseBackend): class AWSCertificateManagerBackend(BaseBackend):
@ -281,7 +350,9 @@ class AWSCertificateManagerBackend(BaseBackend):
@staticmethod @staticmethod
def _arn_not_found(arn): def _arn_not_found(arn):
msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID) msg = "Certificate with arn {0} not found in account {1}".format(
arn, DEFAULT_ACCOUNT_ID
)
return AWSResourceNotFoundException(msg) return AWSResourceNotFoundException(msg)
def _get_arn_from_idempotency_token(self, token): def _get_arn_from_idempotency_token(self, token):
@ -298,17 +369,20 @@ class AWSCertificateManagerBackend(BaseBackend):
""" """
now = datetime.datetime.now() now = datetime.datetime.now()
if token in self._idempotency_tokens: if token in self._idempotency_tokens:
if self._idempotency_tokens[token]['expires'] < now: if self._idempotency_tokens[token]["expires"] < now:
# Token has expired, new request # Token has expired, new request
del self._idempotency_tokens[token] del self._idempotency_tokens[token]
return None return None
else: else:
return self._idempotency_tokens[token]['arn'] return self._idempotency_tokens[token]["arn"]
return None return None
def _set_idempotency_token_arn(self, token, arn): def _set_idempotency_token_arn(self, token, arn):
self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)} self._idempotency_tokens[token] = {
"arn": arn,
"expires": datetime.datetime.now() + datetime.timedelta(hours=1),
}
def import_cert(self, certificate, private_key, chain=None, arn=None): def import_cert(self, certificate, private_key, chain=None, arn=None):
if arn is not None: if arn is not None:
@ -316,7 +390,9 @@ class AWSCertificateManagerBackend(BaseBackend):
raise self._arn_not_found(arn) raise self._arn_not_found(arn)
else: else:
# Will reuse provided ARN # Will reuse provided ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn) bundle = CertBundle(
certificate, private_key, chain=chain, region=region, arn=arn
)
else: else:
# Will generate a random ARN # Will generate a random ARN
bundle = CertBundle(certificate, private_key, chain=chain, region=region) bundle = CertBundle(certificate, private_key, chain=chain, region=region)
@ -351,13 +427,21 @@ class AWSCertificateManagerBackend(BaseBackend):
del self._certificates[arn] del self._certificates[arn]
def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names): def request_certificate(
self,
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
):
if idempotency_token is not None: if idempotency_token is not None:
arn = self._get_arn_from_idempotency_token(idempotency_token) arn = self._get_arn_from_idempotency_token(idempotency_token)
if arn is not None: if arn is not None:
return arn return arn
cert = CertBundle.generate_cert(domain_name, region=self.region, sans=subject_alt_names) cert = CertBundle.generate_cert(
domain_name, region=self.region, sans=subject_alt_names
)
if idempotency_token is not None: if idempotency_token is not None:
self._set_idempotency_token_arn(idempotency_token, cert.arn) self._set_idempotency_token_arn(idempotency_token, cert.arn)
self._certificates[cert.arn] = cert self._certificates[cert.arn] = cert
@ -369,8 +453,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn) cert_bundle = self.get_certificate(arn)
for tag in tags: for tag in tags:
key = tag['Key'] key = tag["Key"]
value = tag.get('Value', None) value = tag.get("Value", None)
cert_bundle.tags[key] = value cert_bundle.tags[key] = value
def remove_tags_from_certificate(self, arn, tags): def remove_tags_from_certificate(self, arn, tags):
@ -378,8 +462,8 @@ class AWSCertificateManagerBackend(BaseBackend):
cert_bundle = self.get_certificate(arn) cert_bundle = self.get_certificate(arn)
for tag in tags: for tag in tags:
key = tag['Key'] key = tag["Key"]
value = tag.get('Value', None) value = tag.get("Value", None)
try: try:
# If value isnt provided, just delete key # If value isnt provided, just delete key

View File

@ -7,7 +7,6 @@ from .models import acm_backends, AWSError, AWSValidationException
class AWSCertificateManagerResponse(BaseResponse): class AWSCertificateManagerResponse(BaseResponse):
@property @property
def acm_backend(self): def acm_backend(self):
""" """
@ -29,40 +28,49 @@ class AWSCertificateManagerResponse(BaseResponse):
return self.request_params.get(param, default) return self.request_params.get(param, default)
def add_tags_to_certificate(self): def add_tags_to_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
tags = self._get_param('Tags') tags = self._get_param("Tags")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
self.acm_backend.add_tags_to_certificate(arn, tags) self.acm_backend.add_tags_to_certificate(arn, tags)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
def delete_certificate(self): def delete_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
self.acm_backend.delete_certificate(arn) self.acm_backend.delete_certificate(arn)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
def describe_certificate(self): def describe_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
@ -72,11 +80,14 @@ class AWSCertificateManagerResponse(BaseResponse):
return json.dumps(cert_bundle.describe()) return json.dumps(cert_bundle.describe())
def get_certificate(self): def get_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
@ -84,8 +95,8 @@ class AWSCertificateManagerResponse(BaseResponse):
return err.response() return err.response()
result = { result = {
'Certificate': cert_bundle.cert.decode(), "Certificate": cert_bundle.cert.decode(),
'CertificateChain': cert_bundle.chain.decode() "CertificateChain": cert_bundle.chain.decode(),
} }
return json.dumps(result) return json.dumps(result)
@ -102,104 +113,129 @@ class AWSCertificateManagerResponse(BaseResponse):
:return: str(JSON) for response :return: str(JSON) for response
""" """
certificate = self._get_param('Certificate') certificate = self._get_param("Certificate")
private_key = self._get_param('PrivateKey') private_key = self._get_param("PrivateKey")
chain = self._get_param('CertificateChain') # Optional chain = self._get_param("CertificateChain") # Optional
current_arn = self._get_param('CertificateArn') # Optional current_arn = self._get_param("CertificateArn") # Optional
# Simple parameter decoding. Rather do it here as its a data transport decision not part of the # Simple parameter decoding. Rather do it here as its a data transport decision not part of the
# actual data # actual data
try: try:
certificate = base64.standard_b64decode(certificate) certificate = base64.standard_b64decode(certificate)
except Exception: except Exception:
return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response() return AWSValidationException(
"The certificate is not PEM-encoded or is not valid."
).response()
try: try:
private_key = base64.standard_b64decode(private_key) private_key = base64.standard_b64decode(private_key)
except Exception: except Exception:
return AWSValidationException('The private key is not PEM-encoded or is not valid.').response() return AWSValidationException(
"The private key is not PEM-encoded or is not valid."
).response()
if chain is not None: if chain is not None:
try: try:
chain = base64.standard_b64decode(chain) chain = base64.standard_b64decode(chain)
except Exception: except Exception:
return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response() return AWSValidationException(
"The certificate chain is not PEM-encoded or is not valid."
).response()
try: try:
arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn) arn = self.acm_backend.import_cert(
certificate, private_key, chain=chain, arn=current_arn
)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return json.dumps({'CertificateArn': arn}) return json.dumps({"CertificateArn": arn})
def list_certificates(self): def list_certificates(self):
certs = [] certs = []
statuses = self._get_param('CertificateStatuses') statuses = self._get_param("CertificateStatuses")
for cert_bundle in self.acm_backend.get_certificates_list(statuses): for cert_bundle in self.acm_backend.get_certificates_list(statuses):
certs.append({ certs.append(
'CertificateArn': cert_bundle.arn, {
'DomainName': cert_bundle.common_name "CertificateArn": cert_bundle.arn,
}) "DomainName": cert_bundle.common_name,
}
)
result = {'CertificateSummaryList': certs} result = {"CertificateSummaryList": certs}
return json.dumps(result) return json.dumps(result)
def list_tags_for_certificate(self): def list_tags_for_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return {'__type': 'MissingParameter', 'message': msg}, dict(status=400) return {"__type": "MissingParameter", "message": msg}, dict(status=400)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = {'Tags': []} result = {"Tags": []}
# Tag "objects" can not contain the Value part # Tag "objects" can not contain the Value part
for key, value in cert_bundle.tags.items(): for key, value in cert_bundle.tags.items():
tag_dict = {'Key': key} tag_dict = {"Key": key}
if value is not None: if value is not None:
tag_dict['Value'] = value tag_dict["Value"] = value
result['Tags'].append(tag_dict) result["Tags"].append(tag_dict)
return json.dumps(result) return json.dumps(result)
def remove_tags_from_certificate(self): def remove_tags_from_certificate(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
tags = self._get_param('Tags') tags = self._get_param("Tags")
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
self.acm_backend.remove_tags_from_certificate(arn, tags) self.acm_backend.remove_tags_from_certificate(arn, tags)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
def request_certificate(self): def request_certificate(self):
domain_name = self._get_param('DomainName') domain_name = self._get_param("DomainName")
domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm domain_validation_options = self._get_param(
idempotency_token = self._get_param('IdempotencyToken') "DomainValidationOptions"
subject_alt_names = self._get_param('SubjectAlternativeNames') ) # is ignored atm
idempotency_token = self._get_param("IdempotencyToken")
subject_alt_names = self._get_param("SubjectAlternativeNames")
if subject_alt_names is not None and len(subject_alt_names) > 10: if subject_alt_names is not None and len(subject_alt_names) > 10:
# There is initial AWS limit of 10 # There is initial AWS limit of 10
msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised' msg = (
return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400) "An ACM limit has been exceeded. Need to request SAN limit to be raised"
)
return (
json.dumps({"__type": "LimitExceededException", "message": msg}),
dict(status=400),
)
try: try:
arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names) arn = self.acm_backend.request_certificate(
domain_name,
domain_validation_options,
idempotency_token,
subject_alt_names,
)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return json.dumps({'CertificateArn': arn}) return json.dumps({"CertificateArn": arn})
def resend_validation_email(self): def resend_validation_email(self):
arn = self._get_param('CertificateArn') arn = self._get_param("CertificateArn")
domain = self._get_param('Domain') domain = self._get_param("Domain")
# ValidationDomain not used yet. # ValidationDomain not used yet.
# Contains domain which is equal to or a subset of Domain # Contains domain which is equal to or a subset of Domain
# that AWS will send validation emails to # that AWS will send validation emails to
@ -207,18 +243,21 @@ class AWSCertificateManagerResponse(BaseResponse):
# validation_domain = self._get_param('ValidationDomain') # validation_domain = self._get_param('ValidationDomain')
if arn is None: if arn is None:
msg = 'A required parameter for the specified action is not supplied.' msg = "A required parameter for the specified action is not supplied."
return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) return (
json.dumps({"__type": "MissingParameter", "message": msg}),
dict(status=400),
)
try: try:
cert_bundle = self.acm_backend.get_certificate(arn) cert_bundle = self.acm_backend.get_certificate(arn)
if cert_bundle.common_name != domain: if cert_bundle.common_name != domain:
msg = 'Parameter Domain does not match certificate domain' msg = "Parameter Domain does not match certificate domain"
_type = 'InvalidDomainValidationOptionsException' _type = "InvalidDomainValidationOptionsException"
return json.dumps({'__type': _type, 'message': msg}), dict(status=400) return json.dumps({"__type": _type, "message": msg}), dict(status=400)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import AWSCertificateManagerResponse from .responses import AWSCertificateManagerResponse
url_bases = [ url_bases = ["https?://acm.(.+).amazonaws.com"]
"https?://acm.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": AWSCertificateManagerResponse.dispatch}
'{0}/$': AWSCertificateManagerResponse.dispatch,
}

View File

@ -4,4 +4,6 @@ import uuid
def make_arn_for_certificate(account_id, region_name): def make_arn_for_certificate(account_id, region_name):
# Example # Example
# arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b # arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b
return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4()) return "arn:aws:acm:{0}:{1}:certificate/{2}".format(
region_name, account_id, uuid.uuid4()
)

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import apigateway_backends from .models import apigateway_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
apigateway_backend = apigateway_backends['us-east-1'] apigateway_backend = apigateway_backends["us-east-1"]
mock_apigateway = base_decorator(apigateway_backends) mock_apigateway = base_decorator(apigateway_backends)
mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends) mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends)

View File

@ -7,7 +7,8 @@ class StageNotFoundException(RESTError):
def __init__(self): def __init__(self):
super(StageNotFoundException, self).__init__( super(StageNotFoundException, self).__init__(
"NotFoundException", "Invalid stage identifier specified") "NotFoundException", "Invalid stage identifier specified"
)
class ApiKeyNotFoundException(RESTError): class ApiKeyNotFoundException(RESTError):
@ -15,4 +16,5 @@ class ApiKeyNotFoundException(RESTError):
def __init__(self): def __init__(self):
super(ApiKeyNotFoundException, self).__init__( super(ApiKeyNotFoundException, self).__init__(
"NotFoundException", "Invalid API Key identifier specified") "NotFoundException", "Invalid API Key identifier specified"
)

View File

@ -17,39 +17,33 @@ STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_nam
class Deployment(BaseModel, dict): class Deployment(BaseModel, dict):
def __init__(self, deployment_id, name, description=""): def __init__(self, deployment_id, name, description=""):
super(Deployment, self).__init__() super(Deployment, self).__init__()
self['id'] = deployment_id self["id"] = deployment_id
self['stageName'] = name self["stageName"] = name
self['description'] = description self["description"] = description
self['createdDate'] = int(time.time()) self["createdDate"] = int(time.time())
class IntegrationResponse(BaseModel, dict): class IntegrationResponse(BaseModel, dict):
def __init__(self, status_code, selection_pattern=None): def __init__(self, status_code, selection_pattern=None):
self['responseTemplates'] = {"application/json": None} self["responseTemplates"] = {"application/json": None}
self['statusCode'] = status_code self["statusCode"] = status_code
if selection_pattern: if selection_pattern:
self['selectionPattern'] = selection_pattern self["selectionPattern"] = selection_pattern
class Integration(BaseModel, dict): class Integration(BaseModel, dict):
def __init__(self, integration_type, uri, http_method, request_templates=None): def __init__(self, integration_type, uri, http_method, request_templates=None):
super(Integration, self).__init__() super(Integration, self).__init__()
self['type'] = integration_type self["type"] = integration_type
self['uri'] = uri self["uri"] = uri
self['httpMethod'] = http_method self["httpMethod"] = http_method
self['requestTemplates'] = request_templates self["requestTemplates"] = request_templates
self["integrationResponses"] = { self["integrationResponses"] = {"200": IntegrationResponse(200)}
"200": IntegrationResponse(200)
}
def create_integration_response(self, status_code, selection_pattern): def create_integration_response(self, status_code, selection_pattern):
integration_response = IntegrationResponse( integration_response = IntegrationResponse(status_code, selection_pattern)
status_code, selection_pattern)
self["integrationResponses"][status_code] = integration_response self["integrationResponses"][status_code] = integration_response
return integration_response return integration_response
@ -61,17 +55,16 @@ class Integration(BaseModel, dict):
class MethodResponse(BaseModel, dict): class MethodResponse(BaseModel, dict):
def __init__(self, status_code): def __init__(self, status_code):
super(MethodResponse, self).__init__() super(MethodResponse, self).__init__()
self['statusCode'] = status_code self["statusCode"] = status_code
class Method(BaseModel, dict): class Method(BaseModel, dict):
def __init__(self, method_type, authorization_type): def __init__(self, method_type, authorization_type):
super(Method, self).__init__() super(Method, self).__init__()
self.update(dict( self.update(
dict(
httpMethod=method_type, httpMethod=method_type,
authorizationType=authorization_type, authorizationType=authorization_type,
authorizerId=None, authorizerId=None,
@ -79,7 +72,8 @@ class Method(BaseModel, dict):
requestParameters=None, requestParameters=None,
requestModels=None, requestModels=None,
methodIntegration=None, methodIntegration=None,
)) )
)
self.method_responses = {} self.method_responses = {}
def create_response(self, response_code): def create_response(self, response_code):
@ -95,16 +89,13 @@ class Method(BaseModel, dict):
class Resource(BaseModel): class Resource(BaseModel):
def __init__(self, id, region_name, api_id, path_part, parent_id): def __init__(self, id, region_name, api_id, path_part, parent_id):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
self.api_id = api_id self.api_id = api_id
self.path_part = path_part self.path_part = path_part
self.parent_id = parent_id self.parent_id = parent_id
self.resource_methods = { self.resource_methods = {"GET": {}}
'GET': {}
}
def to_dict(self): def to_dict(self):
response = { response = {
@ -113,8 +104,8 @@ class Resource(BaseModel):
"resourceMethods": self.resource_methods, "resourceMethods": self.resource_methods,
} }
if self.parent_id: if self.parent_id:
response['parentId'] = self.parent_id response["parentId"] = self.parent_id
response['pathPart'] = self.path_part response["pathPart"] = self.path_part
return response return response
def get_path(self): def get_path(self):
@ -125,102 +116,112 @@ class Resource(BaseModel):
backend = apigateway_backends[self.region_name] backend = apigateway_backends[self.region_name]
parent = backend.get_resource(self.api_id, self.parent_id) parent = backend.get_resource(self.api_id, self.parent_id)
parent_path = parent.get_path() parent_path = parent.get_path()
if parent_path != '/': # Root parent if parent_path != "/": # Root parent
parent_path += '/' parent_path += "/"
return parent_path return parent_path
else: else:
return '' return ""
def get_response(self, request): def get_response(self, request):
integration = self.get_integration(request.method) integration = self.get_integration(request.method)
integration_type = integration['type'] integration_type = integration["type"]
if integration_type == 'HTTP': if integration_type == "HTTP":
uri = integration['uri'] uri = integration["uri"]
requests_func = getattr(requests, integration[ requests_func = getattr(requests, integration["httpMethod"].lower())
'httpMethod'].lower())
response = requests_func(uri) response = requests_func(uri)
else: else:
raise NotImplementedError( raise NotImplementedError(
"The {0} type has not been implemented".format(integration_type)) "The {0} type has not been implemented".format(integration_type)
)
return response.status_code, response.text return response.status_code, response.text
def add_method(self, method_type, authorization_type): def add_method(self, method_type, authorization_type):
method = Method(method_type=method_type, method = Method(method_type=method_type, authorization_type=authorization_type)
authorization_type=authorization_type)
self.resource_methods[method_type] = method self.resource_methods[method_type] = method
return method return method
def get_method(self, method_type): def get_method(self, method_type):
return self.resource_methods[method_type] return self.resource_methods[method_type]
def add_integration(self, method_type, integration_type, uri, request_templates=None): def add_integration(
self, method_type, integration_type, uri, request_templates=None
):
integration = Integration( integration = Integration(
integration_type, uri, method_type, request_templates=request_templates) integration_type, uri, method_type, request_templates=request_templates
self.resource_methods[method_type]['methodIntegration'] = integration )
self.resource_methods[method_type]["methodIntegration"] = integration
return integration return integration
def get_integration(self, method_type): def get_integration(self, method_type):
return self.resource_methods[method_type]['methodIntegration'] return self.resource_methods[method_type]["methodIntegration"]
def delete_integration(self, method_type): def delete_integration(self, method_type):
return self.resource_methods[method_type].pop('methodIntegration') return self.resource_methods[method_type].pop("methodIntegration")
class Stage(BaseModel, dict): class Stage(BaseModel, dict):
def __init__(
def __init__(self, name=None, deployment_id=None, variables=None, self,
description='', cacheClusterEnabled=False, cacheClusterSize=None): name=None,
deployment_id=None,
variables=None,
description="",
cacheClusterEnabled=False,
cacheClusterSize=None,
):
super(Stage, self).__init__() super(Stage, self).__init__()
if variables is None: if variables is None:
variables = {} variables = {}
self['stageName'] = name self["stageName"] = name
self['deploymentId'] = deployment_id self["deploymentId"] = deployment_id
self['methodSettings'] = {} self["methodSettings"] = {}
self['variables'] = variables self["variables"] = variables
self['description'] = description self["description"] = description
self['cacheClusterEnabled'] = cacheClusterEnabled self["cacheClusterEnabled"] = cacheClusterEnabled
if self['cacheClusterEnabled']: if self["cacheClusterEnabled"]:
self['cacheClusterSize'] = str(0.5) self["cacheClusterSize"] = str(0.5)
if cacheClusterSize is not None: if cacheClusterSize is not None:
self['cacheClusterSize'] = str(cacheClusterSize) self["cacheClusterSize"] = str(cacheClusterSize)
def apply_operations(self, patch_operations): def apply_operations(self, patch_operations):
for op in patch_operations: for op in patch_operations:
if 'variables/' in op['path']: if "variables/" in op["path"]:
self._apply_operation_to_variables(op) self._apply_operation_to_variables(op)
elif '/cacheClusterEnabled' in op['path']: elif "/cacheClusterEnabled" in op["path"]:
self['cacheClusterEnabled'] = self._str2bool(op['value']) self["cacheClusterEnabled"] = self._str2bool(op["value"])
if 'cacheClusterSize' not in self and self['cacheClusterEnabled']: if "cacheClusterSize" not in self and self["cacheClusterEnabled"]:
self['cacheClusterSize'] = str(0.5) self["cacheClusterSize"] = str(0.5)
elif '/cacheClusterSize' in op['path']: elif "/cacheClusterSize" in op["path"]:
self['cacheClusterSize'] = str(float(op['value'])) self["cacheClusterSize"] = str(float(op["value"]))
elif '/description' in op['path']: elif "/description" in op["path"]:
self['description'] = op['value'] self["description"] = op["value"]
elif '/deploymentId' in op['path']: elif "/deploymentId" in op["path"]:
self['deploymentId'] = op['value'] self["deploymentId"] = op["value"]
elif op['op'] == 'replace': elif op["op"] == "replace":
# Method Settings drop into here # Method Settings drop into here
# (e.g., path could be '/*/*/logging/loglevel') # (e.g., path could be '/*/*/logging/loglevel')
split_path = op['path'].split('/', 3) split_path = op["path"].split("/", 3)
if len(split_path) != 4: if len(split_path) != 4:
continue continue
self._patch_method_setting( self._patch_method_setting(
'/'.join(split_path[1:3]), split_path[3], op['value']) "/".join(split_path[1:3]), split_path[3], op["value"]
)
else: else:
raise Exception( raise Exception('Patch operation "%s" not implemented' % op["op"])
'Patch operation "%s" not implemented' % op['op'])
return self return self
def _patch_method_setting(self, resource_path_and_method, key, value): def _patch_method_setting(self, resource_path_and_method, key, value):
updated_key = self._method_settings_translations(key) updated_key = self._method_settings_translations(key)
if updated_key is not None: if updated_key is not None:
if resource_path_and_method not in self['methodSettings']: if resource_path_and_method not in self["methodSettings"]:
self['methodSettings'][ self["methodSettings"][
resource_path_and_method] = self._get_default_method_settings() resource_path_and_method
self['methodSettings'][resource_path_and_method][ ] = self._get_default_method_settings()
updated_key] = self._convert_to_type(updated_key, value) self["methodSettings"][resource_path_and_method][
updated_key
] = self._convert_to_type(updated_key, value)
def _get_default_method_settings(self): def _get_default_method_settings(self):
return { return {
@ -232,21 +233,21 @@ class Stage(BaseModel, dict):
"cacheDataEncrypted": True, "cacheDataEncrypted": True,
"cachingEnabled": False, "cachingEnabled": False,
"throttlingBurstLimit": 2000, "throttlingBurstLimit": 2000,
"requireAuthorizationForCacheControl": True "requireAuthorizationForCacheControl": True,
} }
def _method_settings_translations(self, key): def _method_settings_translations(self, key):
mappings = { mappings = {
'metrics/enabled': 'metricsEnabled', "metrics/enabled": "metricsEnabled",
'logging/loglevel': 'loggingLevel', "logging/loglevel": "loggingLevel",
'logging/dataTrace': 'dataTraceEnabled', "logging/dataTrace": "dataTraceEnabled",
'throttling/burstLimit': 'throttlingBurstLimit', "throttling/burstLimit": "throttlingBurstLimit",
'throttling/rateLimit': 'throttlingRateLimit', "throttling/rateLimit": "throttlingRateLimit",
'caching/enabled': 'cachingEnabled', "caching/enabled": "cachingEnabled",
'caching/ttlInSeconds': 'cacheTtlInSeconds', "caching/ttlInSeconds": "cacheTtlInSeconds",
'caching/dataEncrypted': 'cacheDataEncrypted', "caching/dataEncrypted": "cacheDataEncrypted",
'caching/requireAuthorizationForCacheControl': 'requireAuthorizationForCacheControl', "caching/requireAuthorizationForCacheControl": "requireAuthorizationForCacheControl",
'caching/unauthorizedCacheControlHeaderStrategy': 'unauthorizedCacheControlHeaderStrategy' "caching/unauthorizedCacheControlHeaderStrategy": "unauthorizedCacheControlHeaderStrategy",
} }
if key in mappings: if key in mappings:
@ -259,26 +260,26 @@ class Stage(BaseModel, dict):
def _convert_to_type(self, key, val): def _convert_to_type(self, key, val):
type_mappings = { type_mappings = {
'metricsEnabled': 'bool', "metricsEnabled": "bool",
'loggingLevel': 'str', "loggingLevel": "str",
'dataTraceEnabled': 'bool', "dataTraceEnabled": "bool",
'throttlingBurstLimit': 'int', "throttlingBurstLimit": "int",
'throttlingRateLimit': 'float', "throttlingRateLimit": "float",
'cachingEnabled': 'bool', "cachingEnabled": "bool",
'cacheTtlInSeconds': 'int', "cacheTtlInSeconds": "int",
'cacheDataEncrypted': 'bool', "cacheDataEncrypted": "bool",
'requireAuthorizationForCacheControl': 'bool', "requireAuthorizationForCacheControl": "bool",
'unauthorizedCacheControlHeaderStrategy': 'str' "unauthorizedCacheControlHeaderStrategy": "str",
} }
if key in type_mappings: if key in type_mappings:
type_value = type_mappings[key] type_value = type_mappings[key]
if type_value == 'bool': if type_value == "bool":
return self._str2bool(val) return self._str2bool(val)
elif type_value == 'int': elif type_value == "int":
return int(val) return int(val)
elif type_value == 'float': elif type_value == "float":
return float(val) return float(val)
else: else:
return str(val) return str(val)
@ -286,44 +287,55 @@ class Stage(BaseModel, dict):
return str(val) return str(val)
def _apply_operation_to_variables(self, op): def _apply_operation_to_variables(self, op):
key = op['path'][op['path'].rindex("variables/") + 10:] key = op["path"][op["path"].rindex("variables/") + 10 :]
if op['op'] == 'remove': if op["op"] == "remove":
self['variables'].pop(key, None) self["variables"].pop(key, None)
elif op['op'] == 'replace': elif op["op"] == "replace":
self['variables'][key] = op['value'] self["variables"][key] = op["value"]
else: else:
raise Exception('Patch operation "%s" not implemented' % op['op']) raise Exception('Patch operation "%s" not implemented' % op["op"])
class ApiKey(BaseModel, dict): class ApiKey(BaseModel, dict):
def __init__(
def __init__(self, name=None, description=None, enabled=True, self,
generateDistinctId=False, value=None, stageKeys=None, tags=None, customerId=None): name=None,
description=None,
enabled=True,
generateDistinctId=False,
value=None,
stageKeys=None,
tags=None,
customerId=None,
):
super(ApiKey, self).__init__() super(ApiKey, self).__init__()
self['id'] = create_id() self["id"] = create_id()
self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40)) self["value"] = (
self['name'] = name value
self['customerId'] = customerId if value
self['description'] = description else "".join(random.sample(string.ascii_letters + string.digits, 40))
self['enabled'] = enabled )
self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) self["name"] = name
self['stageKeys'] = stageKeys self["customerId"] = customerId
self['tags'] = tags self["description"] = description
self["enabled"] = enabled
self["createdDate"] = self["lastUpdatedDate"] = int(time.time())
self["stageKeys"] = stageKeys
self["tags"] = tags
def update_operations(self, patch_operations): def update_operations(self, patch_operations):
for op in patch_operations: for op in patch_operations:
if op['op'] == 'replace': if op["op"] == "replace":
if '/name' in op['path']: if "/name" in op["path"]:
self['name'] = op['value'] self["name"] = op["value"]
elif '/customerId' in op['path']: elif "/customerId" in op["path"]:
self['customerId'] = op['value'] self["customerId"] = op["value"]
elif '/description' in op['path']: elif "/description" in op["path"]:
self['description'] = op['value'] self["description"] = op["value"]
elif '/enabled' in op['path']: elif "/enabled" in op["path"]:
self['enabled'] = self._str2bool(op['value']) self["enabled"] = self._str2bool(op["value"])
else: else:
raise Exception( raise Exception('Patch operation "%s" not implemented' % op["op"])
'Patch operation "%s" not implemented' % op['op'])
return self return self
def _str2bool(self, v): def _str2bool(self, v):
@ -331,31 +343,35 @@ class ApiKey(BaseModel, dict):
class UsagePlan(BaseModel, dict): class UsagePlan(BaseModel, dict):
def __init__(
def __init__(self, name=None, description=None, apiStages=None, self,
throttle=None, quota=None, tags=None): name=None,
description=None,
apiStages=None,
throttle=None,
quota=None,
tags=None,
):
super(UsagePlan, self).__init__() super(UsagePlan, self).__init__()
self['id'] = create_id() self["id"] = create_id()
self['name'] = name self["name"] = name
self['description'] = description self["description"] = description
self['apiStages'] = apiStages if apiStages else [] self["apiStages"] = apiStages if apiStages else []
self['throttle'] = throttle self["throttle"] = throttle
self['quota'] = quota self["quota"] = quota
self['tags'] = tags self["tags"] = tags
class UsagePlanKey(BaseModel, dict): class UsagePlanKey(BaseModel, dict):
def __init__(self, id, type, name, value): def __init__(self, id, type, name, value):
super(UsagePlanKey, self).__init__() super(UsagePlanKey, self).__init__()
self['id'] = id self["id"] = id
self['name'] = name self["name"] = name
self['type'] = type self["type"] = type
self['value'] = value self["value"] = value
class RestAPI(BaseModel): class RestAPI(BaseModel):
def __init__(self, id, region_name, name, description): def __init__(self, id, region_name, name, description):
self.id = id self.id = id
self.region_name = region_name self.region_name = region_name
@ -367,7 +383,7 @@ class RestAPI(BaseModel):
self.stages = {} self.stages = {}
self.resources = {} self.resources = {}
self.add_child('/') # Add default child self.add_child("/") # Add default child
def __repr__(self): def __repr__(self):
return str(self.id) return str(self.id)
@ -382,8 +398,13 @@ class RestAPI(BaseModel):
def add_child(self, path, parent_id=None): def add_child(self, path, parent_id=None):
child_id = create_id() child_id = create_id()
child = Resource(id=child_id, region_name=self.region_name, child = Resource(
api_id=self.id, path_part=path, parent_id=parent_id) id=child_id,
region_name=self.region_name,
api_id=self.id,
path_part=path,
parent_id=parent_id,
)
self.resources[child_id] = child self.resources[child_id] = child
return child return child
@ -395,36 +416,53 @@ class RestAPI(BaseModel):
def resource_callback(self, request): def resource_callback(self, request):
path = path_url(request.url) path = path_url(request.url)
path_after_stage_name = '/'.join(path.split("/")[2:]) path_after_stage_name = "/".join(path.split("/")[2:])
if not path_after_stage_name: if not path_after_stage_name:
path_after_stage_name = '/' path_after_stage_name = "/"
resource = self.get_resource_for_path(path_after_stage_name) resource = self.get_resource_for_path(path_after_stage_name)
status_code, response = resource.get_response(request) status_code, response = resource.get_response(request)
return status_code, {}, response return status_code, {}, response
def update_integration_mocks(self, stage_name): def update_integration_mocks(self, stage_name):
stage_url_lower = STAGE_URL.format(api_id=self.id.lower(), stage_url_lower = STAGE_URL.format(
region_name=self.region_name, stage_name=stage_name) api_id=self.id.lower(), region_name=self.region_name, stage_name=stage_name
stage_url_upper = STAGE_URL.format(api_id=self.id.upper(), )
region_name=self.region_name, stage_name=stage_name) stage_url_upper = STAGE_URL.format(
api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name
)
for url in [stage_url_lower, stage_url_upper]: for url in [stage_url_lower, stage_url_upper]:
responses._default_mock._matches.insert(0, responses._default_mock._matches.insert(
0,
responses.CallbackResponse( responses.CallbackResponse(
url=url, url=url,
method=responses.GET, method=responses.GET,
callback=self.resource_callback, callback=self.resource_callback,
content_type="text/plain", content_type="text/plain",
match_querystring=False, match_querystring=False,
) ),
) )
def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): def create_stage(
self,
name,
deployment_id,
variables=None,
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
):
if variables is None: if variables is None:
variables = {} variables = {}
stage = Stage(name=name, deployment_id=deployment_id, variables=variables, stage = Stage(
description=description, cacheClusterSize=cacheClusterSize, cacheClusterEnabled=cacheClusterEnabled) name=name,
deployment_id=deployment_id,
variables=variables,
description=description,
cacheClusterSize=cacheClusterSize,
cacheClusterEnabled=cacheClusterEnabled,
)
self.stages[name] = stage self.stages[name] = stage
self.update_integration_mocks(name) self.update_integration_mocks(name)
return stage return stage
@ -436,7 +474,8 @@ class RestAPI(BaseModel):
deployment = Deployment(deployment_id, name, description) deployment = Deployment(deployment_id, name, description)
self.deployments[deployment_id] = deployment self.deployments[deployment_id] = deployment
self.stages[name] = Stage( self.stages[name] = Stage(
name=name, deployment_id=deployment_id, variables=stage_variables) name=name, deployment_id=deployment_id, variables=stage_variables
)
self.update_integration_mocks(name) self.update_integration_mocks(name)
return deployment return deployment
@ -455,7 +494,6 @@ class RestAPI(BaseModel):
class APIGatewayBackend(BaseBackend): class APIGatewayBackend(BaseBackend):
def __init__(self, region_name): def __init__(self, region_name):
super(APIGatewayBackend, self).__init__() super(APIGatewayBackend, self).__init__()
self.apis = {} self.apis = {}
@ -497,10 +535,7 @@ class APIGatewayBackend(BaseBackend):
def create_resource(self, function_id, parent_resource_id, path_part): def create_resource(self, function_id, parent_resource_id, path_part):
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
child = api.add_child( child = api.add_child(path=path_part, parent_id=parent_resource_id)
path=path_part,
parent_id=parent_resource_id,
)
return child return child
def delete_resource(self, function_id, resource_id): def delete_resource(self, function_id, resource_id):
@ -529,13 +564,27 @@ class APIGatewayBackend(BaseBackend):
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
return api.get_stages() return api.get_stages()
def create_stage(self, function_id, stage_name, deploymentId, def create_stage(
variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): self,
function_id,
stage_name,
deploymentId,
variables=None,
description="",
cacheClusterEnabled=None,
cacheClusterSize=None,
):
if variables is None: if variables is None:
variables = {} variables = {}
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
api.create_stage(stage_name, deploymentId, variables=variables, api.create_stage(
description=description, cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) stage_name,
deploymentId,
variables=variables,
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
)
return api.stages.get(stage_name) return api.stages.get(stage_name)
def update_stage(self, function_id, stage_name, patch_operations): def update_stage(self, function_id, stage_name, patch_operations):
@ -550,21 +599,33 @@ class APIGatewayBackend(BaseBackend):
method_response = method.get_response(response_code) method_response = method.get_response(response_code)
return method_response return method_response
def create_method_response(self, function_id, resource_id, method_type, response_code): def create_method_response(
self, function_id, resource_id, method_type, response_code
):
method = self.get_method(function_id, resource_id, method_type) method = self.get_method(function_id, resource_id, method_type)
method_response = method.create_response(response_code) method_response = method.create_response(response_code)
return method_response return method_response
def delete_method_response(self, function_id, resource_id, method_type, response_code): def delete_method_response(
self, function_id, resource_id, method_type, response_code
):
method = self.get_method(function_id, resource_id, method_type) method = self.get_method(function_id, resource_id, method_type)
method_response = method.delete_response(response_code) method_response = method.delete_response(response_code)
return method_response return method_response
def create_integration(self, function_id, resource_id, method_type, integration_type, uri, def create_integration(
request_templates=None): self,
function_id,
resource_id,
method_type,
integration_type,
uri,
request_templates=None,
):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
integration = resource.add_integration(method_type, integration_type, uri, integration = resource.add_integration(
request_templates=request_templates) method_type, integration_type, uri, request_templates=request_templates
)
return integration return integration
def get_integration(self, function_id, resource_id, method_type): def get_integration(self, function_id, resource_id, method_type):
@ -575,28 +636,32 @@ class APIGatewayBackend(BaseBackend):
resource = self.get_resource(function_id, resource_id) resource = self.get_resource(function_id, resource_id)
return resource.delete_integration(method_type) return resource.delete_integration(method_type)
def create_integration_response(self, function_id, resource_id, method_type, status_code, selection_pattern): def create_integration_response(
integration = self.get_integration( self, function_id, resource_id, method_type, status_code, selection_pattern
function_id, resource_id, method_type) ):
integration = self.get_integration(function_id, resource_id, method_type)
integration_response = integration.create_integration_response( integration_response = integration.create_integration_response(
status_code, selection_pattern) status_code, selection_pattern
)
return integration_response return integration_response
def get_integration_response(self, function_id, resource_id, method_type, status_code): def get_integration_response(
integration = self.get_integration( self, function_id, resource_id, method_type, status_code
function_id, resource_id, method_type) ):
integration_response = integration.get_integration_response( integration = self.get_integration(function_id, resource_id, method_type)
status_code) integration_response = integration.get_integration_response(status_code)
return integration_response return integration_response
def delete_integration_response(self, function_id, resource_id, method_type, status_code): def delete_integration_response(
integration = self.get_integration( self, function_id, resource_id, method_type, status_code
function_id, resource_id, method_type) ):
integration_response = integration.delete_integration_response( integration = self.get_integration(function_id, resource_id, method_type)
status_code) integration_response = integration.delete_integration_response(status_code)
return integration_response return integration_response
def create_deployment(self, function_id, name, description="", stage_variables=None): def create_deployment(
self, function_id, name, description="", stage_variables=None
):
if stage_variables is None: if stage_variables is None:
stage_variables = {} stage_variables = {}
api = self.get_rest_api(function_id) api = self.get_rest_api(function_id)
@ -617,7 +682,7 @@ class APIGatewayBackend(BaseBackend):
def create_apikey(self, payload): def create_apikey(self, payload):
key = ApiKey(**payload) key = ApiKey(**payload)
self.keys[key['id']] = key self.keys[key["id"]] = key
return key return key
def get_apikeys(self): def get_apikeys(self):
@ -636,7 +701,7 @@ class APIGatewayBackend(BaseBackend):
def create_usage_plan(self, payload): def create_usage_plan(self, payload):
plan = UsagePlan(**payload) plan = UsagePlan(**payload)
self.usage_plans[plan['id']] = plan self.usage_plans[plan["id"]] = plan
return plan return plan
def get_usage_plans(self, api_key_id=None): def get_usage_plans(self, api_key_id=None):
@ -645,7 +710,7 @@ class APIGatewayBackend(BaseBackend):
plans = [ plans = [
plan plan
for plan in plans for plan in plans
if self.usage_plan_keys.get(plan['id'], {}).get(api_key_id, False) if self.usage_plan_keys.get(plan["id"], {}).get(api_key_id, False)
] ]
return plans return plans
@ -666,8 +731,13 @@ class APIGatewayBackend(BaseBackend):
api_key = self.keys[key_id] api_key = self.keys[key_id]
usage_plan_key = UsagePlanKey(id=key_id, type=payload["keyType"], name=api_key["name"], value=api_key["value"]) usage_plan_key = UsagePlanKey(
self.usage_plan_keys[usage_plan_id][usage_plan_key['id']] = usage_plan_key id=key_id,
type=payload["keyType"],
name=api_key["name"],
value=api_key["value"],
)
self.usage_plan_keys[usage_plan_id][usage_plan_key["id"]] = usage_plan_key
return usage_plan_key return usage_plan_key
def get_usage_plan_keys(self, usage_plan_id): def get_usage_plan_keys(self, usage_plan_id):
@ -685,5 +755,5 @@ class APIGatewayBackend(BaseBackend):
apigateway_backends = {} apigateway_backends = {}
for region_name in Session().get_available_regions('apigateway'): for region_name in Session().get_available_regions("apigateway"):
apigateway_backends[region_name] = APIGatewayBackend(region_name) apigateway_backends[region_name] = APIGatewayBackend(region_name)

View File

@ -8,7 +8,6 @@ from .exceptions import StageNotFoundException, ApiKeyNotFoundException
class APIGatewayResponse(BaseResponse): class APIGatewayResponse(BaseResponse):
def _get_param(self, key): def _get_param(self, key):
return json.loads(self.body).get(key) return json.loads(self.body).get(key)
@ -27,14 +26,12 @@ class APIGatewayResponse(BaseResponse):
def restapis(self, request, full_url, headers): def restapis(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == 'GET': if self.method == "GET":
apis = self.backend.list_apis() apis = self.backend.list_apis()
return 200, {}, json.dumps({"item": [ return 200, {}, json.dumps({"item": [api.to_dict() for api in apis]})
api.to_dict() for api in apis elif self.method == "POST":
]}) name = self._get_param("name")
elif self.method == 'POST': description = self._get_param("description")
name = self._get_param('name')
description = self._get_param('description')
rest_api = self.backend.create_rest_api(name, description) rest_api = self.backend.create_rest_api(name, description)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
@ -42,10 +39,10 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET': if self.method == "GET":
rest_api = self.backend.get_rest_api(function_id) rest_api = self.backend.get_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
elif self.method == 'DELETE': elif self.method == "DELETE":
rest_api = self.backend.delete_rest_api(function_id) rest_api = self.backend.delete_rest_api(function_id)
return 200, {}, json.dumps(rest_api.to_dict()) return 200, {}, json.dumps(rest_api.to_dict())
@ -53,24 +50,25 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET': if self.method == "GET":
resources = self.backend.list_resources(function_id) resources = self.backend.list_resources(function_id)
return 200, {}, json.dumps({"item": [ return (
resource.to_dict() for resource in resources 200,
]}) {},
json.dumps({"item": [resource.to_dict() for resource in resources]}),
)
def resource_individual(self, request, full_url, headers): def resource_individual(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
resource_id = self.path.split("/")[-1] resource_id = self.path.split("/")[-1]
if self.method == 'GET': if self.method == "GET":
resource = self.backend.get_resource(function_id, resource_id) resource = self.backend.get_resource(function_id, resource_id)
elif self.method == 'POST': elif self.method == "POST":
path_part = self._get_param("pathPart") path_part = self._get_param("pathPart")
resource = self.backend.create_resource( resource = self.backend.create_resource(function_id, resource_id, path_part)
function_id, resource_id, path_part) elif self.method == "DELETE":
elif self.method == 'DELETE':
resource = self.backend.delete_resource(function_id, resource_id) resource = self.backend.delete_resource(function_id, resource_id)
return 200, {}, json.dumps(resource.to_dict()) return 200, {}, json.dumps(resource.to_dict())
@ -81,14 +79,14 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4] resource_id = url_path_parts[4]
method_type = url_path_parts[6] method_type = url_path_parts[6]
if self.method == 'GET': if self.method == "GET":
method = self.backend.get_method( method = self.backend.get_method(function_id, resource_id, method_type)
function_id, resource_id, method_type)
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
elif self.method == 'PUT': elif self.method == "PUT":
authorization_type = self._get_param("authorizationType") authorization_type = self._get_param("authorizationType")
method = self.backend.create_method( method = self.backend.create_method(
function_id, resource_id, method_type, authorization_type) function_id, resource_id, method_type, authorization_type
)
return 200, {}, json.dumps(method) return 200, {}, json.dumps(method)
def resource_method_responses(self, request, full_url, headers): def resource_method_responses(self, request, full_url, headers):
@ -99,15 +97,18 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
response_code = url_path_parts[8] response_code = url_path_parts[8]
if self.method == 'GET': if self.method == "GET":
method_response = self.backend.get_method_response( method_response = self.backend.get_method_response(
function_id, resource_id, method_type, response_code) function_id, resource_id, method_type, response_code
elif self.method == 'PUT': )
elif self.method == "PUT":
method_response = self.backend.create_method_response( method_response = self.backend.create_method_response(
function_id, resource_id, method_type, response_code) function_id, resource_id, method_type, response_code
elif self.method == 'DELETE': )
elif self.method == "DELETE":
method_response = self.backend.delete_method_response( method_response = self.backend.delete_method_response(
function_id, resource_id, method_type, response_code) function_id, resource_id, method_type, response_code
)
return 200, {}, json.dumps(method_response) return 200, {}, json.dumps(method_response)
def restapis_stages(self, request, full_url, headers): def restapis_stages(self, request, full_url, headers):
@ -115,21 +116,28 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
function_id = url_path_parts[2] function_id = url_path_parts[2]
if self.method == 'POST': if self.method == "POST":
stage_name = self._get_param("stageName") stage_name = self._get_param("stageName")
deployment_id = self._get_param("deploymentId") deployment_id = self._get_param("deploymentId")
stage_variables = self._get_param_with_default_value( stage_variables = self._get_param_with_default_value("variables", {})
'variables', {}) description = self._get_param_with_default_value("description", "")
description = self._get_param_with_default_value('description', '')
cacheClusterEnabled = self._get_param_with_default_value( cacheClusterEnabled = self._get_param_with_default_value(
'cacheClusterEnabled', False) "cacheClusterEnabled", False
)
cacheClusterSize = self._get_param_with_default_value( cacheClusterSize = self._get_param_with_default_value(
'cacheClusterSize', None) "cacheClusterSize", None
)
stage_response = self.backend.create_stage(function_id, stage_name, deployment_id, stage_response = self.backend.create_stage(
variables=stage_variables, description=description, function_id,
cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) stage_name,
elif self.method == 'GET': deployment_id,
variables=stage_variables,
description=description,
cacheClusterEnabled=cacheClusterEnabled,
cacheClusterSize=cacheClusterSize,
)
elif self.method == "GET":
stages = self.backend.get_stages(function_id) stages = self.backend.get_stages(function_id)
return 200, {}, json.dumps({"item": stages}) return 200, {}, json.dumps({"item": stages})
@ -141,16 +149,22 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2] function_id = url_path_parts[2]
stage_name = url_path_parts[4] stage_name = url_path_parts[4]
if self.method == 'GET': if self.method == "GET":
try: try:
stage_response = self.backend.get_stage( stage_response = self.backend.get_stage(function_id, stage_name)
function_id, stage_name)
except StageNotFoundException as error: except StageNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) return (
elif self.method == 'PATCH': error.code,
patch_operations = self._get_param('patchOperations') {},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == "PATCH":
patch_operations = self._get_param("patchOperations")
stage_response = self.backend.update_stage( stage_response = self.backend.update_stage(
function_id, stage_name, patch_operations) function_id, stage_name, patch_operations
)
return 200, {}, json.dumps(stage_response) return 200, {}, json.dumps(stage_response)
def integrations(self, request, full_url, headers): def integrations(self, request, full_url, headers):
@ -160,18 +174,26 @@ class APIGatewayResponse(BaseResponse):
resource_id = url_path_parts[4] resource_id = url_path_parts[4]
method_type = url_path_parts[6] method_type = url_path_parts[6]
if self.method == 'GET': if self.method == "GET":
integration_response = self.backend.get_integration( integration_response = self.backend.get_integration(
function_id, resource_id, method_type) function_id, resource_id, method_type
elif self.method == 'PUT': )
integration_type = self._get_param('type') elif self.method == "PUT":
uri = self._get_param('uri') integration_type = self._get_param("type")
request_templates = self._get_param('requestTemplates') uri = self._get_param("uri")
request_templates = self._get_param("requestTemplates")
integration_response = self.backend.create_integration( integration_response = self.backend.create_integration(
function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates) function_id,
elif self.method == 'DELETE': resource_id,
method_type,
integration_type,
uri,
request_templates=request_templates,
)
elif self.method == "DELETE":
integration_response = self.backend.delete_integration( integration_response = self.backend.delete_integration(
function_id, resource_id, method_type) function_id, resource_id, method_type
)
return 200, {}, json.dumps(integration_response) return 200, {}, json.dumps(integration_response)
def integration_responses(self, request, full_url, headers): def integration_responses(self, request, full_url, headers):
@ -182,16 +204,16 @@ class APIGatewayResponse(BaseResponse):
method_type = url_path_parts[6] method_type = url_path_parts[6]
status_code = url_path_parts[9] status_code = url_path_parts[9]
if self.method == 'GET': if self.method == "GET":
integration_response = self.backend.get_integration_response( integration_response = self.backend.get_integration_response(
function_id, resource_id, method_type, status_code function_id, resource_id, method_type, status_code
) )
elif self.method == 'PUT': elif self.method == "PUT":
selection_pattern = self._get_param("selectionPattern") selection_pattern = self._get_param("selectionPattern")
integration_response = self.backend.create_integration_response( integration_response = self.backend.create_integration_response(
function_id, resource_id, method_type, status_code, selection_pattern function_id, resource_id, method_type, status_code, selection_pattern
) )
elif self.method == 'DELETE': elif self.method == "DELETE":
integration_response = self.backend.delete_integration_response( integration_response = self.backend.delete_integration_response(
function_id, resource_id, method_type, status_code function_id, resource_id, method_type, status_code
) )
@ -201,16 +223,16 @@ class APIGatewayResponse(BaseResponse):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
function_id = self.path.replace("/restapis/", "", 1).split("/")[0] function_id = self.path.replace("/restapis/", "", 1).split("/")[0]
if self.method == 'GET': if self.method == "GET":
deployments = self.backend.get_deployments(function_id) deployments = self.backend.get_deployments(function_id)
return 200, {}, json.dumps({"item": deployments}) return 200, {}, json.dumps({"item": deployments})
elif self.method == 'POST': elif self.method == "POST":
name = self._get_param("stageName") name = self._get_param("stageName")
description = self._get_param_with_default_value("description", "") description = self._get_param_with_default_value("description", "")
stage_variables = self._get_param_with_default_value( stage_variables = self._get_param_with_default_value("variables", {})
'variables', {})
deployment = self.backend.create_deployment( deployment = self.backend.create_deployment(
function_id, name, description, stage_variables) function_id, name, description, stage_variables
)
return 200, {}, json.dumps(deployment) return 200, {}, json.dumps(deployment)
def individual_deployment(self, request, full_url, headers): def individual_deployment(self, request, full_url, headers):
@ -219,20 +241,18 @@ class APIGatewayResponse(BaseResponse):
function_id = url_path_parts[2] function_id = url_path_parts[2]
deployment_id = url_path_parts[4] deployment_id = url_path_parts[4]
if self.method == 'GET': if self.method == "GET":
deployment = self.backend.get_deployment( deployment = self.backend.get_deployment(function_id, deployment_id)
function_id, deployment_id) elif self.method == "DELETE":
elif self.method == 'DELETE': deployment = self.backend.delete_deployment(function_id, deployment_id)
deployment = self.backend.delete_deployment(
function_id, deployment_id)
return 200, {}, json.dumps(deployment) return 200, {}, json.dumps(deployment)
def apikeys(self, request, full_url, headers): def apikeys(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == 'POST': if self.method == "POST":
apikey_response = self.backend.create_apikey(json.loads(self.body)) apikey_response = self.backend.create_apikey(json.loads(self.body))
elif self.method == 'GET': elif self.method == "GET":
apikeys_response = self.backend.get_apikeys() apikeys_response = self.backend.get_apikeys()
return 200, {}, json.dumps({"item": apikeys_response}) return 200, {}, json.dumps({"item": apikeys_response})
return 200, {}, json.dumps(apikey_response) return 200, {}, json.dumps(apikey_response)
@ -243,21 +263,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
apikey = url_path_parts[2] apikey = url_path_parts[2]
if self.method == 'GET': if self.method == "GET":
apikey_response = self.backend.get_apikey(apikey) apikey_response = self.backend.get_apikey(apikey)
elif self.method == 'PATCH': elif self.method == "PATCH":
patch_operations = self._get_param('patchOperations') patch_operations = self._get_param("patchOperations")
apikey_response = self.backend.update_apikey(apikey, patch_operations) apikey_response = self.backend.update_apikey(apikey, patch_operations)
elif self.method == 'DELETE': elif self.method == "DELETE":
apikey_response = self.backend.delete_apikey(apikey) apikey_response = self.backend.delete_apikey(apikey)
return 200, {}, json.dumps(apikey_response) return 200, {}, json.dumps(apikey_response)
def usage_plans(self, request, full_url, headers): def usage_plans(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if self.method == 'POST': if self.method == "POST":
usage_plan_response = self.backend.create_usage_plan(json.loads(self.body)) usage_plan_response = self.backend.create_usage_plan(json.loads(self.body))
elif self.method == 'GET': elif self.method == "GET":
api_key_id = self.querystring.get("keyId", [None])[0] api_key_id = self.querystring.get("keyId", [None])[0]
usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id) usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id)
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
@ -269,9 +289,9 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
usage_plan = url_path_parts[2] usage_plan = url_path_parts[2]
if self.method == 'GET': if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan(usage_plan) usage_plan_response = self.backend.get_usage_plan(usage_plan)
elif self.method == 'DELETE': elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan(usage_plan) usage_plan_response = self.backend.delete_usage_plan(usage_plan)
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)
@ -281,13 +301,21 @@ class APIGatewayResponse(BaseResponse):
url_path_parts = self.path.split("/") url_path_parts = self.path.split("/")
usage_plan_id = url_path_parts[2] usage_plan_id = url_path_parts[2]
if self.method == 'POST': if self.method == "POST":
try: try:
usage_plan_response = self.backend.create_usage_plan_key(usage_plan_id, json.loads(self.body)) usage_plan_response = self.backend.create_usage_plan_key(
usage_plan_id, json.loads(self.body)
)
except ApiKeyNotFoundException as error: except ApiKeyNotFoundException as error:
return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) return (
error.code,
{},
'{{"message":"{0}","code":"{1}"}}'.format(
error.message, error.error_type
),
)
elif self.method == 'GET': elif self.method == "GET":
usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id)
return 200, {}, json.dumps({"item": usage_plans_response}) return 200, {}, json.dumps({"item": usage_plans_response})
@ -300,8 +328,10 @@ class APIGatewayResponse(BaseResponse):
usage_plan_id = url_path_parts[2] usage_plan_id = url_path_parts[2]
key_id = url_path_parts[4] key_id = url_path_parts[4]
if self.method == 'GET': if self.method == "GET":
usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id) usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id)
elif self.method == 'DELETE': elif self.method == "DELETE":
usage_plan_response = self.backend.delete_usage_plan_key(usage_plan_id, key_id) usage_plan_response = self.backend.delete_usage_plan_key(
usage_plan_id, key_id
)
return 200, {}, json.dumps(usage_plan_response) return 200, {}, json.dumps(usage_plan_response)

View File

@ -1,27 +1,25 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import APIGatewayResponse from .responses import APIGatewayResponse
url_bases = [ url_bases = ["https?://apigateway.(.+).amazonaws.com"]
"https?://apigateway.(.+).amazonaws.com"
]
url_paths = { url_paths = {
'{0}/restapis$': APIGatewayResponse().restapis, "{0}/restapis$": APIGatewayResponse().restapis,
'{0}/restapis/(?P<function_id>[^/]+)/?$': APIGatewayResponse().restapis_individual, "{0}/restapis/(?P<function_id>[^/]+)/?$": APIGatewayResponse().restapis_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources$': APIGatewayResponse().resources, "{0}/restapis/(?P<function_id>[^/]+)/resources$": APIGatewayResponse().resources,
'{0}/restapis/(?P<function_id>[^/]+)/stages$': APIGatewayResponse().restapis_stages, "{0}/restapis/(?P<function_id>[^/]+)/stages$": APIGatewayResponse().restapis_stages,
'{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$': APIGatewayResponse().stages, "{0}/restapis/(?P<function_id>[^/]+)/stages/(?P<stage_name>[^/]+)/?$": APIGatewayResponse().stages,
'{0}/restapis/(?P<function_id>[^/]+)/deployments$': APIGatewayResponse().deployments, "{0}/restapis/(?P<function_id>[^/]+)/deployments$": APIGatewayResponse().deployments,
'{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$': APIGatewayResponse().individual_deployment, "{0}/restapis/(?P<function_id>[^/]+)/deployments/(?P<deployment_id>[^/]+)/?$": APIGatewayResponse().individual_deployment,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$': APIGatewayResponse().resource_individual, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/?$": APIGatewayResponse().resource_individual,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$': APIGatewayResponse().resource_methods, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/?$": APIGatewayResponse().resource_methods,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$': APIGatewayResponse().resource_method_responses, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/responses/(?P<status_code>\d+)$": APIGatewayResponse().resource_method_responses,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$': APIGatewayResponse().integrations, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/?$": APIGatewayResponse().integrations,
'{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$': APIGatewayResponse().integration_responses, "{0}/restapis/(?P<function_id>[^/]+)/resources/(?P<resource_id>[^/]+)/methods/(?P<method_name>[^/]+)/integration/responses/(?P<status_code>\d+)/?$": APIGatewayResponse().integration_responses,
'{0}/apikeys$': APIGatewayResponse().apikeys, "{0}/apikeys$": APIGatewayResponse().apikeys,
'{0}/apikeys/(?P<apikey>[^/]+)': APIGatewayResponse().apikey_individual, "{0}/apikeys/(?P<apikey>[^/]+)": APIGatewayResponse().apikey_individual,
'{0}/usageplans$': APIGatewayResponse().usage_plans, "{0}/usageplans$": APIGatewayResponse().usage_plans,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$': APIGatewayResponse().usage_plan_individual, "{0}/usageplans/(?P<usage_plan_id>[^/]+)/?$": APIGatewayResponse().usage_plan_individual,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$': APIGatewayResponse().usage_plan_keys, "{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys$": APIGatewayResponse().usage_plan_keys,
'{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$': APIGatewayResponse().usage_plan_key_individual, "{0}/usageplans/(?P<usage_plan_id>[^/]+)/keys/(?P<api_key_id>[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual,
} }

View File

@ -7,4 +7,4 @@ import string
def create_id(): def create_id():
size = 10 size = 10
chars = list(range(10)) + list(string.ascii_lowercase) chars = list(range(10)) + list(string.ascii_lowercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return "".join(six.text_type(random.choice(chars)) for x in range(size))

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import athena_backends from .models import athena_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
athena_backend = athena_backends['us-east-1'] athena_backend = athena_backends["us-east-1"]
mock_athena = base_decorator(athena_backends) mock_athena = base_decorator(athena_backends)
mock_athena_deprecated = deprecated_base_decorator(athena_backends) mock_athena_deprecated = deprecated_base_decorator(athena_backends)

View File

@ -5,14 +5,15 @@ from werkzeug.exceptions import BadRequest
class AthenaClientError(BadRequest): class AthenaClientError(BadRequest):
def __init__(self, code, message): def __init__(self, code, message):
super(AthenaClientError, self).__init__() super(AthenaClientError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
{
"Error": { "Error": {
"Code": code, "Code": code,
"Message": message, "Message": message,
'Type': "InvalidRequestException", "Type": "InvalidRequestException",
}, },
'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1', "RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1",
}) }
)

View File

@ -19,31 +19,30 @@ class TaggableResourceMixin(object):
@property @property
def arn(self): def arn(self):
return "arn:aws:athena:{region}:{account_id}:{resource_name}".format( return "arn:aws:athena:{region}:{account_id}:{resource_name}".format(
region=self.region, region=self.region, account_id=ACCOUNT_ID, resource_name=self.resource_name
account_id=ACCOUNT_ID, )
resource_name=self.resource_name)
def create_tags(self, tags): def create_tags(self, tags):
new_keys = [tag_set['Key'] for tag_set in tags] new_keys = [tag_set["Key"] for tag_set in tags]
self.tags = [tag_set for tag_set in self.tags self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys]
if tag_set['Key'] not in new_keys]
self.tags.extend(tags) self.tags.extend(tags)
return self.tags return self.tags
def delete_tags(self, tag_keys): def delete_tags(self, tag_keys):
self.tags = [tag_set for tag_set in self.tags self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys]
if tag_set['Key'] not in tag_keys]
return self.tags return self.tags
class WorkGroup(TaggableResourceMixin, BaseModel): class WorkGroup(TaggableResourceMixin, BaseModel):
resource_type = 'workgroup' resource_type = "workgroup"
state = 'ENABLED' state = "ENABLED"
def __init__(self, athena_backend, name, configuration, description, tags): def __init__(self, athena_backend, name, configuration, description, tags):
self.region_name = athena_backend.region_name self.region_name = athena_backend.region_name
super(WorkGroup, self).__init__(self.region_name, "workgroup/{}".format(name), tags) super(WorkGroup, self).__init__(
self.region_name, "workgroup/{}".format(name), tags
)
self.athena_backend = athena_backend self.athena_backend = athena_backend
self.name = name self.name = name
self.description = description self.description = description
@ -66,14 +65,17 @@ class AthenaBackend(BaseBackend):
return work_group return work_group
def list_work_groups(self): def list_work_groups(self):
return [{ return [
'Name': wg.name, {
'State': wg.state, "Name": wg.name,
'Description': wg.description, "State": wg.state,
'CreationTime': time.time(), "Description": wg.description,
} for wg in self.work_groups.values()] "CreationTime": time.time(),
}
for wg in self.work_groups.values()
]
athena_backends = {} athena_backends = {}
for region in boto3.Session().get_available_regions('athena'): for region in boto3.Session().get_available_regions("athena"):
athena_backends[region] = AthenaBackend(region) athena_backends[region] = AthenaBackend(region)

View File

@ -5,31 +5,37 @@ from .models import athena_backends
class AthenaResponse(BaseResponse): class AthenaResponse(BaseResponse):
@property @property
def athena_backend(self): def athena_backend(self):
return athena_backends[self.region] return athena_backends[self.region]
def create_work_group(self): def create_work_group(self):
name = self._get_param('Name') name = self._get_param("Name")
description = self._get_param('Description') description = self._get_param("Description")
configuration = self._get_param('Configuration') configuration = self._get_param("Configuration")
tags = self._get_param('Tags') tags = self._get_param("Tags")
work_group = self.athena_backend.create_work_group(name, configuration, description, tags) work_group = self.athena_backend.create_work_group(
name, configuration, description, tags
)
if not work_group: if not work_group:
return json.dumps({ return (
'__type': 'InvalidRequestException', json.dumps(
'Message': 'WorkGroup already exists', {
}), dict(status=400) "__type": "InvalidRequestException",
return json.dumps({ "Message": "WorkGroup already exists",
}
),
dict(status=400),
)
return json.dumps(
{
"CreateWorkGroupResponse": { "CreateWorkGroupResponse": {
"ResponseMetadata": { "ResponseMetadata": {
"RequestId": "384ac68d-3775-11df-8963-01868b7c937a", "RequestId": "384ac68d-3775-11df-8963-01868b7c937a"
} }
} }
}) }
)
def list_work_groups(self): def list_work_groups(self):
return json.dumps({ return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()})
"WorkGroups": self.athena_backend.list_work_groups()
})

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import AthenaResponse from .responses import AthenaResponse
url_bases = [ url_bases = ["https?://athena.(.+).amazonaws.com"]
"https?://athena.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": AthenaResponse.dispatch}
'{0}/$': AthenaResponse.dispatch,
}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import autoscaling_backends from .models import autoscaling_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
autoscaling_backend = autoscaling_backends['us-east-1'] autoscaling_backend = autoscaling_backends["us-east-1"]
mock_autoscaling = base_decorator(autoscaling_backends) mock_autoscaling = base_decorator(autoscaling_backends)
mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends) mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends)

View File

@ -12,13 +12,12 @@ class ResourceContentionError(RESTError):
def __init__(self): def __init__(self):
super(ResourceContentionError, self).__init__( super(ResourceContentionError, self).__init__(
"ResourceContentionError", "ResourceContentionError",
"You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).") "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).",
)
class InvalidInstanceError(AutoscalingClientError): class InvalidInstanceError(AutoscalingClientError):
def __init__(self, instance_id): def __init__(self, instance_id):
super(InvalidInstanceError, self).__init__( super(InvalidInstanceError, self).__init__(
"ValidationError", "ValidationError", "Instance [{0}] is invalid.".format(instance_id)
"Instance [{0}] is invalid." )
.format(instance_id))

View File

@ -12,7 +12,9 @@ from moto.elb import elb_backends
from moto.elbv2 import elbv2_backends from moto.elbv2 import elbv2_backends
from moto.elb.exceptions import LoadBalancerNotFoundError from moto.elb.exceptions import LoadBalancerNotFoundError
from .exceptions import ( from .exceptions import (
AutoscalingClientError, ResourceContentionError, InvalidInstanceError AutoscalingClientError,
ResourceContentionError,
InvalidInstanceError,
) )
# http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown # http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown
@ -22,8 +24,13 @@ ASG_NAME_TAG = "aws:autoscaling:groupName"
class InstanceState(object): class InstanceState(object):
def __init__(self, instance, lifecycle_state="InService", def __init__(
health_status="Healthy", protected_from_scale_in=False): self,
instance,
lifecycle_state="InService",
health_status="Healthy",
protected_from_scale_in=False,
):
self.instance = instance self.instance = instance
self.lifecycle_state = lifecycle_state self.lifecycle_state = lifecycle_state
self.health_status = health_status self.health_status = health_status
@ -31,8 +38,16 @@ class InstanceState(object):
class FakeScalingPolicy(BaseModel): class FakeScalingPolicy(BaseModel):
def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment, def __init__(
cooldown, autoscaling_backend): self,
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
autoscaling_backend,
):
self.name = name self.name = name
self.policy_type = policy_type self.policy_type = policy_type
self.adjustment_type = adjustment_type self.adjustment_type = adjustment_type
@ -45,21 +60,38 @@ class FakeScalingPolicy(BaseModel):
self.autoscaling_backend = autoscaling_backend self.autoscaling_backend = autoscaling_backend
def execute(self): def execute(self):
if self.adjustment_type == 'ExactCapacity': if self.adjustment_type == "ExactCapacity":
self.autoscaling_backend.set_desired_capacity( self.autoscaling_backend.set_desired_capacity(
self.as_name, self.scaling_adjustment) self.as_name, self.scaling_adjustment
elif self.adjustment_type == 'ChangeInCapacity': )
elif self.adjustment_type == "ChangeInCapacity":
self.autoscaling_backend.change_capacity( self.autoscaling_backend.change_capacity(
self.as_name, self.scaling_adjustment) self.as_name, self.scaling_adjustment
elif self.adjustment_type == 'PercentChangeInCapacity': )
elif self.adjustment_type == "PercentChangeInCapacity":
self.autoscaling_backend.change_capacity_percent( self.autoscaling_backend.change_capacity_percent(
self.as_name, self.scaling_adjustment) self.as_name, self.scaling_adjustment
)
class FakeLaunchConfiguration(BaseModel): class FakeLaunchConfiguration(BaseModel):
def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data, def __init__(
instance_type, instance_monitoring, instance_profile_name, self,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict): name,
image_id,
key_name,
ramdisk_id,
kernel_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mapping_dict,
):
self.name = name self.name = name
self.image_id = image_id self.image_id = image_id
self.key_name = key_name self.key_name = key_name
@ -80,8 +112,8 @@ class FakeLaunchConfiguration(BaseModel):
config = backend.create_launch_configuration( config = backend.create_launch_configuration(
name=name, name=name,
image_id=instance.image_id, image_id=instance.image_id,
kernel_id='', kernel_id="",
ramdisk_id='', ramdisk_id="",
key_name=instance.key_name, key_name=instance.key_name,
security_groups=instance.security_groups, security_groups=instance.security_groups,
user_data=instance.user_data, user_data=instance.user_data,
@ -91,13 +123,15 @@ class FakeLaunchConfiguration(BaseModel):
spot_price=None, spot_price=None,
ebs_optimized=instance.ebs_optimized, ebs_optimized=instance.ebs_optimized,
associate_public_ip_address=instance.associate_public_ip, associate_public_ip_address=instance.associate_public_ip,
block_device_mappings=instance.block_device_mapping block_device_mappings=instance.block_device_mapping,
) )
return config return config
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
properties = cloudformation_json['Properties'] cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
instance_profile_name = properties.get("IamInstanceProfile") instance_profile_name = properties.get("IamInstanceProfile")
@ -115,20 +149,26 @@ class FakeLaunchConfiguration(BaseModel):
instance_profile_name=instance_profile_name, instance_profile_name=instance_profile_name,
spot_price=properties.get("SpotPrice"), spot_price=properties.get("SpotPrice"),
ebs_optimized=properties.get("EbsOptimized"), ebs_optimized=properties.get("EbsOptimized"),
associate_public_ip_address=properties.get( associate_public_ip_address=properties.get("AssociatePublicIpAddress"),
"AssociatePublicIpAddress"), block_device_mappings=properties.get("BlockDeviceMapping.member"),
block_device_mappings=properties.get("BlockDeviceMapping.member")
) )
return config return config
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json( cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name) original_resource.name, cloudformation_json, region_name
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) )
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod @classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name] backend = autoscaling_backends[region_name]
try: try:
backend.delete_launch_configuration(resource_name) backend.delete_launch_configuration(resource_name)
@ -153,34 +193,49 @@ class FakeLaunchConfiguration(BaseModel):
@property @property
def instance_monitoring_enabled(self): def instance_monitoring_enabled(self):
if self.instance_monitoring: if self.instance_monitoring:
return 'true' return "true"
return 'false' return "false"
def _parse_block_device_mappings(self): def _parse_block_device_mappings(self):
block_device_map = BlockDeviceMapping() block_device_map = BlockDeviceMapping()
for mapping in self.block_device_mapping_dict: for mapping in self.block_device_mapping_dict:
block_type = BlockDeviceType() block_type = BlockDeviceType()
mount_point = mapping.get('device_name') mount_point = mapping.get("device_name")
if 'ephemeral' in mapping.get('virtual_name', ''): if "ephemeral" in mapping.get("virtual_name", ""):
block_type.ephemeral_name = mapping.get('virtual_name') block_type.ephemeral_name = mapping.get("virtual_name")
else: else:
block_type.volume_type = mapping.get('ebs._volume_type') block_type.volume_type = mapping.get("ebs._volume_type")
block_type.snapshot_id = mapping.get('ebs._snapshot_id') block_type.snapshot_id = mapping.get("ebs._snapshot_id")
block_type.delete_on_termination = mapping.get( block_type.delete_on_termination = mapping.get(
'ebs._delete_on_termination') "ebs._delete_on_termination"
block_type.size = mapping.get('ebs._volume_size') )
block_type.iops = mapping.get('ebs._iops') block_type.size = mapping.get("ebs._volume_size")
block_type.iops = mapping.get("ebs._iops")
block_device_map[mount_point] = block_type block_device_map[mount_point] = block_type
return block_device_map return block_device_map
class FakeAutoScalingGroup(BaseModel): class FakeAutoScalingGroup(BaseModel):
def __init__(self, name, availability_zones, desired_capacity, max_size, def __init__(
min_size, launch_config_name, vpc_zone_identifier, self,
default_cooldown, health_check_period, health_check_type, name,
load_balancers, target_group_arns, placement_group, termination_policies, availability_zones,
autoscaling_backend, tags, desired_capacity,
new_instances_protected_from_scale_in=False): max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
autoscaling_backend,
tags,
new_instances_protected_from_scale_in=False,
):
self.autoscaling_backend = autoscaling_backend self.autoscaling_backend = autoscaling_backend
self.name = name self.name = name
@ -190,17 +245,22 @@ class FakeAutoScalingGroup(BaseModel):
self.min_size = min_size self.min_size = min_size
self.launch_config = self.autoscaling_backend.launch_configurations[ self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name] launch_config_name
]
self.launch_config_name = launch_config_name self.launch_config_name = launch_config_name
self.default_cooldown = default_cooldown if default_cooldown else DEFAULT_COOLDOWN self.default_cooldown = (
default_cooldown if default_cooldown else DEFAULT_COOLDOWN
)
self.health_check_period = health_check_period self.health_check_period = health_check_period
self.health_check_type = health_check_type if health_check_type else "EC2" self.health_check_type = health_check_type if health_check_type else "EC2"
self.load_balancers = load_balancers self.load_balancers = load_balancers
self.target_group_arns = target_group_arns self.target_group_arns = target_group_arns
self.placement_group = placement_group self.placement_group = placement_group
self.termination_policies = termination_policies self.termination_policies = termination_policies
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
self.suspended_processes = [] self.suspended_processes = []
self.instance_states = [] self.instance_states = []
@ -215,8 +275,10 @@ class FakeAutoScalingGroup(BaseModel):
if vpc_zone_identifier: if vpc_zone_identifier:
# extract azs for vpcs # extract azs for vpcs
subnet_ids = vpc_zone_identifier.split(',') subnet_ids = vpc_zone_identifier.split(",")
subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(subnet_ids=subnet_ids) subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(
subnet_ids=subnet_ids
)
vpc_zones = [subnet.availability_zone for subnet in subnets] vpc_zones = [subnet.availability_zone for subnet in subnets]
if availability_zones and set(availability_zones) != set(vpc_zones): if availability_zones and set(availability_zones) != set(vpc_zones):
@ -229,7 +291,7 @@ class FakeAutoScalingGroup(BaseModel):
if not update: if not update:
raise AutoscalingClientError( raise AutoscalingClientError(
"ValidationError", "ValidationError",
"At least one Availability Zone or VPC Subnet is required." "At least one Availability Zone or VPC Subnet is required.",
) )
return return
@ -237,8 +299,10 @@ class FakeAutoScalingGroup(BaseModel):
self.vpc_zone_identifier = vpc_zone_identifier self.vpc_zone_identifier = vpc_zone_identifier
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
properties = cloudformation_json['Properties'] cls, resource_name, cloudformation_json, region_name
):
properties = cloudformation_json["Properties"]
launch_config_name = properties.get("LaunchConfigurationName") launch_config_name = properties.get("LaunchConfigurationName")
load_balancer_names = properties.get("LoadBalancerNames", []) load_balancer_names = properties.get("LoadBalancerNames", [])
@ -253,7 +317,8 @@ class FakeAutoScalingGroup(BaseModel):
min_size=properties.get("MinSize"), min_size=properties.get("MinSize"),
launch_config_name=launch_config_name, launch_config_name=launch_config_name,
vpc_zone_identifier=( vpc_zone_identifier=(
','.join(properties.get("VPCZoneIdentifier", [])) or None), ",".join(properties.get("VPCZoneIdentifier", [])) or None
),
default_cooldown=properties.get("Cooldown"), default_cooldown=properties.get("Cooldown"),
health_check_period=properties.get("HealthCheckGracePeriod"), health_check_period=properties.get("HealthCheckGracePeriod"),
health_check_type=properties.get("HealthCheckType"), health_check_type=properties.get("HealthCheckType"),
@ -263,18 +328,26 @@ class FakeAutoScalingGroup(BaseModel):
termination_policies=properties.get("TerminationPolicies", []), termination_policies=properties.get("TerminationPolicies", []),
tags=properties.get("Tags", []), tags=properties.get("Tags", []),
new_instances_protected_from_scale_in=properties.get( new_instances_protected_from_scale_in=properties.get(
"NewInstancesProtectedFromScaleIn", False) "NewInstancesProtectedFromScaleIn", False
),
) )
return group return group
@classmethod @classmethod
def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): def update_from_cloudformation_json(
cls, original_resource, new_resource_name, cloudformation_json, region_name
):
cls.delete_from_cloudformation_json( cls.delete_from_cloudformation_json(
original_resource.name, cloudformation_json, region_name) original_resource.name, cloudformation_json, region_name
return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) )
return cls.create_from_cloudformation_json(
new_resource_name, cloudformation_json, region_name
)
@classmethod @classmethod
def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def delete_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
backend = autoscaling_backends[region_name] backend = autoscaling_backends[region_name]
try: try:
backend.delete_auto_scaling_group(resource_name) backend.delete_auto_scaling_group(resource_name)
@ -289,11 +362,21 @@ class FakeAutoScalingGroup(BaseModel):
def physical_resource_id(self): def physical_resource_id(self):
return self.name return self.name
def update(self, availability_zones, desired_capacity, max_size, min_size, def update(
launch_config_name, vpc_zone_identifier, default_cooldown, self,
health_check_period, health_check_type, availability_zones,
placement_group, termination_policies, desired_capacity,
new_instances_protected_from_scale_in=None): max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=None,
):
self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier, update=True) self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier, update=True)
if max_size is not None: if max_size is not None:
@ -309,14 +392,17 @@ class FakeAutoScalingGroup(BaseModel):
if launch_config_name: if launch_config_name:
self.launch_config = self.autoscaling_backend.launch_configurations[ self.launch_config = self.autoscaling_backend.launch_configurations[
launch_config_name] launch_config_name
]
self.launch_config_name = launch_config_name self.launch_config_name = launch_config_name
if health_check_period is not None: if health_check_period is not None:
self.health_check_period = health_check_period self.health_check_period = health_check_period
if health_check_type is not None: if health_check_type is not None:
self.health_check_type = health_check_type self.health_check_type = health_check_type
if new_instances_protected_from_scale_in is not None: if new_instances_protected_from_scale_in is not None:
self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in self.new_instances_protected_from_scale_in = (
new_instances_protected_from_scale_in
)
if desired_capacity is not None: if desired_capacity is not None:
self.set_desired_capacity(desired_capacity) self.set_desired_capacity(desired_capacity)
@ -342,25 +428,30 @@ class FakeAutoScalingGroup(BaseModel):
# Need to remove some instances # Need to remove some instances
count_to_remove = curr_instance_count - self.desired_capacity count_to_remove = curr_instance_count - self.desired_capacity
instances_to_remove = [ # only remove unprotected instances_to_remove = [ # only remove unprotected
state for state in self.instance_states state
for state in self.instance_states
if not state.protected_from_scale_in if not state.protected_from_scale_in
][:count_to_remove] ][:count_to_remove]
if instances_to_remove: # just in case not instances to remove if instances_to_remove: # just in case not instances to remove
instance_ids_to_remove = [ instance_ids_to_remove = [
instance.instance.id for instance in instances_to_remove] instance.instance.id for instance in instances_to_remove
]
self.autoscaling_backend.ec2_backend.terminate_instances( self.autoscaling_backend.ec2_backend.terminate_instances(
instance_ids_to_remove) instance_ids_to_remove
self.instance_states = list(set(self.instance_states) - set(instances_to_remove)) )
self.instance_states = list(
set(self.instance_states) - set(instances_to_remove)
)
def get_propagated_tags(self): def get_propagated_tags(self):
propagated_tags = {} propagated_tags = {}
for tag in self.tags: for tag in self.tags:
# boto uses 'propagate_at_launch # boto uses 'propagate_at_launch
# boto3 and cloudformation use PropagateAtLaunch # boto3 and cloudformation use PropagateAtLaunch
if 'propagate_at_launch' in tag and tag['propagate_at_launch'] == 'true': if "propagate_at_launch" in tag and tag["propagate_at_launch"] == "true":
propagated_tags[tag['key']] = tag['value'] propagated_tags[tag["key"]] = tag["value"]
if 'PropagateAtLaunch' in tag and tag['PropagateAtLaunch']: if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"]:
propagated_tags[tag['Key']] = tag['Value'] propagated_tags[tag["Key"]] = tag["Value"]
return propagated_tags return propagated_tags
def replace_autoscaling_group_instances(self, count_needed, propagated_tags): def replace_autoscaling_group_instances(self, count_needed, propagated_tags):
@ -371,15 +462,17 @@ class FakeAutoScalingGroup(BaseModel):
self.launch_config.user_data, self.launch_config.user_data,
self.launch_config.security_groups, self.launch_config.security_groups,
instance_type=self.launch_config.instance_type, instance_type=self.launch_config.instance_type,
tags={'instance': propagated_tags}, tags={"instance": propagated_tags},
placement=random.choice(self.availability_zones), placement=random.choice(self.availability_zones),
) )
for instance in reservation.instances: for instance in reservation.instances:
instance.autoscaling_group = self instance.autoscaling_group = self
self.instance_states.append(InstanceState( self.instance_states.append(
InstanceState(
instance, instance,
protected_from_scale_in=self.new_instances_protected_from_scale_in, protected_from_scale_in=self.new_instances_protected_from_scale_in,
)) )
)
def append_target_groups(self, target_group_arns): def append_target_groups(self, target_group_arns):
append = [x for x in target_group_arns if x not in self.target_group_arns] append = [x for x in target_group_arns if x not in self.target_group_arns]
@ -402,10 +495,23 @@ class AutoScalingBackend(BaseBackend):
self.__dict__ = {} self.__dict__ = {}
self.__init__(ec2_backend, elb_backend, elbv2_backend) self.__init__(ec2_backend, elb_backend, elbv2_backend)
def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id, def create_launch_configuration(
security_groups, user_data, instance_type, self,
instance_monitoring, instance_profile_name, name,
spot_price, ebs_optimized, associate_public_ip_address, block_device_mappings): image_id,
key_name,
kernel_id,
ramdisk_id,
security_groups,
user_data,
instance_type,
instance_monitoring,
instance_profile_name,
spot_price,
ebs_optimized,
associate_public_ip_address,
block_device_mappings,
):
launch_configuration = FakeLaunchConfiguration( launch_configuration = FakeLaunchConfiguration(
name=name, name=name,
image_id=image_id, image_id=image_id,
@ -428,23 +534,37 @@ class AutoScalingBackend(BaseBackend):
def describe_launch_configurations(self, names): def describe_launch_configurations(self, names):
configurations = self.launch_configurations.values() configurations = self.launch_configurations.values()
if names: if names:
return [configuration for configuration in configurations if configuration.name in names] return [
configuration
for configuration in configurations
if configuration.name in names
]
else: else:
return list(configurations) return list(configurations)
def delete_launch_configuration(self, launch_configuration_name): def delete_launch_configuration(self, launch_configuration_name):
self.launch_configurations.pop(launch_configuration_name, None) self.launch_configurations.pop(launch_configuration_name, None)
def create_auto_scaling_group(self, name, availability_zones, def create_auto_scaling_group(
desired_capacity, max_size, min_size, self,
launch_config_name, vpc_zone_identifier, name,
default_cooldown, health_check_period, availability_zones,
health_check_type, load_balancers, desired_capacity,
target_group_arns, placement_group, max_size,
termination_policies, tags, min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
load_balancers,
target_group_arns,
placement_group,
termination_policies,
tags,
new_instances_protected_from_scale_in=False, new_instances_protected_from_scale_in=False,
instance_id=None): instance_id=None,
):
def make_int(value): def make_int(value):
return int(value) if value is not None else value return int(value) if value is not None else value
@ -460,7 +580,9 @@ class AutoScalingBackend(BaseBackend):
try: try:
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
launch_config_name = name launch_config_name = name
FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self) FakeLaunchConfiguration.create_from_instance(
launch_config_name, instance, self
)
except InvalidInstanceIdError: except InvalidInstanceIdError:
raise InvalidInstanceError(instance_id) raise InvalidInstanceError(instance_id)
@ -489,19 +611,37 @@ class AutoScalingBackend(BaseBackend):
self.update_attached_target_groups(group.name) self.update_attached_target_groups(group.name)
return group return group
def update_auto_scaling_group(self, name, availability_zones, def update_auto_scaling_group(
desired_capacity, max_size, min_size, self,
launch_config_name, vpc_zone_identifier, name,
default_cooldown, health_check_period, availability_zones,
health_check_type, placement_group, desired_capacity,
max_size,
min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies, termination_policies,
new_instances_protected_from_scale_in=None): new_instances_protected_from_scale_in=None,
):
group = self.autoscaling_groups[name] group = self.autoscaling_groups[name]
group.update(availability_zones, desired_capacity, max_size, group.update(
min_size, launch_config_name, vpc_zone_identifier, availability_zones,
default_cooldown, health_check_period, health_check_type, desired_capacity,
placement_group, termination_policies, max_size,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in) min_size,
launch_config_name,
vpc_zone_identifier,
default_cooldown,
health_check_period,
health_check_type,
placement_group,
termination_policies,
new_instances_protected_from_scale_in=new_instances_protected_from_scale_in,
)
return group return group
def describe_auto_scaling_groups(self, names): def describe_auto_scaling_groups(self, names):
@ -537,32 +677,48 @@ class AutoScalingBackend(BaseBackend):
for x in instance_ids for x in instance_ids
] ]
for instance in new_instances: for instance in new_instances:
self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) self.ec2_backend.create_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
group.instance_states.extend(new_instances) group.instance_states.extend(new_instances)
self.update_attached_elbs(group.name) self.update_attached_elbs(group.name)
def set_instance_health(self, instance_id, health_status, should_respect_grace_period): def set_instance_health(
self, instance_id, health_status, should_respect_grace_period
):
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
instance_state = next(instance_state for group in self.autoscaling_groups.values() instance_state = next(
for instance_state in group.instance_states if instance_state.instance.id == instance.id) instance_state
for group in self.autoscaling_groups.values()
for instance_state in group.instance_states
if instance_state.instance.id == instance.id
)
instance_state.health_status = health_status instance_state.health_status = health_status
def detach_instances(self, group_name, instance_ids, should_decrement): def detach_instances(self, group_name, instance_ids, should_decrement):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
original_size = len(group.instance_states) original_size = len(group.instance_states)
detached_instances = [x for x in group.instance_states if x.instance.id in instance_ids] detached_instances = [
x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in detached_instances: for instance in detached_instances:
self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) self.ec2_backend.delete_tags(
[instance.instance.id], {ASG_NAME_TAG: group.name}
)
new_instance_state = [x for x in group.instance_states if x.instance.id not in instance_ids] new_instance_state = [
x for x in group.instance_states if x.instance.id not in instance_ids
]
group.instance_states = new_instance_state group.instance_states = new_instance_state
if should_decrement: if should_decrement:
group.desired_capacity = original_size - len(instance_ids) group.desired_capacity = original_size - len(instance_ids)
else: else:
count_needed = len(instance_ids) count_needed = len(instance_ids)
group.replace_autoscaling_group_instances(count_needed, group.get_propagated_tags()) group.replace_autoscaling_group_instances(
count_needed, group.get_propagated_tags()
)
self.update_attached_elbs(group_name) self.update_attached_elbs(group_name)
return detached_instances return detached_instances
@ -593,19 +749,32 @@ class AutoScalingBackend(BaseBackend):
desired_capacity = int(desired_capacity) desired_capacity = int(desired_capacity)
self.set_desired_capacity(group_name, desired_capacity) self.set_desired_capacity(group_name, desired_capacity)
def create_autoscaling_policy(self, name, policy_type, adjustment_type, as_name, def create_autoscaling_policy(
scaling_adjustment, cooldown): self, name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown
policy = FakeScalingPolicy(name, policy_type, adjustment_type, as_name, ):
scaling_adjustment, cooldown, self) policy = FakeScalingPolicy(
name,
policy_type,
adjustment_type,
as_name,
scaling_adjustment,
cooldown,
self,
)
self.policies[name] = policy self.policies[name] = policy
return policy return policy
def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None): def describe_policies(
return [policy for policy in self.policies.values() self, autoscaling_group_name=None, policy_names=None, policy_types=None
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and ):
(not policy_names or policy.name in policy_names) and return [
(not policy_types or policy.policy_type in policy_types)] policy
for policy in self.policies.values()
if (not autoscaling_group_name or policy.as_name == autoscaling_group_name)
and (not policy_names or policy.name in policy_names)
and (not policy_types or policy.policy_type in policy_types)
]
def delete_policy(self, group_name): def delete_policy(self, group_name):
self.policies.pop(group_name, None) self.policies.pop(group_name, None)
@ -616,16 +785,14 @@ class AutoScalingBackend(BaseBackend):
def update_attached_elbs(self, group_name): def update_attached_elbs(self, group_name):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group_instance_ids = set( group_instance_ids = set(state.instance.id for state in group.instance_states)
state.instance.id for state in group.instance_states)
# skip this if group.load_balancers is empty # skip this if group.load_balancers is empty
# otherwise elb_backend.describe_load_balancers returns all available load balancers # otherwise elb_backend.describe_load_balancers returns all available load balancers
if not group.load_balancers: if not group.load_balancers:
return return
try: try:
elbs = self.elb_backend.describe_load_balancers( elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
names=group.load_balancers)
except LoadBalancerNotFoundError: except LoadBalancerNotFoundError:
# ELBs can be deleted before their autoscaling group # ELBs can be deleted before their autoscaling group
return return
@ -633,14 +800,15 @@ class AutoScalingBackend(BaseBackend):
for elb in elbs: for elb in elbs:
elb_instace_ids = set(elb.instance_ids) elb_instace_ids = set(elb.instance_ids)
self.elb_backend.register_instances( self.elb_backend.register_instances(
elb.name, group_instance_ids - elb_instace_ids) elb.name, group_instance_ids - elb_instace_ids
)
self.elb_backend.deregister_instances( self.elb_backend.deregister_instances(
elb.name, elb_instace_ids - group_instance_ids) elb.name, elb_instace_ids - group_instance_ids
)
def update_attached_target_groups(self, group_name): def update_attached_target_groups(self, group_name):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group_instance_ids = set( group_instance_ids = set(state.instance.id for state in group.instance_states)
state.instance.id for state in group.instance_states)
# no action necessary if target_group_arns is empty # no action necessary if target_group_arns is empty
if not group.target_group_arns: if not group.target_group_arns:
@ -649,10 +817,13 @@ class AutoScalingBackend(BaseBackend):
target_groups = self.elbv2_backend.describe_target_groups( target_groups = self.elbv2_backend.describe_target_groups(
target_group_arns=group.target_group_arns, target_group_arns=group.target_group_arns,
load_balancer_arn=None, load_balancer_arn=None,
names=None) names=None,
)
for target_group in target_groups: for target_group in target_groups:
asg_targets = [{'id': x, 'port': target_group.port} for x in group_instance_ids] asg_targets = [
{"id": x, "port": target_group.port} for x in group_instance_ids
]
self.elbv2_backend.register_targets(target_group.arn, (asg_targets)) self.elbv2_backend.register_targets(target_group.arn, (asg_targets))
def create_or_update_tags(self, tags): def create_or_update_tags(self, tags):
@ -670,7 +841,7 @@ class AutoScalingBackend(BaseBackend):
new_tags.append(old_tag) new_tags.append(old_tag)
# if key was never in old_tag's add it (create tag) # if key was never in old_tag's add it (create tag)
if not any(new_tag['key'] == tag['key'] for new_tag in new_tags): if not any(new_tag["key"] == tag["key"] for new_tag in new_tags):
new_tags.append(tag) new_tags.append(tag)
group.tags = new_tags group.tags = new_tags
@ -678,7 +849,8 @@ class AutoScalingBackend(BaseBackend):
def attach_load_balancers(self, group_name, load_balancer_names): def attach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group.load_balancers.extend( group.load_balancers.extend(
[x for x in load_balancer_names if x not in group.load_balancers]) [x for x in load_balancer_names if x not in group.load_balancers]
)
self.update_attached_elbs(group_name) self.update_attached_elbs(group_name)
def describe_load_balancers(self, group_name): def describe_load_balancers(self, group_name):
@ -686,13 +858,13 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancers(self, group_name, load_balancer_names): def detach_load_balancers(self, group_name, load_balancer_names):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group_instance_ids = set( group_instance_ids = set(state.instance.id for state in group.instance_states)
state.instance.id for state in group.instance_states)
elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers) elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers)
for elb in elbs: for elb in elbs:
self.elb_backend.deregister_instances( self.elb_backend.deregister_instances(elb.name, group_instance_ids)
elb.name, group_instance_ids) group.load_balancers = [
group.load_balancers = [x for x in group.load_balancers if x not in load_balancer_names] x for x in group.load_balancers if x not in load_balancer_names
]
def attach_load_balancer_target_groups(self, group_name, target_group_arns): def attach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
@ -704,36 +876,51 @@ class AutoScalingBackend(BaseBackend):
def detach_load_balancer_target_groups(self, group_name, target_group_arns): def detach_load_balancer_target_groups(self, group_name, target_group_arns):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group.target_group_arns = [x for x in group.target_group_arns if x not in target_group_arns] group.target_group_arns = [
x for x in group.target_group_arns if x not in target_group_arns
]
for target_group in target_group_arns: for target_group in target_group_arns:
asg_targets = [{'id': x.instance.id} for x in group.instance_states] asg_targets = [{"id": x.instance.id} for x in group.instance_states]
self.elbv2_backend.deregister_targets(target_group, (asg_targets)) self.elbv2_backend.deregister_targets(target_group, (asg_targets))
def suspend_processes(self, group_name, scaling_processes): def suspend_processes(self, group_name, scaling_processes):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
group.suspended_processes = scaling_processes or [] group.suspended_processes = scaling_processes or []
def set_instance_protection(self, group_name, instance_ids, protected_from_scale_in): def set_instance_protection(
self, group_name, instance_ids, protected_from_scale_in
):
group = self.autoscaling_groups[group_name] group = self.autoscaling_groups[group_name]
protected_instances = [ protected_instances = [
x for x in group.instance_states if x.instance.id in instance_ids] x for x in group.instance_states if x.instance.id in instance_ids
]
for instance in protected_instances: for instance in protected_instances:
instance.protected_from_scale_in = protected_from_scale_in instance.protected_from_scale_in = protected_from_scale_in
def notify_terminate_instances(self, instance_ids): def notify_terminate_instances(self, instance_ids):
for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items(): for (
autoscaling_group_name,
autoscaling_group,
) in self.autoscaling_groups.items():
original_instance_count = len(autoscaling_group.instance_states) original_instance_count = len(autoscaling_group.instance_states)
autoscaling_group.instance_states = list(filter( autoscaling_group.instance_states = list(
filter(
lambda i_state: i_state.instance.id not in instance_ids, lambda i_state: i_state.instance.id not in instance_ids,
autoscaling_group.instance_states,
)
)
difference = original_instance_count - len(
autoscaling_group.instance_states autoscaling_group.instance_states
)) )
difference = original_instance_count - len(autoscaling_group.instance_states)
if difference > 0: if difference > 0:
autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags()) autoscaling_group.replace_autoscaling_group_instances(
difference, autoscaling_group.get_propagated_tags()
)
self.update_attached_elbs(autoscaling_group_name) self.update_attached_elbs(autoscaling_group_name)
autoscaling_backends = {} autoscaling_backends = {}
for region, ec2_backend in ec2_backends.items(): for region, ec2_backend in ec2_backends.items():
autoscaling_backends[region] = AutoScalingBackend( autoscaling_backends[region] = AutoScalingBackend(
ec2_backend, elb_backends[region], elbv2_backends[region]) ec2_backend, elb_backends[region], elbv2_backends[region]
)

View File

@ -6,88 +6,88 @@ from .models import autoscaling_backends
class AutoScalingResponse(BaseResponse): class AutoScalingResponse(BaseResponse):
@property @property
def autoscaling_backend(self): def autoscaling_backend(self):
return autoscaling_backends[self.region] return autoscaling_backends[self.region]
def create_launch_configuration(self): def create_launch_configuration(self):
instance_monitoring_string = self._get_param( instance_monitoring_string = self._get_param("InstanceMonitoring.Enabled")
'InstanceMonitoring.Enabled') if instance_monitoring_string == "true":
if instance_monitoring_string == 'true':
instance_monitoring = True instance_monitoring = True
else: else:
instance_monitoring = False instance_monitoring = False
self.autoscaling_backend.create_launch_configuration( self.autoscaling_backend.create_launch_configuration(
name=self._get_param('LaunchConfigurationName'), name=self._get_param("LaunchConfigurationName"),
image_id=self._get_param('ImageId'), image_id=self._get_param("ImageId"),
key_name=self._get_param('KeyName'), key_name=self._get_param("KeyName"),
ramdisk_id=self._get_param('RamdiskId'), ramdisk_id=self._get_param("RamdiskId"),
kernel_id=self._get_param('KernelId'), kernel_id=self._get_param("KernelId"),
security_groups=self._get_multi_param('SecurityGroups.member'), security_groups=self._get_multi_param("SecurityGroups.member"),
user_data=self._get_param('UserData'), user_data=self._get_param("UserData"),
instance_type=self._get_param('InstanceType'), instance_type=self._get_param("InstanceType"),
instance_monitoring=instance_monitoring, instance_monitoring=instance_monitoring,
instance_profile_name=self._get_param('IamInstanceProfile'), instance_profile_name=self._get_param("IamInstanceProfile"),
spot_price=self._get_param('SpotPrice'), spot_price=self._get_param("SpotPrice"),
ebs_optimized=self._get_param('EbsOptimized'), ebs_optimized=self._get_param("EbsOptimized"),
associate_public_ip_address=self._get_param( associate_public_ip_address=self._get_param("AssociatePublicIpAddress"),
"AssociatePublicIpAddress"), block_device_mappings=self._get_list_prefix("BlockDeviceMappings.member"),
block_device_mappings=self._get_list_prefix(
'BlockDeviceMappings.member')
) )
template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE) template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render() return template.render()
def describe_launch_configurations(self): def describe_launch_configurations(self):
names = self._get_multi_param('LaunchConfigurationNames.member') names = self._get_multi_param("LaunchConfigurationNames.member")
all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(names) all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(
marker = self._get_param('NextToken') names
)
marker = self._get_param("NextToken")
all_names = [lc.name for lc in all_launch_configurations] all_names = [lc.name for lc in all_launch_configurations]
if marker: if marker:
start = all_names.index(marker) + 1 start = all_names.index(marker) + 1
else: else:
start = 0 start = 0
max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier max_records = self._get_int_param(
launch_configurations_resp = all_launch_configurations[start:start + max_records] "MaxRecords", 50
) # the default is 100, but using 50 to make testing easier
launch_configurations_resp = all_launch_configurations[
start : start + max_records
]
next_token = None next_token = None
if len(all_launch_configurations) > start + max_records: if len(all_launch_configurations) > start + max_records:
next_token = launch_configurations_resp[-1].name next_token = launch_configurations_resp[-1].name
template = self.response_template( template = self.response_template(DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE)
DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE) return template.render(
return template.render(launch_configurations=launch_configurations_resp, next_token=next_token) launch_configurations=launch_configurations_resp, next_token=next_token
)
def delete_launch_configuration(self): def delete_launch_configuration(self):
launch_configurations_name = self.querystring.get( launch_configurations_name = self.querystring.get("LaunchConfigurationName")[0]
'LaunchConfigurationName')[0] self.autoscaling_backend.delete_launch_configuration(launch_configurations_name)
self.autoscaling_backend.delete_launch_configuration(
launch_configurations_name)
template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE) template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE)
return template.render() return template.render()
def create_auto_scaling_group(self): def create_auto_scaling_group(self):
self.autoscaling_backend.create_auto_scaling_group( self.autoscaling_backend.create_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'), name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param( availability_zones=self._get_multi_param("AvailabilityZones.member"),
'AvailabilityZones.member'), desired_capacity=self._get_int_param("DesiredCapacity"),
desired_capacity=self._get_int_param('DesiredCapacity'), max_size=self._get_int_param("MaxSize"),
max_size=self._get_int_param('MaxSize'), min_size=self._get_int_param("MinSize"),
min_size=self._get_int_param('MinSize'), instance_id=self._get_param("InstanceId"),
instance_id=self._get_param('InstanceId'), launch_config_name=self._get_param("LaunchConfigurationName"),
launch_config_name=self._get_param('LaunchConfigurationName'), vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), default_cooldown=self._get_int_param("DefaultCooldown"),
default_cooldown=self._get_int_param('DefaultCooldown'), health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_period=self._get_int_param('HealthCheckGracePeriod'), health_check_type=self._get_param("HealthCheckType"),
health_check_type=self._get_param('HealthCheckType'), load_balancers=self._get_multi_param("LoadBalancerNames.member"),
load_balancers=self._get_multi_param('LoadBalancerNames.member'), target_group_arns=self._get_multi_param("TargetGroupARNs.member"),
target_group_arns=self._get_multi_param('TargetGroupARNs.member'), placement_group=self._get_param("PlacementGroup"),
placement_group=self._get_param('PlacementGroup'), termination_policies=self._get_multi_param("TerminationPolicies.member"),
termination_policies=self._get_multi_param( tags=self._get_list_prefix("Tags.member"),
'TerminationPolicies.member'),
tags=self._get_list_prefix('Tags.member'),
new_instances_protected_from_scale_in=self._get_bool_param( new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', False) "NewInstancesProtectedFromScaleIn", False
),
) )
template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
@ -95,68 +95,73 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def attach_instances(self): def attach_instances(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param('InstanceIds.member') instance_ids = self._get_multi_param("InstanceIds.member")
self.autoscaling_backend.attach_instances( self.autoscaling_backend.attach_instances(group_name, instance_ids)
group_name, instance_ids)
template = self.response_template(ATTACH_INSTANCES_TEMPLATE) template = self.response_template(ATTACH_INSTANCES_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def set_instance_health(self): def set_instance_health(self):
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
health_status = self._get_param("HealthStatus") health_status = self._get_param("HealthStatus")
if health_status not in ['Healthy', 'Unhealthy']: if health_status not in ["Healthy", "Unhealthy"]:
raise ValueError('Valid instance health states are: [Healthy, Unhealthy]') raise ValueError("Valid instance health states are: [Healthy, Unhealthy]")
should_respect_grace_period = self._get_param("ShouldRespectGracePeriod") should_respect_grace_period = self._get_param("ShouldRespectGracePeriod")
self.autoscaling_backend.set_instance_health(instance_id, health_status, should_respect_grace_period) self.autoscaling_backend.set_instance_health(
instance_id, health_status, should_respect_grace_period
)
template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE) template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def detach_instances(self): def detach_instances(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param('InstanceIds.member') instance_ids = self._get_multi_param("InstanceIds.member")
should_decrement_string = self._get_param('ShouldDecrementDesiredCapacity') should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity")
if should_decrement_string == 'true': if should_decrement_string == "true":
should_decrement = True should_decrement = True
else: else:
should_decrement = False should_decrement = False
detached_instances = self.autoscaling_backend.detach_instances( detached_instances = self.autoscaling_backend.detach_instances(
group_name, instance_ids, should_decrement) group_name, instance_ids, should_decrement
)
template = self.response_template(DETACH_INSTANCES_TEMPLATE) template = self.response_template(DETACH_INSTANCES_TEMPLATE)
return template.render(detached_instances=detached_instances) return template.render(detached_instances=detached_instances)
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def attach_load_balancer_target_groups(self): def attach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param('TargetGroupARNs.member') target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.attach_load_balancer_target_groups( self.autoscaling_backend.attach_load_balancer_target_groups(
group_name, target_group_arns) group_name, target_group_arns
)
template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def describe_load_balancer_target_groups(self): def describe_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups( target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups(
group_name) group_name
)
template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS) template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS)
return template.render(target_group_arns=target_group_arns) return template.render(target_group_arns=target_group_arns)
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def detach_load_balancer_target_groups(self): def detach_load_balancer_target_groups(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
target_group_arns = self._get_multi_param('TargetGroupARNs.member') target_group_arns = self._get_multi_param("TargetGroupARNs.member")
self.autoscaling_backend.detach_load_balancer_target_groups( self.autoscaling_backend.detach_load_balancer_target_groups(
group_name, target_group_arns) group_name, target_group_arns
)
template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE)
return template.render() return template.render()
@ -172,7 +177,7 @@ class AutoScalingResponse(BaseResponse):
max_records = self._get_int_param("MaxRecords", 50) max_records = self._get_int_param("MaxRecords", 50)
if max_records > 100: if max_records > 100:
raise ValueError raise ValueError
groups = all_groups[start:start + max_records] groups = all_groups[start : start + max_records]
next_token = None next_token = None
if max_records and len(all_groups) > start + max_records: if max_records and len(all_groups) > start + max_records:
next_token = groups[-1].name next_token = groups[-1].name
@ -181,42 +186,40 @@ class AutoScalingResponse(BaseResponse):
def update_auto_scaling_group(self): def update_auto_scaling_group(self):
self.autoscaling_backend.update_auto_scaling_group( self.autoscaling_backend.update_auto_scaling_group(
name=self._get_param('AutoScalingGroupName'), name=self._get_param("AutoScalingGroupName"),
availability_zones=self._get_multi_param( availability_zones=self._get_multi_param("AvailabilityZones.member"),
'AvailabilityZones.member'), desired_capacity=self._get_int_param("DesiredCapacity"),
desired_capacity=self._get_int_param('DesiredCapacity'), max_size=self._get_int_param("MaxSize"),
max_size=self._get_int_param('MaxSize'), min_size=self._get_int_param("MinSize"),
min_size=self._get_int_param('MinSize'), launch_config_name=self._get_param("LaunchConfigurationName"),
launch_config_name=self._get_param('LaunchConfigurationName'), vpc_zone_identifier=self._get_param("VPCZoneIdentifier"),
vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), default_cooldown=self._get_int_param("DefaultCooldown"),
default_cooldown=self._get_int_param('DefaultCooldown'), health_check_period=self._get_int_param("HealthCheckGracePeriod"),
health_check_period=self._get_int_param('HealthCheckGracePeriod'), health_check_type=self._get_param("HealthCheckType"),
health_check_type=self._get_param('HealthCheckType'), placement_group=self._get_param("PlacementGroup"),
placement_group=self._get_param('PlacementGroup'), termination_policies=self._get_multi_param("TerminationPolicies.member"),
termination_policies=self._get_multi_param(
'TerminationPolicies.member'),
new_instances_protected_from_scale_in=self._get_bool_param( new_instances_protected_from_scale_in=self._get_bool_param(
'NewInstancesProtectedFromScaleIn', None) "NewInstancesProtectedFromScaleIn", None
),
) )
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
def delete_auto_scaling_group(self): def delete_auto_scaling_group(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
self.autoscaling_backend.delete_auto_scaling_group(group_name) self.autoscaling_backend.delete_auto_scaling_group(group_name)
template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE)
return template.render() return template.render()
def set_desired_capacity(self): def set_desired_capacity(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
desired_capacity = self._get_int_param('DesiredCapacity') desired_capacity = self._get_int_param("DesiredCapacity")
self.autoscaling_backend.set_desired_capacity( self.autoscaling_backend.set_desired_capacity(group_name, desired_capacity)
group_name, desired_capacity)
template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE) template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE)
return template.render() return template.render()
def create_or_update_tags(self): def create_or_update_tags(self):
tags = self._get_list_prefix('Tags.member') tags = self._get_list_prefix("Tags.member")
self.autoscaling_backend.create_or_update_tags(tags) self.autoscaling_backend.create_or_update_tags(tags)
template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE)
@ -224,38 +227,38 @@ class AutoScalingResponse(BaseResponse):
def describe_auto_scaling_instances(self): def describe_auto_scaling_instances(self):
instance_states = self.autoscaling_backend.describe_auto_scaling_instances() instance_states = self.autoscaling_backend.describe_auto_scaling_instances()
template = self.response_template( template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE)
return template.render(instance_states=instance_states) return template.render(instance_states=instance_states)
def put_scaling_policy(self): def put_scaling_policy(self):
policy = self.autoscaling_backend.create_autoscaling_policy( policy = self.autoscaling_backend.create_autoscaling_policy(
name=self._get_param('PolicyName'), name=self._get_param("PolicyName"),
policy_type=self._get_param('PolicyType'), policy_type=self._get_param("PolicyType"),
adjustment_type=self._get_param('AdjustmentType'), adjustment_type=self._get_param("AdjustmentType"),
as_name=self._get_param('AutoScalingGroupName'), as_name=self._get_param("AutoScalingGroupName"),
scaling_adjustment=self._get_int_param('ScalingAdjustment'), scaling_adjustment=self._get_int_param("ScalingAdjustment"),
cooldown=self._get_int_param('Cooldown'), cooldown=self._get_int_param("Cooldown"),
) )
template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE) template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE)
return template.render(policy=policy) return template.render(policy=policy)
def describe_policies(self): def describe_policies(self):
policies = self.autoscaling_backend.describe_policies( policies = self.autoscaling_backend.describe_policies(
autoscaling_group_name=self._get_param('AutoScalingGroupName'), autoscaling_group_name=self._get_param("AutoScalingGroupName"),
policy_names=self._get_multi_param('PolicyNames.member'), policy_names=self._get_multi_param("PolicyNames.member"),
policy_types=self._get_multi_param('PolicyTypes.member')) policy_types=self._get_multi_param("PolicyTypes.member"),
)
template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE) template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE)
return template.render(policies=policies) return template.render(policies=policies)
def delete_policy(self): def delete_policy(self):
group_name = self._get_param('PolicyName') group_name = self._get_param("PolicyName")
self.autoscaling_backend.delete_policy(group_name) self.autoscaling_backend.delete_policy(group_name)
template = self.response_template(DELETE_POLICY_TEMPLATE) template = self.response_template(DELETE_POLICY_TEMPLATE)
return template.render() return template.render()
def execute_policy(self): def execute_policy(self):
group_name = self._get_param('PolicyName') group_name = self._get_param("PolicyName")
self.autoscaling_backend.execute_policy(group_name) self.autoscaling_backend.execute_policy(group_name)
template = self.response_template(EXECUTE_POLICY_TEMPLATE) template = self.response_template(EXECUTE_POLICY_TEMPLATE)
return template.render() return template.render()
@ -263,17 +266,16 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def attach_load_balancers(self): def attach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member") load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.attach_load_balancers( self.autoscaling_backend.attach_load_balancers(group_name, load_balancer_names)
group_name, load_balancer_names)
template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE) template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE)
return template.render() return template.render()
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def describe_load_balancers(self): def describe_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
load_balancers = self.autoscaling_backend.describe_load_balancers(group_name) load_balancers = self.autoscaling_backend.describe_load_balancers(group_name)
template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE) template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE)
return template.render(load_balancers=load_balancers) return template.render(load_balancers=load_balancers)
@ -281,26 +283,28 @@ class AutoScalingResponse(BaseResponse):
@amz_crc32 @amz_crc32
@amzn_request_id @amzn_request_id
def detach_load_balancers(self): def detach_load_balancers(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
load_balancer_names = self._get_multi_param("LoadBalancerNames.member") load_balancer_names = self._get_multi_param("LoadBalancerNames.member")
self.autoscaling_backend.detach_load_balancers( self.autoscaling_backend.detach_load_balancers(group_name, load_balancer_names)
group_name, load_balancer_names)
template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE) template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE)
return template.render() return template.render()
def suspend_processes(self): def suspend_processes(self):
autoscaling_group_name = self._get_param('AutoScalingGroupName') autoscaling_group_name = self._get_param("AutoScalingGroupName")
scaling_processes = self._get_multi_param('ScalingProcesses.member') scaling_processes = self._get_multi_param("ScalingProcesses.member")
self.autoscaling_backend.suspend_processes(autoscaling_group_name, scaling_processes) self.autoscaling_backend.suspend_processes(
autoscaling_group_name, scaling_processes
)
template = self.response_template(SUSPEND_PROCESSES_TEMPLATE) template = self.response_template(SUSPEND_PROCESSES_TEMPLATE)
return template.render() return template.render()
def set_instance_protection(self): def set_instance_protection(self):
group_name = self._get_param('AutoScalingGroupName') group_name = self._get_param("AutoScalingGroupName")
instance_ids = self._get_multi_param('InstanceIds.member') instance_ids = self._get_multi_param("InstanceIds.member")
protected_from_scale_in = self._get_bool_param('ProtectedFromScaleIn') protected_from_scale_in = self._get_bool_param("ProtectedFromScaleIn")
self.autoscaling_backend.set_instance_protection( self.autoscaling_backend.set_instance_protection(
group_name, instance_ids, protected_from_scale_in) group_name, instance_ids, protected_from_scale_in
)
template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE) template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE)
return template.render() return template.render()

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import AutoScalingResponse from .responses import AutoScalingResponse
url_bases = [ url_bases = ["https?://autoscaling.(.+).amazonaws.com"]
"https?://autoscaling.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": AutoScalingResponse.dispatch}
'{0}/$': AutoScalingResponse.dispatch,
}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import lambda_backends from .models import lambda_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
lambda_backend = lambda_backends['us-east-1'] lambda_backend = lambda_backends["us-east-1"]
mock_lambda = base_decorator(lambda_backends) mock_lambda = base_decorator(lambda_backends)
mock_lambda_deprecated = deprecated_base_decorator(lambda_backends) mock_lambda_deprecated = deprecated_base_decorator(lambda_backends)

View File

@ -38,7 +38,7 @@ from moto.dynamodbstreams import dynamodbstreams_backends
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ACCOUNT_ID = '123456789012' ACCOUNT_ID = "123456789012"
try: try:
@ -47,20 +47,22 @@ except ImportError:
from backports.tempfile import TemporaryDirectory from backports.tempfile import TemporaryDirectory
_stderr_regex = re.compile(r'START|END|REPORT RequestId: .*') _stderr_regex = re.compile(r"START|END|REPORT RequestId: .*")
_orig_adapter_send = requests.adapters.HTTPAdapter.send _orig_adapter_send = requests.adapters.HTTPAdapter.send
docker_3 = docker.__version__[0] >= '3' docker_3 = docker.__version__[0] >= "3"
def zip2tar(zip_bytes): def zip2tar(zip_bytes):
with TemporaryDirectory() as td: with TemporaryDirectory() as td:
tarname = os.path.join(td, 'data.tar') tarname = os.path.join(td, "data.tar")
timeshift = int((datetime.datetime.now() - timeshift = int(
datetime.datetime.utcnow()).total_seconds()) (datetime.datetime.now() - datetime.datetime.utcnow()).total_seconds()
with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zipf, \ )
tarfile.TarFile(tarname, 'w') as tarf: with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zipf, tarfile.TarFile(
tarname, "w"
) as tarf:
for zipinfo in zipf.infolist(): for zipinfo in zipf.infolist():
if zipinfo.filename[-1] == '/': # is_dir() is py3.6+ if zipinfo.filename[-1] == "/": # is_dir() is py3.6+
continue continue
tarinfo = tarfile.TarInfo(name=zipinfo.filename) tarinfo = tarfile.TarInfo(name=zipinfo.filename)
@ -69,7 +71,7 @@ def zip2tar(zip_bytes):
infile = zipf.open(zipinfo.filename) infile = zipf.open(zipinfo.filename)
tarf.addfile(tarinfo, infile) tarf.addfile(tarinfo, infile)
with open(tarname, 'rb') as f: with open(tarname, "rb") as f:
tar_data = f.read() tar_data = f.read()
return tar_data return tar_data
@ -83,7 +85,9 @@ class _VolumeRefCount:
class _DockerDataVolumeContext: class _DockerDataVolumeContext:
_data_vol_map = defaultdict(lambda: _VolumeRefCount(0, None)) # {sha256: _VolumeRefCount} _data_vol_map = defaultdict(
lambda: _VolumeRefCount(0, None)
) # {sha256: _VolumeRefCount}
_lock = threading.Lock() _lock = threading.Lock()
def __init__(self, lambda_func): def __init__(self, lambda_func):
@ -109,15 +113,19 @@ class _DockerDataVolumeContext:
return self return self
# It doesn't exist so we need to create it # It doesn't exist so we need to create it
self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(self._lambda_func.code_sha_256) self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(
self._lambda_func.code_sha_256
)
if docker_3: if docker_3:
volumes = {self.name: {'bind': '/tmp/data', 'mode': 'rw'}} volumes = {self.name: {"bind": "/tmp/data", "mode": "rw"}}
else: else:
volumes = {self.name: '/tmp/data'} volumes = {self.name: "/tmp/data"}
container = self._lambda_func.docker_client.containers.run('alpine', 'sleep 100', volumes=volumes, detach=True) container = self._lambda_func.docker_client.containers.run(
"alpine", "sleep 100", volumes=volumes, detach=True
)
try: try:
tar_bytes = zip2tar(self._lambda_func.code_bytes) tar_bytes = zip2tar(self._lambda_func.code_bytes)
container.put_archive('/tmp/data', tar_bytes) container.put_archive("/tmp/data", tar_bytes)
finally: finally:
container.remove(force=True) container.remove(force=True)
@ -140,13 +148,13 @@ class LambdaFunction(BaseModel):
def __init__(self, spec, region, validate_s3=True, version=1): def __init__(self, spec, region, validate_s3=True, version=1):
# required # required
self.region = region self.region = region
self.code = spec['Code'] self.code = spec["Code"]
self.function_name = spec['FunctionName'] self.function_name = spec["FunctionName"]
self.handler = spec['Handler'] self.handler = spec["Handler"]
self.role = spec['Role'] self.role = spec["Role"]
self.run_time = spec['Runtime'] self.run_time = spec["Runtime"]
self.logs_backend = logs_backends[self.region] self.logs_backend = logs_backends[self.region]
self.environment_vars = spec.get('Environment', {}).get('Variables', {}) self.environment_vars = spec.get("Environment", {}).get("Variables", {})
self.docker_client = docker.from_env() self.docker_client = docker.from_env()
self.policy = "" self.policy = ""
@ -161,77 +169,82 @@ class LambdaFunction(BaseModel):
if isinstance(adapter, requests.adapters.HTTPAdapter): if isinstance(adapter, requests.adapters.HTTPAdapter):
adapter.send = functools.partial(_orig_adapter_send, adapter) adapter.send = functools.partial(_orig_adapter_send, adapter)
return adapter return adapter
self.docker_client.api.get_adapter = replace_adapter_send self.docker_client.api.get_adapter = replace_adapter_send
# optional # optional
self.description = spec.get('Description', '') self.description = spec.get("Description", "")
self.memory_size = spec.get('MemorySize', 128) self.memory_size = spec.get("MemorySize", 128)
self.publish = spec.get('Publish', False) # this is ignored currently self.publish = spec.get("Publish", False) # this is ignored currently
self.timeout = spec.get('Timeout', 3) self.timeout = spec.get("Timeout", 3)
self.logs_group_name = '/aws/lambda/{}'.format(self.function_name) self.logs_group_name = "/aws/lambda/{}".format(self.function_name)
self.logs_backend.ensure_log_group(self.logs_group_name, []) self.logs_backend.ensure_log_group(self.logs_group_name, [])
# this isn't finished yet. it needs to find out the VpcId value # this isn't finished yet. it needs to find out the VpcId value
self._vpc_config = spec.get( self._vpc_config = spec.get(
'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []}) "VpcConfig", {"SubnetIds": [], "SecurityGroupIds": []}
)
# auto-generated # auto-generated
self.version = version self.version = version
self.last_modified = datetime.datetime.utcnow().strftime( self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
'%Y-%m-%d %H:%M:%S')
if 'ZipFile' in self.code: if "ZipFile" in self.code:
# more hackery to handle unicode/bytes/str in python3 and python2 - # more hackery to handle unicode/bytes/str in python3 and python2 -
# argh! # argh!
try: try:
to_unzip_code = base64.b64decode( to_unzip_code = base64.b64decode(bytes(self.code["ZipFile"], "utf-8"))
bytes(self.code['ZipFile'], 'utf-8'))
except Exception: except Exception:
to_unzip_code = base64.b64decode(self.code['ZipFile']) to_unzip_code = base64.b64decode(self.code["ZipFile"])
self.code_bytes = to_unzip_code self.code_bytes = to_unzip_code
self.code_size = len(to_unzip_code) self.code_size = len(to_unzip_code)
self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest() self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest()
# TODO: we should be putting this in a lambda bucket # TODO: we should be putting this in a lambda bucket
self.code['UUID'] = str(uuid.uuid4()) self.code["UUID"] = str(uuid.uuid4())
self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID']) self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"])
else: else:
# validate s3 bucket and key # validate s3 bucket and key
key = None key = None
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_key( key = s3_backend.get_key(self.code["S3Bucket"], self.code["S3Key"])
self.code['S3Bucket'], self.code['S3Key'])
except MissingBucket: except MissingBucket:
if do_validate_s3(): if do_validate_s3():
raise ValueError( raise ValueError(
"InvalidParameterValueException", "InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist") "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist",
)
except MissingKey: except MissingKey:
if do_validate_s3(): if do_validate_s3():
raise ValueError( raise ValueError(
"InvalidParameterValueException", "InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.") "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.",
)
if key: if key:
self.code_bytes = key.value self.code_bytes = key.value
self.code_size = key.size self.code_size = key.size
self.code_sha_256 = hashlib.sha256(key.value).hexdigest() self.code_sha_256 = hashlib.sha256(key.value).hexdigest()
self.function_arn = make_function_arn(self.region, ACCOUNT_ID, self.function_name) self.function_arn = make_function_arn(
self.region, ACCOUNT_ID, self.function_name
)
self.tags = dict() self.tags = dict()
def set_version(self, version): def set_version(self, version):
self.function_arn = make_function_ver_arn(self.region, ACCOUNT_ID, self.function_name, version) self.function_arn = make_function_ver_arn(
self.region, ACCOUNT_ID, self.function_name, version
)
self.version = version self.version = version
self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
@property @property
def vpc_config(self): def vpc_config(self):
config = self._vpc_config.copy() config = self._vpc_config.copy()
if config['SecurityGroupIds']: if config["SecurityGroupIds"]:
config.update({"VpcId": "vpc-123abc"}) config.update({"VpcId": "vpc-123abc"})
return config return config
@ -260,17 +273,17 @@ class LambdaFunction(BaseModel):
} }
if self.environment_vars: if self.environment_vars:
config['Environment'] = { config["Environment"] = {"Variables": self.environment_vars}
'Variables': self.environment_vars
}
return config return config
def get_code(self): def get_code(self):
return { return {
"Code": { "Code": {
"Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(self.region, self.code['S3Key']), "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(
"RepositoryType": "S3" self.region, self.code["S3Key"]
),
"RepositoryType": "S3",
}, },
"Configuration": self.get_configuration(), "Configuration": self.get_configuration(),
} }
@ -295,43 +308,48 @@ class LambdaFunction(BaseModel):
return self.get_configuration() return self.get_configuration()
def update_function_code(self, updated_spec): def update_function_code(self, updated_spec):
if 'DryRun' in updated_spec and updated_spec['DryRun']: if "DryRun" in updated_spec and updated_spec["DryRun"]:
return self.get_configuration() return self.get_configuration()
if 'ZipFile' in updated_spec: if "ZipFile" in updated_spec:
self.code['ZipFile'] = updated_spec['ZipFile'] self.code["ZipFile"] = updated_spec["ZipFile"]
# using the "hackery" from __init__ because it seems to work # using the "hackery" from __init__ because it seems to work
# TODOs and FIXMEs included, because they'll need to be fixed # TODOs and FIXMEs included, because they'll need to be fixed
# in both places now # in both places now
try: try:
to_unzip_code = base64.b64decode( to_unzip_code = base64.b64decode(
bytes(updated_spec['ZipFile'], 'utf-8')) bytes(updated_spec["ZipFile"], "utf-8")
)
except Exception: except Exception:
to_unzip_code = base64.b64decode(updated_spec['ZipFile']) to_unzip_code = base64.b64decode(updated_spec["ZipFile"])
self.code_bytes = to_unzip_code self.code_bytes = to_unzip_code
self.code_size = len(to_unzip_code) self.code_size = len(to_unzip_code)
self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest() self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest()
# TODO: we should be putting this in a lambda bucket # TODO: we should be putting this in a lambda bucket
self.code['UUID'] = str(uuid.uuid4()) self.code["UUID"] = str(uuid.uuid4())
self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID']) self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"])
elif 'S3Bucket' in updated_spec and 'S3Key' in updated_spec: elif "S3Bucket" in updated_spec and "S3Key" in updated_spec:
key = None key = None
try: try:
# FIXME: does not validate bucket region # FIXME: does not validate bucket region
key = s3_backend.get_key(updated_spec['S3Bucket'], updated_spec['S3Key']) key = s3_backend.get_key(
updated_spec["S3Bucket"], updated_spec["S3Key"]
)
except MissingBucket: except MissingBucket:
if do_validate_s3(): if do_validate_s3():
raise ValueError( raise ValueError(
"InvalidParameterValueException", "InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist") "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist",
)
except MissingKey: except MissingKey:
if do_validate_s3(): if do_validate_s3():
raise ValueError( raise ValueError(
"InvalidParameterValueException", "InvalidParameterValueException",
"Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.") "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.",
)
if key: if key:
self.code_bytes = key.value self.code_bytes = key.value
self.code_size = key.size self.code_size = key.size
@ -342,7 +360,7 @@ class LambdaFunction(BaseModel):
@staticmethod @staticmethod
def convert(s): def convert(s):
try: try:
return str(s, encoding='utf-8') return str(s, encoding="utf-8")
except Exception: except Exception:
return s return s
@ -370,12 +388,21 @@ class LambdaFunction(BaseModel):
container = output = exit_code = None container = output = exit_code = None
with _DockerDataVolumeContext(self) as data_vol: with _DockerDataVolumeContext(self) as data_vol:
try: try:
run_kwargs = dict(links={'motoserver': 'motoserver'}) if settings.TEST_SERVER_MODE else {} run_kwargs = (
dict(links={"motoserver": "motoserver"})
if settings.TEST_SERVER_MODE
else {}
)
container = self.docker_client.containers.run( container = self.docker_client.containers.run(
"lambci/lambda:{}".format(self.run_time), "lambci/lambda:{}".format(self.run_time),
[self.handler, json.dumps(event)], remove=False, [self.handler, json.dumps(event)],
remove=False,
mem_limit="{}m".format(self.memory_size), mem_limit="{}m".format(self.memory_size),
volumes=["{}:/var/task".format(data_vol.name)], environment=env_vars, detach=True, **run_kwargs) volumes=["{}:/var/task".format(data_vol.name)],
environment=env_vars,
detach=True,
**run_kwargs
)
finally: finally:
if container: if container:
try: try:
@ -386,32 +413,43 @@ class LambdaFunction(BaseModel):
container.kill() container.kill()
else: else:
if docker_3: if docker_3:
exit_code = exit_code['StatusCode'] exit_code = exit_code["StatusCode"]
output = container.logs(stdout=False, stderr=True) output = container.logs(stdout=False, stderr=True)
output += container.logs(stdout=True, stderr=False) output += container.logs(stdout=True, stderr=False)
container.remove() container.remove()
output = output.decode('utf-8') output = output.decode("utf-8")
# Send output to "logs" backend # Send output to "logs" backend
invoke_id = uuid.uuid4().hex invoke_id = uuid.uuid4().hex
log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format( log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format(
date=datetime.datetime.utcnow(), version=self.version, invoke_id=invoke_id date=datetime.datetime.utcnow(),
version=self.version,
invoke_id=invoke_id,
) )
self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name) self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name)
log_events = [{'timestamp': unix_time_millis(), "message": line} log_events = [
for line in output.splitlines()] {"timestamp": unix_time_millis(), "message": line}
self.logs_backend.put_log_events(self.logs_group_name, log_stream_name, log_events, None) for line in output.splitlines()
]
self.logs_backend.put_log_events(
self.logs_group_name, log_stream_name, log_events, None
)
if exit_code != 0: if exit_code != 0:
raise Exception( raise Exception("lambda invoke failed output: {}".format(output))
'lambda invoke failed output: {}'.format(output))
# strip out RequestId lines # strip out RequestId lines
output = os.linesep.join([line for line in self.convert(output).splitlines() if not _stderr_regex.match(line)]) output = os.linesep.join(
[
line
for line in self.convert(output).splitlines()
if not _stderr_regex.match(line)
]
)
return output, False return output, False
except BaseException as e: except BaseException as e:
traceback.print_exc() traceback.print_exc()
@ -426,31 +464,34 @@ class LambdaFunction(BaseModel):
# Get the invocation type: # Get the invocation type:
res, errored = self._invoke_lambda(code=self.code, event=body) res, errored = self._invoke_lambda(code=self.code, event=body)
if request_headers.get("x-amz-invocation-type") == "RequestResponse": if request_headers.get("x-amz-invocation-type") == "RequestResponse":
encoded = base64.b64encode(res.encode('utf-8')) encoded = base64.b64encode(res.encode("utf-8"))
response_headers["x-amz-log-result"] = encoded.decode('utf-8') response_headers["x-amz-log-result"] = encoded.decode("utf-8")
payload['result'] = response_headers["x-amz-log-result"] payload["result"] = response_headers["x-amz-log-result"]
result = res.encode('utf-8') result = res.encode("utf-8")
else: else:
result = json.dumps(payload) result = json.dumps(payload)
if errored: if errored:
response_headers['x-amz-function-error'] = "Handled" response_headers["x-amz-function-error"] = "Handled"
return result return result
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, def create_from_cloudformation_json(
region_name): cls, resource_name, cloudformation_json, region_name
properties = cloudformation_json['Properties'] ):
properties = cloudformation_json["Properties"]
# required # required
spec = { spec = {
'Code': properties['Code'], "Code": properties["Code"],
'FunctionName': resource_name, "FunctionName": resource_name,
'Handler': properties['Handler'], "Handler": properties["Handler"],
'Role': properties['Role'], "Role": properties["Role"],
'Runtime': properties['Runtime'], "Runtime": properties["Runtime"],
} }
optional_properties = 'Description MemorySize Publish Timeout VpcConfig Environment'.split() optional_properties = (
"Description MemorySize Publish Timeout VpcConfig Environment".split()
)
# NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the # NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the
# default logic # default logic
for prop in optional_properties: for prop in optional_properties:
@ -460,27 +501,27 @@ class LambdaFunction(BaseModel):
# when ZipFile is present in CloudFormation, per the official docs, # when ZipFile is present in CloudFormation, per the official docs,
# the code it's a plaintext code snippet up to 4096 bytes. # the code it's a plaintext code snippet up to 4096 bytes.
# this snippet converts this plaintext code to a proper base64-encoded ZIP file. # this snippet converts this plaintext code to a proper base64-encoded ZIP file.
if 'ZipFile' in properties['Code']: if "ZipFile" in properties["Code"]:
spec['Code']['ZipFile'] = base64.b64encode( spec["Code"]["ZipFile"] = base64.b64encode(
cls._create_zipfile_from_plaintext_code( cls._create_zipfile_from_plaintext_code(spec["Code"]["ZipFile"])
spec['Code']['ZipFile'])) )
backend = lambda_backends[region_name] backend = lambda_backends[region_name]
fn = backend.create_function(spec) fn = backend.create_function(spec)
return fn return fn
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import \ from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
UnformattedGetAttTemplateException
if attribute_name == 'Arn': if attribute_name == "Arn":
return make_function_arn(self.region, ACCOUNT_ID, self.function_name) return make_function_arn(self.region, ACCOUNT_ID, self.function_name)
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
@staticmethod @staticmethod
def _create_zipfile_from_plaintext_code(code): def _create_zipfile_from_plaintext_code(code):
zip_output = io.BytesIO() zip_output = io.BytesIO()
zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED) zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED)
zip_file.writestr('lambda_function.zip', code) zip_file.writestr("lambda_function.zip", code)
zip_file.close() zip_file.close()
zip_output.seek(0) zip_output.seek(0)
return zip_output.read() return zip_output.read()
@ -489,61 +530,66 @@ class LambdaFunction(BaseModel):
class EventSourceMapping(BaseModel): class EventSourceMapping(BaseModel):
def __init__(self, spec): def __init__(self, spec):
# required # required
self.function_arn = spec['FunctionArn'] self.function_arn = spec["FunctionArn"]
self.event_source_arn = spec['EventSourceArn'] self.event_source_arn = spec["EventSourceArn"]
self.uuid = str(uuid.uuid4()) self.uuid = str(uuid.uuid4())
self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple()) self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple())
# BatchSize service default/max mapping # BatchSize service default/max mapping
batch_size_map = { batch_size_map = {
'kinesis': (100, 10000), "kinesis": (100, 10000),
'dynamodb': (100, 1000), "dynamodb": (100, 1000),
'sqs': (10, 10), "sqs": (10, 10),
} }
source_type = self.event_source_arn.split(":")[2].lower() source_type = self.event_source_arn.split(":")[2].lower()
batch_size_entry = batch_size_map.get(source_type) batch_size_entry = batch_size_map.get(source_type)
if batch_size_entry: if batch_size_entry:
# Use service default if not provided # Use service default if not provided
batch_size = int(spec.get('BatchSize', batch_size_entry[0])) batch_size = int(spec.get("BatchSize", batch_size_entry[0]))
if batch_size > batch_size_entry[1]: if batch_size > batch_size_entry[1]:
raise ValueError("InvalidParameterValueException", raise ValueError(
"BatchSize {} exceeds the max of {}".format(batch_size, batch_size_entry[1])) "InvalidParameterValueException",
"BatchSize {} exceeds the max of {}".format(
batch_size, batch_size_entry[1]
),
)
else: else:
self.batch_size = batch_size self.batch_size = batch_size
else: else:
raise ValueError("InvalidParameterValueException", raise ValueError(
"Unsupported event source type") "InvalidParameterValueException", "Unsupported event source type"
)
# optional # optional
self.starting_position = spec.get('StartingPosition', 'TRIM_HORIZON') self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON")
self.enabled = spec.get('Enabled', True) self.enabled = spec.get("Enabled", True)
self.starting_position_timestamp = spec.get('StartingPositionTimestamp', self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None)
None)
def get_configuration(self): def get_configuration(self):
return { return {
'UUID': self.uuid, "UUID": self.uuid,
'BatchSize': self.batch_size, "BatchSize": self.batch_size,
'EventSourceArn': self.event_source_arn, "EventSourceArn": self.event_source_arn,
'FunctionArn': self.function_arn, "FunctionArn": self.function_arn,
'LastModified': self.last_modified, "LastModified": self.last_modified,
'LastProcessingResult': '', "LastProcessingResult": "",
'State': 'Enabled' if self.enabled else 'Disabled', "State": "Enabled" if self.enabled else "Disabled",
'StateTransitionReason': 'User initiated' "StateTransitionReason": "User initiated",
} }
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, def create_from_cloudformation_json(
region_name): cls, resource_name, cloudformation_json, region_name
properties = cloudformation_json['Properties'] ):
func = lambda_backends[region_name].get_function(properties['FunctionName']) properties = cloudformation_json["Properties"]
func = lambda_backends[region_name].get_function(properties["FunctionName"])
spec = { spec = {
'FunctionArn': func.function_arn, "FunctionArn": func.function_arn,
'EventSourceArn': properties['EventSourceArn'], "EventSourceArn": properties["EventSourceArn"],
'StartingPosition': properties['StartingPosition'], "StartingPosition": properties["StartingPosition"],
'BatchSize': properties.get('BatchSize', 100) "BatchSize": properties.get("BatchSize", 100),
} }
optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split() optional_properties = "BatchSize Enabled StartingPositionTimestamp".split()
for prop in optional_properties: for prop in optional_properties:
if prop in properties: if prop in properties:
spec[prop] = properties[prop] spec[prop] = properties[prop]
@ -552,20 +598,19 @@ class EventSourceMapping(BaseModel):
class LambdaVersion(BaseModel): class LambdaVersion(BaseModel):
def __init__(self, spec): def __init__(self, spec):
self.version = spec['Version'] self.version = spec["Version"]
def __repr__(self): def __repr__(self):
return str(self.logical_resource_id) return str(self.logical_resource_id)
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, def create_from_cloudformation_json(
region_name): cls, resource_name, cloudformation_json, region_name
properties = cloudformation_json['Properties'] ):
function_name = properties['FunctionName'] properties = cloudformation_json["Properties"]
function_name = properties["FunctionName"]
func = lambda_backends[region_name].publish_function(function_name) func = lambda_backends[region_name].publish_function(function_name)
spec = { spec = {"Version": func.version}
'Version': func.version
}
return LambdaVersion(spec) return LambdaVersion(spec)
@ -576,18 +621,18 @@ class LambdaStorage(object):
self._arns = weakref.WeakValueDictionary() self._arns = weakref.WeakValueDictionary()
def _get_latest(self, name): def _get_latest(self, name):
return self._functions[name]['latest'] return self._functions[name]["latest"]
def _get_version(self, name, version): def _get_version(self, name, version):
index = version - 1 index = version - 1
try: try:
return self._functions[name]['versions'][index] return self._functions[name]["versions"][index]
except IndexError: except IndexError:
return None return None
def _get_alias(self, name, alias): def _get_alias(self, name, alias):
return self._functions[name]['alias'].get(alias, None) return self._functions[name]["alias"].get(alias, None)
def get_function(self, name, qualifier=None): def get_function(self, name, qualifier=None):
if name not in self._functions: if name not in self._functions:
@ -599,15 +644,15 @@ class LambdaStorage(object):
try: try:
return self._get_version(name, int(qualifier)) return self._get_version(name, int(qualifier))
except ValueError: except ValueError:
return self._functions[name]['latest'] return self._functions[name]["latest"]
def list_versions_by_function(self, name): def list_versions_by_function(self, name):
if name not in self._functions: if name not in self._functions:
return None return None
latest = copy.copy(self._functions[name]['latest']) latest = copy.copy(self._functions[name]["latest"])
latest.function_arn += ':$LATEST' latest.function_arn += ":$LATEST"
return [latest] + self._functions[name]['versions'] return [latest] + self._functions[name]["versions"]
def get_arn(self, arn): def get_arn(self, arn):
return self._arns.get(arn, None) return self._arns.get(arn, None)
@ -621,12 +666,12 @@ class LambdaStorage(object):
:type fn: LambdaFunction :type fn: LambdaFunction
""" """
if fn.function_name in self._functions: if fn.function_name in self._functions:
self._functions[fn.function_name]['latest'] = fn self._functions[fn.function_name]["latest"] = fn
else: else:
self._functions[fn.function_name] = { self._functions[fn.function_name] = {
'latest': fn, "latest": fn,
'versions': [], "versions": [],
'alias': weakref.WeakValueDictionary() "alias": weakref.WeakValueDictionary(),
} }
self._arns[fn.function_arn] = fn self._arns[fn.function_arn] = fn
@ -634,14 +679,14 @@ class LambdaStorage(object):
def publish_function(self, name): def publish_function(self, name):
if name not in self._functions: if name not in self._functions:
return None return None
if not self._functions[name]['latest']: if not self._functions[name]["latest"]:
return None return None
new_version = len(self._functions[name]['versions']) + 1 new_version = len(self._functions[name]["versions"]) + 1
fn = copy.copy(self._functions[name]['latest']) fn = copy.copy(self._functions[name]["latest"])
fn.set_version(new_version) fn.set_version(new_version)
self._functions[name]['versions'].append(fn) self._functions[name]["versions"].append(fn)
self._arns[fn.function_arn] = fn self._arns[fn.function_arn] = fn
return fn return fn
@ -651,21 +696,24 @@ class LambdaStorage(object):
name = function.function_name name = function.function_name
if not qualifier: if not qualifier:
# Something is still reffing this so delete all arns # Something is still reffing this so delete all arns
latest = self._functions[name]['latest'].function_arn latest = self._functions[name]["latest"].function_arn
del self._arns[latest] del self._arns[latest]
for fn in self._functions[name]['versions']: for fn in self._functions[name]["versions"]:
del self._arns[fn.function_arn] del self._arns[fn.function_arn]
del self._functions[name] del self._functions[name]
return True return True
elif qualifier == '$LATEST': elif qualifier == "$LATEST":
self._functions[name]['latest'] = None self._functions[name]["latest"] = None
# If theres no functions left # If theres no functions left
if not self._functions[name]['versions'] and not self._functions[name]['latest']: if (
not self._functions[name]["versions"]
and not self._functions[name]["latest"]
):
del self._functions[name] del self._functions[name]
return True return True
@ -673,10 +721,13 @@ class LambdaStorage(object):
else: else:
fn = self.get_function(name, qualifier) fn = self.get_function(name, qualifier)
if fn: if fn:
self._functions[name]['versions'].remove(fn) self._functions[name]["versions"].remove(fn)
# If theres no functions left # If theres no functions left
if not self._functions[name]['versions'] and not self._functions[name]['latest']: if (
not self._functions[name]["versions"]
and not self._functions[name]["latest"]
):
del self._functions[name] del self._functions[name]
return True return True
@ -687,10 +738,10 @@ class LambdaStorage(object):
result = [] result = []
for function_group in self._functions.values(): for function_group in self._functions.values():
if function_group['latest'] is not None: if function_group["latest"] is not None:
result.append(function_group['latest']) result.append(function_group["latest"])
result.extend(function_group['versions']) result.extend(function_group["versions"])
return result return result
@ -707,44 +758,47 @@ class LambdaBackend(BaseBackend):
self.__init__(region_name) self.__init__(region_name)
def create_function(self, spec): def create_function(self, spec):
function_name = spec.get('FunctionName', None) function_name = spec.get("FunctionName", None)
if function_name is None: if function_name is None:
raise RESTError('InvalidParameterValueException', 'Missing FunctionName') raise RESTError("InvalidParameterValueException", "Missing FunctionName")
fn = LambdaFunction(spec, self.region_name, version='$LATEST') fn = LambdaFunction(spec, self.region_name, version="$LATEST")
self._lambdas.put_function(fn) self._lambdas.put_function(fn)
if spec.get('Publish'): if spec.get("Publish"):
ver = self.publish_function(function_name) ver = self.publish_function(function_name)
fn.version = ver.version fn.version = ver.version
return fn return fn
def create_event_source_mapping(self, spec): def create_event_source_mapping(self, spec):
required = [ required = ["EventSourceArn", "FunctionName"]
'EventSourceArn',
'FunctionName',
]
for param in required: for param in required:
if not spec.get(param): if not spec.get(param):
raise RESTError('InvalidParameterValueException', 'Missing {}'.format(param)) raise RESTError(
"InvalidParameterValueException", "Missing {}".format(param)
)
# Validate function name # Validate function name
func = self._lambdas.get_function_by_name_or_arn(spec.pop('FunctionName', '')) func = self._lambdas.get_function_by_name_or_arn(spec.pop("FunctionName", ""))
if not func: if not func:
raise RESTError('ResourceNotFoundException', 'Invalid FunctionName') raise RESTError("ResourceNotFoundException", "Invalid FunctionName")
# Validate queue # Validate queue
for queue in sqs_backends[self.region_name].queues.values(): for queue in sqs_backends[self.region_name].queues.values():
if queue.queue_arn == spec['EventSourceArn']: if queue.queue_arn == spec["EventSourceArn"]:
if queue.lambda_event_source_mappings.get('func.function_arn'): if queue.lambda_event_source_mappings.get("func.function_arn"):
# TODO: Correct exception? # TODO: Correct exception?
raise RESTError('ResourceConflictException', 'The resource already exists.') raise RESTError(
"ResourceConflictException", "The resource already exists."
)
if queue.fifo_queue: if queue.fifo_queue:
raise RESTError('InvalidParameterValueException', raise RESTError(
'{} is FIFO'.format(queue.queue_arn)) "InvalidParameterValueException",
"{} is FIFO".format(queue.queue_arn),
)
else: else:
spec.update({'FunctionArn': func.function_arn}) spec.update({"FunctionArn": func.function_arn})
esm = EventSourceMapping(spec) esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm self._event_source_mappings[esm.uuid] = esm
@ -752,16 +806,18 @@ class LambdaBackend(BaseBackend):
queue.lambda_event_source_mappings[esm.function_arn] = esm queue.lambda_event_source_mappings[esm.function_arn] = esm
return esm return esm
for stream in json.loads(dynamodbstreams_backends[self.region_name].list_streams())['Streams']: for stream in json.loads(
if stream['StreamArn'] == spec['EventSourceArn']: dynamodbstreams_backends[self.region_name].list_streams()
spec.update({'FunctionArn': func.function_arn}) )["Streams"]:
if stream["StreamArn"] == spec["EventSourceArn"]:
spec.update({"FunctionArn": func.function_arn})
esm = EventSourceMapping(spec) esm = EventSourceMapping(spec)
self._event_source_mappings[esm.uuid] = esm self._event_source_mappings[esm.uuid] = esm
table_name = stream['TableName'] table_name = stream["TableName"]
table = dynamodb_backends2[self.region_name].get_table(table_name) table = dynamodb_backends2[self.region_name].get_table(table_name)
table.lambda_event_source_mappings[esm.function_arn] = esm table.lambda_event_source_mappings[esm.function_arn] = esm
return esm return esm
raise RESTError('ResourceNotFoundException', 'Invalid EventSourceArn') raise RESTError("ResourceNotFoundException", "Invalid EventSourceArn")
def publish_function(self, function_name): def publish_function(self, function_name):
return self._lambdas.publish_function(function_name) return self._lambdas.publish_function(function_name)
@ -781,13 +837,15 @@ class LambdaBackend(BaseBackend):
def update_event_source_mapping(self, uuid, spec): def update_event_source_mapping(self, uuid, spec):
esm = self.get_event_source_mapping(uuid) esm = self.get_event_source_mapping(uuid)
if esm: if esm:
if spec.get('FunctionName'): if spec.get("FunctionName"):
func = self._lambdas.get_function_by_name_or_arn(spec.get('FunctionName')) func = self._lambdas.get_function_by_name_or_arn(
spec.get("FunctionName")
)
esm.function_arn = func.function_arn esm.function_arn = func.function_arn
if 'BatchSize' in spec: if "BatchSize" in spec:
esm.batch_size = spec['BatchSize'] esm.batch_size = spec["BatchSize"]
if 'Enabled' in spec: if "Enabled" in spec:
esm.enabled = spec['Enabled'] esm.enabled = spec["Enabled"]
return esm return esm
return False return False
@ -828,13 +886,13 @@ class LambdaBackend(BaseBackend):
"ApproximateReceiveCount": "1", "ApproximateReceiveCount": "1",
"SentTimestamp": "1545082649183", "SentTimestamp": "1545082649183",
"SenderId": "AIDAIENQZJOLO23YVJ4VO", "SenderId": "AIDAIENQZJOLO23YVJ4VO",
"ApproximateFirstReceiveTimestamp": "1545082649185" "ApproximateFirstReceiveTimestamp": "1545082649185",
}, },
"messageAttributes": {}, "messageAttributes": {},
"md5OfBody": "098f6bcd4621d373cade4e832627b4f6", "md5OfBody": "098f6bcd4621d373cade4e832627b4f6",
"eventSource": "aws:sqs", "eventSource": "aws:sqs",
"eventSourceARN": queue_arn, "eventSourceARN": queue_arn,
"awsRegion": self.region_name "awsRegion": self.region_name,
} }
] ]
} }
@ -842,7 +900,7 @@ class LambdaBackend(BaseBackend):
request_headers = {} request_headers = {}
response_headers = {} response_headers = {}
func.invoke(json.dumps(event), request_headers, response_headers) func.invoke(json.dumps(event), request_headers, response_headers)
return 'x-amz-function-error' not in response_headers return "x-amz-function-error" not in response_headers
def send_sns_message(self, function_name, message, subject=None, qualifier=None): def send_sns_message(self, function_name, message, subject=None, qualifier=None):
event = { event = {
@ -859,37 +917,35 @@ class LambdaBackend(BaseBackend):
"MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e", "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
"Message": message, "Message": message,
"MessageAttributes": { "MessageAttributes": {
"Test": { "Test": {"Type": "String", "Value": "TestString"},
"Type": "String", "TestBinary": {"Type": "Binary", "Value": "TestBinary"},
"Value": "TestString"
},
"TestBinary": {
"Type": "Binary",
"Value": "TestBinary"
}
}, },
"Type": "Notification", "Type": "Notification",
"UnsubscribeUrl": "EXAMPLE", "UnsubscribeUrl": "EXAMPLE",
"TopicArn": "arn:aws:sns:EXAMPLE", "TopicArn": "arn:aws:sns:EXAMPLE",
"Subject": subject or "TestInvoke" "Subject": subject or "TestInvoke",
} },
} }
] ]
} }
func = self._lambdas.get_function(function_name, qualifier) func = self._lambdas.get_function(function_name, qualifier)
func.invoke(json.dumps(event), {}, {}) func.invoke(json.dumps(event), {}, {})
def send_dynamodb_items(self, function_arn, items, source): def send_dynamodb_items(self, function_arn, items, source):
event = {'Records': [ event = {
"Records": [
{ {
'eventID': item.to_json()['eventID'], "eventID": item.to_json()["eventID"],
'eventName': 'INSERT', "eventName": "INSERT",
'eventVersion': item.to_json()['eventVersion'], "eventVersion": item.to_json()["eventVersion"],
'eventSource': item.to_json()['eventSource'], "eventSource": item.to_json()["eventSource"],
'awsRegion': self.region_name, "awsRegion": self.region_name,
'dynamodb': item.to_json()['dynamodb'], "dynamodb": item.to_json()["dynamodb"],
'eventSourceARN': source} for item in items]} "eventSourceARN": source,
}
for item in items
]
}
func = self._lambdas.get_arn(function_arn) func = self._lambdas.get_arn(function_arn)
func.invoke(json.dumps(event), {}, {}) func.invoke(json.dumps(event), {}, {})
@ -921,12 +977,13 @@ class LambdaBackend(BaseBackend):
def do_validate_s3(): def do_validate_s3():
return os.environ.get('VALIDATE_LAMBDA_S3', '') in ['', '1', 'true'] return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"]
# Handle us forgotten regions, unless Lambda truly only runs out of US and # Handle us forgotten regions, unless Lambda truly only runs out of US and
lambda_backends = {_region.name: LambdaBackend(_region.name) lambda_backends = {
for _region in boto.awslambda.regions()} _region.name: LambdaBackend(_region.name) for _region in boto.awslambda.regions()
}
lambda_backends['ap-southeast-2'] = LambdaBackend('ap-southeast-2') lambda_backends["ap-southeast-2"] = LambdaBackend("ap-southeast-2")
lambda_backends['us-gov-west-1'] = LambdaBackend('us-gov-west-1') lambda_backends["us-gov-west-1"] = LambdaBackend("us-gov-west-1")

View File

@ -32,57 +32,57 @@ class LambdaResponse(BaseResponse):
def root(self, request, full_url, headers): def root(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
return self._list_functions(request, full_url, headers) return self._list_functions(request, full_url, headers)
elif request.method == 'POST': elif request.method == "POST":
return self._create_function(request, full_url, headers) return self._create_function(request, full_url, headers)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def event_source_mappings(self, request, full_url, headers): def event_source_mappings(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
querystring = self.querystring querystring = self.querystring
event_source_arn = querystring.get('EventSourceArn', [None])[0] event_source_arn = querystring.get("EventSourceArn", [None])[0]
function_name = querystring.get('FunctionName', [None])[0] function_name = querystring.get("FunctionName", [None])[0]
return self._list_event_source_mappings(event_source_arn, function_name) return self._list_event_source_mappings(event_source_arn, function_name)
elif request.method == 'POST': elif request.method == "POST":
return self._create_event_source_mapping(request, full_url, headers) return self._create_event_source_mapping(request, full_url, headers)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def event_source_mapping(self, request, full_url, headers): def event_source_mapping(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
path = request.path if hasattr(request, 'path') else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
uuid = path.split('/')[-1] uuid = path.split("/")[-1]
if request.method == 'GET': if request.method == "GET":
return self._get_event_source_mapping(uuid) return self._get_event_source_mapping(uuid)
elif request.method == 'PUT': elif request.method == "PUT":
return self._update_event_source_mapping(uuid) return self._update_event_source_mapping(uuid)
elif request.method == 'DELETE': elif request.method == "DELETE":
return self._delete_event_source_mapping(uuid) return self._delete_event_source_mapping(uuid)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def function(self, request, full_url, headers): def function(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
return self._get_function(request, full_url, headers) return self._get_function(request, full_url, headers)
elif request.method == 'DELETE': elif request.method == "DELETE":
return self._delete_function(request, full_url, headers) return self._delete_function(request, full_url, headers)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def versions(self, request, full_url, headers): def versions(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
# This is ListVersionByFunction # This is ListVersionByFunction
path = request.path if hasattr(request, 'path') else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split('/')[-2] function_name = path.split("/")[-2]
return self._list_versions_by_function(function_name) return self._list_versions_by_function(function_name)
elif request.method == 'POST': elif request.method == "POST":
return self._publish_function(request, full_url, headers) return self._publish_function(request, full_url, headers)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
@ -91,7 +91,7 @@ class LambdaResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def invoke(self, request, full_url, headers): def invoke(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'POST': if request.method == "POST":
return self._invoke(request, full_url) return self._invoke(request, full_url)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
@ -100,46 +100,46 @@ class LambdaResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def invoke_async(self, request, full_url, headers): def invoke_async(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'POST': if request.method == "POST":
return self._invoke_async(request, full_url) return self._invoke_async(request, full_url)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def tag(self, request, full_url, headers): def tag(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
return self._list_tags(request, full_url) return self._list_tags(request, full_url)
elif request.method == 'POST': elif request.method == "POST":
return self._tag_resource(request, full_url) return self._tag_resource(request, full_url)
elif request.method == 'DELETE': elif request.method == "DELETE":
return self._untag_resource(request, full_url) return self._untag_resource(request, full_url)
else: else:
raise ValueError("Cannot handle {0} request".format(request.method)) raise ValueError("Cannot handle {0} request".format(request.method))
def policy(self, request, full_url, headers): def policy(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'GET': if request.method == "GET":
return self._get_policy(request, full_url, headers) return self._get_policy(request, full_url, headers)
if request.method == 'POST': if request.method == "POST":
return self._add_policy(request, full_url, headers) return self._add_policy(request, full_url, headers)
def configuration(self, request, full_url, headers): def configuration(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'PUT': if request.method == "PUT":
return self._put_configuration(request) return self._put_configuration(request)
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def code(self, request, full_url, headers): def code(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
if request.method == 'PUT': if request.method == "PUT":
return self._put_code() return self._put_code()
else: else:
raise ValueError("Cannot handle request") raise ValueError("Cannot handle request")
def _add_policy(self, request, full_url, headers): def _add_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split('/')[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
policy = self.body policy = self.body
self.lambda_backend.add_policy(function_name, policy) self.lambda_backend.add_policy(function_name, policy)
@ -148,24 +148,30 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}" return 404, {}, "{}"
def _get_policy(self, request, full_url, headers): def _get_policy(self, request, full_url, headers):
path = request.path if hasattr(request, 'path') else path_url(request.url) path = request.path if hasattr(request, "path") else path_url(request.url)
function_name = path.split('/')[-2] function_name = path.split("/")[-2]
if self.lambda_backend.get_function(function_name): if self.lambda_backend.get_function(function_name):
lambda_function = self.lambda_backend.get_function(function_name) lambda_function = self.lambda_backend.get_function(function_name)
return 200, {}, json.dumps(dict(Policy="{\"Statement\":[" + lambda_function.policy + "]}")) return (
200,
{},
json.dumps(
dict(Policy='{"Statement":[' + lambda_function.policy + "]}")
),
)
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _invoke(self, request, full_url): def _invoke(self, request, full_url):
response_headers = {} response_headers = {}
function_name = self.path.rsplit('/', 2)[-2] function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param('qualifier') qualifier = self._get_param("qualifier")
fn = self.lambda_backend.get_function(function_name, qualifier) fn = self.lambda_backend.get_function(function_name, qualifier)
if fn: if fn:
payload = fn.invoke(self.body, self.headers, response_headers) payload = fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(len(payload)) response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload return 202, response_headers, payload
else: else:
return 404, response_headers, "{}" return 404, response_headers, "{}"
@ -173,38 +179,34 @@ class LambdaResponse(BaseResponse):
def _invoke_async(self, request, full_url): def _invoke_async(self, request, full_url):
response_headers = {} response_headers = {}
function_name = self.path.rsplit('/', 3)[-3] function_name = self.path.rsplit("/", 3)[-3]
fn = self.lambda_backend.get_function(function_name, None) fn = self.lambda_backend.get_function(function_name, None)
if fn: if fn:
payload = fn.invoke(self.body, self.headers, response_headers) payload = fn.invoke(self.body, self.headers, response_headers)
response_headers['Content-Length'] = str(len(payload)) response_headers["Content-Length"] = str(len(payload))
return 202, response_headers, payload return 202, response_headers, payload
else: else:
return 404, response_headers, "{}" return 404, response_headers, "{}"
def _list_functions(self, request, full_url, headers): def _list_functions(self, request, full_url, headers):
result = { result = {"Functions": []}
'Functions': []
}
for fn in self.lambda_backend.list_functions(): for fn in self.lambda_backend.list_functions():
json_data = fn.get_configuration() json_data = fn.get_configuration()
json_data['Version'] = '$LATEST' json_data["Version"] = "$LATEST"
result['Functions'].append(json_data) result["Functions"].append(json_data)
return 200, {}, json.dumps(result) return 200, {}, json.dumps(result)
def _list_versions_by_function(self, function_name): def _list_versions_by_function(self, function_name):
result = { result = {"Versions": []}
'Versions': []
}
functions = self.lambda_backend.list_versions_by_function(function_name) functions = self.lambda_backend.list_versions_by_function(function_name)
if functions: if functions:
for fn in functions: for fn in functions:
json_data = fn.get_configuration() json_data = fn.get_configuration()
result['Versions'].append(json_data) result["Versions"].append(json_data)
return 200, {}, json.dumps(result) return 200, {}, json.dumps(result)
@ -212,7 +214,11 @@ class LambdaResponse(BaseResponse):
try: try:
fn = self.lambda_backend.create_function(self.json_body) fn = self.lambda_backend.create_function(self.json_body)
except ValueError as e: except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}) return (
400,
{},
json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}),
)
else: else:
config = fn.get_configuration() config = fn.get_configuration()
return 201, {}, json.dumps(config) return 201, {}, json.dumps(config)
@ -221,16 +227,20 @@ class LambdaResponse(BaseResponse):
try: try:
fn = self.lambda_backend.create_event_source_mapping(self.json_body) fn = self.lambda_backend.create_event_source_mapping(self.json_body)
except ValueError as e: except ValueError as e:
return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}) return (
400,
{},
json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}),
)
else: else:
config = fn.get_configuration() config = fn.get_configuration()
return 201, {}, json.dumps(config) return 201, {}, json.dumps(config)
def _list_event_source_mappings(self, event_source_arn, function_name): def _list_event_source_mappings(self, event_source_arn, function_name):
esms = self.lambda_backend.list_event_source_mappings(event_source_arn, function_name) esms = self.lambda_backend.list_event_source_mappings(
result = { event_source_arn, function_name
'EventSourceMappings': [esm.get_configuration() for esm in esms] )
} result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]}
return 200, {}, json.dumps(result) return 200, {}, json.dumps(result)
def _get_event_source_mapping(self, uuid): def _get_event_source_mapping(self, uuid):
@ -251,13 +261,13 @@ class LambdaResponse(BaseResponse):
esm = self.lambda_backend.delete_event_source_mapping(uuid) esm = self.lambda_backend.delete_event_source_mapping(uuid)
if esm: if esm:
json_result = esm.get_configuration() json_result = esm.get_configuration()
json_result.update({'State': 'Deleting'}) json_result.update({"State": "Deleting"})
return 202, {}, json.dumps(json_result) return 202, {}, json.dumps(json_result)
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _publish_function(self, request, full_url, headers): def _publish_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 2)[-2] function_name = self.path.rsplit("/", 2)[-2]
fn = self.lambda_backend.publish_function(function_name) fn = self.lambda_backend.publish_function(function_name)
if fn: if fn:
@ -267,8 +277,8 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}" return 404, {}, "{}"
def _delete_function(self, request, full_url, headers): def _delete_function(self, request, full_url, headers):
function_name = unquote(self.path.rsplit('/', 1)[-1]) function_name = unquote(self.path.rsplit("/", 1)[-1])
qualifier = self._get_param('Qualifier', None) qualifier = self._get_param("Qualifier", None)
if self.lambda_backend.delete_function(function_name, qualifier): if self.lambda_backend.delete_function(function_name, qualifier):
return 204, {}, "" return 204, {}, ""
@ -276,17 +286,17 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}" return 404, {}, "{}"
def _get_function(self, request, full_url, headers): def _get_function(self, request, full_url, headers):
function_name = self.path.rsplit('/', 1)[-1] function_name = self.path.rsplit("/", 1)[-1]
qualifier = self._get_param('Qualifier', None) qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier) fn = self.lambda_backend.get_function(function_name, qualifier)
if fn: if fn:
code = fn.get_code() code = fn.get_code()
if qualifier is None or qualifier == '$LATEST': if qualifier is None or qualifier == "$LATEST":
code['Configuration']['Version'] = '$LATEST' code["Configuration"]["Version"] = "$LATEST"
if qualifier == '$LATEST': if qualifier == "$LATEST":
code['Configuration']['FunctionArn'] += ':$LATEST' code["Configuration"]["FunctionArn"] += ":$LATEST"
return 200, {}, json.dumps(code) return 200, {}, json.dumps(code)
else: else:
return 404, {}, "{}" return 404, {}, "{}"
@ -299,25 +309,25 @@ class LambdaResponse(BaseResponse):
return self.default_region return self.default_region
def _list_tags(self, request, full_url): def _list_tags(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
fn = self.lambda_backend.get_function_by_arn(function_arn) fn = self.lambda_backend.get_function_by_arn(function_arn)
if fn: if fn:
return 200, {}, json.dumps({'Tags': fn.tags}) return 200, {}, json.dumps({"Tags": fn.tags})
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _tag_resource(self, request, full_url): def _tag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
if self.lambda_backend.tag_resource(function_arn, self.json_body['Tags']): if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]):
return 200, {}, "{}" return 200, {}, "{}"
else: else:
return 404, {}, "{}" return 404, {}, "{}"
def _untag_resource(self, request, full_url): def _untag_resource(self, request, full_url):
function_arn = unquote(self.path.rsplit('/', 1)[-1]) function_arn = unquote(self.path.rsplit("/", 1)[-1])
tag_keys = self.querystring['tagKeys'] tag_keys = self.querystring["tagKeys"]
if self.lambda_backend.untag_resource(function_arn, tag_keys): if self.lambda_backend.untag_resource(function_arn, tag_keys):
return 204, {}, "{}" return 204, {}, "{}"
@ -325,8 +335,8 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}" return 404, {}, "{}"
def _put_configuration(self, request): def _put_configuration(self, request):
function_name = self.path.rsplit('/', 2)[-2] function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param('Qualifier', None) qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier) fn = self.lambda_backend.get_function(function_name, qualifier)
@ -337,13 +347,13 @@ class LambdaResponse(BaseResponse):
return 404, {}, "{}" return 404, {}, "{}"
def _put_code(self): def _put_code(self):
function_name = self.path.rsplit('/', 2)[-2] function_name = self.path.rsplit("/", 2)[-2]
qualifier = self._get_param('Qualifier', None) qualifier = self._get_param("Qualifier", None)
fn = self.lambda_backend.get_function(function_name, qualifier) fn = self.lambda_backend.get_function(function_name, qualifier)
if fn: if fn:
if self.json_body.get('Publish', False): if self.json_body.get("Publish", False):
fn = self.lambda_backend.publish_function(function_name) fn = self.lambda_backend.publish_function(function_name)
config = fn.update_function_code(self.json_body) config = fn.update_function_code(self.json_body)

View File

@ -1,22 +1,20 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import LambdaResponse from .responses import LambdaResponse
url_bases = [ url_bases = ["https?://lambda.(.+).amazonaws.com"]
"https?://lambda.(.+).amazonaws.com",
]
response = LambdaResponse() response = LambdaResponse()
url_paths = { url_paths = {
'{0}/(?P<api_version>[^/]+)/functions/?$': response.root, "{0}/(?P<api_version>[^/]+)/functions/?$": response.root,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$': response.function, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_:%-]+)/?$": response.function,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$': response.versions, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/versions/?$": response.versions,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/?$': response.event_source_mappings, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/?$": response.event_source_mappings,
r'{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$': response.event_source_mapping, r"{0}/(?P<api_version>[^/]+)/event-source-mappings/(?P<UUID>[\w_-]+)/?$": response.event_source_mapping,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$': response.invoke, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invocations/?$": response.invoke,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$': response.invoke_async, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/invoke-async/?$": response.invoke_async,
r'{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)': response.tag, r"{0}/(?P<api_version>[^/]+)/tags/(?P<resource_arn>.+)": response.tag,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$': response.policy, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/policy/?$": response.policy,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$': response.configuration, r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/configuration/?$": response.configuration,
r'{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$': response.code r"{0}/(?P<api_version>[^/]+)/functions/(?P<function_name>[\w_-]+)/code/?$": response.code,
} }

View File

@ -1,20 +1,20 @@
from collections import namedtuple from collections import namedtuple
ARN = namedtuple('ARN', ['region', 'account', 'function_name', 'version']) ARN = namedtuple("ARN", ["region", "account", "function_name", "version"])
def make_function_arn(region, account, name): def make_function_arn(region, account, name):
return 'arn:aws:lambda:{0}:{1}:function:{2}'.format(region, account, name) return "arn:aws:lambda:{0}:{1}:function:{2}".format(region, account, name)
def make_function_ver_arn(region, account, name, version='1'): def make_function_ver_arn(region, account, name, version="1"):
arn = make_function_arn(region, account, name) arn = make_function_arn(region, account, name)
return '{0}:{1}'.format(arn, version) return "{0}:{1}".format(arn, version)
def split_function_arn(arn): def split_function_arn(arn):
arn = arn.replace('arn:aws:lambda:') arn = arn.replace("arn:aws:lambda:")
region, account, _, name, version = arn.split(':') region, account, _, name, version = arn.split(":")
return ARN(region, account, name, version) return ARN(region, account, name, version)

View File

@ -52,57 +52,57 @@ from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends
from moto.config import config_backends from moto.config import config_backends
BACKENDS = { BACKENDS = {
'acm': acm_backends, "acm": acm_backends,
'apigateway': apigateway_backends, "apigateway": apigateway_backends,
'athena': athena_backends, "athena": athena_backends,
'autoscaling': autoscaling_backends, "autoscaling": autoscaling_backends,
'batch': batch_backends, "batch": batch_backends,
'cloudformation': cloudformation_backends, "cloudformation": cloudformation_backends,
'cloudwatch': cloudwatch_backends, "cloudwatch": cloudwatch_backends,
'cognito-identity': cognitoidentity_backends, "cognito-identity": cognitoidentity_backends,
'cognito-idp': cognitoidp_backends, "cognito-idp": cognitoidp_backends,
'config': config_backends, "config": config_backends,
'datapipeline': datapipeline_backends, "datapipeline": datapipeline_backends,
'dynamodb': dynamodb_backends, "dynamodb": dynamodb_backends,
'dynamodb2': dynamodb_backends2, "dynamodb2": dynamodb_backends2,
'dynamodbstreams': dynamodbstreams_backends, "dynamodbstreams": dynamodbstreams_backends,
'ec2': ec2_backends, "ec2": ec2_backends,
'ecr': ecr_backends, "ecr": ecr_backends,
'ecs': ecs_backends, "ecs": ecs_backends,
'elb': elb_backends, "elb": elb_backends,
'elbv2': elbv2_backends, "elbv2": elbv2_backends,
'events': events_backends, "events": events_backends,
'emr': emr_backends, "emr": emr_backends,
'glacier': glacier_backends, "glacier": glacier_backends,
'glue': glue_backends, "glue": glue_backends,
'iam': iam_backends, "iam": iam_backends,
'moto_api': moto_api_backends, "moto_api": moto_api_backends,
'instance_metadata': instance_metadata_backends, "instance_metadata": instance_metadata_backends,
'logs': logs_backends, "logs": logs_backends,
'kinesis': kinesis_backends, "kinesis": kinesis_backends,
'kms': kms_backends, "kms": kms_backends,
'opsworks': opsworks_backends, "opsworks": opsworks_backends,
'organizations': organizations_backends, "organizations": organizations_backends,
'polly': polly_backends, "polly": polly_backends,
'redshift': redshift_backends, "redshift": redshift_backends,
'resource-groups': resourcegroups_backends, "resource-groups": resourcegroups_backends,
'rds': rds2_backends, "rds": rds2_backends,
's3': s3_backends, "s3": s3_backends,
's3bucket_path': s3_backends, "s3bucket_path": s3_backends,
'ses': ses_backends, "ses": ses_backends,
'secretsmanager': secretsmanager_backends, "secretsmanager": secretsmanager_backends,
'sns': sns_backends, "sns": sns_backends,
'sqs': sqs_backends, "sqs": sqs_backends,
'ssm': ssm_backends, "ssm": ssm_backends,
'stepfunctions': stepfunction_backends, "stepfunctions": stepfunction_backends,
'sts': sts_backends, "sts": sts_backends,
'swf': swf_backends, "swf": swf_backends,
'route53': route53_backends, "route53": route53_backends,
'lambda': lambda_backends, "lambda": lambda_backends,
'xray': xray_backends, "xray": xray_backends,
'resourcegroupstaggingapi': resourcegroupstaggingapi_backends, "resourcegroupstaggingapi": resourcegroupstaggingapi_backends,
'iot': iot_backends, "iot": iot_backends,
'iot-data': iotdata_backends, "iot-data": iotdata_backends,
} }
@ -110,6 +110,6 @@ def get_model(name, region_name):
for backends in BACKENDS.values(): for backends in BACKENDS.values():
for region, backend in backends.items(): for region, backend in backends.items():
if region == region_name: if region == region_name:
models = getattr(backend.__class__, '__models__', {}) models = getattr(backend.__class__, "__models__", {})
if name in models: if name in models:
return list(getattr(backend, models[name])()) return list(getattr(backend, models[name])())

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import batch_backends from .models import batch_backends
from ..core.models import base_decorator from ..core.models import base_decorator
batch_backend = batch_backends['us-east-1'] batch_backend = batch_backends["us-east-1"]
mock_batch = base_decorator(batch_backends) mock_batch = base_decorator(batch_backends)

View File

@ -12,26 +12,29 @@ class AWSError(Exception):
self.status = status if status is not None else self.STATUS self.status = status if status is not None else self.STATUS
def response(self): def response(self):
return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) return (
json.dumps({"__type": self.code, "message": self.message}),
dict(status=self.status),
)
class InvalidRequestException(AWSError): class InvalidRequestException(AWSError):
CODE = 'InvalidRequestException' CODE = "InvalidRequestException"
class InvalidParameterValueException(AWSError): class InvalidParameterValueException(AWSError):
CODE = 'InvalidParameterValue' CODE = "InvalidParameterValue"
class ValidationError(AWSError): class ValidationError(AWSError):
CODE = 'ValidationError' CODE = "ValidationError"
class InternalFailure(AWSError): class InternalFailure(AWSError):
CODE = 'InternalFailure' CODE = "InternalFailure"
STATUS = 500 STATUS = 500
class ClientException(AWSError): class ClientException(AWSError):
CODE = 'ClientException' CODE = "ClientException"
STATUS = 400 STATUS = 400

File diff suppressed because it is too large Load Diff

View File

@ -10,7 +10,7 @@ import json
class BatchResponse(BaseResponse): class BatchResponse(BaseResponse):
def _error(self, code, message): def _error(self, code, message):
return json.dumps({'__type': code, 'message': message}), dict(status=400) return json.dumps({"__type": code, "message": message}), dict(status=400)
@property @property
def batch_backend(self): def batch_backend(self):
@ -22,9 +22,9 @@ class BatchResponse(BaseResponse):
@property @property
def json(self): def json(self):
if self.body is None or self.body == '': if self.body is None or self.body == "":
self._json = {} self._json = {}
elif not hasattr(self, '_json'): elif not hasattr(self, "_json"):
try: try:
self._json = json.loads(self.body) self._json = json.loads(self.body)
except ValueError: except ValueError:
@ -39,153 +39,146 @@ class BatchResponse(BaseResponse):
def _get_action(self): def _get_action(self):
# Return element after the /v1/* # Return element after the /v1/*
return urlsplit(self.uri).path.lstrip('/').split('/')[1] return urlsplit(self.uri).path.lstrip("/").split("/")[1]
# CreateComputeEnvironment # CreateComputeEnvironment
def createcomputeenvironment(self): def createcomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironmentName') compute_env_name = self._get_param("computeEnvironmentName")
compute_resource = self._get_param('computeResources') compute_resource = self._get_param("computeResources")
service_role = self._get_param('serviceRole') service_role = self._get_param("serviceRole")
state = self._get_param('state') state = self._get_param("state")
_type = self._get_param('type') _type = self._get_param("type")
try: try:
name, arn = self.batch_backend.create_compute_environment( name, arn = self.batch_backend.create_compute_environment(
compute_environment_name=compute_env_name, compute_environment_name=compute_env_name,
_type=_type, state=state, _type=_type,
state=state,
compute_resources=compute_resource, compute_resources=compute_resource,
service_role=service_role service_role=service_role,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
return json.dumps(result) return json.dumps(result)
# DescribeComputeEnvironments # DescribeComputeEnvironments
def describecomputeenvironments(self): def describecomputeenvironments(self):
compute_environments = self._get_param('computeEnvironments') compute_environments = self._get_param("computeEnvironments")
max_results = self._get_param('maxResults') # Ignored, should be int max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored next_token = self._get_param("nextToken") # Ignored
envs = self.batch_backend.describe_compute_environments(compute_environments, max_results=max_results, next_token=next_token) envs = self.batch_backend.describe_compute_environments(
compute_environments, max_results=max_results, next_token=next_token
)
result = {'computeEnvironments': envs} result = {"computeEnvironments": envs}
return json.dumps(result) return json.dumps(result)
# DeleteComputeEnvironment # DeleteComputeEnvironment
def deletecomputeenvironment(self): def deletecomputeenvironment(self):
compute_environment = self._get_param('computeEnvironment') compute_environment = self._get_param("computeEnvironment")
try: try:
self.batch_backend.delete_compute_environment(compute_environment) self.batch_backend.delete_compute_environment(compute_environment)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
# UpdateComputeEnvironment # UpdateComputeEnvironment
def updatecomputeenvironment(self): def updatecomputeenvironment(self):
compute_env_name = self._get_param('computeEnvironment') compute_env_name = self._get_param("computeEnvironment")
compute_resource = self._get_param('computeResources') compute_resource = self._get_param("computeResources")
service_role = self._get_param('serviceRole') service_role = self._get_param("serviceRole")
state = self._get_param('state') state = self._get_param("state")
try: try:
name, arn = self.batch_backend.update_compute_environment( name, arn = self.batch_backend.update_compute_environment(
compute_environment_name=compute_env_name, compute_environment_name=compute_env_name,
compute_resources=compute_resource, compute_resources=compute_resource,
service_role=service_role, service_role=service_role,
state=state state=state,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name}
'computeEnvironmentArn': arn,
'computeEnvironmentName': name
}
return json.dumps(result) return json.dumps(result)
# CreateJobQueue # CreateJobQueue
def createjobqueue(self): def createjobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder') compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param('jobQueueName') queue_name = self._get_param("jobQueueName")
priority = self._get_param('priority') priority = self._get_param("priority")
state = self._get_param('state') state = self._get_param("state")
try: try:
name, arn = self.batch_backend.create_job_queue( name, arn = self.batch_backend.create_job_queue(
queue_name=queue_name, queue_name=queue_name,
priority=priority, priority=priority,
state=state, state=state,
compute_env_order=compute_env_order compute_env_order=compute_env_order,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"jobQueueArn": arn, "jobQueueName": name}
'jobQueueArn': arn,
'jobQueueName': name
}
return json.dumps(result) return json.dumps(result)
# DescribeJobQueues # DescribeJobQueues
def describejobqueues(self): def describejobqueues(self):
job_queues = self._get_param('jobQueues') job_queues = self._get_param("jobQueues")
max_results = self._get_param('maxResults') # Ignored, should be int max_results = self._get_param("maxResults") # Ignored, should be int
next_token = self._get_param('nextToken') # Ignored next_token = self._get_param("nextToken") # Ignored
queues = self.batch_backend.describe_job_queues(job_queues, max_results=max_results, next_token=next_token) queues = self.batch_backend.describe_job_queues(
job_queues, max_results=max_results, next_token=next_token
)
result = {'jobQueues': queues} result = {"jobQueues": queues}
return json.dumps(result) return json.dumps(result)
# UpdateJobQueue # UpdateJobQueue
def updatejobqueue(self): def updatejobqueue(self):
compute_env_order = self._get_param('computeEnvironmentOrder') compute_env_order = self._get_param("computeEnvironmentOrder")
queue_name = self._get_param('jobQueue') queue_name = self._get_param("jobQueue")
priority = self._get_param('priority') priority = self._get_param("priority")
state = self._get_param('state') state = self._get_param("state")
try: try:
name, arn = self.batch_backend.update_job_queue( name, arn = self.batch_backend.update_job_queue(
queue_name=queue_name, queue_name=queue_name,
priority=priority, priority=priority,
state=state, state=state,
compute_env_order=compute_env_order compute_env_order=compute_env_order,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"jobQueueArn": arn, "jobQueueName": name}
'jobQueueArn': arn,
'jobQueueName': name
}
return json.dumps(result) return json.dumps(result)
# DeleteJobQueue # DeleteJobQueue
def deletejobqueue(self): def deletejobqueue(self):
queue_name = self._get_param('jobQueue') queue_name = self._get_param("jobQueue")
self.batch_backend.delete_job_queue(queue_name) self.batch_backend.delete_job_queue(queue_name)
return '' return ""
# RegisterJobDefinition # RegisterJobDefinition
def registerjobdefinition(self): def registerjobdefinition(self):
container_properties = self._get_param('containerProperties') container_properties = self._get_param("containerProperties")
def_name = self._get_param('jobDefinitionName') def_name = self._get_param("jobDefinitionName")
parameters = self._get_param('parameters') parameters = self._get_param("parameters")
retry_strategy = self._get_param('retryStrategy') retry_strategy = self._get_param("retryStrategy")
_type = self._get_param('type') _type = self._get_param("type")
try: try:
name, arn, revision = self.batch_backend.register_job_definition( name, arn, revision = self.batch_backend.register_job_definition(
@ -193,104 +186,113 @@ class BatchResponse(BaseResponse):
parameters=parameters, parameters=parameters,
_type=_type, _type=_type,
retry_strategy=retry_strategy, retry_strategy=retry_strategy,
container_properties=container_properties container_properties=container_properties,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {
'jobDefinitionArn': arn, "jobDefinitionArn": arn,
'jobDefinitionName': name, "jobDefinitionName": name,
'revision': revision "revision": revision,
} }
return json.dumps(result) return json.dumps(result)
# DeregisterJobDefinition # DeregisterJobDefinition
def deregisterjobdefinition(self): def deregisterjobdefinition(self):
queue_name = self._get_param('jobDefinition') queue_name = self._get_param("jobDefinition")
self.batch_backend.deregister_job_definition(queue_name) self.batch_backend.deregister_job_definition(queue_name)
return '' return ""
# DescribeJobDefinitions # DescribeJobDefinitions
def describejobdefinitions(self): def describejobdefinitions(self):
job_def_name = self._get_param('jobDefinitionName') job_def_name = self._get_param("jobDefinitionName")
job_def_list = self._get_param('jobDefinitions') job_def_list = self._get_param("jobDefinitions")
max_results = self._get_param('maxResults') max_results = self._get_param("maxResults")
next_token = self._get_param('nextToken') next_token = self._get_param("nextToken")
status = self._get_param('status') status = self._get_param("status")
job_defs = self.batch_backend.describe_job_definitions(job_def_name, job_def_list, status, max_results, next_token) job_defs = self.batch_backend.describe_job_definitions(
job_def_name, job_def_list, status, max_results, next_token
)
result = {'jobDefinitions': [job.describe() for job in job_defs]} result = {"jobDefinitions": [job.describe() for job in job_defs]}
return json.dumps(result) return json.dumps(result)
# SubmitJob # SubmitJob
def submitjob(self): def submitjob(self):
container_overrides = self._get_param('containerOverrides') container_overrides = self._get_param("containerOverrides")
depends_on = self._get_param('dependsOn') depends_on = self._get_param("dependsOn")
job_def = self._get_param('jobDefinition') job_def = self._get_param("jobDefinition")
job_name = self._get_param('jobName') job_name = self._get_param("jobName")
job_queue = self._get_param('jobQueue') job_queue = self._get_param("jobQueue")
parameters = self._get_param('parameters') parameters = self._get_param("parameters")
retries = self._get_param('retryStrategy') retries = self._get_param("retryStrategy")
try: try:
name, job_id = self.batch_backend.submit_job( name, job_id = self.batch_backend.submit_job(
job_name, job_def, job_queue, job_name,
job_def,
job_queue,
parameters=parameters, parameters=parameters,
retries=retries, retries=retries,
depends_on=depends_on, depends_on=depends_on,
container_overrides=container_overrides container_overrides=container_overrides,
) )
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = { result = {"jobId": job_id, "jobName": name}
'jobId': job_id,
'jobName': name,
}
return json.dumps(result) return json.dumps(result)
# DescribeJobs # DescribeJobs
def describejobs(self): def describejobs(self):
jobs = self._get_param('jobs') jobs = self._get_param("jobs")
try: try:
return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)}) return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)})
except AWSError as err: except AWSError as err:
return err.response() return err.response()
# ListJobs # ListJobs
def listjobs(self): def listjobs(self):
job_queue = self._get_param('jobQueue') job_queue = self._get_param("jobQueue")
job_status = self._get_param('jobStatus') job_status = self._get_param("jobStatus")
max_results = self._get_param('maxResults') max_results = self._get_param("maxResults")
next_token = self._get_param('nextToken') next_token = self._get_param("nextToken")
try: try:
jobs = self.batch_backend.list_jobs(job_queue, job_status, max_results, next_token) jobs = self.batch_backend.list_jobs(
job_queue, job_status, max_results, next_token
)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
result = {'jobSummaryList': [{'jobId': job.job_id, 'jobName': job.job_name} for job in jobs]} result = {
"jobSummaryList": [
{"jobId": job.job_id, "jobName": job.job_name} for job in jobs
]
}
return json.dumps(result) return json.dumps(result)
# TerminateJob # TerminateJob
def terminatejob(self): def terminatejob(self):
job_id = self._get_param('jobId') job_id = self._get_param("jobId")
reason = self._get_param('reason') reason = self._get_param("reason")
try: try:
self.batch_backend.terminate_job(job_id, reason) self.batch_backend.terminate_job(job_id, reason)
except AWSError as err: except AWSError as err:
return err.response() return err.response()
return '' return ""
# CancelJob # CancelJob
def canceljob(self): # Theres some AWS semantics on the differences but for us they're identical ;-) def canceljob(
self,
): # Theres some AWS semantics on the differences but for us they're identical ;-)
return self.terminatejob() return self.terminatejob()

View File

@ -1,25 +1,23 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import BatchResponse from .responses import BatchResponse
url_bases = [ url_bases = ["https?://batch.(.+).amazonaws.com"]
"https?://batch.(.+).amazonaws.com",
]
url_paths = { url_paths = {
'{0}/v1/createcomputeenvironment$': BatchResponse.dispatch, "{0}/v1/createcomputeenvironment$": BatchResponse.dispatch,
'{0}/v1/describecomputeenvironments$': BatchResponse.dispatch, "{0}/v1/describecomputeenvironments$": BatchResponse.dispatch,
'{0}/v1/deletecomputeenvironment': BatchResponse.dispatch, "{0}/v1/deletecomputeenvironment": BatchResponse.dispatch,
'{0}/v1/updatecomputeenvironment': BatchResponse.dispatch, "{0}/v1/updatecomputeenvironment": BatchResponse.dispatch,
'{0}/v1/createjobqueue': BatchResponse.dispatch, "{0}/v1/createjobqueue": BatchResponse.dispatch,
'{0}/v1/describejobqueues': BatchResponse.dispatch, "{0}/v1/describejobqueues": BatchResponse.dispatch,
'{0}/v1/updatejobqueue': BatchResponse.dispatch, "{0}/v1/updatejobqueue": BatchResponse.dispatch,
'{0}/v1/deletejobqueue': BatchResponse.dispatch, "{0}/v1/deletejobqueue": BatchResponse.dispatch,
'{0}/v1/registerjobdefinition': BatchResponse.dispatch, "{0}/v1/registerjobdefinition": BatchResponse.dispatch,
'{0}/v1/deregisterjobdefinition': BatchResponse.dispatch, "{0}/v1/deregisterjobdefinition": BatchResponse.dispatch,
'{0}/v1/describejobdefinitions': BatchResponse.dispatch, "{0}/v1/describejobdefinitions": BatchResponse.dispatch,
'{0}/v1/submitjob': BatchResponse.dispatch, "{0}/v1/submitjob": BatchResponse.dispatch,
'{0}/v1/describejobs': BatchResponse.dispatch, "{0}/v1/describejobs": BatchResponse.dispatch,
'{0}/v1/listjobs': BatchResponse.dispatch, "{0}/v1/listjobs": BatchResponse.dispatch,
'{0}/v1/terminatejob': BatchResponse.dispatch, "{0}/v1/terminatejob": BatchResponse.dispatch,
'{0}/v1/canceljob': BatchResponse.dispatch, "{0}/v1/canceljob": BatchResponse.dispatch,
} }

View File

@ -2,7 +2,9 @@ from __future__ import unicode_literals
def make_arn_for_compute_env(account_id, name, region_name): def make_arn_for_compute_env(account_id, name, region_name):
return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name) return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(
region_name, account_id, name
)
def make_arn_for_job_queue(account_id, name, region_name): def make_arn_for_job_queue(account_id, name, region_name):
@ -10,7 +12,9 @@ def make_arn_for_job_queue(account_id, name, region_name):
def make_arn_for_task_def(account_id, name, revision, region_name): def make_arn_for_task_def(account_id, name, revision, region_name):
return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision) return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(
region_name, account_id, name, revision
)
def lowercase_first_key(some_dict): def lowercase_first_key(some_dict):

View File

@ -2,7 +2,6 @@ from __future__ import unicode_literals
from .models import cloudformation_backends from .models import cloudformation_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cloudformation_backend = cloudformation_backends['us-east-1'] cloudformation_backend = cloudformation_backends["us-east-1"]
mock_cloudformation = base_decorator(cloudformation_backends) mock_cloudformation = base_decorator(cloudformation_backends)
mock_cloudformation_deprecated = deprecated_base_decorator( mock_cloudformation_deprecated = deprecated_base_decorator(cloudformation_backends)
cloudformation_backends)

View File

@ -4,26 +4,23 @@ from jinja2 import Template
class UnformattedGetAttTemplateException(Exception): class UnformattedGetAttTemplateException(Exception):
description = 'Template error: resource {0} does not support attribute type {1} in Fn::GetAtt' description = (
"Template error: resource {0} does not support attribute type {1} in Fn::GetAtt"
)
status_code = 400 status_code = 400
class ValidationError(BadRequest): class ValidationError(BadRequest):
def __init__(self, name_or_id, message=None): def __init__(self, name_or_id, message=None):
if message is None: if message is None:
message = "Stack with id {0} does not exist".format(name_or_id) message = "Stack with id {0} does not exist".format(name_or_id)
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(ValidationError, self).__init__() super(ValidationError, self).__init__()
self.description = template.render( self.description = template.render(code="ValidationError", message=message)
code="ValidationError",
message=message,
)
class MissingParameterError(BadRequest): class MissingParameterError(BadRequest):
def __init__(self, parameter_name): def __init__(self, parameter_name):
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(MissingParameterError, self).__init__() super(MissingParameterError, self).__init__()
@ -40,8 +37,8 @@ class ExportNotFound(BadRequest):
template = Template(ERROR_RESPONSE) template = Template(ERROR_RESPONSE)
super(ExportNotFound, self).__init__() super(ExportNotFound, self).__init__()
self.description = template.render( self.description = template.render(
code='ExportNotFound', code="ExportNotFound",
message="No export named {0} found.".format(export_name) message="No export named {0} found.".format(export_name),
) )

View File

@ -21,11 +21,19 @@ from .exceptions import ValidationError
class FakeStackSet(BaseModel): class FakeStackSet(BaseModel):
def __init__(
def __init__(self, stackset_id, name, template, region='us-east-1', self,
status='ACTIVE', description=None, parameters=None, tags=None, stackset_id,
admin_role='AWSCloudFormationStackSetAdministrationRole', name,
execution_role='AWSCloudFormationStackSetExecutionRole'): template,
region="us-east-1",
status="ACTIVE",
description=None,
parameters=None,
tags=None,
admin_role="AWSCloudFormationStackSetAdministrationRole",
execution_role="AWSCloudFormationStackSetExecutionRole",
):
self.id = stackset_id self.id = stackset_id
self.arn = generate_stackset_arn(stackset_id, region) self.arn = generate_stackset_arn(stackset_id, region)
self.name = name self.name = name
@ -42,12 +50,14 @@ class FakeStackSet(BaseModel):
def _create_operation(self, operation_id, action, status, accounts=[], regions=[]): def _create_operation(self, operation_id, action, status, accounts=[], regions=[]):
operation = { operation = {
'OperationId': str(operation_id), "OperationId": str(operation_id),
'Action': action, "Action": action,
'Status': status, "Status": status,
'CreationTimestamp': datetime.now(), "CreationTimestamp": datetime.now(),
'EndTimestamp': datetime.now() + timedelta(minutes=2), "EndTimestamp": datetime.now() + timedelta(minutes=2),
'Instances': [{account: region} for account in accounts for region in regions], "Instances": [
{account: region} for account in accounts for region in regions
],
} }
self.operations += [operation] self.operations += [operation]
@ -55,20 +65,30 @@ class FakeStackSet(BaseModel):
def get_operation(self, operation_id): def get_operation(self, operation_id):
for operation in self.operations: for operation in self.operations:
if operation_id == operation['OperationId']: if operation_id == operation["OperationId"]:
return operation return operation
raise ValidationError(operation_id) raise ValidationError(operation_id)
def update_operation(self, operation_id, status): def update_operation(self, operation_id, status):
operation = self.get_operation(operation_id) operation = self.get_operation(operation_id)
operation['Status'] = status operation["Status"] = status
return operation_id return operation_id
def delete(self): def delete(self):
self.status = 'DELETED' self.status = "DELETED"
def update(self, template, description, parameters, tags, admin_role, def update(
execution_role, accounts, regions, operation_id=None): self,
template,
description,
parameters,
tags,
admin_role,
execution_role,
accounts,
regions,
operation_id=None,
):
if not operation_id: if not operation_id:
operation_id = uuid.uuid4() operation_id = uuid.uuid4()
@ -82,9 +102,13 @@ class FakeStackSet(BaseModel):
if accounts and regions: if accounts and regions:
self.update_instances(accounts, regions, self.parameters) self.update_instances(accounts, regions, self.parameters)
operation = self._create_operation(operation_id=operation_id, operation = self._create_operation(
action='UPDATE', status='SUCCEEDED', accounts=accounts, operation_id=operation_id,
regions=regions) action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation return operation
def create_stack_instances(self, accounts, regions, parameters, operation_id=None): def create_stack_instances(self, accounts, regions, parameters, operation_id=None):
@ -94,8 +118,13 @@ class FakeStackSet(BaseModel):
parameters = self.parameters parameters = self.parameters
self.instances.create_instances(accounts, regions, parameters, operation_id) self.instances.create_instances(accounts, regions, parameters, operation_id)
self._create_operation(operation_id=operation_id, action='CREATE', self._create_operation(
status='SUCCEEDED', accounts=accounts, regions=regions) operation_id=operation_id,
action="CREATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
def delete_stack_instances(self, accounts, regions, operation_id=None): def delete_stack_instances(self, accounts, regions, operation_id=None):
if not operation_id: if not operation_id:
@ -103,8 +132,13 @@ class FakeStackSet(BaseModel):
self.instances.delete(accounts, regions) self.instances.delete(accounts, regions)
operation = self._create_operation(operation_id=operation_id, action='DELETE', operation = self._create_operation(
status='SUCCEEDED', accounts=accounts, regions=regions) operation_id=operation_id,
action="DELETE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation return operation
def update_instances(self, accounts, regions, parameters, operation_id=None): def update_instances(self, accounts, regions, parameters, operation_id=None):
@ -112,9 +146,13 @@ class FakeStackSet(BaseModel):
operation_id = uuid.uuid4() operation_id = uuid.uuid4()
self.instances.update(accounts, regions, parameters) self.instances.update(accounts, regions, parameters)
operation = self._create_operation(operation_id=operation_id, operation = self._create_operation(
action='UPDATE', status='SUCCEEDED', accounts=accounts, operation_id=operation_id,
regions=regions) action="UPDATE",
status="SUCCEEDED",
accounts=accounts,
regions=regions,
)
return operation return operation
@ -131,12 +169,12 @@ class FakeStackInstances(BaseModel):
for region in regions: for region in regions:
for account in accounts: for account in accounts:
instance = { instance = {
'StackId': generate_stack_id(self.stack_name, region, account), "StackId": generate_stack_id(self.stack_name, region, account),
'StackSetId': self.stackset_id, "StackSetId": self.stackset_id,
'Region': region, "Region": region,
'Account': account, "Account": account,
'Status': "CURRENT", "Status": "CURRENT",
'ParameterOverrides': parameters if parameters else [], "ParameterOverrides": parameters if parameters else [],
} }
new_instances.append(instance) new_instances.append(instance)
self.stack_instances += new_instances self.stack_instances += new_instances
@ -147,24 +185,35 @@ class FakeStackInstances(BaseModel):
for region in regions: for region in regions:
instance = self.get_instance(account, region) instance = self.get_instance(account, region)
if parameters: if parameters:
instance['ParameterOverrides'] = parameters instance["ParameterOverrides"] = parameters
else: else:
instance['ParameterOverrides'] = [] instance["ParameterOverrides"] = []
def delete(self, accounts, regions): def delete(self, accounts, regions):
for i, instance in enumerate(self.stack_instances): for i, instance in enumerate(self.stack_instances):
if instance['Region'] in regions and instance['Account'] in accounts: if instance["Region"] in regions and instance["Account"] in accounts:
self.stack_instances.pop(i) self.stack_instances.pop(i)
def get_instance(self, account, region): def get_instance(self, account, region):
for i, instance in enumerate(self.stack_instances): for i, instance in enumerate(self.stack_instances):
if instance['Region'] == region and instance['Account'] == account: if instance["Region"] == region and instance["Account"] == account:
return self.stack_instances[i] return self.stack_instances[i]
class FakeStack(BaseModel): class FakeStack(BaseModel):
def __init__(
def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None, create_change_set=False): self,
stack_id,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
create_change_set=False,
):
self.stack_id = stack_id self.stack_id = stack_id
self.name = name self.name = name
self.template = template self.template = template
@ -176,22 +225,31 @@ class FakeStack(BaseModel):
self.tags = tags if tags else {} self.tags = tags if tags else {}
self.events = [] self.events = []
if create_change_set: if create_change_set:
self._add_stack_event("REVIEW_IN_PROGRESS", self._add_stack_event(
resource_status_reason="User Initiated") "REVIEW_IN_PROGRESS", resource_status_reason="User Initiated"
)
else: else:
self._add_stack_event("CREATE_IN_PROGRESS", self._add_stack_event(
resource_status_reason="User Initiated") "CREATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.description = self.template_dict.get('Description') self.description = self.template_dict.get("Description")
self.cross_stack_resources = cross_stack_resources or {} self.cross_stack_resources = cross_stack_resources or {}
self.resource_map = self._create_resource_map() self.resource_map = self._create_resource_map()
self.output_map = self._create_output_map() self.output_map = self._create_output_map()
self._add_stack_event("CREATE_COMPLETE") self._add_stack_event("CREATE_COMPLETE")
self.status = 'CREATE_COMPLETE' self.status = "CREATE_COMPLETE"
def _create_resource_map(self): def _create_resource_map(self):
resource_map = ResourceMap( resource_map = ResourceMap(
self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict, self.cross_stack_resources) self.stack_id,
self.name,
self.parameters,
self.tags,
self.region_name,
self.template_dict,
self.cross_stack_resources,
)
resource_map.create() resource_map.create()
return resource_map return resource_map
@ -200,8 +258,11 @@ class FakeStack(BaseModel):
output_map.create() output_map.create()
return output_map return output_map
def _add_stack_event(self, resource_status, resource_status_reason=None, resource_properties=None): def _add_stack_event(
self.events.append(FakeEvent( self, resource_status, resource_status_reason=None, resource_properties=None
):
self.events.append(
FakeEvent(
stack_id=self.stack_id, stack_id=self.stack_id,
stack_name=self.name, stack_name=self.name,
logical_resource_id=self.name, logical_resource_id=self.name,
@ -210,12 +271,20 @@ class FakeStack(BaseModel):
resource_status=resource_status, resource_status=resource_status,
resource_status_reason=resource_status_reason, resource_status_reason=resource_status_reason,
resource_properties=resource_properties, resource_properties=resource_properties,
)) )
)
def _add_resource_event(self, logical_resource_id, resource_status, resource_status_reason=None, resource_properties=None): def _add_resource_event(
self,
logical_resource_id,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
# not used yet... feel free to help yourself # not used yet... feel free to help yourself
resource = self.resource_map[logical_resource_id] resource = self.resource_map[logical_resource_id]
self.events.append(FakeEvent( self.events.append(
FakeEvent(
stack_id=self.stack_id, stack_id=self.stack_id,
stack_name=self.name, stack_name=self.name,
logical_resource_id=logical_resource_id, logical_resource_id=logical_resource_id,
@ -224,10 +293,11 @@ class FakeStack(BaseModel):
resource_status=resource_status, resource_status=resource_status,
resource_status_reason=resource_status_reason, resource_status_reason=resource_status_reason,
resource_properties=resource_properties, resource_properties=resource_properties,
)) )
)
def _parse_template(self): def _parse_template(self):
yaml.add_multi_constructor('', yaml_tag_constructor) yaml.add_multi_constructor("", yaml_tag_constructor)
try: try:
self.template_dict = yaml.load(self.template, Loader=yaml.Loader) self.template_dict = yaml.load(self.template, Loader=yaml.Loader)
except yaml.parser.ParserError: except yaml.parser.ParserError:
@ -250,7 +320,9 @@ class FakeStack(BaseModel):
return self.output_map.exports return self.output_map.exports
def update(self, template, role_arn=None, parameters=None, tags=None): def update(self, template, role_arn=None, parameters=None, tags=None):
self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated") self._add_stack_event(
"UPDATE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.template = template self.template = template
self._parse_template() self._parse_template()
self.resource_map.update(self.template_dict, parameters) self.resource_map.update(self.template_dict, parameters)
@ -264,15 +336,15 @@ class FakeStack(BaseModel):
# TODO: update tags in the resource map # TODO: update tags in the resource map
def delete(self): def delete(self):
self._add_stack_event("DELETE_IN_PROGRESS", self._add_stack_event(
resource_status_reason="User Initiated") "DELETE_IN_PROGRESS", resource_status_reason="User Initiated"
)
self.resource_map.delete() self.resource_map.delete()
self._add_stack_event("DELETE_COMPLETE") self._add_stack_event("DELETE_COMPLETE")
self.status = "DELETE_COMPLETE" self.status = "DELETE_COMPLETE"
class FakeChange(BaseModel): class FakeChange(BaseModel):
def __init__(self, action, logical_resource_id, resource_type): def __init__(self, action, logical_resource_id, resource_type):
self.action = action self.action = action
self.logical_resource_id = logical_resource_id self.logical_resource_id = logical_resource_id
@ -280,8 +352,21 @@ class FakeChange(BaseModel):
class FakeChangeSet(FakeStack): class FakeChangeSet(FakeStack):
def __init__(
def __init__(self, stack_id, stack_name, stack_template, change_set_id, change_set_name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None): self,
stack_id,
stack_name,
stack_template,
change_set_id,
change_set_name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
cross_stack_resources=None,
):
super(FakeChangeSet, self).__init__( super(FakeChangeSet, self).__init__(
stack_id, stack_id,
stack_name, stack_name,
@ -306,17 +391,28 @@ class FakeChangeSet(FakeStack):
resources_by_action = self.resource_map.diff(self.template_dict, parameters) resources_by_action = self.resource_map.diff(self.template_dict, parameters)
for action, resources in resources_by_action.items(): for action, resources in resources_by_action.items():
for resource_name, resource in resources.items(): for resource_name, resource in resources.items():
changes.append(FakeChange( changes.append(
FakeChange(
action=action, action=action,
logical_resource_id=resource_name, logical_resource_id=resource_name,
resource_type=resource['ResourceType'], resource_type=resource["ResourceType"],
)) )
)
return changes return changes
class FakeEvent(BaseModel): class FakeEvent(BaseModel):
def __init__(
def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None): self,
stack_id,
stack_name,
logical_resource_id,
physical_resource_id,
resource_type,
resource_status,
resource_status_reason=None,
resource_properties=None,
):
self.stack_id = stack_id self.stack_id = stack_id
self.stack_name = stack_name self.stack_name = stack_name
self.logical_resource_id = logical_resource_id self.logical_resource_id = logical_resource_id
@ -330,7 +426,6 @@ class FakeEvent(BaseModel):
class CloudFormationBackend(BaseBackend): class CloudFormationBackend(BaseBackend):
def __init__(self): def __init__(self):
self.stacks = OrderedDict() self.stacks = OrderedDict()
self.stacksets = OrderedDict() self.stacksets = OrderedDict()
@ -338,7 +433,17 @@ class CloudFormationBackend(BaseBackend):
self.exports = OrderedDict() self.exports = OrderedDict()
self.change_sets = OrderedDict() self.change_sets = OrderedDict()
def create_stack_set(self, name, template, parameters, tags=None, description=None, region='us-east-1', admin_role=None, execution_role=None): def create_stack_set(
self,
name,
template,
parameters,
tags=None,
description=None,
region="us-east-1",
admin_role=None,
execution_role=None,
):
stackset_id = generate_stackset_id(name) stackset_id = generate_stackset_id(name)
new_stackset = FakeStackSet( new_stackset = FakeStackSet(
stackset_id=stackset_id, stackset_id=stackset_id,
@ -366,7 +471,9 @@ class CloudFormationBackend(BaseBackend):
if self.stacksets[stackset].name == name: if self.stacksets[stackset].name == name:
self.stacksets[stackset].delete() self.stacksets[stackset].delete()
def create_stack_instances(self, stackset_name, accounts, regions, parameters, operation_id=None): def create_stack_instances(
self, stackset_name, accounts, regions, parameters, operation_id=None
):
stackset = self.get_stack_set(stackset_name) stackset = self.get_stack_set(stackset_name)
stackset.create_stack_instances( stackset.create_stack_instances(
@ -377,9 +484,19 @@ class CloudFormationBackend(BaseBackend):
) )
return stackset return stackset
def update_stack_set(self, stackset_name, template=None, description=None, def update_stack_set(
parameters=None, tags=None, admin_role=None, execution_role=None, self,
accounts=None, regions=None, operation_id=None): stackset_name,
template=None,
description=None,
parameters=None,
tags=None,
admin_role=None,
execution_role=None,
accounts=None,
regions=None,
operation_id=None,
):
stackset = self.get_stack_set(stackset_name) stackset = self.get_stack_set(stackset_name)
update = stackset.update( update = stackset.update(
template=template, template=template,
@ -390,16 +507,28 @@ class CloudFormationBackend(BaseBackend):
execution_role=execution_role, execution_role=execution_role,
accounts=accounts, accounts=accounts,
regions=regions, regions=regions,
operation_id=operation_id operation_id=operation_id,
) )
return update return update
def delete_stack_instances(self, stackset_name, accounts, regions, operation_id=None): def delete_stack_instances(
self, stackset_name, accounts, regions, operation_id=None
):
stackset = self.get_stack_set(stackset_name) stackset = self.get_stack_set(stackset_name)
stackset.delete_stack_instances(accounts, regions, operation_id) stackset.delete_stack_instances(accounts, regions, operation_id)
return stackset return stackset
def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False): def create_stack(
self,
name,
template,
parameters,
region_name,
notification_arns=None,
tags=None,
role_arn=None,
create_change_set=False,
):
stack_id = generate_stack_id(name) stack_id = generate_stack_id(name)
new_stack = FakeStack( new_stack = FakeStack(
stack_id=stack_id, stack_id=stack_id,
@ -419,10 +548,21 @@ class CloudFormationBackend(BaseBackend):
self.exports[export.name] = export self.exports[export.name] = export
return new_stack return new_stack
def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None): def create_change_set(
self,
stack_name,
change_set_name,
template,
parameters,
region_name,
change_set_type,
notification_arns=None,
tags=None,
role_arn=None,
):
stack_id = None stack_id = None
stack_template = None stack_template = None
if change_set_type == 'UPDATE': if change_set_type == "UPDATE":
stacks = self.stacks.values() stacks = self.stacks.values()
stack = None stack = None
for s in stacks: for s in stacks:
@ -449,7 +589,7 @@ class CloudFormationBackend(BaseBackend):
notification_arns=notification_arns, notification_arns=notification_arns,
tags=tags, tags=tags,
role_arn=role_arn, role_arn=role_arn,
cross_stack_resources=self.exports cross_stack_resources=self.exports,
) )
self.change_sets[change_set_id] = new_change_set self.change_sets[change_set_id] = new_change_set
self.stacks[stack_id] = new_change_set self.stacks[stack_id] = new_change_set
@ -488,11 +628,11 @@ class CloudFormationBackend(BaseBackend):
stack = self.change_sets[cs] stack = self.change_sets[cs]
if stack is None: if stack is None:
raise ValidationError(stack_name) raise ValidationError(stack_name)
if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS': if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS":
stack._add_stack_event('CREATE_COMPLETE') stack._add_stack_event("CREATE_COMPLETE")
else: else:
stack._add_stack_event('UPDATE_IN_PROGRESS') stack._add_stack_event("UPDATE_IN_PROGRESS")
stack._add_stack_event('UPDATE_COMPLETE') stack._add_stack_event("UPDATE_COMPLETE")
return True return True
def describe_stacks(self, name_or_stack_id): def describe_stacks(self, name_or_stack_id):
@ -514,9 +654,7 @@ class CloudFormationBackend(BaseBackend):
return self.change_sets.values() return self.change_sets.values()
def list_stacks(self): def list_stacks(self):
return [ return [v for v in self.stacks.values()] + [
v for v in self.stacks.values()
] + [
v for v in self.deleted_stacks.values() v for v in self.deleted_stacks.values()
] ]
@ -558,10 +696,10 @@ class CloudFormationBackend(BaseBackend):
all_exports = list(self.exports.values()) all_exports = list(self.exports.values())
if token is None: if token is None:
exports = all_exports[0:100] exports = all_exports[0:100]
next_token = '100' if len(all_exports) > 100 else None next_token = "100" if len(all_exports) > 100 else None
else: else:
token = int(token) token = int(token)
exports = all_exports[token:token + 100] exports = all_exports[token : token + 100]
next_token = str(token + 100) if len(all_exports) > token + 100 else None next_token = str(token + 100) if len(all_exports) > token + 100 else None
return exports, next_token return exports, next_token
@ -572,7 +710,10 @@ class CloudFormationBackend(BaseBackend):
new_stack_export_names = [x.name for x in stack.exports] new_stack_export_names = [x.name for x in stack.exports]
export_names = self.exports.keys() export_names = self.exports.keys()
if not set(export_names).isdisjoint(new_stack_export_names): if not set(export_names).isdisjoint(new_stack_export_names):
raise ValidationError(stack.stack_id, message='Export names must be unique across a given region') raise ValidationError(
stack.stack_id,
message="Export names must be unique across a given region",
)
cloudformation_backends = {} cloudformation_backends = {}

View File

@ -28,7 +28,12 @@ from moto.s3 import models as s3_models
from moto.sns import models as sns_models from moto.sns import models as sns_models
from moto.sqs import models as sqs_models from moto.sqs import models as sqs_models
from .utils import random_suffix from .utils import random_suffix
from .exceptions import ExportNotFound, MissingParameterError, UnformattedGetAttTemplateException, ValidationError from .exceptions import (
ExportNotFound,
MissingParameterError,
UnformattedGetAttTemplateException,
ValidationError,
)
from boto.cloudformation.stack import Output from boto.cloudformation.stack import Output
MODEL_MAP = { MODEL_MAP = {
@ -100,7 +105,7 @@ NAME_TYPE_MAP = {
"AWS::RDS::DBInstance": "DBInstanceIdentifier", "AWS::RDS::DBInstance": "DBInstanceIdentifier",
"AWS::S3::Bucket": "BucketName", "AWS::S3::Bucket": "BucketName",
"AWS::SNS::Topic": "TopicName", "AWS::SNS::Topic": "TopicName",
"AWS::SQS::Queue": "QueueName" "AWS::SQS::Queue": "QueueName",
} }
# Just ignore these models types for now # Just ignore these models types for now
@ -109,13 +114,12 @@ NULL_MODELS = [
"AWS::CloudFormation::WaitConditionHandle", "AWS::CloudFormation::WaitConditionHandle",
] ]
DEFAULT_REGION = 'us-east-1' DEFAULT_REGION = "us-east-1"
logger = logging.getLogger("moto") logger = logging.getLogger("moto")
class LazyDict(dict): class LazyDict(dict):
def __getitem__(self, key): def __getitem__(self, key):
val = dict.__getitem__(self, key) val = dict.__getitem__(self, key)
if callable(val): if callable(val):
@ -132,10 +136,10 @@ def clean_json(resource_json, resources_map):
Eventually, this is where we would add things like function parsing (fn::) Eventually, this is where we would add things like function parsing (fn::)
""" """
if isinstance(resource_json, dict): if isinstance(resource_json, dict):
if 'Ref' in resource_json: if "Ref" in resource_json:
# Parse resource reference # Parse resource reference
resource = resources_map[resource_json['Ref']] resource = resources_map[resource_json["Ref"]]
if hasattr(resource, 'physical_resource_id'): if hasattr(resource, "physical_resource_id"):
return resource.physical_resource_id return resource.physical_resource_id
else: else:
return resource return resource
@ -148,74 +152,92 @@ def clean_json(resource_json, resources_map):
result = result[clean_json(path, resources_map)] result = result[clean_json(path, resources_map)]
return result return result
if 'Fn::GetAtt' in resource_json: if "Fn::GetAtt" in resource_json:
resource = resources_map.get(resource_json['Fn::GetAtt'][0]) resource = resources_map.get(resource_json["Fn::GetAtt"][0])
if resource is None: if resource is None:
return resource_json return resource_json
try: try:
return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1]) return resource.get_cfn_attribute(resource_json["Fn::GetAtt"][1])
except NotImplementedError as n: except NotImplementedError as n:
logger.warning(str(n).format( logger.warning(str(n).format(resource_json["Fn::GetAtt"][0]))
resource_json['Fn::GetAtt'][0]))
except UnformattedGetAttTemplateException: except UnformattedGetAttTemplateException:
raise ValidationError( raise ValidationError(
'Bad Request', "Bad Request",
UnformattedGetAttTemplateException.description.format( UnformattedGetAttTemplateException.description.format(
resource_json['Fn::GetAtt'][0], resource_json['Fn::GetAtt'][1])) resource_json["Fn::GetAtt"][0], resource_json["Fn::GetAtt"][1]
),
)
if 'Fn::If' in resource_json: if "Fn::If" in resource_json:
condition_name, true_value, false_value = resource_json['Fn::If'] condition_name, true_value, false_value = resource_json["Fn::If"]
if resources_map.lazy_condition_map[condition_name]: if resources_map.lazy_condition_map[condition_name]:
return clean_json(true_value, resources_map) return clean_json(true_value, resources_map)
else: else:
return clean_json(false_value, resources_map) return clean_json(false_value, resources_map)
if 'Fn::Join' in resource_json: if "Fn::Join" in resource_json:
join_list = clean_json(resource_json['Fn::Join'][1], resources_map) join_list = clean_json(resource_json["Fn::Join"][1], resources_map)
return resource_json['Fn::Join'][0].join([str(x) for x in join_list]) return resource_json["Fn::Join"][0].join([str(x) for x in join_list])
if 'Fn::Split' in resource_json: if "Fn::Split" in resource_json:
to_split = clean_json(resource_json['Fn::Split'][1], resources_map) to_split = clean_json(resource_json["Fn::Split"][1], resources_map)
return to_split.split(resource_json['Fn::Split'][0]) return to_split.split(resource_json["Fn::Split"][0])
if 'Fn::Select' in resource_json: if "Fn::Select" in resource_json:
select_index = int(resource_json['Fn::Select'][0]) select_index = int(resource_json["Fn::Select"][0])
select_list = clean_json(resource_json['Fn::Select'][1], resources_map) select_list = clean_json(resource_json["Fn::Select"][1], resources_map)
return select_list[select_index] return select_list[select_index]
if 'Fn::Sub' in resource_json: if "Fn::Sub" in resource_json:
if isinstance(resource_json['Fn::Sub'], list): if isinstance(resource_json["Fn::Sub"], list):
warnings.warn( warnings.warn(
"Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation") "Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation"
)
else: else:
fn_sub_value = clean_json(resource_json['Fn::Sub'], resources_map) fn_sub_value = clean_json(resource_json["Fn::Sub"], resources_map)
to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value) to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value)
literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value) literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value)
for sub in to_sub: for sub in to_sub:
if '.' in sub: if "." in sub:
cleaned_ref = clean_json({'Fn::GetAtt': re.findall('(?<=\${)[^"]*?(?=})', sub)[0].split('.')}, resources_map) cleaned_ref = clean_json(
{
"Fn::GetAtt": re.findall('(?<=\${)[^"]*?(?=})', sub)[
0
].split(".")
},
resources_map,
)
else: else:
cleaned_ref = clean_json({'Ref': re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, resources_map) cleaned_ref = clean_json(
{"Ref": re.findall('(?<=\${)[^"]*?(?=})', sub)[0]},
resources_map,
)
fn_sub_value = fn_sub_value.replace(sub, cleaned_ref) fn_sub_value = fn_sub_value.replace(sub, cleaned_ref)
for literal in literals: for literal in literals:
fn_sub_value = fn_sub_value.replace(literal, literal.replace('!', '')) fn_sub_value = fn_sub_value.replace(
literal, literal.replace("!", "")
)
return fn_sub_value return fn_sub_value
pass pass
if 'Fn::ImportValue' in resource_json: if "Fn::ImportValue" in resource_json:
cleaned_val = clean_json(resource_json['Fn::ImportValue'], resources_map) cleaned_val = clean_json(resource_json["Fn::ImportValue"], resources_map)
values = [x.value for x in resources_map.cross_stack_resources.values() if x.name == cleaned_val] values = [
x.value
for x in resources_map.cross_stack_resources.values()
if x.name == cleaned_val
]
if any(values): if any(values):
return values[0] return values[0]
else: else:
raise ExportNotFound(cleaned_val) raise ExportNotFound(cleaned_val)
if 'Fn::GetAZs' in resource_json: if "Fn::GetAZs" in resource_json:
region = resource_json.get('Fn::GetAZs') or DEFAULT_REGION region = resource_json.get("Fn::GetAZs") or DEFAULT_REGION
result = [] result = []
# TODO: make this configurable, to reflect the real AWS AZs # TODO: make this configurable, to reflect the real AWS AZs
for az in ('a', 'b', 'c', 'd'): for az in ("a", "b", "c", "d"):
result.append('%s%s' % (region, az)) result.append("%s%s" % (region, az))
return result return result
cleaned_json = {} cleaned_json = {}
@ -246,58 +268,69 @@ def resource_name_property_from_type(resource_type):
def generate_resource_name(resource_type, stack_name, logical_id): def generate_resource_name(resource_type, stack_name, logical_id):
if resource_type in ["AWS::ElasticLoadBalancingV2::TargetGroup", if resource_type in [
"AWS::ElasticLoadBalancingV2::LoadBalancer"]: "AWS::ElasticLoadBalancingV2::TargetGroup",
"AWS::ElasticLoadBalancingV2::LoadBalancer",
]:
# Target group names need to be less than 32 characters, so when cloudformation creates a name for you # Target group names need to be less than 32 characters, so when cloudformation creates a name for you
# it makes sure to stay under that limit # it makes sure to stay under that limit
name_prefix = '{0}-{1}'.format(stack_name, logical_id) name_prefix = "{0}-{1}".format(stack_name, logical_id)
my_random_suffix = random_suffix() my_random_suffix = random_suffix()
truncated_name_prefix = name_prefix[0:32 - (len(my_random_suffix) + 1)] truncated_name_prefix = name_prefix[0 : 32 - (len(my_random_suffix) + 1)]
# if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is # if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is
# not allowed # not allowed
if truncated_name_prefix.endswith('-'): if truncated_name_prefix.endswith("-"):
truncated_name_prefix = truncated_name_prefix[:-1] truncated_name_prefix = truncated_name_prefix[:-1]
return '{0}-{1}'.format(truncated_name_prefix, my_random_suffix) return "{0}-{1}".format(truncated_name_prefix, my_random_suffix)
else: else:
return '{0}-{1}-{2}'.format(stack_name, logical_id, random_suffix()) return "{0}-{1}-{2}".format(stack_name, logical_id, random_suffix())
def parse_resource(logical_id, resource_json, resources_map): def parse_resource(logical_id, resource_json, resources_map):
resource_type = resource_json['Type'] resource_type = resource_json["Type"]
resource_class = resource_class_from_type(resource_type) resource_class = resource_class_from_type(resource_type)
if not resource_class: if not resource_class:
warnings.warn( warnings.warn(
"Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(resource_type)) "Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(
resource_type
)
)
return None return None
resource_json = clean_json(resource_json, resources_map) resource_json = clean_json(resource_json, resources_map)
resource_name_property = resource_name_property_from_type(resource_type) resource_name_property = resource_name_property_from_type(resource_type)
if resource_name_property: if resource_name_property:
if 'Properties' not in resource_json: if "Properties" not in resource_json:
resource_json['Properties'] = dict() resource_json["Properties"] = dict()
if resource_name_property not in resource_json['Properties']: if resource_name_property not in resource_json["Properties"]:
resource_json['Properties'][resource_name_property] = generate_resource_name( resource_json["Properties"][
resource_type, resources_map.get('AWS::StackName'), logical_id) resource_name_property
resource_name = resource_json['Properties'][resource_name_property] ] = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
resource_name = resource_json["Properties"][resource_name_property]
else: else:
resource_name = generate_resource_name(resource_type, resources_map.get('AWS::StackName'), logical_id) resource_name = generate_resource_name(
resource_type, resources_map.get("AWS::StackName"), logical_id
)
return resource_class, resource_json, resource_name return resource_class, resource_json, resource_name
def parse_and_create_resource(logical_id, resource_json, resources_map, region_name): def parse_and_create_resource(logical_id, resource_json, resources_map, region_name):
condition = resource_json.get('Condition') condition = resource_json.get("Condition")
if condition and not resources_map.lazy_condition_map[condition]: if condition and not resources_map.lazy_condition_map[condition]:
# If this has a False condition, don't create the resource # If this has a False condition, don't create the resource
return None return None
resource_type = resource_json['Type'] resource_type = resource_json["Type"]
resource_tuple = parse_resource(logical_id, resource_json, resources_map) resource_tuple = parse_resource(logical_id, resource_json, resources_map)
if not resource_tuple: if not resource_tuple:
return None return None
resource_class, resource_json, resource_name = resource_tuple resource_class, resource_json, resource_name = resource_tuple
resource = resource_class.create_from_cloudformation_json( resource = resource_class.create_from_cloudformation_json(
resource_name, resource_json, region_name) resource_name, resource_json, region_name
)
resource.type = resource_type resource.type = resource_type
resource.logical_resource_id = logical_id resource.logical_resource_id = logical_id
return resource return resource
@ -305,24 +338,27 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n
def parse_and_update_resource(logical_id, resource_json, resources_map, region_name): def parse_and_update_resource(logical_id, resource_json, resources_map, region_name):
resource_class, new_resource_json, new_resource_name = parse_resource( resource_class, new_resource_json, new_resource_name = parse_resource(
logical_id, resource_json, resources_map) logical_id, resource_json, resources_map
)
original_resource = resources_map[logical_id] original_resource = resources_map[logical_id]
new_resource = resource_class.update_from_cloudformation_json( new_resource = resource_class.update_from_cloudformation_json(
original_resource=original_resource, original_resource=original_resource,
new_resource_name=new_resource_name, new_resource_name=new_resource_name,
cloudformation_json=new_resource_json, cloudformation_json=new_resource_json,
region_name=region_name region_name=region_name,
) )
new_resource.type = resource_json['Type'] new_resource.type = resource_json["Type"]
new_resource.logical_resource_id = logical_id new_resource.logical_resource_id = logical_id
return new_resource return new_resource
def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name): def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name):
resource_class, resource_json, resource_name = parse_resource( resource_class, resource_json, resource_name = parse_resource(
logical_id, resource_json, resources_map) logical_id, resource_json, resources_map
)
resource_class.delete_from_cloudformation_json( resource_class.delete_from_cloudformation_json(
resource_name, resource_json, region_name) resource_name, resource_json, region_name
)
def parse_condition(condition, resources_map, condition_map): def parse_condition(condition, resources_map, condition_map):
@ -334,8 +370,8 @@ def parse_condition(condition, resources_map, condition_map):
condition_values = [] condition_values = []
for value in list(condition.values())[0]: for value in list(condition.values())[0]:
# Check if we are referencing another Condition # Check if we are referencing another Condition
if 'Condition' in value: if "Condition" in value:
condition_values.append(condition_map[value['Condition']]) condition_values.append(condition_map[value["Condition"]])
else: else:
condition_values.append(clean_json(value, resources_map)) condition_values.append(clean_json(value, resources_map))
@ -344,23 +380,27 @@ def parse_condition(condition, resources_map, condition_map):
elif condition_operator == "Fn::Not": elif condition_operator == "Fn::Not":
return not parse_condition(condition_values[0], resources_map, condition_map) return not parse_condition(condition_values[0], resources_map, condition_map)
elif condition_operator == "Fn::And": elif condition_operator == "Fn::And":
return all([ return all(
[
parse_condition(condition_value, resources_map, condition_map) parse_condition(condition_value, resources_map, condition_map)
for condition_value for condition_value in condition_values
in condition_values]) ]
)
elif condition_operator == "Fn::Or": elif condition_operator == "Fn::Or":
return any([ return any(
[
parse_condition(condition_value, resources_map, condition_map) parse_condition(condition_value, resources_map, condition_map)
for condition_value for condition_value in condition_values
in condition_values]) ]
)
def parse_output(output_logical_id, output_json, resources_map): def parse_output(output_logical_id, output_json, resources_map):
output_json = clean_json(output_json, resources_map) output_json = clean_json(output_json, resources_map)
output = Output() output = Output()
output.key = output_logical_id output.key = output_logical_id
output.value = clean_json(output_json['Value'], resources_map) output.value = clean_json(output_json["Value"], resources_map)
output.description = output_json.get('Description') output.description = output_json.get("Description")
return output return output
@ -371,9 +411,18 @@ class ResourceMap(collections.Mapping):
each resources is passed this lazy map that it can grab dependencies from. each resources is passed this lazy map that it can grab dependencies from.
""" """
def __init__(self, stack_id, stack_name, parameters, tags, region_name, template, cross_stack_resources): def __init__(
self,
stack_id,
stack_name,
parameters,
tags,
region_name,
template,
cross_stack_resources,
):
self._template = template self._template = template
self._resource_json_map = template['Resources'] self._resource_json_map = template["Resources"]
self._region_name = region_name self._region_name = region_name
self.input_parameters = parameters self.input_parameters = parameters
self.tags = copy.deepcopy(tags) self.tags = copy.deepcopy(tags)
@ -401,7 +450,8 @@ class ResourceMap(collections.Mapping):
if not resource_json: if not resource_json:
raise KeyError(resource_logical_id) raise KeyError(resource_logical_id)
new_resource = parse_and_create_resource( new_resource = parse_and_create_resource(
resource_logical_id, resource_json, self, self._region_name) resource_logical_id, resource_json, self, self._region_name
)
if new_resource is not None: if new_resource is not None:
self._parsed_resources[resource_logical_id] = new_resource self._parsed_resources[resource_logical_id] = new_resource
return new_resource return new_resource
@ -417,13 +467,13 @@ class ResourceMap(collections.Mapping):
return self._resource_json_map.keys() return self._resource_json_map.keys()
def load_mapping(self): def load_mapping(self):
self._parsed_resources.update(self._template.get('Mappings', {})) self._parsed_resources.update(self._template.get("Mappings", {}))
def load_parameters(self): def load_parameters(self):
parameter_slots = self._template.get('Parameters', {}) parameter_slots = self._template.get("Parameters", {})
for parameter_name, parameter in parameter_slots.items(): for parameter_name, parameter in parameter_slots.items():
# Set the default values. # Set the default values.
self.resolved_parameters[parameter_name] = parameter.get('Default') self.resolved_parameters[parameter_name] = parameter.get("Default")
# Set any input parameters that were passed # Set any input parameters that were passed
self.no_echo_parameter_keys = [] self.no_echo_parameter_keys = []
@ -431,11 +481,11 @@ class ResourceMap(collections.Mapping):
if key in self.resolved_parameters: if key in self.resolved_parameters:
parameter_slot = parameter_slots[key] parameter_slot = parameter_slots[key]
value_type = parameter_slot.get('Type', 'String') value_type = parameter_slot.get("Type", "String")
if value_type == 'CommaDelimitedList' or value_type.startswith("List"): if value_type == "CommaDelimitedList" or value_type.startswith("List"):
value = value.split(',') value = value.split(",")
if parameter_slot.get('NoEcho'): if parameter_slot.get("NoEcho"):
self.no_echo_parameter_keys.append(key) self.no_echo_parameter_keys.append(key)
self.resolved_parameters[key] = value self.resolved_parameters[key] = value
@ -449,11 +499,15 @@ class ResourceMap(collections.Mapping):
self._parsed_resources.update(self.resolved_parameters) self._parsed_resources.update(self.resolved_parameters)
def load_conditions(self): def load_conditions(self):
conditions = self._template.get('Conditions', {}) conditions = self._template.get("Conditions", {})
self.lazy_condition_map = LazyDict() self.lazy_condition_map = LazyDict()
for condition_name, condition in conditions.items(): for condition_name, condition in conditions.items():
self.lazy_condition_map[condition_name] = functools.partial(parse_condition, self.lazy_condition_map[condition_name] = functools.partial(
condition, self._parsed_resources, self.lazy_condition_map) parse_condition,
condition,
self._parsed_resources,
self.lazy_condition_map,
)
for condition_name in self.lazy_condition_map: for condition_name in self.lazy_condition_map:
self.lazy_condition_map[condition_name] self.lazy_condition_map[condition_name]
@ -465,13 +519,18 @@ class ResourceMap(collections.Mapping):
# Since this is a lazy map, to create every object we just need to # Since this is a lazy map, to create every object we just need to
# iterate through self. # iterate through self.
self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'), self.tags.update(
'aws:cloudformation:stack-id': self.get('AWS::StackId')}) {
"aws:cloudformation:stack-name": self.get("AWS::StackName"),
"aws:cloudformation:stack-id": self.get("AWS::StackId"),
}
)
for resource in self.resources: for resource in self.resources:
if isinstance(self[resource], ec2_models.TaggedEC2Resource): if isinstance(self[resource], ec2_models.TaggedEC2Resource):
self.tags['aws:cloudformation:logical-id'] = resource self.tags["aws:cloudformation:logical-id"] = resource
ec2_models.ec2_backends[self._region_name].create_tags( ec2_models.ec2_backends[self._region_name].create_tags(
[self[resource].physical_resource_id], self.tags) [self[resource].physical_resource_id], self.tags
)
def diff(self, template, parameters=None): def diff(self, template, parameters=None):
if parameters: if parameters:
@ -481,36 +540,35 @@ class ResourceMap(collections.Mapping):
self.load_conditions() self.load_conditions()
old_template = self._resource_json_map old_template = self._resource_json_map
new_template = template['Resources'] new_template = template["Resources"]
resource_names_by_action = { resource_names_by_action = {
'Add': set(new_template) - set(old_template), "Add": set(new_template) - set(old_template),
'Modify': set(name for name in new_template if name in old_template and new_template[ "Modify": set(
name] != old_template[name]), name
'Remove': set(old_template) - set(new_template) for name in new_template
if name in old_template and new_template[name] != old_template[name]
),
"Remove": set(old_template) - set(new_template),
} }
resources_by_action = { resources_by_action = {"Add": {}, "Modify": {}, "Remove": {}}
'Add': {},
'Modify': {}, for resource_name in resource_names_by_action["Add"]:
'Remove': {}, resources_by_action["Add"][resource_name] = {
"LogicalResourceId": resource_name,
"ResourceType": new_template[resource_name]["Type"],
} }
for resource_name in resource_names_by_action['Add']: for resource_name in resource_names_by_action["Modify"]:
resources_by_action['Add'][resource_name] = { resources_by_action["Modify"][resource_name] = {
'LogicalResourceId': resource_name, "LogicalResourceId": resource_name,
'ResourceType': new_template[resource_name]['Type'] "ResourceType": new_template[resource_name]["Type"],
} }
for resource_name in resource_names_by_action['Modify']: for resource_name in resource_names_by_action["Remove"]:
resources_by_action['Modify'][resource_name] = { resources_by_action["Remove"][resource_name] = {
'LogicalResourceId': resource_name, "LogicalResourceId": resource_name,
'ResourceType': new_template[resource_name]['Type'] "ResourceType": old_template[resource_name]["Type"],
}
for resource_name in resource_names_by_action['Remove']:
resources_by_action['Remove'][resource_name] = {
'LogicalResourceId': resource_name,
'ResourceType': old_template[resource_name]['Type']
} }
return resources_by_action return resources_by_action
@ -519,35 +577,38 @@ class ResourceMap(collections.Mapping):
resources_by_action = self.diff(template, parameters) resources_by_action = self.diff(template, parameters)
old_template = self._resource_json_map old_template = self._resource_json_map
new_template = template['Resources'] new_template = template["Resources"]
self._resource_json_map = new_template self._resource_json_map = new_template
for resource_name, resource in resources_by_action['Add'].items(): for resource_name, resource in resources_by_action["Add"].items():
resource_json = new_template[resource_name] resource_json = new_template[resource_name]
new_resource = parse_and_create_resource( new_resource = parse_and_create_resource(
resource_name, resource_json, self, self._region_name) resource_name, resource_json, self, self._region_name
)
self._parsed_resources[resource_name] = new_resource self._parsed_resources[resource_name] = new_resource
for resource_name, resource in resources_by_action['Remove'].items(): for resource_name, resource in resources_by_action["Remove"].items():
resource_json = old_template[resource_name] resource_json = old_template[resource_name]
parse_and_delete_resource( parse_and_delete_resource(
resource_name, resource_json, self, self._region_name) resource_name, resource_json, self, self._region_name
)
self._parsed_resources.pop(resource_name) self._parsed_resources.pop(resource_name)
tries = 1 tries = 1
while resources_by_action['Modify'] and tries < 5: while resources_by_action["Modify"] and tries < 5:
for resource_name, resource in resources_by_action['Modify'].copy().items(): for resource_name, resource in resources_by_action["Modify"].copy().items():
resource_json = new_template[resource_name] resource_json = new_template[resource_name]
try: try:
changed_resource = parse_and_update_resource( changed_resource = parse_and_update_resource(
resource_name, resource_json, self, self._region_name) resource_name, resource_json, self, self._region_name
)
except Exception as e: except Exception as e:
# skip over dependency violations, and try again in a # skip over dependency violations, and try again in a
# second pass # second pass
last_exception = e last_exception = e
else: else:
self._parsed_resources[resource_name] = changed_resource self._parsed_resources[resource_name] = changed_resource
del resources_by_action['Modify'][resource_name] del resources_by_action["Modify"][resource_name]
tries += 1 tries += 1
if tries == 5: if tries == 5:
raise last_exception raise last_exception
@ -559,7 +620,7 @@ class ResourceMap(collections.Mapping):
for resource in remaining_resources.copy(): for resource in remaining_resources.copy():
parsed_resource = self._parsed_resources.get(resource) parsed_resource = self._parsed_resources.get(resource)
try: try:
if parsed_resource and hasattr(parsed_resource, 'delete'): if parsed_resource and hasattr(parsed_resource, "delete"):
parsed_resource.delete(self._region_name) parsed_resource.delete(self._region_name)
except Exception as e: except Exception as e:
# skip over dependency violations, and try again in a # skip over dependency violations, and try again in a
@ -573,11 +634,10 @@ class ResourceMap(collections.Mapping):
class OutputMap(collections.Mapping): class OutputMap(collections.Mapping):
def __init__(self, resources, template, stack_id): def __init__(self, resources, template, stack_id):
self._template = template self._template = template
self._stack_id = stack_id self._stack_id = stack_id
self._output_json_map = template.get('Outputs') self._output_json_map = template.get("Outputs")
# Create the default resources # Create the default resources
self._resource_map = resources self._resource_map = resources
@ -591,7 +651,8 @@ class OutputMap(collections.Mapping):
else: else:
output_json = self._output_json_map.get(output_logical_id) output_json = self._output_json_map.get(output_logical_id)
new_output = parse_output( new_output = parse_output(
output_logical_id, output_json, self._resource_map) output_logical_id, output_json, self._resource_map
)
self._parsed_outputs[output_logical_id] = new_output self._parsed_outputs[output_logical_id] = new_output
return new_output return new_output
@ -610,9 +671,11 @@ class OutputMap(collections.Mapping):
exports = [] exports = []
if self.outputs: if self.outputs:
for key, value in self._output_json_map.items(): for key, value in self._output_json_map.items():
if value.get('Export'): if value.get("Export"):
cleaned_name = clean_json(value['Export'].get('Name'), self._resource_map) cleaned_name = clean_json(
cleaned_value = clean_json(value.get('Value'), self._resource_map) value["Export"].get("Name"), self._resource_map
)
cleaned_value = clean_json(value.get("Value"), self._resource_map)
exports.append(Export(self._stack_id, cleaned_name, cleaned_value)) exports.append(Export(self._stack_id, cleaned_name, cleaned_value))
return exports return exports
@ -622,7 +685,6 @@ class OutputMap(collections.Mapping):
class Export(object): class Export(object):
def __init__(self, exporting_stack_id, name, value): def __init__(self, exporting_stack_id, name, value):
self._exporting_stack_id = exporting_stack_id self._exporting_stack_id = exporting_stack_id
self._name = name self._name = name

View File

@ -12,7 +12,6 @@ from .exceptions import ValidationError
class CloudFormationResponse(BaseResponse): class CloudFormationResponse(BaseResponse):
@property @property
def cloudformation_backend(self): def cloudformation_backend(self):
return cloudformation_backends[self.region] return cloudformation_backends[self.region]
@ -20,17 +19,18 @@ class CloudFormationResponse(BaseResponse):
def _get_stack_from_s3_url(self, template_url): def _get_stack_from_s3_url(self, template_url):
template_url_parts = urlparse(template_url) template_url_parts = urlparse(template_url)
if "localhost" in template_url: if "localhost" in template_url:
bucket_name, key_name = template_url_parts.path.lstrip( bucket_name, key_name = template_url_parts.path.lstrip("/").split("/", 1)
"/").split("/", 1)
else: else:
if template_url_parts.netloc.endswith('amazonaws.com') \ if template_url_parts.netloc.endswith(
and template_url_parts.netloc.startswith('s3'): "amazonaws.com"
) and template_url_parts.netloc.startswith("s3"):
# Handle when S3 url uses amazon url with bucket in path # Handle when S3 url uses amazon url with bucket in path
# Also handles getting region as technically s3 is region'd # Also handles getting region as technically s3 is region'd
# region = template_url.netloc.split('.')[1] # region = template_url.netloc.split('.')[1]
bucket_name, key_name = template_url_parts.path.lstrip( bucket_name, key_name = template_url_parts.path.lstrip("/").split(
"/").split("/", 1) "/", 1
)
else: else:
bucket_name = template_url_parts.netloc.split(".")[0] bucket_name = template_url_parts.netloc.split(".")[0]
key_name = template_url_parts.path.lstrip("/") key_name = template_url_parts.path.lstrip("/")
@ -39,24 +39,26 @@ class CloudFormationResponse(BaseResponse):
return key.value.decode("utf-8") return key.value.decode("utf-8")
def create_stack(self): def create_stack(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
role_arn = self._get_param('RoleARN') role_arn = self._get_param("RoleARN")
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Hack dict-comprehension # Hack dict-comprehension
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in parameters_list for parameter in parameters_list
]) ]
)
if template_url: if template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param( stack_notification_arns = self._get_multi_param("NotificationARNs.member")
'NotificationARNs.member')
stack = self.cloudformation_backend.create_stack( stack = self.cloudformation_backend.create_stack(
name=stack_name, name=stack_name,
@ -68,34 +70,37 @@ class CloudFormationResponse(BaseResponse):
role_arn=role_arn, role_arn=role_arn,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'CreateStackResponse': { {
'CreateStackResult': { "CreateStackResponse": {
'StackId': stack.stack_id, "CreateStackResult": {"StackId": stack.stack_id}
} }
} }
}) )
else: else:
template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE) template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE)
return template.render(stack=stack) return template.render(stack=stack)
@amzn_request_id @amzn_request_id
def create_change_set(self): def create_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
role_arn = self._get_param('RoleARN') role_arn = self._get_param("RoleARN")
update_or_create = self._get_param('ChangeSetType', 'CREATE') update_or_create = self._get_param("ChangeSetType", "CREATE")
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
parameters = {param['parameter_key']: param['parameter_value'] for item in self._get_list_prefix("Tags.member")
for param in parameters_list} )
parameters = {
param["parameter_key"]: param["parameter_value"]
for param in parameters_list
}
if template_url: if template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
stack_notification_arns = self._get_multi_param( stack_notification_arns = self._get_multi_param("NotificationARNs.member")
'NotificationARNs.member')
change_set_id, stack_id = self.cloudformation_backend.create_change_set( change_set_id, stack_id = self.cloudformation_backend.create_change_set(
stack_name=stack_name, stack_name=stack_name,
change_set_name=change_set_name, change_set_name=change_set_name,
@ -108,66 +113,64 @@ class CloudFormationResponse(BaseResponse):
change_set_type=update_or_create, change_set_type=update_or_create,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'CreateChangeSetResponse': { {
'CreateChangeSetResult': { "CreateChangeSetResponse": {
'Id': change_set_id, "CreateChangeSetResult": {
'StackId': stack_id, "Id": change_set_id,
"StackId": stack_id,
} }
} }
}) }
)
else: else:
template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(stack_id=stack_id, change_set_id=change_set_id) return template.render(stack_id=stack_id, change_set_id=change_set_id)
def delete_change_set(self): def delete_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.delete_change_set(change_set_name=change_set_name, stack_name=stack_name) self.cloudformation_backend.delete_change_set(
change_set_name=change_set_name, stack_name=stack_name
)
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'DeleteChangeSetResponse': { {"DeleteChangeSetResponse": {"DeleteChangeSetResult": {}}}
'DeleteChangeSetResult': {}, )
}
})
else: else:
template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render() return template.render()
def describe_change_set(self): def describe_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
change_set = self.cloudformation_backend.describe_change_set( change_set = self.cloudformation_backend.describe_change_set(
change_set_name=change_set_name, change_set_name=change_set_name, stack_name=stack_name
stack_name=stack_name,
) )
template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render(change_set=change_set) return template.render(change_set=change_set)
@amzn_request_id @amzn_request_id
def execute_change_set(self): def execute_change_set(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
change_set_name = self._get_param('ChangeSetName') change_set_name = self._get_param("ChangeSetName")
self.cloudformation_backend.execute_change_set( self.cloudformation_backend.execute_change_set(
stack_name=stack_name, stack_name=stack_name, change_set_name=change_set_name
change_set_name=change_set_name,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'ExecuteChangeSetResponse': { {"ExecuteChangeSetResponse": {"ExecuteChangeSetResult": {}}}
'ExecuteChangeSetResult': {}, )
}
})
else: else:
template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE) template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE)
return template.render() return template.render()
def describe_stacks(self): def describe_stacks(self):
stack_name_or_id = None stack_name_or_id = None
if self._get_param('StackName'): if self._get_param("StackName"):
stack_name_or_id = self.querystring.get('StackName')[0] stack_name_or_id = self.querystring.get("StackName")[0]
token = self._get_param('NextToken') token = self._get_param("NextToken")
stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id) stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id)
stack_ids = [stack.stack_id for stack in stacks] stack_ids = [stack.stack_id for stack in stacks]
if token: if token:
@ -175,7 +178,7 @@ class CloudFormationResponse(BaseResponse):
else: else:
start = 0 start = 0
max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB
stacks_resp = stacks[start:start + max_results] stacks_resp = stacks[start : start + max_results]
next_token = None next_token = None
if len(stacks) > (start + max_results): if len(stacks) > (start + max_results):
next_token = stacks_resp[-1].stack_id next_token = stacks_resp[-1].stack_id
@ -183,9 +186,9 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks_resp, next_token=next_token) return template.render(stacks=stacks_resp, next_token=next_token)
def describe_stack_resource(self): def describe_stack_resource(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
logical_resource_id = self._get_param('LogicalResourceId') logical_resource_id = self._get_param("LogicalResourceId")
for stack_resource in stack.stack_resources: for stack_resource in stack.stack_resources:
if stack_resource.logical_resource_id == logical_resource_id: if stack_resource.logical_resource_id == logical_resource_id:
@ -194,19 +197,18 @@ class CloudFormationResponse(BaseResponse):
else: else:
raise ValidationError(logical_resource_id) raise ValidationError(logical_resource_id)
template = self.response_template( template = self.response_template(DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE)
return template.render(stack=stack, resource=resource) return template.render(stack=stack, resource=resource)
def describe_stack_resources(self): def describe_stack_resources(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE) template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE)
return template.render(stack=stack) return template.render(stack=stack)
def describe_stack_events(self): def describe_stack_events(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE) template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE)
@ -223,68 +225,82 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacks=stacks) return template.render(stacks=stacks)
def list_stack_resources(self): def list_stack_resources(self):
stack_name_or_id = self._get_param('StackName') stack_name_or_id = self._get_param("StackName")
resources = self.cloudformation_backend.list_stack_resources( resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id)
stack_name_or_id)
template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE) template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE)
return template.render(resources=resources) return template.render(resources=resources)
def get_template(self): def get_template(self):
name_or_stack_id = self.querystring.get('StackName')[0] name_or_stack_id = self.querystring.get("StackName")[0]
stack = self.cloudformation_backend.get_stack(name_or_stack_id) stack = self.cloudformation_backend.get_stack(name_or_stack_id)
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
{
"GetTemplateResponse": { "GetTemplateResponse": {
"GetTemplateResult": { "GetTemplateResult": {
"TemplateBody": stack.template, "TemplateBody": stack.template,
"ResponseMetadata": { "ResponseMetadata": {
"RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE"
},
} }
} }
} }
}) )
else: else:
template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE) template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE)
return template.render(stack=stack) return template.render(stack=stack)
def update_stack(self): def update_stack(self):
stack_name = self._get_param('StackName') stack_name = self._get_param("StackName")
role_arn = self._get_param('RoleARN') role_arn = self._get_param("RoleARN")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
if self._get_param('UsePreviousTemplate') == "true": if self._get_param("UsePreviousTemplate") == "true":
stack_body = stack.template stack_body = stack.template
elif not stack_body and template_url: elif not stack_body and template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
incoming_params = self._get_list_prefix("Parameters.member") incoming_params = self._get_list_prefix("Parameters.member")
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in incoming_params if 'parameter_value' in parameter for parameter in incoming_params
]) if "parameter_value" in parameter
previous = dict([ ]
(parameter['parameter_key'], stack.parameters[parameter['parameter_key']]) )
for parameter previous = dict(
in incoming_params if 'use_previous_value' in parameter [
]) (
parameter["parameter_key"],
stack.parameters[parameter["parameter_key"]],
)
for parameter in incoming_params
if "use_previous_value" in parameter
]
)
parameters.update(previous) parameters.update(previous)
# boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't # boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't
# end up containing anything we can use to differentiate between passing an empty value versus not # end up containing anything we can use to differentiate between passing an empty value versus not
# passing anything. so until that changes, moto won't be able to clear tags, only update them. # passing anything. so until that changes, moto won't be able to clear tags, only update them.
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# so that if we don't pass the parameter, we don't clear all the tags accidentally # so that if we don't pass the parameter, we don't clear all the tags accidentally
if not tags: if not tags:
tags = None tags = None
stack = self.cloudformation_backend.get_stack(stack_name) stack = self.cloudformation_backend.get_stack(stack_name)
if stack.status == 'ROLLBACK_COMPLETE': if stack.status == "ROLLBACK_COMPLETE":
raise ValidationError( raise ValidationError(
stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id)) stack.stack_id,
message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(
stack.stack_id
),
)
stack = self.cloudformation_backend.update_stack( stack = self.cloudformation_backend.update_stack(
name=stack_name, name=stack_name,
@ -295,11 +311,7 @@ class CloudFormationResponse(BaseResponse):
) )
if self.request_json: if self.request_json:
stack_body = { stack_body = {
'UpdateStackResponse': { "UpdateStackResponse": {"UpdateStackResult": {"StackId": stack.name}}
'UpdateStackResult': {
'StackId': stack.name,
}
}
} }
return json.dumps(stack_body) return json.dumps(stack_body)
else: else:
@ -307,56 +319,57 @@ class CloudFormationResponse(BaseResponse):
return template.render(stack=stack) return template.render(stack=stack)
def delete_stack(self): def delete_stack(self):
name_or_stack_id = self.querystring.get('StackName')[0] name_or_stack_id = self.querystring.get("StackName")[0]
self.cloudformation_backend.delete_stack(name_or_stack_id) self.cloudformation_backend.delete_stack(name_or_stack_id)
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps({"DeleteStackResponse": {"DeleteStackResult": {}}})
'DeleteStackResponse': {
'DeleteStackResult': {},
}
})
else: else:
template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE) template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE)
return template.render() return template.render()
def list_exports(self): def list_exports(self):
token = self._get_param('NextToken') token = self._get_param("NextToken")
exports, next_token = self.cloudformation_backend.list_exports(token=token) exports, next_token = self.cloudformation_backend.list_exports(token=token)
template = self.response_template(LIST_EXPORTS_RESPONSE) template = self.response_template(LIST_EXPORTS_RESPONSE)
return template.render(exports=exports, next_token=next_token) return template.render(exports=exports, next_token=next_token)
def validate_template(self): def validate_template(self):
cfn_lint = self.cloudformation_backend.validate_template(self._get_param('TemplateBody')) cfn_lint = self.cloudformation_backend.validate_template(
self._get_param("TemplateBody")
)
if cfn_lint: if cfn_lint:
raise ValidationError(cfn_lint[0].message) raise ValidationError(cfn_lint[0].message)
description = "" description = ""
try: try:
description = json.loads(self._get_param('TemplateBody'))['Description'] description = json.loads(self._get_param("TemplateBody"))["Description"]
except (ValueError, KeyError): except (ValueError, KeyError):
pass pass
try: try:
description = yaml.load(self._get_param('TemplateBody'))['Description'] description = yaml.load(self._get_param("TemplateBody"))["Description"]
except (yaml.ParserError, KeyError): except (yaml.ParserError, KeyError):
pass pass
template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE) template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE)
return template.render(description=description) return template.render(description=description)
def create_stack_set(self): def create_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stack_body = self._get_param('TemplateBody') stack_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
# role_arn = self._get_param('RoleARN') # role_arn = self._get_param('RoleARN')
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
# Copy-Pasta - Hack dict-comprehension # Copy-Pasta - Hack dict-comprehension
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in parameters_list for parameter in parameters_list
]) ]
)
if template_url: if template_url:
stack_body = self._get_stack_from_s3_url(template_url) stack_body = self._get_stack_from_s3_url(template_url)
@ -368,59 +381,65 @@ class CloudFormationResponse(BaseResponse):
# role_arn=role_arn, # role_arn=role_arn,
) )
if self.request_json: if self.request_json:
return json.dumps({ return json.dumps(
'CreateStackSetResponse': { {
'CreateStackSetResult': { "CreateStackSetResponse": {
'StackSetId': stackset.stackset_id, "CreateStackSetResult": {"StackSetId": stackset.stackset_id}
} }
} }
}) )
else: else:
template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def create_stack_instances(self): def create_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param('ParameterOverrides.member') parameters = self._get_multi_param("ParameterOverrides.member")
self.cloudformation_backend.create_stack_instances(stackset_name, accounts, regions, parameters) self.cloudformation_backend.create_stack_instances(
stackset_name, accounts, regions, parameters
)
template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE) template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE)
return template.render() return template.render()
def delete_stack_set(self): def delete_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
self.cloudformation_backend.delete_stack_set(stackset_name) self.cloudformation_backend.delete_stack_set(stackset_name)
template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE)
return template.render() return template.render()
def delete_stack_instances(self): def delete_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
operation = self.cloudformation_backend.delete_stack_instances(stackset_name, accounts, regions) operation = self.cloudformation_backend.delete_stack_instances(
stackset_name, accounts, regions
)
template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE) template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE)
return template.render(operation=operation) return template.render(operation=operation)
def describe_stack_set(self): def describe_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
if not stackset.admin_role: if not stackset.admin_role:
stackset.admin_role = 'arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole' stackset.admin_role = "arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole"
if not stackset.execution_role: if not stackset.execution_role:
stackset.execution_role = 'AWSCloudFormationStackSetExecutionRole' stackset.execution_role = "AWSCloudFormationStackSetExecutionRole"
template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def describe_stack_instance(self): def describe_stack_instance(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
account = self._get_param('StackInstanceAccount') account = self._get_param("StackInstanceAccount")
region = self._get_param('StackInstanceRegion') region = self._get_param("StackInstanceRegion")
instance = self.cloudformation_backend.get_stack_set(stackset_name).instances.get_instance(account, region) instance = self.cloudformation_backend.get_stack_set(
stackset_name
).instances.get_instance(account, region)
template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE) template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE)
rendered = template.render(instance=instance) rendered = template.render(instance=instance)
return rendered return rendered
@ -431,61 +450,66 @@ class CloudFormationResponse(BaseResponse):
return template.render(stacksets=stacksets) return template.render(stacksets=stacksets)
def list_stack_instances(self): def list_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE) template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def list_stack_set_operations(self): def list_stack_set_operations(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE) template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE)
return template.render(stackset=stackset) return template.render(stackset=stackset)
def stop_stack_set_operation(self): def stop_stack_set_operation(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
stackset.update_operation(operation_id, 'STOPPED') stackset.update_operation(operation_id, "STOPPED")
template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE) template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE)
return template.render() return template.render()
def describe_stack_set_operation(self): def describe_stack_set_operation(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id) operation = stackset.get_operation(operation_id)
template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE) template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE)
return template.render(stackset=stackset, operation=operation) return template.render(stackset=stackset, operation=operation)
def list_stack_set_operation_results(self): def list_stack_set_operation_results(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
stackset = self.cloudformation_backend.get_stack_set(stackset_name) stackset = self.cloudformation_backend.get_stack_set(stackset_name)
operation = stackset.get_operation(operation_id) operation = stackset.get_operation(operation_id)
template = self.response_template(LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE) template = self.response_template(
LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE
)
return template.render(operation=operation) return template.render(operation=operation)
def update_stack_set(self): def update_stack_set(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
operation_id = self._get_param('OperationId') operation_id = self._get_param("OperationId")
description = self._get_param('Description') description = self._get_param("Description")
execution_role = self._get_param('ExecutionRoleName') execution_role = self._get_param("ExecutionRoleName")
admin_role = self._get_param('AdministrationRoleARN') admin_role = self._get_param("AdministrationRoleARN")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
template_body = self._get_param('TemplateBody') template_body = self._get_param("TemplateBody")
template_url = self._get_param('TemplateURL') template_url = self._get_param("TemplateURL")
if template_url: if template_url:
template_body = self._get_stack_from_s3_url(template_url) template_body = self._get_stack_from_s3_url(template_url)
tags = dict((item['key'], item['value']) tags = dict(
for item in self._get_list_prefix("Tags.member")) (item["key"], item["value"])
for item in self._get_list_prefix("Tags.member")
)
parameters_list = self._get_list_prefix("Parameters.member") parameters_list = self._get_list_prefix("Parameters.member")
parameters = dict([ parameters = dict(
(parameter['parameter_key'], parameter['parameter_value']) [
for parameter (parameter["parameter_key"], parameter["parameter_value"])
in parameters_list for parameter in parameters_list
]) ]
)
operation = self.cloudformation_backend.update_stack_set( operation = self.cloudformation_backend.update_stack_set(
stackset_name=stackset_name, stackset_name=stackset_name,
template=template_body, template=template_body,
@ -496,18 +520,20 @@ class CloudFormationResponse(BaseResponse):
execution_role=execution_role, execution_role=execution_role,
accounts=accounts, accounts=accounts,
regions=regions, regions=regions,
operation_id=operation_id operation_id=operation_id,
) )
template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE) template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE)
return template.render(operation=operation) return template.render(operation=operation)
def update_stack_instances(self): def update_stack_instances(self):
stackset_name = self._get_param('StackSetName') stackset_name = self._get_param("StackSetName")
accounts = self._get_multi_param('Accounts.member') accounts = self._get_multi_param("Accounts.member")
regions = self._get_multi_param('Regions.member') regions = self._get_multi_param("Regions.member")
parameters = self._get_multi_param('ParameterOverrides.member') parameters = self._get_multi_param("ParameterOverrides.member")
operation = self.cloudformation_backend.get_stack_set(stackset_name).update_instances(accounts, regions, parameters) operation = self.cloudformation_backend.get_stack_set(
stackset_name
).update_instances(accounts, regions, parameters)
template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE) template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE)
return template.render(operation=operation) return template.render(operation=operation)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import CloudFormationResponse from .responses import CloudFormationResponse
url_bases = [ url_bases = ["https?://cloudformation.(.+).amazonaws.com"]
"https?://cloudformation.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": CloudFormationResponse.dispatch}
'{0}/$': CloudFormationResponse.dispatch,
}

View File

@ -11,44 +11,51 @@ from cfnlint import decode, core
def generate_stack_id(stack_name, region="us-east-1", account="123456789"): def generate_stack_id(stack_name, region="us-east-1", account="123456789"):
random_id = uuid.uuid4() random_id = uuid.uuid4()
return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(region, account, stack_name, random_id) return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(
region, account, stack_name, random_id
)
def generate_changeset_id(changeset_name, region_name): def generate_changeset_id(changeset_name, region_name):
random_id = uuid.uuid4() random_id = uuid.uuid4()
return 'arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}'.format(region_name, changeset_name, random_id) return "arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}".format(
region_name, changeset_name, random_id
)
def generate_stackset_id(stackset_name): def generate_stackset_id(stackset_name):
random_id = uuid.uuid4() random_id = uuid.uuid4()
return '{}:{}'.format(stackset_name, random_id) return "{}:{}".format(stackset_name, random_id)
def generate_stackset_arn(stackset_id, region_name): def generate_stackset_arn(stackset_id, region_name):
return 'arn:aws:cloudformation:{}:123456789012:stackset/{}'.format(region_name, stackset_id) return "arn:aws:cloudformation:{}:123456789012:stackset/{}".format(
region_name, stackset_id
)
def random_suffix(): def random_suffix():
size = 12 size = 12
chars = list(range(10)) + list(string.ascii_uppercase) chars = list(range(10)) + list(string.ascii_uppercase)
return ''.join(six.text_type(random.choice(chars)) for x in range(size)) return "".join(six.text_type(random.choice(chars)) for x in range(size))
def yaml_tag_constructor(loader, tag, node): def yaml_tag_constructor(loader, tag, node):
"""convert shorthand intrinsic function to full name """convert shorthand intrinsic function to full name
""" """
def _f(loader, tag, node): def _f(loader, tag, node):
if tag == '!GetAtt': if tag == "!GetAtt":
return node.value.split('.') return node.value.split(".")
elif type(node) == yaml.SequenceNode: elif type(node) == yaml.SequenceNode:
return loader.construct_sequence(node) return loader.construct_sequence(node)
else: else:
return node.value return node.value
if tag == '!Ref': if tag == "!Ref":
key = 'Ref' key = "Ref"
else: else:
key = 'Fn::{}'.format(tag[1:]) key = "Fn::{}".format(tag[1:])
return {key: _f(loader, tag, node)} return {key: _f(loader, tag, node)}
@ -71,13 +78,9 @@ def validate_template_cfn_lint(template):
rules = core.get_rules([], [], []) rules = core.get_rules([], [], [])
# Use us-east-1 region (spec file) for validation # Use us-east-1 region (spec file) for validation
regions = ['us-east-1'] regions = ["us-east-1"]
# Process all the rules and gather the errors # Process all the rules and gather the errors
matches = core.run_checks( matches = core.run_checks(abs_filename, template, rules, regions)
abs_filename,
template,
rules,
regions)
return matches return matches

View File

@ -1,6 +1,6 @@
from .models import cloudwatch_backends from .models import cloudwatch_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cloudwatch_backend = cloudwatch_backends['us-east-1'] cloudwatch_backend = cloudwatch_backends["us-east-1"]
mock_cloudwatch = base_decorator(cloudwatch_backends) mock_cloudwatch = base_decorator(cloudwatch_backends)
mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends) mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends)

View File

@ -1,4 +1,3 @@
import json import json
from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core.utils import iso_8601_datetime_with_milliseconds
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
@ -14,7 +13,6 @@ _EMPTY_LIST = tuple()
class Dimension(object): class Dimension(object):
def __init__(self, name, value): def __init__(self, name, value):
self.name = name self.name = name
self.value = value self.value = value
@ -49,10 +47,23 @@ def daterange(start, stop, step=timedelta(days=1), inclusive=False):
class FakeAlarm(BaseModel): class FakeAlarm(BaseModel):
def __init__(
def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods, self,
period, threshold, statistic, description, dimensions, alarm_actions, name,
ok_actions, insufficient_data_actions, unit): namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
):
self.name = name self.name = name
self.namespace = namespace self.namespace = namespace
self.metric_name = metric_name self.metric_name = metric_name
@ -62,8 +73,9 @@ class FakeAlarm(BaseModel):
self.threshold = threshold self.threshold = threshold
self.statistic = statistic self.statistic = statistic
self.description = description self.description = description
self.dimensions = [Dimension(dimension['name'], dimension[ self.dimensions = [
'value']) for dimension in dimensions] Dimension(dimension["name"], dimension["value"]) for dimension in dimensions
]
self.alarm_actions = alarm_actions self.alarm_actions = alarm_actions
self.ok_actions = ok_actions self.ok_actions = ok_actions
self.insufficient_data_actions = insufficient_data_actions self.insufficient_data_actions = insufficient_data_actions
@ -72,15 +84,21 @@ class FakeAlarm(BaseModel):
self.history = [] self.history = []
self.state_reason = '' self.state_reason = ""
self.state_reason_data = '{}' self.state_reason_data = "{}"
self.state_value = 'OK' self.state_value = "OK"
self.state_updated_timestamp = datetime.utcnow() self.state_updated_timestamp = datetime.utcnow()
def update_state(self, reason, reason_data, state_value): def update_state(self, reason, reason_data, state_value):
# History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action # History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action
self.history.append( self.history.append(
('StateUpdate', self.state_reason, self.state_reason_data, self.state_value, self.state_updated_timestamp) (
"StateUpdate",
self.state_reason,
self.state_reason_data,
self.state_value,
self.state_updated_timestamp,
)
) )
self.state_reason = reason self.state_reason = reason
@ -90,14 +108,14 @@ class FakeAlarm(BaseModel):
class MetricDatum(BaseModel): class MetricDatum(BaseModel):
def __init__(self, namespace, name, value, dimensions, timestamp): def __init__(self, namespace, name, value, dimensions, timestamp):
self.namespace = namespace self.namespace = namespace
self.name = name self.name = name
self.value = value self.value = value
self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc()) self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc())
self.dimensions = [Dimension(dimension['Name'], dimension[ self.dimensions = [
'Value']) for dimension in dimensions] Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions
]
class Dashboard(BaseModel): class Dashboard(BaseModel):
@ -120,7 +138,7 @@ class Dashboard(BaseModel):
return len(self.body) return len(self.body)
def __repr__(self): def __repr__(self):
return '<CloudWatchDashboard {0}>'.format(self.name) return "<CloudWatchDashboard {0}>".format(self.name)
class Statistics: class Statistics:
@ -131,7 +149,7 @@ class Statistics:
@property @property
def sample_count(self): def sample_count(self):
if 'SampleCount' not in self.stats: if "SampleCount" not in self.stats:
return None return None
return len(self.values) return len(self.values)
@ -142,28 +160,28 @@ class Statistics:
@property @property
def sum(self): def sum(self):
if 'Sum' not in self.stats: if "Sum" not in self.stats:
return None return None
return sum(self.values) return sum(self.values)
@property @property
def minimum(self): def minimum(self):
if 'Minimum' not in self.stats: if "Minimum" not in self.stats:
return None return None
return min(self.values) return min(self.values)
@property @property
def maximum(self): def maximum(self):
if 'Maximum' not in self.stats: if "Maximum" not in self.stats:
return None return None
return max(self.values) return max(self.values)
@property @property
def average(self): def average(self):
if 'Average' not in self.stats: if "Average" not in self.stats:
return None return None
# when moto is 3.4+ we can switch to the statistics module # when moto is 3.4+ we can switch to the statistics module
@ -171,18 +189,44 @@ class Statistics:
class CloudWatchBackend(BaseBackend): class CloudWatchBackend(BaseBackend):
def __init__(self): def __init__(self):
self.alarms = {} self.alarms = {}
self.dashboards = {} self.dashboards = {}
self.metric_data = [] self.metric_data = []
def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods, def put_metric_alarm(
period, threshold, statistic, description, dimensions, self,
alarm_actions, ok_actions, insufficient_data_actions, unit): name,
alarm = FakeAlarm(name, namespace, metric_name, comparison_operator, evaluation_periods, period, namespace,
threshold, statistic, description, dimensions, alarm_actions, metric_name,
ok_actions, insufficient_data_actions, unit) comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
):
alarm = FakeAlarm(
name,
namespace,
metric_name,
comparison_operator,
evaluation_periods,
period,
threshold,
statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions,
unit,
)
self.alarms[name] = alarm self.alarms[name] = alarm
return alarm return alarm
@ -214,14 +258,12 @@ class CloudWatchBackend(BaseBackend):
] ]
def get_alarms_by_alarm_names(self, alarm_names): def get_alarms_by_alarm_names(self, alarm_names):
return [ return [alarm for alarm in self.alarms.values() if alarm.name in alarm_names]
alarm
for alarm in self.alarms.values()
if alarm.name in alarm_names
]
def get_alarms_by_state_value(self, target_state): def get_alarms_by_state_value(self, target_state):
return filter(lambda alarm: alarm.state_value == target_state, self.alarms.values()) return filter(
lambda alarm: alarm.state_value == target_state, self.alarms.values()
)
def delete_alarms(self, alarm_names): def delete_alarms(self, alarm_names):
for alarm_name in alarm_names: for alarm_name in alarm_names:
@ -230,17 +272,31 @@ class CloudWatchBackend(BaseBackend):
def put_metric_data(self, namespace, metric_data): def put_metric_data(self, namespace, metric_data):
for metric_member in metric_data: for metric_member in metric_data:
# Preserve "datetime" for get_metric_statistics comparisons # Preserve "datetime" for get_metric_statistics comparisons
timestamp = metric_member.get('Timestamp') timestamp = metric_member.get("Timestamp")
if timestamp is not None and type(timestamp) != datetime: if timestamp is not None and type(timestamp) != datetime:
timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
timestamp = timestamp.replace(tzinfo=tzutc()) timestamp = timestamp.replace(tzinfo=tzutc())
self.metric_data.append(MetricDatum( self.metric_data.append(
namespace, metric_member['MetricName'], float(metric_member.get('Value', 0)), metric_member.get('Dimensions.member', _EMPTY_LIST), timestamp)) MetricDatum(
namespace,
metric_member["MetricName"],
float(metric_member.get("Value", 0)),
metric_member.get("Dimensions.member", _EMPTY_LIST),
timestamp,
)
)
def get_metric_statistics(self, namespace, metric_name, start_time, end_time, period, stats): def get_metric_statistics(
self, namespace, metric_name, start_time, end_time, period, stats
):
period_delta = timedelta(seconds=period) period_delta = timedelta(seconds=period)
filtered_data = [md for md in self.metric_data if filtered_data = [
md.namespace == namespace and md.name == metric_name and start_time <= md.timestamp <= end_time] md
for md in self.metric_data
if md.namespace == namespace
and md.name == metric_name
and start_time <= md.timestamp <= end_time
]
# earliest to oldest # earliest to oldest
filtered_data = sorted(filtered_data, key=lambda x: x.timestamp) filtered_data = sorted(filtered_data, key=lambda x: x.timestamp)
@ -249,9 +305,15 @@ class CloudWatchBackend(BaseBackend):
idx = 0 idx = 0
data = list() data = list()
for dt in daterange(filtered_data[0].timestamp, filtered_data[-1].timestamp + period_delta, period_delta): for dt in daterange(
filtered_data[0].timestamp,
filtered_data[-1].timestamp + period_delta,
period_delta,
):
s = Statistics(stats, dt) s = Statistics(stats, dt)
while idx < len(filtered_data) and filtered_data[idx].timestamp < (dt + period_delta): while idx < len(filtered_data) and filtered_data[idx].timestamp < (
dt + period_delta
):
s.values.append(filtered_data[idx].value) s.values.append(filtered_data[idx].value)
idx += 1 idx += 1
@ -268,7 +330,7 @@ class CloudWatchBackend(BaseBackend):
def put_dashboard(self, name, body): def put_dashboard(self, name, body):
self.dashboards[name] = Dashboard(name, body) self.dashboards[name] = Dashboard(name, body)
def list_dashboards(self, prefix=''): def list_dashboards(self, prefix=""):
for key, value in self.dashboards.items(): for key, value in self.dashboards.items():
if key.startswith(prefix): if key.startswith(prefix):
yield value yield value
@ -280,7 +342,12 @@ class CloudWatchBackend(BaseBackend):
left_over = to_delete - all_dashboards left_over = to_delete - all_dashboards
if len(left_over) > 0: if len(left_over) > 0:
# Some dashboards are not found # Some dashboards are not found
return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over)) return (
False,
"The specified dashboard does not exist. [{0}]".format(
", ".join(left_over)
),
)
for dashboard in to_delete: for dashboard in to_delete:
del self.dashboards[dashboard] del self.dashboards[dashboard]
@ -295,32 +362,36 @@ class CloudWatchBackend(BaseBackend):
if reason_data is not None: if reason_data is not None:
json.loads(reason_data) json.loads(reason_data)
except ValueError: except ValueError:
raise RESTError('InvalidFormat', 'StateReasonData is invalid JSON') raise RESTError("InvalidFormat", "StateReasonData is invalid JSON")
if alarm_name not in self.alarms: if alarm_name not in self.alarms:
raise RESTError('ResourceNotFound', 'Alarm {0} not found'.format(alarm_name), status=404) raise RESTError(
"ResourceNotFound", "Alarm {0} not found".format(alarm_name), status=404
)
if state_value not in ('OK', 'ALARM', 'INSUFFICIENT_DATA'): if state_value not in ("OK", "ALARM", "INSUFFICIENT_DATA"):
raise RESTError('InvalidParameterValue', 'StateValue is not one of OK | ALARM | INSUFFICIENT_DATA') raise RESTError(
"InvalidParameterValue",
"StateValue is not one of OK | ALARM | INSUFFICIENT_DATA",
)
self.alarms[alarm_name].update_state(reason, reason_data, state_value) self.alarms[alarm_name].update_state(reason, reason_data, state_value)
class LogGroup(BaseModel): class LogGroup(BaseModel):
def __init__(self, spec): def __init__(self, spec):
# required # required
self.name = spec['LogGroupName'] self.name = spec["LogGroupName"]
# optional # optional
self.tags = spec.get('Tags', []) self.tags = spec.get("Tags", [])
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
properties = cloudformation_json['Properties'] cls, resource_name, cloudformation_json, region_name
spec = { ):
'LogGroupName': properties['LogGroupName'] properties = cloudformation_json["Properties"]
} spec = {"LogGroupName": properties["LogGroupName"]}
optional_properties = 'Tags'.split() optional_properties = "Tags".split()
for prop in optional_properties: for prop in optional_properties:
if prop in properties: if prop in properties:
spec[prop] = properties[prop] spec[prop] = properties[prop]

View File

@ -6,7 +6,6 @@ from dateutil.parser import parse as dtparse
class CloudWatchResponse(BaseResponse): class CloudWatchResponse(BaseResponse):
@property @property
def cloudwatch_backend(self): def cloudwatch_backend(self):
return cloudwatch_backends[self.region] return cloudwatch_backends[self.region]
@ -17,45 +16,54 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def put_metric_alarm(self): def put_metric_alarm(self):
name = self._get_param('AlarmName') name = self._get_param("AlarmName")
namespace = self._get_param('Namespace') namespace = self._get_param("Namespace")
metric_name = self._get_param('MetricName') metric_name = self._get_param("MetricName")
comparison_operator = self._get_param('ComparisonOperator') comparison_operator = self._get_param("ComparisonOperator")
evaluation_periods = self._get_param('EvaluationPeriods') evaluation_periods = self._get_param("EvaluationPeriods")
period = self._get_param('Period') period = self._get_param("Period")
threshold = self._get_param('Threshold') threshold = self._get_param("Threshold")
statistic = self._get_param('Statistic') statistic = self._get_param("Statistic")
description = self._get_param('AlarmDescription') description = self._get_param("AlarmDescription")
dimensions = self._get_list_prefix('Dimensions.member') dimensions = self._get_list_prefix("Dimensions.member")
alarm_actions = self._get_multi_param('AlarmActions.member') alarm_actions = self._get_multi_param("AlarmActions.member")
ok_actions = self._get_multi_param('OKActions.member') ok_actions = self._get_multi_param("OKActions.member")
insufficient_data_actions = self._get_multi_param( insufficient_data_actions = self._get_multi_param(
"InsufficientDataActions.member") "InsufficientDataActions.member"
unit = self._get_param('Unit') )
alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name, unit = self._get_param("Unit")
alarm = self.cloudwatch_backend.put_metric_alarm(
name,
namespace,
metric_name,
comparison_operator, comparison_operator,
evaluation_periods, period, evaluation_periods,
threshold, statistic, period,
description, dimensions, threshold,
alarm_actions, ok_actions, statistic,
description,
dimensions,
alarm_actions,
ok_actions,
insufficient_data_actions, insufficient_data_actions,
unit) unit,
)
template = self.response_template(PUT_METRIC_ALARM_TEMPLATE) template = self.response_template(PUT_METRIC_ALARM_TEMPLATE)
return template.render(alarm=alarm) return template.render(alarm=alarm)
@amzn_request_id @amzn_request_id
def describe_alarms(self): def describe_alarms(self):
action_prefix = self._get_param('ActionPrefix') action_prefix = self._get_param("ActionPrefix")
alarm_name_prefix = self._get_param('AlarmNamePrefix') alarm_name_prefix = self._get_param("AlarmNamePrefix")
alarm_names = self._get_multi_param('AlarmNames.member') alarm_names = self._get_multi_param("AlarmNames.member")
state_value = self._get_param('StateValue') state_value = self._get_param("StateValue")
if action_prefix: if action_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_action_prefix( alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(action_prefix)
action_prefix)
elif alarm_name_prefix: elif alarm_name_prefix:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix( alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix(
alarm_name_prefix) alarm_name_prefix
)
elif alarm_names: elif alarm_names:
alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names) alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names)
elif state_value: elif state_value:
@ -68,15 +76,15 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def delete_alarms(self): def delete_alarms(self):
alarm_names = self._get_multi_param('AlarmNames.member') alarm_names = self._get_multi_param("AlarmNames.member")
self.cloudwatch_backend.delete_alarms(alarm_names) self.cloudwatch_backend.delete_alarms(alarm_names)
template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE) template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE)
return template.render() return template.render()
@amzn_request_id @amzn_request_id
def put_metric_data(self): def put_metric_data(self):
namespace = self._get_param('Namespace') namespace = self._get_param("Namespace")
metric_data = self._get_multi_param('MetricData.member') metric_data = self._get_multi_param("MetricData.member")
self.cloudwatch_backend.put_metric_data(namespace, metric_data) self.cloudwatch_backend.put_metric_data(namespace, metric_data)
template = self.response_template(PUT_METRIC_DATA_TEMPLATE) template = self.response_template(PUT_METRIC_DATA_TEMPLATE)
@ -84,25 +92,29 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def get_metric_statistics(self): def get_metric_statistics(self):
namespace = self._get_param('Namespace') namespace = self._get_param("Namespace")
metric_name = self._get_param('MetricName') metric_name = self._get_param("MetricName")
start_time = dtparse(self._get_param('StartTime')) start_time = dtparse(self._get_param("StartTime"))
end_time = dtparse(self._get_param('EndTime')) end_time = dtparse(self._get_param("EndTime"))
period = int(self._get_param('Period')) period = int(self._get_param("Period"))
statistics = self._get_multi_param("Statistics.member") statistics = self._get_multi_param("Statistics.member")
# Unsupported Parameters (To Be Implemented) # Unsupported Parameters (To Be Implemented)
unit = self._get_param('Unit') unit = self._get_param("Unit")
extended_statistics = self._get_param('ExtendedStatistics') extended_statistics = self._get_param("ExtendedStatistics")
dimensions = self._get_param('Dimensions') dimensions = self._get_param("Dimensions")
if unit or extended_statistics or dimensions: if unit or extended_statistics or dimensions:
raise NotImplementedError() raise NotImplementedError()
# TODO: this should instead throw InvalidParameterCombination # TODO: this should instead throw InvalidParameterCombination
if not statistics: if not statistics:
raise NotImplementedError("Must specify either Statistics or ExtendedStatistics") raise NotImplementedError(
"Must specify either Statistics or ExtendedStatistics"
)
datapoints = self.cloudwatch_backend.get_metric_statistics(namespace, metric_name, start_time, end_time, period, statistics) datapoints = self.cloudwatch_backend.get_metric_statistics(
namespace, metric_name, start_time, end_time, period, statistics
)
template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE) template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE)
return template.render(label=metric_name, datapoints=datapoints) return template.render(label=metric_name, datapoints=datapoints)
@ -114,13 +126,13 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def delete_dashboards(self): def delete_dashboards(self):
dashboards = self._get_multi_param('DashboardNames.member') dashboards = self._get_multi_param("DashboardNames.member")
if dashboards is None: if dashboards is None:
return self._error('InvalidParameterValue', 'Need at least 1 dashboard') return self._error("InvalidParameterValue", "Need at least 1 dashboard")
status, error = self.cloudwatch_backend.delete_dashboards(dashboards) status, error = self.cloudwatch_backend.delete_dashboards(dashboards)
if not status: if not status:
return self._error('ResourceNotFound', error) return self._error("ResourceNotFound", error)
template = self.response_template(DELETE_DASHBOARD_TEMPLATE) template = self.response_template(DELETE_DASHBOARD_TEMPLATE)
return template.render() return template.render()
@ -143,18 +155,18 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def get_dashboard(self): def get_dashboard(self):
dashboard_name = self._get_param('DashboardName') dashboard_name = self._get_param("DashboardName")
dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name) dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name)
if dashboard is None: if dashboard is None:
return self._error('ResourceNotFound', 'Dashboard does not exist') return self._error("ResourceNotFound", "Dashboard does not exist")
template = self.response_template(GET_DASHBOARD_TEMPLATE) template = self.response_template(GET_DASHBOARD_TEMPLATE)
return template.render(dashboard=dashboard) return template.render(dashboard=dashboard)
@amzn_request_id @amzn_request_id
def list_dashboards(self): def list_dashboards(self):
prefix = self._get_param('DashboardNamePrefix', '') prefix = self._get_param("DashboardNamePrefix", "")
dashboards = self.cloudwatch_backend.list_dashboards(prefix) dashboards = self.cloudwatch_backend.list_dashboards(prefix)
@ -163,13 +175,13 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def put_dashboard(self): def put_dashboard(self):
name = self._get_param('DashboardName') name = self._get_param("DashboardName")
body = self._get_param('DashboardBody') body = self._get_param("DashboardBody")
try: try:
json.loads(body) json.loads(body)
except ValueError: except ValueError:
return self._error('InvalidParameterInput', 'Body is invalid JSON') return self._error("InvalidParameterInput", "Body is invalid JSON")
self.cloudwatch_backend.put_dashboard(name, body) self.cloudwatch_backend.put_dashboard(name, body)
@ -178,12 +190,14 @@ class CloudWatchResponse(BaseResponse):
@amzn_request_id @amzn_request_id
def set_alarm_state(self): def set_alarm_state(self):
alarm_name = self._get_param('AlarmName') alarm_name = self._get_param("AlarmName")
reason = self._get_param('StateReason') reason = self._get_param("StateReason")
reason_data = self._get_param('StateReasonData') reason_data = self._get_param("StateReasonData")
state_value = self._get_param('StateValue') state_value = self._get_param("StateValue")
self.cloudwatch_backend.set_alarm_state(alarm_name, reason, reason_data, state_value) self.cloudwatch_backend.set_alarm_state(
alarm_name, reason, reason_data, state_value
)
template = self.response_template(SET_ALARM_STATE_TEMPLATE) template = self.response_template(SET_ALARM_STATE_TEMPLATE)
return template.render() return template.render()

View File

@ -1,9 +1,5 @@
from .responses import CloudWatchResponse from .responses import CloudWatchResponse
url_bases = [ url_bases = ["https?://monitoring.(.+).amazonaws.com"]
"https?://monitoring.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": CloudWatchResponse.dispatch}
'{0}/$': CloudWatchResponse.dispatch,
}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import cognitoidentity_backends from .models import cognitoidentity_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
cognitoidentity_backend = cognitoidentity_backends['us-east-1'] cognitoidentity_backend = cognitoidentity_backends["us-east-1"]
mock_cognitoidentity = base_decorator(cognitoidentity_backends) mock_cognitoidentity = base_decorator(cognitoidentity_backends)
mock_cognitoidentity_deprecated = deprecated_base_decorator(cognitoidentity_backends) mock_cognitoidentity_deprecated = deprecated_base_decorator(cognitoidentity_backends)

View File

@ -6,10 +6,8 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest): class ResourceNotFoundError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(ResourceNotFoundError, self).__init__() super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "ResourceNotFoundException"}
'__type': 'ResourceNotFoundException', )
})

View File

@ -13,22 +13,24 @@ from .utils import get_random_identity_id
class CognitoIdentity(BaseModel): class CognitoIdentity(BaseModel):
def __init__(self, region, identity_pool_name, **kwargs): def __init__(self, region, identity_pool_name, **kwargs):
self.identity_pool_name = identity_pool_name self.identity_pool_name = identity_pool_name
self.allow_unauthenticated_identities = kwargs.get('allow_unauthenticated_identities', '') self.allow_unauthenticated_identities = kwargs.get(
self.supported_login_providers = kwargs.get('supported_login_providers', {}) "allow_unauthenticated_identities", ""
self.developer_provider_name = kwargs.get('developer_provider_name', '') )
self.open_id_connect_provider_arns = kwargs.get('open_id_connect_provider_arns', []) self.supported_login_providers = kwargs.get("supported_login_providers", {})
self.cognito_identity_providers = kwargs.get('cognito_identity_providers', []) self.developer_provider_name = kwargs.get("developer_provider_name", "")
self.saml_provider_arns = kwargs.get('saml_provider_arns', []) self.open_id_connect_provider_arns = kwargs.get(
"open_id_connect_provider_arns", []
)
self.cognito_identity_providers = kwargs.get("cognito_identity_providers", [])
self.saml_provider_arns = kwargs.get("saml_provider_arns", [])
self.identity_pool_id = get_random_identity_id(region) self.identity_pool_id = get_random_identity_id(region)
self.creation_time = datetime.datetime.utcnow() self.creation_time = datetime.datetime.utcnow()
class CognitoIdentityBackend(BaseBackend): class CognitoIdentityBackend(BaseBackend):
def __init__(self, region): def __init__(self, region):
super(CognitoIdentityBackend, self).__init__() super(CognitoIdentityBackend, self).__init__()
self.region = region self.region = region
@ -45,47 +47,61 @@ class CognitoIdentityBackend(BaseBackend):
if not identity_pool: if not identity_pool:
raise ResourceNotFoundError(identity_pool) raise ResourceNotFoundError(identity_pool)
response = json.dumps({ response = json.dumps(
'AllowUnauthenticatedIdentities': identity_pool.allow_unauthenticated_identities, {
'CognitoIdentityProviders': identity_pool.cognito_identity_providers, "AllowUnauthenticatedIdentities": identity_pool.allow_unauthenticated_identities,
'DeveloperProviderName': identity_pool.developer_provider_name, "CognitoIdentityProviders": identity_pool.cognito_identity_providers,
'IdentityPoolId': identity_pool.identity_pool_id, "DeveloperProviderName": identity_pool.developer_provider_name,
'IdentityPoolName': identity_pool.identity_pool_name, "IdentityPoolId": identity_pool.identity_pool_id,
'IdentityPoolTags': {}, "IdentityPoolName": identity_pool.identity_pool_name,
'OpenIdConnectProviderARNs': identity_pool.open_id_connect_provider_arns, "IdentityPoolTags": {},
'SamlProviderARNs': identity_pool.saml_provider_arns, "OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns,
'SupportedLoginProviders': identity_pool.supported_login_providers "SamlProviderARNs": identity_pool.saml_provider_arns,
}) "SupportedLoginProviders": identity_pool.supported_login_providers,
}
)
return response return response
def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, def create_identity_pool(
supported_login_providers, developer_provider_name, open_id_connect_provider_arns, self,
cognito_identity_providers, saml_provider_arns): identity_pool_name,
new_identity = CognitoIdentity(self.region, identity_pool_name, allow_unauthenticated_identities,
supported_login_providers,
developer_provider_name,
open_id_connect_provider_arns,
cognito_identity_providers,
saml_provider_arns,
):
new_identity = CognitoIdentity(
self.region,
identity_pool_name,
allow_unauthenticated_identities=allow_unauthenticated_identities, allow_unauthenticated_identities=allow_unauthenticated_identities,
supported_login_providers=supported_login_providers, supported_login_providers=supported_login_providers,
developer_provider_name=developer_provider_name, developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns, open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers, cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns) saml_provider_arns=saml_provider_arns,
)
self.identity_pools[new_identity.identity_pool_id] = new_identity self.identity_pools[new_identity.identity_pool_id] = new_identity
response = json.dumps({ response = json.dumps(
'IdentityPoolId': new_identity.identity_pool_id, {
'IdentityPoolName': new_identity.identity_pool_name, "IdentityPoolId": new_identity.identity_pool_id,
'AllowUnauthenticatedIdentities': new_identity.allow_unauthenticated_identities, "IdentityPoolName": new_identity.identity_pool_name,
'SupportedLoginProviders': new_identity.supported_login_providers, "AllowUnauthenticatedIdentities": new_identity.allow_unauthenticated_identities,
'DeveloperProviderName': new_identity.developer_provider_name, "SupportedLoginProviders": new_identity.supported_login_providers,
'OpenIdConnectProviderARNs': new_identity.open_id_connect_provider_arns, "DeveloperProviderName": new_identity.developer_provider_name,
'CognitoIdentityProviders': new_identity.cognito_identity_providers, "OpenIdConnectProviderARNs": new_identity.open_id_connect_provider_arns,
'SamlProviderARNs': new_identity.saml_provider_arns "CognitoIdentityProviders": new_identity.cognito_identity_providers,
}) "SamlProviderARNs": new_identity.saml_provider_arns,
}
)
return response return response
def get_id(self): def get_id(self):
identity_id = {'IdentityId': get_random_identity_id(self.region)} identity_id = {"IdentityId": get_random_identity_id(self.region)}
return json.dumps(identity_id) return json.dumps(identity_id)
def get_credentials_for_identity(self, identity_id): def get_credentials_for_identity(self, identity_id):
@ -95,31 +111,26 @@ class CognitoIdentityBackend(BaseBackend):
expiration_str = str(iso_8601_datetime_with_milliseconds(expiration)) expiration_str = str(iso_8601_datetime_with_milliseconds(expiration))
response = json.dumps( response = json.dumps(
{ {
"Credentials": "Credentials": {
{
"AccessKeyId": "TESTACCESSKEY12345", "AccessKeyId": "TESTACCESSKEY12345",
"Expiration": expiration_str, "Expiration": expiration_str,
"SecretKey": "ABCSECRETKEY", "SecretKey": "ABCSECRETKEY",
"SessionToken": "ABC12345" "SessionToken": "ABC12345",
}, },
"IdentityId": identity_id "IdentityId": identity_id,
}) }
)
return response return response
def get_open_id_token_for_developer_identity(self, identity_id): def get_open_id_token_for_developer_identity(self, identity_id):
response = json.dumps( response = json.dumps(
{ {"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
"IdentityId": identity_id, )
"Token": get_random_identity_id(self.region)
})
return response return response
def get_open_id_token(self, identity_id): def get_open_id_token(self, identity_id):
response = json.dumps( response = json.dumps(
{ {"IdentityId": identity_id, "Token": get_random_identity_id(self.region)}
"IdentityId": identity_id,
"Token": get_random_identity_id(self.region)
}
) )
return response return response

View File

@ -6,15 +6,16 @@ from .utils import get_random_identity_id
class CognitoIdentityResponse(BaseResponse): class CognitoIdentityResponse(BaseResponse):
def create_identity_pool(self): def create_identity_pool(self):
identity_pool_name = self._get_param('IdentityPoolName') identity_pool_name = self._get_param("IdentityPoolName")
allow_unauthenticated_identities = self._get_param('AllowUnauthenticatedIdentities') allow_unauthenticated_identities = self._get_param(
supported_login_providers = self._get_param('SupportedLoginProviders') "AllowUnauthenticatedIdentities"
developer_provider_name = self._get_param('DeveloperProviderName') )
open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs') supported_login_providers = self._get_param("SupportedLoginProviders")
cognito_identity_providers = self._get_param('CognitoIdentityProviders') developer_provider_name = self._get_param("DeveloperProviderName")
saml_provider_arns = self._get_param('SamlProviderARNs') open_id_connect_provider_arns = self._get_param("OpenIdConnectProviderARNs")
cognito_identity_providers = self._get_param("CognitoIdentityProviders")
saml_provider_arns = self._get_param("SamlProviderARNs")
return cognitoidentity_backends[self.region].create_identity_pool( return cognitoidentity_backends[self.region].create_identity_pool(
identity_pool_name=identity_pool_name, identity_pool_name=identity_pool_name,
@ -23,20 +24,27 @@ class CognitoIdentityResponse(BaseResponse):
developer_provider_name=developer_provider_name, developer_provider_name=developer_provider_name,
open_id_connect_provider_arns=open_id_connect_provider_arns, open_id_connect_provider_arns=open_id_connect_provider_arns,
cognito_identity_providers=cognito_identity_providers, cognito_identity_providers=cognito_identity_providers,
saml_provider_arns=saml_provider_arns) saml_provider_arns=saml_provider_arns,
)
def get_id(self): def get_id(self):
return cognitoidentity_backends[self.region].get_id() return cognitoidentity_backends[self.region].get_id()
def describe_identity_pool(self): def describe_identity_pool(self):
return cognitoidentity_backends[self.region].describe_identity_pool(self._get_param('IdentityPoolId')) return cognitoidentity_backends[self.region].describe_identity_pool(
self._get_param("IdentityPoolId")
)
def get_credentials_for_identity(self): def get_credentials_for_identity(self):
return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId')) return cognitoidentity_backends[self.region].get_credentials_for_identity(
self._get_param("IdentityId")
)
def get_open_id_token_for_developer_identity(self): def get_open_id_token_for_developer_identity(self):
return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity( return cognitoidentity_backends[
self._get_param('IdentityId') or get_random_identity_id(self.region) self.region
].get_open_id_token_for_developer_identity(
self._get_param("IdentityId") or get_random_identity_id(self.region)
) )
def get_open_id_token(self): def get_open_id_token(self):

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import CognitoIdentityResponse from .responses import CognitoIdentityResponse
url_bases = [ url_bases = ["https?://cognito-identity.(.+).amazonaws.com"]
"https?://cognito-identity.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": CognitoIdentityResponse.dispatch}
'{0}/$': CognitoIdentityResponse.dispatch,
}

View File

@ -5,50 +5,40 @@ from werkzeug.exceptions import BadRequest
class ResourceNotFoundError(BadRequest): class ResourceNotFoundError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(ResourceNotFoundError, self).__init__() super(ResourceNotFoundError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "ResourceNotFoundException"}
'__type': 'ResourceNotFoundException', )
})
class UserNotFoundError(BadRequest): class UserNotFoundError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(UserNotFoundError, self).__init__() super(UserNotFoundError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "UserNotFoundException"}
'__type': 'UserNotFoundException', )
})
class UsernameExistsException(BadRequest): class UsernameExistsException(BadRequest):
def __init__(self, message): def __init__(self, message):
super(UsernameExistsException, self).__init__() super(UsernameExistsException, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "UsernameExistsException"}
'__type': 'UsernameExistsException', )
})
class GroupExistsException(BadRequest): class GroupExistsException(BadRequest):
def __init__(self, message): def __init__(self, message):
super(GroupExistsException, self).__init__() super(GroupExistsException, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "GroupExistsException"}
'__type': 'GroupExistsException', )
})
class NotAuthorizedError(BadRequest): class NotAuthorizedError(BadRequest):
def __init__(self, message): def __init__(self, message):
super(NotAuthorizedError, self).__init__() super(NotAuthorizedError, self).__init__()
self.description = json.dumps({ self.description = json.dumps(
"message": message, {"message": message, "__type": "NotAuthorizedException"}
'__type': 'NotAuthorizedException', )
})

View File

@ -14,8 +14,13 @@ from jose import jws
from moto.compat import OrderedDict from moto.compat import OrderedDict
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from .exceptions import GroupExistsException, NotAuthorizedError, ResourceNotFoundError, UserNotFoundError, \ from .exceptions import (
UsernameExistsException GroupExistsException,
NotAuthorizedError,
ResourceNotFoundError,
UserNotFoundError,
UsernameExistsException,
)
UserStatus = { UserStatus = {
"FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD", "FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD",
@ -45,19 +50,22 @@ def paginate(limit, start_arg="next_token", limit_arg="max_results"):
def outer_wrapper(func): def outer_wrapper(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start = int(default_start if kwargs.get(start_arg) is None else kwargs[start_arg]) start = int(
default_start if kwargs.get(start_arg) is None else kwargs[start_arg]
)
lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg]) lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg])
stop = start + lim stop = start + lim
result = func(*args, **kwargs) result = func(*args, **kwargs)
limited_results = list(itertools.islice(result, start, stop)) limited_results = list(itertools.islice(result, start, stop))
next_token = stop if stop < len(result) else None next_token = stop if stop < len(result) else None
return limited_results, next_token return limited_results, next_token
return wrapper return wrapper
return outer_wrapper return outer_wrapper
class CognitoIdpUserPool(BaseModel): class CognitoIdpUserPool(BaseModel):
def __init__(self, region, name, extended_config): def __init__(self, region, name, extended_config):
self.region = region self.region = region
self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex))
@ -75,7 +83,9 @@ class CognitoIdpUserPool(BaseModel):
self.access_tokens = {} self.access_tokens = {}
self.id_tokens = {} self.id_tokens = {}
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f: with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")
) as f:
self.json_web_key = json.loads(f.read()) self.json_web_key = json.loads(f.read())
def _base_json(self): def _base_json(self):
@ -92,14 +102,18 @@ class CognitoIdpUserPool(BaseModel):
if extended: if extended:
user_pool_json.update(self.extended_config) user_pool_json.update(self.extended_config)
else: else:
user_pool_json["LambdaConfig"] = self.extended_config.get("LambdaConfig") or {} user_pool_json["LambdaConfig"] = (
self.extended_config.get("LambdaConfig") or {}
)
return user_pool_json return user_pool_json
def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}): def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}):
now = int(time.time()) now = int(time.time())
payload = { payload = {
"iss": "https://cognito-idp.{}.amazonaws.com/{}".format(self.region, self.id), "iss": "https://cognito-idp.{}.amazonaws.com/{}".format(
self.region, self.id
),
"sub": self.users[username].id, "sub": self.users[username].id,
"aud": client_id, "aud": client_id,
"token_use": "id", "token_use": "id",
@ -108,7 +122,7 @@ class CognitoIdpUserPool(BaseModel):
} }
payload.update(extra_data) payload.update(extra_data)
return jws.sign(payload, self.json_web_key, algorithm='RS256'), expires_in return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in
def create_id_token(self, client_id, username): def create_id_token(self, client_id, username):
id_token, expires_in = self.create_jwt(client_id, username) id_token, expires_in = self.create_jwt(client_id, username)
@ -121,11 +135,10 @@ class CognitoIdpUserPool(BaseModel):
return refresh_token return refresh_token
def create_access_token(self, client_id, username): def create_access_token(self, client_id, username):
extra_data = self.get_user_extra_data_by_client_id( extra_data = self.get_user_extra_data_by_client_id(client_id, username)
client_id, username access_token, expires_in = self.create_jwt(
client_id, username, extra_data=extra_data
) )
access_token, expires_in = self.create_jwt(client_id, username,
extra_data=extra_data)
self.access_tokens[access_token] = (client_id, username) self.access_tokens[access_token] = (client_id, username)
return access_token, expires_in return access_token, expires_in
@ -143,29 +156,27 @@ class CognitoIdpUserPool(BaseModel):
current_client = self.clients.get(client_id, None) current_client = self.clients.get(client_id, None)
if current_client: if current_client:
for readable_field in current_client.get_readable_fields(): for readable_field in current_client.get_readable_fields():
attribute = list(filter( attribute = list(
lambda f: f['Name'] == readable_field, filter(
self.users.get(username).attributes lambda f: f["Name"] == readable_field,
)) self.users.get(username).attributes,
)
)
if len(attribute) > 0: if len(attribute) > 0:
extra_data.update({ extra_data.update({attribute[0]["Name"]: attribute[0]["Value"]})
attribute[0]['Name']: attribute[0]['Value']
})
return extra_data return extra_data
class CognitoIdpUserPoolDomain(BaseModel): class CognitoIdpUserPoolDomain(BaseModel):
def __init__(self, user_pool_id, domain, custom_domain_config=None): def __init__(self, user_pool_id, domain, custom_domain_config=None):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.domain = domain self.domain = domain
self.custom_domain_config = custom_domain_config or {} self.custom_domain_config = custom_domain_config or {}
def _distribution_name(self): def _distribution_name(self):
if self.custom_domain_config and \ if self.custom_domain_config and "CertificateArn" in self.custom_domain_config:
'CertificateArn' in self.custom_domain_config:
hash = hashlib.md5( hash = hashlib.md5(
self.custom_domain_config['CertificateArn'].encode('utf-8') self.custom_domain_config["CertificateArn"].encode("utf-8")
).hexdigest() ).hexdigest()
return "{hash}.cloudfront.net".format(hash=hash[:16]) return "{hash}.cloudfront.net".format(hash=hash[:16])
return None return None
@ -183,14 +194,11 @@ class CognitoIdpUserPoolDomain(BaseModel):
"Version": None, "Version": None,
} }
elif distribution: elif distribution:
return { return {"CloudFrontDomain": distribution}
"CloudFrontDomain": distribution,
}
return None return None
class CognitoIdpUserPoolClient(BaseModel): class CognitoIdpUserPoolClient(BaseModel):
def __init__(self, user_pool_id, extended_config): def __init__(self, user_pool_id, extended_config):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
@ -212,11 +220,10 @@ class CognitoIdpUserPoolClient(BaseModel):
return user_pool_client_json return user_pool_client_json
def get_readable_fields(self): def get_readable_fields(self):
return self.extended_config.get('ReadAttributes', []) return self.extended_config.get("ReadAttributes", [])
class CognitoIdpIdentityProvider(BaseModel): class CognitoIdpIdentityProvider(BaseModel):
def __init__(self, name, extended_config): def __init__(self, name, extended_config):
self.name = name self.name = name
self.extended_config = extended_config or {} self.extended_config = extended_config or {}
@ -240,7 +247,6 @@ class CognitoIdpIdentityProvider(BaseModel):
class CognitoIdpGroup(BaseModel): class CognitoIdpGroup(BaseModel):
def __init__(self, user_pool_id, group_name, description, role_arn, precedence): def __init__(self, user_pool_id, group_name, description, role_arn, precedence):
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
self.group_name = group_name self.group_name = group_name
@ -267,7 +273,6 @@ class CognitoIdpGroup(BaseModel):
class CognitoIdpUser(BaseModel): class CognitoIdpUser(BaseModel):
def __init__(self, user_pool_id, username, password, status, attributes): def __init__(self, user_pool_id, username, password, status, attributes):
self.id = str(uuid.uuid4()) self.id = str(uuid.uuid4())
self.user_pool_id = user_pool_id self.user_pool_id = user_pool_id
@ -300,19 +305,18 @@ class CognitoIdpUser(BaseModel):
{ {
"Enabled": self.enabled, "Enabled": self.enabled,
attributes_key: self.attributes, attributes_key: self.attributes,
"MFAOptions": [] "MFAOptions": [],
} }
) )
return user_json return user_json
def update_attributes(self, new_attributes): def update_attributes(self, new_attributes):
def flatten_attrs(attrs): def flatten_attrs(attrs):
return {attr['Name']: attr['Value'] for attr in attrs} return {attr["Name"]: attr["Value"] for attr in attrs}
def expand_attrs(attrs): def expand_attrs(attrs):
return [{'Name': k, 'Value': v} for k, v in attrs.items()] return [{"Name": k, "Value": v} for k, v in attrs.items()]
flat_attributes = flatten_attrs(self.attributes) flat_attributes = flatten_attrs(self.attributes)
flat_attributes.update(flatten_attrs(new_attributes)) flat_attributes.update(flatten_attrs(new_attributes))
@ -320,7 +324,6 @@ class CognitoIdpUser(BaseModel):
class CognitoIdpBackend(BaseBackend): class CognitoIdpBackend(BaseBackend):
def __init__(self, region): def __init__(self, region):
super(CognitoIdpBackend, self).__init__() super(CognitoIdpBackend, self).__init__()
self.region = region self.region = region
@ -496,7 +499,9 @@ class CognitoIdpBackend(BaseBackend):
if not user_pool: if not user_pool:
raise ResourceNotFoundError(user_pool_id) raise ResourceNotFoundError(user_pool_id)
group = CognitoIdpGroup(user_pool_id, group_name, description, role_arn, precedence) group = CognitoIdpGroup(
user_pool_id, group_name, description, role_arn, precedence
)
if group.group_name in user_pool.groups: if group.group_name in user_pool.groups:
raise GroupExistsException("A group with the name already exists") raise GroupExistsException("A group with the name already exists")
user_pool.groups[group.group_name] = group user_pool.groups[group.group_name] = group
@ -565,7 +570,13 @@ class CognitoIdpBackend(BaseBackend):
if username in user_pool.users: if username in user_pool.users:
raise UsernameExistsException(username) raise UsernameExistsException(username)
user = CognitoIdpUser(user_pool_id, username, temporary_password, UserStatus["FORCE_CHANGE_PASSWORD"], attributes) user = CognitoIdpUser(
user_pool_id,
username,
temporary_password,
UserStatus["FORCE_CHANGE_PASSWORD"],
attributes,
)
user_pool.users[user.username] = user user_pool.users[user.username] = user
return user return user
@ -611,7 +622,9 @@ class CognitoIdpBackend(BaseBackend):
def _log_user_in(self, user_pool, client, username): def _log_user_in(self, user_pool, client, username):
refresh_token = user_pool.create_refresh_token(client.id, username) refresh_token = user_pool.create_refresh_token(client.id, username)
access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(
refresh_token
)
return { return {
"AuthenticationResult": { "AuthenticationResult": {
@ -654,7 +667,11 @@ class CognitoIdpBackend(BaseBackend):
return self._log_user_in(user_pool, client, username) return self._log_user_in(user_pool, client, username)
elif auth_flow == "REFRESH_TOKEN": elif auth_flow == "REFRESH_TOKEN":
refresh_token = auth_parameters.get("REFRESH_TOKEN") refresh_token = auth_parameters.get("REFRESH_TOKEN")
id_token, access_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) (
id_token,
access_token,
expires_in,
) = user_pool.create_tokens_from_refresh_token(refresh_token)
return { return {
"AuthenticationResult": { "AuthenticationResult": {
@ -666,7 +683,9 @@ class CognitoIdpBackend(BaseBackend):
else: else:
return {} return {}
def respond_to_auth_challenge(self, session, client_id, challenge_name, challenge_responses): def respond_to_auth_challenge(
self, session, client_id, challenge_name, challenge_responses
):
user_pool = self.sessions.get(session) user_pool = self.sessions.get(session)
if not user_pool: if not user_pool:
raise ResourceNotFoundError(session) raise ResourceNotFoundError(session)

View File

@ -8,7 +8,6 @@ from .models import cognitoidp_backends, find_region_by_value
class CognitoIdpResponse(BaseResponse): class CognitoIdpResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
return json.loads(self.body) return json.loads(self.body)
@ -16,10 +15,10 @@ class CognitoIdpResponse(BaseResponse):
# User pool # User pool
def create_user_pool(self): def create_user_pool(self):
name = self.parameters.pop("PoolName") name = self.parameters.pop("PoolName")
user_pool = cognitoidp_backends[self.region].create_user_pool(name, self.parameters) user_pool = cognitoidp_backends[self.region].create_user_pool(
return json.dumps({ name, self.parameters
"UserPool": user_pool.to_json(extended=True) )
}) return json.dumps({"UserPool": user_pool.to_json(extended=True)})
def list_user_pools(self): def list_user_pools(self):
max_results = self._get_param("MaxResults") max_results = self._get_param("MaxResults")
@ -27,9 +26,7 @@ class CognitoIdpResponse(BaseResponse):
user_pools, next_token = cognitoidp_backends[self.region].list_user_pools( user_pools, next_token = cognitoidp_backends[self.region].list_user_pools(
max_results=max_results, next_token=next_token max_results=max_results, next_token=next_token
) )
response = { response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]}
"UserPools": [user_pool.to_json() for user_pool in user_pools],
}
if next_token: if next_token:
response["NextToken"] = str(next_token) response["NextToken"] = str(next_token)
return json.dumps(response) return json.dumps(response)
@ -37,9 +34,7 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool(self): def describe_user_pool(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id) user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id)
return json.dumps({ return json.dumps({"UserPool": user_pool.to_json(extended=True)})
"UserPool": user_pool.to_json(extended=True)
})
def delete_user_pool(self): def delete_user_pool(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
@ -61,14 +56,14 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_domain(self): def describe_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(domain) user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(
domain
)
domain_description = {} domain_description = {}
if user_pool_domain: if user_pool_domain:
domain_description = user_pool_domain.to_json() domain_description = user_pool_domain.to_json()
return json.dumps({ return json.dumps({"DomainDescription": domain_description})
"DomainDescription": domain_description
})
def delete_user_pool_domain(self): def delete_user_pool_domain(self):
domain = self._get_param("Domain") domain = self._get_param("Domain")
@ -89,19 +84,24 @@ class CognitoIdpResponse(BaseResponse):
# User pool client # User pool client
def create_user_pool_client(self): def create_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")
user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(user_pool_id, self.parameters) user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(
return json.dumps({ user_pool_id, self.parameters
"UserPoolClient": user_pool_client.to_json(extended=True) )
}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def list_user_pool_clients(self): def list_user_pool_clients(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults") max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0") next_token = self._get_param("NextToken", "0")
user_pool_clients, next_token = cognitoidp_backends[self.region].list_user_pool_clients(user_pool_id, user_pool_clients, next_token = cognitoidp_backends[
max_results=max_results, next_token=next_token) self.region
].list_user_pool_clients(
user_pool_id, max_results=max_results, next_token=next_token
)
response = { response = {
"UserPoolClients": [user_pool_client.to_json() for user_pool_client in user_pool_clients] "UserPoolClients": [
user_pool_client.to_json() for user_pool_client in user_pool_clients
]
} }
if next_token: if next_token:
response["NextToken"] = str(next_token) response["NextToken"] = str(next_token)
@ -110,43 +110,51 @@ class CognitoIdpResponse(BaseResponse):
def describe_user_pool_client(self): def describe_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(user_pool_id, client_id) user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(
return json.dumps({ user_pool_id, client_id
"UserPoolClient": user_pool_client.to_json(extended=True) )
}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def update_user_pool_client(self): def update_user_pool_client(self):
user_pool_id = self.parameters.pop("UserPoolId") user_pool_id = self.parameters.pop("UserPoolId")
client_id = self.parameters.pop("ClientId") client_id = self.parameters.pop("ClientId")
user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(user_pool_id, client_id, self.parameters) user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(
return json.dumps({ user_pool_id, client_id, self.parameters
"UserPoolClient": user_pool_client.to_json(extended=True) )
}) return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)})
def delete_user_pool_client(self): def delete_user_pool_client(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
client_id = self._get_param("ClientId") client_id = self._get_param("ClientId")
cognitoidp_backends[self.region].delete_user_pool_client(user_pool_id, client_id) cognitoidp_backends[self.region].delete_user_pool_client(
user_pool_id, client_id
)
return "" return ""
# Identity provider # Identity provider
def create_identity_provider(self): def create_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self.parameters.pop("ProviderName") name = self.parameters.pop("ProviderName")
identity_provider = cognitoidp_backends[self.region].create_identity_provider(user_pool_id, name, self.parameters) identity_provider = cognitoidp_backends[self.region].create_identity_provider(
return json.dumps({ user_pool_id, name, self.parameters
"IdentityProvider": identity_provider.to_json(extended=True) )
}) return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def list_identity_providers(self): def list_identity_providers(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
max_results = self._get_param("MaxResults") max_results = self._get_param("MaxResults")
next_token = self._get_param("NextToken", "0") next_token = self._get_param("NextToken", "0")
identity_providers, next_token = cognitoidp_backends[self.region].list_identity_providers( identity_providers, next_token = cognitoidp_backends[
self.region
].list_identity_providers(
user_pool_id, max_results=max_results, next_token=next_token user_pool_id, max_results=max_results, next_token=next_token
) )
response = { response = {
"Providers": [identity_provider.to_json() for identity_provider in identity_providers] "Providers": [
identity_provider.to_json() for identity_provider in identity_providers
]
} }
if next_token: if next_token:
response["NextToken"] = str(next_token) response["NextToken"] = str(next_token)
@ -155,18 +163,22 @@ class CognitoIdpResponse(BaseResponse):
def describe_identity_provider(self): def describe_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName") name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].describe_identity_provider(user_pool_id, name) identity_provider = cognitoidp_backends[self.region].describe_identity_provider(
return json.dumps({ user_pool_id, name
"IdentityProvider": identity_provider.to_json(extended=True) )
}) return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def update_identity_provider(self): def update_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
name = self._get_param("ProviderName") name = self._get_param("ProviderName")
identity_provider = cognitoidp_backends[self.region].update_identity_provider(user_pool_id, name, self.parameters) identity_provider = cognitoidp_backends[self.region].update_identity_provider(
return json.dumps({ user_pool_id, name, self.parameters
"IdentityProvider": identity_provider.to_json(extended=True) )
}) return json.dumps(
{"IdentityProvider": identity_provider.to_json(extended=True)}
)
def delete_identity_provider(self): def delete_identity_provider(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
@ -183,31 +195,21 @@ class CognitoIdpResponse(BaseResponse):
precedence = self._get_param("Precedence") precedence = self._get_param("Precedence")
group = cognitoidp_backends[self.region].create_group( group = cognitoidp_backends[self.region].create_group(
user_pool_id, user_pool_id, group_name, description, role_arn, precedence
group_name,
description,
role_arn,
precedence,
) )
return json.dumps({ return json.dumps({"Group": group.to_json()})
"Group": group.to_json(),
})
def get_group(self): def get_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name) group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name)
return json.dumps({ return json.dumps({"Group": group.to_json()})
"Group": group.to_json(),
})
def list_groups(self): def list_groups(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].list_groups(user_pool_id) groups = cognitoidp_backends[self.region].list_groups(user_pool_id)
return json.dumps({ return json.dumps({"Groups": [group.to_json() for group in groups]})
"Groups": [group.to_json() for group in groups],
})
def delete_group(self): def delete_group(self):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
@ -221,9 +223,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_add_user_to_group( cognitoidp_backends[self.region].admin_add_user_to_group(
user_pool_id, user_pool_id, group_name, username
group_name,
username,
) )
return "" return ""
@ -231,18 +231,18 @@ class CognitoIdpResponse(BaseResponse):
def list_users_in_group(self): def list_users_in_group(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
users = cognitoidp_backends[self.region].list_users_in_group(user_pool_id, group_name) users = cognitoidp_backends[self.region].list_users_in_group(
return json.dumps({ user_pool_id, group_name
"Users": [user.to_json(extended=True) for user in users], )
}) return json.dumps({"Users": [user.to_json(extended=True) for user in users]})
def admin_list_groups_for_user(self): def admin_list_groups_for_user(self):
username = self._get_param("Username") username = self._get_param("Username")
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
groups = cognitoidp_backends[self.region].admin_list_groups_for_user(user_pool_id, username) groups = cognitoidp_backends[self.region].admin_list_groups_for_user(
return json.dumps({ user_pool_id, username
"Groups": [group.to_json() for group in groups], )
}) return json.dumps({"Groups": [group.to_json() for group in groups]})
def admin_remove_user_from_group(self): def admin_remove_user_from_group(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
@ -250,9 +250,7 @@ class CognitoIdpResponse(BaseResponse):
group_name = self._get_param("GroupName") group_name = self._get_param("GroupName")
cognitoidp_backends[self.region].admin_remove_user_from_group( cognitoidp_backends[self.region].admin_remove_user_from_group(
user_pool_id, user_pool_id, group_name, username
group_name,
username,
) )
return "" return ""
@ -266,28 +264,24 @@ class CognitoIdpResponse(BaseResponse):
user_pool_id, user_pool_id,
username, username,
temporary_password, temporary_password,
self._get_param("UserAttributes", []) self._get_param("UserAttributes", []),
) )
return json.dumps({ return json.dumps({"User": user.to_json(extended=True)})
"User": user.to_json(extended=True)
})
def admin_get_user(self): def admin_get_user(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username) user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username)
return json.dumps( return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes"))
user.to_json(extended=True, attributes_key="UserAttributes")
)
def list_users(self): def list_users(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
limit = self._get_param("Limit") limit = self._get_param("Limit")
token = self._get_param("PaginationToken") token = self._get_param("PaginationToken")
users, token = cognitoidp_backends[self.region].list_users(user_pool_id, users, token = cognitoidp_backends[self.region].list_users(
limit=limit, user_pool_id, limit=limit, pagination_token=token
pagination_token=token) )
response = {"Users": [user.to_json(extended=True) for user in users]} response = {"Users": [user.to_json(extended=True) for user in users]}
if token: if token:
response["PaginationToken"] = str(token) response["PaginationToken"] = str(token)
@ -318,10 +312,7 @@ class CognitoIdpResponse(BaseResponse):
auth_parameters = self._get_param("AuthParameters") auth_parameters = self._get_param("AuthParameters")
auth_result = cognitoidp_backends[self.region].admin_initiate_auth( auth_result = cognitoidp_backends[self.region].admin_initiate_auth(
user_pool_id, user_pool_id, client_id, auth_flow, auth_parameters
client_id,
auth_flow,
auth_parameters,
) )
return json.dumps(auth_result) return json.dumps(auth_result)
@ -332,21 +323,15 @@ class CognitoIdpResponse(BaseResponse):
challenge_name = self._get_param("ChallengeName") challenge_name = self._get_param("ChallengeName")
challenge_responses = self._get_param("ChallengeResponses") challenge_responses = self._get_param("ChallengeResponses")
auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge( auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge(
session, session, client_id, challenge_name, challenge_responses
client_id,
challenge_name,
challenge_responses,
) )
return json.dumps(auth_result) return json.dumps(auth_result)
def forgot_password(self): def forgot_password(self):
return json.dumps({ return json.dumps(
"CodeDeliveryDetails": { {"CodeDeliveryDetails": {"DeliveryMedium": "EMAIL", "Destination": "..."}}
"DeliveryMedium": "EMAIL", )
"Destination": "...",
}
})
# This endpoint receives no authorization header, so if moto-server is listening # This endpoint receives no authorization header, so if moto-server is listening
# on localhost (doesn't get a region in the host header), it doesn't know what # on localhost (doesn't get a region in the host header), it doesn't know what
@ -357,7 +342,9 @@ class CognitoIdpResponse(BaseResponse):
username = self._get_param("Username") username = self._get_param("Username")
password = self._get_param("Password") password = self._get_param("Password")
region = find_region_by_value("client_id", client_id) region = find_region_by_value("client_id", client_id)
cognitoidp_backends[region].confirm_forgot_password(client_id, username, password) cognitoidp_backends[region].confirm_forgot_password(
client_id, username, password
)
return "" return ""
# Ditto the comment on confirm_forgot_password. # Ditto the comment on confirm_forgot_password.
@ -366,21 +353,26 @@ class CognitoIdpResponse(BaseResponse):
previous_password = self._get_param("PreviousPassword") previous_password = self._get_param("PreviousPassword")
proposed_password = self._get_param("ProposedPassword") proposed_password = self._get_param("ProposedPassword")
region = find_region_by_value("access_token", access_token) region = find_region_by_value("access_token", access_token)
cognitoidp_backends[region].change_password(access_token, previous_password, proposed_password) cognitoidp_backends[region].change_password(
access_token, previous_password, proposed_password
)
return "" return ""
def admin_update_user_attributes(self): def admin_update_user_attributes(self):
user_pool_id = self._get_param("UserPoolId") user_pool_id = self._get_param("UserPoolId")
username = self._get_param("Username") username = self._get_param("Username")
attributes = self._get_param("UserAttributes") attributes = self._get_param("UserAttributes")
cognitoidp_backends[self.region].admin_update_user_attributes(user_pool_id, username, attributes) cognitoidp_backends[self.region].admin_update_user_attributes(
user_pool_id, username, attributes
)
return "" return ""
class CognitoIdpJsonWebKeyResponse(BaseResponse): class CognitoIdpJsonWebKeyResponse(BaseResponse):
def __init__(self): def __init__(self):
with open(os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")) as f: with open(
os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")
) as f:
self.json_web_key = f.read() self.json_web_key = f.read()
def serve_json_web_key(self, request, full_url, headers): def serve_json_web_key(self, request, full_url, headers):

View File

@ -1,11 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse
url_bases = [ url_bases = ["https?://cognito-idp.(.+).amazonaws.com"]
"https?://cognito-idp.(.+).amazonaws.com",
]
url_paths = { url_paths = {
'{0}/$': CognitoIdpResponse.dispatch, "{0}/$": CognitoIdpResponse.dispatch,
'{0}/<user_pool_id>/.well-known/jwks.json$': CognitoIdpJsonWebKeyResponse().serve_json_web_key, "{0}/<user_pool_id>/.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key,
} }

View File

@ -6,8 +6,12 @@ class NameTooLongException(JsonRESTError):
code = 400 code = 400
def __init__(self, name, location): def __init__(self, name, location):
message = '1 validation error detected: Value \'{name}\' at \'{location}\' failed to satisfy' \ message = (
' constraint: Member must have length less than or equal to 256'.format(name=name, location=location) "1 validation error detected: Value '{name}' at '{location}' failed to satisfy"
" constraint: Member must have length less than or equal to 256".format(
name=name, location=location
)
)
super(NameTooLongException, self).__init__("ValidationException", message) super(NameTooLongException, self).__init__("ValidationException", message)
@ -15,41 +19,54 @@ class InvalidConfigurationRecorderNameException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'The configuration recorder name \'{name}\' is not valid, blank string.'.format(name=name) message = "The configuration recorder name '{name}' is not valid, blank string.".format(
super(InvalidConfigurationRecorderNameException, self).__init__("InvalidConfigurationRecorderNameException", name=name
message) )
super(InvalidConfigurationRecorderNameException, self).__init__(
"InvalidConfigurationRecorderNameException", message
)
class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError): class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Failed to put configuration recorder \'{name}\' because the maximum number of ' \ message = (
'configuration recorders: 1 is reached.'.format(name=name) "Failed to put configuration recorder '{name}' because the maximum number of "
"configuration recorders: 1 is reached.".format(name=name)
)
super(MaxNumberOfConfigurationRecordersExceededException, self).__init__( super(MaxNumberOfConfigurationRecordersExceededException, self).__init__(
"MaxNumberOfConfigurationRecordersExceededException", message) "MaxNumberOfConfigurationRecordersExceededException", message
)
class InvalidRecordingGroupException(JsonRESTError): class InvalidRecordingGroupException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'The recording group provided is not valid' message = "The recording group provided is not valid"
super(InvalidRecordingGroupException, self).__init__("InvalidRecordingGroupException", message) super(InvalidRecordingGroupException, self).__init__(
"InvalidRecordingGroupException", message
)
class InvalidResourceTypeException(JsonRESTError): class InvalidResourceTypeException(JsonRESTError):
code = 400 code = 400
def __init__(self, bad_list, good_list): def __init__(self, bad_list, good_list):
message = '{num} validation error detected: Value \'{bad_list}\' at ' \ message = (
'\'configurationRecorder.recordingGroup.resourceTypes\' failed to satisfy constraint: ' \ "{num} validation error detected: Value '{bad_list}' at "
'Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]'.format( "'configurationRecorder.recordingGroup.resourceTypes' failed to satisfy constraint: "
num=len(bad_list), bad_list=bad_list, good_list=good_list) "Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]".format(
num=len(bad_list), bad_list=bad_list, good_list=good_list
)
)
# For PY2: # For PY2:
message = str(message) message = str(message)
super(InvalidResourceTypeException, self).__init__("ValidationException", message) super(InvalidResourceTypeException, self).__init__(
"ValidationException", message
)
class NoSuchConfigurationAggregatorException(JsonRESTError): class NoSuchConfigurationAggregatorException(JsonRESTError):
@ -57,36 +74,48 @@ class NoSuchConfigurationAggregatorException(JsonRESTError):
def __init__(self, number=1): def __init__(self, number=1):
if number == 1: if number == 1:
message = 'The configuration aggregator does not exist. Check the configuration aggregator name and try again.' message = "The configuration aggregator does not exist. Check the configuration aggregator name and try again."
else: else:
message = 'At least one of the configuration aggregators does not exist. Check the configuration aggregator' \ message = (
' names and try again.' "At least one of the configuration aggregators does not exist. Check the configuration aggregator"
super(NoSuchConfigurationAggregatorException, self).__init__("NoSuchConfigurationAggregatorException", message) " names and try again."
)
super(NoSuchConfigurationAggregatorException, self).__init__(
"NoSuchConfigurationAggregatorException", message
)
class NoSuchConfigurationRecorderException(JsonRESTError): class NoSuchConfigurationRecorderException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Cannot find configuration recorder with the specified name \'{name}\'.'.format(name=name) message = "Cannot find configuration recorder with the specified name '{name}'.".format(
super(NoSuchConfigurationRecorderException, self).__init__("NoSuchConfigurationRecorderException", message) name=name
)
super(NoSuchConfigurationRecorderException, self).__init__(
"NoSuchConfigurationRecorderException", message
)
class InvalidDeliveryChannelNameException(JsonRESTError): class InvalidDeliveryChannelNameException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'The delivery channel name \'{name}\' is not valid, blank string.'.format(name=name) message = "The delivery channel name '{name}' is not valid, blank string.".format(
super(InvalidDeliveryChannelNameException, self).__init__("InvalidDeliveryChannelNameException", name=name
message) )
super(InvalidDeliveryChannelNameException, self).__init__(
"InvalidDeliveryChannelNameException", message
)
class NoSuchBucketException(JsonRESTError): class NoSuchBucketException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here.""" """We are *only* validating that there is value that is not '' here."""
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'Cannot find a S3 bucket with an empty bucket name.' message = "Cannot find a S3 bucket with an empty bucket name."
super(NoSuchBucketException, self).__init__("NoSuchBucketException", message) super(NoSuchBucketException, self).__init__("NoSuchBucketException", message)
@ -94,89 +123,120 @@ class InvalidNextTokenException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'The nextToken provided is invalid' message = "The nextToken provided is invalid"
super(InvalidNextTokenException, self).__init__("InvalidNextTokenException", message) super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", message
)
class InvalidS3KeyPrefixException(JsonRESTError): class InvalidS3KeyPrefixException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'The s3 key prefix \'\' is not valid, empty s3 key prefix.' message = "The s3 key prefix '' is not valid, empty s3 key prefix."
super(InvalidS3KeyPrefixException, self).__init__("InvalidS3KeyPrefixException", message) super(InvalidS3KeyPrefixException, self).__init__(
"InvalidS3KeyPrefixException", message
)
class InvalidSNSTopicARNException(JsonRESTError): class InvalidSNSTopicARNException(JsonRESTError):
"""We are *only* validating that there is value that is not '' here.""" """We are *only* validating that there is value that is not '' here."""
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'The sns topic arn \'\' is not valid.' message = "The sns topic arn '' is not valid."
super(InvalidSNSTopicARNException, self).__init__("InvalidSNSTopicARNException", message) super(InvalidSNSTopicARNException, self).__init__(
"InvalidSNSTopicARNException", message
)
class InvalidDeliveryFrequency(JsonRESTError): class InvalidDeliveryFrequency(JsonRESTError):
code = 400 code = 400
def __init__(self, value, good_list): def __init__(self, value, good_list):
message = '1 validation error detected: Value \'{value}\' at ' \ message = (
'\'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency\' failed to satisfy ' \ "1 validation error detected: Value '{value}' at "
'constraint: Member must satisfy enum value set: {good_list}'.format(value=value, good_list=good_list) "'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency' failed to satisfy "
super(InvalidDeliveryFrequency, self).__init__("InvalidDeliveryFrequency", message) "constraint: Member must satisfy enum value set: {good_list}".format(
value=value, good_list=good_list
)
)
super(InvalidDeliveryFrequency, self).__init__(
"InvalidDeliveryFrequency", message
)
class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError): class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Failed to put delivery channel \'{name}\' because the maximum number of ' \ message = (
'delivery channels: 1 is reached.'.format(name=name) "Failed to put delivery channel '{name}' because the maximum number of "
"delivery channels: 1 is reached.".format(name=name)
)
super(MaxNumberOfDeliveryChannelsExceededException, self).__init__( super(MaxNumberOfDeliveryChannelsExceededException, self).__init__(
"MaxNumberOfDeliveryChannelsExceededException", message) "MaxNumberOfDeliveryChannelsExceededException", message
)
class NoSuchDeliveryChannelException(JsonRESTError): class NoSuchDeliveryChannelException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Cannot find delivery channel with specified name \'{name}\'.'.format(name=name) message = "Cannot find delivery channel with specified name '{name}'.".format(
super(NoSuchDeliveryChannelException, self).__init__("NoSuchDeliveryChannelException", message) name=name
)
super(NoSuchDeliveryChannelException, self).__init__(
"NoSuchDeliveryChannelException", message
)
class NoAvailableConfigurationRecorderException(JsonRESTError): class NoAvailableConfigurationRecorderException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'Configuration recorder is not available to put delivery channel.' message = "Configuration recorder is not available to put delivery channel."
super(NoAvailableConfigurationRecorderException, self).__init__("NoAvailableConfigurationRecorderException", super(NoAvailableConfigurationRecorderException, self).__init__(
message) "NoAvailableConfigurationRecorderException", message
)
class NoAvailableDeliveryChannelException(JsonRESTError): class NoAvailableDeliveryChannelException(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
message = 'Delivery channel is not available to start configuration recorder.' message = "Delivery channel is not available to start configuration recorder."
super(NoAvailableDeliveryChannelException, self).__init__("NoAvailableDeliveryChannelException", message) super(NoAvailableDeliveryChannelException, self).__init__(
"NoAvailableDeliveryChannelException", message
)
class LastDeliveryChannelDeleteFailedException(JsonRESTError): class LastDeliveryChannelDeleteFailedException(JsonRESTError):
code = 400 code = 400
def __init__(self, name): def __init__(self, name):
message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \ message = (
'because there is a running configuration recorder.'.format(name=name) "Failed to delete last specified delivery channel with name '{name}', because there, "
super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message) "because there is a running configuration recorder.".format(name=name)
)
super(LastDeliveryChannelDeleteFailedException, self).__init__(
"LastDeliveryChannelDeleteFailedException", message
)
class TooManyAccountSources(JsonRESTError): class TooManyAccountSources(JsonRESTError):
code = 400 code = 400
def __init__(self, length): def __init__(self, length):
locations = ['com.amazonaws.xyz'] * length locations = ["com.amazonaws.xyz"] * length
message = 'Value \'[{locations}]\' at \'accountAggregationSources\' failed to satisfy constraint: ' \ message = (
'Member must have length less than or equal to 1'.format(locations=', '.join(locations)) "Value '[{locations}]' at 'accountAggregationSources' failed to satisfy constraint: "
"Member must have length less than or equal to 1".format(
locations=", ".join(locations)
)
)
super(TooManyAccountSources, self).__init__("ValidationException", message) super(TooManyAccountSources, self).__init__("ValidationException", message)
@ -185,16 +245,22 @@ class DuplicateTags(JsonRESTError):
def __init__(self): def __init__(self):
super(DuplicateTags, self).__init__( super(DuplicateTags, self).__init__(
'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.') "InvalidInput",
"Duplicate tag keys found. Please note that Tag keys are case insensitive.",
)
class TagKeyTooBig(JsonRESTError): class TagKeyTooBig(JsonRESTError):
code = 400 code = 400
def __init__(self, tag, param='tags.X.member.key'): def __init__(self, tag, param="tags.X.member.key"):
super(TagKeyTooBig, self).__init__( super(TagKeyTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy " "ValidationException",
"constraint: Member must have length less than or equal to 128".format(tag, param)) "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 128".format(
tag, param
),
)
class TagValueTooBig(JsonRESTError): class TagValueTooBig(JsonRESTError):
@ -202,76 +268,100 @@ class TagValueTooBig(JsonRESTError):
def __init__(self, tag): def __init__(self, tag):
super(TagValueTooBig, self).__init__( super(TagValueTooBig, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " "ValidationException",
"constraint: Member must have length less than or equal to 256".format(tag)) "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy "
"constraint: Member must have length less than or equal to 256".format(tag),
)
class InvalidParameterValueException(JsonRESTError): class InvalidParameterValueException(JsonRESTError):
code = 400 code = 400
def __init__(self, message): def __init__(self, message):
super(InvalidParameterValueException, self).__init__('InvalidParameterValueException', message) super(InvalidParameterValueException, self).__init__(
"InvalidParameterValueException", message
)
class InvalidTagCharacters(JsonRESTError): class InvalidTagCharacters(JsonRESTError):
code = 400 code = 400
def __init__(self, tag, param='tags.X.member.key'): def __init__(self, tag, param="tags.X.member.key"):
message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param) message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(
message += 'constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+' tag, param
)
message += "constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+"
super(InvalidTagCharacters, self).__init__('ValidationException', message) super(InvalidTagCharacters, self).__init__("ValidationException", message)
class TooManyTags(JsonRESTError): class TooManyTags(JsonRESTError):
code = 400 code = 400
def __init__(self, tags, param='tags'): def __init__(self, tags, param="tags"):
super(TooManyTags, self).__init__( super(TooManyTags, self).__init__(
'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy " "ValidationException",
"constraint: Member must have length less than or equal to 50.".format(tags, param)) "1 validation error detected: Value '{}' at '{}' failed to satisfy "
"constraint: Member must have length less than or equal to 50.".format(
tags, param
),
)
class InvalidResourceParameters(JsonRESTError): class InvalidResourceParameters(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
super(InvalidResourceParameters, self).__init__('ValidationException', 'Both Resource ID and Resource Name ' super(InvalidResourceParameters, self).__init__(
'cannot be specified in the request') "ValidationException",
"Both Resource ID and Resource Name " "cannot be specified in the request",
)
class InvalidLimit(JsonRESTError): class InvalidLimit(JsonRESTError):
code = 400 code = 400
def __init__(self, value): def __init__(self, value):
super(InvalidLimit, self).__init__('ValidationException', 'Value \'{value}\' at \'limit\' failed to satisify constraint: Member' super(InvalidLimit, self).__init__(
' must have value less than or equal to 100'.format(value=value)) "ValidationException",
"Value '{value}' at 'limit' failed to satisify constraint: Member"
" must have value less than or equal to 100".format(value=value),
)
class TooManyResourceIds(JsonRESTError): class TooManyResourceIds(JsonRESTError):
code = 400 code = 400
def __init__(self): def __init__(self):
super(TooManyResourceIds, self).__init__('ValidationException', "The specified list had more than 20 resource ID's. " super(TooManyResourceIds, self).__init__(
"It must have '20' or less items") "ValidationException",
"The specified list had more than 20 resource ID's. "
"It must have '20' or less items",
)
class ResourceNotDiscoveredException(JsonRESTError): class ResourceNotDiscoveredException(JsonRESTError):
code = 400 code = 400
def __init__(self, type, resource): def __init__(self, type, resource):
super(ResourceNotDiscoveredException, self).__init__('ResourceNotDiscoveredException', super(ResourceNotDiscoveredException, self).__init__(
'Resource {resource} of resourceType:{type} is unknown or has not been ' "ResourceNotDiscoveredException",
'discovered'.format(resource=resource, type=type)) "Resource {resource} of resourceType:{type} is unknown or has not been "
"discovered".format(resource=resource, type=type),
)
class TooManyResourceKeys(JsonRESTError): class TooManyResourceKeys(JsonRESTError):
code = 400 code = 400
def __init__(self, bad_list): def __init__(self, bad_list):
message = '1 validation error detected: Value \'{bad_list}\' at ' \ message = (
'\'resourceKeys\' failed to satisfy constraint: ' \ "1 validation error detected: Value '{bad_list}' at "
'Member must have length less than or equal to 100'.format(bad_list=bad_list) "'resourceKeys' failed to satisfy constraint: "
"Member must have length less than or equal to 100".format(
bad_list=bad_list
)
)
# For PY2: # For PY2:
message = str(message) message = str(message)

View File

@ -9,35 +9,55 @@ from datetime import datetime
from boto3 import Session from boto3 import Session
from moto.config.exceptions import InvalidResourceTypeException, InvalidDeliveryFrequency, \ from moto.config.exceptions import (
InvalidConfigurationRecorderNameException, NameTooLongException, \ InvalidResourceTypeException,
MaxNumberOfConfigurationRecordersExceededException, InvalidRecordingGroupException, \ InvalidDeliveryFrequency,
NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \ InvalidConfigurationRecorderNameException,
InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \ NameTooLongException,
InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \ MaxNumberOfConfigurationRecordersExceededException,
NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException, TagKeyTooBig, \ InvalidRecordingGroupException,
TooManyTags, TagValueTooBig, TooManyAccountSources, InvalidParameterValueException, InvalidNextTokenException, \ NoSuchConfigurationRecorderException,
NoSuchConfigurationAggregatorException, InvalidTagCharacters, DuplicateTags, InvalidLimit, InvalidResourceParameters, \ NoAvailableConfigurationRecorderException,
TooManyResourceIds, ResourceNotDiscoveredException, TooManyResourceKeys InvalidDeliveryChannelNameException,
NoSuchBucketException,
InvalidS3KeyPrefixException,
InvalidSNSTopicARNException,
MaxNumberOfDeliveryChannelsExceededException,
NoAvailableDeliveryChannelException,
NoSuchDeliveryChannelException,
LastDeliveryChannelDeleteFailedException,
TagKeyTooBig,
TooManyTags,
TagValueTooBig,
TooManyAccountSources,
InvalidParameterValueException,
InvalidNextTokenException,
NoSuchConfigurationAggregatorException,
InvalidTagCharacters,
DuplicateTags,
InvalidLimit,
InvalidResourceParameters,
TooManyResourceIds,
ResourceNotDiscoveredException,
TooManyResourceKeys,
)
from moto.core import BaseBackend, BaseModel from moto.core import BaseBackend, BaseModel
from moto.s3.config import s3_config_query from moto.s3.config import s3_config_query
DEFAULT_ACCOUNT_ID = '123456789012' DEFAULT_ACCOUNT_ID = "123456789012"
POP_STRINGS = [ POP_STRINGS = [
'capitalizeStart', "capitalizeStart",
'CapitalizeStart', "CapitalizeStart",
'capitalizeArn', "capitalizeArn",
'CapitalizeArn', "CapitalizeArn",
'capitalizeARN', "capitalizeARN",
'CapitalizeARN' "CapitalizeARN",
] ]
DEFAULT_PAGE_SIZE = 100 DEFAULT_PAGE_SIZE = 100
# Map the Config resource type to a backend: # Map the Config resource type to a backend:
RESOURCE_MAP = { RESOURCE_MAP = {"AWS::S3::Bucket": s3_config_query}
'AWS::S3::Bucket': s3_config_query
}
def datetime2int(date): def datetime2int(date):
@ -45,12 +65,14 @@ def datetime2int(date):
def snake_to_camels(original, cap_start, cap_arn): def snake_to_camels(original, cap_start, cap_arn):
parts = original.split('_') parts = original.split("_")
camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:]) camel_cased = parts[0].lower() + "".join(p.title() for p in parts[1:])
if cap_arn: if cap_arn:
camel_cased = camel_cased.replace('Arn', 'ARN') # Some config services use 'ARN' instead of 'Arn' camel_cased = camel_cased.replace(
"Arn", "ARN"
) # Some config services use 'ARN' instead of 'Arn'
if cap_start: if cap_start:
camel_cased = camel_cased[0].upper() + camel_cased[1::] camel_cased = camel_cased[0].upper() + camel_cased[1::]
@ -67,7 +89,7 @@ def random_string():
return "".join(chars) return "".join(chars)
def validate_tag_key(tag_key, exception_param='tags.X.member.key'): def validate_tag_key(tag_key, exception_param="tags.X.member.key"):
"""Validates the tag key. """Validates the tag key.
:param tag_key: The tag key to check against. :param tag_key: The tag key to check against.
@ -81,7 +103,7 @@ def validate_tag_key(tag_key, exception_param='tags.X.member.key'):
# Validate that the tag key fits the proper Regex: # Validate that the tag key fits the proper Regex:
# [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+ # [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+
match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key) match = re.findall(r"[\w\s_.:/=+\-@]+", tag_key)
# Kudos if you can come up with a better way of doing a global search :) # Kudos if you can come up with a better way of doing a global search :)
if not len(match) or len(match[0]) < len(tag_key): if not len(match) or len(match[0]) < len(tag_key):
raise InvalidTagCharacters(tag_key, param=exception_param) raise InvalidTagCharacters(tag_key, param=exception_param)
@ -106,14 +128,14 @@ def validate_tags(tags):
for tag in tags: for tag in tags:
# Validate the Key: # Validate the Key:
validate_tag_key(tag['Key']) validate_tag_key(tag["Key"])
check_tag_duplicate(proper_tags, tag['Key']) check_tag_duplicate(proper_tags, tag["Key"])
# Validate the Value: # Validate the Value:
if len(tag['Value']) > 256: if len(tag["Value"]) > 256:
raise TagValueTooBig(tag['Value']) raise TagValueTooBig(tag["Value"])
proper_tags[tag['Key']] = tag['Value'] proper_tags[tag["Key"]] = tag["Value"]
return proper_tags return proper_tags
@ -134,9 +156,17 @@ class ConfigEmptyDictable(BaseModel):
for item, value in self.__dict__.items(): for item, value in self.__dict__.items():
if value is not None: if value is not None:
if isinstance(value, ConfigEmptyDictable): if isinstance(value, ConfigEmptyDictable):
data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value.to_dict() data[
snake_to_camels(
item, self.capitalize_start, self.capitalize_arn
)
] = value.to_dict()
else: else:
data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value data[
snake_to_camels(
item, self.capitalize_start, self.capitalize_arn
)
] = value
# Cleanse the extra properties: # Cleanse the extra properties:
for prop in POP_STRINGS: for prop in POP_STRINGS:
@ -146,7 +176,6 @@ class ConfigEmptyDictable(BaseModel):
class ConfigRecorderStatus(ConfigEmptyDictable): class ConfigRecorderStatus(ConfigEmptyDictable):
def __init__(self, name): def __init__(self, name):
super(ConfigRecorderStatus, self).__init__() super(ConfigRecorderStatus, self).__init__()
@ -161,7 +190,7 @@ class ConfigRecorderStatus(ConfigEmptyDictable):
def start(self): def start(self):
self.recording = True self.recording = True
self.last_status = 'PENDING' self.last_status = "PENDING"
self.last_start_time = datetime2int(datetime.utcnow()) self.last_start_time = datetime2int(datetime.utcnow())
self.last_status_change_time = datetime2int(datetime.utcnow()) self.last_status_change_time = datetime2int(datetime.utcnow())
@ -172,7 +201,6 @@ class ConfigRecorderStatus(ConfigEmptyDictable):
class ConfigDeliverySnapshotProperties(ConfigEmptyDictable): class ConfigDeliverySnapshotProperties(ConfigEmptyDictable):
def __init__(self, delivery_frequency): def __init__(self, delivery_frequency):
super(ConfigDeliverySnapshotProperties, self).__init__() super(ConfigDeliverySnapshotProperties, self).__init__()
@ -180,8 +208,9 @@ class ConfigDeliverySnapshotProperties(ConfigEmptyDictable):
class ConfigDeliveryChannel(ConfigEmptyDictable): class ConfigDeliveryChannel(ConfigEmptyDictable):
def __init__(
def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None): self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None
):
super(ConfigDeliveryChannel, self).__init__() super(ConfigDeliveryChannel, self).__init__()
self.name = name self.name = name
@ -192,8 +221,12 @@ class ConfigDeliveryChannel(ConfigEmptyDictable):
class RecordingGroup(ConfigEmptyDictable): class RecordingGroup(ConfigEmptyDictable):
def __init__(
def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None): self,
all_supported=True,
include_global_resource_types=False,
resource_types=None,
):
super(RecordingGroup, self).__init__() super(RecordingGroup, self).__init__()
self.all_supported = all_supported self.all_supported = all_supported
@ -202,8 +235,7 @@ class RecordingGroup(ConfigEmptyDictable):
class ConfigRecorder(ConfigEmptyDictable): class ConfigRecorder(ConfigEmptyDictable):
def __init__(self, role_arn, recording_group, name="default", status=None):
def __init__(self, role_arn, recording_group, name='default', status=None):
super(ConfigRecorder, self).__init__() super(ConfigRecorder, self).__init__()
self.name = name self.name = name
@ -217,18 +249,21 @@ class ConfigRecorder(ConfigEmptyDictable):
class AccountAggregatorSource(ConfigEmptyDictable): class AccountAggregatorSource(ConfigEmptyDictable):
def __init__(self, account_ids, aws_regions=None, all_aws_regions=None): def __init__(self, account_ids, aws_regions=None, all_aws_regions=None):
super(AccountAggregatorSource, self).__init__(capitalize_start=True) super(AccountAggregatorSource, self).__init__(capitalize_start=True)
# Can't have both the regions and all_regions flag present -- also can't have them both missing: # Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions: if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies ' raise InvalidParameterValueException(
'the use of all regions. You must choose one of these options.') "Your configuration aggregator contains a list of regions and also specifies "
"the use of all regions. You must choose one of these options."
)
if not (aws_regions or all_aws_regions): if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported ' raise InvalidParameterValueException(
'regions and try again.') "Your request does not specify any regions. Select AWS Config-supported "
"regions and try again."
)
self.account_ids = account_ids self.account_ids = account_ids
self.aws_regions = aws_regions self.aws_regions = aws_regions
@ -240,18 +275,23 @@ class AccountAggregatorSource(ConfigEmptyDictable):
class OrganizationAggregationSource(ConfigEmptyDictable): class OrganizationAggregationSource(ConfigEmptyDictable):
def __init__(self, role_arn, aws_regions=None, all_aws_regions=None): def __init__(self, role_arn, aws_regions=None, all_aws_regions=None):
super(OrganizationAggregationSource, self).__init__(capitalize_start=True, capitalize_arn=False) super(OrganizationAggregationSource, self).__init__(
capitalize_start=True, capitalize_arn=False
)
# Can't have both the regions and all_regions flag present -- also can't have them both missing: # Can't have both the regions and all_regions flag present -- also can't have them both missing:
if aws_regions and all_aws_regions: if aws_regions and all_aws_regions:
raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies ' raise InvalidParameterValueException(
'the use of all regions. You must choose one of these options.') "Your configuration aggregator contains a list of regions and also specifies "
"the use of all regions. You must choose one of these options."
)
if not (aws_regions or all_aws_regions): if not (aws_regions or all_aws_regions):
raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported ' raise InvalidParameterValueException(
'regions and try again.') "Your request does not specify any regions. Select AWS Config-supported "
"regions and try again."
)
self.role_arn = role_arn self.role_arn = role_arn
self.aws_regions = aws_regions self.aws_regions = aws_regions
@ -263,15 +303,14 @@ class OrganizationAggregationSource(ConfigEmptyDictable):
class ConfigAggregator(ConfigEmptyDictable): class ConfigAggregator(ConfigEmptyDictable):
def __init__(self, name, region, account_sources=None, org_source=None, tags=None): def __init__(self, name, region, account_sources=None, org_source=None, tags=None):
super(ConfigAggregator, self).__init__(capitalize_start=True, capitalize_arn=False) super(ConfigAggregator, self).__init__(
capitalize_start=True, capitalize_arn=False
)
self.configuration_aggregator_name = name self.configuration_aggregator_name = name
self.configuration_aggregator_arn = 'arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}'.format( self.configuration_aggregator_arn = "arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}".format(
region=region, region=region, id=DEFAULT_ACCOUNT_ID, random=random_string()
id=DEFAULT_ACCOUNT_ID,
random=random_string()
) )
self.account_aggregation_sources = account_sources self.account_aggregation_sources = account_sources
self.organization_aggregation_source = org_source self.organization_aggregation_source = org_source
@ -287,7 +326,9 @@ class ConfigAggregator(ConfigEmptyDictable):
# Override the account aggregation sources if present: # Override the account aggregation sources if present:
if self.account_aggregation_sources: if self.account_aggregation_sources:
result['AccountAggregationSources'] = [a.to_dict() for a in self.account_aggregation_sources] result["AccountAggregationSources"] = [
a.to_dict() for a in self.account_aggregation_sources
]
# Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to! # Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to!
# if self.tags: # if self.tags:
@ -297,15 +338,22 @@ class ConfigAggregator(ConfigEmptyDictable):
class ConfigAggregationAuthorization(ConfigEmptyDictable): class ConfigAggregationAuthorization(ConfigEmptyDictable):
def __init__(
self, current_region, authorized_account_id, authorized_aws_region, tags=None
):
super(ConfigAggregationAuthorization, self).__init__(
capitalize_start=True, capitalize_arn=False
)
def __init__(self, current_region, authorized_account_id, authorized_aws_region, tags=None): self.aggregation_authorization_arn = (
super(ConfigAggregationAuthorization, self).__init__(capitalize_start=True, capitalize_arn=False) "arn:aws:config:{region}:{id}:aggregation-authorization/"
"{auth_account}/{auth_region}".format(
self.aggregation_authorization_arn = 'arn:aws:config:{region}:{id}:aggregation-authorization/' \ region=current_region,
'{auth_account}/{auth_region}'.format(region=current_region,
id=DEFAULT_ACCOUNT_ID, id=DEFAULT_ACCOUNT_ID,
auth_account=authorized_account_id, auth_account=authorized_account_id,
auth_region=authorized_aws_region) auth_region=authorized_aws_region,
)
)
self.authorized_account_id = authorized_account_id self.authorized_account_id = authorized_account_id
self.authorized_aws_region = authorized_aws_region self.authorized_aws_region = authorized_aws_region
self.creation_time = datetime2int(datetime.utcnow()) self.creation_time = datetime2int(datetime.utcnow())
@ -315,7 +363,6 @@ class ConfigAggregationAuthorization(ConfigEmptyDictable):
class ConfigBackend(BaseBackend): class ConfigBackend(BaseBackend):
def __init__(self): def __init__(self):
self.recorders = {} self.recorders = {}
self.delivery_channels = {} self.delivery_channels = {}
@ -325,9 +372,11 @@ class ConfigBackend(BaseBackend):
@staticmethod @staticmethod
def _validate_resource_types(resource_list): def _validate_resource_types(resource_list):
# Load the service file: # Load the service file:
resource_package = 'botocore' resource_package = "botocore"
resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) resource_path = "/".join(("data", "config", "2014-11-12", "service-2.json"))
config_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) config_schema = json.loads(
pkg_resources.resource_string(resource_package, resource_path)
)
# Verify that each entry exists in the supported list: # Verify that each entry exists in the supported list:
bad_list = [] bad_list = []
@ -335,72 +384,114 @@ class ConfigBackend(BaseBackend):
# For PY2: # For PY2:
r_str = str(resource) r_str = str(resource)
if r_str not in config_schema['shapes']['ResourceType']['enum']: if r_str not in config_schema["shapes"]["ResourceType"]["enum"]:
bad_list.append(r_str) bad_list.append(r_str)
if bad_list: if bad_list:
raise InvalidResourceTypeException(bad_list, config_schema['shapes']['ResourceType']['enum']) raise InvalidResourceTypeException(
bad_list, config_schema["shapes"]["ResourceType"]["enum"]
)
@staticmethod @staticmethod
def _validate_delivery_snapshot_properties(properties): def _validate_delivery_snapshot_properties(properties):
# Load the service file: # Load the service file:
resource_package = 'botocore' resource_package = "botocore"
resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) resource_path = "/".join(("data", "config", "2014-11-12", "service-2.json"))
conifg_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) conifg_schema = json.loads(
pkg_resources.resource_string(resource_package, resource_path)
)
# Verify that the deliveryFrequency is set to an acceptable value: # Verify that the deliveryFrequency is set to an acceptable value:
if properties.get('deliveryFrequency', None) not in \ if (
conifg_schema['shapes']['MaximumExecutionFrequency']['enum']: properties.get("deliveryFrequency", None)
raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None), not in conifg_schema["shapes"]["MaximumExecutionFrequency"]["enum"]
conifg_schema['shapes']['MaximumExecutionFrequency']['enum']) ):
raise InvalidDeliveryFrequency(
properties.get("deliveryFrequency", None),
conifg_schema["shapes"]["MaximumExecutionFrequency"]["enum"],
)
def put_configuration_aggregator(self, config_aggregator, region): def put_configuration_aggregator(self, config_aggregator, region):
# Validate the name: # Validate the name:
if len(config_aggregator['ConfigurationAggregatorName']) > 256: if len(config_aggregator["ConfigurationAggregatorName"]) > 256:
raise NameTooLongException(config_aggregator['ConfigurationAggregatorName'], 'configurationAggregatorName') raise NameTooLongException(
config_aggregator["ConfigurationAggregatorName"],
"configurationAggregatorName",
)
account_sources = None account_sources = None
org_source = None org_source = None
# Tag validation: # Tag validation:
tags = validate_tags(config_aggregator.get('Tags', [])) tags = validate_tags(config_aggregator.get("Tags", []))
# Exception if both AccountAggregationSources and OrganizationAggregationSource are supplied: # Exception if both AccountAggregationSources and OrganizationAggregationSource are supplied:
if config_aggregator.get('AccountAggregationSources') and config_aggregator.get('OrganizationAggregationSource'): if config_aggregator.get("AccountAggregationSources") and config_aggregator.get(
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request contains both the' "OrganizationAggregationSource"
' AccountAggregationSource and the OrganizationAggregationSource. Include only ' ):
'one aggregation source and try again.') raise InvalidParameterValueException(
"The configuration aggregator cannot be created because your request contains both the"
" AccountAggregationSource and the OrganizationAggregationSource. Include only "
"one aggregation source and try again."
)
# If neither are supplied: # If neither are supplied:
if not config_aggregator.get('AccountAggregationSources') and not config_aggregator.get('OrganizationAggregationSource'): if not config_aggregator.get(
raise InvalidParameterValueException('The configuration aggregator cannot be created because your request is missing either ' "AccountAggregationSources"
'the AccountAggregationSource or the OrganizationAggregationSource. Include the ' ) and not config_aggregator.get("OrganizationAggregationSource"):
'appropriate aggregation source and try again.') raise InvalidParameterValueException(
"The configuration aggregator cannot be created because your request is missing either "
"the AccountAggregationSource or the OrganizationAggregationSource. Include the "
"appropriate aggregation source and try again."
)
if config_aggregator.get('AccountAggregationSources'): if config_aggregator.get("AccountAggregationSources"):
# Currently, only 1 account aggregation source can be set: # Currently, only 1 account aggregation source can be set:
if len(config_aggregator['AccountAggregationSources']) > 1: if len(config_aggregator["AccountAggregationSources"]) > 1:
raise TooManyAccountSources(len(config_aggregator['AccountAggregationSources'])) raise TooManyAccountSources(
len(config_aggregator["AccountAggregationSources"])
)
account_sources = [] account_sources = []
for a in config_aggregator['AccountAggregationSources']: for a in config_aggregator["AccountAggregationSources"]:
account_sources.append(AccountAggregatorSource(a['AccountIds'], aws_regions=a.get('AwsRegions'), account_sources.append(
all_aws_regions=a.get('AllAwsRegions'))) AccountAggregatorSource(
a["AccountIds"],
aws_regions=a.get("AwsRegions"),
all_aws_regions=a.get("AllAwsRegions"),
)
)
else: else:
org_source = OrganizationAggregationSource(config_aggregator['OrganizationAggregationSource']['RoleArn'], org_source = OrganizationAggregationSource(
aws_regions=config_aggregator['OrganizationAggregationSource'].get('AwsRegions'), config_aggregator["OrganizationAggregationSource"]["RoleArn"],
all_aws_regions=config_aggregator['OrganizationAggregationSource'].get( aws_regions=config_aggregator["OrganizationAggregationSource"].get(
'AllAwsRegions')) "AwsRegions"
),
all_aws_regions=config_aggregator["OrganizationAggregationSource"].get(
"AllAwsRegions"
),
)
# Grab the existing one if it exists and update it: # Grab the existing one if it exists and update it:
if not self.config_aggregators.get(config_aggregator['ConfigurationAggregatorName']): if not self.config_aggregators.get(
aggregator = ConfigAggregator(config_aggregator['ConfigurationAggregatorName'], region, account_sources=account_sources, config_aggregator["ConfigurationAggregatorName"]
org_source=org_source, tags=tags) ):
self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] = aggregator aggregator = ConfigAggregator(
config_aggregator["ConfigurationAggregatorName"],
region,
account_sources=account_sources,
org_source=org_source,
tags=tags,
)
self.config_aggregators[
config_aggregator["ConfigurationAggregatorName"]
] = aggregator
else: else:
aggregator = self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] aggregator = self.config_aggregators[
config_aggregator["ConfigurationAggregatorName"]
]
aggregator.tags = tags aggregator.tags = tags
aggregator.account_aggregation_sources = account_sources aggregator.account_aggregation_sources = account_sources
aggregator.organization_aggregation_source = org_source aggregator.organization_aggregation_source = org_source
@ -411,7 +502,7 @@ class ConfigBackend(BaseBackend):
def describe_configuration_aggregators(self, names, token, limit): def describe_configuration_aggregators(self, names, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
agg_list = [] agg_list = []
result = {'ConfigurationAggregators': []} result = {"ConfigurationAggregators": []}
if names: if names:
for name in names: for name in names:
@ -441,11 +532,13 @@ class ConfigBackend(BaseBackend):
start = sorted_aggregators.index(token) start = sorted_aggregators.index(token)
# Get the list of items to collect: # Get the list of items to collect:
agg_list = sorted_aggregators[start:(start + limit)] agg_list = sorted_aggregators[start : (start + limit)]
result['ConfigurationAggregators'] = [self.config_aggregators[agg].to_dict() for agg in agg_list] result["ConfigurationAggregators"] = [
self.config_aggregators[agg].to_dict() for agg in agg_list
]
if len(sorted_aggregators) > (start + limit): if len(sorted_aggregators) > (start + limit):
result['NextToken'] = sorted_aggregators[start + limit] result["NextToken"] = sorted_aggregators[start + limit]
return result return result
@ -455,16 +548,22 @@ class ConfigBackend(BaseBackend):
del self.config_aggregators[config_aggregator] del self.config_aggregators[config_aggregator]
def put_aggregation_authorization(self, current_region, authorized_account, authorized_region, tags): def put_aggregation_authorization(
self, current_region, authorized_account, authorized_region, tags
):
# Tag validation: # Tag validation:
tags = validate_tags(tags or []) tags = validate_tags(tags or [])
# Does this already exist? # Does this already exist?
key = '{}/{}'.format(authorized_account, authorized_region) key = "{}/{}".format(authorized_account, authorized_region)
agg_auth = self.aggregation_authorizations.get(key) agg_auth = self.aggregation_authorizations.get(key)
if not agg_auth: if not agg_auth:
agg_auth = ConfigAggregationAuthorization(current_region, authorized_account, authorized_region, tags=tags) agg_auth = ConfigAggregationAuthorization(
self.aggregation_authorizations['{}/{}'.format(authorized_account, authorized_region)] = agg_auth current_region, authorized_account, authorized_region, tags=tags
)
self.aggregation_authorizations[
"{}/{}".format(authorized_account, authorized_region)
] = agg_auth
else: else:
# Only update the tags: # Only update the tags:
agg_auth.tags = tags agg_auth.tags = tags
@ -473,7 +572,7 @@ class ConfigBackend(BaseBackend):
def describe_aggregation_authorizations(self, token, limit): def describe_aggregation_authorizations(self, token, limit):
limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit
result = {'AggregationAuthorizations': []} result = {"AggregationAuthorizations": []}
if not self.aggregation_authorizations: if not self.aggregation_authorizations:
return result return result
@ -492,70 +591,82 @@ class ConfigBackend(BaseBackend):
start = sorted_authorizations.index(token) start = sorted_authorizations.index(token)
# Get the list of items to collect: # Get the list of items to collect:
auth_list = sorted_authorizations[start:(start + limit)] auth_list = sorted_authorizations[start : (start + limit)]
result['AggregationAuthorizations'] = [self.aggregation_authorizations[auth].to_dict() for auth in auth_list] result["AggregationAuthorizations"] = [
self.aggregation_authorizations[auth].to_dict() for auth in auth_list
]
if len(sorted_authorizations) > (start + limit): if len(sorted_authorizations) > (start + limit):
result['NextToken'] = sorted_authorizations[start + limit] result["NextToken"] = sorted_authorizations[start + limit]
return result return result
def delete_aggregation_authorization(self, authorized_account, authorized_region): def delete_aggregation_authorization(self, authorized_account, authorized_region):
# This will always return a 200 -- regardless if there is or isn't an existing # This will always return a 200 -- regardless if there is or isn't an existing
# aggregation authorization. # aggregation authorization.
key = '{}/{}'.format(authorized_account, authorized_region) key = "{}/{}".format(authorized_account, authorized_region)
self.aggregation_authorizations.pop(key, None) self.aggregation_authorizations.pop(key, None)
def put_configuration_recorder(self, config_recorder): def put_configuration_recorder(self, config_recorder):
# Validate the name: # Validate the name:
if not config_recorder.get('name'): if not config_recorder.get("name"):
raise InvalidConfigurationRecorderNameException(config_recorder.get('name')) raise InvalidConfigurationRecorderNameException(config_recorder.get("name"))
if len(config_recorder.get('name')) > 256: if len(config_recorder.get("name")) > 256:
raise NameTooLongException(config_recorder.get('name'), 'configurationRecorder.name') raise NameTooLongException(
config_recorder.get("name"), "configurationRecorder.name"
)
# We're going to assume that the passed in Role ARN is correct. # We're going to assume that the passed in Role ARN is correct.
# Config currently only allows 1 configuration recorder for an account: # Config currently only allows 1 configuration recorder for an account:
if len(self.recorders) == 1 and not self.recorders.get(config_recorder['name']): if len(self.recorders) == 1 and not self.recorders.get(config_recorder["name"]):
raise MaxNumberOfConfigurationRecordersExceededException(config_recorder['name']) raise MaxNumberOfConfigurationRecordersExceededException(
config_recorder["name"]
)
# Is this updating an existing one? # Is this updating an existing one?
recorder_status = None recorder_status = None
if self.recorders.get(config_recorder['name']): if self.recorders.get(config_recorder["name"]):
recorder_status = self.recorders[config_recorder['name']].status recorder_status = self.recorders[config_recorder["name"]].status
# Validate the Recording Group: # Validate the Recording Group:
if config_recorder.get('recordingGroup') is None: if config_recorder.get("recordingGroup") is None:
recording_group = RecordingGroup() recording_group = RecordingGroup()
else: else:
rg = config_recorder['recordingGroup'] rg = config_recorder["recordingGroup"]
# If an empty dict is passed in, then bad: # If an empty dict is passed in, then bad:
if not rg: if not rg:
raise InvalidRecordingGroupException() raise InvalidRecordingGroupException()
# Can't have both the resource types specified and the other flags as True. # Can't have both the resource types specified and the other flags as True.
if rg.get('resourceTypes') and ( if rg.get("resourceTypes") and (
rg.get('allSupported', False) or rg.get("allSupported", False)
rg.get('includeGlobalResourceTypes', False)): or rg.get("includeGlobalResourceTypes", False)
):
raise InvalidRecordingGroupException() raise InvalidRecordingGroupException()
# Must supply resourceTypes if 'allSupported' is not supplied: # Must supply resourceTypes if 'allSupported' is not supplied:
if not rg.get('allSupported') and not rg.get('resourceTypes'): if not rg.get("allSupported") and not rg.get("resourceTypes"):
raise InvalidRecordingGroupException() raise InvalidRecordingGroupException()
# Validate that the list provided is correct: # Validate that the list provided is correct:
self._validate_resource_types(rg.get('resourceTypes', [])) self._validate_resource_types(rg.get("resourceTypes", []))
recording_group = RecordingGroup( recording_group = RecordingGroup(
all_supported=rg.get('allSupported', True), all_supported=rg.get("allSupported", True),
include_global_resource_types=rg.get('includeGlobalResourceTypes', False), include_global_resource_types=rg.get(
resource_types=rg.get('resourceTypes', []) "includeGlobalResourceTypes", False
),
resource_types=rg.get("resourceTypes", []),
) )
self.recorders[config_recorder['name']] = \ self.recorders[config_recorder["name"]] = ConfigRecorder(
ConfigRecorder(config_recorder['roleARN'], recording_group, name=config_recorder['name'], config_recorder["roleARN"],
status=recorder_status) recording_group,
name=config_recorder["name"],
status=recorder_status,
)
def describe_configuration_recorders(self, recorder_names): def describe_configuration_recorders(self, recorder_names):
recorders = [] recorders = []
@ -597,43 +708,54 @@ class ConfigBackend(BaseBackend):
raise NoAvailableConfigurationRecorderException() raise NoAvailableConfigurationRecorderException()
# Validate the name: # Validate the name:
if not delivery_channel.get('name'): if not delivery_channel.get("name"):
raise InvalidDeliveryChannelNameException(delivery_channel.get('name')) raise InvalidDeliveryChannelNameException(delivery_channel.get("name"))
if len(delivery_channel.get('name')) > 256: if len(delivery_channel.get("name")) > 256:
raise NameTooLongException(delivery_channel.get('name'), 'deliveryChannel.name') raise NameTooLongException(
delivery_channel.get("name"), "deliveryChannel.name"
)
# We are going to assume that the bucket exists -- but will verify if the bucket provided is blank: # We are going to assume that the bucket exists -- but will verify if the bucket provided is blank:
if not delivery_channel.get('s3BucketName'): if not delivery_channel.get("s3BucketName"):
raise NoSuchBucketException() raise NoSuchBucketException()
# We are going to assume that the bucket has the correct policy attached to it. We are only going to verify # We are going to assume that the bucket has the correct policy attached to it. We are only going to verify
# if the prefix provided is not an empty string: # if the prefix provided is not an empty string:
if delivery_channel.get('s3KeyPrefix', None) == '': if delivery_channel.get("s3KeyPrefix", None) == "":
raise InvalidS3KeyPrefixException() raise InvalidS3KeyPrefixException()
# Ditto for SNS -- Only going to assume that the ARN provided is not an empty string: # Ditto for SNS -- Only going to assume that the ARN provided is not an empty string:
if delivery_channel.get('snsTopicARN', None) == '': if delivery_channel.get("snsTopicARN", None) == "":
raise InvalidSNSTopicARNException() raise InvalidSNSTopicARNException()
# Config currently only allows 1 delivery channel for an account: # Config currently only allows 1 delivery channel for an account:
if len(self.delivery_channels) == 1 and not self.delivery_channels.get(delivery_channel['name']): if len(self.delivery_channels) == 1 and not self.delivery_channels.get(
raise MaxNumberOfDeliveryChannelsExceededException(delivery_channel['name']) delivery_channel["name"]
):
raise MaxNumberOfDeliveryChannelsExceededException(delivery_channel["name"])
if not delivery_channel.get('configSnapshotDeliveryProperties'): if not delivery_channel.get("configSnapshotDeliveryProperties"):
dp = None dp = None
else: else:
# Validate the config snapshot delivery properties: # Validate the config snapshot delivery properties:
self._validate_delivery_snapshot_properties(delivery_channel['configSnapshotDeliveryProperties']) self._validate_delivery_snapshot_properties(
delivery_channel["configSnapshotDeliveryProperties"]
)
dp = ConfigDeliverySnapshotProperties( dp = ConfigDeliverySnapshotProperties(
delivery_channel['configSnapshotDeliveryProperties']['deliveryFrequency']) delivery_channel["configSnapshotDeliveryProperties"][
"deliveryFrequency"
]
)
self.delivery_channels[delivery_channel['name']] = \ self.delivery_channels[delivery_channel["name"]] = ConfigDeliveryChannel(
ConfigDeliveryChannel(delivery_channel['name'], delivery_channel['s3BucketName'], delivery_channel["name"],
prefix=delivery_channel.get('s3KeyPrefix', None), delivery_channel["s3BucketName"],
sns_arn=delivery_channel.get('snsTopicARN', None), prefix=delivery_channel.get("s3KeyPrefix", None),
snapshot_properties=dp) sns_arn=delivery_channel.get("snsTopicARN", None),
snapshot_properties=dp,
)
def describe_delivery_channels(self, channel_names): def describe_delivery_channels(self, channel_names):
channels = [] channels = []
@ -687,7 +809,15 @@ class ConfigBackend(BaseBackend):
del self.delivery_channels[channel_name] del self.delivery_channels[channel_name]
def list_discovered_resources(self, resource_type, backend_region, resource_ids, resource_name, limit, next_token): def list_discovered_resources(
self,
resource_type,
backend_region,
resource_ids,
resource_name,
limit,
next_token,
):
"""This will query against the mocked AWS Config (non-aggregated) listing function that must exist for the resource backend. """This will query against the mocked AWS Config (non-aggregated) listing function that must exist for the resource backend.
:param resource_type: :param resource_type:
@ -716,33 +846,45 @@ class ConfigBackend(BaseBackend):
# call upon the resource type's Config Query class to retrieve the list of resources that match the criteria: # call upon the resource type's Config Query class to retrieve the list of resources that match the criteria:
if RESOURCE_MAP.get(resource_type, {}): if RESOURCE_MAP.get(resource_type, {}):
# Is this a global resource type? -- if so, re-write the region to 'global': # Is this a global resource type? -- if so, re-write the region to 'global':
backend_query_region = backend_region # Always provide the backend this request arrived from. backend_query_region = (
if RESOURCE_MAP[resource_type].backends.get('global'): backend_region # Always provide the backend this request arrived from.
backend_region = 'global' )
if RESOURCE_MAP[resource_type].backends.get("global"):
backend_region = "global"
# For non-aggregated queries, the we only care about the backend_region. Need to verify that moto has implemented # For non-aggregated queries, the we only care about the backend_region. Need to verify that moto has implemented
# the region for the given backend: # the region for the given backend:
if RESOURCE_MAP[resource_type].backends.get(backend_region): if RESOURCE_MAP[resource_type].backends.get(backend_region):
# Fetch the resources for the backend's region: # Fetch the resources for the backend's region:
identifiers, new_token = \ identifiers, new_token = RESOURCE_MAP[
RESOURCE_MAP[resource_type].list_config_service_resources(resource_ids, resource_name, limit, next_token, resource_type
backend_region=backend_query_region) ].list_config_service_resources(
resource_ids,
resource_name,
limit,
next_token,
backend_region=backend_query_region,
)
result = {'resourceIdentifiers': [ result = {
"resourceIdentifiers": [
{ {
'resourceType': identifier['type'], "resourceType": identifier["type"],
'resourceId': identifier['id'], "resourceId": identifier["id"],
'resourceName': identifier['name'] "resourceName": identifier["name"],
} }
for identifier in identifiers] for identifier in identifiers
]
} }
if new_token: if new_token:
result['nextToken'] = new_token result["nextToken"] = new_token
return result return result
def list_aggregate_discovered_resources(self, aggregator_name, resource_type, filters, limit, next_token): def list_aggregate_discovered_resources(
self, aggregator_name, resource_type, filters, limit, next_token
):
"""This will query against the mocked AWS Config listing function that must exist for the resource backend. """This will query against the mocked AWS Config listing function that must exist for the resource backend.
As far a moto goes -- the only real difference between this function and the `list_discovered_resources` function is that As far a moto goes -- the only real difference between this function and the `list_discovered_resources` function is that
@ -770,27 +912,35 @@ class ConfigBackend(BaseBackend):
# call upon the resource type's Config Query class to retrieve the list of resources that match the criteria: # call upon the resource type's Config Query class to retrieve the list of resources that match the criteria:
if RESOURCE_MAP.get(resource_type, {}): if RESOURCE_MAP.get(resource_type, {}):
# We only care about a filter's Region, Resource Name, and Resource ID: # We only care about a filter's Region, Resource Name, and Resource ID:
resource_region = filters.get('Region') resource_region = filters.get("Region")
resource_id = [filters['ResourceId']] if filters.get('ResourceId') else None resource_id = [filters["ResourceId"]] if filters.get("ResourceId") else None
resource_name = filters.get('ResourceName') resource_name = filters.get("ResourceName")
identifiers, new_token = \ identifiers, new_token = RESOURCE_MAP[
RESOURCE_MAP[resource_type].list_config_service_resources(resource_id, resource_name, limit, next_token, resource_type
resource_region=resource_region) ].list_config_service_resources(
resource_id,
resource_name,
limit,
next_token,
resource_region=resource_region,
)
result = {'ResourceIdentifiers': [ result = {
"ResourceIdentifiers": [
{ {
'SourceAccountId': DEFAULT_ACCOUNT_ID, "SourceAccountId": DEFAULT_ACCOUNT_ID,
'SourceRegion': identifier['region'], "SourceRegion": identifier["region"],
'ResourceType': identifier['type'], "ResourceType": identifier["type"],
'ResourceId': identifier['id'], "ResourceId": identifier["id"],
'ResourceName': identifier['name'] "ResourceName": identifier["name"],
} }
for identifier in identifiers] for identifier in identifiers
]
} }
if new_token: if new_token:
result['NextToken'] = new_token result["NextToken"] = new_token
return result return result
@ -806,22 +956,26 @@ class ConfigBackend(BaseBackend):
raise ResourceNotDiscoveredException(resource_type, id) raise ResourceNotDiscoveredException(resource_type, id)
# Is the resource type global? # Is the resource type global?
backend_query_region = backend_region # Always provide the backend this request arrived from. backend_query_region = (
if RESOURCE_MAP[resource_type].backends.get('global'): backend_region # Always provide the backend this request arrived from.
backend_region = 'global' )
if RESOURCE_MAP[resource_type].backends.get("global"):
backend_region = "global"
# If the backend region isn't implemented then we won't find the item: # If the backend region isn't implemented then we won't find the item:
if not RESOURCE_MAP[resource_type].backends.get(backend_region): if not RESOURCE_MAP[resource_type].backends.get(backend_region):
raise ResourceNotDiscoveredException(resource_type, id) raise ResourceNotDiscoveredException(resource_type, id)
# Get the item: # Get the item:
item = RESOURCE_MAP[resource_type].get_config_resource(id, backend_region=backend_query_region) item = RESOURCE_MAP[resource_type].get_config_resource(
id, backend_region=backend_query_region
)
if not item: if not item:
raise ResourceNotDiscoveredException(resource_type, id) raise ResourceNotDiscoveredException(resource_type, id)
item['accountId'] = DEFAULT_ACCOUNT_ID item["accountId"] = DEFAULT_ACCOUNT_ID
return {'configurationItems': [item]} return {"configurationItems": [item]}
def batch_get_resource_config(self, resource_keys, backend_region): def batch_get_resource_config(self, resource_keys, backend_region):
"""Returns the configuration of an item in the AWS Config format of the resource for the current regional backend. """Returns the configuration of an item in the AWS Config format of the resource for the current regional backend.
@ -831,37 +985,50 @@ class ConfigBackend(BaseBackend):
""" """
# Can't have more than 100 items # Can't have more than 100 items
if len(resource_keys) > 100: if len(resource_keys) > 100:
raise TooManyResourceKeys(['com.amazonaws.starling.dove.ResourceKey@12345'] * len(resource_keys)) raise TooManyResourceKeys(
["com.amazonaws.starling.dove.ResourceKey@12345"] * len(resource_keys)
)
results = [] results = []
for resource in resource_keys: for resource in resource_keys:
# Does the resource type exist? # Does the resource type exist?
if not RESOURCE_MAP.get(resource['resourceType']): if not RESOURCE_MAP.get(resource["resourceType"]):
# Not found so skip. # Not found so skip.
continue continue
# Is the resource type global? # Is the resource type global?
config_backend_region = backend_region config_backend_region = backend_region
backend_query_region = backend_region # Always provide the backend this request arrived from. backend_query_region = (
if RESOURCE_MAP[resource['resourceType']].backends.get('global'): backend_region # Always provide the backend this request arrived from.
config_backend_region = 'global' )
if RESOURCE_MAP[resource["resourceType"]].backends.get("global"):
config_backend_region = "global"
# If the backend region isn't implemented then we won't find the item: # If the backend region isn't implemented then we won't find the item:
if not RESOURCE_MAP[resource['resourceType']].backends.get(config_backend_region): if not RESOURCE_MAP[resource["resourceType"]].backends.get(
config_backend_region
):
continue continue
# Get the item: # Get the item:
item = RESOURCE_MAP[resource['resourceType']].get_config_resource(resource['resourceId'], backend_region=backend_query_region) item = RESOURCE_MAP[resource["resourceType"]].get_config_resource(
resource["resourceId"], backend_region=backend_query_region
)
if not item: if not item:
continue continue
item['accountId'] = DEFAULT_ACCOUNT_ID item["accountId"] = DEFAULT_ACCOUNT_ID
results.append(item) results.append(item)
return {'baseConfigurationItems': results, 'unprocessedResourceKeys': []} # At this time, moto is not adding unprocessed items. return {
"baseConfigurationItems": results,
"unprocessedResourceKeys": [],
} # At this time, moto is not adding unprocessed items.
def batch_get_aggregate_resource_config(self, aggregator_name, resource_identifiers): def batch_get_aggregate_resource_config(
self, aggregator_name, resource_identifiers
):
"""Returns the configuration of an item in the AWS Config format of the resource for the current regional backend. """Returns the configuration of an item in the AWS Config format of the resource for the current regional backend.
As far a moto goes -- the only real difference between this function and the `batch_get_resource_config` function is that As far a moto goes -- the only real difference between this function and the `batch_get_resource_config` function is that
@ -874,15 +1041,18 @@ class ConfigBackend(BaseBackend):
# Can't have more than 100 items # Can't have more than 100 items
if len(resource_identifiers) > 100: if len(resource_identifiers) > 100:
raise TooManyResourceKeys(['com.amazonaws.starling.dove.AggregateResourceIdentifier@12345'] * len(resource_identifiers)) raise TooManyResourceKeys(
["com.amazonaws.starling.dove.AggregateResourceIdentifier@12345"]
* len(resource_identifiers)
)
found = [] found = []
not_found = [] not_found = []
for identifier in resource_identifiers: for identifier in resource_identifiers:
resource_type = identifier['ResourceType'] resource_type = identifier["ResourceType"]
resource_region = identifier['SourceRegion'] resource_region = identifier["SourceRegion"]
resource_id = identifier['ResourceId'] resource_id = identifier["ResourceId"]
resource_name = identifier.get('ResourceName', None) resource_name = identifier.get("ResourceName", None)
# Does the resource type exist? # Does the resource type exist?
if not RESOURCE_MAP.get(resource_type): if not RESOURCE_MAP.get(resource_type):
@ -890,23 +1060,29 @@ class ConfigBackend(BaseBackend):
continue continue
# Get the item: # Get the item:
item = RESOURCE_MAP[resource_type].get_config_resource(resource_id, resource_name=resource_name, item = RESOURCE_MAP[resource_type].get_config_resource(
resource_region=resource_region) resource_id,
resource_name=resource_name,
resource_region=resource_region,
)
if not item: if not item:
not_found.append(identifier) not_found.append(identifier)
continue continue
item['accountId'] = DEFAULT_ACCOUNT_ID item["accountId"] = DEFAULT_ACCOUNT_ID
# The 'tags' field is not included in aggregate results for some reason... # The 'tags' field is not included in aggregate results for some reason...
item.pop('tags', None) item.pop("tags", None)
found.append(item) found.append(item)
return {'BaseConfigurationItems': found, 'UnprocessedResourceIdentifiers': not_found} return {
"BaseConfigurationItems": found,
"UnprocessedResourceIdentifiers": not_found,
}
config_backends = {} config_backends = {}
boto3_session = Session() boto3_session = Session()
for region in boto3_session.get_available_regions('config'): for region in boto3_session.get_available_regions("config"):
config_backends[region] = ConfigBackend() config_backends[region] = ConfigBackend()

View File

@ -4,116 +4,150 @@ from .models import config_backends
class ConfigResponse(BaseResponse): class ConfigResponse(BaseResponse):
@property @property
def config_backend(self): def config_backend(self):
return config_backends[self.region] return config_backends[self.region]
def put_configuration_recorder(self): def put_configuration_recorder(self):
self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder')) self.config_backend.put_configuration_recorder(
self._get_param("ConfigurationRecorder")
)
return "" return ""
def put_configuration_aggregator(self): def put_configuration_aggregator(self):
aggregator = self.config_backend.put_configuration_aggregator(json.loads(self.body), self.region) aggregator = self.config_backend.put_configuration_aggregator(
schema = {'ConfigurationAggregator': aggregator} json.loads(self.body), self.region
)
schema = {"ConfigurationAggregator": aggregator}
return json.dumps(schema) return json.dumps(schema)
def describe_configuration_aggregators(self): def describe_configuration_aggregators(self):
aggregators = self.config_backend.describe_configuration_aggregators(self._get_param('ConfigurationAggregatorNames'), aggregators = self.config_backend.describe_configuration_aggregators(
self._get_param('NextToken'), self._get_param("ConfigurationAggregatorNames"),
self._get_param('Limit')) self._get_param("NextToken"),
self._get_param("Limit"),
)
return json.dumps(aggregators) return json.dumps(aggregators)
def delete_configuration_aggregator(self): def delete_configuration_aggregator(self):
self.config_backend.delete_configuration_aggregator(self._get_param('ConfigurationAggregatorName')) self.config_backend.delete_configuration_aggregator(
self._get_param("ConfigurationAggregatorName")
)
return "" return ""
def put_aggregation_authorization(self): def put_aggregation_authorization(self):
agg_auth = self.config_backend.put_aggregation_authorization(self.region, agg_auth = self.config_backend.put_aggregation_authorization(
self._get_param('AuthorizedAccountId'), self.region,
self._get_param('AuthorizedAwsRegion'), self._get_param("AuthorizedAccountId"),
self._get_param('Tags')) self._get_param("AuthorizedAwsRegion"),
schema = {'AggregationAuthorization': agg_auth} self._get_param("Tags"),
)
schema = {"AggregationAuthorization": agg_auth}
return json.dumps(schema) return json.dumps(schema)
def describe_aggregation_authorizations(self): def describe_aggregation_authorizations(self):
authorizations = self.config_backend.describe_aggregation_authorizations(self._get_param('NextToken'), self._get_param('Limit')) authorizations = self.config_backend.describe_aggregation_authorizations(
self._get_param("NextToken"), self._get_param("Limit")
)
return json.dumps(authorizations) return json.dumps(authorizations)
def delete_aggregation_authorization(self): def delete_aggregation_authorization(self):
self.config_backend.delete_aggregation_authorization(self._get_param('AuthorizedAccountId'), self._get_param('AuthorizedAwsRegion')) self.config_backend.delete_aggregation_authorization(
self._get_param("AuthorizedAccountId"),
self._get_param("AuthorizedAwsRegion"),
)
return "" return ""
def describe_configuration_recorders(self): def describe_configuration_recorders(self):
recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames')) recorders = self.config_backend.describe_configuration_recorders(
schema = {'ConfigurationRecorders': recorders} self._get_param("ConfigurationRecorderNames")
)
schema = {"ConfigurationRecorders": recorders}
return json.dumps(schema) return json.dumps(schema)
def describe_configuration_recorder_status(self): def describe_configuration_recorder_status(self):
recorder_statuses = self.config_backend.describe_configuration_recorder_status( recorder_statuses = self.config_backend.describe_configuration_recorder_status(
self._get_param('ConfigurationRecorderNames')) self._get_param("ConfigurationRecorderNames")
schema = {'ConfigurationRecordersStatus': recorder_statuses} )
schema = {"ConfigurationRecordersStatus": recorder_statuses}
return json.dumps(schema) return json.dumps(schema)
def put_delivery_channel(self): def put_delivery_channel(self):
self.config_backend.put_delivery_channel(self._get_param('DeliveryChannel')) self.config_backend.put_delivery_channel(self._get_param("DeliveryChannel"))
return "" return ""
def describe_delivery_channels(self): def describe_delivery_channels(self):
delivery_channels = self.config_backend.describe_delivery_channels(self._get_param('DeliveryChannelNames')) delivery_channels = self.config_backend.describe_delivery_channels(
schema = {'DeliveryChannels': delivery_channels} self._get_param("DeliveryChannelNames")
)
schema = {"DeliveryChannels": delivery_channels}
return json.dumps(schema) return json.dumps(schema)
def describe_delivery_channel_status(self): def describe_delivery_channel_status(self):
raise NotImplementedError() raise NotImplementedError()
def delete_delivery_channel(self): def delete_delivery_channel(self):
self.config_backend.delete_delivery_channel(self._get_param('DeliveryChannelName')) self.config_backend.delete_delivery_channel(
self._get_param("DeliveryChannelName")
)
return "" return ""
def delete_configuration_recorder(self): def delete_configuration_recorder(self):
self.config_backend.delete_configuration_recorder(self._get_param('ConfigurationRecorderName')) self.config_backend.delete_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return "" return ""
def start_configuration_recorder(self): def start_configuration_recorder(self):
self.config_backend.start_configuration_recorder(self._get_param('ConfigurationRecorderName')) self.config_backend.start_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return "" return ""
def stop_configuration_recorder(self): def stop_configuration_recorder(self):
self.config_backend.stop_configuration_recorder(self._get_param('ConfigurationRecorderName')) self.config_backend.stop_configuration_recorder(
self._get_param("ConfigurationRecorderName")
)
return "" return ""
def list_discovered_resources(self): def list_discovered_resources(self):
schema = self.config_backend.list_discovered_resources(self._get_param('resourceType'), schema = self.config_backend.list_discovered_resources(
self._get_param("resourceType"),
self.region, self.region,
self._get_param('resourceIds'), self._get_param("resourceIds"),
self._get_param('resourceName'), self._get_param("resourceName"),
self._get_param('limit'), self._get_param("limit"),
self._get_param('nextToken')) self._get_param("nextToken"),
)
return json.dumps(schema) return json.dumps(schema)
def list_aggregate_discovered_resources(self): def list_aggregate_discovered_resources(self):
schema = self.config_backend.list_aggregate_discovered_resources(self._get_param('ConfigurationAggregatorName'), schema = self.config_backend.list_aggregate_discovered_resources(
self._get_param('ResourceType'), self._get_param("ConfigurationAggregatorName"),
self._get_param('Filters'), self._get_param("ResourceType"),
self._get_param('Limit'), self._get_param("Filters"),
self._get_param('NextToken')) self._get_param("Limit"),
self._get_param("NextToken"),
)
return json.dumps(schema) return json.dumps(schema)
def get_resource_config_history(self): def get_resource_config_history(self):
schema = self.config_backend.get_resource_config_history(self._get_param('resourceType'), schema = self.config_backend.get_resource_config_history(
self._get_param('resourceId'), self._get_param("resourceType"), self._get_param("resourceId"), self.region
self.region) )
return json.dumps(schema) return json.dumps(schema)
def batch_get_resource_config(self): def batch_get_resource_config(self):
schema = self.config_backend.batch_get_resource_config(self._get_param('resourceKeys'), schema = self.config_backend.batch_get_resource_config(
self.region) self._get_param("resourceKeys"), self.region
)
return json.dumps(schema) return json.dumps(schema)
def batch_get_aggregate_resource_config(self): def batch_get_aggregate_resource_config(self):
schema = self.config_backend.batch_get_aggregate_resource_config(self._get_param('ConfigurationAggregatorName'), schema = self.config_backend.batch_get_aggregate_resource_config(
self._get_param('ResourceIdentifiers')) self._get_param("ConfigurationAggregatorName"),
self._get_param("ResourceIdentifiers"),
)
return json.dumps(schema) return json.dumps(schema)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import ConfigResponse from .responses import ConfigResponse
url_bases = [ url_bases = ["https?://config.(.+).amazonaws.com"]
"https?://config.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": ConfigResponse.dispatch}
'{0}/$': ConfigResponse.dispatch,
}

View File

@ -4,4 +4,6 @@ from .models import BaseModel, BaseBackend, moto_api_backend # noqa
from .responses import ActionAuthenticatorMixin from .responses import ActionAuthenticatorMixin
moto_api_backends = {"global": moto_api_backend} moto_api_backends = {"global": moto_api_backend}
set_initial_no_auth_action_count = ActionAuthenticatorMixin.set_initial_no_auth_action_count set_initial_no_auth_action_count = (
ActionAuthenticatorMixin.set_initial_no_auth_action_count
)

View File

@ -26,7 +26,12 @@ from six import string_types
from moto.iam.models import ACCOUNT_ID, Policy from moto.iam.models import ACCOUNT_ID, Policy
from moto.iam import iam_backend from moto.iam import iam_backend
from moto.core.exceptions import SignatureDoesNotMatchError, AccessDeniedError, InvalidClientTokenIdError, AuthFailureError from moto.core.exceptions import (
SignatureDoesNotMatchError,
AccessDeniedError,
InvalidClientTokenIdError,
AuthFailureError,
)
from moto.s3.exceptions import ( from moto.s3.exceptions import (
BucketAccessDeniedError, BucketAccessDeniedError,
S3AccessDeniedError, S3AccessDeniedError,
@ -35,7 +40,7 @@ from moto.s3.exceptions import (
S3InvalidAccessKeyIdError, S3InvalidAccessKeyIdError,
BucketInvalidAccessKeyIdError, BucketInvalidAccessKeyIdError,
BucketSignatureDoesNotMatchError, BucketSignatureDoesNotMatchError,
S3SignatureDoesNotMatchError S3SignatureDoesNotMatchError,
) )
from moto.sts import sts_backend from moto.sts import sts_backend
@ -50,9 +55,8 @@ def create_access_key(access_key_id, headers):
class IAMUserAccessKey(object): class IAMUserAccessKey(object):
def __init__(self, access_key_id, headers): def __init__(self, access_key_id, headers):
iam_users = iam_backend.list_users('/', None, None) iam_users = iam_backend.list_users("/", None, None)
for iam_user in iam_users: for iam_user in iam_users:
for access_key in iam_user.access_keys: for access_key in iam_user.access_keys:
if access_key.access_key_id == access_key_id: if access_key.access_key_id == access_key_id:
@ -67,8 +71,7 @@ class IAMUserAccessKey(object):
@property @property
def arn(self): def arn(self):
return "arn:aws:iam::{account_id}:user/{iam_user_name}".format( return "arn:aws:iam::{account_id}:user/{iam_user_name}".format(
account_id=ACCOUNT_ID, account_id=ACCOUNT_ID, iam_user_name=self._owner_user_name
iam_user_name=self._owner_user_name
) )
def create_credentials(self): def create_credentials(self):
@ -79,27 +82,34 @@ class IAMUserAccessKey(object):
inline_policy_names = iam_backend.list_user_policies(self._owner_user_name) inline_policy_names = iam_backend.list_user_policies(self._owner_user_name)
for inline_policy_name in inline_policy_names: for inline_policy_name in inline_policy_names:
inline_policy = iam_backend.get_user_policy(self._owner_user_name, inline_policy_name) inline_policy = iam_backend.get_user_policy(
self._owner_user_name, inline_policy_name
)
user_policies.append(inline_policy) user_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_user_policies(self._owner_user_name) attached_policies, _ = iam_backend.list_attached_user_policies(
self._owner_user_name
)
user_policies += attached_policies user_policies += attached_policies
user_groups = iam_backend.get_groups_for_user(self._owner_user_name) user_groups = iam_backend.get_groups_for_user(self._owner_user_name)
for user_group in user_groups: for user_group in user_groups:
inline_group_policy_names = iam_backend.list_group_policies(user_group.name) inline_group_policy_names = iam_backend.list_group_policies(user_group.name)
for inline_group_policy_name in inline_group_policy_names: for inline_group_policy_name in inline_group_policy_names:
inline_user_group_policy = iam_backend.get_group_policy(user_group.name, inline_group_policy_name) inline_user_group_policy = iam_backend.get_group_policy(
user_group.name, inline_group_policy_name
)
user_policies.append(inline_user_group_policy) user_policies.append(inline_user_group_policy)
attached_group_policies, _ = iam_backend.list_attached_group_policies(user_group.name) attached_group_policies, _ = iam_backend.list_attached_group_policies(
user_group.name
)
user_policies += attached_group_policies user_policies += attached_group_policies
return user_policies return user_policies
class AssumedRoleAccessKey(object): class AssumedRoleAccessKey(object):
def __init__(self, access_key_id, headers): def __init__(self, access_key_id, headers):
for assumed_role in sts_backend.assumed_roles: for assumed_role in sts_backend.assumed_roles:
if assumed_role.access_key_id == access_key_id: if assumed_role.access_key_id == access_key_id:
@ -118,28 +128,33 @@ class AssumedRoleAccessKey(object):
return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format(
account_id=ACCOUNT_ID, account_id=ACCOUNT_ID,
role_name=self._owner_role_name, role_name=self._owner_role_name,
session_name=self._session_name session_name=self._session_name,
) )
def create_credentials(self): def create_credentials(self):
return Credentials(self._access_key_id, self._secret_access_key, self._session_token) return Credentials(
self._access_key_id, self._secret_access_key, self._session_token
)
def collect_policies(self): def collect_policies(self):
role_policies = [] role_policies = []
inline_policy_names = iam_backend.list_role_policies(self._owner_role_name) inline_policy_names = iam_backend.list_role_policies(self._owner_role_name)
for inline_policy_name in inline_policy_names: for inline_policy_name in inline_policy_names:
_, inline_policy = iam_backend.get_role_policy(self._owner_role_name, inline_policy_name) _, inline_policy = iam_backend.get_role_policy(
self._owner_role_name, inline_policy_name
)
role_policies.append(inline_policy) role_policies.append(inline_policy)
attached_policies, _ = iam_backend.list_attached_role_policies(self._owner_role_name) attached_policies, _ = iam_backend.list_attached_role_policies(
self._owner_role_name
)
role_policies += attached_policies role_policies += attached_policies
return role_policies return role_policies
class CreateAccessKeyFailure(Exception): class CreateAccessKeyFailure(Exception):
def __init__(self, reason, *args): def __init__(self, reason, *args):
super(CreateAccessKeyFailure, self).__init__(*args) super(CreateAccessKeyFailure, self).__init__(*args)
self.reason = reason self.reason = reason
@ -147,32 +162,54 @@ class CreateAccessKeyFailure(Exception):
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class IAMRequestBase(object): class IAMRequestBase(object):
def __init__(self, method, path, data, headers): def __init__(self, method, path, data, headers):
log.debug("Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format( log.debug(
class_name=self.__class__.__name__, method=method, path=path, data=data, headers=headers)) "Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format(
class_name=self.__class__.__name__,
method=method,
path=path,
data=data,
headers=headers,
)
)
self._method = method self._method = method
self._path = path self._path = path
self._data = data self._data = data
self._headers = headers self._headers = headers
credential_scope = self._get_string_between('Credential=', ',', self._headers['Authorization']) credential_scope = self._get_string_between(
credential_data = credential_scope.split('/') "Credential=", ",", self._headers["Authorization"]
)
credential_data = credential_scope.split("/")
self._region = credential_data[2] self._region = credential_data[2]
self._service = credential_data[3] self._service = credential_data[3]
self._action = self._service + ":" + (self._data["Action"][0] if isinstance(self._data["Action"], list) else self._data["Action"]) self._action = (
self._service
+ ":"
+ (
self._data["Action"][0]
if isinstance(self._data["Action"], list)
else self._data["Action"]
)
)
try: try:
self._access_key = create_access_key(access_key_id=credential_data[0], headers=headers) self._access_key = create_access_key(
access_key_id=credential_data[0], headers=headers
)
except CreateAccessKeyFailure as e: except CreateAccessKeyFailure as e:
self._raise_invalid_access_key(e.reason) self._raise_invalid_access_key(e.reason)
def check_signature(self): def check_signature(self):
original_signature = self._get_string_between('Signature=', ',', self._headers['Authorization']) original_signature = self._get_string_between(
"Signature=", ",", self._headers["Authorization"]
)
calculated_signature = self._calculate_signature() calculated_signature = self._calculate_signature()
if original_signature != calculated_signature: if original_signature != calculated_signature:
self._raise_signature_does_not_match() self._raise_signature_does_not_match()
def check_action_permitted(self): def check_action_permitted(self):
if self._action == 'sts:GetCallerIdentity': # always allowed, even if there's an explicit Deny for it if (
self._action == "sts:GetCallerIdentity"
): # always allowed, even if there's an explicit Deny for it
return True return True
policies = self._access_key.collect_policies() policies = self._access_key.collect_policies()
@ -213,10 +250,14 @@ class IAMRequestBase(object):
return headers return headers
def _create_aws_request(self): def _create_aws_request(self):
signed_headers = self._get_string_between('SignedHeaders=', ',', self._headers['Authorization']).split(';') signed_headers = self._get_string_between(
"SignedHeaders=", ",", self._headers["Authorization"]
).split(";")
headers = self._create_headers_for_aws_request(signed_headers, self._headers) headers = self._create_headers_for_aws_request(signed_headers, self._headers)
request = AWSRequest(method=self._method, url=self._path, data=self._data, headers=headers) request = AWSRequest(
request.context['timestamp'] = headers['X-Amz-Date'] method=self._method, url=self._path, data=self._data, headers=headers
)
request.context["timestamp"] = headers["X-Amz-Date"]
return request return request
@ -234,7 +275,6 @@ class IAMRequestBase(object):
class IAMRequest(IAMRequestBase): class IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self): def _raise_signature_does_not_match(self):
if self._service == "ec2": if self._service == "ec2":
raise AuthFailureError() raise AuthFailureError()
@ -251,14 +291,10 @@ class IAMRequest(IAMRequestBase):
return SigV4Auth(credentials, self._service, self._region) return SigV4Auth(credentials, self._service, self._region)
def _raise_access_denied(self): def _raise_access_denied(self):
raise AccessDeniedError( raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action)
user_arn=self._access_key.arn,
action=self._action
)
class S3IAMRequest(IAMRequestBase): class S3IAMRequest(IAMRequestBase):
def _raise_signature_does_not_match(self): def _raise_signature_does_not_match(self):
if "BucketName" in self._data: if "BucketName" in self._data:
raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"]) raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"])
@ -288,10 +324,13 @@ class S3IAMRequest(IAMRequestBase):
class IAMPolicy(object): class IAMPolicy(object):
def __init__(self, policy): def __init__(self, policy):
if isinstance(policy, Policy): if isinstance(policy, Policy):
default_version = next(policy_version for policy_version in policy.versions if policy_version.is_default) default_version = next(
policy_version
for policy_version in policy.versions
if policy_version.is_default
)
policy_document = default_version.document policy_document = default_version.document
elif isinstance(policy, string_types): elif isinstance(policy, string_types):
policy_document = policy policy_document = policy
@ -321,7 +360,6 @@ class IAMPolicy(object):
class IAMPolicyStatement(object): class IAMPolicyStatement(object):
def __init__(self, statement): def __init__(self, statement):
self._statement = statement self._statement = statement

View File

@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException
from jinja2 import DictLoader, Environment from jinja2 import DictLoader, Environment
SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?> SINGLE_ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<Error> <Error>
<Code>{{error_type}}</Code> <Code>{{error_type}}</Code>
<Message>{{message}}</Message> <Message>{{message}}</Message>
@ -13,7 +13,7 @@ SINGLE_ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</Error> </Error>
""" """
ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?> ERROR_RESPONSE = """<?xml version="1.0" encoding="UTF-8"?>
<ErrorResponse> <ErrorResponse>
<Errors> <Errors>
<Error> <Error>
@ -26,7 +26,7 @@ ERROR_RESPONSE = u"""<?xml version="1.0" encoding="UTF-8"?>
</ErrorResponse> </ErrorResponse>
""" """
ERROR_JSON_RESPONSE = u"""{ ERROR_JSON_RESPONSE = """{
"message": "{{message}}", "message": "{{message}}",
"__type": "{{error_type}}" "__type": "{{error_type}}"
} }
@ -37,18 +37,19 @@ class RESTError(HTTPException):
code = 400 code = 400
templates = { templates = {
'single_error': SINGLE_ERROR_RESPONSE, "single_error": SINGLE_ERROR_RESPONSE,
'error': ERROR_RESPONSE, "error": ERROR_RESPONSE,
'error_json': ERROR_JSON_RESPONSE, "error_json": ERROR_JSON_RESPONSE,
} }
def __init__(self, error_type, message, template='error', **kwargs): def __init__(self, error_type, message, template="error", **kwargs):
super(RESTError, self).__init__() super(RESTError, self).__init__()
env = Environment(loader=DictLoader(self.templates)) env = Environment(loader=DictLoader(self.templates))
self.error_type = error_type self.error_type = error_type
self.message = message self.message = message
self.description = env.get_template(template).render( self.description = env.get_template(template).render(
error_type=error_type, message=message, **kwargs) error_type=error_type, message=message, **kwargs
)
class DryRunClientError(RESTError): class DryRunClientError(RESTError):
@ -56,12 +57,11 @@ class DryRunClientError(RESTError):
class JsonRESTError(RESTError): class JsonRESTError(RESTError):
def __init__(self, error_type, message, template='error_json', **kwargs): def __init__(self, error_type, message, template="error_json", **kwargs):
super(JsonRESTError, self).__init__( super(JsonRESTError, self).__init__(error_type, message, template, **kwargs)
error_type, message, template, **kwargs)
def get_headers(self, *args, **kwargs): def get_headers(self, *args, **kwargs):
return [('Content-Type', 'application/json')] return [("Content-Type", "application/json")]
def get_body(self, *args, **kwargs): def get_body(self, *args, **kwargs):
return self.description return self.description
@ -72,8 +72,9 @@ class SignatureDoesNotMatchError(RESTError):
def __init__(self): def __init__(self):
super(SignatureDoesNotMatchError, self).__init__( super(SignatureDoesNotMatchError, self).__init__(
'SignatureDoesNotMatch', "SignatureDoesNotMatch",
"The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.") "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.",
)
class InvalidClientTokenIdError(RESTError): class InvalidClientTokenIdError(RESTError):
@ -81,8 +82,9 @@ class InvalidClientTokenIdError(RESTError):
def __init__(self): def __init__(self):
super(InvalidClientTokenIdError, self).__init__( super(InvalidClientTokenIdError, self).__init__(
'InvalidClientTokenId', "InvalidClientTokenId",
"The security token included in the request is invalid.") "The security token included in the request is invalid.",
)
class AccessDeniedError(RESTError): class AccessDeniedError(RESTError):
@ -90,11 +92,11 @@ class AccessDeniedError(RESTError):
def __init__(self, user_arn, action): def __init__(self, user_arn, action):
super(AccessDeniedError, self).__init__( super(AccessDeniedError, self).__init__(
'AccessDenied', "AccessDenied",
"User: {user_arn} is not authorized to perform: {operation}".format( "User: {user_arn} is not authorized to perform: {operation}".format(
user_arn=user_arn, user_arn=user_arn, operation=action
operation=action ),
)) )
class AuthFailureError(RESTError): class AuthFailureError(RESTError):
@ -102,13 +104,17 @@ class AuthFailureError(RESTError):
def __init__(self): def __init__(self):
super(AuthFailureError, self).__init__( super(AuthFailureError, self).__init__(
'AuthFailure', "AuthFailure",
"AWS was not able to validate the provided access credentials") "AWS was not able to validate the provided access credentials",
)
class InvalidNextTokenException(JsonRESTError): class InvalidNextTokenException(JsonRESTError):
"""For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core.""" """For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core."""
code = 400 code = 400
def __init__(self): def __init__(self):
super(InvalidNextTokenException, self).__init__('InvalidNextTokenException', 'The nextToken provided is invalid') super(InvalidNextTokenException, self).__init__(
"InvalidNextTokenException", "The nextToken provided is invalid"
)

View File

@ -31,15 +31,19 @@ class BaseMockAWS(object):
self.backends_for_urls = {} self.backends_for_urls = {}
from moto.backends import BACKENDS from moto.backends import BACKENDS
default_backends = { default_backends = {
"instance_metadata": BACKENDS['instance_metadata']['global'], "instance_metadata": BACKENDS["instance_metadata"]["global"],
"moto_api": BACKENDS['moto_api']['global'], "moto_api": BACKENDS["moto_api"]["global"],
} }
self.backends_for_urls.update(self.backends) self.backends_for_urls.update(self.backends)
self.backends_for_urls.update(default_backends) self.backends_for_urls.update(default_backends)
# "Mock" the AWS credentials as they can't be mocked in Botocore currently # "Mock" the AWS credentials as they can't be mocked in Botocore currently
FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"} FAKE_KEYS = {
"AWS_ACCESS_KEY_ID": "foobar_key",
"AWS_SECRET_ACCESS_KEY": "foobar_secret",
}
self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS) self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS)
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
@ -72,7 +76,7 @@ class BaseMockAWS(object):
self.__class__.nested_count -= 1 self.__class__.nested_count -= 1
if self.__class__.nested_count < 0: if self.__class__.nested_count < 0:
raise RuntimeError('Called stop() before start().') raise RuntimeError("Called stop() before start().")
if self.__class__.nested_count == 0: if self.__class__.nested_count == 0:
self.disable_patching() self.disable_patching()
@ -85,6 +89,7 @@ class BaseMockAWS(object):
finally: finally:
self.stop() self.stop()
return result return result
functools.update_wrapper(wrapper, func) functools.update_wrapper(wrapper, func)
wrapper.__wrapped__ = func wrapper.__wrapped__ = func
return wrapper return wrapper
@ -122,7 +127,6 @@ class BaseMockAWS(object):
class HttprettyMockAWS(BaseMockAWS): class HttprettyMockAWS(BaseMockAWS):
def reset(self): def reset(self):
HTTPretty.reset() HTTPretty.reset()
@ -144,18 +148,26 @@ class HttprettyMockAWS(BaseMockAWS):
HTTPretty.reset() HTTPretty.reset()
RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD, RESPONSES_METHODS = [
responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT] responses.GET,
responses.DELETE,
responses.HEAD,
responses.OPTIONS,
responses.PATCH,
responses.POST,
responses.PUT,
]
class CallbackResponse(responses.CallbackResponse): class CallbackResponse(responses.CallbackResponse):
''' """
Need to subclass so we can change a couple things Need to subclass so we can change a couple things
''' """
def get_response(self, request): def get_response(self, request):
''' """
Need to override this so we can pass decode_content=False Need to override this so we can pass decode_content=False
''' """
headers = self.get_headers() headers = self.get_headers()
result = self.callback(request) result = self.callback(request)
@ -177,17 +189,17 @@ class CallbackResponse(responses.CallbackResponse):
) )
def _url_matches(self, url, other, match_querystring=False): def _url_matches(self, url, other, match_querystring=False):
''' """
Need to override this so we can fix querystrings breaking regex matching Need to override this so we can fix querystrings breaking regex matching
''' """
if not match_querystring: if not match_querystring:
other = other.split('?', 1)[0] other = other.split("?", 1)[0]
if responses._is_string(url): if responses._is_string(url):
if responses._has_unicode(url): if responses._has_unicode(url):
url = responses._clean_unicode(url) url = responses._clean_unicode(url)
if not isinstance(other, six.text_type): if not isinstance(other, six.text_type):
other = other.encode('ascii').decode('utf8') other = other.encode("ascii").decode("utf8")
return self._url_matches_strict(url, other) return self._url_matches_strict(url, other)
elif isinstance(url, responses.Pattern) and url.match(other): elif isinstance(url, responses.Pattern) and url.match(other):
return True return True
@ -195,22 +207,23 @@ class CallbackResponse(responses.CallbackResponse):
return False return False
botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') botocore_mock = responses.RequestsMock(
assert_all_requests_are_fired=False,
target="botocore.vendored.requests.adapters.HTTPAdapter.send",
)
responses_mock = responses._default_mock responses_mock = responses._default_mock
# Add passthrough to allow any other requests to work # Add passthrough to allow any other requests to work
# Since this uses .startswith, it applies to http and https requests. # Since this uses .startswith, it applies to http and https requests.
responses_mock.add_passthru("http") responses_mock.add_passthru("http")
BOTOCORE_HTTP_METHODS = [ BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"]
'GET', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT'
]
class MockRawResponse(BytesIO): class MockRawResponse(BytesIO):
def __init__(self, input): def __init__(self, input):
if isinstance(input, six.text_type): if isinstance(input, six.text_type):
input = input.encode('utf-8') input = input.encode("utf-8")
super(MockRawResponse, self).__init__(input) super(MockRawResponse, self).__init__(input)
def stream(self, **kwargs): def stream(self, **kwargs):
@ -241,7 +254,7 @@ class BotocoreStubber(object):
found_index = None found_index = None
matchers = self.methods.get(request.method) matchers = self.methods.get(request.method)
base_url = request.url.split('?', 1)[0] base_url = request.url.split("?", 1)[0]
for i, (pattern, callback) in enumerate(matchers): for i, (pattern, callback) in enumerate(matchers):
if pattern.match(base_url): if pattern.match(base_url):
if found_index is None: if found_index is None:
@ -254,8 +267,10 @@ class BotocoreStubber(object):
if response_callback is not None: if response_callback is not None:
for header, value in request.headers.items(): for header, value in request.headers.items():
if isinstance(value, six.binary_type): if isinstance(value, six.binary_type):
request.headers[header] = value.decode('utf-8') request.headers[header] = value.decode("utf-8")
status, headers, body = response_callback(request, request.url, request.headers) status, headers, body = response_callback(
request, request.url, request.headers
)
body = MockRawResponse(body) body = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, body) response = AWSResponse(request.url, status, headers, body)
@ -263,7 +278,7 @@ class BotocoreStubber(object):
botocore_stubber = BotocoreStubber() botocore_stubber = BotocoreStubber()
BUILTIN_HANDLERS.append(('before-send', botocore_stubber)) BUILTIN_HANDLERS.append(("before-send", botocore_stubber))
def not_implemented_callback(request): def not_implemented_callback(request):
@ -287,7 +302,9 @@ class BotocoreEventMockAWS(BaseMockAWS):
pattern = re.compile(key) pattern = re.compile(key)
botocore_stubber.register_response(method, pattern, value) botocore_stubber.register_response(method, pattern, value)
if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'): if not hasattr(responses_mock, "_patcher") or not hasattr(
responses_mock._patcher, "target"
):
responses_mock.start() responses_mock.start()
for method in RESPONSES_METHODS: for method in RESPONSES_METHODS:
@ -336,9 +353,9 @@ MockAWS = BotocoreEventMockAWS
class ServerModeMockAWS(BaseMockAWS): class ServerModeMockAWS(BaseMockAWS):
def reset(self): def reset(self):
import requests import requests
requests.post("http://localhost:5000/moto-api/reset") requests.post("http://localhost:5000/moto-api/reset")
def enable_patching(self): def enable_patching(self):
@ -350,13 +367,13 @@ class ServerModeMockAWS(BaseMockAWS):
import mock import mock
def fake_boto3_client(*args, **kwargs): def fake_boto3_client(*args, **kwargs):
if 'endpoint_url' not in kwargs: if "endpoint_url" not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000" kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_client(*args, **kwargs) return real_boto3_client(*args, **kwargs)
def fake_boto3_resource(*args, **kwargs): def fake_boto3_resource(*args, **kwargs):
if 'endpoint_url' not in kwargs: if "endpoint_url" not in kwargs:
kwargs['endpoint_url'] = "http://localhost:5000" kwargs["endpoint_url"] = "http://localhost:5000"
return real_boto3_resource(*args, **kwargs) return real_boto3_resource(*args, **kwargs)
def fake_httplib_send_output(self, message_body=None, *args, **kwargs): def fake_httplib_send_output(self, message_body=None, *args, **kwargs):
@ -364,7 +381,7 @@ class ServerModeMockAWS(BaseMockAWS):
bytes_buffer = [] bytes_buffer = []
for chunk in mixed_buffer: for chunk in mixed_buffer:
if isinstance(chunk, six.text_type): if isinstance(chunk, six.text_type):
bytes_buffer.append(chunk.encode('utf-8')) bytes_buffer.append(chunk.encode("utf-8"))
else: else:
bytes_buffer.append(chunk) bytes_buffer.append(chunk)
msg = b"\r\n".join(bytes_buffer) msg = b"\r\n".join(bytes_buffer)
@ -385,10 +402,12 @@ class ServerModeMockAWS(BaseMockAWS):
if message_body is not None: if message_body is not None:
self.send(message_body) self.send(message_body)
self._client_patcher = mock.patch('boto3.client', fake_boto3_client) self._client_patcher = mock.patch("boto3.client", fake_boto3_client)
self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource) self._resource_patcher = mock.patch("boto3.resource", fake_boto3_resource)
if six.PY2: if six.PY2:
self._httplib_patcher = mock.patch('httplib.HTTPConnection._send_output', fake_httplib_send_output) self._httplib_patcher = mock.patch(
"httplib.HTTPConnection._send_output", fake_httplib_send_output
)
self._client_patcher.start() self._client_patcher.start()
self._resource_patcher.start() self._resource_patcher.start()
@ -404,7 +423,6 @@ class ServerModeMockAWS(BaseMockAWS):
class Model(type): class Model(type):
def __new__(self, clsname, bases, namespace): def __new__(self, clsname, bases, namespace):
cls = super(Model, self).__new__(self, clsname, bases, namespace) cls = super(Model, self).__new__(self, clsname, bases, namespace)
cls.__models__ = {} cls.__models__ = {}
@ -419,9 +437,11 @@ class Model(type):
@staticmethod @staticmethod
def prop(model_name): def prop(model_name):
""" decorator to mark a class method as returning model values """ """ decorator to mark a class method as returning model values """
def dec(f): def dec(f):
f.__returns_model__ = model_name f.__returns_model__ = model_name
return f return f
return dec return dec
@ -431,7 +451,7 @@ model_data = defaultdict(dict)
class InstanceTrackerMeta(type): class InstanceTrackerMeta(type):
def __new__(meta, name, bases, dct): def __new__(meta, name, bases, dct):
cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct) cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct)
if name == 'BaseModel': if name == "BaseModel":
return cls return cls
service = cls.__module__.split(".")[1] service = cls.__module__.split(".")[1]
@ -450,7 +470,6 @@ class BaseModel(object):
class BaseBackend(object): class BaseBackend(object):
def _reset_model_refs(self): def _reset_model_refs(self):
# Remove all references to the models stored # Remove all references to the models stored
for service, models in model_data.items(): for service, models in model_data.items():
@ -466,8 +485,9 @@ class BaseBackend(object):
def _url_module(self): def _url_module(self):
backend_module = self.__class__.__module__ backend_module = self.__class__.__module__
backend_urls_module_name = backend_module.replace("models", "urls") backend_urls_module_name = backend_module.replace("models", "urls")
backend_urls_module = __import__(backend_urls_module_name, fromlist=[ backend_urls_module = __import__(
'url_bases', 'url_paths']) backend_urls_module_name, fromlist=["url_bases", "url_paths"]
)
return backend_urls_module return backend_urls_module
@property @property
@ -523,9 +543,9 @@ class BaseBackend(object):
def decorator(self, func=None): def decorator(self, func=None):
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
mocked_backend = ServerModeMockAWS({'global': self}) mocked_backend = ServerModeMockAWS({"global": self})
else: else:
mocked_backend = MockAWS({'global': self}) mocked_backend = MockAWS({"global": self})
if func: if func:
return mocked_backend(func) return mocked_backend(func)
@ -534,9 +554,9 @@ class BaseBackend(object):
def deprecated_decorator(self, func=None): def deprecated_decorator(self, func=None):
if func: if func:
return HttprettyMockAWS({'global': self})(func) return HttprettyMockAWS({"global": self})(func)
else: else:
return HttprettyMockAWS({'global': self}) return HttprettyMockAWS({"global": self})
# def list_config_service_resources(self, resource_ids, resource_name, limit, next_token): # def list_config_service_resources(self, resource_ids, resource_name, limit, next_token):
# """For AWS Config. This will list all of the resources of the given type and optional resource name and region""" # """For AWS Config. This will list all of the resources of the given type and optional resource name and region"""
@ -544,12 +564,19 @@ class BaseBackend(object):
class ConfigQueryModel(object): class ConfigQueryModel(object):
def __init__(self, backends): def __init__(self, backends):
"""Inits based on the resource type's backends (1 for each region if applicable)""" """Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends self.backends = backends
def list_config_service_resources(self, resource_ids, resource_name, limit, next_token, backend_region=None, resource_region=None): def list_config_service_resources(
self,
resource_ids,
resource_name,
limit,
next_token,
backend_region=None,
resource_region=None,
):
"""For AWS Config. This will list all of the resources of the given type and optional resource name and region. """For AWS Config. This will list all of the resources of the given type and optional resource name and region.
This supports both aggregated and non-aggregated listing. The following notes the difference: This supports both aggregated and non-aggregated listing. The following notes the difference:
@ -593,7 +620,9 @@ class ConfigQueryModel(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_config_resource(self, resource_id, resource_name=None, backend_region=None, resource_region=None): def get_config_resource(
self, resource_id, resource_name=None, backend_region=None, resource_region=None
):
"""For AWS Config. This will query the backend for the specific resource type configuration. """For AWS Config. This will query the backend for the specific resource type configuration.
This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests
@ -644,9 +673,9 @@ class deprecated_base_decorator(base_decorator):
class MotoAPIBackend(BaseBackend): class MotoAPIBackend(BaseBackend):
def reset(self): def reset(self):
from moto.backends import BACKENDS from moto.backends import BACKENDS
for name, backends in BACKENDS.items(): for name, backends in BACKENDS.items():
if name == "moto_api": if name == "moto_api":
continue continue

View File

@ -40,7 +40,7 @@ def _decode_dict(d):
newkey = [] newkey = []
for k in key: for k in key:
if isinstance(k, six.binary_type): if isinstance(k, six.binary_type):
newkey.append(k.decode('utf-8')) newkey.append(k.decode("utf-8"))
else: else:
newkey.append(k) newkey.append(k)
else: else:
@ -52,7 +52,7 @@ def _decode_dict(d):
newvalue = [] newvalue = []
for v in value: for v in value:
if isinstance(v, six.binary_type): if isinstance(v, six.binary_type):
newvalue.append(v.decode('utf-8')) newvalue.append(v.decode("utf-8"))
else: else:
newvalue.append(v) newvalue.append(v)
else: else:
@ -90,7 +90,8 @@ class _TemplateEnvironmentMixin(object):
super(_TemplateEnvironmentMixin, self).__init__() super(_TemplateEnvironmentMixin, self).__init__()
self.loader = DynamicDictLoader({}) self.loader = DynamicDictLoader({})
self.environment = Environment( self.environment = Environment(
loader=self.loader, autoescape=self.should_autoescape) loader=self.loader, autoescape=self.should_autoescape
)
@property @property
def should_autoescape(self): def should_autoescape(self):
@ -104,13 +105,15 @@ class _TemplateEnvironmentMixin(object):
template_id = id(source) template_id = id(source)
if not self.contains_template(template_id): if not self.contains_template(template_id):
collapsed = re.sub( collapsed = re.sub(
self.RIGHT_PATTERN, self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source)
">",
re.sub(self.LEFT_PATTERN, "<", source)
) )
self.loader.update({template_id: collapsed}) self.loader.update({template_id: collapsed})
self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape, trim_blocks=True, self.environment = Environment(
lstrip_blocks=True) loader=self.loader,
autoescape=self.should_autoescape,
trim_blocks=True,
lstrip_blocks=True,
)
return self.environment.get_template(template_id) return self.environment.get_template(template_id)
@ -119,8 +122,13 @@ class ActionAuthenticatorMixin(object):
request_count = 0 request_count = 0
def _authenticate_and_authorize_action(self, iam_request_cls): def _authenticate_and_authorize_action(self, iam_request_cls):
if ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT: if (
iam_request = iam_request_cls(method=self.method, path=self.path, data=self.data, headers=self.headers) ActionAuthenticatorMixin.request_count
>= settings.INITIAL_NO_AUTH_ACTION_COUNT
):
iam_request = iam_request_cls(
method=self.method, path=self.path, data=self.data, headers=self.headers
)
iam_request.check_signature() iam_request.check_signature()
iam_request.check_action_permitted() iam_request.check_action_permitted()
else: else:
@ -137,10 +145,17 @@ class ActionAuthenticatorMixin(object):
def decorator(function): def decorator(function):
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
response = requests.post("http://localhost:5000/moto-api/reset-auth", data=str(initial_no_auth_action_count).encode()) response = requests.post(
original_initial_no_auth_action_count = response.json()['PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT'] "http://localhost:5000/moto-api/reset-auth",
data=str(initial_no_auth_action_count).encode(),
)
original_initial_no_auth_action_count = response.json()[
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT"
]
else: else:
original_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT original_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
original_request_count = ActionAuthenticatorMixin.request_count original_request_count = ActionAuthenticatorMixin.request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count
ActionAuthenticatorMixin.request_count = 0 ActionAuthenticatorMixin.request_count = 0
@ -148,10 +163,15 @@ class ActionAuthenticatorMixin(object):
result = function(*args, **kwargs) result = function(*args, **kwargs)
finally: finally:
if settings.TEST_SERVER_MODE: if settings.TEST_SERVER_MODE:
requests.post("http://localhost:5000/moto-api/reset-auth", data=str(original_initial_no_auth_action_count).encode()) requests.post(
"http://localhost:5000/moto-api/reset-auth",
data=str(original_initial_no_auth_action_count).encode(),
)
else: else:
ActionAuthenticatorMixin.request_count = original_request_count ActionAuthenticatorMixin.request_count = original_request_count
settings.INITIAL_NO_AUTH_ACTION_COUNT = original_initial_no_auth_action_count settings.INITIAL_NO_AUTH_ACTION_COUNT = (
original_initial_no_auth_action_count
)
return result return result
functools.update_wrapper(wrapper, function) functools.update_wrapper(wrapper, function)
@ -163,11 +183,13 @@ class ActionAuthenticatorMixin(object):
class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
default_region = 'us-east-1' default_region = "us-east-1"
# to extract region, use [^.] # to extract region, use [^.]
region_regex = re.compile(r'\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com') region_regex = re.compile(r"\.(?P<region>[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com")
param_list_regex = re.compile(r'(.*)\.(\d+)\.') param_list_regex = re.compile(r"(.*)\.(\d+)\.")
access_key_regex = re.compile(r'AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]') access_key_regex = re.compile(
r"AWS.*(?P<access_key>(?<![A-Z0-9])[A-Z0-9]{20}(?![A-Z0-9]))[:/]"
)
aws_service_spec = None aws_service_spec = None
@classmethod @classmethod
@ -176,7 +198,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def setup_class(self, request, full_url, headers): def setup_class(self, request, full_url, headers):
querystring = {} querystring = {}
if hasattr(request, 'body'): if hasattr(request, "body"):
# Boto # Boto
self.body = request.body self.body = request.body
else: else:
@ -189,24 +211,29 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
querystring = {} querystring = {}
for key, value in request.form.items(): for key, value in request.form.items():
querystring[key] = [value, ] querystring[key] = [value]
raw_body = self.body raw_body = self.body
if isinstance(self.body, six.binary_type): if isinstance(self.body, six.binary_type):
self.body = self.body.decode('utf-8') self.body = self.body.decode("utf-8")
if not querystring: if not querystring:
querystring.update( querystring.update(
parse_qs(urlparse(full_url).query, keep_blank_values=True)) parse_qs(urlparse(full_url).query, keep_blank_values=True)
)
if not querystring: if not querystring:
if 'json' in request.headers.get('content-type', []) and self.aws_service_spec: if (
"json" in request.headers.get("content-type", [])
and self.aws_service_spec
):
decoded = json.loads(self.body) decoded = json.loads(self.body)
target = request.headers.get( target = request.headers.get("x-amz-target") or request.headers.get(
'x-amz-target') or request.headers.get('X-Amz-Target') "X-Amz-Target"
service, method = target.split('.') )
service, method = target.split(".")
input_spec = self.aws_service_spec.input_spec(method) input_spec = self.aws_service_spec.input_spec(method)
flat = flatten_json_request_body('', decoded, input_spec) flat = flatten_json_request_body("", decoded, input_spec)
for key, value in flat.items(): for key, value in flat.items():
querystring[key] = [value] querystring[key] = [value]
elif self.body: elif self.body:
@ -231,17 +258,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
self.uri_match = None self.uri_match = None
self.headers = request.headers self.headers = request.headers
if 'host' not in self.headers: if "host" not in self.headers:
self.headers['host'] = urlparse(full_url).netloc self.headers["host"] = urlparse(full_url).netloc
self.response_headers = {"server": "amazon.com"} self.response_headers = {"server": "amazon.com"}
def get_region_from_url(self, request, full_url): def get_region_from_url(self, request, full_url):
match = self.region_regex.search(full_url) match = self.region_regex.search(full_url)
if match: if match:
region = match.group(1) region = match.group(1)
elif 'Authorization' in request.headers and 'AWS4' in request.headers['Authorization']: elif (
region = request.headers['Authorization'].split(",")[ "Authorization" in request.headers
0].split("/")[2] and "AWS4" in request.headers["Authorization"]
):
region = request.headers["Authorization"].split(",")[0].split("/")[2]
else: else:
region = self.default_region region = self.default_region
return region return region
@ -250,16 +279,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
""" """
Returns the access key id used in this request as the current user id Returns the access key id used in this request as the current user id
""" """
if 'Authorization' in self.headers: if "Authorization" in self.headers:
match = self.access_key_regex.search(self.headers['Authorization']) match = self.access_key_regex.search(self.headers["Authorization"])
if match: if match:
return match.group(1) return match.group(1)
if self.querystring.get('AWSAccessKeyId'): if self.querystring.get("AWSAccessKeyId"):
return self.querystring.get('AWSAccessKeyId') return self.querystring.get("AWSAccessKeyId")
else: else:
# Should we raise an unauthorized exception instead? # Should we raise an unauthorized exception instead?
return '111122223333' return "111122223333"
def _dispatch(self, request, full_url, headers): def _dispatch(self, request, full_url, headers):
self.setup_class(request, full_url, headers) self.setup_class(request, full_url, headers)
@ -274,17 +303,22 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
-> '^/cars/.*/drivers/.*/drive$' -> '^/cars/.*/drivers/.*/drive$'
""" """
def _convert(elem, is_last):
if not re.match('^{.*}$', elem):
return elem
name = elem.replace('{', '').replace('}', '')
if is_last:
return '(?P<%s>[^/]*)' % name
return '(?P<%s>.*)' % name
elems = uri.split('/') def _convert(elem, is_last):
if not re.match("^{.*}$", elem):
return elem
name = elem.replace("{", "").replace("}", "")
if is_last:
return "(?P<%s>[^/]*)" % name
return "(?P<%s>.*)" % name
elems = uri.split("/")
num_elems = len(elems) num_elems = len(elems)
regexp = '^{}$'.format('/'.join([_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)])) regexp = "^{}$".format(
"/".join(
[_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)]
)
)
return regexp return regexp
def _get_action_from_method_and_request_uri(self, method, request_uri): def _get_action_from_method_and_request_uri(self, method, request_uri):
@ -295,19 +329,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
# service response class should have 'SERVICE_NAME' class member, # service response class should have 'SERVICE_NAME' class member,
# if you want to get action from method and url # if you want to get action from method and url
if not hasattr(self, 'SERVICE_NAME'): if not hasattr(self, "SERVICE_NAME"):
return None return None
service = self.SERVICE_NAME service = self.SERVICE_NAME
conn = boto3.client(service, region_name=self.region) conn = boto3.client(service, region_name=self.region)
# make cache if it does not exist yet # make cache if it does not exist yet
if not hasattr(self, 'method_urls'): if not hasattr(self, "method_urls"):
self.method_urls = defaultdict(lambda: defaultdict(str)) self.method_urls = defaultdict(lambda: defaultdict(str))
op_names = conn._service_model.operation_names op_names = conn._service_model.operation_names
for op_name in op_names: for op_name in op_names:
op_model = conn._service_model.operation_model(op_name) op_model = conn._service_model.operation_model(op_name)
_method = op_model.http['method'] _method = op_model.http["method"]
uri_regexp = self.uri_to_regexp(op_model.http['requestUri']) uri_regexp = self.uri_to_regexp(op_model.http["requestUri"])
self.method_urls[_method][uri_regexp] = op_model.name self.method_urls[_method][uri_regexp] = op_model.name
regexp_and_names = self.method_urls[method] regexp_and_names = self.method_urls[method]
for regexp, name in regexp_and_names.items(): for regexp, name in regexp_and_names.items():
@ -318,11 +352,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return None return None
def _get_action(self): def _get_action(self):
action = self.querystring.get('Action', [""])[0] action = self.querystring.get("Action", [""])[0]
if not action: # Some services use a header for the action if not action: # Some services use a header for the action
# Headers are case-insensitive. Probably a better way to do this. # Headers are case-insensitive. Probably a better way to do this.
match = self.headers.get( match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target")
'x-amz-target') or self.headers.get('X-Amz-Target')
if match: if match:
action = match.split(".")[-1] action = match.split(".")[-1]
# get action from method and uri # get action from method and uri
@ -354,10 +387,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return self._send_response(headers, response) return self._send_response(headers, response)
if not action: if not action:
return 404, headers, '' return 404, headers, ""
raise NotImplementedError( raise NotImplementedError(
"The {0} action has not been implemented".format(action)) "The {0} action has not been implemented".format(action)
)
@staticmethod @staticmethod
def _send_response(headers, response): def _send_response(headers, response):
@ -365,11 +399,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
body, new_headers = response body, new_headers = response
else: else:
status, new_headers, body = response status, new_headers, body = response
status = new_headers.get('status', 200) status = new_headers.get("status", 200)
headers.update(new_headers) headers.update(new_headers)
# Cast status to string # Cast status to string
if "status" in headers: if "status" in headers:
headers['status'] = str(headers['status']) headers["status"] = str(headers["status"])
return status, headers, body return status, headers, body
def _get_param(self, param_name, if_none=None): def _get_param(self, param_name, if_none=None):
@ -403,9 +437,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
def _get_bool_param(self, param_name, if_none=None): def _get_bool_param(self, param_name, if_none=None):
val = self._get_param(param_name) val = self._get_param(param_name)
if val is not None: if val is not None:
if val.lower() == 'true': if val.lower() == "true":
return True return True
elif val.lower() == 'false': elif val.lower() == "false":
return False return False
return if_none return if_none
@ -423,11 +457,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if is_tracked(name) or not name.startswith(param_prefix): if is_tracked(name) or not name.startswith(param_prefix):
continue continue
if len(name) > len(param_prefix) and \ if len(name) > len(param_prefix) and not name[
not name[len(param_prefix):].startswith('.'): len(param_prefix) :
].startswith("."):
continue continue
match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None match = (
self.param_list_regex.search(name[len(param_prefix) :])
if len(name) > len(param_prefix)
else None
)
if match: if match:
prefix = param_prefix + match.group(1) prefix = param_prefix + match.group(1)
value = self._get_multi_param(prefix) value = self._get_multi_param(prefix)
@ -442,7 +481,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
if len(value_dict) > 1: if len(value_dict) > 1:
# strip off period prefix # strip off period prefix
value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()} value_dict = {
name[len(param_prefix) + 1 :]: value
for name, value in value_dict.items()
}
else: else:
value_dict = list(value_dict.values())[0] value_dict = list(value_dict.values())[0]
@ -461,7 +503,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
index = 1 index = 1
while True: while True:
value_dict = self._get_multi_param_helper(prefix + str(index)) value_dict = self._get_multi_param_helper(prefix + str(index))
if not value_dict and value_dict != '': if not value_dict and value_dict != "":
break break
values.append(value_dict) values.append(value_dict)
@ -486,8 +528,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
params = {} params = {}
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(param_prefix): if key.startswith(param_prefix):
params[camelcase_to_underscores( params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[
key.replace(param_prefix, ""))] = value[0] 0
]
return params return params
def _get_list_prefix(self, param_prefix): def _get_list_prefix(self, param_prefix):
@ -520,19 +563,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
new_items = {} new_items = {}
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if key.startswith(index_prefix): if key.startswith(index_prefix):
new_items[camelcase_to_underscores( new_items[
key.replace(index_prefix, ""))] = value[0] camelcase_to_underscores(key.replace(index_prefix, ""))
] = value[0]
if not new_items: if not new_items:
break break
results.append(new_items) results.append(new_items)
param_index += 1 param_index += 1
return results return results
def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'): def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"):
results = {} results = {}
param_index = 1 param_index = 1
while 1: while 1:
index_prefix = '{0}.{1}.'.format(param_prefix, param_index) index_prefix = "{0}.{1}.".format(param_prefix, param_index)
k, v = None, None k, v = None, None
for key, value in self.querystring.items(): for key, value in self.querystring.items():
@ -559,8 +603,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
param_index = 1 param_index = 1
while True: while True:
key_name = 'tag.{0}._key'.format(param_index) key_name = "tag.{0}._key".format(param_index)
value_name = 'tag.{0}._value'.format(param_index) value_name = "tag.{0}._value".format(param_index)
try: try:
results[resource_type][tag[key_name]] = tag[value_name] results[resource_type][tag[key_name]] = tag[value_name]
@ -570,7 +614,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
return results return results
def _get_object_map(self, prefix, name='Name', value='Value'): def _get_object_map(self, prefix, name="Name", value="Value"):
""" """
Given a query dict like Given a query dict like
{ {
@ -598,15 +642,14 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
index = 1 index = 1
while True: while True:
# Loop through looking for keys representing object name # Loop through looking for keys representing object name
name_key = '{0}.{1}.{2}'.format(prefix, index, name) name_key = "{0}.{1}.{2}".format(prefix, index, name)
obj_name = self.querystring.get(name_key) obj_name = self.querystring.get(name_key)
if not obj_name: if not obj_name:
# Found all keys # Found all keys
break break
obj = {} obj = {}
value_key_prefix = '{0}.{1}.{2}.'.format( value_key_prefix = "{0}.{1}.{2}.".format(prefix, index, value)
prefix, index, value)
for k, v in self.querystring.items(): for k, v in self.querystring.items():
if k.startswith(value_key_prefix): if k.startswith(value_key_prefix):
_, value_key = k.split(value_key_prefix, 1) _, value_key = k.split(value_key_prefix, 1)
@ -620,31 +663,46 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin):
@property @property
def request_json(self): def request_json(self):
return 'JSON' in self.querystring.get('ContentType', []) return "JSON" in self.querystring.get("ContentType", [])
def is_not_dryrun(self, action): def is_not_dryrun(self, action):
if 'true' in self.querystring.get('DryRun', ['false']): if "true" in self.querystring.get("DryRun", ["false"]):
message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action message = (
raise DryRunClientError( "An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set"
error_type="DryRunOperation", message=message) % action
)
raise DryRunClientError(error_type="DryRunOperation", message=message)
return True return True
class MotoAPIResponse(BaseResponse): class MotoAPIResponse(BaseResponse):
def reset_response(self, request, full_url, headers): def reset_response(self, request, full_url, headers):
if request.method == "POST": if request.method == "POST":
from .models import moto_api_backend from .models import moto_api_backend
moto_api_backend.reset() moto_api_backend.reset()
return 200, {}, json.dumps({"status": "ok"}) return 200, {}, json.dumps({"status": "ok"})
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"}) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"})
def reset_auth_response(self, request, full_url, headers): def reset_auth_response(self, request, full_url, headers):
if request.method == "POST": if request.method == "POST":
previous_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT previous_initial_no_auth_action_count = (
settings.INITIAL_NO_AUTH_ACTION_COUNT
)
settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode()) settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode())
ActionAuthenticatorMixin.request_count = 0 ActionAuthenticatorMixin.request_count = 0
return 200, {}, json.dumps({"status": "ok", "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(previous_initial_no_auth_action_count)}) return (
200,
{},
json.dumps(
{
"status": "ok",
"PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(
previous_initial_no_auth_action_count
),
}
),
)
return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"}) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"})
def model_data(self, request, full_url, headers): def model_data(self, request, full_url, headers):
@ -672,7 +730,8 @@ class MotoAPIResponse(BaseResponse):
def dashboard(self, request, full_url, headers): def dashboard(self, request, full_url, headers):
from flask import render_template from flask import render_template
return render_template('dashboard.html')
return render_template("dashboard.html")
class _RecursiveDictRef(object): class _RecursiveDictRef(object):
@ -683,7 +742,7 @@ class _RecursiveDictRef(object):
self.dic = {} self.dic = {}
def __repr__(self): def __repr__(self):
return '{!r}'.format(self.dic) return "{!r}".format(self.dic)
def __getattr__(self, key): def __getattr__(self, key):
return self.dic.__getattr__(key) return self.dic.__getattr__(key)
@ -707,21 +766,21 @@ class AWSServiceSpec(object):
""" """
def __init__(self, path): def __init__(self, path):
self.path = resource_filename('botocore', path) self.path = resource_filename("botocore", path)
with io.open(self.path, 'r', encoding='utf-8') as f: with io.open(self.path, "r", encoding="utf-8") as f:
spec = json.load(f) spec = json.load(f)
self.metadata = spec['metadata'] self.metadata = spec["metadata"]
self.operations = spec['operations'] self.operations = spec["operations"]
self.shapes = spec['shapes'] self.shapes = spec["shapes"]
def input_spec(self, operation): def input_spec(self, operation):
try: try:
op = self.operations[operation] op = self.operations[operation]
except KeyError: except KeyError:
raise ValueError('Invalid operation: {}'.format(operation)) raise ValueError("Invalid operation: {}".format(operation))
if 'input' not in op: if "input" not in op:
return {} return {}
shape = self.shapes[op['input']['shape']] shape = self.shapes[op["input"]["shape"]]
return self._expand(shape) return self._expand(shape)
def output_spec(self, operation): def output_spec(self, operation):
@ -735,129 +794,133 @@ class AWSServiceSpec(object):
try: try:
op = self.operations[operation] op = self.operations[operation]
except KeyError: except KeyError:
raise ValueError('Invalid operation: {}'.format(operation)) raise ValueError("Invalid operation: {}".format(operation))
if 'output' not in op: if "output" not in op:
return {} return {}
shape = self.shapes[op['output']['shape']] shape = self.shapes[op["output"]["shape"]]
return self._expand(shape) return self._expand(shape)
def _expand(self, shape): def _expand(self, shape):
def expand(dic, seen=None): def expand(dic, seen=None):
seen = seen or {} seen = seen or {}
if dic['type'] == 'structure': if dic["type"] == "structure":
nodes = {} nodes = {}
for k, v in dic['members'].items(): for k, v in dic["members"].items():
seen_till_here = dict(seen) seen_till_here = dict(seen)
if k in seen_till_here: if k in seen_till_here:
nodes[k] = seen_till_here[k] nodes[k] = seen_till_here[k]
continue continue
seen_till_here[k] = _RecursiveDictRef() seen_till_here[k] = _RecursiveDictRef()
nodes[k] = expand(self.shapes[v['shape']], seen_till_here) nodes[k] = expand(self.shapes[v["shape"]], seen_till_here)
seen_till_here[k].set_reference(k, nodes[k]) seen_till_here[k].set_reference(k, nodes[k])
nodes['type'] = 'structure' nodes["type"] = "structure"
return nodes return nodes
elif dic['type'] == 'list': elif dic["type"] == "list":
seen_till_here = dict(seen) seen_till_here = dict(seen)
shape = dic['member']['shape'] shape = dic["member"]["shape"]
if shape in seen_till_here: if shape in seen_till_here:
return seen_till_here[shape] return seen_till_here[shape]
seen_till_here[shape] = _RecursiveDictRef() seen_till_here[shape] = _RecursiveDictRef()
expanded = expand(self.shapes[shape], seen_till_here) expanded = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, expanded) seen_till_here[shape].set_reference(shape, expanded)
return {'type': 'list', 'member': expanded} return {"type": "list", "member": expanded}
elif dic['type'] == 'map': elif dic["type"] == "map":
seen_till_here = dict(seen) seen_till_here = dict(seen)
node = {'type': 'map'} node = {"type": "map"}
if 'shape' in dic['key']: if "shape" in dic["key"]:
shape = dic['key']['shape'] shape = dic["key"]["shape"]
seen_till_here[shape] = _RecursiveDictRef() seen_till_here[shape] = _RecursiveDictRef()
node['key'] = expand(self.shapes[shape], seen_till_here) node["key"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['key']) seen_till_here[shape].set_reference(shape, node["key"])
else: else:
node['key'] = dic['key']['type'] node["key"] = dic["key"]["type"]
if 'shape' in dic['value']: if "shape" in dic["value"]:
shape = dic['value']['shape'] shape = dic["value"]["shape"]
seen_till_here[shape] = _RecursiveDictRef() seen_till_here[shape] = _RecursiveDictRef()
node['value'] = expand(self.shapes[shape], seen_till_here) node["value"] = expand(self.shapes[shape], seen_till_here)
seen_till_here[shape].set_reference(shape, node['value']) seen_till_here[shape].set_reference(shape, node["value"])
else: else:
node['value'] = dic['value']['type'] node["value"] = dic["value"]["type"]
return node return node
else: else:
return {'type': dic['type']} return {"type": dic["type"]}
return expand(shape) return expand(shape)
def to_str(value, spec): def to_str(value, spec):
vtype = spec['type'] vtype = spec["type"]
if vtype == 'boolean': if vtype == "boolean":
return 'true' if value else 'false' return "true" if value else "false"
elif vtype == 'integer': elif vtype == "integer":
return str(value) return str(value)
elif vtype == 'float': elif vtype == "float":
return str(value) return str(value)
elif vtype == 'double': elif vtype == "double":
return str(value) return str(value)
elif vtype == 'timestamp': elif vtype == "timestamp":
return datetime.datetime.utcfromtimestamp( return (
value).replace(tzinfo=pytz.utc).isoformat() datetime.datetime.utcfromtimestamp(value)
elif vtype == 'string': .replace(tzinfo=pytz.utc)
.isoformat()
)
elif vtype == "string":
return str(value) return str(value)
elif value is None: elif value is None:
return 'null' return "null"
else: else:
raise TypeError('Unknown type {}'.format(vtype)) raise TypeError("Unknown type {}".format(vtype))
def from_str(value, spec): def from_str(value, spec):
vtype = spec['type'] vtype = spec["type"]
if vtype == 'boolean': if vtype == "boolean":
return True if value == 'true' else False return True if value == "true" else False
elif vtype == 'integer': elif vtype == "integer":
return int(value) return int(value)
elif vtype == 'float': elif vtype == "float":
return float(value) return float(value)
elif vtype == 'double': elif vtype == "double":
return float(value) return float(value)
elif vtype == 'timestamp': elif vtype == "timestamp":
return value return value
elif vtype == 'string': elif vtype == "string":
return value return value
raise TypeError('Unknown type {}'.format(vtype)) raise TypeError("Unknown type {}".format(vtype))
def flatten_json_request_body(prefix, dict_body, spec): def flatten_json_request_body(prefix, dict_body, spec):
"""Convert a JSON request body into query params.""" """Convert a JSON request body into query params."""
if len(spec) == 1 and 'type' in spec: if len(spec) == 1 and "type" in spec:
return {prefix: to_str(dict_body, spec)} return {prefix: to_str(dict_body, spec)}
flat = {} flat = {}
for key, value in dict_body.items(): for key, value in dict_body.items():
node_type = spec[key]['type'] node_type = spec[key]["type"]
if node_type == 'list': if node_type == "list":
for idx, v in enumerate(value, 1): for idx, v in enumerate(value, 1):
pref = key + '.member.' + str(idx) pref = key + ".member." + str(idx)
flat.update(flatten_json_request_body( flat.update(flatten_json_request_body(pref, v, spec[key]["member"]))
pref, v, spec[key]['member'])) elif node_type == "map":
elif node_type == 'map':
for idx, (k, v) in enumerate(value.items(), 1): for idx, (k, v) in enumerate(value.items(), 1):
pref = key + '.entry.' + str(idx) pref = key + ".entry." + str(idx)
flat.update(flatten_json_request_body( flat.update(
pref + '.key', k, spec[key]['key'])) flatten_json_request_body(pref + ".key", k, spec[key]["key"])
flat.update(flatten_json_request_body( )
pref + '.value', v, spec[key]['value'])) flat.update(
flatten_json_request_body(pref + ".value", v, spec[key]["value"])
)
else: else:
flat.update(flatten_json_request_body(key, value, spec[key])) flat.update(flatten_json_request_body(key, value, spec[key]))
if prefix: if prefix:
prefix = prefix + '.' prefix = prefix + "."
return dict((prefix + k, v) for k, v in flat.items()) return dict((prefix + k, v) for k, v in flat.items())
@ -880,41 +943,40 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
od = OrderedDict() od = OrderedDict()
for k, v in value.items(): for k, v in value.items():
if k.startswith('@'): if k.startswith("@"):
continue continue
if k not in spec: if k not in spec:
# this can happen when with an older version of # this can happen when with an older version of
# botocore for which the node in XML template is not # botocore for which the node in XML template is not
# defined in service spec. # defined in service spec.
log.warning( log.warning("Field %s is not defined by the botocore version in use", k)
'Field %s is not defined by the botocore version in use', k)
continue continue
if spec[k]['type'] == 'list': if spec[k]["type"] == "list":
if v is None: if v is None:
od[k] = [] od[k] = []
elif len(spec[k]['member']) == 1: elif len(spec[k]["member"]) == 1:
if isinstance(v['member'], list): if isinstance(v["member"], list):
od[k] = transform(v['member'], spec[k]['member']) od[k] = transform(v["member"], spec[k]["member"])
else: else:
od[k] = [transform(v['member'], spec[k]['member'])] od[k] = [transform(v["member"], spec[k]["member"])]
elif isinstance(v['member'], list): elif isinstance(v["member"], list):
od[k] = [transform(o, spec[k]['member']) od[k] = [transform(o, spec[k]["member"]) for o in v["member"]]
for o in v['member']] elif isinstance(v["member"], OrderedDict):
elif isinstance(v['member'], OrderedDict): od[k] = [transform(v["member"], spec[k]["member"])]
od[k] = [transform(v['member'], spec[k]['member'])]
else: else:
raise ValueError('Malformatted input') raise ValueError("Malformatted input")
elif spec[k]['type'] == 'map': elif spec[k]["type"] == "map":
if v is None: if v is None:
od[k] = {} od[k] = {}
else: else:
items = ([v['entry']] if not isinstance(v['entry'], list) else items = (
v['entry']) [v["entry"]] if not isinstance(v["entry"], list) else v["entry"]
)
for item in items: for item in items:
key = from_str(item['key'], spec[k]['key']) key = from_str(item["key"], spec[k]["key"])
val = from_str(item['value'], spec[k]['value']) val = from_str(item["value"], spec[k]["value"])
if k not in od: if k not in od:
od[k] = {} od[k] = {}
od[k][key] = val od[k][key] = val
@ -928,7 +990,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None):
dic = xmltodict.parse(xml) dic = xmltodict.parse(xml)
output_spec = service_spec.output_spec(operation) output_spec = service_spec.output_spec(operation)
try: try:
for k in (result_node or (operation + 'Response', operation + 'Result')): for k in result_node or (operation + "Response", operation + "Result"):
dic = dic[k] dic = dic[k]
except KeyError: except KeyError:
return None return None

View File

@ -1,15 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import MotoAPIResponse from .responses import MotoAPIResponse
url_bases = [ url_bases = ["https?://motoapi.amazonaws.com"]
"https?://motoapi.amazonaws.com"
]
response_instance = MotoAPIResponse() response_instance = MotoAPIResponse()
url_paths = { url_paths = {
'{0}/moto-api/$': response_instance.dashboard, "{0}/moto-api/$": response_instance.dashboard,
'{0}/moto-api/data.json': response_instance.model_data, "{0}/moto-api/data.json": response_instance.model_data,
'{0}/moto-api/reset': response_instance.reset_response, "{0}/moto-api/reset": response_instance.reset_response,
'{0}/moto-api/reset-auth': response_instance.reset_auth_response, "{0}/moto-api/reset-auth": response_instance.reset_auth_response,
} }

View File

@ -15,9 +15,9 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase
def camelcase_to_underscores(argument): def camelcase_to_underscores(argument):
''' Converts a camelcase param like theNewAttribute to the equivalent """ Converts a camelcase param like theNewAttribute to the equivalent
python underscore variable like the_new_attribute''' python underscore variable like the_new_attribute"""
result = '' result = ""
prev_char_title = True prev_char_title = True
if not argument: if not argument:
return argument return argument
@ -41,18 +41,18 @@ def camelcase_to_underscores(argument):
def underscores_to_camelcase(argument): def underscores_to_camelcase(argument):
''' Converts a camelcase param like the_new_attribute to the equivalent """ Converts a camelcase param like the_new_attribute to the equivalent
camelcase version like theNewAttribute. Note that the first letter is camelcase version like theNewAttribute. Note that the first letter is
NOT capitalized by this function ''' NOT capitalized by this function """
result = '' result = ""
previous_was_underscore = False previous_was_underscore = False
for char in argument: for char in argument:
if char != '_': if char != "_":
if previous_was_underscore: if previous_was_underscore:
result += char.upper() result += char.upper()
else: else:
result += char result += char
previous_was_underscore = char == '_' previous_was_underscore = char == "_"
return result return result
@ -69,12 +69,18 @@ def method_names_from_class(clazz):
def get_random_hex(length=8): def get_random_hex(length=8):
chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f'] chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"]
return ''.join(six.text_type(random.choice(chars)) for x in range(length)) return "".join(six.text_type(random.choice(chars)) for x in range(length))
def get_random_message_id(): def get_random_message_id():
return '{0}-{1}-{2}-{3}-{4}'.format(get_random_hex(8), get_random_hex(4), get_random_hex(4), get_random_hex(4), get_random_hex(12)) return "{0}-{1}-{2}-{3}-{4}".format(
get_random_hex(8),
get_random_hex(4),
get_random_hex(4),
get_random_hex(4),
get_random_hex(12),
)
def convert_regex_to_flask_path(url_path): def convert_regex_to_flask_path(url_path):
@ -97,7 +103,6 @@ def convert_regex_to_flask_path(url_path):
class convert_httpretty_response(object): class convert_httpretty_response(object):
def __init__(self, callback): def __init__(self, callback):
self.callback = callback self.callback = callback
@ -114,13 +119,12 @@ class convert_httpretty_response(object):
def __call__(self, request, url, headers, **kwargs): def __call__(self, request, url, headers, **kwargs):
result = self.callback(request, url, headers) result = self.callback(request, url, headers)
status, headers, response = result status, headers, response = result
if 'server' not in headers: if "server" not in headers:
headers["server"] = "amazon.com" headers["server"] = "amazon.com"
return status, headers, response return status, headers, response
class convert_flask_to_httpretty_response(object): class convert_flask_to_httpretty_response(object):
def __init__(self, callback): def __init__(self, callback):
self.callback = callback self.callback = callback
@ -145,13 +149,12 @@ class convert_flask_to_httpretty_response(object):
status, headers, content = 200, {}, result status, headers, content = 200, {}, result
response = Response(response=content, status=status, headers=headers) response = Response(response=content, status=status, headers=headers)
if request.method == "HEAD" and 'content-length' in headers: if request.method == "HEAD" and "content-length" in headers:
response.headers['Content-Length'] = headers['content-length'] response.headers["Content-Length"] = headers["content-length"]
return response return response
class convert_flask_to_responses_response(object): class convert_flask_to_responses_response(object):
def __init__(self, callback): def __init__(self, callback):
self.callback = callback self.callback = callback
@ -176,14 +179,14 @@ class convert_flask_to_responses_response(object):
def iso_8601_datetime_with_milliseconds(datetime): def iso_8601_datetime_with_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + 'Z' return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z"
def iso_8601_datetime_without_milliseconds(datetime): def iso_8601_datetime_without_milliseconds(datetime):
return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z' return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z"
RFC1123 = '%a, %d %b %Y %H:%M:%S GMT' RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
def rfc_1123_datetime(datetime): def rfc_1123_datetime(datetime):
@ -212,16 +215,16 @@ def gen_amz_crc32(response, headerdict=None):
crc = str(binascii.crc32(response)) crc = str(binascii.crc32(response))
if headerdict is not None and isinstance(headerdict, dict): if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amz-crc32': crc}) headerdict.update({"x-amz-crc32": crc})
return crc return crc
def gen_amzn_requestid_long(headerdict=None): def gen_amzn_requestid_long(headerdict=None):
req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)]) req_id = "".join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)])
if headerdict is not None and isinstance(headerdict, dict): if headerdict is not None and isinstance(headerdict, dict):
headerdict.update({'x-amzn-requestid': req_id}) headerdict.update({"x-amzn-requestid": req_id})
return req_id return req_id
@ -239,13 +242,13 @@ def amz_crc32(f):
else: else:
if len(response) == 2: if len(response) == 2:
body, new_headers = response body, new_headers = response
status = new_headers.get('status', 200) status = new_headers.get("status", 200)
else: else:
status, new_headers, body = response status, new_headers, body = response
headers.update(new_headers) headers.update(new_headers)
# Cast status to string # Cast status to string
if "status" in headers: if "status" in headers:
headers['status'] = str(headers['status']) headers["status"] = str(headers["status"])
try: try:
# Doesnt work on python2 for some odd unicode strings # Doesnt work on python2 for some odd unicode strings
@ -271,7 +274,7 @@ def amzn_request_id(f):
else: else:
if len(response) == 2: if len(response) == 2:
body, new_headers = response body, new_headers = response
status = new_headers.get('status', 200) status = new_headers.get("status", 200)
else: else:
status, new_headers, body = response status, new_headers, body = response
headers.update(new_headers) headers.update(new_headers)
@ -280,7 +283,7 @@ def amzn_request_id(f):
# Update request ID in XML # Update request ID in XML
try: try:
body = re.sub(r'(?<=<RequestId>).*(?=<\/RequestId>)', request_id, body) body = re.sub(r"(?<=<RequestId>).*(?=<\/RequestId>)", request_id, body)
except Exception: # Will just ignore if it cant work on bytes (which are str's on python2) except Exception: # Will just ignore if it cant work on bytes (which are str's on python2)
pass pass
@ -293,7 +296,7 @@ def path_url(url):
parsed_url = urlparse(url) parsed_url = urlparse(url)
path = parsed_url.path path = parsed_url.path
if not path: if not path:
path = '/' path = "/"
if parsed_url.query: if parsed_url.query:
path = path + '?' + parsed_url.query path = path + "?" + parsed_url.query
return path return path

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import datapipeline_backends from .models import datapipeline_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
datapipeline_backend = datapipeline_backends['us-east-1'] datapipeline_backend = datapipeline_backends["us-east-1"]
mock_datapipeline = base_decorator(datapipeline_backends) mock_datapipeline = base_decorator(datapipeline_backends)
mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends) mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends)

View File

@ -8,85 +8,65 @@ from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys
class PipelineObject(BaseModel): class PipelineObject(BaseModel):
def __init__(self, object_id, name, fields): def __init__(self, object_id, name, fields):
self.object_id = object_id self.object_id = object_id
self.name = name self.name = name
self.fields = fields self.fields = fields
def to_json(self): def to_json(self):
return { return {"fields": self.fields, "id": self.object_id, "name": self.name}
"fields": self.fields,
"id": self.object_id,
"name": self.name,
}
class Pipeline(BaseModel): class Pipeline(BaseModel):
def __init__(self, name, unique_id, **kwargs): def __init__(self, name, unique_id, **kwargs):
self.name = name self.name = name
self.unique_id = unique_id self.unique_id = unique_id
self.description = kwargs.get('description', '') self.description = kwargs.get("description", "")
self.pipeline_id = get_random_pipeline_id() self.pipeline_id = get_random_pipeline_id()
self.creation_time = datetime.datetime.utcnow() self.creation_time = datetime.datetime.utcnow()
self.objects = [] self.objects = []
self.status = "PENDING" self.status = "PENDING"
self.tags = kwargs.get('tags', []) self.tags = kwargs.get("tags", [])
@property @property
def physical_resource_id(self): def physical_resource_id(self):
return self.pipeline_id return self.pipeline_id
def to_meta_json(self): def to_meta_json(self):
return { return {"id": self.pipeline_id, "name": self.name}
"id": self.pipeline_id,
"name": self.name,
}
def to_json(self): def to_json(self):
return { return {
"description": self.description, "description": self.description,
"fields": [{ "fields": [
"key": "@pipelineState", {"key": "@pipelineState", "stringValue": self.status},
"stringValue": self.status, {"key": "description", "stringValue": self.description},
}, { {"key": "name", "stringValue": self.name},
"key": "description", {
"stringValue": self.description
}, {
"key": "name",
"stringValue": self.name
}, {
"key": "@creationTime", "key": "@creationTime",
"stringValue": datetime.datetime.strftime(self.creation_time, '%Y-%m-%dT%H-%M-%S'), "stringValue": datetime.datetime.strftime(
}, { self.creation_time, "%Y-%m-%dT%H-%M-%S"
"key": "@id", ),
"stringValue": self.pipeline_id, },
}, { {"key": "@id", "stringValue": self.pipeline_id},
"key": "@sphere", {"key": "@sphere", "stringValue": "PIPELINE"},
"stringValue": "PIPELINE" {"key": "@version", "stringValue": "1"},
}, { {"key": "@userId", "stringValue": "924374875933"},
"key": "@version", {"key": "@accountId", "stringValue": "924374875933"},
"stringValue": "1" {"key": "uniqueId", "stringValue": self.unique_id},
}, { ],
"key": "@userId",
"stringValue": "924374875933"
}, {
"key": "@accountId",
"stringValue": "924374875933"
}, {
"key": "uniqueId",
"stringValue": self.unique_id
}],
"name": self.name, "name": self.name,
"pipelineId": self.pipeline_id, "pipelineId": self.pipeline_id,
"tags": self.tags "tags": self.tags,
} }
def set_pipeline_objects(self, pipeline_objects): def set_pipeline_objects(self, pipeline_objects):
self.objects = [ self.objects = [
PipelineObject(pipeline_object['id'], pipeline_object[ PipelineObject(
'name'], pipeline_object['fields']) pipeline_object["id"],
pipeline_object["name"],
pipeline_object["fields"],
)
for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects) for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects)
] ]
@ -94,15 +74,19 @@ class Pipeline(BaseModel):
self.status = "SCHEDULED" self.status = "SCHEDULED"
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
cls, resource_name, cloudformation_json, region_name
):
datapipeline_backend = datapipeline_backends[region_name] datapipeline_backend = datapipeline_backends[region_name]
properties = cloudformation_json["Properties"] properties = cloudformation_json["Properties"]
cloudformation_unique_id = "cf-" + properties["Name"] cloudformation_unique_id = "cf-" + properties["Name"]
pipeline = datapipeline_backend.create_pipeline( pipeline = datapipeline_backend.create_pipeline(
properties["Name"], cloudformation_unique_id) properties["Name"], cloudformation_unique_id
)
datapipeline_backend.put_pipeline_definition( datapipeline_backend.put_pipeline_definition(
pipeline.pipeline_id, properties["PipelineObjects"]) pipeline.pipeline_id, properties["PipelineObjects"]
)
if properties["Activate"]: if properties["Activate"]:
pipeline.activate() pipeline.activate()
@ -110,7 +94,6 @@ class Pipeline(BaseModel):
class DataPipelineBackend(BaseBackend): class DataPipelineBackend(BaseBackend):
def __init__(self): def __init__(self):
self.pipelines = OrderedDict() self.pipelines = OrderedDict()
@ -123,8 +106,11 @@ class DataPipelineBackend(BaseBackend):
return self.pipelines.values() return self.pipelines.values()
def describe_pipelines(self, pipeline_ids): def describe_pipelines(self, pipeline_ids):
pipelines = [pipeline for pipeline in self.pipelines.values( pipelines = [
) if pipeline.pipeline_id in pipeline_ids] pipeline
for pipeline in self.pipelines.values()
if pipeline.pipeline_id in pipeline_ids
]
return pipelines return pipelines
def get_pipeline(self, pipeline_id): def get_pipeline(self, pipeline_id):
@ -144,7 +130,8 @@ class DataPipelineBackend(BaseBackend):
def describe_objects(self, object_ids, pipeline_id): def describe_objects(self, object_ids, pipeline_id):
pipeline = self.get_pipeline(pipeline_id) pipeline = self.get_pipeline(pipeline_id)
pipeline_objects = [ pipeline_objects = [
pipeline_object for pipeline_object in pipeline.objects pipeline_object
for pipeline_object in pipeline.objects
if pipeline_object.object_id in object_ids if pipeline_object.object_id in object_ids
] ]
return pipeline_objects return pipeline_objects

View File

@ -7,7 +7,6 @@ from .models import datapipeline_backends
class DataPipelineResponse(BaseResponse): class DataPipelineResponse(BaseResponse):
@property @property
def parameters(self): def parameters(self):
# TODO this should really be moved to core/responses.py # TODO this should really be moved to core/responses.py
@ -21,47 +20,47 @@ class DataPipelineResponse(BaseResponse):
return datapipeline_backends[self.region] return datapipeline_backends[self.region]
def create_pipeline(self): def create_pipeline(self):
name = self.parameters.get('name') name = self.parameters.get("name")
unique_id = self.parameters.get('uniqueId') unique_id = self.parameters.get("uniqueId")
description = self.parameters.get('description', '') description = self.parameters.get("description", "")
tags = self.parameters.get('tags', []) tags = self.parameters.get("tags", [])
pipeline = self.datapipeline_backend.create_pipeline(name, unique_id, description=description, tags=tags) pipeline = self.datapipeline_backend.create_pipeline(
return json.dumps({ name, unique_id, description=description, tags=tags
"pipelineId": pipeline.pipeline_id, )
}) return json.dumps({"pipelineId": pipeline.pipeline_id})
def list_pipelines(self): def list_pipelines(self):
pipelines = list(self.datapipeline_backend.list_pipelines()) pipelines = list(self.datapipeline_backend.list_pipelines())
pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines] pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines]
max_pipelines = 50 max_pipelines = 50
marker = self.parameters.get('marker') marker = self.parameters.get("marker")
if marker: if marker:
start = pipeline_ids.index(marker) + 1 start = pipeline_ids.index(marker) + 1
else: else:
start = 0 start = 0
pipelines_resp = pipelines[start:start + max_pipelines] pipelines_resp = pipelines[start : start + max_pipelines]
has_more_results = False has_more_results = False
marker = None marker = None
if start + max_pipelines < len(pipeline_ids) - 1: if start + max_pipelines < len(pipeline_ids) - 1:
has_more_results = True has_more_results = True
marker = pipelines_resp[-1].pipeline_id marker = pipelines_resp[-1].pipeline_id
return json.dumps({ return json.dumps(
{
"hasMoreResults": has_more_results, "hasMoreResults": has_more_results,
"marker": marker, "marker": marker,
"pipelineIdList": [ "pipelineIdList": [
pipeline.to_meta_json() for pipeline in pipelines_resp pipeline.to_meta_json() for pipeline in pipelines_resp
] ],
}) }
)
def describe_pipelines(self): def describe_pipelines(self):
pipeline_ids = self.parameters["pipelineIds"] pipeline_ids = self.parameters["pipelineIds"]
pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids) pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids)
return json.dumps({ return json.dumps(
"pipelineDescriptionList": [ {"pipelineDescriptionList": [pipeline.to_json() for pipeline in pipelines]}
pipeline.to_json() for pipeline in pipelines )
]
})
def delete_pipeline(self): def delete_pipeline(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
@ -72,31 +71,38 @@ class DataPipelineResponse(BaseResponse):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
pipeline_objects = self.parameters["pipelineObjects"] pipeline_objects = self.parameters["pipelineObjects"]
self.datapipeline_backend.put_pipeline_definition( self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects)
pipeline_id, pipeline_objects)
return json.dumps({"errored": False}) return json.dumps({"errored": False})
def get_pipeline_definition(self): def get_pipeline_definition(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
pipeline_definition = self.datapipeline_backend.get_pipeline_definition( pipeline_definition = self.datapipeline_backend.get_pipeline_definition(
pipeline_id) pipeline_id
return json.dumps({ )
"pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition] return json.dumps(
}) {
"pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_definition
]
}
)
def describe_objects(self): def describe_objects(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]
object_ids = self.parameters["objectIds"] object_ids = self.parameters["objectIds"]
pipeline_objects = self.datapipeline_backend.describe_objects( pipeline_objects = self.datapipeline_backend.describe_objects(
object_ids, pipeline_id) object_ids, pipeline_id
return json.dumps({ )
return json.dumps(
{
"hasMoreResults": False, "hasMoreResults": False,
"marker": None, "marker": None,
"pipelineObjects": [ "pipelineObjects": [
pipeline_object.to_json() for pipeline_object in pipeline_objects pipeline_object.to_json() for pipeline_object in pipeline_objects
] ],
}) }
)
def activate_pipeline(self): def activate_pipeline(self):
pipeline_id = self.parameters["pipelineId"] pipeline_id = self.parameters["pipelineId"]

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import DataPipelineResponse from .responses import DataPipelineResponse
url_bases = [ url_bases = ["https?://datapipeline.(.+).amazonaws.com"]
"https?://datapipeline.(.+).amazonaws.com",
]
url_paths = { url_paths = {"{0}/$": DataPipelineResponse.dispatch}
'{0}/$': DataPipelineResponse.dispatch,
}

View File

@ -14,7 +14,9 @@ def remove_capitalization_of_dict_keys(obj):
normalized_key = key[:1].lower() + key[1:] normalized_key = key[:1].lower() + key[1:]
result[normalized_key] = remove_capitalization_of_dict_keys(value) result[normalized_key] = remove_capitalization_of_dict_keys(value)
return result return result
elif isinstance(obj, collections.Iterable) and not isinstance(obj, six.string_types): elif isinstance(obj, collections.Iterable) and not isinstance(
obj, six.string_types
):
result = obj.__class__() result = obj.__class__()
for item in obj: for item in obj:
result += (remove_capitalization_of_dict_keys(item),) result += (remove_capitalization_of_dict_keys(item),)

View File

@ -1,19 +1,22 @@
from __future__ import unicode_literals from __future__ import unicode_literals
# TODO add tests for all of these # TODO add tests for all of these
COMPARISON_FUNCS = { COMPARISON_FUNCS = {
'EQ': lambda item_value, test_value: item_value == test_value, "EQ": lambda item_value, test_value: item_value == test_value,
'NE': lambda item_value, test_value: item_value != test_value, "NE": lambda item_value, test_value: item_value != test_value,
'LE': lambda item_value, test_value: item_value <= test_value, "LE": lambda item_value, test_value: item_value <= test_value,
'LT': lambda item_value, test_value: item_value < test_value, "LT": lambda item_value, test_value: item_value < test_value,
'GE': lambda item_value, test_value: item_value >= test_value, "GE": lambda item_value, test_value: item_value >= test_value,
'GT': lambda item_value, test_value: item_value > test_value, "GT": lambda item_value, test_value: item_value > test_value,
'NULL': lambda item_value: item_value is None, "NULL": lambda item_value: item_value is None,
'NOT_NULL': lambda item_value: item_value is not None, "NOT_NULL": lambda item_value: item_value is not None,
'CONTAINS': lambda item_value, test_value: test_value in item_value, "CONTAINS": lambda item_value, test_value: test_value in item_value,
'NOT_CONTAINS': lambda item_value, test_value: test_value not in item_value, "NOT_CONTAINS": lambda item_value, test_value: test_value not in item_value,
'BEGINS_WITH': lambda item_value, test_value: item_value.startswith(test_value), "BEGINS_WITH": lambda item_value, test_value: item_value.startswith(test_value),
'IN': lambda item_value, *test_values: item_value in test_values, "IN": lambda item_value, *test_values: item_value in test_values,
'BETWEEN': lambda item_value, lower_test_value, upper_test_value: lower_test_value <= item_value <= upper_test_value, "BETWEEN": lambda item_value, lower_test_value, upper_test_value: lower_test_value
<= item_value
<= upper_test_value,
} }

View File

@ -10,9 +10,8 @@ from .comparisons import get_comparison_func
class DynamoJsonEncoder(json.JSONEncoder): class DynamoJsonEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
if hasattr(obj, 'to_json'): if hasattr(obj, "to_json"):
return obj.to_json() return obj.to_json()
@ -33,10 +32,7 @@ class DynamoType(object):
return hash((self.type, self.value)) return hash((self.type, self.value))
def __eq__(self, other): def __eq__(self, other):
return ( return self.type == other.type and self.value == other.value
self.type == other.type and
self.value == other.value
)
def __repr__(self): def __repr__(self):
return "DynamoType: {0}".format(self.to_json()) return "DynamoType: {0}".format(self.to_json())
@ -54,7 +50,6 @@ class DynamoType(object):
class Item(BaseModel): class Item(BaseModel):
def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs):
self.hash_key = hash_key self.hash_key = hash_key
self.hash_key_type = hash_key_type self.hash_key_type = hash_key_type
@ -73,9 +68,7 @@ class Item(BaseModel):
for attribute_key, attribute in self.attrs.items(): for attribute_key, attribute in self.attrs.items():
attributes[attribute_key] = attribute.value attributes[attribute_key] = attribute.value
return { return {"Attributes": attributes}
"Attributes": attributes
}
def describe_attrs(self, attributes): def describe_attrs(self, attributes):
if attributes: if attributes:
@ -85,16 +78,20 @@ class Item(BaseModel):
included[key] = value included[key] = value
else: else:
included = self.attrs included = self.attrs
return { return {"Item": included}
"Item": included
}
class Table(BaseModel): class Table(BaseModel):
def __init__(
def __init__(self, name, hash_key_attr, hash_key_type, self,
range_key_attr=None, range_key_type=None, read_capacity=None, name,
write_capacity=None): hash_key_attr,
hash_key_type,
range_key_attr=None,
range_key_type=None,
read_capacity=None,
write_capacity=None,
):
self.name = name self.name = name
self.hash_key_attr = hash_key_attr self.hash_key_attr = hash_key_attr
self.hash_key_type = hash_key_type self.hash_key_type = hash_key_type
@ -117,12 +114,12 @@ class Table(BaseModel):
"KeySchema": { "KeySchema": {
"HashKeyElement": { "HashKeyElement": {
"AttributeName": self.hash_key_attr, "AttributeName": self.hash_key_attr,
"AttributeType": self.hash_key_type "AttributeType": self.hash_key_type,
}, }
}, },
"ProvisionedThroughput": { "ProvisionedThroughput": {
"ReadCapacityUnits": self.read_capacity, "ReadCapacityUnits": self.read_capacity,
"WriteCapacityUnits": self.write_capacity "WriteCapacityUnits": self.write_capacity,
}, },
"TableName": self.name, "TableName": self.name,
"TableStatus": "ACTIVE", "TableStatus": "ACTIVE",
@ -133,19 +130,29 @@ class Table(BaseModel):
if self.has_range_key: if self.has_range_key:
results["Table"]["KeySchema"]["RangeKeyElement"] = { results["Table"]["KeySchema"]["RangeKeyElement"] = {
"AttributeName": self.range_key_attr, "AttributeName": self.range_key_attr,
"AttributeType": self.range_key_type "AttributeType": self.range_key_type,
} }
return results return results
@classmethod @classmethod
def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): def create_from_cloudformation_json(
properties = cloudformation_json['Properties'] cls, resource_name, cloudformation_json, region_name
key_attr = [i['AttributeName'] for i in properties['KeySchema'] if i['KeyType'] == 'HASH'][0] ):
key_type = [i['AttributeType'] for i in properties['AttributeDefinitions'] if i['AttributeName'] == key_attr][0] properties = cloudformation_json["Properties"]
key_attr = [
i["AttributeName"]
for i in properties["KeySchema"]
if i["KeyType"] == "HASH"
][0]
key_type = [
i["AttributeType"]
for i in properties["AttributeDefinitions"]
if i["AttributeName"] == key_attr
][0]
spec = { spec = {
'name': properties['TableName'], "name": properties["TableName"],
'hash_key_attr': key_attr, "hash_key_attr": key_attr,
'hash_key_type': key_type "hash_key_type": key_type,
} }
# TODO: optional properties still missing: # TODO: optional properties still missing:
# range_key_attr, range_key_type, read_capacity, write_capacity # range_key_attr, range_key_type, read_capacity, write_capacity
@ -173,8 +180,9 @@ class Table(BaseModel):
else: else:
range_value = None range_value = None
item = Item(hash_value, self.hash_key_type, range_value, item = Item(
self.range_key_type, item_attrs) hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs
)
if range_value: if range_value:
self.items[hash_value][range_value] = item self.items[hash_value][range_value] = item
@ -185,7 +193,8 @@ class Table(BaseModel):
def get_item(self, hash_key, range_key): def get_item(self, hash_key, range_key):
if self.has_range_key and not range_key: if self.has_range_key and not range_key:
raise ValueError( raise ValueError(
"Table has a range key, but no range key was passed into get_item") "Table has a range key, but no range key was passed into get_item"
)
try: try:
if range_key: if range_key:
return self.items[hash_key][range_key] return self.items[hash_key][range_key]
@ -228,7 +237,10 @@ class Table(BaseModel):
for result in self.all_items(): for result in self.all_items():
scanned_count += 1 scanned_count += 1
passes_all_conditions = True passes_all_conditions = True
for attribute_name, (comparison_operator, comparison_objs) in filters.items(): for (
attribute_name,
(comparison_operator, comparison_objs),
) in filters.items():
attribute = result.attrs.get(attribute_name) attribute = result.attrs.get(attribute_name)
if attribute: if attribute:
@ -236,7 +248,7 @@ class Table(BaseModel):
if not attribute.compare(comparison_operator, comparison_objs): if not attribute.compare(comparison_operator, comparison_objs):
passes_all_conditions = False passes_all_conditions = False
break break
elif comparison_operator == 'NULL': elif comparison_operator == "NULL":
# Comparison is NULL and we don't have the attribute # Comparison is NULL and we don't have the attribute
continue continue
else: else:
@ -261,15 +273,17 @@ class Table(BaseModel):
def get_cfn_attribute(self, attribute_name): def get_cfn_attribute(self, attribute_name):
from moto.cloudformation.exceptions import UnformattedGetAttTemplateException from moto.cloudformation.exceptions import UnformattedGetAttTemplateException
if attribute_name == 'StreamArn':
region = 'us-east-1' if attribute_name == "StreamArn":
time = '2000-01-01T00:00:00.000' region = "us-east-1"
return 'arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}'.format(region, self.name, time) time = "2000-01-01T00:00:00.000"
return "arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}".format(
region, self.name, time
)
raise UnformattedGetAttTemplateException() raise UnformattedGetAttTemplateException()
class DynamoDBBackend(BaseBackend): class DynamoDBBackend(BaseBackend):
def __init__(self): def __init__(self):
self.tables = OrderedDict() self.tables = OrderedDict()
@ -310,8 +324,7 @@ class DynamoDBBackend(BaseBackend):
return None, None return None, None
hash_key = DynamoType(hash_key_dict) hash_key = DynamoType(hash_key_dict)
range_values = [DynamoType(range_value) range_values = [DynamoType(range_value) for range_value in range_value_dicts]
for range_value in range_value_dicts]
return table.query(hash_key, range_comparison, range_values) return table.query(hash_key, range_comparison, range_values)

View File

@ -8,7 +8,6 @@ from .models import dynamodb_backend, dynamo_json_dump
class DynamoHandler(BaseResponse): class DynamoHandler(BaseResponse):
def get_endpoint_name(self, headers): def get_endpoint_name(self, headers):
"""Parses request headers and extracts part od the X-Amz-Target """Parses request headers and extracts part od the X-Amz-Target
that corresponds to a method of DynamoHandler that corresponds to a method of DynamoHandler
@ -16,15 +15,15 @@ class DynamoHandler(BaseResponse):
ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables
""" """
# Headers are case-insensitive. Probably a better way to do this. # Headers are case-insensitive. Probably a better way to do this.
match = headers.get('x-amz-target') or headers.get('X-Amz-Target') match = headers.get("x-amz-target") or headers.get("X-Amz-Target")
if match: if match:
return match.split(".")[1] return match.split(".")[1]
def error(self, type_, status=400): def error(self, type_, status=400):
return status, self.response_headers, dynamo_json_dump({'__type': type_}) return status, self.response_headers, dynamo_json_dump({"__type": type_})
def call_action(self): def call_action(self):
self.body = json.loads(self.body or '{}') self.body = json.loads(self.body or "{}")
endpoint = self.get_endpoint_name(self.headers) endpoint = self.get_endpoint_name(self.headers)
if endpoint: if endpoint:
endpoint = camelcase_to_underscores(endpoint) endpoint = camelcase_to_underscores(endpoint)
@ -41,7 +40,7 @@ class DynamoHandler(BaseResponse):
def list_tables(self): def list_tables(self):
body = self.body body = self.body
limit = body.get('Limit') limit = body.get("Limit")
if body.get("ExclusiveStartTableName"): if body.get("ExclusiveStartTableName"):
last = body.get("ExclusiveStartTableName") last = body.get("ExclusiveStartTableName")
start = list(dynamodb_backend.tables.keys()).index(last) + 1 start = list(dynamodb_backend.tables.keys()).index(last) + 1
@ -49,7 +48,7 @@ class DynamoHandler(BaseResponse):
start = 0 start = 0
all_tables = list(dynamodb_backend.tables.keys()) all_tables = list(dynamodb_backend.tables.keys())
if limit: if limit:
tables = all_tables[start:start + limit] tables = all_tables[start : start + limit]
else: else:
tables = all_tables[start:] tables = all_tables[start:]
response = {"TableNames": tables} response = {"TableNames": tables}
@ -59,16 +58,16 @@ class DynamoHandler(BaseResponse):
def create_table(self): def create_table(self):
body = self.body body = self.body
name = body['TableName'] name = body["TableName"]
key_schema = body['KeySchema'] key_schema = body["KeySchema"]
hash_key = key_schema['HashKeyElement'] hash_key = key_schema["HashKeyElement"]
hash_key_attr = hash_key['AttributeName'] hash_key_attr = hash_key["AttributeName"]
hash_key_type = hash_key['AttributeType'] hash_key_type = hash_key["AttributeType"]
range_key = key_schema.get('RangeKeyElement', {}) range_key = key_schema.get("RangeKeyElement", {})
range_key_attr = range_key.get('AttributeName') range_key_attr = range_key.get("AttributeName")
range_key_type = range_key.get('AttributeType') range_key_type = range_key.get("AttributeType")
throughput = body["ProvisionedThroughput"] throughput = body["ProvisionedThroughput"]
read_units = throughput["ReadCapacityUnits"] read_units = throughput["ReadCapacityUnits"]
@ -86,137 +85,131 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(table.describe) return dynamo_json_dump(table.describe)
def delete_table(self): def delete_table(self):
name = self.body['TableName'] name = self.body["TableName"]
table = dynamodb_backend.delete_table(name) table = dynamodb_backend.delete_table(name)
if table: if table:
return dynamo_json_dump(table.describe) return dynamo_json_dump(table.describe)
else: else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er) return self.error(er)
def update_table(self): def update_table(self):
name = self.body['TableName'] name = self.body["TableName"]
throughput = self.body["ProvisionedThroughput"] throughput = self.body["ProvisionedThroughput"]
new_read_units = throughput["ReadCapacityUnits"] new_read_units = throughput["ReadCapacityUnits"]
new_write_units = throughput["WriteCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"]
table = dynamodb_backend.update_table_throughput( table = dynamodb_backend.update_table_throughput(
name, new_read_units, new_write_units) name, new_read_units, new_write_units
)
return dynamo_json_dump(table.describe) return dynamo_json_dump(table.describe)
def describe_table(self): def describe_table(self):
name = self.body['TableName'] name = self.body["TableName"]
try: try:
table = dynamodb_backend.tables[name] table = dynamodb_backend.tables[name]
except KeyError: except KeyError:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er) return self.error(er)
return dynamo_json_dump(table.describe) return dynamo_json_dump(table.describe)
def put_item(self): def put_item(self):
name = self.body['TableName'] name = self.body["TableName"]
item = self.body['Item'] item = self.body["Item"]
result = dynamodb_backend.put_item(name, item) result = dynamodb_backend.put_item(name, item)
if result: if result:
item_dict = result.to_json() item_dict = result.to_json()
item_dict['ConsumedCapacityUnits'] = 1 item_dict["ConsumedCapacityUnits"] = 1
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
else: else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er) return self.error(er)
def batch_write_item(self): def batch_write_item(self):
table_batches = self.body['RequestItems'] table_batches = self.body["RequestItems"]
for table_name, table_requests in table_batches.items(): for table_name, table_requests in table_batches.items():
for table_request in table_requests: for table_request in table_requests:
request_type = list(table_request)[0] request_type = list(table_request)[0]
request = list(table_request.values())[0] request = list(table_request.values())[0]
if request_type == 'PutRequest': if request_type == "PutRequest":
item = request['Item'] item = request["Item"]
dynamodb_backend.put_item(table_name, item) dynamodb_backend.put_item(table_name, item)
elif request_type == 'DeleteRequest': elif request_type == "DeleteRequest":
key = request['Key'] key = request["Key"]
hash_key = key['HashKeyElement'] hash_key = key["HashKeyElement"]
range_key = key.get('RangeKeyElement') range_key = key.get("RangeKeyElement")
item = dynamodb_backend.delete_item( item = dynamodb_backend.delete_item(table_name, hash_key, range_key)
table_name, hash_key, range_key)
response = { response = {
"Responses": { "Responses": {
"Thread": { "Thread": {"ConsumedCapacityUnits": 1.0},
"ConsumedCapacityUnits": 1.0 "Reply": {"ConsumedCapacityUnits": 1.0},
}, },
"Reply": { "UnprocessedItems": {},
"ConsumedCapacityUnits": 1.0
}
},
"UnprocessedItems": {}
} }
return dynamo_json_dump(response) return dynamo_json_dump(response)
def get_item(self): def get_item(self):
name = self.body['TableName'] name = self.body["TableName"]
key = self.body['Key'] key = self.body["Key"]
hash_key = key['HashKeyElement'] hash_key = key["HashKeyElement"]
range_key = key.get('RangeKeyElement') range_key = key.get("RangeKeyElement")
attrs_to_get = self.body.get('AttributesToGet') attrs_to_get = self.body.get("AttributesToGet")
try: try:
item = dynamodb_backend.get_item(name, hash_key, range_key) item = dynamodb_backend.get_item(name, hash_key, range_key)
except ValueError: except ValueError:
er = 'com.amazon.coral.validate#ValidationException' er = "com.amazon.coral.validate#ValidationException"
return self.error(er, status=400) return self.error(er, status=400)
if item: if item:
item_dict = item.describe_attrs(attrs_to_get) item_dict = item.describe_attrs(attrs_to_get)
item_dict['ConsumedCapacityUnits'] = 0.5 item_dict["ConsumedCapacityUnits"] = 0.5
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
else: else:
# Item not found # Item not found
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er, status=404) return self.error(er, status=404)
def batch_get_item(self): def batch_get_item(self):
table_batches = self.body['RequestItems'] table_batches = self.body["RequestItems"]
results = { results = {"Responses": {"UnprocessedKeys": {}}}
"Responses": {
"UnprocessedKeys": {}
}
}
for table_name, table_request in table_batches.items(): for table_name, table_request in table_batches.items():
items = [] items = []
keys = table_request['Keys'] keys = table_request["Keys"]
attributes_to_get = table_request.get('AttributesToGet') attributes_to_get = table_request.get("AttributesToGet")
for key in keys: for key in keys:
hash_key = key["HashKeyElement"] hash_key = key["HashKeyElement"]
range_key = key.get("RangeKeyElement") range_key = key.get("RangeKeyElement")
item = dynamodb_backend.get_item( item = dynamodb_backend.get_item(table_name, hash_key, range_key)
table_name, hash_key, range_key)
if item: if item:
item_describe = item.describe_attrs(attributes_to_get) item_describe = item.describe_attrs(attributes_to_get)
items.append(item_describe) items.append(item_describe)
results["Responses"][table_name] = { results["Responses"][table_name] = {
"Items": items, "ConsumedCapacityUnits": 1} "Items": items,
"ConsumedCapacityUnits": 1,
}
return dynamo_json_dump(results) return dynamo_json_dump(results)
def query(self): def query(self):
name = self.body['TableName'] name = self.body["TableName"]
hash_key = self.body['HashKeyValue'] hash_key = self.body["HashKeyValue"]
range_condition = self.body.get('RangeKeyCondition') range_condition = self.body.get("RangeKeyCondition")
if range_condition: if range_condition:
range_comparison = range_condition['ComparisonOperator'] range_comparison = range_condition["ComparisonOperator"]
range_values = range_condition['AttributeValueList'] range_values = range_condition["AttributeValueList"]
else: else:
range_comparison = None range_comparison = None
range_values = [] range_values = []
items, last_page = dynamodb_backend.query( items, last_page = dynamodb_backend.query(
name, hash_key, range_comparison, range_values) name, hash_key, range_comparison, range_values
)
if items is None: if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er) return self.error(er)
result = { result = {
@ -234,10 +227,10 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result) return dynamo_json_dump(result)
def scan(self): def scan(self):
name = self.body['TableName'] name = self.body["TableName"]
filters = {} filters = {}
scan_filters = self.body.get('ScanFilter', {}) scan_filters = self.body.get("ScanFilter", {})
for attribute_name, scan_filter in scan_filters.items(): for attribute_name, scan_filter in scan_filters.items():
# Keys are attribute names. Values are tuples of (comparison, # Keys are attribute names. Values are tuples of (comparison,
# comparison_value) # comparison_value)
@ -248,14 +241,14 @@ class DynamoHandler(BaseResponse):
items, scanned_count, last_page = dynamodb_backend.scan(name, filters) items, scanned_count, last_page = dynamodb_backend.scan(name, filters)
if items is None: if items is None:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er) return self.error(er)
result = { result = {
"Count": len(items), "Count": len(items),
"Items": [item.attrs for item in items if item], "Items": [item.attrs for item in items if item],
"ConsumedCapacityUnits": 1, "ConsumedCapacityUnits": 1,
"ScannedCount": scanned_count "ScannedCount": scanned_count,
} }
# Implement this when we do pagination # Implement this when we do pagination
@ -267,19 +260,19 @@ class DynamoHandler(BaseResponse):
return dynamo_json_dump(result) return dynamo_json_dump(result)
def delete_item(self): def delete_item(self):
name = self.body['TableName'] name = self.body["TableName"]
key = self.body['Key'] key = self.body["Key"]
hash_key = key['HashKeyElement'] hash_key = key["HashKeyElement"]
range_key = key.get('RangeKeyElement') range_key = key.get("RangeKeyElement")
return_values = self.body.get('ReturnValues', '') return_values = self.body.get("ReturnValues", "")
item = dynamodb_backend.delete_item(name, hash_key, range_key) item = dynamodb_backend.delete_item(name, hash_key, range_key)
if item: if item:
if return_values == 'ALL_OLD': if return_values == "ALL_OLD":
item_dict = item.to_json() item_dict = item.to_json()
else: else:
item_dict = {'Attributes': []} item_dict = {"Attributes": []}
item_dict['ConsumedCapacityUnits'] = 0.5 item_dict["ConsumedCapacityUnits"] = 0.5
return dynamo_json_dump(item_dict) return dynamo_json_dump(item_dict)
else: else:
er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException"
return self.error(er) return self.error(er)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import DynamoHandler from .responses import DynamoHandler
url_bases = [ url_bases = ["https?://dynamodb.(.+).amazonaws.com"]
"https?://dynamodb.(.+).amazonaws.com"
]
url_paths = { url_paths = {"{0}/": DynamoHandler.dispatch}
"{0}/": DynamoHandler.dispatch,
}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import dynamodb_backends as dynamodb_backends2 from .models import dynamodb_backends as dynamodb_backends2
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
dynamodb_backend2 = dynamodb_backends2['us-east-1'] dynamodb_backend2 = dynamodb_backends2["us-east-1"]
mock_dynamodb2 = base_decorator(dynamodb_backends2) mock_dynamodb2 = base_decorator(dynamodb_backends2)
mock_dynamodb2_deprecated = deprecated_base_decorator(dynamodb_backends2) mock_dynamodb2_deprecated = deprecated_base_decorator(dynamodb_backends2)

File diff suppressed because it is too large Load Diff

View File

@ -7,4 +7,4 @@ class InvalidUpdateExpression(ValueError):
class ItemSizeTooLarge(Exception): class ItemSizeTooLarge(Exception):
message = 'Item size has exceeded the maximum allowed size' message = "Item size has exceeded the maximum allowed size"

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import DynamoHandler from .responses import DynamoHandler
url_bases = [ url_bases = ["https?://dynamodb.(.+).amazonaws.com"]
"https?://dynamodb.(.+).amazonaws.com"
]
url_paths = { url_paths = {"{0}/": DynamoHandler.dispatch}
"{0}/": DynamoHandler.dispatch,
}

View File

@ -2,5 +2,5 @@ from __future__ import unicode_literals
from .models import dynamodbstreams_backends from .models import dynamodbstreams_backends
from ..core.models import base_decorator from ..core.models import base_decorator
dynamodbstreams_backend = dynamodbstreams_backends['us-east-1'] dynamodbstreams_backend = dynamodbstreams_backends["us-east-1"]
mock_dynamodbstreams = base_decorator(dynamodbstreams_backends) mock_dynamodbstreams = base_decorator(dynamodbstreams_backends)

View File

@ -10,51 +10,59 @@ from moto.dynamodb2.models import dynamodb_backends
class ShardIterator(BaseModel): class ShardIterator(BaseModel):
def __init__(self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None): def __init__(
self.id = base64.b64encode(os.urandom(472)).decode('utf-8') self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None
):
self.id = base64.b64encode(os.urandom(472)).decode("utf-8")
self.streams_backend = streams_backend self.streams_backend = streams_backend
self.stream_shard = stream_shard self.stream_shard = stream_shard
self.shard_iterator_type = shard_iterator_type self.shard_iterator_type = shard_iterator_type
if shard_iterator_type == 'TRIM_HORIZON': if shard_iterator_type == "TRIM_HORIZON":
self.sequence_number = stream_shard.starting_sequence_number self.sequence_number = stream_shard.starting_sequence_number
elif shard_iterator_type == 'LATEST': elif shard_iterator_type == "LATEST":
self.sequence_number = stream_shard.starting_sequence_number + len(stream_shard.items) self.sequence_number = stream_shard.starting_sequence_number + len(
elif shard_iterator_type == 'AT_SEQUENCE_NUMBER': stream_shard.items
)
elif shard_iterator_type == "AT_SEQUENCE_NUMBER":
self.sequence_number = sequence_number self.sequence_number = sequence_number
elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER': elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER":
self.sequence_number = sequence_number + 1 self.sequence_number = sequence_number + 1
@property @property
def arn(self): def arn(self):
return '{}/stream/{}|1|{}'.format( return "{}/stream/{}|1|{}".format(
self.stream_shard.table.table_arn, self.stream_shard.table.table_arn,
self.stream_shard.table.latest_stream_label, self.stream_shard.table.latest_stream_label,
self.id) self.id,
)
def to_json(self): def to_json(self):
return { return {"ShardIterator": self.arn}
'ShardIterator': self.arn
}
def get(self, limit=1000): def get(self, limit=1000):
items = self.stream_shard.get(self.sequence_number, limit) items = self.stream_shard.get(self.sequence_number, limit)
try: try:
last_sequence_number = max(int(i['dynamodb']['SequenceNumber']) for i in items) last_sequence_number = max(
new_shard_iterator = ShardIterator(self.streams_backend, int(i["dynamodb"]["SequenceNumber"]) for i in items
)
new_shard_iterator = ShardIterator(
self.streams_backend,
self.stream_shard, self.stream_shard,
'AFTER_SEQUENCE_NUMBER', "AFTER_SEQUENCE_NUMBER",
last_sequence_number) last_sequence_number,
)
except ValueError: except ValueError:
new_shard_iterator = ShardIterator(self.streams_backend, new_shard_iterator = ShardIterator(
self.streams_backend,
self.stream_shard, self.stream_shard,
'AT_SEQUENCE_NUMBER', "AT_SEQUENCE_NUMBER",
self.sequence_number) self.sequence_number,
)
self.streams_backend.shard_iterators[new_shard_iterator.arn] = new_shard_iterator self.streams_backend.shard_iterators[
return { new_shard_iterator.arn
'NextShardIterator': new_shard_iterator.arn, ] = new_shard_iterator
'Records': items return {"NextShardIterator": new_shard_iterator.arn, "Records": items}
}
class DynamoDBStreamsBackend(BaseBackend): class DynamoDBStreamsBackend(BaseBackend):
@ -72,23 +80,27 @@ class DynamoDBStreamsBackend(BaseBackend):
return dynamodb_backends[self.region] return dynamodb_backends[self.region]
def _get_table_from_arn(self, arn): def _get_table_from_arn(self, arn):
table_name = arn.split(':', 6)[5].split('/')[1] table_name = arn.split(":", 6)[5].split("/")[1]
return self.dynamodb.get_table(table_name) return self.dynamodb.get_table(table_name)
def describe_stream(self, arn): def describe_stream(self, arn):
table = self._get_table_from_arn(arn) table = self._get_table_from_arn(arn)
resp = {'StreamDescription': { resp = {
'StreamArn': arn, "StreamDescription": {
'StreamLabel': table.latest_stream_label, "StreamArn": arn,
'StreamStatus': ('ENABLED' if table.latest_stream_label "StreamLabel": table.latest_stream_label,
else 'DISABLED'), "StreamStatus": (
'StreamViewType': table.stream_specification['StreamViewType'], "ENABLED" if table.latest_stream_label else "DISABLED"
'CreationRequestDateTime': table.stream_shard.created_on.isoformat(), ),
'TableName': table.name, "StreamViewType": table.stream_specification["StreamViewType"],
'KeySchema': table.schema, "CreationRequestDateTime": table.stream_shard.created_on.isoformat(),
'Shards': ([table.stream_shard.to_json()] if table.stream_shard "TableName": table.name,
else []) "KeySchema": table.schema,
}} "Shards": (
[table.stream_shard.to_json()] if table.stream_shard else []
),
}
}
return json.dumps(resp) return json.dumps(resp)
@ -98,22 +110,26 @@ class DynamoDBStreamsBackend(BaseBackend):
if table_name is not None and table.name != table_name: if table_name is not None and table.name != table_name:
continue continue
if table.latest_stream_label: if table.latest_stream_label:
d = table.describe(base_key='Table') d = table.describe(base_key="Table")
streams.append({ streams.append(
'StreamArn': d['Table']['LatestStreamArn'], {
'TableName': d['Table']['TableName'], "StreamArn": d["Table"]["LatestStreamArn"],
'StreamLabel': d['Table']['LatestStreamLabel'] "TableName": d["Table"]["TableName"],
}) "StreamLabel": d["Table"]["LatestStreamLabel"],
}
)
return json.dumps({'Streams': streams}) return json.dumps({"Streams": streams})
def get_shard_iterator(self, arn, shard_id, shard_iterator_type, sequence_number=None): def get_shard_iterator(
self, arn, shard_id, shard_iterator_type, sequence_number=None
):
table = self._get_table_from_arn(arn) table = self._get_table_from_arn(arn)
assert table.stream_shard.id == shard_id assert table.stream_shard.id == shard_id
shard_iterator = ShardIterator(self, table.stream_shard, shard_iterator = ShardIterator(
shard_iterator_type, self, table.stream_shard, shard_iterator_type, sequence_number
sequence_number) )
self.shard_iterators[shard_iterator.arn] = shard_iterator self.shard_iterators[shard_iterator.arn] = shard_iterator
return json.dumps(shard_iterator.to_json()) return json.dumps(shard_iterator.to_json())
@ -123,7 +139,7 @@ class DynamoDBStreamsBackend(BaseBackend):
return json.dumps(shard_iterator.get(limit)) return json.dumps(shard_iterator.get(limit))
available_regions = boto3.session.Session().get_available_regions( available_regions = boto3.session.Session().get_available_regions("dynamodbstreams")
'dynamodbstreams') dynamodbstreams_backends = {
dynamodbstreams_backends = {region: DynamoDBStreamsBackend(region=region) region: DynamoDBStreamsBackend(region=region) for region in available_regions
for region in available_regions} }

View File

@ -7,34 +7,34 @@ from six import string_types
class DynamoDBStreamsHandler(BaseResponse): class DynamoDBStreamsHandler(BaseResponse):
@property @property
def backend(self): def backend(self):
return dynamodbstreams_backends[self.region] return dynamodbstreams_backends[self.region]
def describe_stream(self): def describe_stream(self):
arn = self._get_param('StreamArn') arn = self._get_param("StreamArn")
return self.backend.describe_stream(arn) return self.backend.describe_stream(arn)
def list_streams(self): def list_streams(self):
table_name = self._get_param('TableName') table_name = self._get_param("TableName")
return self.backend.list_streams(table_name) return self.backend.list_streams(table_name)
def get_shard_iterator(self): def get_shard_iterator(self):
arn = self._get_param('StreamArn') arn = self._get_param("StreamArn")
shard_id = self._get_param('ShardId') shard_id = self._get_param("ShardId")
shard_iterator_type = self._get_param('ShardIteratorType') shard_iterator_type = self._get_param("ShardIteratorType")
sequence_number = self._get_param('SequenceNumber') sequence_number = self._get_param("SequenceNumber")
# according to documentation sequence_number param should be string # according to documentation sequence_number param should be string
if isinstance(sequence_number, string_types): if isinstance(sequence_number, string_types):
sequence_number = int(sequence_number) sequence_number = int(sequence_number)
return self.backend.get_shard_iterator(arn, shard_id, return self.backend.get_shard_iterator(
shard_iterator_type, sequence_number) arn, shard_id, shard_iterator_type, sequence_number
)
def get_records(self): def get_records(self):
arn = self._get_param('ShardIterator') arn = self._get_param("ShardIterator")
limit = self._get_param('Limit') limit = self._get_param("Limit")
if limit is None: if limit is None:
limit = 1000 limit = 1000
return self.backend.get_records(arn, limit) return self.backend.get_records(arn, limit)

View File

@ -1,10 +1,6 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from .responses import DynamoDBStreamsHandler from .responses import DynamoDBStreamsHandler
url_bases = [ url_bases = ["https?://streams.dynamodb.(.+).amazonaws.com"]
"https?://streams.dynamodb.(.+).amazonaws.com"
]
url_paths = { url_paths = {"{0}/$": DynamoDBStreamsHandler.dispatch}
"{0}/$": DynamoDBStreamsHandler.dispatch,
}

View File

@ -2,6 +2,6 @@ from __future__ import unicode_literals
from .models import ec2_backends from .models import ec2_backends
from ..core.models import base_decorator, deprecated_base_decorator from ..core.models import base_decorator, deprecated_base_decorator
ec2_backend = ec2_backends['us-east-1'] ec2_backend = ec2_backends["us-east-1"]
mock_ec2 = base_decorator(ec2_backends) mock_ec2 = base_decorator(ec2_backends)
mock_ec2_deprecated = deprecated_base_decorator(ec2_backends) mock_ec2_deprecated = deprecated_base_decorator(ec2_backends)

View File

@ -7,499 +7,468 @@ class EC2ClientError(RESTError):
class DependencyViolationError(EC2ClientError): class DependencyViolationError(EC2ClientError):
def __init__(self, message): def __init__(self, message):
super(DependencyViolationError, self).__init__( super(DependencyViolationError, self).__init__("DependencyViolation", message)
"DependencyViolation", message)
class MissingParameterError(EC2ClientError): class MissingParameterError(EC2ClientError):
def __init__(self, parameter): def __init__(self, parameter):
super(MissingParameterError, self).__init__( super(MissingParameterError, self).__init__(
"MissingParameter", "MissingParameter",
"The request must contain the parameter {0}" "The request must contain the parameter {0}".format(parameter),
.format(parameter)) )
class InvalidDHCPOptionsIdError(EC2ClientError): class InvalidDHCPOptionsIdError(EC2ClientError):
def __init__(self, dhcp_options_id): def __init__(self, dhcp_options_id):
super(InvalidDHCPOptionsIdError, self).__init__( super(InvalidDHCPOptionsIdError, self).__init__(
"InvalidDhcpOptionID.NotFound", "InvalidDhcpOptionID.NotFound",
"DhcpOptionID {0} does not exist." "DhcpOptionID {0} does not exist.".format(dhcp_options_id),
.format(dhcp_options_id)) )
class MalformedDHCPOptionsIdError(EC2ClientError): class MalformedDHCPOptionsIdError(EC2ClientError):
def __init__(self, dhcp_options_id): def __init__(self, dhcp_options_id):
super(MalformedDHCPOptionsIdError, self).__init__( super(MalformedDHCPOptionsIdError, self).__init__(
"InvalidDhcpOptionsId.Malformed", "InvalidDhcpOptionsId.Malformed",
"Invalid id: \"{0}\" (expecting \"dopt-...\")" 'Invalid id: "{0}" (expecting "dopt-...")'.format(dhcp_options_id),
.format(dhcp_options_id)) )
class InvalidKeyPairNameError(EC2ClientError): class InvalidKeyPairNameError(EC2ClientError):
def __init__(self, key): def __init__(self, key):
super(InvalidKeyPairNameError, self).__init__( super(InvalidKeyPairNameError, self).__init__(
"InvalidKeyPair.NotFound", "InvalidKeyPair.NotFound", "The keypair '{0}' does not exist.".format(key)
"The keypair '{0}' does not exist." )
.format(key))
class InvalidKeyPairDuplicateError(EC2ClientError): class InvalidKeyPairDuplicateError(EC2ClientError):
def __init__(self, key): def __init__(self, key):
super(InvalidKeyPairDuplicateError, self).__init__( super(InvalidKeyPairDuplicateError, self).__init__(
"InvalidKeyPair.Duplicate", "InvalidKeyPair.Duplicate", "The keypair '{0}' already exists.".format(key)
"The keypair '{0}' already exists." )
.format(key))
class InvalidKeyPairFormatError(EC2ClientError): class InvalidKeyPairFormatError(EC2ClientError):
def __init__(self): def __init__(self):
super(InvalidKeyPairFormatError, self).__init__( super(InvalidKeyPairFormatError, self).__init__(
"InvalidKeyPair.Format", "InvalidKeyPair.Format", "Key is not in valid OpenSSH public key format"
"Key is not in valid OpenSSH public key format") )
class InvalidVPCIdError(EC2ClientError): class InvalidVPCIdError(EC2ClientError):
def __init__(self, vpc_id): def __init__(self, vpc_id):
super(InvalidVPCIdError, self).__init__( super(InvalidVPCIdError, self).__init__(
"InvalidVpcID.NotFound", "InvalidVpcID.NotFound", "VpcID {0} does not exist.".format(vpc_id)
"VpcID {0} does not exist." )
.format(vpc_id))
class InvalidSubnetIdError(EC2ClientError): class InvalidSubnetIdError(EC2ClientError):
def __init__(self, subnet_id): def __init__(self, subnet_id):
super(InvalidSubnetIdError, self).__init__( super(InvalidSubnetIdError, self).__init__(
"InvalidSubnetID.NotFound", "InvalidSubnetID.NotFound",
"The subnet ID '{0}' does not exist" "The subnet ID '{0}' does not exist".format(subnet_id),
.format(subnet_id)) )
class InvalidNetworkAclIdError(EC2ClientError): class InvalidNetworkAclIdError(EC2ClientError):
def __init__(self, network_acl_id): def __init__(self, network_acl_id):
super(InvalidNetworkAclIdError, self).__init__( super(InvalidNetworkAclIdError, self).__init__(
"InvalidNetworkAclID.NotFound", "InvalidNetworkAclID.NotFound",
"The network acl ID '{0}' does not exist" "The network acl ID '{0}' does not exist".format(network_acl_id),
.format(network_acl_id)) )
class InvalidVpnGatewayIdError(EC2ClientError): class InvalidVpnGatewayIdError(EC2ClientError):
def __init__(self, network_acl_id): def __init__(self, network_acl_id):
super(InvalidVpnGatewayIdError, self).__init__( super(InvalidVpnGatewayIdError, self).__init__(
"InvalidVpnGatewayID.NotFound", "InvalidVpnGatewayID.NotFound",
"The virtual private gateway ID '{0}' does not exist" "The virtual private gateway ID '{0}' does not exist".format(
.format(network_acl_id)) network_acl_id
),
)
class InvalidVpnConnectionIdError(EC2ClientError): class InvalidVpnConnectionIdError(EC2ClientError):
def __init__(self, network_acl_id): def __init__(self, network_acl_id):
super(InvalidVpnConnectionIdError, self).__init__( super(InvalidVpnConnectionIdError, self).__init__(
"InvalidVpnConnectionID.NotFound", "InvalidVpnConnectionID.NotFound",
"The vpnConnection ID '{0}' does not exist" "The vpnConnection ID '{0}' does not exist".format(network_acl_id),
.format(network_acl_id)) )
class InvalidCustomerGatewayIdError(EC2ClientError): class InvalidCustomerGatewayIdError(EC2ClientError):
def __init__(self, customer_gateway_id): def __init__(self, customer_gateway_id):
super(InvalidCustomerGatewayIdError, self).__init__( super(InvalidCustomerGatewayIdError, self).__init__(
"InvalidCustomerGatewayID.NotFound", "InvalidCustomerGatewayID.NotFound",
"The customer gateway ID '{0}' does not exist" "The customer gateway ID '{0}' does not exist".format(customer_gateway_id),
.format(customer_gateway_id)) )
class InvalidNetworkInterfaceIdError(EC2ClientError): class InvalidNetworkInterfaceIdError(EC2ClientError):
def __init__(self, eni_id): def __init__(self, eni_id):
super(InvalidNetworkInterfaceIdError, self).__init__( super(InvalidNetworkInterfaceIdError, self).__init__(
"InvalidNetworkInterfaceID.NotFound", "InvalidNetworkInterfaceID.NotFound",
"The network interface ID '{0}' does not exist" "The network interface ID '{0}' does not exist".format(eni_id),
.format(eni_id)) )
class InvalidNetworkAttachmentIdError(EC2ClientError): class InvalidNetworkAttachmentIdError(EC2ClientError):
def __init__(self, attachment_id): def __init__(self, attachment_id):
super(InvalidNetworkAttachmentIdError, self).__init__( super(InvalidNetworkAttachmentIdError, self).__init__(
"InvalidAttachmentID.NotFound", "InvalidAttachmentID.NotFound",
"The network interface attachment ID '{0}' does not exist" "The network interface attachment ID '{0}' does not exist".format(
.format(attachment_id)) attachment_id
),
)
class InvalidSecurityGroupDuplicateError(EC2ClientError): class InvalidSecurityGroupDuplicateError(EC2ClientError):
def __init__(self, name): def __init__(self, name):
super(InvalidSecurityGroupDuplicateError, self).__init__( super(InvalidSecurityGroupDuplicateError, self).__init__(
"InvalidGroup.Duplicate", "InvalidGroup.Duplicate",
"The security group '{0}' already exists" "The security group '{0}' already exists".format(name),
.format(name)) )
class InvalidSecurityGroupNotFoundError(EC2ClientError): class InvalidSecurityGroupNotFoundError(EC2ClientError):
def __init__(self, name): def __init__(self, name):
super(InvalidSecurityGroupNotFoundError, self).__init__( super(InvalidSecurityGroupNotFoundError, self).__init__(
"InvalidGroup.NotFound", "InvalidGroup.NotFound",
"The security group '{0}' does not exist" "The security group '{0}' does not exist".format(name),
.format(name)) )
class InvalidPermissionNotFoundError(EC2ClientError): class InvalidPermissionNotFoundError(EC2ClientError):
def __init__(self): def __init__(self):
super(InvalidPermissionNotFoundError, self).__init__( super(InvalidPermissionNotFoundError, self).__init__(
"InvalidPermission.NotFound", "InvalidPermission.NotFound",
"The specified rule does not exist in this security group") "The specified rule does not exist in this security group",
)
class InvalidPermissionDuplicateError(EC2ClientError): class InvalidPermissionDuplicateError(EC2ClientError):
def __init__(self): def __init__(self):
super(InvalidPermissionDuplicateError, self).__init__( super(InvalidPermissionDuplicateError, self).__init__(
"InvalidPermission.Duplicate", "InvalidPermission.Duplicate", "The specified rule already exists"
"The specified rule already exists") )
class InvalidRouteTableIdError(EC2ClientError): class InvalidRouteTableIdError(EC2ClientError):
def __init__(self, route_table_id): def __init__(self, route_table_id):
super(InvalidRouteTableIdError, self).__init__( super(InvalidRouteTableIdError, self).__init__(
"InvalidRouteTableID.NotFound", "InvalidRouteTableID.NotFound",
"The routeTable ID '{0}' does not exist" "The routeTable ID '{0}' does not exist".format(route_table_id),
.format(route_table_id)) )
class InvalidRouteError(EC2ClientError): class InvalidRouteError(EC2ClientError):
def __init__(self, route_table_id, cidr): def __init__(self, route_table_id, cidr):
super(InvalidRouteError, self).__init__( super(InvalidRouteError, self).__init__(
"InvalidRoute.NotFound", "InvalidRoute.NotFound",
"no route with destination-cidr-block {0} in route table {1}" "no route with destination-cidr-block {0} in route table {1}".format(
.format(cidr, route_table_id)) cidr, route_table_id
),
)
class InvalidInstanceIdError(EC2ClientError): class InvalidInstanceIdError(EC2ClientError):
def __init__(self, instance_id): def __init__(self, instance_id):
super(InvalidInstanceIdError, self).__init__( super(InvalidInstanceIdError, self).__init__(
"InvalidInstanceID.NotFound", "InvalidInstanceID.NotFound",
"The instance ID '{0}' does not exist" "The instance ID '{0}' does not exist".format(instance_id),
.format(instance_id)) )
class InvalidAMIIdError(EC2ClientError): class InvalidAMIIdError(EC2ClientError):
def __init__(self, ami_id): def __init__(self, ami_id):
super(InvalidAMIIdError, self).__init__( super(InvalidAMIIdError, self).__init__(
"InvalidAMIID.NotFound", "InvalidAMIID.NotFound",
"The image id '[{0}]' does not exist" "The image id '[{0}]' does not exist".format(ami_id),
.format(ami_id)) )
class InvalidAMIAttributeItemValueError(EC2ClientError): class InvalidAMIAttributeItemValueError(EC2ClientError):
def __init__(self, attribute, value): def __init__(self, attribute, value):
super(InvalidAMIAttributeItemValueError, self).__init__( super(InvalidAMIAttributeItemValueError, self).__init__(
"InvalidAMIAttributeItemValue", "InvalidAMIAttributeItemValue",
"Invalid attribute item value \"{0}\" for {1} item type." 'Invalid attribute item value "{0}" for {1} item type.'.format(
.format(value, attribute)) value, attribute
),
)
class MalformedAMIIdError(EC2ClientError): class MalformedAMIIdError(EC2ClientError):
def __init__(self, ami_id): def __init__(self, ami_id):
super(MalformedAMIIdError, self).__init__( super(MalformedAMIIdError, self).__init__(
"InvalidAMIID.Malformed", "InvalidAMIID.Malformed",
"Invalid id: \"{0}\" (expecting \"ami-...\")" 'Invalid id: "{0}" (expecting "ami-...")'.format(ami_id),
.format(ami_id)) )
class InvalidSnapshotIdError(EC2ClientError): class InvalidSnapshotIdError(EC2ClientError):
def __init__(self, snapshot_id): def __init__(self, snapshot_id):
super(InvalidSnapshotIdError, self).__init__( super(InvalidSnapshotIdError, self).__init__(
"InvalidSnapshot.NotFound", "InvalidSnapshot.NotFound", ""
"") # Note: AWS returns empty message for this, as of 2014.08.22. ) # Note: AWS returns empty message for this, as of 2014.08.22.
class InvalidVolumeIdError(EC2ClientError): class InvalidVolumeIdError(EC2ClientError):
def __init__(self, volume_id): def __init__(self, volume_id):
super(InvalidVolumeIdError, self).__init__( super(InvalidVolumeIdError, self).__init__(
"InvalidVolume.NotFound", "InvalidVolume.NotFound",
"The volume '{0}' does not exist." "The volume '{0}' does not exist.".format(volume_id),
.format(volume_id)) )
class InvalidVolumeAttachmentError(EC2ClientError): class InvalidVolumeAttachmentError(EC2ClientError):
def __init__(self, volume_id, instance_id): def __init__(self, volume_id, instance_id):
super(InvalidVolumeAttachmentError, self).__init__( super(InvalidVolumeAttachmentError, self).__init__(
"InvalidAttachment.NotFound", "InvalidAttachment.NotFound",
"Volume {0} can not be detached from {1} because it is not attached" "Volume {0} can not be detached from {1} because it is not attached".format(
.format(volume_id, instance_id)) volume_id, instance_id
),
)
class InvalidDomainError(EC2ClientError): class InvalidDomainError(EC2ClientError):
def __init__(self, domain): def __init__(self, domain):
super(InvalidDomainError, self).__init__( super(InvalidDomainError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue", "Invalid value '{0}' for domain.".format(domain)
"Invalid value '{0}' for domain." )
.format(domain))
class InvalidAddressError(EC2ClientError): class InvalidAddressError(EC2ClientError):
def __init__(self, ip): def __init__(self, ip):
super(InvalidAddressError, self).__init__( super(InvalidAddressError, self).__init__(
"InvalidAddress.NotFound", "InvalidAddress.NotFound", "Address '{0}' not found.".format(ip)
"Address '{0}' not found." )
.format(ip))
class InvalidAllocationIdError(EC2ClientError): class InvalidAllocationIdError(EC2ClientError):
def __init__(self, allocation_id): def __init__(self, allocation_id):
super(InvalidAllocationIdError, self).__init__( super(InvalidAllocationIdError, self).__init__(
"InvalidAllocationID.NotFound", "InvalidAllocationID.NotFound",
"Allocation ID '{0}' not found." "Allocation ID '{0}' not found.".format(allocation_id),
.format(allocation_id)) )
class InvalidAssociationIdError(EC2ClientError): class InvalidAssociationIdError(EC2ClientError):
def __init__(self, association_id): def __init__(self, association_id):
super(InvalidAssociationIdError, self).__init__( super(InvalidAssociationIdError, self).__init__(
"InvalidAssociationID.NotFound", "InvalidAssociationID.NotFound",
"Association ID '{0}' not found." "Association ID '{0}' not found.".format(association_id),
.format(association_id)) )
class InvalidVpcCidrBlockAssociationIdError(EC2ClientError): class InvalidVpcCidrBlockAssociationIdError(EC2ClientError):
def __init__(self, association_id): def __init__(self, association_id):
super(InvalidVpcCidrBlockAssociationIdError, self).__init__( super(InvalidVpcCidrBlockAssociationIdError, self).__init__(
"InvalidVpcCidrBlockAssociationIdError.NotFound", "InvalidVpcCidrBlockAssociationIdError.NotFound",
"The vpc CIDR block association ID '{0}' does not exist" "The vpc CIDR block association ID '{0}' does not exist".format(
.format(association_id)) association_id
),
)
class InvalidVPCPeeringConnectionIdError(EC2ClientError): class InvalidVPCPeeringConnectionIdError(EC2ClientError):
def __init__(self, vpc_peering_connection_id): def __init__(self, vpc_peering_connection_id):
super(InvalidVPCPeeringConnectionIdError, self).__init__( super(InvalidVPCPeeringConnectionIdError, self).__init__(
"InvalidVpcPeeringConnectionId.NotFound", "InvalidVpcPeeringConnectionId.NotFound",
"VpcPeeringConnectionID {0} does not exist." "VpcPeeringConnectionID {0} does not exist.".format(
.format(vpc_peering_connection_id)) vpc_peering_connection_id
),
)
class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError): class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError):
def __init__(self, vpc_peering_connection_id): def __init__(self, vpc_peering_connection_id):
super(InvalidVPCPeeringConnectionStateTransitionError, self).__init__( super(InvalidVPCPeeringConnectionStateTransitionError, self).__init__(
"InvalidStateTransition", "InvalidStateTransition",
"VpcPeeringConnectionID {0} is not in the correct state for the request." "VpcPeeringConnectionID {0} is not in the correct state for the request.".format(
.format(vpc_peering_connection_id)) vpc_peering_connection_id
),
)
class InvalidParameterValueError(EC2ClientError): class InvalidParameterValueError(EC2ClientError):
def __init__(self, parameter_value): def __init__(self, parameter_value):
super(InvalidParameterValueError, self).__init__( super(InvalidParameterValueError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
"Value {0} is invalid for parameter." "Value {0} is invalid for parameter.".format(parameter_value),
.format(parameter_value)) )
class InvalidParameterValueErrorTagNull(EC2ClientError): class InvalidParameterValueErrorTagNull(EC2ClientError):
def __init__(self): def __init__(self):
super(InvalidParameterValueErrorTagNull, self).__init__( super(InvalidParameterValueErrorTagNull, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
"Tag value cannot be null. Use empty string instead.") "Tag value cannot be null. Use empty string instead.",
)
class InvalidParameterValueErrorUnknownAttribute(EC2ClientError): class InvalidParameterValueErrorUnknownAttribute(EC2ClientError):
def __init__(self, parameter_value): def __init__(self, parameter_value):
super(InvalidParameterValueErrorUnknownAttribute, self).__init__( super(InvalidParameterValueErrorUnknownAttribute, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
"Value ({0}) for parameter attribute is invalid. Unknown attribute." "Value ({0}) for parameter attribute is invalid. Unknown attribute.".format(
.format(parameter_value)) parameter_value
),
)
class InvalidInternetGatewayIdError(EC2ClientError): class InvalidInternetGatewayIdError(EC2ClientError):
def __init__(self, internet_gateway_id): def __init__(self, internet_gateway_id):
super(InvalidInternetGatewayIdError, self).__init__( super(InvalidInternetGatewayIdError, self).__init__(
"InvalidInternetGatewayID.NotFound", "InvalidInternetGatewayID.NotFound",
"InternetGatewayID {0} does not exist." "InternetGatewayID {0} does not exist.".format(internet_gateway_id),
.format(internet_gateway_id)) )
class GatewayNotAttachedError(EC2ClientError): class GatewayNotAttachedError(EC2ClientError):
def __init__(self, internet_gateway_id, vpc_id): def __init__(self, internet_gateway_id, vpc_id):
super(GatewayNotAttachedError, self).__init__( super(GatewayNotAttachedError, self).__init__(
"Gateway.NotAttached", "Gateway.NotAttached",
"InternetGatewayID {0} is not attached to a VPC {1}." "InternetGatewayID {0} is not attached to a VPC {1}.".format(
.format(internet_gateway_id, vpc_id)) internet_gateway_id, vpc_id
),
)
class ResourceAlreadyAssociatedError(EC2ClientError): class ResourceAlreadyAssociatedError(EC2ClientError):
def __init__(self, resource_id): def __init__(self, resource_id):
super(ResourceAlreadyAssociatedError, self).__init__( super(ResourceAlreadyAssociatedError, self).__init__(
"Resource.AlreadyAssociated", "Resource.AlreadyAssociated",
"Resource {0} is already associated." "Resource {0} is already associated.".format(resource_id),
.format(resource_id)) )
class TagLimitExceeded(EC2ClientError): class TagLimitExceeded(EC2ClientError):
def __init__(self): def __init__(self):
super(TagLimitExceeded, self).__init__( super(TagLimitExceeded, self).__init__(
"TagLimitExceeded", "TagLimitExceeded",
"The maximum number of Tags for a resource has been reached.") "The maximum number of Tags for a resource has been reached.",
)
class InvalidID(EC2ClientError): class InvalidID(EC2ClientError):
def __init__(self, resource_id): def __init__(self, resource_id):
super(InvalidID, self).__init__( super(InvalidID, self).__init__(
"InvalidID", "InvalidID", "The ID '{0}' is not valid".format(resource_id)
"The ID '{0}' is not valid" )
.format(resource_id))
class InvalidCIDRSubnetError(EC2ClientError): class InvalidCIDRSubnetError(EC2ClientError):
def __init__(self, cidr): def __init__(self, cidr):
super(InvalidCIDRSubnetError, self).__init__( super(InvalidCIDRSubnetError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
"invalid CIDR subnet specification: {0}" "invalid CIDR subnet specification: {0}".format(cidr),
.format(cidr)) )
class RulesPerSecurityGroupLimitExceededError(EC2ClientError): class RulesPerSecurityGroupLimitExceededError(EC2ClientError):
def __init__(self): def __init__(self):
super(RulesPerSecurityGroupLimitExceededError, self).__init__( super(RulesPerSecurityGroupLimitExceededError, self).__init__(
"RulesPerSecurityGroupLimitExceeded", "RulesPerSecurityGroupLimitExceeded",
'The maximum number of rules per security group ' "The maximum number of rules per security group " "has been reached.",
'has been reached.') )
class MotoNotImplementedError(NotImplementedError): class MotoNotImplementedError(NotImplementedError):
def __init__(self, blurb): def __init__(self, blurb):
super(MotoNotImplementedError, self).__init__( super(MotoNotImplementedError, self).__init__(
"{0} has not been implemented in Moto yet." "{0} has not been implemented in Moto yet."
" Feel free to open an issue at" " Feel free to open an issue at"
" https://github.com/spulec/moto/issues".format(blurb)) " https://github.com/spulec/moto/issues".format(blurb)
)
class FilterNotImplementedError(MotoNotImplementedError): class FilterNotImplementedError(MotoNotImplementedError):
def __init__(self, filter_name, method_name): def __init__(self, filter_name, method_name):
super(FilterNotImplementedError, self).__init__( super(FilterNotImplementedError, self).__init__(
"The filter '{0}' for {1}".format( "The filter '{0}' for {1}".format(filter_name, method_name)
filter_name, method_name)) )
class CidrLimitExceeded(EC2ClientError): class CidrLimitExceeded(EC2ClientError):
def __init__(self, vpc_id, max_cidr_limit): def __init__(self, vpc_id, max_cidr_limit):
super(CidrLimitExceeded, self).__init__( super(CidrLimitExceeded, self).__init__(
"CidrLimitExceeded", "CidrLimitExceeded",
"This network '{0}' has met its maximum number of allowed CIDRs: {1}".format(vpc_id, max_cidr_limit) "This network '{0}' has met its maximum number of allowed CIDRs: {1}".format(
vpc_id, max_cidr_limit
),
) )
class OperationNotPermitted(EC2ClientError): class OperationNotPermitted(EC2ClientError):
def __init__(self, association_id): def __init__(self, association_id):
super(OperationNotPermitted, self).__init__( super(OperationNotPermitted, self).__init__(
"OperationNotPermitted", "OperationNotPermitted",
"The vpc CIDR block with association ID {} may not be disassociated. " "The vpc CIDR block with association ID {} may not be disassociated. "
"It is the primary IPv4 CIDR block of the VPC".format(association_id) "It is the primary IPv4 CIDR block of the VPC".format(association_id),
) )
class InvalidAvailabilityZoneError(EC2ClientError): class InvalidAvailabilityZoneError(EC2ClientError):
def __init__(self, availability_zone_value, valid_availability_zones): def __init__(self, availability_zone_value, valid_availability_zones):
super(InvalidAvailabilityZoneError, self).__init__( super(InvalidAvailabilityZoneError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
"Value ({0}) for parameter availabilityZone is invalid. " "Value ({0}) for parameter availabilityZone is invalid. "
"Subnets can currently only be created in the following availability zones: {1}.".format(availability_zone_value, valid_availability_zones) "Subnets can currently only be created in the following availability zones: {1}.".format(
availability_zone_value, valid_availability_zones
),
) )
class NetworkAclEntryAlreadyExistsError(EC2ClientError): class NetworkAclEntryAlreadyExistsError(EC2ClientError):
def __init__(self, rule_number): def __init__(self, rule_number):
super(NetworkAclEntryAlreadyExistsError, self).__init__( super(NetworkAclEntryAlreadyExistsError, self).__init__(
"NetworkAclEntryAlreadyExists", "NetworkAclEntryAlreadyExists",
"The network acl entry identified by {} already exists.".format(rule_number) "The network acl entry identified by {} already exists.".format(
rule_number
),
) )
class InvalidSubnetRangeError(EC2ClientError): class InvalidSubnetRangeError(EC2ClientError):
def __init__(self, cidr_block): def __init__(self, cidr_block):
super(InvalidSubnetRangeError, self).__init__( super(InvalidSubnetRangeError, self).__init__(
"InvalidSubnet.Range", "InvalidSubnet.Range", "The CIDR '{}' is invalid.".format(cidr_block)
"The CIDR '{}' is invalid.".format(cidr_block)
) )
class InvalidCIDRBlockParameterError(EC2ClientError): class InvalidCIDRBlockParameterError(EC2ClientError):
def __init__(self, cidr_block): def __init__(self, cidr_block):
super(InvalidCIDRBlockParameterError, self).__init__( super(InvalidCIDRBlockParameterError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
"Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block) "Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(
cidr_block
),
) )
class InvalidDestinationCIDRBlockParameterError(EC2ClientError): class InvalidDestinationCIDRBlockParameterError(EC2ClientError):
def __init__(self, cidr_block): def __init__(self, cidr_block):
super(InvalidDestinationCIDRBlockParameterError, self).__init__( super(InvalidDestinationCIDRBlockParameterError, self).__init__(
"InvalidParameterValue", "InvalidParameterValue",
"Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block) "Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(
cidr_block
),
) )
class InvalidSubnetConflictError(EC2ClientError): class InvalidSubnetConflictError(EC2ClientError):
def __init__(self, cidr_block): def __init__(self, cidr_block):
super(InvalidSubnetConflictError, self).__init__( super(InvalidSubnetConflictError, self).__init__(
"InvalidSubnet.Conflict", "InvalidSubnet.Conflict",
"The CIDR '{}' conflicts with another subnet".format(cidr_block) "The CIDR '{}' conflicts with another subnet".format(cidr_block),
) )
class InvalidVPCRangeError(EC2ClientError): class InvalidVPCRangeError(EC2ClientError):
def __init__(self, cidr_block): def __init__(self, cidr_block):
super(InvalidVPCRangeError, self).__init__( super(InvalidVPCRangeError, self).__init__(
"InvalidVpc.Range", "InvalidVpc.Range", "The CIDR '{}' is invalid.".format(cidr_block)
"The CIDR '{}' is invalid.".format(cidr_block)
) )
@ -509,7 +478,9 @@ class OperationNotPermitted2(EC2ClientError):
super(OperationNotPermitted2, self).__init__( super(OperationNotPermitted2, self).__init__(
"OperationNotPermitted", "OperationNotPermitted",
"Incorrect region ({0}) specified for this request." "Incorrect region ({0}) specified for this request."
"VPC peering connection {1} must be accepted in region {2}".format(client_region, pcx_id, acceptor_region) "VPC peering connection {1} must be accepted in region {2}".format(
client_region, pcx_id, acceptor_region
),
) )
@ -519,9 +490,9 @@ class OperationNotPermitted3(EC2ClientError):
super(OperationNotPermitted3, self).__init__( super(OperationNotPermitted3, self).__init__(
"OperationNotPermitted", "OperationNotPermitted",
"Incorrect region ({0}) specified for this request." "Incorrect region ({0}) specified for this request."
"VPC peering connection {1} must be accepted or rejected in region {2}".format(client_region, "VPC peering connection {1} must be accepted or rejected in region {2}".format(
pcx_id, client_region, pcx_id, acceptor_region
acceptor_region) ),
) )
@ -529,5 +500,5 @@ class InvalidLaunchTemplateNameError(EC2ClientError):
def __init__(self): def __init__(self):
super(InvalidLaunchTemplateNameError, self).__init__( super(InvalidLaunchTemplateNameError, self).__init__(
"InvalidLaunchTemplateName.AlreadyExistsException", "InvalidLaunchTemplateName.AlreadyExistsException",
"Launch template name already in use." "Launch template name already in use.",
) )

File diff suppressed because it is too large Load Diff

View File

@ -70,10 +70,10 @@ class EC2Response(
Windows, Windows,
NatGateways, NatGateways,
): ):
@property @property
def ec2_backend(self): def ec2_backend(self):
from moto.ec2.models import ec2_backends from moto.ec2.models import ec2_backends
return ec2_backends[self.region] return ec2_backends[self.region]
@property @property

View File

@ -3,13 +3,12 @@ from moto.core.responses import BaseResponse
class AccountAttributes(BaseResponse): class AccountAttributes(BaseResponse):
def describe_account_attributes(self): def describe_account_attributes(self):
template = self.response_template(DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT) template = self.response_template(DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT)
return template.render() return template.render()
DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = u""" DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = """
<DescribeAccountAttributesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/"> <DescribeAccountAttributesResponse xmlns="http://ec2.amazonaws.com/doc/2016-11-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<accountAttributeSet> <accountAttributeSet>

View File

@ -3,7 +3,7 @@ from moto.core.responses import BaseResponse
class AmazonDevPay(BaseResponse): class AmazonDevPay(BaseResponse):
def confirm_product_instance(self): def confirm_product_instance(self):
raise NotImplementedError( raise NotImplementedError(
'AmazonDevPay.confirm_product_instance is not yet implemented') "AmazonDevPay.confirm_product_instance is not yet implemented"
)

View File

@ -4,76 +4,83 @@ from moto.ec2.utils import filters_from_querystring
class AmisResponse(BaseResponse): class AmisResponse(BaseResponse):
def create_image(self): def create_image(self):
name = self.querystring.get('Name')[0] name = self.querystring.get("Name")[0]
description = self._get_param('Description', if_none='') description = self._get_param("Description", if_none="")
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
if self.is_not_dryrun('CreateImage'): if self.is_not_dryrun("CreateImage"):
image = self.ec2_backend.create_image( image = self.ec2_backend.create_image(
instance_id, name, description, context=self) instance_id, name, description, context=self
)
template = self.response_template(CREATE_IMAGE_RESPONSE) template = self.response_template(CREATE_IMAGE_RESPONSE)
return template.render(image=image) return template.render(image=image)
def copy_image(self): def copy_image(self):
source_image_id = self._get_param('SourceImageId') source_image_id = self._get_param("SourceImageId")
source_region = self._get_param('SourceRegion') source_region = self._get_param("SourceRegion")
name = self._get_param('Name') name = self._get_param("Name")
description = self._get_param('Description') description = self._get_param("Description")
if self.is_not_dryrun('CopyImage'): if self.is_not_dryrun("CopyImage"):
image = self.ec2_backend.copy_image( image = self.ec2_backend.copy_image(
source_image_id, source_region, name, description) source_image_id, source_region, name, description
)
template = self.response_template(COPY_IMAGE_RESPONSE) template = self.response_template(COPY_IMAGE_RESPONSE)
return template.render(image=image) return template.render(image=image)
def deregister_image(self): def deregister_image(self):
ami_id = self._get_param('ImageId') ami_id = self._get_param("ImageId")
if self.is_not_dryrun('DeregisterImage'): if self.is_not_dryrun("DeregisterImage"):
success = self.ec2_backend.deregister_image(ami_id) success = self.ec2_backend.deregister_image(ami_id)
template = self.response_template(DEREGISTER_IMAGE_RESPONSE) template = self.response_template(DEREGISTER_IMAGE_RESPONSE)
return template.render(success=str(success).lower()) return template.render(success=str(success).lower())
def describe_images(self): def describe_images(self):
ami_ids = self._get_multi_param('ImageId') ami_ids = self._get_multi_param("ImageId")
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
owners = self._get_multi_param('Owner') owners = self._get_multi_param("Owner")
exec_users = self._get_multi_param('ExecutableBy') exec_users = self._get_multi_param("ExecutableBy")
images = self.ec2_backend.describe_images( images = self.ec2_backend.describe_images(
ami_ids=ami_ids, filters=filters, exec_users=exec_users, ami_ids=ami_ids,
owners=owners, context=self) filters=filters,
exec_users=exec_users,
owners=owners,
context=self,
)
template = self.response_template(DESCRIBE_IMAGES_RESPONSE) template = self.response_template(DESCRIBE_IMAGES_RESPONSE)
return template.render(images=images) return template.render(images=images)
def describe_image_attribute(self): def describe_image_attribute(self):
ami_id = self._get_param('ImageId') ami_id = self._get_param("ImageId")
groups = self.ec2_backend.get_launch_permission_groups(ami_id) groups = self.ec2_backend.get_launch_permission_groups(ami_id)
users = self.ec2_backend.get_launch_permission_users(ami_id) users = self.ec2_backend.get_launch_permission_users(ami_id)
template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE) template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE)
return template.render(ami_id=ami_id, groups=groups, users=users) return template.render(ami_id=ami_id, groups=groups, users=users)
def modify_image_attribute(self): def modify_image_attribute(self):
ami_id = self._get_param('ImageId') ami_id = self._get_param("ImageId")
operation_type = self._get_param('OperationType') operation_type = self._get_param("OperationType")
group = self._get_param('UserGroup.1') group = self._get_param("UserGroup.1")
user_ids = self._get_multi_param('UserId') user_ids = self._get_multi_param("UserId")
if self.is_not_dryrun('ModifyImageAttribute'): if self.is_not_dryrun("ModifyImageAttribute"):
if (operation_type == 'add'): if operation_type == "add":
self.ec2_backend.add_launch_permission( self.ec2_backend.add_launch_permission(
ami_id, user_ids=user_ids, group=group) ami_id, user_ids=user_ids, group=group
elif (operation_type == 'remove'): )
elif operation_type == "remove":
self.ec2_backend.remove_launch_permission( self.ec2_backend.remove_launch_permission(
ami_id, user_ids=user_ids, group=group) ami_id, user_ids=user_ids, group=group
)
return MODIFY_IMAGE_ATTRIBUTE_RESPONSE return MODIFY_IMAGE_ATTRIBUTE_RESPONSE
def register_image(self): def register_image(self):
if self.is_not_dryrun('RegisterImage'): if self.is_not_dryrun("RegisterImage"):
raise NotImplementedError( raise NotImplementedError("AMIs.register_image is not yet implemented")
'AMIs.register_image is not yet implemented')
def reset_image_attribute(self): def reset_image_attribute(self):
if self.is_not_dryrun('ResetImageAttribute'): if self.is_not_dryrun("ResetImageAttribute"):
raise NotImplementedError( raise NotImplementedError(
'AMIs.reset_image_attribute is not yet implemented') "AMIs.reset_image_attribute is not yet implemented"
)
CREATE_IMAGE_RESPONSE = """<CreateImageResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_IMAGE_RESPONSE = """<CreateImageResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">

View File

@ -3,14 +3,13 @@ from moto.core.responses import BaseResponse
class AvailabilityZonesAndRegions(BaseResponse): class AvailabilityZonesAndRegions(BaseResponse):
def describe_availability_zones(self): def describe_availability_zones(self):
zones = self.ec2_backend.describe_availability_zones() zones = self.ec2_backend.describe_availability_zones()
template = self.response_template(DESCRIBE_ZONES_RESPONSE) template = self.response_template(DESCRIBE_ZONES_RESPONSE)
return template.render(zones=zones) return template.render(zones=zones)
def describe_regions(self): def describe_regions(self):
region_names = self._get_multi_param('RegionName') region_names = self._get_multi_param("RegionName")
regions = self.ec2_backend.describe_regions(region_names) regions = self.ec2_backend.describe_regions(region_names)
template = self.response_template(DESCRIBE_REGIONS_RESPONSE) template = self.response_template(DESCRIBE_REGIONS_RESPONSE)
return template.render(regions=regions) return template.render(regions=regions)

View File

@ -4,21 +4,20 @@ from moto.ec2.utils import filters_from_querystring
class CustomerGateways(BaseResponse): class CustomerGateways(BaseResponse):
def create_customer_gateway(self): def create_customer_gateway(self):
# raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented') # raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented')
type = self._get_param('Type') type = self._get_param("Type")
ip_address = self._get_param('IpAddress') ip_address = self._get_param("IpAddress")
bgp_asn = self._get_param('BgpAsn') bgp_asn = self._get_param("BgpAsn")
customer_gateway = self.ec2_backend.create_customer_gateway( customer_gateway = self.ec2_backend.create_customer_gateway(
type, ip_address=ip_address, bgp_asn=bgp_asn) type, ip_address=ip_address, bgp_asn=bgp_asn
)
template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=customer_gateway) return template.render(customer_gateway=customer_gateway)
def delete_customer_gateway(self): def delete_customer_gateway(self):
customer_gateway_id = self._get_param('CustomerGatewayId') customer_gateway_id = self._get_param("CustomerGatewayId")
delete_status = self.ec2_backend.delete_customer_gateway( delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id)
customer_gateway_id)
template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE) template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE)
return template.render(customer_gateway=delete_status) return template.render(customer_gateway=delete_status)

View File

@ -1,15 +1,12 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import ( from moto.ec2.utils import filters_from_querystring, dhcp_configuration_from_querystring
filters_from_querystring,
dhcp_configuration_from_querystring)
class DHCPOptions(BaseResponse): class DHCPOptions(BaseResponse):
def associate_dhcp_options(self): def associate_dhcp_options(self):
dhcp_opt_id = self._get_param('DhcpOptionsId') dhcp_opt_id = self._get_param("DhcpOptionsId")
vpc_id = self._get_param('VpcId') vpc_id = self._get_param("VpcId")
dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0] dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0]
vpc = self.ec2_backend.get_vpc(vpc_id) vpc = self.ec2_backend.get_vpc(vpc_id)
@ -35,14 +32,14 @@ class DHCPOptions(BaseResponse):
domain_name=domain_name, domain_name=domain_name,
ntp_servers=ntp_servers, ntp_servers=ntp_servers,
netbios_name_servers=netbios_name_servers, netbios_name_servers=netbios_name_servers,
netbios_node_type=netbios_node_type netbios_node_type=netbios_node_type,
) )
template = self.response_template(CREATE_DHCP_OPTIONS_RESPONSE) template = self.response_template(CREATE_DHCP_OPTIONS_RESPONSE)
return template.render(dhcp_options_set=dhcp_options_set) return template.render(dhcp_options_set=dhcp_options_set)
def delete_dhcp_options(self): def delete_dhcp_options(self):
dhcp_opt_id = self._get_param('DhcpOptionsId') dhcp_opt_id = self._get_param("DhcpOptionsId")
delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id) delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id)
template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE) template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE)
return template.render(delete_status=delete_status) return template.render(delete_status=delete_status)
@ -50,13 +47,12 @@ class DHCPOptions(BaseResponse):
def describe_dhcp_options(self): def describe_dhcp_options(self):
dhcp_opt_ids = self._get_multi_param("DhcpOptionsId") dhcp_opt_ids = self._get_multi_param("DhcpOptionsId")
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
dhcp_opts = self.ec2_backend.get_all_dhcp_options( dhcp_opts = self.ec2_backend.get_all_dhcp_options(dhcp_opt_ids, filters)
dhcp_opt_ids, filters)
template = self.response_template(DESCRIBE_DHCP_OPTIONS_RESPONSE) template = self.response_template(DESCRIBE_DHCP_OPTIONS_RESPONSE)
return template.render(dhcp_options=dhcp_opts) return template.render(dhcp_options=dhcp_opts)
CREATE_DHCP_OPTIONS_RESPONSE = u""" CREATE_DHCP_OPTIONS_RESPONSE = """
<CreateDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <CreateDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<dhcpOptions> <dhcpOptions>
@ -92,14 +88,14 @@ CREATE_DHCP_OPTIONS_RESPONSE = u"""
</CreateDhcpOptionsResponse> </CreateDhcpOptionsResponse>
""" """
DELETE_DHCP_OPTIONS_RESPONSE = u""" DELETE_DHCP_OPTIONS_RESPONSE = """
<DeleteDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <DeleteDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<return>{{delete_status}}</return> <return>{{delete_status}}</return>
</DeleteDhcpOptionsResponse> </DeleteDhcpOptionsResponse>
""" """
DESCRIBE_DHCP_OPTIONS_RESPONSE = u""" DESCRIBE_DHCP_OPTIONS_RESPONSE = """
<DescribeDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <DescribeDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<dhcpOptionsSet> <dhcpOptionsSet>
@ -139,7 +135,7 @@ DESCRIBE_DHCP_OPTIONS_RESPONSE = u"""
</DescribeDhcpOptionsResponse> </DescribeDhcpOptionsResponse>
""" """
ASSOCIATE_DHCP_OPTIONS_RESPONSE = u""" ASSOCIATE_DHCP_OPTIONS_RESPONSE = """
<AssociateDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <AssociateDhcpOptionsResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId> <requestId>7a62c49f-347e-4fc4-9331-6e8eEXAMPLE</requestId>
<return>true</return> <return>true</return>

View File

@ -4,137 +4,148 @@ from moto.ec2.utils import filters_from_querystring
class ElasticBlockStore(BaseResponse): class ElasticBlockStore(BaseResponse):
def attach_volume(self): def attach_volume(self):
volume_id = self._get_param('VolumeId') volume_id = self._get_param("VolumeId")
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
device_path = self._get_param('Device') device_path = self._get_param("Device")
if self.is_not_dryrun('AttachVolume'): if self.is_not_dryrun("AttachVolume"):
attachment = self.ec2_backend.attach_volume( attachment = self.ec2_backend.attach_volume(
volume_id, instance_id, device_path) volume_id, instance_id, device_path
)
template = self.response_template(ATTACHED_VOLUME_RESPONSE) template = self.response_template(ATTACHED_VOLUME_RESPONSE)
return template.render(attachment=attachment) return template.render(attachment=attachment)
def copy_snapshot(self): def copy_snapshot(self):
source_snapshot_id = self._get_param('SourceSnapshotId') source_snapshot_id = self._get_param("SourceSnapshotId")
source_region = self._get_param('SourceRegion') source_region = self._get_param("SourceRegion")
description = self._get_param('Description') description = self._get_param("Description")
if self.is_not_dryrun('CopySnapshot'): if self.is_not_dryrun("CopySnapshot"):
snapshot = self.ec2_backend.copy_snapshot( snapshot = self.ec2_backend.copy_snapshot(
source_snapshot_id, source_region, description) source_snapshot_id, source_region, description
)
template = self.response_template(COPY_SNAPSHOT_RESPONSE) template = self.response_template(COPY_SNAPSHOT_RESPONSE)
return template.render(snapshot=snapshot) return template.render(snapshot=snapshot)
def create_snapshot(self): def create_snapshot(self):
volume_id = self._get_param('VolumeId') volume_id = self._get_param("VolumeId")
description = self._get_param('Description') description = self._get_param("Description")
tags = self._parse_tag_specification("TagSpecification") tags = self._parse_tag_specification("TagSpecification")
snapshot_tags = tags.get('snapshot', {}) snapshot_tags = tags.get("snapshot", {})
if self.is_not_dryrun('CreateSnapshot'): if self.is_not_dryrun("CreateSnapshot"):
snapshot = self.ec2_backend.create_snapshot(volume_id, description) snapshot = self.ec2_backend.create_snapshot(volume_id, description)
snapshot.add_tags(snapshot_tags) snapshot.add_tags(snapshot_tags)
template = self.response_template(CREATE_SNAPSHOT_RESPONSE) template = self.response_template(CREATE_SNAPSHOT_RESPONSE)
return template.render(snapshot=snapshot) return template.render(snapshot=snapshot)
def create_volume(self): def create_volume(self):
size = self._get_param('Size') size = self._get_param("Size")
zone = self._get_param('AvailabilityZone') zone = self._get_param("AvailabilityZone")
snapshot_id = self._get_param('SnapshotId') snapshot_id = self._get_param("SnapshotId")
tags = self._parse_tag_specification("TagSpecification") tags = self._parse_tag_specification("TagSpecification")
volume_tags = tags.get('volume', {}) volume_tags = tags.get("volume", {})
encrypted = self._get_param('Encrypted', if_none=False) encrypted = self._get_param("Encrypted", if_none=False)
if self.is_not_dryrun('CreateVolume'): if self.is_not_dryrun("CreateVolume"):
volume = self.ec2_backend.create_volume( volume = self.ec2_backend.create_volume(size, zone, snapshot_id, encrypted)
size, zone, snapshot_id, encrypted)
volume.add_tags(volume_tags) volume.add_tags(volume_tags)
template = self.response_template(CREATE_VOLUME_RESPONSE) template = self.response_template(CREATE_VOLUME_RESPONSE)
return template.render(volume=volume) return template.render(volume=volume)
def delete_snapshot(self): def delete_snapshot(self):
snapshot_id = self._get_param('SnapshotId') snapshot_id = self._get_param("SnapshotId")
if self.is_not_dryrun('DeleteSnapshot'): if self.is_not_dryrun("DeleteSnapshot"):
self.ec2_backend.delete_snapshot(snapshot_id) self.ec2_backend.delete_snapshot(snapshot_id)
return DELETE_SNAPSHOT_RESPONSE return DELETE_SNAPSHOT_RESPONSE
def delete_volume(self): def delete_volume(self):
volume_id = self._get_param('VolumeId') volume_id = self._get_param("VolumeId")
if self.is_not_dryrun('DeleteVolume'): if self.is_not_dryrun("DeleteVolume"):
self.ec2_backend.delete_volume(volume_id) self.ec2_backend.delete_volume(volume_id)
return DELETE_VOLUME_RESPONSE return DELETE_VOLUME_RESPONSE
def describe_snapshots(self): def describe_snapshots(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
snapshot_ids = self._get_multi_param('SnapshotId') snapshot_ids = self._get_multi_param("SnapshotId")
snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters) snapshots = self.ec2_backend.describe_snapshots(
snapshot_ids=snapshot_ids, filters=filters
)
template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE) template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE)
return template.render(snapshots=snapshots) return template.render(snapshots=snapshots)
def describe_volumes(self): def describe_volumes(self):
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
volume_ids = self._get_multi_param('VolumeId') volume_ids = self._get_multi_param("VolumeId")
volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters) volumes = self.ec2_backend.describe_volumes(
volume_ids=volume_ids, filters=filters
)
template = self.response_template(DESCRIBE_VOLUMES_RESPONSE) template = self.response_template(DESCRIBE_VOLUMES_RESPONSE)
return template.render(volumes=volumes) return template.render(volumes=volumes)
def describe_volume_attribute(self): def describe_volume_attribute(self):
raise NotImplementedError( raise NotImplementedError(
'ElasticBlockStore.describe_volume_attribute is not yet implemented') "ElasticBlockStore.describe_volume_attribute is not yet implemented"
)
def describe_volume_status(self): def describe_volume_status(self):
raise NotImplementedError( raise NotImplementedError(
'ElasticBlockStore.describe_volume_status is not yet implemented') "ElasticBlockStore.describe_volume_status is not yet implemented"
)
def detach_volume(self): def detach_volume(self):
volume_id = self._get_param('VolumeId') volume_id = self._get_param("VolumeId")
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
device_path = self._get_param('Device') device_path = self._get_param("Device")
if self.is_not_dryrun('DetachVolume'): if self.is_not_dryrun("DetachVolume"):
attachment = self.ec2_backend.detach_volume( attachment = self.ec2_backend.detach_volume(
volume_id, instance_id, device_path) volume_id, instance_id, device_path
)
template = self.response_template(DETATCH_VOLUME_RESPONSE) template = self.response_template(DETATCH_VOLUME_RESPONSE)
return template.render(attachment=attachment) return template.render(attachment=attachment)
def enable_volume_io(self): def enable_volume_io(self):
if self.is_not_dryrun('EnableVolumeIO'): if self.is_not_dryrun("EnableVolumeIO"):
raise NotImplementedError( raise NotImplementedError(
'ElasticBlockStore.enable_volume_io is not yet implemented') "ElasticBlockStore.enable_volume_io is not yet implemented"
)
def import_volume(self): def import_volume(self):
if self.is_not_dryrun('ImportVolume'): if self.is_not_dryrun("ImportVolume"):
raise NotImplementedError( raise NotImplementedError(
'ElasticBlockStore.import_volume is not yet implemented') "ElasticBlockStore.import_volume is not yet implemented"
)
def describe_snapshot_attribute(self): def describe_snapshot_attribute(self):
snapshot_id = self._get_param('SnapshotId') snapshot_id = self._get_param("SnapshotId")
groups = self.ec2_backend.get_create_volume_permission_groups( groups = self.ec2_backend.get_create_volume_permission_groups(snapshot_id)
snapshot_id) template = self.response_template(DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE)
template = self.response_template(
DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE)
return template.render(snapshot_id=snapshot_id, groups=groups) return template.render(snapshot_id=snapshot_id, groups=groups)
def modify_snapshot_attribute(self): def modify_snapshot_attribute(self):
snapshot_id = self._get_param('SnapshotId') snapshot_id = self._get_param("SnapshotId")
operation_type = self._get_param('OperationType') operation_type = self._get_param("OperationType")
group = self._get_param('UserGroup.1') group = self._get_param("UserGroup.1")
user_id = self._get_param('UserId.1') user_id = self._get_param("UserId.1")
if self.is_not_dryrun('ModifySnapshotAttribute'): if self.is_not_dryrun("ModifySnapshotAttribute"):
if (operation_type == 'add'): if operation_type == "add":
self.ec2_backend.add_create_volume_permission( self.ec2_backend.add_create_volume_permission(
snapshot_id, user_id=user_id, group=group) snapshot_id, user_id=user_id, group=group
elif (operation_type == 'remove'): )
elif operation_type == "remove":
self.ec2_backend.remove_create_volume_permission( self.ec2_backend.remove_create_volume_permission(
snapshot_id, user_id=user_id, group=group) snapshot_id, user_id=user_id, group=group
)
return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE
def modify_volume_attribute(self): def modify_volume_attribute(self):
if self.is_not_dryrun('ModifyVolumeAttribute'): if self.is_not_dryrun("ModifyVolumeAttribute"):
raise NotImplementedError( raise NotImplementedError(
'ElasticBlockStore.modify_volume_attribute is not yet implemented') "ElasticBlockStore.modify_volume_attribute is not yet implemented"
)
def reset_snapshot_attribute(self): def reset_snapshot_attribute(self):
if self.is_not_dryrun('ResetSnapshotAttribute'): if self.is_not_dryrun("ResetSnapshotAttribute"):
raise NotImplementedError( raise NotImplementedError(
'ElasticBlockStore.reset_snapshot_attribute is not yet implemented') "ElasticBlockStore.reset_snapshot_attribute is not yet implemented"
)
CREATE_VOLUME_RESPONSE = """<CreateVolumeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_VOLUME_RESPONSE = """<CreateVolumeResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">

View File

@ -4,14 +4,14 @@ from moto.ec2.utils import filters_from_querystring
class ElasticIPAddresses(BaseResponse): class ElasticIPAddresses(BaseResponse):
def allocate_address(self): def allocate_address(self):
domain = self._get_param('Domain', if_none='standard') domain = self._get_param("Domain", if_none="standard")
reallocate_address = self._get_param('Address', if_none=None) reallocate_address = self._get_param("Address", if_none=None)
if self.is_not_dryrun('AllocateAddress'): if self.is_not_dryrun("AllocateAddress"):
if reallocate_address: if reallocate_address:
address = self.ec2_backend.allocate_address( address = self.ec2_backend.allocate_address(
domain, address=reallocate_address) domain, address=reallocate_address
)
else: else:
address = self.ec2_backend.allocate_address(domain) address = self.ec2_backend.allocate_address(domain)
template = self.response_template(ALLOCATE_ADDRESS_RESPONSE) template = self.response_template(ALLOCATE_ADDRESS_RESPONSE)
@ -21,73 +21,92 @@ class ElasticIPAddresses(BaseResponse):
instance = eni = None instance = eni = None
if "InstanceId" in self.querystring: if "InstanceId" in self.querystring:
instance = self.ec2_backend.get_instance( instance = self.ec2_backend.get_instance(self._get_param("InstanceId"))
self._get_param('InstanceId'))
elif "NetworkInterfaceId" in self.querystring: elif "NetworkInterfaceId" in self.querystring:
eni = self.ec2_backend.get_network_interface( eni = self.ec2_backend.get_network_interface(
self._get_param('NetworkInterfaceId')) self._get_param("NetworkInterfaceId")
)
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.") "MissingParameter",
"Invalid request, expect InstanceId/NetworkId parameter.",
)
reassociate = False reassociate = False
if "AllowReassociation" in self.querystring: if "AllowReassociation" in self.querystring:
reassociate = self._get_param('AllowReassociation') == "true" reassociate = self._get_param("AllowReassociation") == "true"
if self.is_not_dryrun('AssociateAddress'): if self.is_not_dryrun("AssociateAddress"):
if instance or eni: if instance or eni:
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
eip = self.ec2_backend.associate_address( eip = self.ec2_backend.associate_address(
instance=instance, eni=eni, instance=instance,
address=self._get_param('PublicIp'), reassociate=reassociate) eni=eni,
address=self._get_param("PublicIp"),
reassociate=reassociate,
)
elif "AllocationId" in self.querystring: elif "AllocationId" in self.querystring:
eip = self.ec2_backend.associate_address( eip = self.ec2_backend.associate_address(
instance=instance, eni=eni, instance=instance,
allocation_id=self._get_param('AllocationId'), reassociate=reassociate) eni=eni,
allocation_id=self._get_param("AllocationId"),
reassociate=reassociate,
)
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") "MissingParameter",
"Invalid request, expect PublicIp/AllocationId parameter.",
)
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect either instance or ENI.") "MissingParameter",
"Invalid request, expect either instance or ENI.",
)
template = self.response_template(ASSOCIATE_ADDRESS_RESPONSE) template = self.response_template(ASSOCIATE_ADDRESS_RESPONSE)
return template.render(address=eip) return template.render(address=eip)
def describe_addresses(self): def describe_addresses(self):
allocation_ids = self._get_multi_param('AllocationId') allocation_ids = self._get_multi_param("AllocationId")
public_ips = self._get_multi_param('PublicIp') public_ips = self._get_multi_param("PublicIp")
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
addresses = self.ec2_backend.describe_addresses( addresses = self.ec2_backend.describe_addresses(
allocation_ids, public_ips, filters) allocation_ids, public_ips, filters
)
template = self.response_template(DESCRIBE_ADDRESS_RESPONSE) template = self.response_template(DESCRIBE_ADDRESS_RESPONSE)
return template.render(addresses=addresses) return template.render(addresses=addresses)
def disassociate_address(self): def disassociate_address(self):
if self.is_not_dryrun('DisAssociateAddress'): if self.is_not_dryrun("DisAssociateAddress"):
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
self.ec2_backend.disassociate_address( self.ec2_backend.disassociate_address(
address=self._get_param('PublicIp')) address=self._get_param("PublicIp")
)
elif "AssociationId" in self.querystring: elif "AssociationId" in self.querystring:
self.ec2_backend.disassociate_address( self.ec2_backend.disassociate_address(
association_id=self._get_param('AssociationId')) association_id=self._get_param("AssociationId")
)
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.") "MissingParameter",
"Invalid request, expect PublicIp/AssociationId parameter.",
)
return self.response_template(DISASSOCIATE_ADDRESS_RESPONSE).render() return self.response_template(DISASSOCIATE_ADDRESS_RESPONSE).render()
def release_address(self): def release_address(self):
if self.is_not_dryrun('ReleaseAddress'): if self.is_not_dryrun("ReleaseAddress"):
if "PublicIp" in self.querystring: if "PublicIp" in self.querystring:
self.ec2_backend.release_address( self.ec2_backend.release_address(address=self._get_param("PublicIp"))
address=self._get_param('PublicIp'))
elif "AllocationId" in self.querystring: elif "AllocationId" in self.querystring:
self.ec2_backend.release_address( self.ec2_backend.release_address(
allocation_id=self._get_param('AllocationId')) allocation_id=self._get_param("AllocationId")
)
else: else:
self.ec2_backend.raise_error( self.ec2_backend.raise_error(
"MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") "MissingParameter",
"Invalid request, expect PublicIp/AllocationId parameter.",
)
return self.response_template(RELEASE_ADDRESS_RESPONSE).render() return self.response_template(RELEASE_ADDRESS_RESPONSE).render()

View File

@ -4,71 +4,69 @@ from moto.ec2.utils import filters_from_querystring
class ElasticNetworkInterfaces(BaseResponse): class ElasticNetworkInterfaces(BaseResponse):
def create_network_interface(self): def create_network_interface(self):
subnet_id = self._get_param('SubnetId') subnet_id = self._get_param("SubnetId")
private_ip_address = self._get_param('PrivateIpAddress') private_ip_address = self._get_param("PrivateIpAddress")
groups = self._get_multi_param('SecurityGroupId') groups = self._get_multi_param("SecurityGroupId")
subnet = self.ec2_backend.get_subnet(subnet_id) subnet = self.ec2_backend.get_subnet(subnet_id)
description = self._get_param('Description') description = self._get_param("Description")
if self.is_not_dryrun('CreateNetworkInterface'): if self.is_not_dryrun("CreateNetworkInterface"):
eni = self.ec2_backend.create_network_interface( eni = self.ec2_backend.create_network_interface(
subnet, private_ip_address, groups, description) subnet, private_ip_address, groups, description
template = self.response_template( )
CREATE_NETWORK_INTERFACE_RESPONSE) template = self.response_template(CREATE_NETWORK_INTERFACE_RESPONSE)
return template.render(eni=eni) return template.render(eni=eni)
def delete_network_interface(self): def delete_network_interface(self):
eni_id = self._get_param('NetworkInterfaceId') eni_id = self._get_param("NetworkInterfaceId")
if self.is_not_dryrun('DeleteNetworkInterface'): if self.is_not_dryrun("DeleteNetworkInterface"):
self.ec2_backend.delete_network_interface(eni_id) self.ec2_backend.delete_network_interface(eni_id)
template = self.response_template( template = self.response_template(DELETE_NETWORK_INTERFACE_RESPONSE)
DELETE_NETWORK_INTERFACE_RESPONSE)
return template.render() return template.render()
def describe_network_interface_attribute(self): def describe_network_interface_attribute(self):
raise NotImplementedError( raise NotImplementedError(
'ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented') "ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented"
)
def describe_network_interfaces(self): def describe_network_interfaces(self):
eni_ids = self._get_multi_param('NetworkInterfaceId') eni_ids = self._get_multi_param("NetworkInterfaceId")
filters = filters_from_querystring(self.querystring) filters = filters_from_querystring(self.querystring)
enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters) enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters)
template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE) template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE)
return template.render(enis=enis) return template.render(enis=enis)
def attach_network_interface(self): def attach_network_interface(self):
eni_id = self._get_param('NetworkInterfaceId') eni_id = self._get_param("NetworkInterfaceId")
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
device_index = self._get_param('DeviceIndex') device_index = self._get_param("DeviceIndex")
if self.is_not_dryrun('AttachNetworkInterface'): if self.is_not_dryrun("AttachNetworkInterface"):
attachment_id = self.ec2_backend.attach_network_interface( attachment_id = self.ec2_backend.attach_network_interface(
eni_id, instance_id, device_index) eni_id, instance_id, device_index
template = self.response_template( )
ATTACH_NETWORK_INTERFACE_RESPONSE) template = self.response_template(ATTACH_NETWORK_INTERFACE_RESPONSE)
return template.render(attachment_id=attachment_id) return template.render(attachment_id=attachment_id)
def detach_network_interface(self): def detach_network_interface(self):
attachment_id = self._get_param('AttachmentId') attachment_id = self._get_param("AttachmentId")
if self.is_not_dryrun('DetachNetworkInterface'): if self.is_not_dryrun("DetachNetworkInterface"):
self.ec2_backend.detach_network_interface(attachment_id) self.ec2_backend.detach_network_interface(attachment_id)
template = self.response_template( template = self.response_template(DETACH_NETWORK_INTERFACE_RESPONSE)
DETACH_NETWORK_INTERFACE_RESPONSE)
return template.render() return template.render()
def modify_network_interface_attribute(self): def modify_network_interface_attribute(self):
# Currently supports modifying one and only one security group # Currently supports modifying one and only one security group
eni_id = self._get_param('NetworkInterfaceId') eni_id = self._get_param("NetworkInterfaceId")
group_id = self._get_param('SecurityGroupId.1') group_id = self._get_param("SecurityGroupId.1")
if self.is_not_dryrun('ModifyNetworkInterface'): if self.is_not_dryrun("ModifyNetworkInterface"):
self.ec2_backend.modify_network_interface_attribute( self.ec2_backend.modify_network_interface_attribute(eni_id, group_id)
eni_id, group_id)
return MODIFY_NETWORK_INTERFACE_ATTRIBUTE_RESPONSE return MODIFY_NETWORK_INTERFACE_ATTRIBUTE_RESPONSE
def reset_network_interface_attribute(self): def reset_network_interface_attribute(self):
if self.is_not_dryrun('ResetNetworkInterface'): if self.is_not_dryrun("ResetNetworkInterface"):
raise NotImplementedError( raise NotImplementedError(
'ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented') "ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented"
)
CREATE_NETWORK_INTERFACE_RESPONSE = """ CREATE_NETWORK_INTERFACE_RESPONSE = """

View File

@ -3,20 +3,19 @@ from moto.core.responses import BaseResponse
class General(BaseResponse): class General(BaseResponse):
def get_console_output(self): def get_console_output(self):
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
if not instance_id: if not instance_id:
# For compatibility with boto. # For compatibility with boto.
# See: https://github.com/spulec/moto/pull/1152#issuecomment-332487599 # See: https://github.com/spulec/moto/pull/1152#issuecomment-332487599
instance_id = self._get_multi_param('InstanceId')[0] instance_id = self._get_multi_param("InstanceId")[0]
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
template = self.response_template(GET_CONSOLE_OUTPUT_RESULT) template = self.response_template(GET_CONSOLE_OUTPUT_RESULT)
return template.render(instance=instance) return template.render(instance=instance)
GET_CONSOLE_OUTPUT_RESULT = ''' GET_CONSOLE_OUTPUT_RESULT = """
<GetConsoleOutputResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> <GetConsoleOutputResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<instanceId>{{ instance.id }}</instanceId> <instanceId>{{ instance.id }}</instanceId>
@ -29,4 +28,4 @@ R0hNRU0gYXZhaWxhYmxlLgo3MjdNQiBMT1dNRU0gYXZhaWxhYmxlLgpOWCAoRXhlY3V0ZSBEaXNh
YmxlKSBwcm90ZWN0aW9uOiBhY3RpdmUKSVJRIGxvY2t1cCBkZXRlY3Rpb24gZGlzYWJsZWQKQnVp YmxlKSBwcm90ZWN0aW9uOiBhY3RpdmUKSVJRIGxvY2t1cCBkZXRlY3Rpb24gZGlzYWJsZWQKQnVp
bHQgMSB6b25lbGlzdHMKS2VybmVsIGNvbW1hbmQgbGluZTogcm9vdD0vZGV2L3NkYTEgcm8gNApF bHQgMSB6b25lbGlzdHMKS2VybmVsIGNvbW1hbmQgbGluZTogcm9vdD0vZGV2L3NkYTEgcm8gNApF
bmFibGluZyBmYXN0IEZQVSBzYXZlIGFuZCByZXN0b3JlLi4uIGRvbmUuCg==</output> bmFibGluZyBmYXN0IEZQVSBzYXZlIGFuZCByZXN0b3JlLi4uIGRvbmUuCg==</output>
</GetConsoleOutputResponse>''' </GetConsoleOutputResponse>"""

View File

@ -4,20 +4,19 @@ from boto.ec2.instancetype import InstanceType
from moto.autoscaling import autoscaling_backends from moto.autoscaling import autoscaling_backends
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.core.utils import camelcase_to_underscores from moto.core.utils import camelcase_to_underscores
from moto.ec2.utils import filters_from_querystring, \ from moto.ec2.utils import filters_from_querystring, dict_from_querystring
dict_from_querystring
from moto.elbv2 import elbv2_backends from moto.elbv2 import elbv2_backends
class InstanceResponse(BaseResponse): class InstanceResponse(BaseResponse):
def describe_instances(self): def describe_instances(self):
filter_dict = filters_from_querystring(self.querystring) filter_dict = filters_from_querystring(self.querystring)
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param("InstanceId")
token = self._get_param("NextToken") token = self._get_param("NextToken")
if instance_ids: if instance_ids:
reservations = self.ec2_backend.get_reservations_by_instance_ids( reservations = self.ec2_backend.get_reservations_by_instance_ids(
instance_ids, filters=filter_dict) instance_ids, filters=filter_dict
)
else: else:
reservations = self.ec2_backend.all_reservations(filters=filter_dict) reservations = self.ec2_backend.all_reservations(filters=filter_dict)
@ -26,47 +25,66 @@ class InstanceResponse(BaseResponse):
start = reservation_ids.index(token) + 1 start = reservation_ids.index(token) + 1
else: else:
start = 0 start = 0
max_results = int(self._get_param('MaxResults', 100)) max_results = int(self._get_param("MaxResults", 100))
reservations_resp = reservations[start:start + max_results] reservations_resp = reservations[start : start + max_results]
next_token = None next_token = None
if max_results and len(reservations) > (start + max_results): if max_results and len(reservations) > (start + max_results):
next_token = reservations_resp[-1].id next_token = reservations_resp[-1].id
template = self.response_template(EC2_DESCRIBE_INSTANCES) template = self.response_template(EC2_DESCRIBE_INSTANCES)
return template.render(reservations=reservations_resp, next_token=next_token).replace('True', 'true').replace('False', 'false') return (
template.render(reservations=reservations_resp, next_token=next_token)
.replace("True", "true")
.replace("False", "false")
)
def run_instances(self): def run_instances(self):
min_count = int(self._get_param('MinCount', if_none='1')) min_count = int(self._get_param("MinCount", if_none="1"))
image_id = self._get_param('ImageId') image_id = self._get_param("ImageId")
owner_id = self._get_param('OwnerId') owner_id = self._get_param("OwnerId")
user_data = self._get_param('UserData') user_data = self._get_param("UserData")
security_group_names = self._get_multi_param('SecurityGroup') security_group_names = self._get_multi_param("SecurityGroup")
security_group_ids = self._get_multi_param('SecurityGroupId') security_group_ids = self._get_multi_param("SecurityGroupId")
nics = dict_from_querystring("NetworkInterface", self.querystring) nics = dict_from_querystring("NetworkInterface", self.querystring)
instance_type = self._get_param('InstanceType', if_none='m1.small') instance_type = self._get_param("InstanceType", if_none="m1.small")
placement = self._get_param('Placement.AvailabilityZone') placement = self._get_param("Placement.AvailabilityZone")
subnet_id = self._get_param('SubnetId') subnet_id = self._get_param("SubnetId")
private_ip = self._get_param('PrivateIpAddress') private_ip = self._get_param("PrivateIpAddress")
associate_public_ip = self._get_param('AssociatePublicIpAddress') associate_public_ip = self._get_param("AssociatePublicIpAddress")
key_name = self._get_param('KeyName') key_name = self._get_param("KeyName")
ebs_optimized = self._get_param('EbsOptimized') ebs_optimized = self._get_param("EbsOptimized")
instance_initiated_shutdown_behavior = self._get_param("InstanceInitiatedShutdownBehavior") instance_initiated_shutdown_behavior = self._get_param(
"InstanceInitiatedShutdownBehavior"
)
tags = self._parse_tag_specification("TagSpecification") tags = self._parse_tag_specification("TagSpecification")
region_name = self.region region_name = self.region
if self.is_not_dryrun('RunInstance'): if self.is_not_dryrun("RunInstance"):
new_reservation = self.ec2_backend.add_instances( new_reservation = self.ec2_backend.add_instances(
image_id, min_count, user_data, security_group_names, image_id,
instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id, min_count,
owner_id=owner_id, key_name=key_name, security_group_ids=security_group_ids, user_data,
nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip, security_group_names,
tags=tags, ebs_optimized=ebs_optimized, instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior) instance_type=instance_type,
placement=placement,
region_name=region_name,
subnet_id=subnet_id,
owner_id=owner_id,
key_name=key_name,
security_group_ids=security_group_ids,
nics=nics,
private_ip=private_ip,
associate_public_ip=associate_public_ip,
tags=tags,
ebs_optimized=ebs_optimized,
instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior,
)
template = self.response_template(EC2_RUN_INSTANCES) template = self.response_template(EC2_RUN_INSTANCES)
return template.render(reservation=new_reservation) return template.render(reservation=new_reservation)
def terminate_instances(self): def terminate_instances(self):
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param("InstanceId")
if self.is_not_dryrun('TerminateInstance'): if self.is_not_dryrun("TerminateInstance"):
instances = self.ec2_backend.terminate_instances(instance_ids) instances = self.ec2_backend.terminate_instances(instance_ids)
autoscaling_backends[self.region].notify_terminate_instances(instance_ids) autoscaling_backends[self.region].notify_terminate_instances(instance_ids)
elbv2_backends[self.region].notify_terminate_instances(instance_ids) elbv2_backends[self.region].notify_terminate_instances(instance_ids)
@ -74,33 +92,32 @@ class InstanceResponse(BaseResponse):
return template.render(instances=instances) return template.render(instances=instances)
def reboot_instances(self): def reboot_instances(self):
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param("InstanceId")
if self.is_not_dryrun('RebootInstance'): if self.is_not_dryrun("RebootInstance"):
instances = self.ec2_backend.reboot_instances(instance_ids) instances = self.ec2_backend.reboot_instances(instance_ids)
template = self.response_template(EC2_REBOOT_INSTANCES) template = self.response_template(EC2_REBOOT_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def stop_instances(self): def stop_instances(self):
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param("InstanceId")
if self.is_not_dryrun('StopInstance'): if self.is_not_dryrun("StopInstance"):
instances = self.ec2_backend.stop_instances(instance_ids) instances = self.ec2_backend.stop_instances(instance_ids)
template = self.response_template(EC2_STOP_INSTANCES) template = self.response_template(EC2_STOP_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def start_instances(self): def start_instances(self):
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param("InstanceId")
if self.is_not_dryrun('StartInstance'): if self.is_not_dryrun("StartInstance"):
instances = self.ec2_backend.start_instances(instance_ids) instances = self.ec2_backend.start_instances(instance_ids)
template = self.response_template(EC2_START_INSTANCES) template = self.response_template(EC2_START_INSTANCES)
return template.render(instances=instances) return template.render(instances=instances)
def describe_instance_status(self): def describe_instance_status(self):
instance_ids = self._get_multi_param('InstanceId') instance_ids = self._get_multi_param("InstanceId")
include_all_instances = self._get_param('IncludeAllInstances') == 'true' include_all_instances = self._get_param("IncludeAllInstances") == "true"
if instance_ids: if instance_ids:
instances = self.ec2_backend.get_multi_instances_by_id( instances = self.ec2_backend.get_multi_instances_by_id(instance_ids)
instance_ids)
elif include_all_instances: elif include_all_instances:
instances = self.ec2_backend.all_instances() instances = self.ec2_backend.all_instances()
else: else:
@ -110,40 +127,45 @@ class InstanceResponse(BaseResponse):
return template.render(instances=instances) return template.render(instances=instances)
def describe_instance_types(self): def describe_instance_types(self):
instance_types = [InstanceType( instance_types = [
name='t1.micro', cores=1, memory=644874240, disk=0)] InstanceType(name="t1.micro", cores=1, memory=644874240, disk=0)
]
template = self.response_template(EC2_DESCRIBE_INSTANCE_TYPES) template = self.response_template(EC2_DESCRIBE_INSTANCE_TYPES)
return template.render(instance_types=instance_types) return template.render(instance_types=instance_types)
def describe_instance_attribute(self): def describe_instance_attribute(self):
# TODO this and modify below should raise IncorrectInstanceState if # TODO this and modify below should raise IncorrectInstanceState if
# instance not in stopped state # instance not in stopped state
attribute = self._get_param('Attribute') attribute = self._get_param("Attribute")
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
instance, value = self.ec2_backend.describe_instance_attribute( instance, value = self.ec2_backend.describe_instance_attribute(
instance_id, attribute) instance_id, attribute
)
if attribute == "groupSet": if attribute == "groupSet":
template = self.response_template( template = self.response_template(EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE)
EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE)
else: else:
template = self.response_template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE) template = self.response_template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE)
return template.render(instance=instance, attribute=attribute, value=value) return template.render(instance=instance, attribute=attribute, value=value)
def modify_instance_attribute(self): def modify_instance_attribute(self):
handlers = [self._dot_value_instance_attribute_handler, handlers = [
self._dot_value_instance_attribute_handler,
self._block_device_mapping_handler, self._block_device_mapping_handler,
self._security_grp_instance_attribute_handler] self._security_grp_instance_attribute_handler,
]
for handler in handlers: for handler in handlers:
success = handler() success = handler()
if success: if success:
return success return success
msg = "This specific call to ModifyInstanceAttribute has not been" \ msg = (
" implemented in Moto yet. Feel free to open an issue at" \ "This specific call to ModifyInstanceAttribute has not been"
" implemented in Moto yet. Feel free to open an issue at"
" https://github.com/spulec/moto/issues" " https://github.com/spulec/moto/issues"
)
raise NotImplementedError(msg) raise NotImplementedError(msg)
def _block_device_mapping_handler(self): def _block_device_mapping_handler(self):
@ -166,8 +188,8 @@ class InstanceResponse(BaseResponse):
configuration, but it should be trivial to add anything else. configuration, but it should be trivial to add anything else.
""" """
mapping_counter = 1 mapping_counter = 1
mapping_device_name_fmt = 'BlockDeviceMapping.%s.DeviceName' mapping_device_name_fmt = "BlockDeviceMapping.%s.DeviceName"
mapping_del_on_term_fmt = 'BlockDeviceMapping.%s.Ebs.DeleteOnTermination' mapping_del_on_term_fmt = "BlockDeviceMapping.%s.Ebs.DeleteOnTermination"
while True: while True:
mapping_device_name = mapping_device_name_fmt % mapping_counter mapping_device_name = mapping_device_name_fmt % mapping_counter
if mapping_device_name not in self.querystring.keys(): if mapping_device_name not in self.querystring.keys():
@ -175,15 +197,14 @@ class InstanceResponse(BaseResponse):
mapping_del_on_term = mapping_del_on_term_fmt % mapping_counter mapping_del_on_term = mapping_del_on_term_fmt % mapping_counter
del_on_term_value_str = self.querystring[mapping_del_on_term][0] del_on_term_value_str = self.querystring[mapping_del_on_term][0]
del_on_term_value = True if 'true' == del_on_term_value_str else False del_on_term_value = True if "true" == del_on_term_value_str else False
device_name_value = self.querystring[mapping_device_name][0] device_name_value = self.querystring[mapping_device_name][0]
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
instance = self.ec2_backend.get_instance(instance_id) instance = self.ec2_backend.get_instance(instance_id)
if self.is_not_dryrun('ModifyInstanceAttribute'): if self.is_not_dryrun("ModifyInstanceAttribute"):
block_device_type = instance.block_device_mapping[ block_device_type = instance.block_device_mapping[device_name_value]
device_name_value]
block_device_type.delete_on_termination = del_on_term_value block_device_type.delete_on_termination = del_on_term_value
# +1 for the next device # +1 for the next device
@ -195,32 +216,33 @@ class InstanceResponse(BaseResponse):
def _dot_value_instance_attribute_handler(self): def _dot_value_instance_attribute_handler(self):
attribute_key = None attribute_key = None
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if '.Value' in key: if ".Value" in key:
attribute_key = key attribute_key = key
break break
if not attribute_key: if not attribute_key:
return return
if self.is_not_dryrun('Modify' + attribute_key.split(".")[0]): if self.is_not_dryrun("Modify" + attribute_key.split(".")[0]):
value = self.querystring.get(attribute_key)[0] value = self.querystring.get(attribute_key)[0]
normalized_attribute = camelcase_to_underscores( normalized_attribute = camelcase_to_underscores(attribute_key.split(".")[0])
attribute_key.split(".")[0]) instance_id = self._get_param("InstanceId")
instance_id = self._get_param('InstanceId')
self.ec2_backend.modify_instance_attribute( self.ec2_backend.modify_instance_attribute(
instance_id, normalized_attribute, value) instance_id, normalized_attribute, value
)
return EC2_MODIFY_INSTANCE_ATTRIBUTE return EC2_MODIFY_INSTANCE_ATTRIBUTE
def _security_grp_instance_attribute_handler(self): def _security_grp_instance_attribute_handler(self):
new_security_grp_list = [] new_security_grp_list = []
for key, value in self.querystring.items(): for key, value in self.querystring.items():
if 'GroupId.' in key: if "GroupId." in key:
new_security_grp_list.append(self.querystring.get(key)[0]) new_security_grp_list.append(self.querystring.get(key)[0])
instance_id = self._get_param('InstanceId') instance_id = self._get_param("InstanceId")
if self.is_not_dryrun('ModifyInstanceSecurityGroups'): if self.is_not_dryrun("ModifyInstanceSecurityGroups"):
self.ec2_backend.modify_instance_security_groups( self.ec2_backend.modify_instance_security_groups(
instance_id, new_security_grp_list) instance_id, new_security_grp_list
)
return EC2_MODIFY_INSTANCE_ATTRIBUTE return EC2_MODIFY_INSTANCE_ATTRIBUTE

View File

@ -1,29 +1,26 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from moto.core.responses import BaseResponse from moto.core.responses import BaseResponse
from moto.ec2.utils import ( from moto.ec2.utils import filters_from_querystring
filters_from_querystring,
)
class InternetGateways(BaseResponse): class InternetGateways(BaseResponse):
def attach_internet_gateway(self): def attach_internet_gateway(self):
igw_id = self._get_param('InternetGatewayId') igw_id = self._get_param("InternetGatewayId")
vpc_id = self._get_param('VpcId') vpc_id = self._get_param("VpcId")
if self.is_not_dryrun('AttachInternetGateway'): if self.is_not_dryrun("AttachInternetGateway"):
self.ec2_backend.attach_internet_gateway(igw_id, vpc_id) self.ec2_backend.attach_internet_gateway(igw_id, vpc_id)
template = self.response_template(ATTACH_INTERNET_GATEWAY_RESPONSE) template = self.response_template(ATTACH_INTERNET_GATEWAY_RESPONSE)
return template.render() return template.render()
def create_internet_gateway(self): def create_internet_gateway(self):
if self.is_not_dryrun('CreateInternetGateway'): if self.is_not_dryrun("CreateInternetGateway"):
igw = self.ec2_backend.create_internet_gateway() igw = self.ec2_backend.create_internet_gateway()
template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE) template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE)
return template.render(internet_gateway=igw) return template.render(internet_gateway=igw)
def delete_internet_gateway(self): def delete_internet_gateway(self):
igw_id = self._get_param('InternetGatewayId') igw_id = self._get_param("InternetGatewayId")
if self.is_not_dryrun('DeleteInternetGateway'): if self.is_not_dryrun("DeleteInternetGateway"):
self.ec2_backend.delete_internet_gateway(igw_id) self.ec2_backend.delete_internet_gateway(igw_id)
template = self.response_template(DELETE_INTERNET_GATEWAY_RESPONSE) template = self.response_template(DELETE_INTERNET_GATEWAY_RESPONSE)
return template.render() return template.render()
@ -33,10 +30,10 @@ class InternetGateways(BaseResponse):
if "InternetGatewayId.1" in self.querystring: if "InternetGatewayId.1" in self.querystring:
igw_ids = self._get_multi_param("InternetGatewayId") igw_ids = self._get_multi_param("InternetGatewayId")
igws = self.ec2_backend.describe_internet_gateways( igws = self.ec2_backend.describe_internet_gateways(
igw_ids, filters=filter_dict) igw_ids, filters=filter_dict
)
else: else:
igws = self.ec2_backend.describe_internet_gateways( igws = self.ec2_backend.describe_internet_gateways(filters=filter_dict)
filters=filter_dict)
template = self.response_template(DESCRIBE_INTERNET_GATEWAYS_RESPONSE) template = self.response_template(DESCRIBE_INTERNET_GATEWAYS_RESPONSE)
return template.render(internet_gateways=igws) return template.render(internet_gateways=igws)
@ -44,20 +41,20 @@ class InternetGateways(BaseResponse):
def detach_internet_gateway(self): def detach_internet_gateway(self):
# TODO validate no instances with EIPs in VPC before detaching # TODO validate no instances with EIPs in VPC before detaching
# raise else DependencyViolationError() # raise else DependencyViolationError()
igw_id = self._get_param('InternetGatewayId') igw_id = self._get_param("InternetGatewayId")
vpc_id = self._get_param('VpcId') vpc_id = self._get_param("VpcId")
if self.is_not_dryrun('DetachInternetGateway'): if self.is_not_dryrun("DetachInternetGateway"):
self.ec2_backend.detach_internet_gateway(igw_id, vpc_id) self.ec2_backend.detach_internet_gateway(igw_id, vpc_id)
template = self.response_template(DETACH_INTERNET_GATEWAY_RESPONSE) template = self.response_template(DETACH_INTERNET_GATEWAY_RESPONSE)
return template.render() return template.render()
ATTACH_INTERNET_GATEWAY_RESPONSE = u"""<AttachInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> ATTACH_INTERNET_GATEWAY_RESPONSE = """<AttachInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return> <return>true</return>
</AttachInternetGatewayResponse>""" </AttachInternetGatewayResponse>"""
CREATE_INTERNET_GATEWAY_RESPONSE = u"""<CreateInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> CREATE_INTERNET_GATEWAY_RESPONSE = """<CreateInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<internetGateway> <internetGateway>
<internetGatewayId>{{ internet_gateway.id }}</internetGatewayId> <internetGatewayId>{{ internet_gateway.id }}</internetGatewayId>
@ -75,12 +72,12 @@ CREATE_INTERNET_GATEWAY_RESPONSE = u"""<CreateInternetGatewayResponse xmlns="htt
</internetGateway> </internetGateway>
</CreateInternetGatewayResponse>""" </CreateInternetGatewayResponse>"""
DELETE_INTERNET_GATEWAY_RESPONSE = u"""<DeleteInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> DELETE_INTERNET_GATEWAY_RESPONSE = """<DeleteInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return> <return>true</return>
</DeleteInternetGatewayResponse>""" </DeleteInternetGatewayResponse>"""
DESCRIBE_INTERNET_GATEWAYS_RESPONSE = u"""<DescribeInternetGatewaysResponse xmlns="http://ec2.amazonaws.com/doc/2013-10- DESCRIBE_INTERNET_GATEWAYS_RESPONSE = """<DescribeInternetGatewaysResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-
15/"> 15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<internetGatewaySet> <internetGatewaySet>
@ -112,7 +109,7 @@ DESCRIBE_INTERNET_GATEWAYS_RESPONSE = u"""<DescribeInternetGatewaysResponse xmln
</internetGatewaySet> </internetGatewaySet>
</DescribeInternetGatewaysResponse>""" </DescribeInternetGatewaysResponse>"""
DETACH_INTERNET_GATEWAY_RESPONSE = u"""<DetachInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/"> DETACH_INTERNET_GATEWAY_RESPONSE = """<DetachInternetGatewayResponse xmlns="http://ec2.amazonaws.com/doc/2013-10-15/">
<requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId> <requestId>59dbff89-35bd-4eac-99ed-be587EXAMPLE</requestId>
<return>true</return> <return>true</return>
</DetachInternetGatewayResponse>""" </DetachInternetGatewayResponse>"""

Some files were not shown because too many files have changed in this diff Show More