diff --git a/.travis.yml b/.travis.yml index 77dd2ae55..3e6eb3809 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,7 @@ python: - 2.7 - 3.6 - 3.7 +- 3.8 env: - TEST_SERVER_MODE=false - TEST_SERVER_MODE=true @@ -17,7 +18,14 @@ install: python setup.py sdist if [ "$TEST_SERVER_MODE" = "true" ]; then - docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${TRAVIS_PYTHON_VERSION}-stretch /moto/travis_moto_server.sh & + if [ "$TRAVIS_PYTHON_VERSION" = "3.8" ]; then + # Python 3.8 does not provide Stretch images yet [1] + # [1] https://github.com/docker-library/python/issues/428 + PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-buster + else + PYTHON_DOCKER_TAG=${TRAVIS_PYTHON_VERSION}-stretch + fi + docker run --rm -t --name motoserver -e TEST_SERVER_MODE=true -e AWS_SECRET_ACCESS_KEY=server_secret -e AWS_ACCESS_KEY_ID=server_key -v `pwd`:/moto -p 5000:5000 -v /var/run/docker.sock:/var/run/docker.sock python:${PYTHON_DOCKER_TAG} /moto/travis_moto_server.sh & fi travis_retry pip install boto==2.45.0 travis_retry pip install boto3 @@ -29,7 +37,8 @@ install: python wait_for.py fi script: -- make test +- make test-only +- if [[ $TRAVIS_PYTHON_VERSION == "3.7" ]]; then make lint; fi after_success: - coveralls before_deploy: diff --git a/Makefile b/Makefile index 2a7249760..ca0286984 100644 --- a/Makefile +++ b/Makefile @@ -14,12 +14,16 @@ init: lint: flake8 moto + black --check moto/ tests/ -test: lint +test-only: rm -f .coverage rm -rf cover @nosetests -sv --with-coverage --cover-html ./tests/ $(TEST_EXCLUDE) + +test: lint test-only + test_server: @TEST_SERVER_MODE=true nosetests -sv --with-coverage --cover-html ./tests/ diff --git a/README.md b/README.md index 7642009f3..4024328a9 100644 --- a/README.md +++ b/README.md @@ -7,9 +7,9 @@ [![Docs](https://readthedocs.org/projects/pip/badge/?version=stable)](http://docs.getmoto.org) ![PyPI](https://img.shields.io/pypi/v/moto.svg) ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/moto.svg) -![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) +![PyPI - Downloads](https://img.shields.io/pypi/dw/moto.svg) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -# In a nutshell +## In a nutshell Moto is a library that allows your tests to easily mock out AWS Services. diff --git a/moto/__init__.py b/moto/__init__.py index 8e915933a..ed64413f8 100644 --- a/moto/__init__.py +++ b/moto/__init__.py @@ -1,64 +1,73 @@ from __future__ import unicode_literals -import logging + +# import logging # logging.getLogger('boto').setLevel(logging.CRITICAL) -__title__ = 'moto' -__version__ = '1.3.14.dev' +__title__ = "moto" +__version__ = "1.3.14.dev" -from .acm import mock_acm # flake8: noqa -from .apigateway import mock_apigateway, mock_apigateway_deprecated # flake8: noqa -from .athena import mock_athena # flake8: noqa -from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # flake8: noqa -from .awslambda import mock_lambda, mock_lambda_deprecated # flake8: noqa -from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # flake8: noqa -from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # flake8: noqa -from .cognitoidentity import mock_cognitoidentity, mock_cognitoidentity_deprecated # flake8: noqa -from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # flake8: noqa -from .config import mock_config # flake8: noqa -from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # flake8: noqa -from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # flake8: noqa -from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # flake8: noqa -from .dynamodbstreams import mock_dynamodbstreams # flake8: noqa -from .ec2 import mock_ec2, mock_ec2_deprecated # flake8: noqa -from .ecr import mock_ecr, mock_ecr_deprecated # flake8: noqa -from .ecs import mock_ecs, mock_ecs_deprecated # flake8: noqa -from .elb import mock_elb, mock_elb_deprecated # flake8: noqa -from .elbv2 import mock_elbv2 # flake8: noqa -from .emr import mock_emr, mock_emr_deprecated # flake8: noqa -from .events import mock_events # flake8: noqa -from .glacier import mock_glacier, mock_glacier_deprecated # flake8: noqa -from .glue import mock_glue # flake8: noqa -from .iam import mock_iam, mock_iam_deprecated # flake8: noqa -from .kinesis import mock_kinesis, mock_kinesis_deprecated # flake8: noqa -from .kms import mock_kms, mock_kms_deprecated # flake8: noqa -from .organizations import mock_organizations # flake8: noqa -from .opsworks import mock_opsworks, mock_opsworks_deprecated # flake8: noqa -from .polly import mock_polly # flake8: noqa -from .rds import mock_rds, mock_rds_deprecated # flake8: noqa -from .rds2 import mock_rds2, mock_rds2_deprecated # flake8: noqa -from .redshift import mock_redshift, mock_redshift_deprecated # flake8: noqa -from .resourcegroups import mock_resourcegroups # flake8: noqa -from .s3 import mock_s3, mock_s3_deprecated # flake8: noqa -from .ses import mock_ses, mock_ses_deprecated # flake8: noqa -from .secretsmanager import mock_secretsmanager # flake8: noqa -from .sns import mock_sns, mock_sns_deprecated # flake8: noqa -from .sqs import mock_sqs, mock_sqs_deprecated # flake8: noqa -from .stepfunctions import mock_stepfunctions # flake8: noqa -from .sts import mock_sts, mock_sts_deprecated # flake8: noqa -from .ssm import mock_ssm # flake8: noqa -from .route53 import mock_route53, mock_route53_deprecated # flake8: noqa -from .swf import mock_swf, mock_swf_deprecated # flake8: noqa -from .xray import mock_xray, mock_xray_client, XRaySegment # flake8: noqa -from .logs import mock_logs, mock_logs_deprecated # flake8: noqa -from .batch import mock_batch # flake8: noqa -from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # flake8: noqa -from .iot import mock_iot # flake8: noqa -from .iotdata import mock_iotdata # flake8: noqa +from .acm import mock_acm # noqa +from .apigateway import mock_apigateway, mock_apigateway_deprecated # noqa +from .athena import mock_athena # noqa +from .autoscaling import mock_autoscaling, mock_autoscaling_deprecated # noqa +from .awslambda import mock_lambda, mock_lambda_deprecated # noqa +from .cloudformation import mock_cloudformation, mock_cloudformation_deprecated # noqa +from .cloudwatch import mock_cloudwatch, mock_cloudwatch_deprecated # noqa +from .cognitoidentity import ( # noqa + mock_cognitoidentity, + mock_cognitoidentity_deprecated, +) +from .cognitoidp import mock_cognitoidp, mock_cognitoidp_deprecated # noqa +from .config import mock_config # noqa +from .datapipeline import mock_datapipeline, mock_datapipeline_deprecated # noqa +from .dynamodb import mock_dynamodb, mock_dynamodb_deprecated # noqa +from .dynamodb2 import mock_dynamodb2, mock_dynamodb2_deprecated # noqa +from .dynamodbstreams import mock_dynamodbstreams # noqa +from .ec2 import mock_ec2, mock_ec2_deprecated # noqa +from .ecr import mock_ecr, mock_ecr_deprecated # noqa +from .ecs import mock_ecs, mock_ecs_deprecated # noqa +from .elb import mock_elb, mock_elb_deprecated # noqa +from .elbv2 import mock_elbv2 # noqa +from .emr import mock_emr, mock_emr_deprecated # noqa +from .events import mock_events # noqa +from .glacier import mock_glacier, mock_glacier_deprecated # noqa +from .glue import mock_glue # noqa +from .iam import mock_iam, mock_iam_deprecated # noqa +from .kinesis import mock_kinesis, mock_kinesis_deprecated # noqa +from .kms import mock_kms, mock_kms_deprecated # noqa +from .organizations import mock_organizations # noqa +from .opsworks import mock_opsworks, mock_opsworks_deprecated # noqa +from .polly import mock_polly # noqa +from .rds import mock_rds, mock_rds_deprecated # noqa +from .rds2 import mock_rds2, mock_rds2_deprecated # noqa +from .redshift import mock_redshift, mock_redshift_deprecated # noqa +from .resourcegroups import mock_resourcegroups # noqa +from .s3 import mock_s3, mock_s3_deprecated # noqa +from .ses import mock_ses, mock_ses_deprecated # noqa +from .secretsmanager import mock_secretsmanager # noqa +from .sns import mock_sns, mock_sns_deprecated # noqa +from .sqs import mock_sqs, mock_sqs_deprecated # noqa +from .stepfunctions import mock_stepfunctions # noqa +from .sts import mock_sts, mock_sts_deprecated # noqa +from .ssm import mock_ssm # noqa +from .route53 import mock_route53, mock_route53_deprecated # noqa +from .swf import mock_swf, mock_swf_deprecated # noqa +from .xray import mock_xray, mock_xray_client, XRaySegment # noqa +from .logs import mock_logs, mock_logs_deprecated # noqa +from .batch import mock_batch # noqa +from .resourcegroupstaggingapi import mock_resourcegroupstaggingapi # noqa +from .iot import mock_iot # noqa +from .iotdata import mock_iotdata # noqa try: # Need to monkey-patch botocore requests back to underlying urllib3 classes - from botocore.awsrequest import HTTPSConnectionPool, HTTPConnectionPool, HTTPConnection, VerifiedHTTPSConnection + from botocore.awsrequest import ( + HTTPSConnectionPool, + HTTPConnectionPool, + HTTPConnection, + VerifiedHTTPSConnection, + ) except ImportError: pass else: diff --git a/moto/acm/__init__.py b/moto/acm/__init__.py index 6cd8a4aa5..07804282e 100644 --- a/moto/acm/__init__.py +++ b/moto/acm/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import acm_backends from ..core.models import base_decorator -acm_backend = acm_backends['us-east-1'] +acm_backend = acm_backends["us-east-1"] mock_acm = base_decorator(acm_backends) diff --git a/moto/acm/models.py b/moto/acm/models.py index b25dbcdff..a85017040 100644 --- a/moto/acm/models.py +++ b/moto/acm/models.py @@ -57,20 +57,29 @@ class AWSError(Exception): self.message = message def response(self): - resp = {'__type': self.TYPE, 'message': self.message} + resp = {"__type": self.TYPE, "message": self.message} return json.dumps(resp), dict(status=self.STATUS) class AWSValidationException(AWSError): - TYPE = 'ValidationException' + TYPE = "ValidationException" class AWSResourceNotFoundException(AWSError): - TYPE = 'ResourceNotFoundException' + TYPE = "ResourceNotFoundException" class CertBundle(BaseModel): - def __init__(self, certificate, private_key, chain=None, region='us-east-1', arn=None, cert_type='IMPORTED', cert_status='ISSUED'): + def __init__( + self, + certificate, + private_key, + chain=None, + region="us-east-1", + arn=None, + cert_type="IMPORTED", + cert_status="ISSUED", + ): self.created_at = datetime.datetime.now() self.cert = certificate self._cert = None @@ -87,7 +96,7 @@ class CertBundle(BaseModel): if self.chain is None: self.chain = GOOGLE_ROOT_CA else: - self.chain += b'\n' + GOOGLE_ROOT_CA + self.chain += b"\n" + GOOGLE_ROOT_CA # Takes care of PEM checking self.validate_pk() @@ -114,149 +123,209 @@ class CertBundle(BaseModel): sans.add(domain_name) sans = [cryptography.x509.DNSName(item) for item in sans] - key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) - subject = cryptography.x509.Name([ - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, u"CA"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.LOCALITY_NAME, u"San Francisco"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"My Company"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, domain_name), - ]) - issuer = cryptography.x509.Name([ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COUNTRY_NAME, u"US"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATION_NAME, u"Amazon"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u"Server CA 1B"), - cryptography.x509.NameAttribute(cryptography.x509.NameOID.COMMON_NAME, u"Amazon"), - ]) - cert = cryptography.x509.CertificateBuilder().subject_name( - subject - ).issuer_name( - issuer - ).public_key( - key.public_key() - ).serial_number( - cryptography.x509.random_serial_number() - ).not_valid_before( - datetime.datetime.utcnow() - ).not_valid_after( - datetime.datetime.utcnow() + datetime.timedelta(days=365) - ).add_extension( - cryptography.x509.SubjectAlternativeName(sans), - critical=False, - ).sign(key, hashes.SHA512(), default_backend()) + key = cryptography.hazmat.primitives.asymmetric.rsa.generate_private_key( + public_exponent=65537, key_size=2048, backend=default_backend() + ) + subject = cryptography.x509.Name( + [ + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COUNTRY_NAME, "US" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.STATE_OR_PROVINCE_NAME, "CA" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.LOCALITY_NAME, "San Francisco" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.ORGANIZATION_NAME, "My Company" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COMMON_NAME, domain_name + ), + ] + ) + issuer = cryptography.x509.Name( + [ # C = US, O = Amazon, OU = Server CA 1B, CN = Amazon + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COUNTRY_NAME, "US" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.ORGANIZATION_NAME, "Amazon" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.ORGANIZATIONAL_UNIT_NAME, "Server CA 1B" + ), + cryptography.x509.NameAttribute( + cryptography.x509.NameOID.COMMON_NAME, "Amazon" + ), + ] + ) + cert = ( + cryptography.x509.CertificateBuilder() + .subject_name(subject) + .issuer_name(issuer) + .public_key(key.public_key()) + .serial_number(cryptography.x509.random_serial_number()) + .not_valid_before(datetime.datetime.utcnow()) + .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=365)) + .add_extension( + cryptography.x509.SubjectAlternativeName(sans), critical=False + ) + .sign(key, hashes.SHA512(), default_backend()) + ) cert_armored = cert.public_bytes(serialization.Encoding.PEM) private_key = key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption() + encryption_algorithm=serialization.NoEncryption(), ) - return cls(cert_armored, private_key, cert_type='AMAZON_ISSUED', cert_status='PENDING_VALIDATION', region=region) + return cls( + cert_armored, + private_key, + cert_type="AMAZON_ISSUED", + cert_status="PENDING_VALIDATION", + region=region, + ) def validate_pk(self): try: - self._key = serialization.load_pem_private_key(self.key, password=None, backend=default_backend()) + self._key = serialization.load_pem_private_key( + self.key, password=None, backend=default_backend() + ) if self._key.key_size > 2048: - AWSValidationException('The private key length is not supported. Only 1024-bit and 2048-bit are allowed.') + AWSValidationException( + "The private key length is not supported. Only 1024-bit and 2048-bit are allowed." + ) except Exception as err: if isinstance(err, AWSValidationException): raise - raise AWSValidationException('The private key is not PEM-encoded or is not valid.') + raise AWSValidationException( + "The private key is not PEM-encoded or is not valid." + ) def validate_certificate(self): try: - self._cert = cryptography.x509.load_pem_x509_certificate(self.cert, default_backend()) + self._cert = cryptography.x509.load_pem_x509_certificate( + self.cert, default_backend() + ) now = datetime.datetime.utcnow() if self._cert.not_valid_after < now: - raise AWSValidationException('The certificate has expired, is not valid.') + raise AWSValidationException( + "The certificate has expired, is not valid." + ) if self._cert.not_valid_before > now: - raise AWSValidationException('The certificate is not in effect yet, is not valid.') + raise AWSValidationException( + "The certificate is not in effect yet, is not valid." + ) # Extracting some common fields for ease of use # Have to search through cert.subject for OIDs - self.common_name = self._cert.subject.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value + self.common_name = self._cert.subject.get_attributes_for_oid( + cryptography.x509.OID_COMMON_NAME + )[0].value except Exception as err: if isinstance(err, AWSValidationException): raise - raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') + raise AWSValidationException( + "The certificate is not PEM-encoded or is not valid." + ) def validate_chain(self): try: self._chain = [] - for cert_armored in self.chain.split(b'-\n-'): + for cert_armored in self.chain.split(b"-\n-"): # Would leave encoded but Py2 does not have raw binary strings cert_armored = cert_armored.decode() # Fix missing -'s on split - cert_armored = re.sub(r'^----B', '-----B', cert_armored) - cert_armored = re.sub(r'E----$', 'E-----', cert_armored) - cert = cryptography.x509.load_pem_x509_certificate(cert_armored.encode(), default_backend()) + cert_armored = re.sub(r"^----B", "-----B", cert_armored) + cert_armored = re.sub(r"E----$", "E-----", cert_armored) + cert = cryptography.x509.load_pem_x509_certificate( + cert_armored.encode(), default_backend() + ) self._chain.append(cert) now = datetime.datetime.now() if self._cert.not_valid_after < now: - raise AWSValidationException('The certificate chain has expired, is not valid.') + raise AWSValidationException( + "The certificate chain has expired, is not valid." + ) if self._cert.not_valid_before > now: - raise AWSValidationException('The certificate chain is not in effect yet, is not valid.') + raise AWSValidationException( + "The certificate chain is not in effect yet, is not valid." + ) except Exception as err: if isinstance(err, AWSValidationException): raise - raise AWSValidationException('The certificate is not PEM-encoded or is not valid.') + raise AWSValidationException( + "The certificate is not PEM-encoded or is not valid." + ) def check(self): # Basically, if the certificate is pending, and then checked again after 1 min # It will appear as if its been validated - if self.type == 'AMAZON_ISSUED' and self.status == 'PENDING_VALIDATION' and \ - (datetime.datetime.now() - self.created_at).total_seconds() > 60: # 1min - self.status = 'ISSUED' + if ( + self.type == "AMAZON_ISSUED" + and self.status == "PENDING_VALIDATION" + and (datetime.datetime.now() - self.created_at).total_seconds() > 60 + ): # 1min + self.status = "ISSUED" def describe(self): # 'RenewalSummary': {}, # Only when cert is amazon issued if self._key.key_size == 1024: - key_algo = 'RSA_1024' + key_algo = "RSA_1024" elif self._key.key_size == 2048: - key_algo = 'RSA_2048' + key_algo = "RSA_2048" else: - key_algo = 'EC_prime256v1' + key_algo = "EC_prime256v1" # Look for SANs - san_obj = self._cert.extensions.get_extension_for_oid(cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME) + san_obj = self._cert.extensions.get_extension_for_oid( + cryptography.x509.OID_SUBJECT_ALTERNATIVE_NAME + ) sans = [] if san_obj is not None: sans = [item.value for item in san_obj.value] result = { - 'Certificate': { - 'CertificateArn': self.arn, - 'DomainName': self.common_name, - 'InUseBy': [], - 'Issuer': self._cert.issuer.get_attributes_for_oid(cryptography.x509.OID_COMMON_NAME)[0].value, - 'KeyAlgorithm': key_algo, - 'NotAfter': datetime_to_epoch(self._cert.not_valid_after), - 'NotBefore': datetime_to_epoch(self._cert.not_valid_before), - 'Serial': self._cert.serial_number, - 'SignatureAlgorithm': self._cert.signature_algorithm_oid._name.upper().replace('ENCRYPTION', ''), - 'Status': self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED. - 'Subject': 'CN={0}'.format(self.common_name), - 'SubjectAlternativeNames': sans, - 'Type': self.type # One of IMPORTED, AMAZON_ISSUED + "Certificate": { + "CertificateArn": self.arn, + "DomainName": self.common_name, + "InUseBy": [], + "Issuer": self._cert.issuer.get_attributes_for_oid( + cryptography.x509.OID_COMMON_NAME + )[0].value, + "KeyAlgorithm": key_algo, + "NotAfter": datetime_to_epoch(self._cert.not_valid_after), + "NotBefore": datetime_to_epoch(self._cert.not_valid_before), + "Serial": self._cert.serial_number, + "SignatureAlgorithm": self._cert.signature_algorithm_oid._name.upper().replace( + "ENCRYPTION", "" + ), + "Status": self.status, # One of PENDING_VALIDATION, ISSUED, INACTIVE, EXPIRED, VALIDATION_TIMED_OUT, REVOKED, FAILED. + "Subject": "CN={0}".format(self.common_name), + "SubjectAlternativeNames": sans, + "Type": self.type, # One of IMPORTED, AMAZON_ISSUED } } - if self.type == 'IMPORTED': - result['Certificate']['ImportedAt'] = datetime_to_epoch(self.created_at) + if self.type == "IMPORTED": + result["Certificate"]["ImportedAt"] = datetime_to_epoch(self.created_at) else: - result['Certificate']['CreatedAt'] = datetime_to_epoch(self.created_at) - result['Certificate']['IssuedAt'] = datetime_to_epoch(self.created_at) + result["Certificate"]["CreatedAt"] = datetime_to_epoch(self.created_at) + result["Certificate"]["IssuedAt"] = datetime_to_epoch(self.created_at) return result @@ -264,7 +333,7 @@ class CertBundle(BaseModel): return self.arn def __repr__(self): - return '' + return "" class AWSCertificateManagerBackend(BaseBackend): @@ -281,7 +350,9 @@ class AWSCertificateManagerBackend(BaseBackend): @staticmethod def _arn_not_found(arn): - msg = 'Certificate with arn {0} not found in account {1}'.format(arn, DEFAULT_ACCOUNT_ID) + msg = "Certificate with arn {0} not found in account {1}".format( + arn, DEFAULT_ACCOUNT_ID + ) return AWSResourceNotFoundException(msg) def _get_arn_from_idempotency_token(self, token): @@ -298,17 +369,20 @@ class AWSCertificateManagerBackend(BaseBackend): """ now = datetime.datetime.now() if token in self._idempotency_tokens: - if self._idempotency_tokens[token]['expires'] < now: + if self._idempotency_tokens[token]["expires"] < now: # Token has expired, new request del self._idempotency_tokens[token] return None else: - return self._idempotency_tokens[token]['arn'] + return self._idempotency_tokens[token]["arn"] return None def _set_idempotency_token_arn(self, token, arn): - self._idempotency_tokens[token] = {'arn': arn, 'expires': datetime.datetime.now() + datetime.timedelta(hours=1)} + self._idempotency_tokens[token] = { + "arn": arn, + "expires": datetime.datetime.now() + datetime.timedelta(hours=1), + } def import_cert(self, certificate, private_key, chain=None, arn=None): if arn is not None: @@ -316,7 +390,9 @@ class AWSCertificateManagerBackend(BaseBackend): raise self._arn_not_found(arn) else: # Will reuse provided ARN - bundle = CertBundle(certificate, private_key, chain=chain, region=region, arn=arn) + bundle = CertBundle( + certificate, private_key, chain=chain, region=region, arn=arn + ) else: # Will generate a random ARN bundle = CertBundle(certificate, private_key, chain=chain, region=region) @@ -351,13 +427,21 @@ class AWSCertificateManagerBackend(BaseBackend): del self._certificates[arn] - def request_certificate(self, domain_name, domain_validation_options, idempotency_token, subject_alt_names): + def request_certificate( + self, + domain_name, + domain_validation_options, + idempotency_token, + subject_alt_names, + ): if idempotency_token is not None: arn = self._get_arn_from_idempotency_token(idempotency_token) if arn is not None: return arn - cert = CertBundle.generate_cert(domain_name, region=self.region, sans=subject_alt_names) + cert = CertBundle.generate_cert( + domain_name, region=self.region, sans=subject_alt_names + ) if idempotency_token is not None: self._set_idempotency_token_arn(idempotency_token, cert.arn) self._certificates[cert.arn] = cert @@ -369,8 +453,8 @@ class AWSCertificateManagerBackend(BaseBackend): cert_bundle = self.get_certificate(arn) for tag in tags: - key = tag['Key'] - value = tag.get('Value', None) + key = tag["Key"] + value = tag.get("Value", None) cert_bundle.tags[key] = value def remove_tags_from_certificate(self, arn, tags): @@ -378,8 +462,8 @@ class AWSCertificateManagerBackend(BaseBackend): cert_bundle = self.get_certificate(arn) for tag in tags: - key = tag['Key'] - value = tag.get('Value', None) + key = tag["Key"] + value = tag.get("Value", None) try: # If value isnt provided, just delete key diff --git a/moto/acm/responses.py b/moto/acm/responses.py index 0d0ac640b..13b22fa95 100644 --- a/moto/acm/responses.py +++ b/moto/acm/responses.py @@ -7,7 +7,6 @@ from .models import acm_backends, AWSError, AWSValidationException class AWSCertificateManagerResponse(BaseResponse): - @property def acm_backend(self): """ @@ -29,40 +28,49 @@ class AWSCertificateManagerResponse(BaseResponse): return self.request_params.get(param, default) def add_tags_to_certificate(self): - arn = self._get_param('CertificateArn') - tags = self._get_param('Tags') + arn = self._get_param("CertificateArn") + tags = self._get_param("Tags") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: self.acm_backend.add_tags_to_certificate(arn, tags) except AWSError as err: return err.response() - return '' + return "" def delete_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: self.acm_backend.delete_certificate(arn) except AWSError as err: return err.response() - return '' + return "" def describe_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: cert_bundle = self.acm_backend.get_certificate(arn) @@ -72,11 +80,14 @@ class AWSCertificateManagerResponse(BaseResponse): return json.dumps(cert_bundle.describe()) def get_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: cert_bundle = self.acm_backend.get_certificate(arn) @@ -84,8 +95,8 @@ class AWSCertificateManagerResponse(BaseResponse): return err.response() result = { - 'Certificate': cert_bundle.cert.decode(), - 'CertificateChain': cert_bundle.chain.decode() + "Certificate": cert_bundle.cert.decode(), + "CertificateChain": cert_bundle.chain.decode(), } return json.dumps(result) @@ -102,104 +113,129 @@ class AWSCertificateManagerResponse(BaseResponse): :return: str(JSON) for response """ - certificate = self._get_param('Certificate') - private_key = self._get_param('PrivateKey') - chain = self._get_param('CertificateChain') # Optional - current_arn = self._get_param('CertificateArn') # Optional + certificate = self._get_param("Certificate") + private_key = self._get_param("PrivateKey") + chain = self._get_param("CertificateChain") # Optional + current_arn = self._get_param("CertificateArn") # Optional # Simple parameter decoding. Rather do it here as its a data transport decision not part of the # actual data try: certificate = base64.standard_b64decode(certificate) except Exception: - return AWSValidationException('The certificate is not PEM-encoded or is not valid.').response() + return AWSValidationException( + "The certificate is not PEM-encoded or is not valid." + ).response() try: private_key = base64.standard_b64decode(private_key) except Exception: - return AWSValidationException('The private key is not PEM-encoded or is not valid.').response() + return AWSValidationException( + "The private key is not PEM-encoded or is not valid." + ).response() if chain is not None: try: chain = base64.standard_b64decode(chain) except Exception: - return AWSValidationException('The certificate chain is not PEM-encoded or is not valid.').response() + return AWSValidationException( + "The certificate chain is not PEM-encoded or is not valid." + ).response() try: - arn = self.acm_backend.import_cert(certificate, private_key, chain=chain, arn=current_arn) + arn = self.acm_backend.import_cert( + certificate, private_key, chain=chain, arn=current_arn + ) except AWSError as err: return err.response() - return json.dumps({'CertificateArn': arn}) + return json.dumps({"CertificateArn": arn}) def list_certificates(self): certs = [] - statuses = self._get_param('CertificateStatuses') + statuses = self._get_param("CertificateStatuses") for cert_bundle in self.acm_backend.get_certificates_list(statuses): - certs.append({ - 'CertificateArn': cert_bundle.arn, - 'DomainName': cert_bundle.common_name - }) + certs.append( + { + "CertificateArn": cert_bundle.arn, + "DomainName": cert_bundle.common_name, + } + ) - result = {'CertificateSummaryList': certs} + result = {"CertificateSummaryList": certs} return json.dumps(result) def list_tags_for_certificate(self): - arn = self._get_param('CertificateArn') + arn = self._get_param("CertificateArn") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return {'__type': 'MissingParameter', 'message': msg}, dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return {"__type": "MissingParameter", "message": msg}, dict(status=400) try: cert_bundle = self.acm_backend.get_certificate(arn) except AWSError as err: return err.response() - result = {'Tags': []} + result = {"Tags": []} # Tag "objects" can not contain the Value part for key, value in cert_bundle.tags.items(): - tag_dict = {'Key': key} + tag_dict = {"Key": key} if value is not None: - tag_dict['Value'] = value - result['Tags'].append(tag_dict) + tag_dict["Value"] = value + result["Tags"].append(tag_dict) return json.dumps(result) def remove_tags_from_certificate(self): - arn = self._get_param('CertificateArn') - tags = self._get_param('Tags') + arn = self._get_param("CertificateArn") + tags = self._get_param("Tags") if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: self.acm_backend.remove_tags_from_certificate(arn, tags) except AWSError as err: return err.response() - return '' + return "" def request_certificate(self): - domain_name = self._get_param('DomainName') - domain_validation_options = self._get_param('DomainValidationOptions') # is ignored atm - idempotency_token = self._get_param('IdempotencyToken') - subject_alt_names = self._get_param('SubjectAlternativeNames') + domain_name = self._get_param("DomainName") + domain_validation_options = self._get_param( + "DomainValidationOptions" + ) # is ignored atm + idempotency_token = self._get_param("IdempotencyToken") + subject_alt_names = self._get_param("SubjectAlternativeNames") if subject_alt_names is not None and len(subject_alt_names) > 10: # There is initial AWS limit of 10 - msg = 'An ACM limit has been exceeded. Need to request SAN limit to be raised' - return json.dumps({'__type': 'LimitExceededException', 'message': msg}), dict(status=400) + msg = ( + "An ACM limit has been exceeded. Need to request SAN limit to be raised" + ) + return ( + json.dumps({"__type": "LimitExceededException", "message": msg}), + dict(status=400), + ) try: - arn = self.acm_backend.request_certificate(domain_name, domain_validation_options, idempotency_token, subject_alt_names) + arn = self.acm_backend.request_certificate( + domain_name, + domain_validation_options, + idempotency_token, + subject_alt_names, + ) except AWSError as err: return err.response() - return json.dumps({'CertificateArn': arn}) + return json.dumps({"CertificateArn": arn}) def resend_validation_email(self): - arn = self._get_param('CertificateArn') - domain = self._get_param('Domain') + arn = self._get_param("CertificateArn") + domain = self._get_param("Domain") # ValidationDomain not used yet. # Contains domain which is equal to or a subset of Domain # that AWS will send validation emails to @@ -207,18 +243,21 @@ class AWSCertificateManagerResponse(BaseResponse): # validation_domain = self._get_param('ValidationDomain') if arn is None: - msg = 'A required parameter for the specified action is not supplied.' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "A required parameter for the specified action is not supplied." + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: cert_bundle = self.acm_backend.get_certificate(arn) if cert_bundle.common_name != domain: - msg = 'Parameter Domain does not match certificate domain' - _type = 'InvalidDomainValidationOptionsException' - return json.dumps({'__type': _type, 'message': msg}), dict(status=400) + msg = "Parameter Domain does not match certificate domain" + _type = "InvalidDomainValidationOptionsException" + return json.dumps({"__type": _type, "message": msg}), dict(status=400) except AWSError as err: return err.response() - return '' + return "" diff --git a/moto/acm/urls.py b/moto/acm/urls.py index 20acbb3f4..8a8d3e2ef 100644 --- a/moto/acm/urls.py +++ b/moto/acm/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import AWSCertificateManagerResponse -url_bases = [ - "https?://acm.(.+).amazonaws.com", -] +url_bases = ["https?://acm.(.+).amazonaws.com"] -url_paths = { - '{0}/$': AWSCertificateManagerResponse.dispatch, -} +url_paths = {"{0}/$": AWSCertificateManagerResponse.dispatch} diff --git a/moto/acm/utils.py b/moto/acm/utils.py index b3c441454..6d695d95c 100644 --- a/moto/acm/utils.py +++ b/moto/acm/utils.py @@ -4,4 +4,6 @@ import uuid def make_arn_for_certificate(account_id, region_name): # Example # arn:aws:acm:eu-west-2:764371465172:certificate/c4b738b8-56fe-4b3a-b841-1c047654780b - return "arn:aws:acm:{0}:{1}:certificate/{2}".format(region_name, account_id, uuid.uuid4()) + return "arn:aws:acm:{0}:{1}:certificate/{2}".format( + region_name, account_id, uuid.uuid4() + ) diff --git a/moto/apigateway/__init__.py b/moto/apigateway/__init__.py index 98b2058d9..42da3db53 100644 --- a/moto/apigateway/__init__.py +++ b/moto/apigateway/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import apigateway_backends from ..core.models import base_decorator, deprecated_base_decorator -apigateway_backend = apigateway_backends['us-east-1'] +apigateway_backend = apigateway_backends["us-east-1"] mock_apigateway = base_decorator(apigateway_backends) mock_apigateway_deprecated = deprecated_base_decorator(apigateway_backends) diff --git a/moto/apigateway/exceptions.py b/moto/apigateway/exceptions.py index 62fa24392..98845d2f0 100644 --- a/moto/apigateway/exceptions.py +++ b/moto/apigateway/exceptions.py @@ -7,7 +7,8 @@ class StageNotFoundException(RESTError): def __init__(self): super(StageNotFoundException, self).__init__( - "NotFoundException", "Invalid stage identifier specified") + "NotFoundException", "Invalid stage identifier specified" + ) class ApiKeyNotFoundException(RESTError): @@ -15,4 +16,5 @@ class ApiKeyNotFoundException(RESTError): def __init__(self): super(ApiKeyNotFoundException, self).__init__( - "NotFoundException", "Invalid API Key identifier specified") + "NotFoundException", "Invalid API Key identifier specified" + ) diff --git a/moto/apigateway/models.py b/moto/apigateway/models.py index 6f7fc97a9..f7b26e5e2 100644 --- a/moto/apigateway/models.py +++ b/moto/apigateway/models.py @@ -17,39 +17,33 @@ STAGE_URL = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_nam class Deployment(BaseModel, dict): - def __init__(self, deployment_id, name, description=""): super(Deployment, self).__init__() - self['id'] = deployment_id - self['stageName'] = name - self['description'] = description - self['createdDate'] = int(time.time()) + self["id"] = deployment_id + self["stageName"] = name + self["description"] = description + self["createdDate"] = int(time.time()) class IntegrationResponse(BaseModel, dict): - def __init__(self, status_code, selection_pattern=None): - self['responseTemplates'] = {"application/json": None} - self['statusCode'] = status_code + self["responseTemplates"] = {"application/json": None} + self["statusCode"] = status_code if selection_pattern: - self['selectionPattern'] = selection_pattern + self["selectionPattern"] = selection_pattern class Integration(BaseModel, dict): - def __init__(self, integration_type, uri, http_method, request_templates=None): super(Integration, self).__init__() - self['type'] = integration_type - self['uri'] = uri - self['httpMethod'] = http_method - self['requestTemplates'] = request_templates - self["integrationResponses"] = { - "200": IntegrationResponse(200) - } + self["type"] = integration_type + self["uri"] = uri + self["httpMethod"] = http_method + self["requestTemplates"] = request_templates + self["integrationResponses"] = {"200": IntegrationResponse(200)} def create_integration_response(self, status_code, selection_pattern): - integration_response = IntegrationResponse( - status_code, selection_pattern) + integration_response = IntegrationResponse(status_code, selection_pattern) self["integrationResponses"][status_code] = integration_response return integration_response @@ -61,25 +55,25 @@ class Integration(BaseModel, dict): class MethodResponse(BaseModel, dict): - def __init__(self, status_code): super(MethodResponse, self).__init__() - self['statusCode'] = status_code + self["statusCode"] = status_code class Method(BaseModel, dict): - def __init__(self, method_type, authorization_type): super(Method, self).__init__() - self.update(dict( - httpMethod=method_type, - authorizationType=authorization_type, - authorizerId=None, - apiKeyRequired=None, - requestParameters=None, - requestModels=None, - methodIntegration=None, - )) + self.update( + dict( + httpMethod=method_type, + authorizationType=authorization_type, + authorizerId=None, + apiKeyRequired=None, + requestParameters=None, + requestModels=None, + methodIntegration=None, + ) + ) self.method_responses = {} def create_response(self, response_code): @@ -95,16 +89,13 @@ class Method(BaseModel, dict): class Resource(BaseModel): - def __init__(self, id, region_name, api_id, path_part, parent_id): self.id = id self.region_name = region_name self.api_id = api_id self.path_part = path_part self.parent_id = parent_id - self.resource_methods = { - 'GET': {} - } + self.resource_methods = {"GET": {}} def to_dict(self): response = { @@ -113,8 +104,8 @@ class Resource(BaseModel): "resourceMethods": self.resource_methods, } if self.parent_id: - response['parentId'] = self.parent_id - response['pathPart'] = self.path_part + response["parentId"] = self.parent_id + response["pathPart"] = self.path_part return response def get_path(self): @@ -125,102 +116,112 @@ class Resource(BaseModel): backend = apigateway_backends[self.region_name] parent = backend.get_resource(self.api_id, self.parent_id) parent_path = parent.get_path() - if parent_path != '/': # Root parent - parent_path += '/' + if parent_path != "/": # Root parent + parent_path += "/" return parent_path else: - return '' + return "" def get_response(self, request): integration = self.get_integration(request.method) - integration_type = integration['type'] + integration_type = integration["type"] - if integration_type == 'HTTP': - uri = integration['uri'] - requests_func = getattr(requests, integration[ - 'httpMethod'].lower()) + if integration_type == "HTTP": + uri = integration["uri"] + requests_func = getattr(requests, integration["httpMethod"].lower()) response = requests_func(uri) else: raise NotImplementedError( - "The {0} type has not been implemented".format(integration_type)) + "The {0} type has not been implemented".format(integration_type) + ) return response.status_code, response.text def add_method(self, method_type, authorization_type): - method = Method(method_type=method_type, - authorization_type=authorization_type) + method = Method(method_type=method_type, authorization_type=authorization_type) self.resource_methods[method_type] = method return method def get_method(self, method_type): return self.resource_methods[method_type] - def add_integration(self, method_type, integration_type, uri, request_templates=None): + def add_integration( + self, method_type, integration_type, uri, request_templates=None + ): integration = Integration( - integration_type, uri, method_type, request_templates=request_templates) - self.resource_methods[method_type]['methodIntegration'] = integration + integration_type, uri, method_type, request_templates=request_templates + ) + self.resource_methods[method_type]["methodIntegration"] = integration return integration def get_integration(self, method_type): - return self.resource_methods[method_type]['methodIntegration'] + return self.resource_methods[method_type]["methodIntegration"] def delete_integration(self, method_type): - return self.resource_methods[method_type].pop('methodIntegration') + return self.resource_methods[method_type].pop("methodIntegration") class Stage(BaseModel, dict): - - def __init__(self, name=None, deployment_id=None, variables=None, - description='', cacheClusterEnabled=False, cacheClusterSize=None): + def __init__( + self, + name=None, + deployment_id=None, + variables=None, + description="", + cacheClusterEnabled=False, + cacheClusterSize=None, + ): super(Stage, self).__init__() if variables is None: variables = {} - self['stageName'] = name - self['deploymentId'] = deployment_id - self['methodSettings'] = {} - self['variables'] = variables - self['description'] = description - self['cacheClusterEnabled'] = cacheClusterEnabled - if self['cacheClusterEnabled']: - self['cacheClusterSize'] = str(0.5) + self["stageName"] = name + self["deploymentId"] = deployment_id + self["methodSettings"] = {} + self["variables"] = variables + self["description"] = description + self["cacheClusterEnabled"] = cacheClusterEnabled + if self["cacheClusterEnabled"]: + self["cacheClusterSize"] = str(0.5) if cacheClusterSize is not None: - self['cacheClusterSize'] = str(cacheClusterSize) + self["cacheClusterSize"] = str(cacheClusterSize) def apply_operations(self, patch_operations): for op in patch_operations: - if 'variables/' in op['path']: + if "variables/" in op["path"]: self._apply_operation_to_variables(op) - elif '/cacheClusterEnabled' in op['path']: - self['cacheClusterEnabled'] = self._str2bool(op['value']) - if 'cacheClusterSize' not in self and self['cacheClusterEnabled']: - self['cacheClusterSize'] = str(0.5) - elif '/cacheClusterSize' in op['path']: - self['cacheClusterSize'] = str(float(op['value'])) - elif '/description' in op['path']: - self['description'] = op['value'] - elif '/deploymentId' in op['path']: - self['deploymentId'] = op['value'] - elif op['op'] == 'replace': + elif "/cacheClusterEnabled" in op["path"]: + self["cacheClusterEnabled"] = self._str2bool(op["value"]) + if "cacheClusterSize" not in self and self["cacheClusterEnabled"]: + self["cacheClusterSize"] = str(0.5) + elif "/cacheClusterSize" in op["path"]: + self["cacheClusterSize"] = str(float(op["value"])) + elif "/description" in op["path"]: + self["description"] = op["value"] + elif "/deploymentId" in op["path"]: + self["deploymentId"] = op["value"] + elif op["op"] == "replace": # Method Settings drop into here # (e.g., path could be '/*/*/logging/loglevel') - split_path = op['path'].split('/', 3) + split_path = op["path"].split("/", 3) if len(split_path) != 4: continue self._patch_method_setting( - '/'.join(split_path[1:3]), split_path[3], op['value']) + "/".join(split_path[1:3]), split_path[3], op["value"] + ) else: - raise Exception( - 'Patch operation "%s" not implemented' % op['op']) + raise Exception('Patch operation "%s" not implemented' % op["op"]) return self def _patch_method_setting(self, resource_path_and_method, key, value): updated_key = self._method_settings_translations(key) if updated_key is not None: - if resource_path_and_method not in self['methodSettings']: - self['methodSettings'][ - resource_path_and_method] = self._get_default_method_settings() - self['methodSettings'][resource_path_and_method][ - updated_key] = self._convert_to_type(updated_key, value) + if resource_path_and_method not in self["methodSettings"]: + self["methodSettings"][ + resource_path_and_method + ] = self._get_default_method_settings() + self["methodSettings"][resource_path_and_method][ + updated_key + ] = self._convert_to_type(updated_key, value) def _get_default_method_settings(self): return { @@ -232,21 +233,21 @@ class Stage(BaseModel, dict): "cacheDataEncrypted": True, "cachingEnabled": False, "throttlingBurstLimit": 2000, - "requireAuthorizationForCacheControl": True + "requireAuthorizationForCacheControl": True, } def _method_settings_translations(self, key): mappings = { - 'metrics/enabled': 'metricsEnabled', - 'logging/loglevel': 'loggingLevel', - 'logging/dataTrace': 'dataTraceEnabled', - 'throttling/burstLimit': 'throttlingBurstLimit', - 'throttling/rateLimit': 'throttlingRateLimit', - 'caching/enabled': 'cachingEnabled', - 'caching/ttlInSeconds': 'cacheTtlInSeconds', - 'caching/dataEncrypted': 'cacheDataEncrypted', - 'caching/requireAuthorizationForCacheControl': 'requireAuthorizationForCacheControl', - 'caching/unauthorizedCacheControlHeaderStrategy': 'unauthorizedCacheControlHeaderStrategy' + "metrics/enabled": "metricsEnabled", + "logging/loglevel": "loggingLevel", + "logging/dataTrace": "dataTraceEnabled", + "throttling/burstLimit": "throttlingBurstLimit", + "throttling/rateLimit": "throttlingRateLimit", + "caching/enabled": "cachingEnabled", + "caching/ttlInSeconds": "cacheTtlInSeconds", + "caching/dataEncrypted": "cacheDataEncrypted", + "caching/requireAuthorizationForCacheControl": "requireAuthorizationForCacheControl", + "caching/unauthorizedCacheControlHeaderStrategy": "unauthorizedCacheControlHeaderStrategy", } if key in mappings: @@ -259,26 +260,26 @@ class Stage(BaseModel, dict): def _convert_to_type(self, key, val): type_mappings = { - 'metricsEnabled': 'bool', - 'loggingLevel': 'str', - 'dataTraceEnabled': 'bool', - 'throttlingBurstLimit': 'int', - 'throttlingRateLimit': 'float', - 'cachingEnabled': 'bool', - 'cacheTtlInSeconds': 'int', - 'cacheDataEncrypted': 'bool', - 'requireAuthorizationForCacheControl': 'bool', - 'unauthorizedCacheControlHeaderStrategy': 'str' + "metricsEnabled": "bool", + "loggingLevel": "str", + "dataTraceEnabled": "bool", + "throttlingBurstLimit": "int", + "throttlingRateLimit": "float", + "cachingEnabled": "bool", + "cacheTtlInSeconds": "int", + "cacheDataEncrypted": "bool", + "requireAuthorizationForCacheControl": "bool", + "unauthorizedCacheControlHeaderStrategy": "str", } if key in type_mappings: type_value = type_mappings[key] - if type_value == 'bool': + if type_value == "bool": return self._str2bool(val) - elif type_value == 'int': + elif type_value == "int": return int(val) - elif type_value == 'float': + elif type_value == "float": return float(val) else: return str(val) @@ -286,44 +287,55 @@ class Stage(BaseModel, dict): return str(val) def _apply_operation_to_variables(self, op): - key = op['path'][op['path'].rindex("variables/") + 10:] - if op['op'] == 'remove': - self['variables'].pop(key, None) - elif op['op'] == 'replace': - self['variables'][key] = op['value'] + key = op["path"][op["path"].rindex("variables/") + 10 :] + if op["op"] == "remove": + self["variables"].pop(key, None) + elif op["op"] == "replace": + self["variables"][key] = op["value"] else: - raise Exception('Patch operation "%s" not implemented' % op['op']) + raise Exception('Patch operation "%s" not implemented' % op["op"]) class ApiKey(BaseModel, dict): - - def __init__(self, name=None, description=None, enabled=True, - generateDistinctId=False, value=None, stageKeys=None, tags=None, customerId=None): + def __init__( + self, + name=None, + description=None, + enabled=True, + generateDistinctId=False, + value=None, + stageKeys=None, + tags=None, + customerId=None, + ): super(ApiKey, self).__init__() - self['id'] = create_id() - self['value'] = value if value else ''.join(random.sample(string.ascii_letters + string.digits, 40)) - self['name'] = name - self['customerId'] = customerId - self['description'] = description - self['enabled'] = enabled - self['createdDate'] = self['lastUpdatedDate'] = int(time.time()) - self['stageKeys'] = stageKeys - self['tags'] = tags + self["id"] = create_id() + self["value"] = ( + value + if value + else "".join(random.sample(string.ascii_letters + string.digits, 40)) + ) + self["name"] = name + self["customerId"] = customerId + self["description"] = description + self["enabled"] = enabled + self["createdDate"] = self["lastUpdatedDate"] = int(time.time()) + self["stageKeys"] = stageKeys + self["tags"] = tags def update_operations(self, patch_operations): for op in patch_operations: - if op['op'] == 'replace': - if '/name' in op['path']: - self['name'] = op['value'] - elif '/customerId' in op['path']: - self['customerId'] = op['value'] - elif '/description' in op['path']: - self['description'] = op['value'] - elif '/enabled' in op['path']: - self['enabled'] = self._str2bool(op['value']) + if op["op"] == "replace": + if "/name" in op["path"]: + self["name"] = op["value"] + elif "/customerId" in op["path"]: + self["customerId"] = op["value"] + elif "/description" in op["path"]: + self["description"] = op["value"] + elif "/enabled" in op["path"]: + self["enabled"] = self._str2bool(op["value"]) else: - raise Exception( - 'Patch operation "%s" not implemented' % op['op']) + raise Exception('Patch operation "%s" not implemented' % op["op"]) return self def _str2bool(self, v): @@ -331,31 +343,35 @@ class ApiKey(BaseModel, dict): class UsagePlan(BaseModel, dict): - - def __init__(self, name=None, description=None, apiStages=None, - throttle=None, quota=None, tags=None): + def __init__( + self, + name=None, + description=None, + apiStages=None, + throttle=None, + quota=None, + tags=None, + ): super(UsagePlan, self).__init__() - self['id'] = create_id() - self['name'] = name - self['description'] = description - self['apiStages'] = apiStages if apiStages else [] - self['throttle'] = throttle - self['quota'] = quota - self['tags'] = tags + self["id"] = create_id() + self["name"] = name + self["description"] = description + self["apiStages"] = apiStages if apiStages else [] + self["throttle"] = throttle + self["quota"] = quota + self["tags"] = tags class UsagePlanKey(BaseModel, dict): - def __init__(self, id, type, name, value): super(UsagePlanKey, self).__init__() - self['id'] = id - self['name'] = name - self['type'] = type - self['value'] = value + self["id"] = id + self["name"] = name + self["type"] = type + self["value"] = value class RestAPI(BaseModel): - def __init__(self, id, region_name, name, description): self.id = id self.region_name = region_name @@ -367,7 +383,7 @@ class RestAPI(BaseModel): self.stages = {} self.resources = {} - self.add_child('/') # Add default child + self.add_child("/") # Add default child def __repr__(self): return str(self.id) @@ -382,8 +398,13 @@ class RestAPI(BaseModel): def add_child(self, path, parent_id=None): child_id = create_id() - child = Resource(id=child_id, region_name=self.region_name, - api_id=self.id, path_part=path, parent_id=parent_id) + child = Resource( + id=child_id, + region_name=self.region_name, + api_id=self.id, + path_part=path, + parent_id=parent_id, + ) self.resources[child_id] = child return child @@ -395,36 +416,53 @@ class RestAPI(BaseModel): def resource_callback(self, request): path = path_url(request.url) - path_after_stage_name = '/'.join(path.split("/")[2:]) + path_after_stage_name = "/".join(path.split("/")[2:]) if not path_after_stage_name: - path_after_stage_name = '/' + path_after_stage_name = "/" resource = self.get_resource_for_path(path_after_stage_name) status_code, response = resource.get_response(request) return status_code, {}, response def update_integration_mocks(self, stage_name): - stage_url_lower = STAGE_URL.format(api_id=self.id.lower(), - region_name=self.region_name, stage_name=stage_name) - stage_url_upper = STAGE_URL.format(api_id=self.id.upper(), - region_name=self.region_name, stage_name=stage_name) + stage_url_lower = STAGE_URL.format( + api_id=self.id.lower(), region_name=self.region_name, stage_name=stage_name + ) + stage_url_upper = STAGE_URL.format( + api_id=self.id.upper(), region_name=self.region_name, stage_name=stage_name + ) for url in [stage_url_lower, stage_url_upper]: - responses._default_mock._matches.insert(0, + responses._default_mock._matches.insert( + 0, responses.CallbackResponse( url=url, method=responses.GET, callback=self.resource_callback, content_type="text/plain", match_querystring=False, - ) + ), ) - def create_stage(self, name, deployment_id, variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): + def create_stage( + self, + name, + deployment_id, + variables=None, + description="", + cacheClusterEnabled=None, + cacheClusterSize=None, + ): if variables is None: variables = {} - stage = Stage(name=name, deployment_id=deployment_id, variables=variables, - description=description, cacheClusterSize=cacheClusterSize, cacheClusterEnabled=cacheClusterEnabled) + stage = Stage( + name=name, + deployment_id=deployment_id, + variables=variables, + description=description, + cacheClusterSize=cacheClusterSize, + cacheClusterEnabled=cacheClusterEnabled, + ) self.stages[name] = stage self.update_integration_mocks(name) return stage @@ -436,7 +474,8 @@ class RestAPI(BaseModel): deployment = Deployment(deployment_id, name, description) self.deployments[deployment_id] = deployment self.stages[name] = Stage( - name=name, deployment_id=deployment_id, variables=stage_variables) + name=name, deployment_id=deployment_id, variables=stage_variables + ) self.update_integration_mocks(name) return deployment @@ -455,7 +494,6 @@ class RestAPI(BaseModel): class APIGatewayBackend(BaseBackend): - def __init__(self, region_name): super(APIGatewayBackend, self).__init__() self.apis = {} @@ -497,10 +535,7 @@ class APIGatewayBackend(BaseBackend): def create_resource(self, function_id, parent_resource_id, path_part): api = self.get_rest_api(function_id) - child = api.add_child( - path=path_part, - parent_id=parent_resource_id, - ) + child = api.add_child(path=path_part, parent_id=parent_resource_id) return child def delete_resource(self, function_id, resource_id): @@ -529,13 +564,27 @@ class APIGatewayBackend(BaseBackend): api = self.get_rest_api(function_id) return api.get_stages() - def create_stage(self, function_id, stage_name, deploymentId, - variables=None, description='', cacheClusterEnabled=None, cacheClusterSize=None): + def create_stage( + self, + function_id, + stage_name, + deploymentId, + variables=None, + description="", + cacheClusterEnabled=None, + cacheClusterSize=None, + ): if variables is None: variables = {} api = self.get_rest_api(function_id) - api.create_stage(stage_name, deploymentId, variables=variables, - description=description, cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) + api.create_stage( + stage_name, + deploymentId, + variables=variables, + description=description, + cacheClusterEnabled=cacheClusterEnabled, + cacheClusterSize=cacheClusterSize, + ) return api.stages.get(stage_name) def update_stage(self, function_id, stage_name, patch_operations): @@ -550,21 +599,33 @@ class APIGatewayBackend(BaseBackend): method_response = method.get_response(response_code) return method_response - def create_method_response(self, function_id, resource_id, method_type, response_code): + def create_method_response( + self, function_id, resource_id, method_type, response_code + ): method = self.get_method(function_id, resource_id, method_type) method_response = method.create_response(response_code) return method_response - def delete_method_response(self, function_id, resource_id, method_type, response_code): + def delete_method_response( + self, function_id, resource_id, method_type, response_code + ): method = self.get_method(function_id, resource_id, method_type) method_response = method.delete_response(response_code) return method_response - def create_integration(self, function_id, resource_id, method_type, integration_type, uri, - request_templates=None): + def create_integration( + self, + function_id, + resource_id, + method_type, + integration_type, + uri, + request_templates=None, + ): resource = self.get_resource(function_id, resource_id) - integration = resource.add_integration(method_type, integration_type, uri, - request_templates=request_templates) + integration = resource.add_integration( + method_type, integration_type, uri, request_templates=request_templates + ) return integration def get_integration(self, function_id, resource_id, method_type): @@ -575,28 +636,32 @@ class APIGatewayBackend(BaseBackend): resource = self.get_resource(function_id, resource_id) return resource.delete_integration(method_type) - def create_integration_response(self, function_id, resource_id, method_type, status_code, selection_pattern): - integration = self.get_integration( - function_id, resource_id, method_type) + def create_integration_response( + self, function_id, resource_id, method_type, status_code, selection_pattern + ): + integration = self.get_integration(function_id, resource_id, method_type) integration_response = integration.create_integration_response( - status_code, selection_pattern) + status_code, selection_pattern + ) return integration_response - def get_integration_response(self, function_id, resource_id, method_type, status_code): - integration = self.get_integration( - function_id, resource_id, method_type) - integration_response = integration.get_integration_response( - status_code) + def get_integration_response( + self, function_id, resource_id, method_type, status_code + ): + integration = self.get_integration(function_id, resource_id, method_type) + integration_response = integration.get_integration_response(status_code) return integration_response - def delete_integration_response(self, function_id, resource_id, method_type, status_code): - integration = self.get_integration( - function_id, resource_id, method_type) - integration_response = integration.delete_integration_response( - status_code) + def delete_integration_response( + self, function_id, resource_id, method_type, status_code + ): + integration = self.get_integration(function_id, resource_id, method_type) + integration_response = integration.delete_integration_response(status_code) return integration_response - def create_deployment(self, function_id, name, description="", stage_variables=None): + def create_deployment( + self, function_id, name, description="", stage_variables=None + ): if stage_variables is None: stage_variables = {} api = self.get_rest_api(function_id) @@ -617,7 +682,7 @@ class APIGatewayBackend(BaseBackend): def create_apikey(self, payload): key = ApiKey(**payload) - self.keys[key['id']] = key + self.keys[key["id"]] = key return key def get_apikeys(self): @@ -636,7 +701,7 @@ class APIGatewayBackend(BaseBackend): def create_usage_plan(self, payload): plan = UsagePlan(**payload) - self.usage_plans[plan['id']] = plan + self.usage_plans[plan["id"]] = plan return plan def get_usage_plans(self, api_key_id=None): @@ -645,7 +710,7 @@ class APIGatewayBackend(BaseBackend): plans = [ plan for plan in plans - if self.usage_plan_keys.get(plan['id'], {}).get(api_key_id, False) + if self.usage_plan_keys.get(plan["id"], {}).get(api_key_id, False) ] return plans @@ -666,8 +731,13 @@ class APIGatewayBackend(BaseBackend): api_key = self.keys[key_id] - usage_plan_key = UsagePlanKey(id=key_id, type=payload["keyType"], name=api_key["name"], value=api_key["value"]) - self.usage_plan_keys[usage_plan_id][usage_plan_key['id']] = usage_plan_key + usage_plan_key = UsagePlanKey( + id=key_id, + type=payload["keyType"], + name=api_key["name"], + value=api_key["value"], + ) + self.usage_plan_keys[usage_plan_id][usage_plan_key["id"]] = usage_plan_key return usage_plan_key def get_usage_plan_keys(self, usage_plan_id): @@ -685,5 +755,5 @@ class APIGatewayBackend(BaseBackend): apigateway_backends = {} -for region_name in Session().get_available_regions('apigateway'): +for region_name in Session().get_available_regions("apigateway"): apigateway_backends[region_name] = APIGatewayBackend(region_name) diff --git a/moto/apigateway/responses.py b/moto/apigateway/responses.py index fa82705b1..db626eac8 100644 --- a/moto/apigateway/responses.py +++ b/moto/apigateway/responses.py @@ -8,7 +8,6 @@ from .exceptions import StageNotFoundException, ApiKeyNotFoundException class APIGatewayResponse(BaseResponse): - def _get_param(self, key): return json.loads(self.body).get(key) @@ -27,14 +26,12 @@ class APIGatewayResponse(BaseResponse): def restapis(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if self.method == 'GET': + if self.method == "GET": apis = self.backend.list_apis() - return 200, {}, json.dumps({"item": [ - api.to_dict() for api in apis - ]}) - elif self.method == 'POST': - name = self._get_param('name') - description = self._get_param('description') + return 200, {}, json.dumps({"item": [api.to_dict() for api in apis]}) + elif self.method == "POST": + name = self._get_param("name") + description = self._get_param("description") rest_api = self.backend.create_rest_api(name, description) return 200, {}, json.dumps(rest_api.to_dict()) @@ -42,10 +39,10 @@ class APIGatewayResponse(BaseResponse): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] - if self.method == 'GET': + if self.method == "GET": rest_api = self.backend.get_rest_api(function_id) return 200, {}, json.dumps(rest_api.to_dict()) - elif self.method == 'DELETE': + elif self.method == "DELETE": rest_api = self.backend.delete_rest_api(function_id) return 200, {}, json.dumps(rest_api.to_dict()) @@ -53,24 +50,25 @@ class APIGatewayResponse(BaseResponse): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] - if self.method == 'GET': + if self.method == "GET": resources = self.backend.list_resources(function_id) - return 200, {}, json.dumps({"item": [ - resource.to_dict() for resource in resources - ]}) + return ( + 200, + {}, + json.dumps({"item": [resource.to_dict() for resource in resources]}), + ) def resource_individual(self, request, full_url, headers): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] resource_id = self.path.split("/")[-1] - if self.method == 'GET': + if self.method == "GET": resource = self.backend.get_resource(function_id, resource_id) - elif self.method == 'POST': + elif self.method == "POST": path_part = self._get_param("pathPart") - resource = self.backend.create_resource( - function_id, resource_id, path_part) - elif self.method == 'DELETE': + resource = self.backend.create_resource(function_id, resource_id, path_part) + elif self.method == "DELETE": resource = self.backend.delete_resource(function_id, resource_id) return 200, {}, json.dumps(resource.to_dict()) @@ -81,14 +79,14 @@ class APIGatewayResponse(BaseResponse): resource_id = url_path_parts[4] method_type = url_path_parts[6] - if self.method == 'GET': - method = self.backend.get_method( - function_id, resource_id, method_type) + if self.method == "GET": + method = self.backend.get_method(function_id, resource_id, method_type) return 200, {}, json.dumps(method) - elif self.method == 'PUT': + elif self.method == "PUT": authorization_type = self._get_param("authorizationType") method = self.backend.create_method( - function_id, resource_id, method_type, authorization_type) + function_id, resource_id, method_type, authorization_type + ) return 200, {}, json.dumps(method) def resource_method_responses(self, request, full_url, headers): @@ -99,15 +97,18 @@ class APIGatewayResponse(BaseResponse): method_type = url_path_parts[6] response_code = url_path_parts[8] - if self.method == 'GET': + if self.method == "GET": method_response = self.backend.get_method_response( - function_id, resource_id, method_type, response_code) - elif self.method == 'PUT': + function_id, resource_id, method_type, response_code + ) + elif self.method == "PUT": method_response = self.backend.create_method_response( - function_id, resource_id, method_type, response_code) - elif self.method == 'DELETE': + function_id, resource_id, method_type, response_code + ) + elif self.method == "DELETE": method_response = self.backend.delete_method_response( - function_id, resource_id, method_type, response_code) + function_id, resource_id, method_type, response_code + ) return 200, {}, json.dumps(method_response) def restapis_stages(self, request, full_url, headers): @@ -115,21 +116,28 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") function_id = url_path_parts[2] - if self.method == 'POST': + if self.method == "POST": stage_name = self._get_param("stageName") deployment_id = self._get_param("deploymentId") - stage_variables = self._get_param_with_default_value( - 'variables', {}) - description = self._get_param_with_default_value('description', '') + stage_variables = self._get_param_with_default_value("variables", {}) + description = self._get_param_with_default_value("description", "") cacheClusterEnabled = self._get_param_with_default_value( - 'cacheClusterEnabled', False) + "cacheClusterEnabled", False + ) cacheClusterSize = self._get_param_with_default_value( - 'cacheClusterSize', None) + "cacheClusterSize", None + ) - stage_response = self.backend.create_stage(function_id, stage_name, deployment_id, - variables=stage_variables, description=description, - cacheClusterEnabled=cacheClusterEnabled, cacheClusterSize=cacheClusterSize) - elif self.method == 'GET': + stage_response = self.backend.create_stage( + function_id, + stage_name, + deployment_id, + variables=stage_variables, + description=description, + cacheClusterEnabled=cacheClusterEnabled, + cacheClusterSize=cacheClusterSize, + ) + elif self.method == "GET": stages = self.backend.get_stages(function_id) return 200, {}, json.dumps({"item": stages}) @@ -141,16 +149,22 @@ class APIGatewayResponse(BaseResponse): function_id = url_path_parts[2] stage_name = url_path_parts[4] - if self.method == 'GET': + if self.method == "GET": try: - stage_response = self.backend.get_stage( - function_id, stage_name) + stage_response = self.backend.get_stage(function_id, stage_name) except StageNotFoundException as error: - return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) - elif self.method == 'PATCH': - patch_operations = self._get_param('patchOperations') + return ( + error.code, + {}, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) + elif self.method == "PATCH": + patch_operations = self._get_param("patchOperations") stage_response = self.backend.update_stage( - function_id, stage_name, patch_operations) + function_id, stage_name, patch_operations + ) return 200, {}, json.dumps(stage_response) def integrations(self, request, full_url, headers): @@ -160,18 +174,26 @@ class APIGatewayResponse(BaseResponse): resource_id = url_path_parts[4] method_type = url_path_parts[6] - if self.method == 'GET': + if self.method == "GET": integration_response = self.backend.get_integration( - function_id, resource_id, method_type) - elif self.method == 'PUT': - integration_type = self._get_param('type') - uri = self._get_param('uri') - request_templates = self._get_param('requestTemplates') + function_id, resource_id, method_type + ) + elif self.method == "PUT": + integration_type = self._get_param("type") + uri = self._get_param("uri") + request_templates = self._get_param("requestTemplates") integration_response = self.backend.create_integration( - function_id, resource_id, method_type, integration_type, uri, request_templates=request_templates) - elif self.method == 'DELETE': + function_id, + resource_id, + method_type, + integration_type, + uri, + request_templates=request_templates, + ) + elif self.method == "DELETE": integration_response = self.backend.delete_integration( - function_id, resource_id, method_type) + function_id, resource_id, method_type + ) return 200, {}, json.dumps(integration_response) def integration_responses(self, request, full_url, headers): @@ -182,16 +204,16 @@ class APIGatewayResponse(BaseResponse): method_type = url_path_parts[6] status_code = url_path_parts[9] - if self.method == 'GET': + if self.method == "GET": integration_response = self.backend.get_integration_response( function_id, resource_id, method_type, status_code ) - elif self.method == 'PUT': + elif self.method == "PUT": selection_pattern = self._get_param("selectionPattern") integration_response = self.backend.create_integration_response( function_id, resource_id, method_type, status_code, selection_pattern ) - elif self.method == 'DELETE': + elif self.method == "DELETE": integration_response = self.backend.delete_integration_response( function_id, resource_id, method_type, status_code ) @@ -201,16 +223,16 @@ class APIGatewayResponse(BaseResponse): self.setup_class(request, full_url, headers) function_id = self.path.replace("/restapis/", "", 1).split("/")[0] - if self.method == 'GET': + if self.method == "GET": deployments = self.backend.get_deployments(function_id) return 200, {}, json.dumps({"item": deployments}) - elif self.method == 'POST': + elif self.method == "POST": name = self._get_param("stageName") description = self._get_param_with_default_value("description", "") - stage_variables = self._get_param_with_default_value( - 'variables', {}) + stage_variables = self._get_param_with_default_value("variables", {}) deployment = self.backend.create_deployment( - function_id, name, description, stage_variables) + function_id, name, description, stage_variables + ) return 200, {}, json.dumps(deployment) def individual_deployment(self, request, full_url, headers): @@ -219,20 +241,18 @@ class APIGatewayResponse(BaseResponse): function_id = url_path_parts[2] deployment_id = url_path_parts[4] - if self.method == 'GET': - deployment = self.backend.get_deployment( - function_id, deployment_id) - elif self.method == 'DELETE': - deployment = self.backend.delete_deployment( - function_id, deployment_id) + if self.method == "GET": + deployment = self.backend.get_deployment(function_id, deployment_id) + elif self.method == "DELETE": + deployment = self.backend.delete_deployment(function_id, deployment_id) return 200, {}, json.dumps(deployment) def apikeys(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if self.method == 'POST': + if self.method == "POST": apikey_response = self.backend.create_apikey(json.loads(self.body)) - elif self.method == 'GET': + elif self.method == "GET": apikeys_response = self.backend.get_apikeys() return 200, {}, json.dumps({"item": apikeys_response}) return 200, {}, json.dumps(apikey_response) @@ -243,21 +263,21 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") apikey = url_path_parts[2] - if self.method == 'GET': + if self.method == "GET": apikey_response = self.backend.get_apikey(apikey) - elif self.method == 'PATCH': - patch_operations = self._get_param('patchOperations') + elif self.method == "PATCH": + patch_operations = self._get_param("patchOperations") apikey_response = self.backend.update_apikey(apikey, patch_operations) - elif self.method == 'DELETE': + elif self.method == "DELETE": apikey_response = self.backend.delete_apikey(apikey) return 200, {}, json.dumps(apikey_response) def usage_plans(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if self.method == 'POST': + if self.method == "POST": usage_plan_response = self.backend.create_usage_plan(json.loads(self.body)) - elif self.method == 'GET': + elif self.method == "GET": api_key_id = self.querystring.get("keyId", [None])[0] usage_plans_response = self.backend.get_usage_plans(api_key_id=api_key_id) return 200, {}, json.dumps({"item": usage_plans_response}) @@ -269,9 +289,9 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") usage_plan = url_path_parts[2] - if self.method == 'GET': + if self.method == "GET": usage_plan_response = self.backend.get_usage_plan(usage_plan) - elif self.method == 'DELETE': + elif self.method == "DELETE": usage_plan_response = self.backend.delete_usage_plan(usage_plan) return 200, {}, json.dumps(usage_plan_response) @@ -281,13 +301,21 @@ class APIGatewayResponse(BaseResponse): url_path_parts = self.path.split("/") usage_plan_id = url_path_parts[2] - if self.method == 'POST': + if self.method == "POST": try: - usage_plan_response = self.backend.create_usage_plan_key(usage_plan_id, json.loads(self.body)) + usage_plan_response = self.backend.create_usage_plan_key( + usage_plan_id, json.loads(self.body) + ) except ApiKeyNotFoundException as error: - return error.code, {}, '{{"message":"{0}","code":"{1}"}}'.format(error.message, error.error_type) + return ( + error.code, + {}, + '{{"message":"{0}","code":"{1}"}}'.format( + error.message, error.error_type + ), + ) - elif self.method == 'GET': + elif self.method == "GET": usage_plans_response = self.backend.get_usage_plan_keys(usage_plan_id) return 200, {}, json.dumps({"item": usage_plans_response}) @@ -300,8 +328,10 @@ class APIGatewayResponse(BaseResponse): usage_plan_id = url_path_parts[2] key_id = url_path_parts[4] - if self.method == 'GET': + if self.method == "GET": usage_plan_response = self.backend.get_usage_plan_key(usage_plan_id, key_id) - elif self.method == 'DELETE': - usage_plan_response = self.backend.delete_usage_plan_key(usage_plan_id, key_id) + elif self.method == "DELETE": + usage_plan_response = self.backend.delete_usage_plan_key( + usage_plan_id, key_id + ) return 200, {}, json.dumps(usage_plan_response) diff --git a/moto/apigateway/urls.py b/moto/apigateway/urls.py index 5c6d372fa..bb2b2d216 100644 --- a/moto/apigateway/urls.py +++ b/moto/apigateway/urls.py @@ -1,27 +1,25 @@ from __future__ import unicode_literals from .responses import APIGatewayResponse -url_bases = [ - "https?://apigateway.(.+).amazonaws.com" -] +url_bases = ["https?://apigateway.(.+).amazonaws.com"] url_paths = { - '{0}/restapis$': APIGatewayResponse().restapis, - '{0}/restapis/(?P[^/]+)/?$': APIGatewayResponse().restapis_individual, - '{0}/restapis/(?P[^/]+)/resources$': APIGatewayResponse().resources, - '{0}/restapis/(?P[^/]+)/stages$': APIGatewayResponse().restapis_stages, - '{0}/restapis/(?P[^/]+)/stages/(?P[^/]+)/?$': APIGatewayResponse().stages, - '{0}/restapis/(?P[^/]+)/deployments$': APIGatewayResponse().deployments, - '{0}/restapis/(?P[^/]+)/deployments/(?P[^/]+)/?$': APIGatewayResponse().individual_deployment, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/?$': APIGatewayResponse().resource_individual, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/?$': APIGatewayResponse().resource_methods, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/responses/(?P\d+)$': APIGatewayResponse().resource_method_responses, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/?$': APIGatewayResponse().integrations, - '{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/responses/(?P\d+)/?$': APIGatewayResponse().integration_responses, - '{0}/apikeys$': APIGatewayResponse().apikeys, - '{0}/apikeys/(?P[^/]+)': APIGatewayResponse().apikey_individual, - '{0}/usageplans$': APIGatewayResponse().usage_plans, - '{0}/usageplans/(?P[^/]+)/?$': APIGatewayResponse().usage_plan_individual, - '{0}/usageplans/(?P[^/]+)/keys$': APIGatewayResponse().usage_plan_keys, - '{0}/usageplans/(?P[^/]+)/keys/(?P[^/]+)/?$': APIGatewayResponse().usage_plan_key_individual, + "{0}/restapis$": APIGatewayResponse().restapis, + "{0}/restapis/(?P[^/]+)/?$": APIGatewayResponse().restapis_individual, + "{0}/restapis/(?P[^/]+)/resources$": APIGatewayResponse().resources, + "{0}/restapis/(?P[^/]+)/stages$": APIGatewayResponse().restapis_stages, + "{0}/restapis/(?P[^/]+)/stages/(?P[^/]+)/?$": APIGatewayResponse().stages, + "{0}/restapis/(?P[^/]+)/deployments$": APIGatewayResponse().deployments, + "{0}/restapis/(?P[^/]+)/deployments/(?P[^/]+)/?$": APIGatewayResponse().individual_deployment, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/?$": APIGatewayResponse().resource_individual, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/?$": APIGatewayResponse().resource_methods, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/responses/(?P\d+)$": APIGatewayResponse().resource_method_responses, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/?$": APIGatewayResponse().integrations, + "{0}/restapis/(?P[^/]+)/resources/(?P[^/]+)/methods/(?P[^/]+)/integration/responses/(?P\d+)/?$": APIGatewayResponse().integration_responses, + "{0}/apikeys$": APIGatewayResponse().apikeys, + "{0}/apikeys/(?P[^/]+)": APIGatewayResponse().apikey_individual, + "{0}/usageplans$": APIGatewayResponse().usage_plans, + "{0}/usageplans/(?P[^/]+)/?$": APIGatewayResponse().usage_plan_individual, + "{0}/usageplans/(?P[^/]+)/keys$": APIGatewayResponse().usage_plan_keys, + "{0}/usageplans/(?P[^/]+)/keys/(?P[^/]+)/?$": APIGatewayResponse().usage_plan_key_individual, } diff --git a/moto/apigateway/utils.py b/moto/apigateway/utils.py index 31f8060b0..807848f66 100644 --- a/moto/apigateway/utils.py +++ b/moto/apigateway/utils.py @@ -7,4 +7,4 @@ import string def create_id(): size = 10 chars = list(range(10)) + list(string.ascii_lowercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) diff --git a/moto/athena/__init__.py b/moto/athena/__init__.py index c7bfa2b1f..3c1dc15c5 100644 --- a/moto/athena/__init__.py +++ b/moto/athena/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import athena_backends from ..core.models import base_decorator, deprecated_base_decorator -athena_backend = athena_backends['us-east-1'] +athena_backend = athena_backends["us-east-1"] mock_athena = base_decorator(athena_backends) mock_athena_deprecated = deprecated_base_decorator(athena_backends) diff --git a/moto/athena/exceptions.py b/moto/athena/exceptions.py index 1faa54731..96b35556a 100644 --- a/moto/athena/exceptions.py +++ b/moto/athena/exceptions.py @@ -5,14 +5,15 @@ from werkzeug.exceptions import BadRequest class AthenaClientError(BadRequest): - def __init__(self, code, message): super(AthenaClientError, self).__init__() - self.description = json.dumps({ - "Error": { - "Code": code, - "Message": message, - 'Type': "InvalidRequestException", - }, - 'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1', - }) + self.description = json.dumps( + { + "Error": { + "Code": code, + "Message": message, + "Type": "InvalidRequestException", + }, + "RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1", + } + ) diff --git a/moto/athena/models.py b/moto/athena/models.py index d6d88cde8..7353e6a6e 100644 --- a/moto/athena/models.py +++ b/moto/athena/models.py @@ -19,31 +19,30 @@ class TaggableResourceMixin(object): @property def arn(self): return "arn:aws:athena:{region}:{account_id}:{resource_name}".format( - region=self.region, - account_id=ACCOUNT_ID, - resource_name=self.resource_name) + region=self.region, account_id=ACCOUNT_ID, resource_name=self.resource_name + ) def create_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags - if tag_set['Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def delete_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags - if tag_set['Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] return self.tags class WorkGroup(TaggableResourceMixin, BaseModel): - resource_type = 'workgroup' - state = 'ENABLED' + resource_type = "workgroup" + state = "ENABLED" def __init__(self, athena_backend, name, configuration, description, tags): self.region_name = athena_backend.region_name - super(WorkGroup, self).__init__(self.region_name, "workgroup/{}".format(name), tags) + super(WorkGroup, self).__init__( + self.region_name, "workgroup/{}".format(name), tags + ) self.athena_backend = athena_backend self.name = name self.description = description @@ -66,14 +65,17 @@ class AthenaBackend(BaseBackend): return work_group def list_work_groups(self): - return [{ - 'Name': wg.name, - 'State': wg.state, - 'Description': wg.description, - 'CreationTime': time.time(), - } for wg in self.work_groups.values()] + return [ + { + "Name": wg.name, + "State": wg.state, + "Description": wg.description, + "CreationTime": time.time(), + } + for wg in self.work_groups.values() + ] athena_backends = {} -for region in boto3.Session().get_available_regions('athena'): +for region in boto3.Session().get_available_regions("athena"): athena_backends[region] = AthenaBackend(region) diff --git a/moto/athena/responses.py b/moto/athena/responses.py index 13d33c129..80cac5d62 100644 --- a/moto/athena/responses.py +++ b/moto/athena/responses.py @@ -5,31 +5,37 @@ from .models import athena_backends class AthenaResponse(BaseResponse): - @property def athena_backend(self): return athena_backends[self.region] def create_work_group(self): - name = self._get_param('Name') - description = self._get_param('Description') - configuration = self._get_param('Configuration') - tags = self._get_param('Tags') - work_group = self.athena_backend.create_work_group(name, configuration, description, tags) + name = self._get_param("Name") + description = self._get_param("Description") + configuration = self._get_param("Configuration") + tags = self._get_param("Tags") + work_group = self.athena_backend.create_work_group( + name, configuration, description, tags + ) if not work_group: - return json.dumps({ - '__type': 'InvalidRequestException', - 'Message': 'WorkGroup already exists', - }), dict(status=400) - return json.dumps({ - "CreateWorkGroupResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return ( + json.dumps( + { + "__type": "InvalidRequestException", + "Message": "WorkGroup already exists", + } + ), + dict(status=400), + ) + return json.dumps( + { + "CreateWorkGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def list_work_groups(self): - return json.dumps({ - "WorkGroups": self.athena_backend.list_work_groups() - }) + return json.dumps({"WorkGroups": self.athena_backend.list_work_groups()}) diff --git a/moto/athena/urls.py b/moto/athena/urls.py index bdd4ebc1e..4f8fdf7ee 100644 --- a/moto/athena/urls.py +++ b/moto/athena/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import AthenaResponse -url_bases = [ - "https?://athena.(.+).amazonaws.com", -] +url_bases = ["https?://athena.(.+).amazonaws.com"] -url_paths = { - '{0}/$': AthenaResponse.dispatch, -} +url_paths = {"{0}/$": AthenaResponse.dispatch} diff --git a/moto/autoscaling/__init__.py b/moto/autoscaling/__init__.py index b2b8b0bae..13c1adb16 100644 --- a/moto/autoscaling/__init__.py +++ b/moto/autoscaling/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import autoscaling_backends from ..core.models import base_decorator, deprecated_base_decorator -autoscaling_backend = autoscaling_backends['us-east-1'] +autoscaling_backend = autoscaling_backends["us-east-1"] mock_autoscaling = base_decorator(autoscaling_backends) mock_autoscaling_deprecated = deprecated_base_decorator(autoscaling_backends) diff --git a/moto/autoscaling/exceptions.py b/moto/autoscaling/exceptions.py index 74f62241d..6f73eff8f 100644 --- a/moto/autoscaling/exceptions.py +++ b/moto/autoscaling/exceptions.py @@ -12,13 +12,12 @@ class ResourceContentionError(RESTError): def __init__(self): super(ResourceContentionError, self).__init__( "ResourceContentionError", - "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).") + "You already have a pending update to an Auto Scaling resource (for example, a group, instance, or load balancer).", + ) class InvalidInstanceError(AutoscalingClientError): - def __init__(self, instance_id): super(InvalidInstanceError, self).__init__( - "ValidationError", - "Instance [{0}] is invalid." - .format(instance_id)) + "ValidationError", "Instance [{0}] is invalid.".format(instance_id) + ) diff --git a/moto/autoscaling/models.py b/moto/autoscaling/models.py index 422075951..45ee7d192 100644 --- a/moto/autoscaling/models.py +++ b/moto/autoscaling/models.py @@ -12,7 +12,9 @@ from moto.elb import elb_backends from moto.elbv2 import elbv2_backends from moto.elb.exceptions import LoadBalancerNotFoundError from .exceptions import ( - AutoscalingClientError, ResourceContentionError, InvalidInstanceError + AutoscalingClientError, + ResourceContentionError, + InvalidInstanceError, ) # http://docs.aws.amazon.com/AutoScaling/latest/DeveloperGuide/AS_Concepts.html#Cooldown @@ -22,8 +24,13 @@ ASG_NAME_TAG = "aws:autoscaling:groupName" class InstanceState(object): - def __init__(self, instance, lifecycle_state="InService", - health_status="Healthy", protected_from_scale_in=False): + def __init__( + self, + instance, + lifecycle_state="InService", + health_status="Healthy", + protected_from_scale_in=False, + ): self.instance = instance self.lifecycle_state = lifecycle_state self.health_status = health_status @@ -31,8 +38,16 @@ class InstanceState(object): class FakeScalingPolicy(BaseModel): - def __init__(self, name, policy_type, adjustment_type, as_name, scaling_adjustment, - cooldown, autoscaling_backend): + def __init__( + self, + name, + policy_type, + adjustment_type, + as_name, + scaling_adjustment, + cooldown, + autoscaling_backend, + ): self.name = name self.policy_type = policy_type self.adjustment_type = adjustment_type @@ -45,21 +60,38 @@ class FakeScalingPolicy(BaseModel): self.autoscaling_backend = autoscaling_backend def execute(self): - if self.adjustment_type == 'ExactCapacity': + if self.adjustment_type == "ExactCapacity": self.autoscaling_backend.set_desired_capacity( - self.as_name, self.scaling_adjustment) - elif self.adjustment_type == 'ChangeInCapacity': + self.as_name, self.scaling_adjustment + ) + elif self.adjustment_type == "ChangeInCapacity": self.autoscaling_backend.change_capacity( - self.as_name, self.scaling_adjustment) - elif self.adjustment_type == 'PercentChangeInCapacity': + self.as_name, self.scaling_adjustment + ) + elif self.adjustment_type == "PercentChangeInCapacity": self.autoscaling_backend.change_capacity_percent( - self.as_name, self.scaling_adjustment) + self.as_name, self.scaling_adjustment + ) class FakeLaunchConfiguration(BaseModel): - def __init__(self, name, image_id, key_name, ramdisk_id, kernel_id, security_groups, user_data, - instance_type, instance_monitoring, instance_profile_name, - spot_price, ebs_optimized, associate_public_ip_address, block_device_mapping_dict): + def __init__( + self, + name, + image_id, + key_name, + ramdisk_id, + kernel_id, + security_groups, + user_data, + instance_type, + instance_monitoring, + instance_profile_name, + spot_price, + ebs_optimized, + associate_public_ip_address, + block_device_mapping_dict, + ): self.name = name self.image_id = image_id self.key_name = key_name @@ -80,8 +112,8 @@ class FakeLaunchConfiguration(BaseModel): config = backend.create_launch_configuration( name=name, image_id=instance.image_id, - kernel_id='', - ramdisk_id='', + kernel_id="", + ramdisk_id="", key_name=instance.key_name, security_groups=instance.security_groups, user_data=instance.user_data, @@ -91,13 +123,15 @@ class FakeLaunchConfiguration(BaseModel): spot_price=None, ebs_optimized=instance.ebs_optimized, associate_public_ip_address=instance.associate_public_ip, - block_device_mappings=instance.block_device_mapping + block_device_mappings=instance.block_device_mapping, ) return config @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] instance_profile_name = properties.get("IamInstanceProfile") @@ -115,20 +149,26 @@ class FakeLaunchConfiguration(BaseModel): instance_profile_name=instance_profile_name, spot_price=properties.get("SpotPrice"), ebs_optimized=properties.get("EbsOptimized"), - associate_public_ip_address=properties.get( - "AssociatePublicIpAddress"), - block_device_mappings=properties.get("BlockDeviceMapping.member") + associate_public_ip_address=properties.get("AssociatePublicIpAddress"), + block_device_mappings=properties.get("BlockDeviceMapping.member"), ) return config @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = autoscaling_backends[region_name] try: backend.delete_launch_configuration(resource_name) @@ -153,34 +193,49 @@ class FakeLaunchConfiguration(BaseModel): @property def instance_monitoring_enabled(self): if self.instance_monitoring: - return 'true' - return 'false' + return "true" + return "false" def _parse_block_device_mappings(self): block_device_map = BlockDeviceMapping() for mapping in self.block_device_mapping_dict: block_type = BlockDeviceType() - mount_point = mapping.get('device_name') - if 'ephemeral' in mapping.get('virtual_name', ''): - block_type.ephemeral_name = mapping.get('virtual_name') + mount_point = mapping.get("device_name") + if "ephemeral" in mapping.get("virtual_name", ""): + block_type.ephemeral_name = mapping.get("virtual_name") else: - block_type.volume_type = mapping.get('ebs._volume_type') - block_type.snapshot_id = mapping.get('ebs._snapshot_id') + block_type.volume_type = mapping.get("ebs._volume_type") + block_type.snapshot_id = mapping.get("ebs._snapshot_id") block_type.delete_on_termination = mapping.get( - 'ebs._delete_on_termination') - block_type.size = mapping.get('ebs._volume_size') - block_type.iops = mapping.get('ebs._iops') + "ebs._delete_on_termination" + ) + block_type.size = mapping.get("ebs._volume_size") + block_type.iops = mapping.get("ebs._iops") block_device_map[mount_point] = block_type return block_device_map class FakeAutoScalingGroup(BaseModel): - def __init__(self, name, availability_zones, desired_capacity, max_size, - min_size, launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, health_check_type, - load_balancers, target_group_arns, placement_group, termination_policies, - autoscaling_backend, tags, - new_instances_protected_from_scale_in=False): + def __init__( + self, + name, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + load_balancers, + target_group_arns, + placement_group, + termination_policies, + autoscaling_backend, + tags, + new_instances_protected_from_scale_in=False, + ): self.autoscaling_backend = autoscaling_backend self.name = name @@ -190,17 +245,22 @@ class FakeAutoScalingGroup(BaseModel): self.min_size = min_size self.launch_config = self.autoscaling_backend.launch_configurations[ - launch_config_name] + launch_config_name + ] self.launch_config_name = launch_config_name - self.default_cooldown = default_cooldown if default_cooldown else DEFAULT_COOLDOWN + self.default_cooldown = ( + default_cooldown if default_cooldown else DEFAULT_COOLDOWN + ) self.health_check_period = health_check_period self.health_check_type = health_check_type if health_check_type else "EC2" self.load_balancers = load_balancers self.target_group_arns = target_group_arns self.placement_group = placement_group self.termination_policies = termination_policies - self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in + self.new_instances_protected_from_scale_in = ( + new_instances_protected_from_scale_in + ) self.suspended_processes = [] self.instance_states = [] @@ -215,8 +275,10 @@ class FakeAutoScalingGroup(BaseModel): if vpc_zone_identifier: # extract azs for vpcs - subnet_ids = vpc_zone_identifier.split(',') - subnets = self.autoscaling_backend.ec2_backend.get_all_subnets(subnet_ids=subnet_ids) + subnet_ids = vpc_zone_identifier.split(",") + subnets = self.autoscaling_backend.ec2_backend.get_all_subnets( + subnet_ids=subnet_ids + ) vpc_zones = [subnet.availability_zone for subnet in subnets] if availability_zones and set(availability_zones) != set(vpc_zones): @@ -229,7 +291,7 @@ class FakeAutoScalingGroup(BaseModel): if not update: raise AutoscalingClientError( "ValidationError", - "At least one Availability Zone or VPC Subnet is required." + "At least one Availability Zone or VPC Subnet is required.", ) return @@ -237,8 +299,10 @@ class FakeAutoScalingGroup(BaseModel): self.vpc_zone_identifier = vpc_zone_identifier @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] launch_config_name = properties.get("LaunchConfigurationName") load_balancer_names = properties.get("LoadBalancerNames", []) @@ -253,7 +317,8 @@ class FakeAutoScalingGroup(BaseModel): min_size=properties.get("MinSize"), launch_config_name=launch_config_name, vpc_zone_identifier=( - ','.join(properties.get("VPCZoneIdentifier", [])) or None), + ",".join(properties.get("VPCZoneIdentifier", [])) or None + ), default_cooldown=properties.get("Cooldown"), health_check_period=properties.get("HealthCheckGracePeriod"), health_check_type=properties.get("HealthCheckType"), @@ -263,18 +328,26 @@ class FakeAutoScalingGroup(BaseModel): termination_policies=properties.get("TerminationPolicies", []), tags=properties.get("Tags", []), new_instances_protected_from_scale_in=properties.get( - "NewInstancesProtectedFromScaleIn", False) + "NewInstancesProtectedFromScaleIn", False + ), ) return group @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = autoscaling_backends[region_name] try: backend.delete_auto_scaling_group(resource_name) @@ -289,11 +362,21 @@ class FakeAutoScalingGroup(BaseModel): def physical_resource_id(self): return self.name - def update(self, availability_zones, desired_capacity, max_size, min_size, - launch_config_name, vpc_zone_identifier, default_cooldown, - health_check_period, health_check_type, - placement_group, termination_policies, - new_instances_protected_from_scale_in=None): + def update( + self, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + placement_group, + termination_policies, + new_instances_protected_from_scale_in=None, + ): self._set_azs_and_vpcs(availability_zones, vpc_zone_identifier, update=True) if max_size is not None: @@ -309,14 +392,17 @@ class FakeAutoScalingGroup(BaseModel): if launch_config_name: self.launch_config = self.autoscaling_backend.launch_configurations[ - launch_config_name] + launch_config_name + ] self.launch_config_name = launch_config_name if health_check_period is not None: self.health_check_period = health_check_period if health_check_type is not None: self.health_check_type = health_check_type if new_instances_protected_from_scale_in is not None: - self.new_instances_protected_from_scale_in = new_instances_protected_from_scale_in + self.new_instances_protected_from_scale_in = ( + new_instances_protected_from_scale_in + ) if desired_capacity is not None: self.set_desired_capacity(desired_capacity) @@ -342,25 +428,30 @@ class FakeAutoScalingGroup(BaseModel): # Need to remove some instances count_to_remove = curr_instance_count - self.desired_capacity instances_to_remove = [ # only remove unprotected - state for state in self.instance_states + state + for state in self.instance_states if not state.protected_from_scale_in ][:count_to_remove] if instances_to_remove: # just in case not instances to remove instance_ids_to_remove = [ - instance.instance.id for instance in instances_to_remove] + instance.instance.id for instance in instances_to_remove + ] self.autoscaling_backend.ec2_backend.terminate_instances( - instance_ids_to_remove) - self.instance_states = list(set(self.instance_states) - set(instances_to_remove)) + instance_ids_to_remove + ) + self.instance_states = list( + set(self.instance_states) - set(instances_to_remove) + ) def get_propagated_tags(self): propagated_tags = {} for tag in self.tags: # boto uses 'propagate_at_launch # boto3 and cloudformation use PropagateAtLaunch - if 'propagate_at_launch' in tag and tag['propagate_at_launch'] == 'true': - propagated_tags[tag['key']] = tag['value'] - if 'PropagateAtLaunch' in tag and tag['PropagateAtLaunch']: - propagated_tags[tag['Key']] = tag['Value'] + if "propagate_at_launch" in tag and tag["propagate_at_launch"] == "true": + propagated_tags[tag["key"]] = tag["value"] + if "PropagateAtLaunch" in tag and tag["PropagateAtLaunch"]: + propagated_tags[tag["Key"]] = tag["Value"] return propagated_tags def replace_autoscaling_group_instances(self, count_needed, propagated_tags): @@ -371,15 +462,17 @@ class FakeAutoScalingGroup(BaseModel): self.launch_config.user_data, self.launch_config.security_groups, instance_type=self.launch_config.instance_type, - tags={'instance': propagated_tags}, + tags={"instance": propagated_tags}, placement=random.choice(self.availability_zones), ) for instance in reservation.instances: instance.autoscaling_group = self - self.instance_states.append(InstanceState( - instance, - protected_from_scale_in=self.new_instances_protected_from_scale_in, - )) + self.instance_states.append( + InstanceState( + instance, + protected_from_scale_in=self.new_instances_protected_from_scale_in, + ) + ) def append_target_groups(self, target_group_arns): append = [x for x in target_group_arns if x not in self.target_group_arns] @@ -402,10 +495,23 @@ class AutoScalingBackend(BaseBackend): self.__dict__ = {} self.__init__(ec2_backend, elb_backend, elbv2_backend) - def create_launch_configuration(self, name, image_id, key_name, kernel_id, ramdisk_id, - security_groups, user_data, instance_type, - instance_monitoring, instance_profile_name, - spot_price, ebs_optimized, associate_public_ip_address, block_device_mappings): + def create_launch_configuration( + self, + name, + image_id, + key_name, + kernel_id, + ramdisk_id, + security_groups, + user_data, + instance_type, + instance_monitoring, + instance_profile_name, + spot_price, + ebs_optimized, + associate_public_ip_address, + block_device_mappings, + ): launch_configuration = FakeLaunchConfiguration( name=name, image_id=image_id, @@ -428,23 +534,37 @@ class AutoScalingBackend(BaseBackend): def describe_launch_configurations(self, names): configurations = self.launch_configurations.values() if names: - return [configuration for configuration in configurations if configuration.name in names] + return [ + configuration + for configuration in configurations + if configuration.name in names + ] else: return list(configurations) def delete_launch_configuration(self, launch_configuration_name): self.launch_configurations.pop(launch_configuration_name, None) - def create_auto_scaling_group(self, name, availability_zones, - desired_capacity, max_size, min_size, - launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, - health_check_type, load_balancers, - target_group_arns, placement_group, - termination_policies, tags, - new_instances_protected_from_scale_in=False, - instance_id=None): - + def create_auto_scaling_group( + self, + name, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + load_balancers, + target_group_arns, + placement_group, + termination_policies, + tags, + new_instances_protected_from_scale_in=False, + instance_id=None, + ): def make_int(value): return int(value) if value is not None else value @@ -460,7 +580,9 @@ class AutoScalingBackend(BaseBackend): try: instance = self.ec2_backend.get_instance(instance_id) launch_config_name = name - FakeLaunchConfiguration.create_from_instance(launch_config_name, instance, self) + FakeLaunchConfiguration.create_from_instance( + launch_config_name, instance, self + ) except InvalidInstanceIdError: raise InvalidInstanceError(instance_id) @@ -489,19 +611,37 @@ class AutoScalingBackend(BaseBackend): self.update_attached_target_groups(group.name) return group - def update_auto_scaling_group(self, name, availability_zones, - desired_capacity, max_size, min_size, - launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, - health_check_type, placement_group, - termination_policies, - new_instances_protected_from_scale_in=None): + def update_auto_scaling_group( + self, + name, + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + placement_group, + termination_policies, + new_instances_protected_from_scale_in=None, + ): group = self.autoscaling_groups[name] - group.update(availability_zones, desired_capacity, max_size, - min_size, launch_config_name, vpc_zone_identifier, - default_cooldown, health_check_period, health_check_type, - placement_group, termination_policies, - new_instances_protected_from_scale_in=new_instances_protected_from_scale_in) + group.update( + availability_zones, + desired_capacity, + max_size, + min_size, + launch_config_name, + vpc_zone_identifier, + default_cooldown, + health_check_period, + health_check_type, + placement_group, + termination_policies, + new_instances_protected_from_scale_in=new_instances_protected_from_scale_in, + ) return group def describe_auto_scaling_groups(self, names): @@ -537,32 +677,48 @@ class AutoScalingBackend(BaseBackend): for x in instance_ids ] for instance in new_instances: - self.ec2_backend.create_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) + self.ec2_backend.create_tags( + [instance.instance.id], {ASG_NAME_TAG: group.name} + ) group.instance_states.extend(new_instances) self.update_attached_elbs(group.name) - def set_instance_health(self, instance_id, health_status, should_respect_grace_period): + def set_instance_health( + self, instance_id, health_status, should_respect_grace_period + ): instance = self.ec2_backend.get_instance(instance_id) - instance_state = next(instance_state for group in self.autoscaling_groups.values() - for instance_state in group.instance_states if instance_state.instance.id == instance.id) + instance_state = next( + instance_state + for group in self.autoscaling_groups.values() + for instance_state in group.instance_states + if instance_state.instance.id == instance.id + ) instance_state.health_status = health_status def detach_instances(self, group_name, instance_ids, should_decrement): group = self.autoscaling_groups[group_name] original_size = len(group.instance_states) - detached_instances = [x for x in group.instance_states if x.instance.id in instance_ids] + detached_instances = [ + x for x in group.instance_states if x.instance.id in instance_ids + ] for instance in detached_instances: - self.ec2_backend.delete_tags([instance.instance.id], {ASG_NAME_TAG: group.name}) + self.ec2_backend.delete_tags( + [instance.instance.id], {ASG_NAME_TAG: group.name} + ) - new_instance_state = [x for x in group.instance_states if x.instance.id not in instance_ids] + new_instance_state = [ + x for x in group.instance_states if x.instance.id not in instance_ids + ] group.instance_states = new_instance_state if should_decrement: group.desired_capacity = original_size - len(instance_ids) else: count_needed = len(instance_ids) - group.replace_autoscaling_group_instances(count_needed, group.get_propagated_tags()) + group.replace_autoscaling_group_instances( + count_needed, group.get_propagated_tags() + ) self.update_attached_elbs(group_name) return detached_instances @@ -593,19 +749,32 @@ class AutoScalingBackend(BaseBackend): desired_capacity = int(desired_capacity) self.set_desired_capacity(group_name, desired_capacity) - def create_autoscaling_policy(self, name, policy_type, adjustment_type, as_name, - scaling_adjustment, cooldown): - policy = FakeScalingPolicy(name, policy_type, adjustment_type, as_name, - scaling_adjustment, cooldown, self) + def create_autoscaling_policy( + self, name, policy_type, adjustment_type, as_name, scaling_adjustment, cooldown + ): + policy = FakeScalingPolicy( + name, + policy_type, + adjustment_type, + as_name, + scaling_adjustment, + cooldown, + self, + ) self.policies[name] = policy return policy - def describe_policies(self, autoscaling_group_name=None, policy_names=None, policy_types=None): - return [policy for policy in self.policies.values() - if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) and - (not policy_names or policy.name in policy_names) and - (not policy_types or policy.policy_type in policy_types)] + def describe_policies( + self, autoscaling_group_name=None, policy_names=None, policy_types=None + ): + return [ + policy + for policy in self.policies.values() + if (not autoscaling_group_name or policy.as_name == autoscaling_group_name) + and (not policy_names or policy.name in policy_names) + and (not policy_types or policy.policy_type in policy_types) + ] def delete_policy(self, group_name): self.policies.pop(group_name, None) @@ -616,16 +785,14 @@ class AutoScalingBackend(BaseBackend): def update_attached_elbs(self, group_name): group = self.autoscaling_groups[group_name] - group_instance_ids = set( - state.instance.id for state in group.instance_states) + group_instance_ids = set(state.instance.id for state in group.instance_states) # skip this if group.load_balancers is empty # otherwise elb_backend.describe_load_balancers returns all available load balancers if not group.load_balancers: return try: - elbs = self.elb_backend.describe_load_balancers( - names=group.load_balancers) + elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers) except LoadBalancerNotFoundError: # ELBs can be deleted before their autoscaling group return @@ -633,14 +800,15 @@ class AutoScalingBackend(BaseBackend): for elb in elbs: elb_instace_ids = set(elb.instance_ids) self.elb_backend.register_instances( - elb.name, group_instance_ids - elb_instace_ids) + elb.name, group_instance_ids - elb_instace_ids + ) self.elb_backend.deregister_instances( - elb.name, elb_instace_ids - group_instance_ids) + elb.name, elb_instace_ids - group_instance_ids + ) def update_attached_target_groups(self, group_name): group = self.autoscaling_groups[group_name] - group_instance_ids = set( - state.instance.id for state in group.instance_states) + group_instance_ids = set(state.instance.id for state in group.instance_states) # no action necessary if target_group_arns is empty if not group.target_group_arns: @@ -649,10 +817,13 @@ class AutoScalingBackend(BaseBackend): target_groups = self.elbv2_backend.describe_target_groups( target_group_arns=group.target_group_arns, load_balancer_arn=None, - names=None) + names=None, + ) for target_group in target_groups: - asg_targets = [{'id': x, 'port': target_group.port} for x in group_instance_ids] + asg_targets = [ + {"id": x, "port": target_group.port} for x in group_instance_ids + ] self.elbv2_backend.register_targets(target_group.arn, (asg_targets)) def create_or_update_tags(self, tags): @@ -670,7 +841,7 @@ class AutoScalingBackend(BaseBackend): new_tags.append(old_tag) # if key was never in old_tag's add it (create tag) - if not any(new_tag['key'] == tag['key'] for new_tag in new_tags): + if not any(new_tag["key"] == tag["key"] for new_tag in new_tags): new_tags.append(tag) group.tags = new_tags @@ -678,7 +849,8 @@ class AutoScalingBackend(BaseBackend): def attach_load_balancers(self, group_name, load_balancer_names): group = self.autoscaling_groups[group_name] group.load_balancers.extend( - [x for x in load_balancer_names if x not in group.load_balancers]) + [x for x in load_balancer_names if x not in group.load_balancers] + ) self.update_attached_elbs(group_name) def describe_load_balancers(self, group_name): @@ -686,13 +858,13 @@ class AutoScalingBackend(BaseBackend): def detach_load_balancers(self, group_name, load_balancer_names): group = self.autoscaling_groups[group_name] - group_instance_ids = set( - state.instance.id for state in group.instance_states) + group_instance_ids = set(state.instance.id for state in group.instance_states) elbs = self.elb_backend.describe_load_balancers(names=group.load_balancers) for elb in elbs: - self.elb_backend.deregister_instances( - elb.name, group_instance_ids) - group.load_balancers = [x for x in group.load_balancers if x not in load_balancer_names] + self.elb_backend.deregister_instances(elb.name, group_instance_ids) + group.load_balancers = [ + x for x in group.load_balancers if x not in load_balancer_names + ] def attach_load_balancer_target_groups(self, group_name, target_group_arns): group = self.autoscaling_groups[group_name] @@ -704,36 +876,51 @@ class AutoScalingBackend(BaseBackend): def detach_load_balancer_target_groups(self, group_name, target_group_arns): group = self.autoscaling_groups[group_name] - group.target_group_arns = [x for x in group.target_group_arns if x not in target_group_arns] + group.target_group_arns = [ + x for x in group.target_group_arns if x not in target_group_arns + ] for target_group in target_group_arns: - asg_targets = [{'id': x.instance.id} for x in group.instance_states] + asg_targets = [{"id": x.instance.id} for x in group.instance_states] self.elbv2_backend.deregister_targets(target_group, (asg_targets)) def suspend_processes(self, group_name, scaling_processes): group = self.autoscaling_groups[group_name] group.suspended_processes = scaling_processes or [] - def set_instance_protection(self, group_name, instance_ids, protected_from_scale_in): + def set_instance_protection( + self, group_name, instance_ids, protected_from_scale_in + ): group = self.autoscaling_groups[group_name] protected_instances = [ - x for x in group.instance_states if x.instance.id in instance_ids] + x for x in group.instance_states if x.instance.id in instance_ids + ] for instance in protected_instances: instance.protected_from_scale_in = protected_from_scale_in def notify_terminate_instances(self, instance_ids): - for autoscaling_group_name, autoscaling_group in self.autoscaling_groups.items(): + for ( + autoscaling_group_name, + autoscaling_group, + ) in self.autoscaling_groups.items(): original_instance_count = len(autoscaling_group.instance_states) - autoscaling_group.instance_states = list(filter( - lambda i_state: i_state.instance.id not in instance_ids, + autoscaling_group.instance_states = list( + filter( + lambda i_state: i_state.instance.id not in instance_ids, + autoscaling_group.instance_states, + ) + ) + difference = original_instance_count - len( autoscaling_group.instance_states - )) - difference = original_instance_count - len(autoscaling_group.instance_states) + ) if difference > 0: - autoscaling_group.replace_autoscaling_group_instances(difference, autoscaling_group.get_propagated_tags()) + autoscaling_group.replace_autoscaling_group_instances( + difference, autoscaling_group.get_propagated_tags() + ) self.update_attached_elbs(autoscaling_group_name) autoscaling_backends = {} for region, ec2_backend in ec2_backends.items(): autoscaling_backends[region] = AutoScalingBackend( - ec2_backend, elb_backends[region], elbv2_backends[region]) + ec2_backend, elb_backends[region], elbv2_backends[region] + ) diff --git a/moto/autoscaling/responses.py b/moto/autoscaling/responses.py index 5e409aafb..83e2f7d5a 100644 --- a/moto/autoscaling/responses.py +++ b/moto/autoscaling/responses.py @@ -6,88 +6,88 @@ from .models import autoscaling_backends class AutoScalingResponse(BaseResponse): - @property def autoscaling_backend(self): return autoscaling_backends[self.region] def create_launch_configuration(self): - instance_monitoring_string = self._get_param( - 'InstanceMonitoring.Enabled') - if instance_monitoring_string == 'true': + instance_monitoring_string = self._get_param("InstanceMonitoring.Enabled") + if instance_monitoring_string == "true": instance_monitoring = True else: instance_monitoring = False self.autoscaling_backend.create_launch_configuration( - name=self._get_param('LaunchConfigurationName'), - image_id=self._get_param('ImageId'), - key_name=self._get_param('KeyName'), - ramdisk_id=self._get_param('RamdiskId'), - kernel_id=self._get_param('KernelId'), - security_groups=self._get_multi_param('SecurityGroups.member'), - user_data=self._get_param('UserData'), - instance_type=self._get_param('InstanceType'), + name=self._get_param("LaunchConfigurationName"), + image_id=self._get_param("ImageId"), + key_name=self._get_param("KeyName"), + ramdisk_id=self._get_param("RamdiskId"), + kernel_id=self._get_param("KernelId"), + security_groups=self._get_multi_param("SecurityGroups.member"), + user_data=self._get_param("UserData"), + instance_type=self._get_param("InstanceType"), instance_monitoring=instance_monitoring, - instance_profile_name=self._get_param('IamInstanceProfile'), - spot_price=self._get_param('SpotPrice'), - ebs_optimized=self._get_param('EbsOptimized'), - associate_public_ip_address=self._get_param( - "AssociatePublicIpAddress"), - block_device_mappings=self._get_list_prefix( - 'BlockDeviceMappings.member') + instance_profile_name=self._get_param("IamInstanceProfile"), + spot_price=self._get_param("SpotPrice"), + ebs_optimized=self._get_param("EbsOptimized"), + associate_public_ip_address=self._get_param("AssociatePublicIpAddress"), + block_device_mappings=self._get_list_prefix("BlockDeviceMappings.member"), ) template = self.response_template(CREATE_LAUNCH_CONFIGURATION_TEMPLATE) return template.render() def describe_launch_configurations(self): - names = self._get_multi_param('LaunchConfigurationNames.member') - all_launch_configurations = self.autoscaling_backend.describe_launch_configurations(names) - marker = self._get_param('NextToken') + names = self._get_multi_param("LaunchConfigurationNames.member") + all_launch_configurations = self.autoscaling_backend.describe_launch_configurations( + names + ) + marker = self._get_param("NextToken") all_names = [lc.name for lc in all_launch_configurations] if marker: start = all_names.index(marker) + 1 else: start = 0 - max_records = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier - launch_configurations_resp = all_launch_configurations[start:start + max_records] + max_records = self._get_int_param( + "MaxRecords", 50 + ) # the default is 100, but using 50 to make testing easier + launch_configurations_resp = all_launch_configurations[ + start : start + max_records + ] next_token = None if len(all_launch_configurations) > start + max_records: next_token = launch_configurations_resp[-1].name - template = self.response_template( - DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE) - return template.render(launch_configurations=launch_configurations_resp, next_token=next_token) + template = self.response_template(DESCRIBE_LAUNCH_CONFIGURATIONS_TEMPLATE) + return template.render( + launch_configurations=launch_configurations_resp, next_token=next_token + ) def delete_launch_configuration(self): - launch_configurations_name = self.querystring.get( - 'LaunchConfigurationName')[0] - self.autoscaling_backend.delete_launch_configuration( - launch_configurations_name) + launch_configurations_name = self.querystring.get("LaunchConfigurationName")[0] + self.autoscaling_backend.delete_launch_configuration(launch_configurations_name) template = self.response_template(DELETE_LAUNCH_CONFIGURATION_TEMPLATE) return template.render() def create_auto_scaling_group(self): self.autoscaling_backend.create_auto_scaling_group( - name=self._get_param('AutoScalingGroupName'), - availability_zones=self._get_multi_param( - 'AvailabilityZones.member'), - desired_capacity=self._get_int_param('DesiredCapacity'), - max_size=self._get_int_param('MaxSize'), - min_size=self._get_int_param('MinSize'), - instance_id=self._get_param('InstanceId'), - launch_config_name=self._get_param('LaunchConfigurationName'), - vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), - default_cooldown=self._get_int_param('DefaultCooldown'), - health_check_period=self._get_int_param('HealthCheckGracePeriod'), - health_check_type=self._get_param('HealthCheckType'), - load_balancers=self._get_multi_param('LoadBalancerNames.member'), - target_group_arns=self._get_multi_param('TargetGroupARNs.member'), - placement_group=self._get_param('PlacementGroup'), - termination_policies=self._get_multi_param( - 'TerminationPolicies.member'), - tags=self._get_list_prefix('Tags.member'), + name=self._get_param("AutoScalingGroupName"), + availability_zones=self._get_multi_param("AvailabilityZones.member"), + desired_capacity=self._get_int_param("DesiredCapacity"), + max_size=self._get_int_param("MaxSize"), + min_size=self._get_int_param("MinSize"), + instance_id=self._get_param("InstanceId"), + launch_config_name=self._get_param("LaunchConfigurationName"), + vpc_zone_identifier=self._get_param("VPCZoneIdentifier"), + default_cooldown=self._get_int_param("DefaultCooldown"), + health_check_period=self._get_int_param("HealthCheckGracePeriod"), + health_check_type=self._get_param("HealthCheckType"), + load_balancers=self._get_multi_param("LoadBalancerNames.member"), + target_group_arns=self._get_multi_param("TargetGroupARNs.member"), + placement_group=self._get_param("PlacementGroup"), + termination_policies=self._get_multi_param("TerminationPolicies.member"), + tags=self._get_list_prefix("Tags.member"), new_instances_protected_from_scale_in=self._get_bool_param( - 'NewInstancesProtectedFromScaleIn', False) + "NewInstancesProtectedFromScaleIn", False + ), ) template = self.response_template(CREATE_AUTOSCALING_GROUP_TEMPLATE) return template.render() @@ -95,68 +95,73 @@ class AutoScalingResponse(BaseResponse): @amz_crc32 @amzn_request_id def attach_instances(self): - group_name = self._get_param('AutoScalingGroupName') - instance_ids = self._get_multi_param('InstanceIds.member') - self.autoscaling_backend.attach_instances( - group_name, instance_ids) + group_name = self._get_param("AutoScalingGroupName") + instance_ids = self._get_multi_param("InstanceIds.member") + self.autoscaling_backend.attach_instances(group_name, instance_ids) template = self.response_template(ATTACH_INSTANCES_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def set_instance_health(self): - instance_id = self._get_param('InstanceId') + instance_id = self._get_param("InstanceId") health_status = self._get_param("HealthStatus") - if health_status not in ['Healthy', 'Unhealthy']: - raise ValueError('Valid instance health states are: [Healthy, Unhealthy]') + if health_status not in ["Healthy", "Unhealthy"]: + raise ValueError("Valid instance health states are: [Healthy, Unhealthy]") should_respect_grace_period = self._get_param("ShouldRespectGracePeriod") - self.autoscaling_backend.set_instance_health(instance_id, health_status, should_respect_grace_period) + self.autoscaling_backend.set_instance_health( + instance_id, health_status, should_respect_grace_period + ) template = self.response_template(SET_INSTANCE_HEALTH_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def detach_instances(self): - group_name = self._get_param('AutoScalingGroupName') - instance_ids = self._get_multi_param('InstanceIds.member') - should_decrement_string = self._get_param('ShouldDecrementDesiredCapacity') - if should_decrement_string == 'true': + group_name = self._get_param("AutoScalingGroupName") + instance_ids = self._get_multi_param("InstanceIds.member") + should_decrement_string = self._get_param("ShouldDecrementDesiredCapacity") + if should_decrement_string == "true": should_decrement = True else: should_decrement = False detached_instances = self.autoscaling_backend.detach_instances( - group_name, instance_ids, should_decrement) + group_name, instance_ids, should_decrement + ) template = self.response_template(DETACH_INSTANCES_TEMPLATE) return template.render(detached_instances=detached_instances) @amz_crc32 @amzn_request_id def attach_load_balancer_target_groups(self): - group_name = self._get_param('AutoScalingGroupName') - target_group_arns = self._get_multi_param('TargetGroupARNs.member') + group_name = self._get_param("AutoScalingGroupName") + target_group_arns = self._get_multi_param("TargetGroupARNs.member") self.autoscaling_backend.attach_load_balancer_target_groups( - group_name, target_group_arns) + group_name, target_group_arns + ) template = self.response_template(ATTACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def describe_load_balancer_target_groups(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") target_group_arns = self.autoscaling_backend.describe_load_balancer_target_groups( - group_name) + group_name + ) template = self.response_template(DESCRIBE_LOAD_BALANCER_TARGET_GROUPS) return template.render(target_group_arns=target_group_arns) @amz_crc32 @amzn_request_id def detach_load_balancer_target_groups(self): - group_name = self._get_param('AutoScalingGroupName') - target_group_arns = self._get_multi_param('TargetGroupARNs.member') + group_name = self._get_param("AutoScalingGroupName") + target_group_arns = self._get_multi_param("TargetGroupARNs.member") self.autoscaling_backend.detach_load_balancer_target_groups( - group_name, target_group_arns) + group_name, target_group_arns + ) template = self.response_template(DETACH_LOAD_BALANCER_TARGET_GROUPS_TEMPLATE) return template.render() @@ -172,7 +177,7 @@ class AutoScalingResponse(BaseResponse): max_records = self._get_int_param("MaxRecords", 50) if max_records > 100: raise ValueError - groups = all_groups[start:start + max_records] + groups = all_groups[start : start + max_records] next_token = None if max_records and len(all_groups) > start + max_records: next_token = groups[-1].name @@ -181,42 +186,40 @@ class AutoScalingResponse(BaseResponse): def update_auto_scaling_group(self): self.autoscaling_backend.update_auto_scaling_group( - name=self._get_param('AutoScalingGroupName'), - availability_zones=self._get_multi_param( - 'AvailabilityZones.member'), - desired_capacity=self._get_int_param('DesiredCapacity'), - max_size=self._get_int_param('MaxSize'), - min_size=self._get_int_param('MinSize'), - launch_config_name=self._get_param('LaunchConfigurationName'), - vpc_zone_identifier=self._get_param('VPCZoneIdentifier'), - default_cooldown=self._get_int_param('DefaultCooldown'), - health_check_period=self._get_int_param('HealthCheckGracePeriod'), - health_check_type=self._get_param('HealthCheckType'), - placement_group=self._get_param('PlacementGroup'), - termination_policies=self._get_multi_param( - 'TerminationPolicies.member'), + name=self._get_param("AutoScalingGroupName"), + availability_zones=self._get_multi_param("AvailabilityZones.member"), + desired_capacity=self._get_int_param("DesiredCapacity"), + max_size=self._get_int_param("MaxSize"), + min_size=self._get_int_param("MinSize"), + launch_config_name=self._get_param("LaunchConfigurationName"), + vpc_zone_identifier=self._get_param("VPCZoneIdentifier"), + default_cooldown=self._get_int_param("DefaultCooldown"), + health_check_period=self._get_int_param("HealthCheckGracePeriod"), + health_check_type=self._get_param("HealthCheckType"), + placement_group=self._get_param("PlacementGroup"), + termination_policies=self._get_multi_param("TerminationPolicies.member"), new_instances_protected_from_scale_in=self._get_bool_param( - 'NewInstancesProtectedFromScaleIn', None) + "NewInstancesProtectedFromScaleIn", None + ), ) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) return template.render() def delete_auto_scaling_group(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") self.autoscaling_backend.delete_auto_scaling_group(group_name) template = self.response_template(DELETE_AUTOSCALING_GROUP_TEMPLATE) return template.render() def set_desired_capacity(self): - group_name = self._get_param('AutoScalingGroupName') - desired_capacity = self._get_int_param('DesiredCapacity') - self.autoscaling_backend.set_desired_capacity( - group_name, desired_capacity) + group_name = self._get_param("AutoScalingGroupName") + desired_capacity = self._get_int_param("DesiredCapacity") + self.autoscaling_backend.set_desired_capacity(group_name, desired_capacity) template = self.response_template(SET_DESIRED_CAPACITY_TEMPLATE) return template.render() def create_or_update_tags(self): - tags = self._get_list_prefix('Tags.member') + tags = self._get_list_prefix("Tags.member") self.autoscaling_backend.create_or_update_tags(tags) template = self.response_template(UPDATE_AUTOSCALING_GROUP_TEMPLATE) @@ -224,38 +227,38 @@ class AutoScalingResponse(BaseResponse): def describe_auto_scaling_instances(self): instance_states = self.autoscaling_backend.describe_auto_scaling_instances() - template = self.response_template( - DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE) + template = self.response_template(DESCRIBE_AUTOSCALING_INSTANCES_TEMPLATE) return template.render(instance_states=instance_states) def put_scaling_policy(self): policy = self.autoscaling_backend.create_autoscaling_policy( - name=self._get_param('PolicyName'), - policy_type=self._get_param('PolicyType'), - adjustment_type=self._get_param('AdjustmentType'), - as_name=self._get_param('AutoScalingGroupName'), - scaling_adjustment=self._get_int_param('ScalingAdjustment'), - cooldown=self._get_int_param('Cooldown'), + name=self._get_param("PolicyName"), + policy_type=self._get_param("PolicyType"), + adjustment_type=self._get_param("AdjustmentType"), + as_name=self._get_param("AutoScalingGroupName"), + scaling_adjustment=self._get_int_param("ScalingAdjustment"), + cooldown=self._get_int_param("Cooldown"), ) template = self.response_template(CREATE_SCALING_POLICY_TEMPLATE) return template.render(policy=policy) def describe_policies(self): policies = self.autoscaling_backend.describe_policies( - autoscaling_group_name=self._get_param('AutoScalingGroupName'), - policy_names=self._get_multi_param('PolicyNames.member'), - policy_types=self._get_multi_param('PolicyTypes.member')) + autoscaling_group_name=self._get_param("AutoScalingGroupName"), + policy_names=self._get_multi_param("PolicyNames.member"), + policy_types=self._get_multi_param("PolicyTypes.member"), + ) template = self.response_template(DESCRIBE_SCALING_POLICIES_TEMPLATE) return template.render(policies=policies) def delete_policy(self): - group_name = self._get_param('PolicyName') + group_name = self._get_param("PolicyName") self.autoscaling_backend.delete_policy(group_name) template = self.response_template(DELETE_POLICY_TEMPLATE) return template.render() def execute_policy(self): - group_name = self._get_param('PolicyName') + group_name = self._get_param("PolicyName") self.autoscaling_backend.execute_policy(group_name) template = self.response_template(EXECUTE_POLICY_TEMPLATE) return template.render() @@ -263,17 +266,16 @@ class AutoScalingResponse(BaseResponse): @amz_crc32 @amzn_request_id def attach_load_balancers(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") load_balancer_names = self._get_multi_param("LoadBalancerNames.member") - self.autoscaling_backend.attach_load_balancers( - group_name, load_balancer_names) + self.autoscaling_backend.attach_load_balancers(group_name, load_balancer_names) template = self.response_template(ATTACH_LOAD_BALANCERS_TEMPLATE) return template.render() @amz_crc32 @amzn_request_id def describe_load_balancers(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") load_balancers = self.autoscaling_backend.describe_load_balancers(group_name) template = self.response_template(DESCRIBE_LOAD_BALANCERS_TEMPLATE) return template.render(load_balancers=load_balancers) @@ -281,26 +283,28 @@ class AutoScalingResponse(BaseResponse): @amz_crc32 @amzn_request_id def detach_load_balancers(self): - group_name = self._get_param('AutoScalingGroupName') + group_name = self._get_param("AutoScalingGroupName") load_balancer_names = self._get_multi_param("LoadBalancerNames.member") - self.autoscaling_backend.detach_load_balancers( - group_name, load_balancer_names) + self.autoscaling_backend.detach_load_balancers(group_name, load_balancer_names) template = self.response_template(DETACH_LOAD_BALANCERS_TEMPLATE) return template.render() def suspend_processes(self): - autoscaling_group_name = self._get_param('AutoScalingGroupName') - scaling_processes = self._get_multi_param('ScalingProcesses.member') - self.autoscaling_backend.suspend_processes(autoscaling_group_name, scaling_processes) + autoscaling_group_name = self._get_param("AutoScalingGroupName") + scaling_processes = self._get_multi_param("ScalingProcesses.member") + self.autoscaling_backend.suspend_processes( + autoscaling_group_name, scaling_processes + ) template = self.response_template(SUSPEND_PROCESSES_TEMPLATE) return template.render() def set_instance_protection(self): - group_name = self._get_param('AutoScalingGroupName') - instance_ids = self._get_multi_param('InstanceIds.member') - protected_from_scale_in = self._get_bool_param('ProtectedFromScaleIn') + group_name = self._get_param("AutoScalingGroupName") + instance_ids = self._get_multi_param("InstanceIds.member") + protected_from_scale_in = self._get_bool_param("ProtectedFromScaleIn") self.autoscaling_backend.set_instance_protection( - group_name, instance_ids, protected_from_scale_in) + group_name, instance_ids, protected_from_scale_in + ) template = self.response_template(SET_INSTANCE_PROTECTION_TEMPLATE) return template.render() diff --git a/moto/autoscaling/urls.py b/moto/autoscaling/urls.py index 0743fdcf7..5fb33c25d 100644 --- a/moto/autoscaling/urls.py +++ b/moto/autoscaling/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import AutoScalingResponse -url_bases = [ - "https?://autoscaling.(.+).amazonaws.com", -] +url_bases = ["https?://autoscaling.(.+).amazonaws.com"] -url_paths = { - '{0}/$': AutoScalingResponse.dispatch, -} +url_paths = {"{0}/$": AutoScalingResponse.dispatch} diff --git a/moto/awslambda/__init__.py b/moto/awslambda/__init__.py index f0d694654..d40bf051a 100644 --- a/moto/awslambda/__init__.py +++ b/moto/awslambda/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import lambda_backends from ..core.models import base_decorator, deprecated_base_decorator -lambda_backend = lambda_backends['us-east-1'] +lambda_backend = lambda_backends["us-east-1"] mock_lambda = base_decorator(lambda_backends) mock_lambda_deprecated = deprecated_base_decorator(lambda_backends) diff --git a/moto/awslambda/models.py b/moto/awslambda/models.py index 234950125..00205df3e 100644 --- a/moto/awslambda/models.py +++ b/moto/awslambda/models.py @@ -38,7 +38,7 @@ from moto.dynamodbstreams import dynamodbstreams_backends logger = logging.getLogger(__name__) -ACCOUNT_ID = '123456789012' +ACCOUNT_ID = "123456789012" try: @@ -47,20 +47,22 @@ except ImportError: from backports.tempfile import TemporaryDirectory -_stderr_regex = re.compile(r'START|END|REPORT RequestId: .*') +_stderr_regex = re.compile(r"START|END|REPORT RequestId: .*") _orig_adapter_send = requests.adapters.HTTPAdapter.send -docker_3 = docker.__version__[0] >= '3' +docker_3 = docker.__version__[0] >= "3" def zip2tar(zip_bytes): with TemporaryDirectory() as td: - tarname = os.path.join(td, 'data.tar') - timeshift = int((datetime.datetime.now() - - datetime.datetime.utcnow()).total_seconds()) - with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zipf, \ - tarfile.TarFile(tarname, 'w') as tarf: + tarname = os.path.join(td, "data.tar") + timeshift = int( + (datetime.datetime.now() - datetime.datetime.utcnow()).total_seconds() + ) + with zipfile.ZipFile(io.BytesIO(zip_bytes), "r") as zipf, tarfile.TarFile( + tarname, "w" + ) as tarf: for zipinfo in zipf.infolist(): - if zipinfo.filename[-1] == '/': # is_dir() is py3.6+ + if zipinfo.filename[-1] == "/": # is_dir() is py3.6+ continue tarinfo = tarfile.TarInfo(name=zipinfo.filename) @@ -69,7 +71,7 @@ def zip2tar(zip_bytes): infile = zipf.open(zipinfo.filename) tarf.addfile(tarinfo, infile) - with open(tarname, 'rb') as f: + with open(tarname, "rb") as f: tar_data = f.read() return tar_data @@ -83,7 +85,9 @@ class _VolumeRefCount: class _DockerDataVolumeContext: - _data_vol_map = defaultdict(lambda: _VolumeRefCount(0, None)) # {sha256: _VolumeRefCount} + _data_vol_map = defaultdict( + lambda: _VolumeRefCount(0, None) + ) # {sha256: _VolumeRefCount} _lock = threading.Lock() def __init__(self, lambda_func): @@ -109,15 +113,19 @@ class _DockerDataVolumeContext: return self # It doesn't exist so we need to create it - self._vol_ref.volume = self._lambda_func.docker_client.volumes.create(self._lambda_func.code_sha_256) + self._vol_ref.volume = self._lambda_func.docker_client.volumes.create( + self._lambda_func.code_sha_256 + ) if docker_3: - volumes = {self.name: {'bind': '/tmp/data', 'mode': 'rw'}} + volumes = {self.name: {"bind": "/tmp/data", "mode": "rw"}} else: - volumes = {self.name: '/tmp/data'} - container = self._lambda_func.docker_client.containers.run('alpine', 'sleep 100', volumes=volumes, detach=True) + volumes = {self.name: "/tmp/data"} + container = self._lambda_func.docker_client.containers.run( + "alpine", "sleep 100", volumes=volumes, detach=True + ) try: tar_bytes = zip2tar(self._lambda_func.code_bytes) - container.put_archive('/tmp/data', tar_bytes) + container.put_archive("/tmp/data", tar_bytes) finally: container.remove(force=True) @@ -140,13 +148,13 @@ class LambdaFunction(BaseModel): def __init__(self, spec, region, validate_s3=True, version=1): # required self.region = region - self.code = spec['Code'] - self.function_name = spec['FunctionName'] - self.handler = spec['Handler'] - self.role = spec['Role'] - self.run_time = spec['Runtime'] + self.code = spec["Code"] + self.function_name = spec["FunctionName"] + self.handler = spec["Handler"] + self.role = spec["Role"] + self.run_time = spec["Runtime"] self.logs_backend = logs_backends[self.region] - self.environment_vars = spec.get('Environment', {}).get('Variables', {}) + self.environment_vars = spec.get("Environment", {}).get("Variables", {}) self.docker_client = docker.from_env() self.policy = "" @@ -161,77 +169,82 @@ class LambdaFunction(BaseModel): if isinstance(adapter, requests.adapters.HTTPAdapter): adapter.send = functools.partial(_orig_adapter_send, adapter) return adapter + self.docker_client.api.get_adapter = replace_adapter_send # optional - self.description = spec.get('Description', '') - self.memory_size = spec.get('MemorySize', 128) - self.publish = spec.get('Publish', False) # this is ignored currently - self.timeout = spec.get('Timeout', 3) + self.description = spec.get("Description", "") + self.memory_size = spec.get("MemorySize", 128) + self.publish = spec.get("Publish", False) # this is ignored currently + self.timeout = spec.get("Timeout", 3) - self.logs_group_name = '/aws/lambda/{}'.format(self.function_name) + self.logs_group_name = "/aws/lambda/{}".format(self.function_name) self.logs_backend.ensure_log_group(self.logs_group_name, []) # this isn't finished yet. it needs to find out the VpcId value self._vpc_config = spec.get( - 'VpcConfig', {'SubnetIds': [], 'SecurityGroupIds': []}) + "VpcConfig", {"SubnetIds": [], "SecurityGroupIds": []} + ) # auto-generated self.version = version - self.last_modified = datetime.datetime.utcnow().strftime( - '%Y-%m-%d %H:%M:%S') + self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") - if 'ZipFile' in self.code: + if "ZipFile" in self.code: # more hackery to handle unicode/bytes/str in python3 and python2 - # argh! try: - to_unzip_code = base64.b64decode( - bytes(self.code['ZipFile'], 'utf-8')) + to_unzip_code = base64.b64decode(bytes(self.code["ZipFile"], "utf-8")) except Exception: - to_unzip_code = base64.b64decode(self.code['ZipFile']) + to_unzip_code = base64.b64decode(self.code["ZipFile"]) self.code_bytes = to_unzip_code self.code_size = len(to_unzip_code) self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest() # TODO: we should be putting this in a lambda bucket - self.code['UUID'] = str(uuid.uuid4()) - self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID']) + self.code["UUID"] = str(uuid.uuid4()) + self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"]) else: # validate s3 bucket and key key = None try: # FIXME: does not validate bucket region - key = s3_backend.get_key( - self.code['S3Bucket'], self.code['S3Key']) + key = s3_backend.get_key(self.code["S3Bucket"], self.code["S3Key"]) except MissingBucket: if do_validate_s3(): raise ValueError( "InvalidParameterValueException", - "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist") + "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist", + ) except MissingKey: if do_validate_s3(): raise ValueError( "InvalidParameterValueException", - "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.") + "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.", + ) if key: self.code_bytes = key.value self.code_size = key.size self.code_sha_256 = hashlib.sha256(key.value).hexdigest() - self.function_arn = make_function_arn(self.region, ACCOUNT_ID, self.function_name) + self.function_arn = make_function_arn( + self.region, ACCOUNT_ID, self.function_name + ) self.tags = dict() def set_version(self, version): - self.function_arn = make_function_ver_arn(self.region, ACCOUNT_ID, self.function_name, version) + self.function_arn = make_function_ver_arn( + self.region, ACCOUNT_ID, self.function_name, version + ) self.version = version - self.last_modified = datetime.datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S') + self.last_modified = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") @property def vpc_config(self): config = self._vpc_config.copy() - if config['SecurityGroupIds']: + if config["SecurityGroupIds"]: config.update({"VpcId": "vpc-123abc"}) return config @@ -260,17 +273,17 @@ class LambdaFunction(BaseModel): } if self.environment_vars: - config['Environment'] = { - 'Variables': self.environment_vars - } + config["Environment"] = {"Variables": self.environment_vars} return config def get_code(self): return { "Code": { - "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format(self.region, self.code['S3Key']), - "RepositoryType": "S3" + "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/{1}".format( + self.region, self.code["S3Key"] + ), + "RepositoryType": "S3", }, "Configuration": self.get_configuration(), } @@ -297,43 +310,48 @@ class LambdaFunction(BaseModel): return self.get_configuration() def update_function_code(self, updated_spec): - if 'DryRun' in updated_spec and updated_spec['DryRun']: + if "DryRun" in updated_spec and updated_spec["DryRun"]: return self.get_configuration() - if 'ZipFile' in updated_spec: - self.code['ZipFile'] = updated_spec['ZipFile'] + if "ZipFile" in updated_spec: + self.code["ZipFile"] = updated_spec["ZipFile"] # using the "hackery" from __init__ because it seems to work # TODOs and FIXMEs included, because they'll need to be fixed # in both places now try: to_unzip_code = base64.b64decode( - bytes(updated_spec['ZipFile'], 'utf-8')) + bytes(updated_spec["ZipFile"], "utf-8") + ) except Exception: - to_unzip_code = base64.b64decode(updated_spec['ZipFile']) + to_unzip_code = base64.b64decode(updated_spec["ZipFile"]) self.code_bytes = to_unzip_code self.code_size = len(to_unzip_code) self.code_sha_256 = hashlib.sha256(to_unzip_code).hexdigest() # TODO: we should be putting this in a lambda bucket - self.code['UUID'] = str(uuid.uuid4()) - self.code['S3Key'] = '{}-{}'.format(self.function_name, self.code['UUID']) - elif 'S3Bucket' in updated_spec and 'S3Key' in updated_spec: + self.code["UUID"] = str(uuid.uuid4()) + self.code["S3Key"] = "{}-{}".format(self.function_name, self.code["UUID"]) + elif "S3Bucket" in updated_spec and "S3Key" in updated_spec: key = None try: # FIXME: does not validate bucket region - key = s3_backend.get_key(updated_spec['S3Bucket'], updated_spec['S3Key']) + key = s3_backend.get_key( + updated_spec["S3Bucket"], updated_spec["S3Key"] + ) except MissingBucket: if do_validate_s3(): raise ValueError( "InvalidParameterValueException", - "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist") + "Error occurred while GetObject. S3 Error Code: NoSuchBucket. S3 Error Message: The specified bucket does not exist", + ) except MissingKey: if do_validate_s3(): raise ValueError( "InvalidParameterValueException", - "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.") + "Error occurred while GetObject. S3 Error Code: NoSuchKey. S3 Error Message: The specified key does not exist.", + ) if key: self.code_bytes = key.value self.code_size = key.size @@ -344,7 +362,7 @@ class LambdaFunction(BaseModel): @staticmethod def convert(s): try: - return str(s, encoding='utf-8') + return str(s, encoding="utf-8") except Exception: return s @@ -372,12 +390,21 @@ class LambdaFunction(BaseModel): container = output = exit_code = None with _DockerDataVolumeContext(self) as data_vol: try: - run_kwargs = dict(links={'motoserver': 'motoserver'}) if settings.TEST_SERVER_MODE else {} + run_kwargs = ( + dict(links={"motoserver": "motoserver"}) + if settings.TEST_SERVER_MODE + else {} + ) container = self.docker_client.containers.run( "lambci/lambda:{}".format(self.run_time), - [self.handler, json.dumps(event)], remove=False, + [self.handler, json.dumps(event)], + remove=False, mem_limit="{}m".format(self.memory_size), - volumes=["{}:/var/task".format(data_vol.name)], environment=env_vars, detach=True, **run_kwargs) + volumes=["{}:/var/task".format(data_vol.name)], + environment=env_vars, + detach=True, + **run_kwargs + ) finally: if container: try: @@ -388,32 +415,43 @@ class LambdaFunction(BaseModel): container.kill() else: if docker_3: - exit_code = exit_code['StatusCode'] + exit_code = exit_code["StatusCode"] output = container.logs(stdout=False, stderr=True) output += container.logs(stdout=True, stderr=False) container.remove() - output = output.decode('utf-8') + output = output.decode("utf-8") # Send output to "logs" backend invoke_id = uuid.uuid4().hex log_stream_name = "{date.year}/{date.month:02d}/{date.day:02d}/[{version}]{invoke_id}".format( - date=datetime.datetime.utcnow(), version=self.version, invoke_id=invoke_id + date=datetime.datetime.utcnow(), + version=self.version, + invoke_id=invoke_id, ) self.logs_backend.create_log_stream(self.logs_group_name, log_stream_name) - log_events = [{'timestamp': unix_time_millis(), "message": line} - for line in output.splitlines()] - self.logs_backend.put_log_events(self.logs_group_name, log_stream_name, log_events, None) + log_events = [ + {"timestamp": unix_time_millis(), "message": line} + for line in output.splitlines() + ] + self.logs_backend.put_log_events( + self.logs_group_name, log_stream_name, log_events, None + ) if exit_code != 0: - raise Exception( - 'lambda invoke failed output: {}'.format(output)) + raise Exception("lambda invoke failed output: {}".format(output)) # strip out RequestId lines - output = os.linesep.join([line for line in self.convert(output).splitlines() if not _stderr_regex.match(line)]) + output = os.linesep.join( + [ + line + for line in self.convert(output).splitlines() + if not _stderr_regex.match(line) + ] + ) return output, False except BaseException as e: traceback.print_exc() @@ -428,31 +466,34 @@ class LambdaFunction(BaseModel): # Get the invocation type: res, errored = self._invoke_lambda(code=self.code, event=body) if request_headers.get("x-amz-invocation-type") == "RequestResponse": - encoded = base64.b64encode(res.encode('utf-8')) - response_headers["x-amz-log-result"] = encoded.decode('utf-8') - payload['result'] = response_headers["x-amz-log-result"] - result = res.encode('utf-8') + encoded = base64.b64encode(res.encode("utf-8")) + response_headers["x-amz-log-result"] = encoded.decode("utf-8") + payload["result"] = response_headers["x-amz-log-result"] + result = res.encode("utf-8") else: result = json.dumps(payload) if errored: - response_headers['x-amz-function-error'] = "Handled" + response_headers["x-amz-function-error"] = "Handled" return result @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, - region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] # required spec = { - 'Code': properties['Code'], - 'FunctionName': resource_name, - 'Handler': properties['Handler'], - 'Role': properties['Role'], - 'Runtime': properties['Runtime'], + "Code": properties["Code"], + "FunctionName": resource_name, + "Handler": properties["Handler"], + "Role": properties["Role"], + "Runtime": properties["Runtime"], } - optional_properties = 'Description MemorySize Publish Timeout VpcConfig Environment'.split() + optional_properties = ( + "Description MemorySize Publish Timeout VpcConfig Environment".split() + ) # NOTE: Not doing `properties.get(k, DEFAULT)` to avoid duplicating the # default logic for prop in optional_properties: @@ -462,27 +503,27 @@ class LambdaFunction(BaseModel): # when ZipFile is present in CloudFormation, per the official docs, # the code it's a plaintext code snippet up to 4096 bytes. # this snippet converts this plaintext code to a proper base64-encoded ZIP file. - if 'ZipFile' in properties['Code']: - spec['Code']['ZipFile'] = base64.b64encode( - cls._create_zipfile_from_plaintext_code( - spec['Code']['ZipFile'])) + if "ZipFile" in properties["Code"]: + spec["Code"]["ZipFile"] = base64.b64encode( + cls._create_zipfile_from_plaintext_code(spec["Code"]["ZipFile"]) + ) backend = lambda_backends[region_name] fn = backend.create_function(spec) return fn def get_cfn_attribute(self, attribute_name): - from moto.cloudformation.exceptions import \ - UnformattedGetAttTemplateException - if attribute_name == 'Arn': + from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + + if attribute_name == "Arn": return make_function_arn(self.region, ACCOUNT_ID, self.function_name) raise UnformattedGetAttTemplateException() @staticmethod def _create_zipfile_from_plaintext_code(code): zip_output = io.BytesIO() - zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED) - zip_file.writestr('lambda_function.zip', code) + zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) + zip_file.writestr("lambda_function.zip", code) zip_file.close() zip_output.seek(0) return zip_output.read() @@ -491,61 +532,66 @@ class LambdaFunction(BaseModel): class EventSourceMapping(BaseModel): def __init__(self, spec): # required - self.function_arn = spec['FunctionArn'] - self.event_source_arn = spec['EventSourceArn'] + self.function_arn = spec["FunctionArn"] + self.event_source_arn = spec["EventSourceArn"] self.uuid = str(uuid.uuid4()) self.last_modified = time.mktime(datetime.datetime.utcnow().timetuple()) # BatchSize service default/max mapping batch_size_map = { - 'kinesis': (100, 10000), - 'dynamodb': (100, 1000), - 'sqs': (10, 10), + "kinesis": (100, 10000), + "dynamodb": (100, 1000), + "sqs": (10, 10), } source_type = self.event_source_arn.split(":")[2].lower() batch_size_entry = batch_size_map.get(source_type) if batch_size_entry: # Use service default if not provided - batch_size = int(spec.get('BatchSize', batch_size_entry[0])) + batch_size = int(spec.get("BatchSize", batch_size_entry[0])) if batch_size > batch_size_entry[1]: - raise ValueError("InvalidParameterValueException", - "BatchSize {} exceeds the max of {}".format(batch_size, batch_size_entry[1])) + raise ValueError( + "InvalidParameterValueException", + "BatchSize {} exceeds the max of {}".format( + batch_size, batch_size_entry[1] + ), + ) else: self.batch_size = batch_size else: - raise ValueError("InvalidParameterValueException", - "Unsupported event source type") + raise ValueError( + "InvalidParameterValueException", "Unsupported event source type" + ) # optional - self.starting_position = spec.get('StartingPosition', 'TRIM_HORIZON') - self.enabled = spec.get('Enabled', True) - self.starting_position_timestamp = spec.get('StartingPositionTimestamp', - None) + self.starting_position = spec.get("StartingPosition", "TRIM_HORIZON") + self.enabled = spec.get("Enabled", True) + self.starting_position_timestamp = spec.get("StartingPositionTimestamp", None) def get_configuration(self): return { - 'UUID': self.uuid, - 'BatchSize': self.batch_size, - 'EventSourceArn': self.event_source_arn, - 'FunctionArn': self.function_arn, - 'LastModified': self.last_modified, - 'LastProcessingResult': '', - 'State': 'Enabled' if self.enabled else 'Disabled', - 'StateTransitionReason': 'User initiated' + "UUID": self.uuid, + "BatchSize": self.batch_size, + "EventSourceArn": self.event_source_arn, + "FunctionArn": self.function_arn, + "LastModified": self.last_modified, + "LastProcessingResult": "", + "State": "Enabled" if self.enabled else "Disabled", + "StateTransitionReason": "User initiated", } @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, - region_name): - properties = cloudformation_json['Properties'] - func = lambda_backends[region_name].get_function(properties['FunctionName']) + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + func = lambda_backends[region_name].get_function(properties["FunctionName"]) spec = { - 'FunctionArn': func.function_arn, - 'EventSourceArn': properties['EventSourceArn'], - 'StartingPosition': properties['StartingPosition'], - 'BatchSize': properties.get('BatchSize', 100) + "FunctionArn": func.function_arn, + "EventSourceArn": properties["EventSourceArn"], + "StartingPosition": properties["StartingPosition"], + "BatchSize": properties.get("BatchSize", 100), } - optional_properties = 'BatchSize Enabled StartingPositionTimestamp'.split() + optional_properties = "BatchSize Enabled StartingPositionTimestamp".split() for prop in optional_properties: if prop in properties: spec[prop] = properties[prop] @@ -554,20 +600,19 @@ class EventSourceMapping(BaseModel): class LambdaVersion(BaseModel): def __init__(self, spec): - self.version = spec['Version'] + self.version = spec["Version"] def __repr__(self): return str(self.logical_resource_id) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, - region_name): - properties = cloudformation_json['Properties'] - function_name = properties['FunctionName'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + function_name = properties["FunctionName"] func = lambda_backends[region_name].publish_function(function_name) - spec = { - 'Version': func.version - } + spec = {"Version": func.version} return LambdaVersion(spec) @@ -578,18 +623,18 @@ class LambdaStorage(object): self._arns = weakref.WeakValueDictionary() def _get_latest(self, name): - return self._functions[name]['latest'] + return self._functions[name]["latest"] def _get_version(self, name, version): index = version - 1 try: - return self._functions[name]['versions'][index] + return self._functions[name]["versions"][index] except IndexError: return None def _get_alias(self, name, alias): - return self._functions[name]['alias'].get(alias, None) + return self._functions[name]["alias"].get(alias, None) def get_function_by_name(self, name, qualifier=None): if name not in self._functions: @@ -601,15 +646,15 @@ class LambdaStorage(object): try: return self._get_version(name, int(qualifier)) except ValueError: - return self._functions[name]['latest'] + return self._functions[name]["latest"] def list_versions_by_function(self, name): if name not in self._functions: return None - latest = copy.copy(self._functions[name]['latest']) - latest.function_arn += ':$LATEST' - return [latest] + self._functions[name]['versions'] + latest = copy.copy(self._functions[name]["latest"]) + latest.function_arn += ":$LATEST" + return [latest] + self._functions[name]["versions"] def get_arn(self, arn): return self._arns.get(arn, None) @@ -623,12 +668,12 @@ class LambdaStorage(object): :type fn: LambdaFunction """ if fn.function_name in self._functions: - self._functions[fn.function_name]['latest'] = fn + self._functions[fn.function_name]["latest"] = fn else: self._functions[fn.function_name] = { - 'latest': fn, - 'versions': [], - 'alias': weakref.WeakValueDictionary() + "latest": fn, + "versions": [], + "alias": weakref.WeakValueDictionary(), } self._arns[fn.function_arn] = fn @@ -636,14 +681,14 @@ class LambdaStorage(object): def publish_function(self, name): if name not in self._functions: return None - if not self._functions[name]['latest']: + if not self._functions[name]["latest"]: return None - new_version = len(self._functions[name]['versions']) + 1 - fn = copy.copy(self._functions[name]['latest']) + new_version = len(self._functions[name]["versions"]) + 1 + fn = copy.copy(self._functions[name]["latest"]) fn.set_version(new_version) - self._functions[name]['versions'].append(fn) + self._functions[name]["versions"].append(fn) self._arns[fn.function_arn] = fn return fn @@ -653,21 +698,24 @@ class LambdaStorage(object): name = function.function_name if not qualifier: # Something is still reffing this so delete all arns - latest = self._functions[name]['latest'].function_arn + latest = self._functions[name]["latest"].function_arn del self._arns[latest] - for fn in self._functions[name]['versions']: + for fn in self._functions[name]["versions"]: del self._arns[fn.function_arn] del self._functions[name] return True - elif qualifier == '$LATEST': - self._functions[name]['latest'] = None + elif qualifier == "$LATEST": + self._functions[name]["latest"] = None # If theres no functions left - if not self._functions[name]['versions'] and not self._functions[name]['latest']: + if ( + not self._functions[name]["versions"] + and not self._functions[name]["latest"] + ): del self._functions[name] return True @@ -675,10 +723,13 @@ class LambdaStorage(object): else: fn = self.get_function_by_name(name, qualifier) if fn: - self._functions[name]['versions'].remove(fn) + self._functions[name]["versions"].remove(fn) # If theres no functions left - if not self._functions[name]['versions'] and not self._functions[name]['latest']: + if ( + not self._functions[name]["versions"] + and not self._functions[name]["latest"] + ): del self._functions[name] return True @@ -689,10 +740,10 @@ class LambdaStorage(object): result = [] for function_group in self._functions.values(): - if function_group['latest'] is not None: - result.append(function_group['latest']) + if function_group["latest"] is not None: + result.append(function_group["latest"]) - result.extend(function_group['versions']) + result.extend(function_group["versions"]) return result @@ -709,44 +760,47 @@ class LambdaBackend(BaseBackend): self.__init__(region_name) def create_function(self, spec): - function_name = spec.get('FunctionName', None) + function_name = spec.get("FunctionName", None) if function_name is None: - raise RESTError('InvalidParameterValueException', 'Missing FunctionName') + raise RESTError("InvalidParameterValueException", "Missing FunctionName") - fn = LambdaFunction(spec, self.region_name, version='$LATEST') + fn = LambdaFunction(spec, self.region_name, version="$LATEST") self._lambdas.put_function(fn) - if spec.get('Publish'): + if spec.get("Publish"): ver = self.publish_function(function_name) fn.version = ver.version return fn def create_event_source_mapping(self, spec): - required = [ - 'EventSourceArn', - 'FunctionName', - ] + required = ["EventSourceArn", "FunctionName"] for param in required: if not spec.get(param): - raise RESTError('InvalidParameterValueException', 'Missing {}'.format(param)) + raise RESTError( + "InvalidParameterValueException", "Missing {}".format(param) + ) # Validate function name - func = self._lambdas.get_function_by_name_or_arn(spec.pop('FunctionName', '')) + func = self._lambdas.get_function_by_name_or_arn(spec.pop("FunctionName", "")) if not func: - raise RESTError('ResourceNotFoundException', 'Invalid FunctionName') + raise RESTError("ResourceNotFoundException", "Invalid FunctionName") # Validate queue for queue in sqs_backends[self.region_name].queues.values(): - if queue.queue_arn == spec['EventSourceArn']: - if queue.lambda_event_source_mappings.get('func.function_arn'): + if queue.queue_arn == spec["EventSourceArn"]: + if queue.lambda_event_source_mappings.get("func.function_arn"): # TODO: Correct exception? - raise RESTError('ResourceConflictException', 'The resource already exists.') + raise RESTError( + "ResourceConflictException", "The resource already exists." + ) if queue.fifo_queue: - raise RESTError('InvalidParameterValueException', - '{} is FIFO'.format(queue.queue_arn)) + raise RESTError( + "InvalidParameterValueException", + "{} is FIFO".format(queue.queue_arn), + ) else: - spec.update({'FunctionArn': func.function_arn}) + spec.update({"FunctionArn": func.function_arn}) esm = EventSourceMapping(spec) self._event_source_mappings[esm.uuid] = esm @@ -754,16 +808,18 @@ class LambdaBackend(BaseBackend): queue.lambda_event_source_mappings[esm.function_arn] = esm return esm - for stream in json.loads(dynamodbstreams_backends[self.region_name].list_streams())['Streams']: - if stream['StreamArn'] == spec['EventSourceArn']: - spec.update({'FunctionArn': func.function_arn}) + for stream in json.loads( + dynamodbstreams_backends[self.region_name].list_streams() + )["Streams"]: + if stream["StreamArn"] == spec["EventSourceArn"]: + spec.update({"FunctionArn": func.function_arn}) esm = EventSourceMapping(spec) self._event_source_mappings[esm.uuid] = esm - table_name = stream['TableName'] + table_name = stream["TableName"] table = dynamodb_backends2[self.region_name].get_table(table_name) table.lambda_event_source_mappings[esm.function_arn] = esm return esm - raise RESTError('ResourceNotFoundException', 'Invalid EventSourceArn') + raise RESTError("ResourceNotFoundException", "Invalid EventSourceArn") def publish_function(self, function_name): return self._lambdas.publish_function(function_name) @@ -783,13 +839,15 @@ class LambdaBackend(BaseBackend): def update_event_source_mapping(self, uuid, spec): esm = self.get_event_source_mapping(uuid) if esm: - if spec.get('FunctionName'): - func = self._lambdas.get_function_by_name_or_arn(spec.get('FunctionName')) + if spec.get("FunctionName"): + func = self._lambdas.get_function_by_name_or_arn( + spec.get("FunctionName") + ) esm.function_arn = func.function_arn - if 'BatchSize' in spec: - esm.batch_size = spec['BatchSize'] - if 'Enabled' in spec: - esm.enabled = spec['Enabled'] + if "BatchSize" in spec: + esm.batch_size = spec["BatchSize"] + if "Enabled" in spec: + esm.enabled = spec["Enabled"] return esm return False @@ -830,13 +888,13 @@ class LambdaBackend(BaseBackend): "ApproximateReceiveCount": "1", "SentTimestamp": "1545082649183", "SenderId": "AIDAIENQZJOLO23YVJ4VO", - "ApproximateFirstReceiveTimestamp": "1545082649185" + "ApproximateFirstReceiveTimestamp": "1545082649185", }, "messageAttributes": {}, "md5OfBody": "098f6bcd4621d373cade4e832627b4f6", "eventSource": "aws:sqs", "eventSourceARN": queue_arn, - "awsRegion": self.region_name + "awsRegion": self.region_name, } ] } @@ -844,7 +902,7 @@ class LambdaBackend(BaseBackend): request_headers = {} response_headers = {} func.invoke(json.dumps(event), request_headers, response_headers) - return 'x-amz-function-error' not in response_headers + return "x-amz-function-error" not in response_headers def send_sns_message(self, function_name, message, subject=None, qualifier=None): event = { @@ -861,37 +919,35 @@ class LambdaBackend(BaseBackend): "MessageId": "95df01b4-ee98-5cb9-9903-4c221d41eb5e", "Message": message, "MessageAttributes": { - "Test": { - "Type": "String", - "Value": "TestString" - }, - "TestBinary": { - "Type": "Binary", - "Value": "TestBinary" - } + "Test": {"Type": "String", "Value": "TestString"}, + "TestBinary": {"Type": "Binary", "Value": "TestBinary"}, }, "Type": "Notification", "UnsubscribeUrl": "EXAMPLE", "TopicArn": "arn:aws:sns:EXAMPLE", - "Subject": subject or "TestInvoke" - } + "Subject": subject or "TestInvoke", + }, } ] - } func = self._lambdas.get_function_by_name_or_arn(function_name, qualifier) func.invoke(json.dumps(event), {}, {}) def send_dynamodb_items(self, function_arn, items, source): - event = {'Records': [ - { - 'eventID': item.to_json()['eventID'], - 'eventName': 'INSERT', - 'eventVersion': item.to_json()['eventVersion'], - 'eventSource': item.to_json()['eventSource'], - 'awsRegion': self.region_name, - 'dynamodb': item.to_json()['dynamodb'], - 'eventSourceARN': source} for item in items]} + event = { + "Records": [ + { + "eventID": item.to_json()["eventID"], + "eventName": "INSERT", + "eventVersion": item.to_json()["eventVersion"], + "eventSource": item.to_json()["eventSource"], + "awsRegion": self.region_name, + "dynamodb": item.to_json()["dynamodb"], + "eventSourceARN": source, + } + for item in items + ] + } func = self._lambdas.get_arn(function_arn) func.invoke(json.dumps(event), {}, {}) @@ -923,12 +979,13 @@ class LambdaBackend(BaseBackend): def do_validate_s3(): - return os.environ.get('VALIDATE_LAMBDA_S3', '') in ['', '1', 'true'] + return os.environ.get("VALIDATE_LAMBDA_S3", "") in ["", "1", "true"] # Handle us forgotten regions, unless Lambda truly only runs out of US and -lambda_backends = {_region.name: LambdaBackend(_region.name) - for _region in boto.awslambda.regions()} +lambda_backends = { + _region.name: LambdaBackend(_region.name) for _region in boto.awslambda.regions() +} -lambda_backends['ap-southeast-2'] = LambdaBackend('ap-southeast-2') -lambda_backends['us-gov-west-1'] = LambdaBackend('us-gov-west-1') +lambda_backends["ap-southeast-2"] = LambdaBackend("ap-southeast-2") +lambda_backends["us-gov-west-1"] = LambdaBackend("us-gov-west-1") diff --git a/moto/awslambda/responses.py b/moto/awslambda/responses.py index 37016e718..62265b310 100644 --- a/moto/awslambda/responses.py +++ b/moto/awslambda/responses.py @@ -32,57 +32,57 @@ class LambdaResponse(BaseResponse): def root(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": return self._list_functions(request, full_url, headers) - elif request.method == 'POST': + elif request.method == "POST": return self._create_function(request, full_url, headers) else: raise ValueError("Cannot handle request") def event_source_mappings(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": querystring = self.querystring - event_source_arn = querystring.get('EventSourceArn', [None])[0] - function_name = querystring.get('FunctionName', [None])[0] + event_source_arn = querystring.get("EventSourceArn", [None])[0] + function_name = querystring.get("FunctionName", [None])[0] return self._list_event_source_mappings(event_source_arn, function_name) - elif request.method == 'POST': + elif request.method == "POST": return self._create_event_source_mapping(request, full_url, headers) else: raise ValueError("Cannot handle request") def event_source_mapping(self, request, full_url, headers): self.setup_class(request, full_url, headers) - path = request.path if hasattr(request, 'path') else path_url(request.url) - uuid = path.split('/')[-1] - if request.method == 'GET': + path = request.path if hasattr(request, "path") else path_url(request.url) + uuid = path.split("/")[-1] + if request.method == "GET": return self._get_event_source_mapping(uuid) - elif request.method == 'PUT': + elif request.method == "PUT": return self._update_event_source_mapping(uuid) - elif request.method == 'DELETE': + elif request.method == "DELETE": return self._delete_event_source_mapping(uuid) else: raise ValueError("Cannot handle request") def function(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": return self._get_function(request, full_url, headers) - elif request.method == 'DELETE': + elif request.method == "DELETE": return self._delete_function(request, full_url, headers) else: raise ValueError("Cannot handle request") def versions(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": # This is ListVersionByFunction - path = request.path if hasattr(request, 'path') else path_url(request.url) - function_name = path.split('/')[-2] + path = request.path if hasattr(request, "path") else path_url(request.url) + function_name = path.split("/")[-2] return self._list_versions_by_function(function_name) - elif request.method == 'POST': + elif request.method == "POST": return self._publish_function(request, full_url, headers) else: raise ValueError("Cannot handle request") @@ -91,7 +91,7 @@ class LambdaResponse(BaseResponse): @amzn_request_id def invoke(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'POST': + if request.method == "POST": return self._invoke(request, full_url) else: raise ValueError("Cannot handle request") @@ -100,46 +100,46 @@ class LambdaResponse(BaseResponse): @amzn_request_id def invoke_async(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'POST': + if request.method == "POST": return self._invoke_async(request, full_url) else: raise ValueError("Cannot handle request") def tag(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": return self._list_tags(request, full_url) - elif request.method == 'POST': + elif request.method == "POST": return self._tag_resource(request, full_url) - elif request.method == 'DELETE': + elif request.method == "DELETE": return self._untag_resource(request, full_url) else: raise ValueError("Cannot handle {0} request".format(request.method)) def policy(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'GET': + if request.method == "GET": return self._get_policy(request, full_url, headers) - if request.method == 'POST': + if request.method == "POST": return self._add_policy(request, full_url, headers) def configuration(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'PUT': + if request.method == "PUT": return self._put_configuration(request) else: raise ValueError("Cannot handle request") def code(self, request, full_url, headers): self.setup_class(request, full_url, headers) - if request.method == 'PUT': + if request.method == "PUT": return self._put_code() else: raise ValueError("Cannot handle request") def _add_policy(self, request, full_url, headers): - path = request.path if hasattr(request, 'path') else path_url(request.url) - function_name = path.split('/')[-2] + path = request.path if hasattr(request, "path") else path_url(request.url) + function_name = path.split("/")[-2] if self.lambda_backend.get_function(function_name): policy = self.body self.lambda_backend.add_policy(function_name, policy) @@ -148,24 +148,30 @@ class LambdaResponse(BaseResponse): return 404, {}, "{}" def _get_policy(self, request, full_url, headers): - path = request.path if hasattr(request, 'path') else path_url(request.url) - function_name = path.split('/')[-2] + path = request.path if hasattr(request, "path") else path_url(request.url) + function_name = path.split("/")[-2] if self.lambda_backend.get_function(function_name): lambda_function = self.lambda_backend.get_function(function_name) - return 200, {}, json.dumps(dict(Policy="{\"Statement\":[" + lambda_function.policy + "]}")) + return ( + 200, + {}, + json.dumps( + dict(Policy='{"Statement":[' + lambda_function.policy + "]}") + ), + ) else: return 404, {}, "{}" def _invoke(self, request, full_url): response_headers = {} - function_name = self.path.rsplit('/', 2)[-2] - qualifier = self._get_param('qualifier') + function_name = self.path.rsplit("/", 2)[-2] + qualifier = self._get_param("qualifier") fn = self.lambda_backend.get_function(function_name, qualifier) if fn: payload = fn.invoke(self.body, self.headers, response_headers) - response_headers['Content-Length'] = str(len(payload)) + response_headers["Content-Length"] = str(len(payload)) return 202, response_headers, payload else: return 404, response_headers, "{}" @@ -173,38 +179,34 @@ class LambdaResponse(BaseResponse): def _invoke_async(self, request, full_url): response_headers = {} - function_name = self.path.rsplit('/', 3)[-3] + function_name = self.path.rsplit("/", 3)[-3] fn = self.lambda_backend.get_function(function_name, None) if fn: payload = fn.invoke(self.body, self.headers, response_headers) - response_headers['Content-Length'] = str(len(payload)) + response_headers["Content-Length"] = str(len(payload)) return 202, response_headers, payload else: return 404, response_headers, "{}" def _list_functions(self, request, full_url, headers): - result = { - 'Functions': [] - } + result = {"Functions": []} for fn in self.lambda_backend.list_functions(): json_data = fn.get_configuration() - json_data['Version'] = '$LATEST' - result['Functions'].append(json_data) + json_data["Version"] = "$LATEST" + result["Functions"].append(json_data) return 200, {}, json.dumps(result) def _list_versions_by_function(self, function_name): - result = { - 'Versions': [] - } + result = {"Versions": []} functions = self.lambda_backend.list_versions_by_function(function_name) if functions: for fn in functions: json_data = fn.get_configuration() - result['Versions'].append(json_data) + result["Versions"].append(json_data) return 200, {}, json.dumps(result) @@ -212,7 +214,11 @@ class LambdaResponse(BaseResponse): try: fn = self.lambda_backend.create_function(self.json_body) except ValueError as e: - return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}) + return ( + 400, + {}, + json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}), + ) else: config = fn.get_configuration() return 201, {}, json.dumps(config) @@ -221,16 +227,20 @@ class LambdaResponse(BaseResponse): try: fn = self.lambda_backend.create_event_source_mapping(self.json_body) except ValueError as e: - return 400, {}, json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}) + return ( + 400, + {}, + json.dumps({"Error": {"Code": e.args[0], "Message": e.args[1]}}), + ) else: config = fn.get_configuration() return 201, {}, json.dumps(config) def _list_event_source_mappings(self, event_source_arn, function_name): - esms = self.lambda_backend.list_event_source_mappings(event_source_arn, function_name) - result = { - 'EventSourceMappings': [esm.get_configuration() for esm in esms] - } + esms = self.lambda_backend.list_event_source_mappings( + event_source_arn, function_name + ) + result = {"EventSourceMappings": [esm.get_configuration() for esm in esms]} return 200, {}, json.dumps(result) def _get_event_source_mapping(self, uuid): @@ -251,13 +261,13 @@ class LambdaResponse(BaseResponse): esm = self.lambda_backend.delete_event_source_mapping(uuid) if esm: json_result = esm.get_configuration() - json_result.update({'State': 'Deleting'}) + json_result.update({"State": "Deleting"}) return 202, {}, json.dumps(json_result) else: return 404, {}, "{}" def _publish_function(self, request, full_url, headers): - function_name = self.path.rsplit('/', 2)[-2] + function_name = self.path.rsplit("/", 2)[-2] fn = self.lambda_backend.publish_function(function_name) if fn: @@ -267,8 +277,8 @@ class LambdaResponse(BaseResponse): return 404, {}, "{}" def _delete_function(self, request, full_url, headers): - function_name = unquote(self.path.rsplit('/', 1)[-1]) - qualifier = self._get_param('Qualifier', None) + function_name = unquote(self.path.rsplit("/", 1)[-1]) + qualifier = self._get_param("Qualifier", None) if self.lambda_backend.delete_function(function_name, qualifier): return 204, {}, "" @@ -276,17 +286,17 @@ class LambdaResponse(BaseResponse): return 404, {}, "{}" def _get_function(self, request, full_url, headers): - function_name = unquote(self.path.rsplit('/', 1)[-1]) - qualifier = self._get_param('Qualifier', None) + function_name = unquote(self.path.rsplit("/", 1)[-1]) + qualifier = self._get_param("Qualifier", None) fn = self.lambda_backend.get_function(function_name, qualifier) if fn: code = fn.get_code() - if qualifier is None or qualifier == '$LATEST': - code['Configuration']['Version'] = '$LATEST' - if qualifier == '$LATEST': - code['Configuration']['FunctionArn'] += ':$LATEST' + if qualifier is None or qualifier == "$LATEST": + code["Configuration"]["Version"] = "$LATEST" + if qualifier == "$LATEST": + code["Configuration"]["FunctionArn"] += ":$LATEST" return 200, {}, json.dumps(code) else: return 404, {}, "{}" @@ -299,25 +309,25 @@ class LambdaResponse(BaseResponse): return self.default_region def _list_tags(self, request, full_url): - function_arn = unquote(self.path.rsplit('/', 1)[-1]) + function_arn = unquote(self.path.rsplit("/", 1)[-1]) fn = self.lambda_backend.get_function_by_arn(function_arn) if fn: - return 200, {}, json.dumps({'Tags': fn.tags}) + return 200, {}, json.dumps({"Tags": fn.tags}) else: return 404, {}, "{}" def _tag_resource(self, request, full_url): - function_arn = unquote(self.path.rsplit('/', 1)[-1]) + function_arn = unquote(self.path.rsplit("/", 1)[-1]) - if self.lambda_backend.tag_resource(function_arn, self.json_body['Tags']): + if self.lambda_backend.tag_resource(function_arn, self.json_body["Tags"]): return 200, {}, "{}" else: return 404, {}, "{}" def _untag_resource(self, request, full_url): - function_arn = unquote(self.path.rsplit('/', 1)[-1]) - tag_keys = self.querystring['tagKeys'] + function_arn = unquote(self.path.rsplit("/", 1)[-1]) + tag_keys = self.querystring["tagKeys"] if self.lambda_backend.untag_resource(function_arn, tag_keys): return 204, {}, "{}" @@ -325,8 +335,8 @@ class LambdaResponse(BaseResponse): return 404, {}, "{}" def _put_configuration(self, request): - function_name = self.path.rsplit('/', 2)[-2] - qualifier = self._get_param('Qualifier', None) + function_name = self.path.rsplit("/", 2)[-2] + qualifier = self._get_param("Qualifier", None) fn = self.lambda_backend.get_function(function_name, qualifier) @@ -337,13 +347,13 @@ class LambdaResponse(BaseResponse): return 404, {}, "{}" def _put_code(self): - function_name = self.path.rsplit('/', 2)[-2] - qualifier = self._get_param('Qualifier', None) + function_name = self.path.rsplit("/", 2)[-2] + qualifier = self._get_param("Qualifier", None) fn = self.lambda_backend.get_function(function_name, qualifier) if fn: - if self.json_body.get('Publish', False): + if self.json_body.get("Publish", False): fn = self.lambda_backend.publish_function(function_name) config = fn.update_function_code(self.json_body) diff --git a/moto/awslambda/urls.py b/moto/awslambda/urls.py index 0ee8797d3..da7346817 100644 --- a/moto/awslambda/urls.py +++ b/moto/awslambda/urls.py @@ -1,22 +1,20 @@ from __future__ import unicode_literals from .responses import LambdaResponse -url_bases = [ - "https?://lambda.(.+).amazonaws.com", -] +url_bases = ["https?://lambda.(.+).amazonaws.com"] response = LambdaResponse() url_paths = { - '{0}/(?P[^/]+)/functions/?$': response.root, - r'{0}/(?P[^/]+)/functions/(?P[\w_:%-]+)/?$': response.function, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/versions/?$': response.versions, - r'{0}/(?P[^/]+)/event-source-mappings/?$': response.event_source_mappings, - r'{0}/(?P[^/]+)/event-source-mappings/(?P[\w_-]+)/?$': response.event_source_mapping, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invocations/?$': response.invoke, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invoke-async/?$': response.invoke_async, - r'{0}/(?P[^/]+)/tags/(?P.+)': response.tag, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/policy/?$': response.policy, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/configuration/?$': response.configuration, - r'{0}/(?P[^/]+)/functions/(?P[\w_-]+)/code/?$': response.code + "{0}/(?P[^/]+)/functions/?$": response.root, + r"{0}/(?P[^/]+)/functions/(?P[\w_:%-]+)/?$": response.function, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/versions/?$": response.versions, + r"{0}/(?P[^/]+)/event-source-mappings/?$": response.event_source_mappings, + r"{0}/(?P[^/]+)/event-source-mappings/(?P[\w_-]+)/?$": response.event_source_mapping, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invocations/?$": response.invoke, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/invoke-async/?$": response.invoke_async, + r"{0}/(?P[^/]+)/tags/(?P.+)": response.tag, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/policy/?$": response.policy, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/configuration/?$": response.configuration, + r"{0}/(?P[^/]+)/functions/(?P[\w_-]+)/code/?$": response.code, } diff --git a/moto/awslambda/utils.py b/moto/awslambda/utils.py index 82027cb2f..e024b7b9b 100644 --- a/moto/awslambda/utils.py +++ b/moto/awslambda/utils.py @@ -1,20 +1,20 @@ from collections import namedtuple -ARN = namedtuple('ARN', ['region', 'account', 'function_name', 'version']) +ARN = namedtuple("ARN", ["region", "account", "function_name", "version"]) def make_function_arn(region, account, name): - return 'arn:aws:lambda:{0}:{1}:function:{2}'.format(region, account, name) + return "arn:aws:lambda:{0}:{1}:function:{2}".format(region, account, name) -def make_function_ver_arn(region, account, name, version='1'): +def make_function_ver_arn(region, account, name, version="1"): arn = make_function_arn(region, account, name) - return '{0}:{1}'.format(arn, version) + return "{0}:{1}".format(arn, version) def split_function_arn(arn): - arn = arn.replace('arn:aws:lambda:') + arn = arn.replace("arn:aws:lambda:") - region, account, _, name, version = arn.split(':') + region, account, _, name, version = arn.split(":") return ARN(region, account, name, version) diff --git a/moto/backends.py b/moto/backends.py index 0a387ac7e..bd91b1da2 100644 --- a/moto/backends.py +++ b/moto/backends.py @@ -52,57 +52,57 @@ from moto.resourcegroupstaggingapi import resourcegroupstaggingapi_backends from moto.config import config_backends BACKENDS = { - 'acm': acm_backends, - 'apigateway': apigateway_backends, - 'athena': athena_backends, - 'autoscaling': autoscaling_backends, - 'batch': batch_backends, - 'cloudformation': cloudformation_backends, - 'cloudwatch': cloudwatch_backends, - 'cognito-identity': cognitoidentity_backends, - 'cognito-idp': cognitoidp_backends, - 'config': config_backends, - 'datapipeline': datapipeline_backends, - 'dynamodb': dynamodb_backends, - 'dynamodb2': dynamodb_backends2, - 'dynamodbstreams': dynamodbstreams_backends, - 'ec2': ec2_backends, - 'ecr': ecr_backends, - 'ecs': ecs_backends, - 'elb': elb_backends, - 'elbv2': elbv2_backends, - 'events': events_backends, - 'emr': emr_backends, - 'glacier': glacier_backends, - 'glue': glue_backends, - 'iam': iam_backends, - 'moto_api': moto_api_backends, - 'instance_metadata': instance_metadata_backends, - 'logs': logs_backends, - 'kinesis': kinesis_backends, - 'kms': kms_backends, - 'opsworks': opsworks_backends, - 'organizations': organizations_backends, - 'polly': polly_backends, - 'redshift': redshift_backends, - 'resource-groups': resourcegroups_backends, - 'rds': rds2_backends, - 's3': s3_backends, - 's3bucket_path': s3_backends, - 'ses': ses_backends, - 'secretsmanager': secretsmanager_backends, - 'sns': sns_backends, - 'sqs': sqs_backends, - 'ssm': ssm_backends, - 'stepfunctions': stepfunction_backends, - 'sts': sts_backends, - 'swf': swf_backends, - 'route53': route53_backends, - 'lambda': lambda_backends, - 'xray': xray_backends, - 'resourcegroupstaggingapi': resourcegroupstaggingapi_backends, - 'iot': iot_backends, - 'iot-data': iotdata_backends, + "acm": acm_backends, + "apigateway": apigateway_backends, + "athena": athena_backends, + "autoscaling": autoscaling_backends, + "batch": batch_backends, + "cloudformation": cloudformation_backends, + "cloudwatch": cloudwatch_backends, + "cognito-identity": cognitoidentity_backends, + "cognito-idp": cognitoidp_backends, + "config": config_backends, + "datapipeline": datapipeline_backends, + "dynamodb": dynamodb_backends, + "dynamodb2": dynamodb_backends2, + "dynamodbstreams": dynamodbstreams_backends, + "ec2": ec2_backends, + "ecr": ecr_backends, + "ecs": ecs_backends, + "elb": elb_backends, + "elbv2": elbv2_backends, + "events": events_backends, + "emr": emr_backends, + "glacier": glacier_backends, + "glue": glue_backends, + "iam": iam_backends, + "moto_api": moto_api_backends, + "instance_metadata": instance_metadata_backends, + "logs": logs_backends, + "kinesis": kinesis_backends, + "kms": kms_backends, + "opsworks": opsworks_backends, + "organizations": organizations_backends, + "polly": polly_backends, + "redshift": redshift_backends, + "resource-groups": resourcegroups_backends, + "rds": rds2_backends, + "s3": s3_backends, + "s3bucket_path": s3_backends, + "ses": ses_backends, + "secretsmanager": secretsmanager_backends, + "sns": sns_backends, + "sqs": sqs_backends, + "ssm": ssm_backends, + "stepfunctions": stepfunction_backends, + "sts": sts_backends, + "swf": swf_backends, + "route53": route53_backends, + "lambda": lambda_backends, + "xray": xray_backends, + "resourcegroupstaggingapi": resourcegroupstaggingapi_backends, + "iot": iot_backends, + "iot-data": iotdata_backends, } @@ -110,6 +110,6 @@ def get_model(name, region_name): for backends in BACKENDS.values(): for region, backend in backends.items(): if region == region_name: - models = getattr(backend.__class__, '__models__', {}) + models = getattr(backend.__class__, "__models__", {}) if name in models: return list(getattr(backend, models[name])()) diff --git a/moto/batch/__init__.py b/moto/batch/__init__.py index 6002b6fc7..40144d35d 100644 --- a/moto/batch/__init__.py +++ b/moto/batch/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import batch_backends from ..core.models import base_decorator -batch_backend = batch_backends['us-east-1'] +batch_backend = batch_backends["us-east-1"] mock_batch = base_decorator(batch_backends) diff --git a/moto/batch/exceptions.py b/moto/batch/exceptions.py index a71e54ce3..c411f3fce 100644 --- a/moto/batch/exceptions.py +++ b/moto/batch/exceptions.py @@ -12,26 +12,29 @@ class AWSError(Exception): self.status = status if status is not None else self.STATUS def response(self): - return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) + return ( + json.dumps({"__type": self.code, "message": self.message}), + dict(status=self.status), + ) class InvalidRequestException(AWSError): - CODE = 'InvalidRequestException' + CODE = "InvalidRequestException" class InvalidParameterValueException(AWSError): - CODE = 'InvalidParameterValue' + CODE = "InvalidParameterValue" class ValidationError(AWSError): - CODE = 'ValidationError' + CODE = "ValidationError" class InternalFailure(AWSError): - CODE = 'InternalFailure' + CODE = "InternalFailure" STATUS = 500 class ClientException(AWSError): - CODE = 'ClientException' + CODE = "ClientException" STATUS = 400 diff --git a/moto/batch/models.py b/moto/batch/models.py index caa442802..5c7fb4739 100644 --- a/moto/batch/models.py +++ b/moto/batch/models.py @@ -19,7 +19,12 @@ from moto.ecs import ecs_backends from moto.logs import logs_backends from .exceptions import InvalidParameterValueException, InternalFailure, ClientException -from .utils import make_arn_for_compute_env, make_arn_for_job_queue, make_arn_for_task_def, lowercase_first_key +from .utils import ( + make_arn_for_compute_env, + make_arn_for_job_queue, + make_arn_for_task_def, + lowercase_first_key, +) from moto.ec2.exceptions import InvalidSubnetIdError from moto.ec2.models import INSTANCE_TYPES as EC2_INSTANCE_TYPES from moto.iam.exceptions import IAMNotFoundException @@ -28,7 +33,9 @@ from moto.iam.exceptions import IAMNotFoundException _orig_adapter_send = requests.adapters.HTTPAdapter.send logger = logging.getLogger(__name__) DEFAULT_ACCOUNT_ID = 123456789012 -COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile(r'^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$') +COMPUTE_ENVIRONMENT_NAME_REGEX = re.compile( + r"^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$" +) def datetime2int(date): @@ -36,13 +43,23 @@ def datetime2int(date): class ComputeEnvironment(BaseModel): - def __init__(self, compute_environment_name, _type, state, compute_resources, service_role, region_name): + def __init__( + self, + compute_environment_name, + _type, + state, + compute_resources, + service_role, + region_name, + ): self.name = compute_environment_name self.env_type = _type self.state = state self.compute_resources = compute_resources self.service_role = service_role - self.arn = make_arn_for_compute_env(DEFAULT_ACCOUNT_ID, compute_environment_name, region_name) + self.arn = make_arn_for_compute_env( + DEFAULT_ACCOUNT_ID, compute_environment_name, region_name + ) self.instances = [] self.ecs_arn = None @@ -60,16 +77,18 @@ class ComputeEnvironment(BaseModel): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = batch_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] env = backend.create_compute_environment( resource_name, - properties['Type'], - properties.get('State', 'ENABLED'), - lowercase_first_key(properties['ComputeResources']), - properties['ServiceRole'] + properties["Type"], + properties.get("State", "ENABLED"), + lowercase_first_key(properties["ComputeResources"]), + properties["ServiceRole"], ) arn = env[1] @@ -77,7 +96,9 @@ class ComputeEnvironment(BaseModel): class JobQueue(BaseModel): - def __init__(self, name, priority, state, environments, env_order_json, region_name): + def __init__( + self, name, priority, state, environments, env_order_json, region_name + ): """ :param name: Job queue name :type name: str @@ -98,18 +119,18 @@ class JobQueue(BaseModel): self.environments = environments self.env_order_json = env_order_json self.arn = make_arn_for_job_queue(DEFAULT_ACCOUNT_ID, name, region_name) - self.status = 'VALID' + self.status = "VALID" self.jobs = [] def describe(self): result = { - 'computeEnvironmentOrder': self.env_order_json, - 'jobQueueArn': self.arn, - 'jobQueueName': self.name, - 'priority': self.priority, - 'state': self.state, - 'status': self.status + "computeEnvironmentOrder": self.env_order_json, + "jobQueueArn": self.arn, + "jobQueueName": self.name, + "priority": self.priority, + "state": self.state, + "status": self.status, } return result @@ -119,19 +140,24 @@ class JobQueue(BaseModel): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = batch_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] # Need to deal with difference case from cloudformation compute_resources, e.g. instanceRole vs InstanceRole # Hacky fix to normalise keys, is making me think I want to start spamming cAsEiNsEnSiTiVe dictionaries - compute_envs = [lowercase_first_key(dict_item) for dict_item in properties['ComputeEnvironmentOrder']] + compute_envs = [ + lowercase_first_key(dict_item) + for dict_item in properties["ComputeEnvironmentOrder"] + ] queue = backend.create_job_queue( queue_name=resource_name, - priority=properties['Priority'], - state=properties.get('State', 'ENABLED'), - compute_env_order=compute_envs + priority=properties["Priority"], + state=properties.get("State", "ENABLED"), + compute_env_order=compute_envs, ) arn = queue[1] @@ -139,7 +165,16 @@ class JobQueue(BaseModel): class JobDefinition(BaseModel): - def __init__(self, name, parameters, _type, container_properties, region_name, revision=0, retry_strategy=0): + def __init__( + self, + name, + parameters, + _type, + container_properties, + region_name, + revision=0, + retry_strategy=0, + ): self.name = name self.retries = retry_strategy self.type = _type @@ -147,7 +182,7 @@ class JobDefinition(BaseModel): self._region = region_name self.container_properties = container_properties self.arn = None - self.status = 'INACTIVE' + self.status = "INACTIVE" if parameters is None: parameters = {} @@ -158,31 +193,33 @@ class JobDefinition(BaseModel): def _update_arn(self): self.revision += 1 - self.arn = make_arn_for_task_def(DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region) + self.arn = make_arn_for_task_def( + DEFAULT_ACCOUNT_ID, self.name, self.revision, self._region + ) def _validate(self): - if self.type not in ('container',): + if self.type not in ("container",): raise ClientException('type must be one of "container"') # For future use when containers arnt the only thing in batch - if self.type != 'container': + if self.type != "container": raise NotImplementedError() if not isinstance(self.parameters, dict): - raise ClientException('parameters must be a string to string map') + raise ClientException("parameters must be a string to string map") - if 'image' not in self.container_properties: - raise ClientException('containerProperties must contain image') + if "image" not in self.container_properties: + raise ClientException("containerProperties must contain image") - if 'memory' not in self.container_properties: - raise ClientException('containerProperties must contain memory') - if self.container_properties['memory'] < 4: - raise ClientException('container memory limit must be greater than 4') + if "memory" not in self.container_properties: + raise ClientException("containerProperties must contain memory") + if self.container_properties["memory"] < 4: + raise ClientException("container memory limit must be greater than 4") - if 'vcpus' not in self.container_properties: - raise ClientException('containerProperties must contain vcpus') - if self.container_properties['vcpus'] < 1: - raise ClientException('container vcpus limit must be greater than 0') + if "vcpus" not in self.container_properties: + raise ClientException("containerProperties must contain vcpus") + if self.container_properties["vcpus"] < 1: + raise ClientException("container vcpus limit must be greater than 0") def update(self, parameters, _type, container_properties, retry_strategy): if parameters is None: @@ -197,21 +234,29 @@ class JobDefinition(BaseModel): if retry_strategy is None: retry_strategy = self.retries - return JobDefinition(self.name, parameters, _type, container_properties, region_name=self._region, revision=self.revision, retry_strategy=retry_strategy) + return JobDefinition( + self.name, + parameters, + _type, + container_properties, + region_name=self._region, + revision=self.revision, + retry_strategy=retry_strategy, + ) def describe(self): result = { - 'jobDefinitionArn': self.arn, - 'jobDefinitionName': self.name, - 'parameters': self.parameters, - 'revision': self.revision, - 'status': self.status, - 'type': self.type + "jobDefinitionArn": self.arn, + "jobDefinitionName": self.name, + "parameters": self.parameters, + "revision": self.revision, + "status": self.status, + "type": self.type, } if self.container_properties is not None: - result['containerProperties'] = self.container_properties + result["containerProperties"] = self.container_properties if self.retries is not None and self.retries > 0: - result['retryStrategy'] = {'attempts': self.retries} + result["retryStrategy"] = {"attempts": self.retries} return result @@ -220,16 +265,18 @@ class JobDefinition(BaseModel): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): backend = batch_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] res = backend.register_job_definition( def_name=resource_name, - parameters=lowercase_first_key(properties.get('Parameters', {})), - _type='container', - retry_strategy=lowercase_first_key(properties['RetryStrategy']), - container_properties=lowercase_first_key(properties['ContainerProperties']) + parameters=lowercase_first_key(properties.get("Parameters", {})), + _type="container", + retry_strategy=lowercase_first_key(properties["RetryStrategy"]), + container_properties=lowercase_first_key(properties["ContainerProperties"]), ) arn = res[1] @@ -255,7 +302,7 @@ class Job(threading.Thread, BaseModel): self.job_id = str(uuid.uuid4()) self.job_definition = job_def self.job_queue = job_queue - self.job_state = 'SUBMITTED' # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED + self.job_state = "SUBMITTED" # One of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED self.job_queue.jobs.append(self) self.job_started_at = datetime.datetime(1970, 1, 1) self.job_stopped_at = datetime.datetime(1970, 1, 1) @@ -265,7 +312,7 @@ class Job(threading.Thread, BaseModel): self.stop = False self.daemon = True - self.name = 'MOTO-BATCH-' + self.job_id + self.name = "MOTO-BATCH-" + self.job_id self.docker_client = docker.from_env() self._log_backend = log_backend @@ -281,30 +328,33 @@ class Job(threading.Thread, BaseModel): if isinstance(adapter, requests.adapters.HTTPAdapter): adapter.send = functools.partial(_orig_adapter_send, adapter) return adapter + self.docker_client.api.get_adapter = replace_adapter_send def describe(self): result = { - 'jobDefinition': self.job_definition.arn, - 'jobId': self.job_id, - 'jobName': self.job_name, - 'jobQueue': self.job_queue.arn, - 'startedAt': datetime2int(self.job_started_at), - 'status': self.job_state, - 'dependsOn': [] + "jobDefinition": self.job_definition.arn, + "jobId": self.job_id, + "jobName": self.job_name, + "jobQueue": self.job_queue.arn, + "startedAt": datetime2int(self.job_started_at), + "status": self.job_state, + "dependsOn": [], } if self.job_stopped: - result['stoppedAt'] = datetime2int(self.job_stopped_at) - result['container'] = {} - result['container']['command'] = ['/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"'] - result['container']['privileged'] = False - result['container']['readonlyRootFilesystem'] = False - result['container']['ulimits'] = {} - result['container']['vcpus'] = 1 - result['container']['volumes'] = '' - result['container']['logStreamName'] = self.log_stream_name + result["stoppedAt"] = datetime2int(self.job_stopped_at) + result["container"] = {} + result["container"]["command"] = [ + '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"' + ] + result["container"]["privileged"] = False + result["container"]["readonlyRootFilesystem"] = False + result["container"]["ulimits"] = {} + result["container"]["vcpus"] = 1 + result["container"]["volumes"] = "" + result["container"]["logStreamName"] = self.log_stream_name if self.job_stopped_reason is not None: - result['statusReason'] = self.job_stopped_reason + result["statusReason"] = self.job_stopped_reason return result def run(self): @@ -322,24 +372,22 @@ class Job(threading.Thread, BaseModel): :return: """ try: - self.job_state = 'PENDING' + self.job_state = "PENDING" time.sleep(1) - image = 'alpine:latest' + image = "alpine:latest" cmd = '/bin/sh -c "for a in `seq 1 10`; do echo Hello World; sleep 1; done"' - name = '{0}-{1}'.format(self.job_name, self.job_id) + name = "{0}-{1}".format(self.job_name, self.job_id) - self.job_state = 'RUNNABLE' + self.job_state = "RUNNABLE" # TODO setup ecs container instance time.sleep(1) - self.job_state = 'STARTING' + self.job_state = "STARTING" container = self.docker_client.containers.run( - image, cmd, - detach=True, - name=name + image, cmd, detach=True, name=name ) - self.job_state = 'RUNNING' + self.job_state = "RUNNING" self.job_started_at = datetime.datetime.now() try: # Log collection @@ -353,53 +401,99 @@ class Job(threading.Thread, BaseModel): # events seem to be duplicated. now = datetime.datetime.now() i = 1 - while container.status == 'running' and not self.stop: + while container.status == "running" and not self.stop: time.sleep(0.15) if i % 10 == 0: - logs_stderr.extend(container.logs(stdout=False, stderr=True, timestamps=True, since=datetime2int(now)).decode().split('\n')) - logs_stdout.extend(container.logs(stdout=True, stderr=False, timestamps=True, since=datetime2int(now)).decode().split('\n')) + logs_stderr.extend( + container.logs( + stdout=False, + stderr=True, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) + logs_stdout.extend( + container.logs( + stdout=True, + stderr=False, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) now = datetime.datetime.now() container.reload() i += 1 # Container should be stopped by this point... unless asked to stop - if container.status == 'running': + if container.status == "running": container.kill() self.job_stopped_at = datetime.datetime.now() # Get final logs - logs_stderr.extend(container.logs(stdout=False, stderr=True, timestamps=True, since=datetime2int(now)).decode().split('\n')) - logs_stdout.extend(container.logs(stdout=True, stderr=False, timestamps=True, since=datetime2int(now)).decode().split('\n')) + logs_stderr.extend( + container.logs( + stdout=False, + stderr=True, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) + logs_stdout.extend( + container.logs( + stdout=True, + stderr=False, + timestamps=True, + since=datetime2int(now), + ) + .decode() + .split("\n") + ) - self.job_state = 'SUCCEEDED' if not self.stop else 'FAILED' + self.job_state = "SUCCEEDED" if not self.stop else "FAILED" # Process logs logs_stdout = [x for x in logs_stdout if len(x) > 0] logs_stderr = [x for x in logs_stderr if len(x) > 0] logs = [] for line in logs_stdout + logs_stderr: - date, line = line.split(' ', 1) + date, line = line.split(" ", 1) date = dateutil.parser.parse(date) date = int(date.timestamp()) - logs.append({'timestamp': date, 'message': line.strip()}) + logs.append({"timestamp": date, "message": line.strip()}) # Send to cloudwatch - log_group = '/aws/batch/job' - stream_name = '{0}/default/{1}'.format(self.job_definition.name, self.job_id) + log_group = "/aws/batch/job" + stream_name = "{0}/default/{1}".format( + self.job_definition.name, self.job_id + ) self.log_stream_name = stream_name self._log_backend.ensure_log_group(log_group, None) self._log_backend.create_log_stream(log_group, stream_name) self._log_backend.put_log_events(log_group, stream_name, logs, None) except Exception as err: - logger.error('Failed to run AWS Batch container {0}. Error {1}'.format(self.name, err)) - self.job_state = 'FAILED' + logger.error( + "Failed to run AWS Batch container {0}. Error {1}".format( + self.name, err + ) + ) + self.job_state = "FAILED" container.kill() finally: container.remove() except Exception as err: - logger.error('Failed to run AWS Batch container {0}. Error {1}'.format(self.name, err)) - self.job_state = 'FAILED' + logger.error( + "Failed to run AWS Batch container {0}. Error {1}".format( + self.name, err + ) + ) + self.job_state = "FAILED" self.job_stopped = True self.job_stopped_at = datetime.datetime.now() @@ -426,7 +520,7 @@ class BatchBackend(BaseBackend): :return: IAM Backend :rtype: moto.iam.models.IAMBackend """ - return iam_backends['global'] + return iam_backends["global"] @property def ec2_backend(self): @@ -456,7 +550,7 @@ class BatchBackend(BaseBackend): region_name = self.region_name for job in self._jobs.values(): - if job.job_state not in ('FAILED', 'SUCCEEDED'): + if job.job_state not in ("FAILED", "SUCCEEDED"): job.stop = True # Try to join job.join(0.2) @@ -539,8 +633,10 @@ class BatchBackend(BaseBackend): """ job_def = self.get_job_definition_by_arn(identifier) if job_def is None: - if ':' in identifier: - job_def = self.get_job_definition_by_name_revision(*identifier.split(':', 1)) + if ":" in identifier: + job_def = self.get_job_definition_by_name_revision( + *identifier.split(":", 1) + ) else: job_def = self.get_job_definition_by_name(identifier) return job_def @@ -579,7 +675,9 @@ class BatchBackend(BaseBackend): except KeyError: return None - def describe_compute_environments(self, environments=None, max_results=None, next_token=None): + def describe_compute_environments( + self, environments=None, max_results=None, next_token=None + ): envs = set() if environments is not None: envs = set(environments) @@ -591,82 +689,107 @@ class BatchBackend(BaseBackend): continue json_part = { - 'computeEnvironmentArn': arn, - 'computeEnvironmentName': environment.name, - 'ecsClusterArn': environment.ecs_arn, - 'serviceRole': environment.service_role, - 'state': environment.state, - 'type': environment.env_type, - 'status': 'VALID' + "computeEnvironmentArn": arn, + "computeEnvironmentName": environment.name, + "ecsClusterArn": environment.ecs_arn, + "serviceRole": environment.service_role, + "state": environment.state, + "type": environment.env_type, + "status": "VALID", } - if environment.env_type == 'MANAGED': - json_part['computeResources'] = environment.compute_resources + if environment.env_type == "MANAGED": + json_part["computeResources"] = environment.compute_resources result.append(json_part) return result - def create_compute_environment(self, compute_environment_name, _type, state, compute_resources, service_role): + def create_compute_environment( + self, compute_environment_name, _type, state, compute_resources, service_role + ): # Validate if COMPUTE_ENVIRONMENT_NAME_REGEX.match(compute_environment_name) is None: - raise InvalidParameterValueException('Compute environment name does not match ^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$') + raise InvalidParameterValueException( + "Compute environment name does not match ^[A-Za-z0-9][A-Za-z0-9_-]{1,126}[A-Za-z0-9]$" + ) if self.get_compute_environment_by_name(compute_environment_name) is not None: - raise InvalidParameterValueException('A compute environment already exists with the name {0}'.format(compute_environment_name)) + raise InvalidParameterValueException( + "A compute environment already exists with the name {0}".format( + compute_environment_name + ) + ) # Look for IAM role try: self.iam_backend.get_role_by_arn(service_role) except IAMNotFoundException: - raise InvalidParameterValueException('Could not find IAM role {0}'.format(service_role)) + raise InvalidParameterValueException( + "Could not find IAM role {0}".format(service_role) + ) - if _type not in ('MANAGED', 'UNMANAGED'): - raise InvalidParameterValueException('type {0} must be one of MANAGED | UNMANAGED'.format(service_role)) + if _type not in ("MANAGED", "UNMANAGED"): + raise InvalidParameterValueException( + "type {0} must be one of MANAGED | UNMANAGED".format(service_role) + ) - if state is not None and state not in ('ENABLED', 'DISABLED'): - raise InvalidParameterValueException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state is not None and state not in ("ENABLED", "DISABLED"): + raise InvalidParameterValueException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) - if compute_resources is None and _type == 'MANAGED': - raise InvalidParameterValueException('computeResources must be specified when creating a MANAGED environment'.format(state)) + if compute_resources is None and _type == "MANAGED": + raise InvalidParameterValueException( + "computeResources must be specified when creating a MANAGED environment".format( + state + ) + ) elif compute_resources is not None: self._validate_compute_resources(compute_resources) # By here, all values except SPOT ones have been validated new_comp_env = ComputeEnvironment( - compute_environment_name, _type, state, - compute_resources, service_role, - region_name=self.region_name + compute_environment_name, + _type, + state, + compute_resources, + service_role, + region_name=self.region_name, ) self._compute_environments[new_comp_env.arn] = new_comp_env # Ok by this point, everything is legit, so if its Managed then start some instances - if _type == 'MANAGED': - cpus = int(compute_resources.get('desiredvCpus', compute_resources['minvCpus'])) - instance_types = compute_resources['instanceTypes'] - needed_instance_types = self.find_min_instances_to_meet_vcpus(instance_types, cpus) + if _type == "MANAGED": + cpus = int( + compute_resources.get("desiredvCpus", compute_resources["minvCpus"]) + ) + instance_types = compute_resources["instanceTypes"] + needed_instance_types = self.find_min_instances_to_meet_vcpus( + instance_types, cpus + ) # Create instances # Will loop over and over so we get decent subnet coverage - subnet_cycle = cycle(compute_resources['subnets']) + subnet_cycle = cycle(compute_resources["subnets"]) for instance_type in needed_instance_types: reservation = self.ec2_backend.add_instances( - image_id='ami-ecs-optimised', # Todo import AMIs + image_id="ami-ecs-optimised", # Todo import AMIs count=1, user_data=None, security_group_names=[], instance_type=instance_type, region_name=self.region_name, subnet_id=six.next(subnet_cycle), - key_name=compute_resources.get('ec2KeyPair', 'AWS_OWNED'), - security_group_ids=compute_resources['securityGroupIds'] + key_name=compute_resources.get("ec2KeyPair", "AWS_OWNED"), + security_group_ids=compute_resources["securityGroupIds"], ) new_comp_env.add_instance(reservation.instances[0]) # Create ECS cluster # Should be of format P2OnDemand_Batch_UUID - cluster_name = 'OnDemand_Batch_' + str(uuid.uuid4()) + cluster_name = "OnDemand_Batch_" + str(uuid.uuid4()) ecs_cluster = self.ecs_backend.create_cluster(cluster_name) new_comp_env.set_ecs(ecs_cluster.arn, cluster_name) @@ -679,47 +802,73 @@ class BatchBackend(BaseBackend): :param cr: computeResources :type cr: dict """ - for param in ('instanceRole', 'maxvCpus', 'minvCpus', 'instanceTypes', 'securityGroupIds', 'subnets', 'type'): + for param in ( + "instanceRole", + "maxvCpus", + "minvCpus", + "instanceTypes", + "securityGroupIds", + "subnets", + "type", + ): if param not in cr: - raise InvalidParameterValueException('computeResources must contain {0}'.format(param)) + raise InvalidParameterValueException( + "computeResources must contain {0}".format(param) + ) - if self.iam_backend.get_role_by_arn(cr['instanceRole']) is None: - raise InvalidParameterValueException('could not find instanceRole {0}'.format(cr['instanceRole'])) + if self.iam_backend.get_role_by_arn(cr["instanceRole"]) is None: + raise InvalidParameterValueException( + "could not find instanceRole {0}".format(cr["instanceRole"]) + ) - if cr['maxvCpus'] < 0: - raise InvalidParameterValueException('maxVCpus must be positive') - if cr['minvCpus'] < 0: - raise InvalidParameterValueException('minVCpus must be positive') - if cr['maxvCpus'] < cr['minvCpus']: - raise InvalidParameterValueException('maxVCpus must be greater than minvCpus') + if cr["maxvCpus"] < 0: + raise InvalidParameterValueException("maxVCpus must be positive") + if cr["minvCpus"] < 0: + raise InvalidParameterValueException("minVCpus must be positive") + if cr["maxvCpus"] < cr["minvCpus"]: + raise InvalidParameterValueException( + "maxVCpus must be greater than minvCpus" + ) - if len(cr['instanceTypes']) == 0: - raise InvalidParameterValueException('At least 1 instance type must be provided') - for instance_type in cr['instanceTypes']: - if instance_type == 'optimal': + if len(cr["instanceTypes"]) == 0: + raise InvalidParameterValueException( + "At least 1 instance type must be provided" + ) + for instance_type in cr["instanceTypes"]: + if instance_type == "optimal": pass # Optimal should pick from latest of current gen elif instance_type not in EC2_INSTANCE_TYPES: - raise InvalidParameterValueException('Instance type {0} does not exist'.format(instance_type)) + raise InvalidParameterValueException( + "Instance type {0} does not exist".format(instance_type) + ) - for sec_id in cr['securityGroupIds']: + for sec_id in cr["securityGroupIds"]: if self.ec2_backend.get_security_group_from_id(sec_id) is None: - raise InvalidParameterValueException('security group {0} does not exist'.format(sec_id)) - if len(cr['securityGroupIds']) == 0: - raise InvalidParameterValueException('At least 1 security group must be provided') + raise InvalidParameterValueException( + "security group {0} does not exist".format(sec_id) + ) + if len(cr["securityGroupIds"]) == 0: + raise InvalidParameterValueException( + "At least 1 security group must be provided" + ) - for subnet_id in cr['subnets']: + for subnet_id in cr["subnets"]: try: self.ec2_backend.get_subnet(subnet_id) except InvalidSubnetIdError: - raise InvalidParameterValueException('subnet {0} does not exist'.format(subnet_id)) - if len(cr['subnets']) == 0: - raise InvalidParameterValueException('At least 1 subnet must be provided') + raise InvalidParameterValueException( + "subnet {0} does not exist".format(subnet_id) + ) + if len(cr["subnets"]) == 0: + raise InvalidParameterValueException("At least 1 subnet must be provided") - if cr['type'] not in ('EC2', 'SPOT'): - raise InvalidParameterValueException('computeResources.type must be either EC2 | SPOT') + if cr["type"] not in ("EC2", "SPOT"): + raise InvalidParameterValueException( + "computeResources.type must be either EC2 | SPOT" + ) - if cr['type'] == 'SPOT': - raise InternalFailure('SPOT NOT SUPPORTED YET') + if cr["type"] == "SPOT": + raise InternalFailure("SPOT NOT SUPPORTED YET") @staticmethod def find_min_instances_to_meet_vcpus(instance_types, target): @@ -738,11 +887,11 @@ class BatchBackend(BaseBackend): instances = [] for instance_type in instance_types: - if instance_type == 'optimal': - instance_type = 'm4.4xlarge' + if instance_type == "optimal": + instance_type = "m4.4xlarge" instance_vcpus.append( - (EC2_INSTANCE_TYPES[instance_type]['vcpus'], instance_type) + (EC2_INSTANCE_TYPES[instance_type]["vcpus"], instance_type) ) instance_vcpus = sorted(instance_vcpus, key=lambda item: item[0], reverse=True) @@ -773,7 +922,7 @@ class BatchBackend(BaseBackend): def delete_compute_environment(self, compute_environment_name): if compute_environment_name is None: - raise InvalidParameterValueException('Missing computeEnvironment parameter') + raise InvalidParameterValueException("Missing computeEnvironment parameter") compute_env = self.get_compute_environment(compute_environment_name) @@ -784,29 +933,35 @@ class BatchBackend(BaseBackend): # Delete ECS cluster self.ecs_backend.delete_cluster(compute_env.ecs_name) - if compute_env.env_type == 'MANAGED': + if compute_env.env_type == "MANAGED": # Delete compute envrionment instance_ids = [instance.id for instance in compute_env.instances] self.ec2_backend.terminate_instances(instance_ids) - def update_compute_environment(self, compute_environment_name, state, compute_resources, service_role): + def update_compute_environment( + self, compute_environment_name, state, compute_resources, service_role + ): # Validate compute_env = self.get_compute_environment(compute_environment_name) if compute_env is None: - raise ClientException('Compute environment {0} does not exist') + raise ClientException("Compute environment {0} does not exist") # Look for IAM role if service_role is not None: try: role = self.iam_backend.get_role_by_arn(service_role) except IAMNotFoundException: - raise InvalidParameterValueException('Could not find IAM role {0}'.format(service_role)) + raise InvalidParameterValueException( + "Could not find IAM role {0}".format(service_role) + ) compute_env.service_role = role if state is not None: - if state not in ('ENABLED', 'DISABLED'): - raise InvalidParameterValueException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state not in ("ENABLED", "DISABLED"): + raise InvalidParameterValueException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) compute_env.state = state @@ -832,32 +987,51 @@ class BatchBackend(BaseBackend): :return: Tuple of Name, ARN :rtype: tuple of str """ - for variable, var_name in ((queue_name, 'jobQueueName'), (priority, 'priority'), (state, 'state'), (compute_env_order, 'computeEnvironmentOrder')): + for variable, var_name in ( + (queue_name, "jobQueueName"), + (priority, "priority"), + (state, "state"), + (compute_env_order, "computeEnvironmentOrder"), + ): if variable is None: - raise ClientException('{0} must be provided'.format(var_name)) + raise ClientException("{0} must be provided".format(var_name)) - if state not in ('ENABLED', 'DISABLED'): - raise ClientException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state not in ("ENABLED", "DISABLED"): + raise ClientException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) if self.get_job_queue_by_name(queue_name) is not None: - raise ClientException('Job queue {0} already exists'.format(queue_name)) + raise ClientException("Job queue {0} already exists".format(queue_name)) if len(compute_env_order) == 0: - raise ClientException('At least 1 compute environment must be provided') + raise ClientException("At least 1 compute environment must be provided") try: # orders and extracts computeEnvironment names - ordered_compute_environments = [item['computeEnvironment'] for item in sorted(compute_env_order, key=lambda x: x['order'])] + ordered_compute_environments = [ + item["computeEnvironment"] + for item in sorted(compute_env_order, key=lambda x: x["order"]) + ] env_objects = [] # Check each ARN exists, then make a list of compute env's for arn in ordered_compute_environments: env = self.get_compute_environment_by_arn(arn) if env is None: - raise ClientException('Compute environment {0} does not exist'.format(arn)) + raise ClientException( + "Compute environment {0} does not exist".format(arn) + ) env_objects.append(env) except Exception: - raise ClientException('computeEnvironmentOrder is malformed') + raise ClientException("computeEnvironmentOrder is malformed") # Create new Job Queue - queue = JobQueue(queue_name, priority, state, env_objects, compute_env_order, self.region_name) + queue = JobQueue( + queue_name, + priority, + state, + env_objects, + compute_env_order, + self.region_name, + ) self._job_queues[queue.arn] = queue return queue_name, queue.arn @@ -893,33 +1067,40 @@ class BatchBackend(BaseBackend): :rtype: tuple of str """ if queue_name is None: - raise ClientException('jobQueueName must be provided') + raise ClientException("jobQueueName must be provided") job_queue = self.get_job_queue(queue_name) if job_queue is None: - raise ClientException('Job queue {0} does not exist'.format(queue_name)) + raise ClientException("Job queue {0} does not exist".format(queue_name)) if state is not None: - if state not in ('ENABLED', 'DISABLED'): - raise ClientException('state {0} must be one of ENABLED | DISABLED'.format(state)) + if state not in ("ENABLED", "DISABLED"): + raise ClientException( + "state {0} must be one of ENABLED | DISABLED".format(state) + ) job_queue.state = state if compute_env_order is not None: if len(compute_env_order) == 0: - raise ClientException('At least 1 compute environment must be provided') + raise ClientException("At least 1 compute environment must be provided") try: # orders and extracts computeEnvironment names - ordered_compute_environments = [item['computeEnvironment'] for item in sorted(compute_env_order, key=lambda x: x['order'])] + ordered_compute_environments = [ + item["computeEnvironment"] + for item in sorted(compute_env_order, key=lambda x: x["order"]) + ] env_objects = [] # Check each ARN exists, then make a list of compute env's for arn in ordered_compute_environments: env = self.get_compute_environment_by_arn(arn) if env is None: - raise ClientException('Compute environment {0} does not exist'.format(arn)) + raise ClientException( + "Compute environment {0} does not exist".format(arn) + ) env_objects.append(env) except Exception: - raise ClientException('computeEnvironmentOrder is malformed') + raise ClientException("computeEnvironmentOrder is malformed") job_queue.env_order_json = compute_env_order job_queue.environments = env_objects @@ -935,22 +1116,33 @@ class BatchBackend(BaseBackend): if job_queue is not None: del self._job_queues[job_queue.arn] - def register_job_definition(self, def_name, parameters, _type, retry_strategy, container_properties): + def register_job_definition( + self, def_name, parameters, _type, retry_strategy, container_properties + ): if def_name is None: - raise ClientException('jobDefinitionName must be provided') + raise ClientException("jobDefinitionName must be provided") job_def = self.get_job_definition_by_name(def_name) if retry_strategy is not None: try: - retry_strategy = retry_strategy['attempts'] + retry_strategy = retry_strategy["attempts"] except Exception: - raise ClientException('retryStrategy is malformed') + raise ClientException("retryStrategy is malformed") if job_def is None: - job_def = JobDefinition(def_name, parameters, _type, container_properties, region_name=self.region_name, retry_strategy=retry_strategy) + job_def = JobDefinition( + def_name, + parameters, + _type, + container_properties, + region_name=self.region_name, + retry_strategy=retry_strategy, + ) else: # Make new jobdef - job_def = job_def.update(parameters, _type, container_properties, retry_strategy) + job_def = job_def.update( + parameters, _type, container_properties, retry_strategy + ) self._job_definitions[job_def.arn] = job_def @@ -958,14 +1150,21 @@ class BatchBackend(BaseBackend): def deregister_job_definition(self, def_name): job_def = self.get_job_definition_by_arn(def_name) - if job_def is None and ':' in def_name: - name, revision = def_name.split(':', 1) + if job_def is None and ":" in def_name: + name, revision = def_name.split(":", 1) job_def = self.get_job_definition_by_name_revision(name, revision) if job_def is not None: del self._job_definitions[job_def.arn] - def describe_job_definitions(self, job_def_name=None, job_def_list=None, status=None, max_results=None, next_token=None): + def describe_job_definitions( + self, + job_def_name=None, + job_def_list=None, + status=None, + max_results=None, + next_token=None, + ): jobs = [] # As a job name can reference multiple revisions, we get a list of them @@ -986,17 +1185,28 @@ class BatchBackend(BaseBackend): return [job for job in jobs if job.status == status] return jobs - def submit_job(self, job_name, job_def_id, job_queue, parameters=None, retries=None, depends_on=None, container_overrides=None): + def submit_job( + self, + job_name, + job_def_id, + job_queue, + parameters=None, + retries=None, + depends_on=None, + container_overrides=None, + ): # TODO parameters, retries (which is a dict raw from request), job dependancies and container overrides are ignored for now # Look for job definition job_def = self.get_job_definition(job_def_id) if job_def is None: - raise ClientException('Job definition {0} does not exist'.format(job_def_id)) + raise ClientException( + "Job definition {0} does not exist".format(job_def_id) + ) queue = self.get_job_queue(job_queue) if queue is None: - raise ClientException('Job queue {0} does not exist'.format(job_queue)) + raise ClientException("Job queue {0} does not exist".format(job_queue)) job = Job(job_name, job_def, queue, log_backend=self.logs_backend) self._jobs[job.job_id] = job @@ -1025,10 +1235,20 @@ class BatchBackend(BaseBackend): job_queue = self.get_job_queue(job_queue) if job_queue is None: - raise ClientException('Job queue {0} does not exist'.format(job_queue)) + raise ClientException("Job queue {0} does not exist".format(job_queue)) - if job_status is not None and job_status not in ('SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING', 'SUCCEEDED', 'FAILED'): - raise ClientException('Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED') + if job_status is not None and job_status not in ( + "SUBMITTED", + "PENDING", + "RUNNABLE", + "STARTING", + "RUNNING", + "SUCCEEDED", + "FAILED", + ): + raise ClientException( + "Job status is not one of SUBMITTED | PENDING | RUNNABLE | STARTING | RUNNING | SUCCEEDED | FAILED" + ) for job in job_queue.jobs: if job_status is not None and job.job_state != job_status: @@ -1040,16 +1260,18 @@ class BatchBackend(BaseBackend): def terminate_job(self, job_id, reason): if job_id is None: - raise ClientException('Job ID does not exist') + raise ClientException("Job ID does not exist") if reason is None: - raise ClientException('Reason does not exist') + raise ClientException("Reason does not exist") job = self.get_job_by_id(job_id) if job is None: - raise ClientException('Job not found') + raise ClientException("Job not found") job.terminate(reason) available_regions = boto3.session.Session().get_available_regions("batch") -batch_backends = {region: BatchBackend(region_name=region) for region in available_regions} +batch_backends = { + region: BatchBackend(region_name=region) for region in available_regions +} diff --git a/moto/batch/responses.py b/moto/batch/responses.py index 7fb606184..61b00e9c9 100644 --- a/moto/batch/responses.py +++ b/moto/batch/responses.py @@ -10,7 +10,7 @@ import json class BatchResponse(BaseResponse): def _error(self, code, message): - return json.dumps({'__type': code, 'message': message}), dict(status=400) + return json.dumps({"__type": code, "message": message}), dict(status=400) @property def batch_backend(self): @@ -22,9 +22,9 @@ class BatchResponse(BaseResponse): @property def json(self): - if self.body is None or self.body == '': + if self.body is None or self.body == "": self._json = {} - elif not hasattr(self, '_json'): + elif not hasattr(self, "_json"): try: self._json = json.loads(self.body) except ValueError: @@ -39,153 +39,146 @@ class BatchResponse(BaseResponse): def _get_action(self): # Return element after the /v1/* - return urlsplit(self.uri).path.lstrip('/').split('/')[1] + return urlsplit(self.uri).path.lstrip("/").split("/")[1] # CreateComputeEnvironment def createcomputeenvironment(self): - compute_env_name = self._get_param('computeEnvironmentName') - compute_resource = self._get_param('computeResources') - service_role = self._get_param('serviceRole') - state = self._get_param('state') - _type = self._get_param('type') + compute_env_name = self._get_param("computeEnvironmentName") + compute_resource = self._get_param("computeResources") + service_role = self._get_param("serviceRole") + state = self._get_param("state") + _type = self._get_param("type") try: name, arn = self.batch_backend.create_compute_environment( compute_environment_name=compute_env_name, - _type=_type, state=state, + _type=_type, + state=state, compute_resources=compute_resource, - service_role=service_role + service_role=service_role, ) except AWSError as err: return err.response() - result = { - 'computeEnvironmentArn': arn, - 'computeEnvironmentName': name - } + result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name} return json.dumps(result) # DescribeComputeEnvironments def describecomputeenvironments(self): - compute_environments = self._get_param('computeEnvironments') - max_results = self._get_param('maxResults') # Ignored, should be int - next_token = self._get_param('nextToken') # Ignored + compute_environments = self._get_param("computeEnvironments") + max_results = self._get_param("maxResults") # Ignored, should be int + next_token = self._get_param("nextToken") # Ignored - envs = self.batch_backend.describe_compute_environments(compute_environments, max_results=max_results, next_token=next_token) + envs = self.batch_backend.describe_compute_environments( + compute_environments, max_results=max_results, next_token=next_token + ) - result = {'computeEnvironments': envs} + result = {"computeEnvironments": envs} return json.dumps(result) # DeleteComputeEnvironment def deletecomputeenvironment(self): - compute_environment = self._get_param('computeEnvironment') + compute_environment = self._get_param("computeEnvironment") try: self.batch_backend.delete_compute_environment(compute_environment) except AWSError as err: return err.response() - return '' + return "" # UpdateComputeEnvironment def updatecomputeenvironment(self): - compute_env_name = self._get_param('computeEnvironment') - compute_resource = self._get_param('computeResources') - service_role = self._get_param('serviceRole') - state = self._get_param('state') + compute_env_name = self._get_param("computeEnvironment") + compute_resource = self._get_param("computeResources") + service_role = self._get_param("serviceRole") + state = self._get_param("state") try: name, arn = self.batch_backend.update_compute_environment( compute_environment_name=compute_env_name, compute_resources=compute_resource, service_role=service_role, - state=state + state=state, ) except AWSError as err: return err.response() - result = { - 'computeEnvironmentArn': arn, - 'computeEnvironmentName': name - } + result = {"computeEnvironmentArn": arn, "computeEnvironmentName": name} return json.dumps(result) # CreateJobQueue def createjobqueue(self): - compute_env_order = self._get_param('computeEnvironmentOrder') - queue_name = self._get_param('jobQueueName') - priority = self._get_param('priority') - state = self._get_param('state') + compute_env_order = self._get_param("computeEnvironmentOrder") + queue_name = self._get_param("jobQueueName") + priority = self._get_param("priority") + state = self._get_param("state") try: name, arn = self.batch_backend.create_job_queue( queue_name=queue_name, priority=priority, state=state, - compute_env_order=compute_env_order + compute_env_order=compute_env_order, ) except AWSError as err: return err.response() - result = { - 'jobQueueArn': arn, - 'jobQueueName': name - } + result = {"jobQueueArn": arn, "jobQueueName": name} return json.dumps(result) # DescribeJobQueues def describejobqueues(self): - job_queues = self._get_param('jobQueues') - max_results = self._get_param('maxResults') # Ignored, should be int - next_token = self._get_param('nextToken') # Ignored + job_queues = self._get_param("jobQueues") + max_results = self._get_param("maxResults") # Ignored, should be int + next_token = self._get_param("nextToken") # Ignored - queues = self.batch_backend.describe_job_queues(job_queues, max_results=max_results, next_token=next_token) + queues = self.batch_backend.describe_job_queues( + job_queues, max_results=max_results, next_token=next_token + ) - result = {'jobQueues': queues} + result = {"jobQueues": queues} return json.dumps(result) # UpdateJobQueue def updatejobqueue(self): - compute_env_order = self._get_param('computeEnvironmentOrder') - queue_name = self._get_param('jobQueue') - priority = self._get_param('priority') - state = self._get_param('state') + compute_env_order = self._get_param("computeEnvironmentOrder") + queue_name = self._get_param("jobQueue") + priority = self._get_param("priority") + state = self._get_param("state") try: name, arn = self.batch_backend.update_job_queue( queue_name=queue_name, priority=priority, state=state, - compute_env_order=compute_env_order + compute_env_order=compute_env_order, ) except AWSError as err: return err.response() - result = { - 'jobQueueArn': arn, - 'jobQueueName': name - } + result = {"jobQueueArn": arn, "jobQueueName": name} return json.dumps(result) # DeleteJobQueue def deletejobqueue(self): - queue_name = self._get_param('jobQueue') + queue_name = self._get_param("jobQueue") self.batch_backend.delete_job_queue(queue_name) - return '' + return "" # RegisterJobDefinition def registerjobdefinition(self): - container_properties = self._get_param('containerProperties') - def_name = self._get_param('jobDefinitionName') - parameters = self._get_param('parameters') - retry_strategy = self._get_param('retryStrategy') - _type = self._get_param('type') + container_properties = self._get_param("containerProperties") + def_name = self._get_param("jobDefinitionName") + parameters = self._get_param("parameters") + retry_strategy = self._get_param("retryStrategy") + _type = self._get_param("type") try: name, arn, revision = self.batch_backend.register_job_definition( @@ -193,104 +186,113 @@ class BatchResponse(BaseResponse): parameters=parameters, _type=_type, retry_strategy=retry_strategy, - container_properties=container_properties + container_properties=container_properties, ) except AWSError as err: return err.response() result = { - 'jobDefinitionArn': arn, - 'jobDefinitionName': name, - 'revision': revision + "jobDefinitionArn": arn, + "jobDefinitionName": name, + "revision": revision, } return json.dumps(result) # DeregisterJobDefinition def deregisterjobdefinition(self): - queue_name = self._get_param('jobDefinition') + queue_name = self._get_param("jobDefinition") self.batch_backend.deregister_job_definition(queue_name) - return '' + return "" # DescribeJobDefinitions def describejobdefinitions(self): - job_def_name = self._get_param('jobDefinitionName') - job_def_list = self._get_param('jobDefinitions') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') - status = self._get_param('status') + job_def_name = self._get_param("jobDefinitionName") + job_def_list = self._get_param("jobDefinitions") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") + status = self._get_param("status") - job_defs = self.batch_backend.describe_job_definitions(job_def_name, job_def_list, status, max_results, next_token) + job_defs = self.batch_backend.describe_job_definitions( + job_def_name, job_def_list, status, max_results, next_token + ) - result = {'jobDefinitions': [job.describe() for job in job_defs]} + result = {"jobDefinitions": [job.describe() for job in job_defs]} return json.dumps(result) # SubmitJob def submitjob(self): - container_overrides = self._get_param('containerOverrides') - depends_on = self._get_param('dependsOn') - job_def = self._get_param('jobDefinition') - job_name = self._get_param('jobName') - job_queue = self._get_param('jobQueue') - parameters = self._get_param('parameters') - retries = self._get_param('retryStrategy') + container_overrides = self._get_param("containerOverrides") + depends_on = self._get_param("dependsOn") + job_def = self._get_param("jobDefinition") + job_name = self._get_param("jobName") + job_queue = self._get_param("jobQueue") + parameters = self._get_param("parameters") + retries = self._get_param("retryStrategy") try: name, job_id = self.batch_backend.submit_job( - job_name, job_def, job_queue, + job_name, + job_def, + job_queue, parameters=parameters, retries=retries, depends_on=depends_on, - container_overrides=container_overrides + container_overrides=container_overrides, ) except AWSError as err: return err.response() - result = { - 'jobId': job_id, - 'jobName': name, - } + result = {"jobId": job_id, "jobName": name} return json.dumps(result) # DescribeJobs def describejobs(self): - jobs = self._get_param('jobs') + jobs = self._get_param("jobs") try: - return json.dumps({'jobs': self.batch_backend.describe_jobs(jobs)}) + return json.dumps({"jobs": self.batch_backend.describe_jobs(jobs)}) except AWSError as err: return err.response() # ListJobs def listjobs(self): - job_queue = self._get_param('jobQueue') - job_status = self._get_param('jobStatus') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') + job_queue = self._get_param("jobQueue") + job_status = self._get_param("jobStatus") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") try: - jobs = self.batch_backend.list_jobs(job_queue, job_status, max_results, next_token) + jobs = self.batch_backend.list_jobs( + job_queue, job_status, max_results, next_token + ) except AWSError as err: return err.response() - result = {'jobSummaryList': [{'jobId': job.job_id, 'jobName': job.job_name} for job in jobs]} + result = { + "jobSummaryList": [ + {"jobId": job.job_id, "jobName": job.job_name} for job in jobs + ] + } return json.dumps(result) # TerminateJob def terminatejob(self): - job_id = self._get_param('jobId') - reason = self._get_param('reason') + job_id = self._get_param("jobId") + reason = self._get_param("reason") try: self.batch_backend.terminate_job(job_id, reason) except AWSError as err: return err.response() - return '' + return "" # CancelJob - def canceljob(self): # Theres some AWS semantics on the differences but for us they're identical ;-) + def canceljob( + self, + ): # Theres some AWS semantics on the differences but for us they're identical ;-) return self.terminatejob() diff --git a/moto/batch/urls.py b/moto/batch/urls.py index c64086ef2..9dc507416 100644 --- a/moto/batch/urls.py +++ b/moto/batch/urls.py @@ -1,25 +1,23 @@ from __future__ import unicode_literals from .responses import BatchResponse -url_bases = [ - "https?://batch.(.+).amazonaws.com", -] +url_bases = ["https?://batch.(.+).amazonaws.com"] url_paths = { - '{0}/v1/createcomputeenvironment$': BatchResponse.dispatch, - '{0}/v1/describecomputeenvironments$': BatchResponse.dispatch, - '{0}/v1/deletecomputeenvironment': BatchResponse.dispatch, - '{0}/v1/updatecomputeenvironment': BatchResponse.dispatch, - '{0}/v1/createjobqueue': BatchResponse.dispatch, - '{0}/v1/describejobqueues': BatchResponse.dispatch, - '{0}/v1/updatejobqueue': BatchResponse.dispatch, - '{0}/v1/deletejobqueue': BatchResponse.dispatch, - '{0}/v1/registerjobdefinition': BatchResponse.dispatch, - '{0}/v1/deregisterjobdefinition': BatchResponse.dispatch, - '{0}/v1/describejobdefinitions': BatchResponse.dispatch, - '{0}/v1/submitjob': BatchResponse.dispatch, - '{0}/v1/describejobs': BatchResponse.dispatch, - '{0}/v1/listjobs': BatchResponse.dispatch, - '{0}/v1/terminatejob': BatchResponse.dispatch, - '{0}/v1/canceljob': BatchResponse.dispatch, + "{0}/v1/createcomputeenvironment$": BatchResponse.dispatch, + "{0}/v1/describecomputeenvironments$": BatchResponse.dispatch, + "{0}/v1/deletecomputeenvironment": BatchResponse.dispatch, + "{0}/v1/updatecomputeenvironment": BatchResponse.dispatch, + "{0}/v1/createjobqueue": BatchResponse.dispatch, + "{0}/v1/describejobqueues": BatchResponse.dispatch, + "{0}/v1/updatejobqueue": BatchResponse.dispatch, + "{0}/v1/deletejobqueue": BatchResponse.dispatch, + "{0}/v1/registerjobdefinition": BatchResponse.dispatch, + "{0}/v1/deregisterjobdefinition": BatchResponse.dispatch, + "{0}/v1/describejobdefinitions": BatchResponse.dispatch, + "{0}/v1/submitjob": BatchResponse.dispatch, + "{0}/v1/describejobs": BatchResponse.dispatch, + "{0}/v1/listjobs": BatchResponse.dispatch, + "{0}/v1/terminatejob": BatchResponse.dispatch, + "{0}/v1/canceljob": BatchResponse.dispatch, } diff --git a/moto/batch/utils.py b/moto/batch/utils.py index 829a55f12..ce9b2ffe8 100644 --- a/moto/batch/utils.py +++ b/moto/batch/utils.py @@ -2,7 +2,9 @@ from __future__ import unicode_literals def make_arn_for_compute_env(account_id, name, region_name): - return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format(region_name, account_id, name) + return "arn:aws:batch:{0}:{1}:compute-environment/{2}".format( + region_name, account_id, name + ) def make_arn_for_job_queue(account_id, name, region_name): @@ -10,7 +12,9 @@ def make_arn_for_job_queue(account_id, name, region_name): def make_arn_for_task_def(account_id, name, revision, region_name): - return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format(region_name, account_id, name, revision) + return "arn:aws:batch:{0}:{1}:job-definition/{2}:{3}".format( + region_name, account_id, name, revision + ) def lowercase_first_key(some_dict): diff --git a/moto/cloudformation/__init__.py b/moto/cloudformation/__init__.py index b73e3ab6c..351af146c 100644 --- a/moto/cloudformation/__init__.py +++ b/moto/cloudformation/__init__.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals from .models import cloudformation_backends from ..core.models import base_decorator, deprecated_base_decorator -cloudformation_backend = cloudformation_backends['us-east-1'] +cloudformation_backend = cloudformation_backends["us-east-1"] mock_cloudformation = base_decorator(cloudformation_backends) -mock_cloudformation_deprecated = deprecated_base_decorator( - cloudformation_backends) +mock_cloudformation_deprecated = deprecated_base_decorator(cloudformation_backends) diff --git a/moto/cloudformation/exceptions.py b/moto/cloudformation/exceptions.py index 6ea15c5ca..10669ca56 100644 --- a/moto/cloudformation/exceptions.py +++ b/moto/cloudformation/exceptions.py @@ -4,26 +4,23 @@ from jinja2 import Template class UnformattedGetAttTemplateException(Exception): - description = 'Template error: resource {0} does not support attribute type {1} in Fn::GetAtt' + description = ( + "Template error: resource {0} does not support attribute type {1} in Fn::GetAtt" + ) status_code = 400 class ValidationError(BadRequest): - def __init__(self, name_or_id, message=None): if message is None: message = "Stack with id {0} does not exist".format(name_or_id) template = Template(ERROR_RESPONSE) super(ValidationError, self).__init__() - self.description = template.render( - code="ValidationError", - message=message, - ) + self.description = template.render(code="ValidationError", message=message) class MissingParameterError(BadRequest): - def __init__(self, parameter_name): template = Template(ERROR_RESPONSE) super(MissingParameterError, self).__init__() @@ -40,8 +37,8 @@ class ExportNotFound(BadRequest): template = Template(ERROR_RESPONSE) super(ExportNotFound, self).__init__() self.description = template.render( - code='ExportNotFound', - message="No export named {0} found.".format(export_name) + code="ExportNotFound", + message="No export named {0} found.".format(export_name), ) diff --git a/moto/cloudformation/models.py b/moto/cloudformation/models.py index 01e3113dd..71ceaf168 100644 --- a/moto/cloudformation/models.py +++ b/moto/cloudformation/models.py @@ -21,11 +21,19 @@ from .exceptions import ValidationError class FakeStackSet(BaseModel): - - def __init__(self, stackset_id, name, template, region='us-east-1', - status='ACTIVE', description=None, parameters=None, tags=None, - admin_role='AWSCloudFormationStackSetAdministrationRole', - execution_role='AWSCloudFormationStackSetExecutionRole'): + def __init__( + self, + stackset_id, + name, + template, + region="us-east-1", + status="ACTIVE", + description=None, + parameters=None, + tags=None, + admin_role="AWSCloudFormationStackSetAdministrationRole", + execution_role="AWSCloudFormationStackSetExecutionRole", + ): self.id = stackset_id self.arn = generate_stackset_arn(stackset_id, region) self.name = name @@ -42,12 +50,14 @@ class FakeStackSet(BaseModel): def _create_operation(self, operation_id, action, status, accounts=[], regions=[]): operation = { - 'OperationId': str(operation_id), - 'Action': action, - 'Status': status, - 'CreationTimestamp': datetime.now(), - 'EndTimestamp': datetime.now() + timedelta(minutes=2), - 'Instances': [{account: region} for account in accounts for region in regions], + "OperationId": str(operation_id), + "Action": action, + "Status": status, + "CreationTimestamp": datetime.now(), + "EndTimestamp": datetime.now() + timedelta(minutes=2), + "Instances": [ + {account: region} for account in accounts for region in regions + ], } self.operations += [operation] @@ -55,20 +65,30 @@ class FakeStackSet(BaseModel): def get_operation(self, operation_id): for operation in self.operations: - if operation_id == operation['OperationId']: + if operation_id == operation["OperationId"]: return operation raise ValidationError(operation_id) def update_operation(self, operation_id, status): operation = self.get_operation(operation_id) - operation['Status'] = status + operation["Status"] = status return operation_id def delete(self): - self.status = 'DELETED' + self.status = "DELETED" - def update(self, template, description, parameters, tags, admin_role, - execution_role, accounts, regions, operation_id=None): + def update( + self, + template, + description, + parameters, + tags, + admin_role, + execution_role, + accounts, + regions, + operation_id=None, + ): if not operation_id: operation_id = uuid.uuid4() @@ -82,9 +102,13 @@ class FakeStackSet(BaseModel): if accounts and regions: self.update_instances(accounts, regions, self.parameters) - operation = self._create_operation(operation_id=operation_id, - action='UPDATE', status='SUCCEEDED', accounts=accounts, - regions=regions) + operation = self._create_operation( + operation_id=operation_id, + action="UPDATE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) return operation def create_stack_instances(self, accounts, regions, parameters, operation_id=None): @@ -94,8 +118,13 @@ class FakeStackSet(BaseModel): parameters = self.parameters self.instances.create_instances(accounts, regions, parameters, operation_id) - self._create_operation(operation_id=operation_id, action='CREATE', - status='SUCCEEDED', accounts=accounts, regions=regions) + self._create_operation( + operation_id=operation_id, + action="CREATE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) def delete_stack_instances(self, accounts, regions, operation_id=None): if not operation_id: @@ -103,8 +132,13 @@ class FakeStackSet(BaseModel): self.instances.delete(accounts, regions) - operation = self._create_operation(operation_id=operation_id, action='DELETE', - status='SUCCEEDED', accounts=accounts, regions=regions) + operation = self._create_operation( + operation_id=operation_id, + action="DELETE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) return operation def update_instances(self, accounts, regions, parameters, operation_id=None): @@ -112,9 +146,13 @@ class FakeStackSet(BaseModel): operation_id = uuid.uuid4() self.instances.update(accounts, regions, parameters) - operation = self._create_operation(operation_id=operation_id, - action='UPDATE', status='SUCCEEDED', accounts=accounts, - regions=regions) + operation = self._create_operation( + operation_id=operation_id, + action="UPDATE", + status="SUCCEEDED", + accounts=accounts, + regions=regions, + ) return operation @@ -131,12 +169,12 @@ class FakeStackInstances(BaseModel): for region in regions: for account in accounts: instance = { - 'StackId': generate_stack_id(self.stack_name, region, account), - 'StackSetId': self.stackset_id, - 'Region': region, - 'Account': account, - 'Status': "CURRENT", - 'ParameterOverrides': parameters if parameters else [], + "StackId": generate_stack_id(self.stack_name, region, account), + "StackSetId": self.stackset_id, + "Region": region, + "Account": account, + "Status": "CURRENT", + "ParameterOverrides": parameters if parameters else [], } new_instances.append(instance) self.stack_instances += new_instances @@ -147,24 +185,35 @@ class FakeStackInstances(BaseModel): for region in regions: instance = self.get_instance(account, region) if parameters: - instance['ParameterOverrides'] = parameters + instance["ParameterOverrides"] = parameters else: - instance['ParameterOverrides'] = [] + instance["ParameterOverrides"] = [] def delete(self, accounts, regions): for i, instance in enumerate(self.stack_instances): - if instance['Region'] in regions and instance['Account'] in accounts: + if instance["Region"] in regions and instance["Account"] in accounts: self.stack_instances.pop(i) def get_instance(self, account, region): for i, instance in enumerate(self.stack_instances): - if instance['Region'] == region and instance['Account'] == account: + if instance["Region"] == region and instance["Account"] == account: return self.stack_instances[i] class FakeStack(BaseModel): - - def __init__(self, stack_id, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None, create_change_set=False): + def __init__( + self, + stack_id, + name, + template, + parameters, + region_name, + notification_arns=None, + tags=None, + role_arn=None, + cross_stack_resources=None, + create_change_set=False, + ): self.stack_id = stack_id self.name = name self.template = template @@ -176,22 +225,31 @@ class FakeStack(BaseModel): self.tags = tags if tags else {} self.events = [] if create_change_set: - self._add_stack_event("REVIEW_IN_PROGRESS", - resource_status_reason="User Initiated") + self._add_stack_event( + "REVIEW_IN_PROGRESS", resource_status_reason="User Initiated" + ) else: - self._add_stack_event("CREATE_IN_PROGRESS", - resource_status_reason="User Initiated") + self._add_stack_event( + "CREATE_IN_PROGRESS", resource_status_reason="User Initiated" + ) - self.description = self.template_dict.get('Description') + self.description = self.template_dict.get("Description") self.cross_stack_resources = cross_stack_resources or {} self.resource_map = self._create_resource_map() self.output_map = self._create_output_map() self._add_stack_event("CREATE_COMPLETE") - self.status = 'CREATE_COMPLETE' + self.status = "CREATE_COMPLETE" def _create_resource_map(self): resource_map = ResourceMap( - self.stack_id, self.name, self.parameters, self.tags, self.region_name, self.template_dict, self.cross_stack_resources) + self.stack_id, + self.name, + self.parameters, + self.tags, + self.region_name, + self.template_dict, + self.cross_stack_resources, + ) resource_map.create() return resource_map @@ -200,34 +258,46 @@ class FakeStack(BaseModel): output_map.create() return output_map - def _add_stack_event(self, resource_status, resource_status_reason=None, resource_properties=None): - self.events.append(FakeEvent( - stack_id=self.stack_id, - stack_name=self.name, - logical_resource_id=self.name, - physical_resource_id=self.stack_id, - resource_type="AWS::CloudFormation::Stack", - resource_status=resource_status, - resource_status_reason=resource_status_reason, - resource_properties=resource_properties, - )) + def _add_stack_event( + self, resource_status, resource_status_reason=None, resource_properties=None + ): + self.events.append( + FakeEvent( + stack_id=self.stack_id, + stack_name=self.name, + logical_resource_id=self.name, + physical_resource_id=self.stack_id, + resource_type="AWS::CloudFormation::Stack", + resource_status=resource_status, + resource_status_reason=resource_status_reason, + resource_properties=resource_properties, + ) + ) - def _add_resource_event(self, logical_resource_id, resource_status, resource_status_reason=None, resource_properties=None): + def _add_resource_event( + self, + logical_resource_id, + resource_status, + resource_status_reason=None, + resource_properties=None, + ): # not used yet... feel free to help yourself resource = self.resource_map[logical_resource_id] - self.events.append(FakeEvent( - stack_id=self.stack_id, - stack_name=self.name, - logical_resource_id=logical_resource_id, - physical_resource_id=resource.physical_resource_id, - resource_type=resource.type, - resource_status=resource_status, - resource_status_reason=resource_status_reason, - resource_properties=resource_properties, - )) + self.events.append( + FakeEvent( + stack_id=self.stack_id, + stack_name=self.name, + logical_resource_id=logical_resource_id, + physical_resource_id=resource.physical_resource_id, + resource_type=resource.type, + resource_status=resource_status, + resource_status_reason=resource_status_reason, + resource_properties=resource_properties, + ) + ) def _parse_template(self): - yaml.add_multi_constructor('', yaml_tag_constructor) + yaml.add_multi_constructor("", yaml_tag_constructor) try: self.template_dict = yaml.load(self.template, Loader=yaml.Loader) except yaml.parser.ParserError: @@ -250,7 +320,9 @@ class FakeStack(BaseModel): return self.output_map.exports def update(self, template, role_arn=None, parameters=None, tags=None): - self._add_stack_event("UPDATE_IN_PROGRESS", resource_status_reason="User Initiated") + self._add_stack_event( + "UPDATE_IN_PROGRESS", resource_status_reason="User Initiated" + ) self.template = template self._parse_template() self.resource_map.update(self.template_dict, parameters) @@ -264,15 +336,15 @@ class FakeStack(BaseModel): # TODO: update tags in the resource map def delete(self): - self._add_stack_event("DELETE_IN_PROGRESS", - resource_status_reason="User Initiated") + self._add_stack_event( + "DELETE_IN_PROGRESS", resource_status_reason="User Initiated" + ) self.resource_map.delete() self._add_stack_event("DELETE_COMPLETE") self.status = "DELETE_COMPLETE" class FakeChange(BaseModel): - def __init__(self, action, logical_resource_id, resource_type): self.action = action self.logical_resource_id = logical_resource_id @@ -280,8 +352,21 @@ class FakeChange(BaseModel): class FakeChangeSet(FakeStack): - - def __init__(self, stack_id, stack_name, stack_template, change_set_id, change_set_name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, cross_stack_resources=None): + def __init__( + self, + stack_id, + stack_name, + stack_template, + change_set_id, + change_set_name, + template, + parameters, + region_name, + notification_arns=None, + tags=None, + role_arn=None, + cross_stack_resources=None, + ): super(FakeChangeSet, self).__init__( stack_id, stack_name, @@ -306,17 +391,28 @@ class FakeChangeSet(FakeStack): resources_by_action = self.resource_map.diff(self.template_dict, parameters) for action, resources in resources_by_action.items(): for resource_name, resource in resources.items(): - changes.append(FakeChange( - action=action, - logical_resource_id=resource_name, - resource_type=resource['ResourceType'], - )) + changes.append( + FakeChange( + action=action, + logical_resource_id=resource_name, + resource_type=resource["ResourceType"], + ) + ) return changes class FakeEvent(BaseModel): - - def __init__(self, stack_id, stack_name, logical_resource_id, physical_resource_id, resource_type, resource_status, resource_status_reason=None, resource_properties=None): + def __init__( + self, + stack_id, + stack_name, + logical_resource_id, + physical_resource_id, + resource_type, + resource_status, + resource_status_reason=None, + resource_properties=None, + ): self.stack_id = stack_id self.stack_name = stack_name self.logical_resource_id = logical_resource_id @@ -330,7 +426,6 @@ class FakeEvent(BaseModel): class CloudFormationBackend(BaseBackend): - def __init__(self): self.stacks = OrderedDict() self.stacksets = OrderedDict() @@ -338,7 +433,17 @@ class CloudFormationBackend(BaseBackend): self.exports = OrderedDict() self.change_sets = OrderedDict() - def create_stack_set(self, name, template, parameters, tags=None, description=None, region='us-east-1', admin_role=None, execution_role=None): + def create_stack_set( + self, + name, + template, + parameters, + tags=None, + description=None, + region="us-east-1", + admin_role=None, + execution_role=None, + ): stackset_id = generate_stackset_id(name) new_stackset = FakeStackSet( stackset_id=stackset_id, @@ -366,7 +471,9 @@ class CloudFormationBackend(BaseBackend): if self.stacksets[stackset].name == name: self.stacksets[stackset].delete() - def create_stack_instances(self, stackset_name, accounts, regions, parameters, operation_id=None): + def create_stack_instances( + self, stackset_name, accounts, regions, parameters, operation_id=None + ): stackset = self.get_stack_set(stackset_name) stackset.create_stack_instances( @@ -377,9 +484,19 @@ class CloudFormationBackend(BaseBackend): ) return stackset - def update_stack_set(self, stackset_name, template=None, description=None, - parameters=None, tags=None, admin_role=None, execution_role=None, - accounts=None, regions=None, operation_id=None): + def update_stack_set( + self, + stackset_name, + template=None, + description=None, + parameters=None, + tags=None, + admin_role=None, + execution_role=None, + accounts=None, + regions=None, + operation_id=None, + ): stackset = self.get_stack_set(stackset_name) update = stackset.update( template=template, @@ -390,16 +507,28 @@ class CloudFormationBackend(BaseBackend): execution_role=execution_role, accounts=accounts, regions=regions, - operation_id=operation_id + operation_id=operation_id, ) return update - def delete_stack_instances(self, stackset_name, accounts, regions, operation_id=None): + def delete_stack_instances( + self, stackset_name, accounts, regions, operation_id=None + ): stackset = self.get_stack_set(stackset_name) stackset.delete_stack_instances(accounts, regions, operation_id) return stackset - def create_stack(self, name, template, parameters, region_name, notification_arns=None, tags=None, role_arn=None, create_change_set=False): + def create_stack( + self, + name, + template, + parameters, + region_name, + notification_arns=None, + tags=None, + role_arn=None, + create_change_set=False, + ): stack_id = generate_stack_id(name) new_stack = FakeStack( stack_id=stack_id, @@ -419,10 +548,21 @@ class CloudFormationBackend(BaseBackend): self.exports[export.name] = export return new_stack - def create_change_set(self, stack_name, change_set_name, template, parameters, region_name, change_set_type, notification_arns=None, tags=None, role_arn=None): + def create_change_set( + self, + stack_name, + change_set_name, + template, + parameters, + region_name, + change_set_type, + notification_arns=None, + tags=None, + role_arn=None, + ): stack_id = None stack_template = None - if change_set_type == 'UPDATE': + if change_set_type == "UPDATE": stacks = self.stacks.values() stack = None for s in stacks: @@ -449,7 +589,7 @@ class CloudFormationBackend(BaseBackend): notification_arns=notification_arns, tags=tags, role_arn=role_arn, - cross_stack_resources=self.exports + cross_stack_resources=self.exports, ) self.change_sets[change_set_id] = new_change_set self.stacks[stack_id] = new_change_set @@ -488,11 +628,11 @@ class CloudFormationBackend(BaseBackend): stack = self.change_sets[cs] if stack is None: raise ValidationError(stack_name) - if stack.events[-1].resource_status == 'REVIEW_IN_PROGRESS': - stack._add_stack_event('CREATE_COMPLETE') + if stack.events[-1].resource_status == "REVIEW_IN_PROGRESS": + stack._add_stack_event("CREATE_COMPLETE") else: - stack._add_stack_event('UPDATE_IN_PROGRESS') - stack._add_stack_event('UPDATE_COMPLETE') + stack._add_stack_event("UPDATE_IN_PROGRESS") + stack._add_stack_event("UPDATE_COMPLETE") return True def describe_stacks(self, name_or_stack_id): @@ -514,9 +654,7 @@ class CloudFormationBackend(BaseBackend): return self.change_sets.values() def list_stacks(self): - return [ - v for v in self.stacks.values() - ] + [ + return [v for v in self.stacks.values()] + [ v for v in self.deleted_stacks.values() ] @@ -558,10 +696,10 @@ class CloudFormationBackend(BaseBackend): all_exports = list(self.exports.values()) if token is None: exports = all_exports[0:100] - next_token = '100' if len(all_exports) > 100 else None + next_token = "100" if len(all_exports) > 100 else None else: token = int(token) - exports = all_exports[token:token + 100] + exports = all_exports[token : token + 100] next_token = str(token + 100) if len(all_exports) > token + 100 else None return exports, next_token @@ -572,7 +710,10 @@ class CloudFormationBackend(BaseBackend): new_stack_export_names = [x.name for x in stack.exports] export_names = self.exports.keys() if not set(export_names).isdisjoint(new_stack_export_names): - raise ValidationError(stack.stack_id, message='Export names must be unique across a given region') + raise ValidationError( + stack.stack_id, + message="Export names must be unique across a given region", + ) cloudformation_backends = {} diff --git a/moto/cloudformation/parsing.py b/moto/cloudformation/parsing.py index f2e03bd81..77e3c271c 100644 --- a/moto/cloudformation/parsing.py +++ b/moto/cloudformation/parsing.py @@ -28,7 +28,12 @@ from moto.s3 import models as s3_models from moto.sns import models as sns_models from moto.sqs import models as sqs_models from .utils import random_suffix -from .exceptions import ExportNotFound, MissingParameterError, UnformattedGetAttTemplateException, ValidationError +from .exceptions import ( + ExportNotFound, + MissingParameterError, + UnformattedGetAttTemplateException, + ValidationError, +) from boto.cloudformation.stack import Output MODEL_MAP = { @@ -100,7 +105,7 @@ NAME_TYPE_MAP = { "AWS::RDS::DBInstance": "DBInstanceIdentifier", "AWS::S3::Bucket": "BucketName", "AWS::SNS::Topic": "TopicName", - "AWS::SQS::Queue": "QueueName" + "AWS::SQS::Queue": "QueueName", } # Just ignore these models types for now @@ -109,13 +114,12 @@ NULL_MODELS = [ "AWS::CloudFormation::WaitConditionHandle", ] -DEFAULT_REGION = 'us-east-1' +DEFAULT_REGION = "us-east-1" logger = logging.getLogger("moto") class LazyDict(dict): - def __getitem__(self, key): val = dict.__getitem__(self, key) if callable(val): @@ -132,10 +136,10 @@ def clean_json(resource_json, resources_map): Eventually, this is where we would add things like function parsing (fn::) """ if isinstance(resource_json, dict): - if 'Ref' in resource_json: + if "Ref" in resource_json: # Parse resource reference - resource = resources_map[resource_json['Ref']] - if hasattr(resource, 'physical_resource_id'): + resource = resources_map[resource_json["Ref"]] + if hasattr(resource, "physical_resource_id"): return resource.physical_resource_id else: return resource @@ -148,74 +152,92 @@ def clean_json(resource_json, resources_map): result = result[clean_json(path, resources_map)] return result - if 'Fn::GetAtt' in resource_json: - resource = resources_map.get(resource_json['Fn::GetAtt'][0]) + if "Fn::GetAtt" in resource_json: + resource = resources_map.get(resource_json["Fn::GetAtt"][0]) if resource is None: return resource_json try: - return resource.get_cfn_attribute(resource_json['Fn::GetAtt'][1]) + return resource.get_cfn_attribute(resource_json["Fn::GetAtt"][1]) except NotImplementedError as n: - logger.warning(str(n).format( - resource_json['Fn::GetAtt'][0])) + logger.warning(str(n).format(resource_json["Fn::GetAtt"][0])) except UnformattedGetAttTemplateException: raise ValidationError( - 'Bad Request', + "Bad Request", UnformattedGetAttTemplateException.description.format( - resource_json['Fn::GetAtt'][0], resource_json['Fn::GetAtt'][1])) + resource_json["Fn::GetAtt"][0], resource_json["Fn::GetAtt"][1] + ), + ) - if 'Fn::If' in resource_json: - condition_name, true_value, false_value = resource_json['Fn::If'] + if "Fn::If" in resource_json: + condition_name, true_value, false_value = resource_json["Fn::If"] if resources_map.lazy_condition_map[condition_name]: return clean_json(true_value, resources_map) else: return clean_json(false_value, resources_map) - if 'Fn::Join' in resource_json: - join_list = clean_json(resource_json['Fn::Join'][1], resources_map) - return resource_json['Fn::Join'][0].join([str(x) for x in join_list]) + if "Fn::Join" in resource_json: + join_list = clean_json(resource_json["Fn::Join"][1], resources_map) + return resource_json["Fn::Join"][0].join([str(x) for x in join_list]) - if 'Fn::Split' in resource_json: - to_split = clean_json(resource_json['Fn::Split'][1], resources_map) - return to_split.split(resource_json['Fn::Split'][0]) + if "Fn::Split" in resource_json: + to_split = clean_json(resource_json["Fn::Split"][1], resources_map) + return to_split.split(resource_json["Fn::Split"][0]) - if 'Fn::Select' in resource_json: - select_index = int(resource_json['Fn::Select'][0]) - select_list = clean_json(resource_json['Fn::Select'][1], resources_map) + if "Fn::Select" in resource_json: + select_index = int(resource_json["Fn::Select"][0]) + select_list = clean_json(resource_json["Fn::Select"][1], resources_map) return select_list[select_index] - if 'Fn::Sub' in resource_json: - if isinstance(resource_json['Fn::Sub'], list): + if "Fn::Sub" in resource_json: + if isinstance(resource_json["Fn::Sub"], list): warnings.warn( - "Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation") + "Tried to parse Fn::Sub with variable mapping but it's not supported by moto's CloudFormation implementation" + ) else: - fn_sub_value = clean_json(resource_json['Fn::Sub'], resources_map) + fn_sub_value = clean_json(resource_json["Fn::Sub"], resources_map) to_sub = re.findall('(?=\${)[^!^"]*?}', fn_sub_value) literals = re.findall('(?=\${!)[^"]*?}', fn_sub_value) for sub in to_sub: - if '.' in sub: - cleaned_ref = clean_json({'Fn::GetAtt': re.findall('(?<=\${)[^"]*?(?=})', sub)[0].split('.')}, resources_map) + if "." in sub: + cleaned_ref = clean_json( + { + "Fn::GetAtt": re.findall('(?<=\${)[^"]*?(?=})', sub)[ + 0 + ].split(".") + }, + resources_map, + ) else: - cleaned_ref = clean_json({'Ref': re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, resources_map) + cleaned_ref = clean_json( + {"Ref": re.findall('(?<=\${)[^"]*?(?=})', sub)[0]}, + resources_map, + ) fn_sub_value = fn_sub_value.replace(sub, cleaned_ref) for literal in literals: - fn_sub_value = fn_sub_value.replace(literal, literal.replace('!', '')) + fn_sub_value = fn_sub_value.replace( + literal, literal.replace("!", "") + ) return fn_sub_value pass - if 'Fn::ImportValue' in resource_json: - cleaned_val = clean_json(resource_json['Fn::ImportValue'], resources_map) - values = [x.value for x in resources_map.cross_stack_resources.values() if x.name == cleaned_val] + if "Fn::ImportValue" in resource_json: + cleaned_val = clean_json(resource_json["Fn::ImportValue"], resources_map) + values = [ + x.value + for x in resources_map.cross_stack_resources.values() + if x.name == cleaned_val + ] if any(values): return values[0] else: raise ExportNotFound(cleaned_val) - if 'Fn::GetAZs' in resource_json: - region = resource_json.get('Fn::GetAZs') or DEFAULT_REGION + if "Fn::GetAZs" in resource_json: + region = resource_json.get("Fn::GetAZs") or DEFAULT_REGION result = [] # TODO: make this configurable, to reflect the real AWS AZs - for az in ('a', 'b', 'c', 'd'): - result.append('%s%s' % (region, az)) + for az in ("a", "b", "c", "d"): + result.append("%s%s" % (region, az)) return result cleaned_json = {} @@ -246,58 +268,69 @@ def resource_name_property_from_type(resource_type): def generate_resource_name(resource_type, stack_name, logical_id): - if resource_type in ["AWS::ElasticLoadBalancingV2::TargetGroup", - "AWS::ElasticLoadBalancingV2::LoadBalancer"]: + if resource_type in [ + "AWS::ElasticLoadBalancingV2::TargetGroup", + "AWS::ElasticLoadBalancingV2::LoadBalancer", + ]: # Target group names need to be less than 32 characters, so when cloudformation creates a name for you # it makes sure to stay under that limit - name_prefix = '{0}-{1}'.format(stack_name, logical_id) + name_prefix = "{0}-{1}".format(stack_name, logical_id) my_random_suffix = random_suffix() - truncated_name_prefix = name_prefix[0:32 - (len(my_random_suffix) + 1)] + truncated_name_prefix = name_prefix[0 : 32 - (len(my_random_suffix) + 1)] # if the truncated name ends in a dash, we'll end up with a double dash in the final name, which is # not allowed - if truncated_name_prefix.endswith('-'): + if truncated_name_prefix.endswith("-"): truncated_name_prefix = truncated_name_prefix[:-1] - return '{0}-{1}'.format(truncated_name_prefix, my_random_suffix) + return "{0}-{1}".format(truncated_name_prefix, my_random_suffix) else: - return '{0}-{1}-{2}'.format(stack_name, logical_id, random_suffix()) + return "{0}-{1}-{2}".format(stack_name, logical_id, random_suffix()) def parse_resource(logical_id, resource_json, resources_map): - resource_type = resource_json['Type'] + resource_type = resource_json["Type"] resource_class = resource_class_from_type(resource_type) if not resource_class: warnings.warn( - "Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format(resource_type)) + "Tried to parse {0} but it's not supported by moto's CloudFormation implementation".format( + resource_type + ) + ) return None resource_json = clean_json(resource_json, resources_map) resource_name_property = resource_name_property_from_type(resource_type) if resource_name_property: - if 'Properties' not in resource_json: - resource_json['Properties'] = dict() - if resource_name_property not in resource_json['Properties']: - resource_json['Properties'][resource_name_property] = generate_resource_name( - resource_type, resources_map.get('AWS::StackName'), logical_id) - resource_name = resource_json['Properties'][resource_name_property] + if "Properties" not in resource_json: + resource_json["Properties"] = dict() + if resource_name_property not in resource_json["Properties"]: + resource_json["Properties"][ + resource_name_property + ] = generate_resource_name( + resource_type, resources_map.get("AWS::StackName"), logical_id + ) + resource_name = resource_json["Properties"][resource_name_property] else: - resource_name = generate_resource_name(resource_type, resources_map.get('AWS::StackName'), logical_id) + resource_name = generate_resource_name( + resource_type, resources_map.get("AWS::StackName"), logical_id + ) return resource_class, resource_json, resource_name def parse_and_create_resource(logical_id, resource_json, resources_map, region_name): - condition = resource_json.get('Condition') + condition = resource_json.get("Condition") if condition and not resources_map.lazy_condition_map[condition]: # If this has a False condition, don't create the resource return None - resource_type = resource_json['Type'] + resource_type = resource_json["Type"] resource_tuple = parse_resource(logical_id, resource_json, resources_map) if not resource_tuple: return None resource_class, resource_json, resource_name = resource_tuple resource = resource_class.create_from_cloudformation_json( - resource_name, resource_json, region_name) + resource_name, resource_json, region_name + ) resource.type = resource_type resource.logical_resource_id = logical_id return resource @@ -305,24 +338,27 @@ def parse_and_create_resource(logical_id, resource_json, resources_map, region_n def parse_and_update_resource(logical_id, resource_json, resources_map, region_name): resource_class, new_resource_json, new_resource_name = parse_resource( - logical_id, resource_json, resources_map) + logical_id, resource_json, resources_map + ) original_resource = resources_map[logical_id] new_resource = resource_class.update_from_cloudformation_json( original_resource=original_resource, new_resource_name=new_resource_name, cloudformation_json=new_resource_json, - region_name=region_name + region_name=region_name, ) - new_resource.type = resource_json['Type'] + new_resource.type = resource_json["Type"] new_resource.logical_resource_id = logical_id return new_resource def parse_and_delete_resource(logical_id, resource_json, resources_map, region_name): resource_class, resource_json, resource_name = parse_resource( - logical_id, resource_json, resources_map) + logical_id, resource_json, resources_map + ) resource_class.delete_from_cloudformation_json( - resource_name, resource_json, region_name) + resource_name, resource_json, region_name + ) def parse_condition(condition, resources_map, condition_map): @@ -334,8 +370,8 @@ def parse_condition(condition, resources_map, condition_map): condition_values = [] for value in list(condition.values())[0]: # Check if we are referencing another Condition - if 'Condition' in value: - condition_values.append(condition_map[value['Condition']]) + if "Condition" in value: + condition_values.append(condition_map[value["Condition"]]) else: condition_values.append(clean_json(value, resources_map)) @@ -344,23 +380,27 @@ def parse_condition(condition, resources_map, condition_map): elif condition_operator == "Fn::Not": return not parse_condition(condition_values[0], resources_map, condition_map) elif condition_operator == "Fn::And": - return all([ - parse_condition(condition_value, resources_map, condition_map) - for condition_value - in condition_values]) + return all( + [ + parse_condition(condition_value, resources_map, condition_map) + for condition_value in condition_values + ] + ) elif condition_operator == "Fn::Or": - return any([ - parse_condition(condition_value, resources_map, condition_map) - for condition_value - in condition_values]) + return any( + [ + parse_condition(condition_value, resources_map, condition_map) + for condition_value in condition_values + ] + ) def parse_output(output_logical_id, output_json, resources_map): output_json = clean_json(output_json, resources_map) output = Output() output.key = output_logical_id - output.value = clean_json(output_json['Value'], resources_map) - output.description = output_json.get('Description') + output.value = clean_json(output_json["Value"], resources_map) + output.description = output_json.get("Description") return output @@ -371,9 +411,18 @@ class ResourceMap(collections.Mapping): each resources is passed this lazy map that it can grab dependencies from. """ - def __init__(self, stack_id, stack_name, parameters, tags, region_name, template, cross_stack_resources): + def __init__( + self, + stack_id, + stack_name, + parameters, + tags, + region_name, + template, + cross_stack_resources, + ): self._template = template - self._resource_json_map = template['Resources'] + self._resource_json_map = template["Resources"] self._region_name = region_name self.input_parameters = parameters self.tags = copy.deepcopy(tags) @@ -401,7 +450,8 @@ class ResourceMap(collections.Mapping): if not resource_json: raise KeyError(resource_logical_id) new_resource = parse_and_create_resource( - resource_logical_id, resource_json, self, self._region_name) + resource_logical_id, resource_json, self, self._region_name + ) if new_resource is not None: self._parsed_resources[resource_logical_id] = new_resource return new_resource @@ -417,13 +467,13 @@ class ResourceMap(collections.Mapping): return self._resource_json_map.keys() def load_mapping(self): - self._parsed_resources.update(self._template.get('Mappings', {})) + self._parsed_resources.update(self._template.get("Mappings", {})) def load_parameters(self): - parameter_slots = self._template.get('Parameters', {}) + parameter_slots = self._template.get("Parameters", {}) for parameter_name, parameter in parameter_slots.items(): # Set the default values. - self.resolved_parameters[parameter_name] = parameter.get('Default') + self.resolved_parameters[parameter_name] = parameter.get("Default") # Set any input parameters that were passed self.no_echo_parameter_keys = [] @@ -431,11 +481,11 @@ class ResourceMap(collections.Mapping): if key in self.resolved_parameters: parameter_slot = parameter_slots[key] - value_type = parameter_slot.get('Type', 'String') - if value_type == 'CommaDelimitedList' or value_type.startswith("List"): - value = value.split(',') + value_type = parameter_slot.get("Type", "String") + if value_type == "CommaDelimitedList" or value_type.startswith("List"): + value = value.split(",") - if parameter_slot.get('NoEcho'): + if parameter_slot.get("NoEcho"): self.no_echo_parameter_keys.append(key) self.resolved_parameters[key] = value @@ -449,11 +499,15 @@ class ResourceMap(collections.Mapping): self._parsed_resources.update(self.resolved_parameters) def load_conditions(self): - conditions = self._template.get('Conditions', {}) + conditions = self._template.get("Conditions", {}) self.lazy_condition_map = LazyDict() for condition_name, condition in conditions.items(): - self.lazy_condition_map[condition_name] = functools.partial(parse_condition, - condition, self._parsed_resources, self.lazy_condition_map) + self.lazy_condition_map[condition_name] = functools.partial( + parse_condition, + condition, + self._parsed_resources, + self.lazy_condition_map, + ) for condition_name in self.lazy_condition_map: self.lazy_condition_map[condition_name] @@ -465,13 +519,18 @@ class ResourceMap(collections.Mapping): # Since this is a lazy map, to create every object we just need to # iterate through self. - self.tags.update({'aws:cloudformation:stack-name': self.get('AWS::StackName'), - 'aws:cloudformation:stack-id': self.get('AWS::StackId')}) + self.tags.update( + { + "aws:cloudformation:stack-name": self.get("AWS::StackName"), + "aws:cloudformation:stack-id": self.get("AWS::StackId"), + } + ) for resource in self.resources: if isinstance(self[resource], ec2_models.TaggedEC2Resource): - self.tags['aws:cloudformation:logical-id'] = resource + self.tags["aws:cloudformation:logical-id"] = resource ec2_models.ec2_backends[self._region_name].create_tags( - [self[resource].physical_resource_id], self.tags) + [self[resource].physical_resource_id], self.tags + ) def diff(self, template, parameters=None): if parameters: @@ -481,36 +540,35 @@ class ResourceMap(collections.Mapping): self.load_conditions() old_template = self._resource_json_map - new_template = template['Resources'] + new_template = template["Resources"] resource_names_by_action = { - 'Add': set(new_template) - set(old_template), - 'Modify': set(name for name in new_template if name in old_template and new_template[ - name] != old_template[name]), - 'Remove': set(old_template) - set(new_template) - } - resources_by_action = { - 'Add': {}, - 'Modify': {}, - 'Remove': {}, + "Add": set(new_template) - set(old_template), + "Modify": set( + name + for name in new_template + if name in old_template and new_template[name] != old_template[name] + ), + "Remove": set(old_template) - set(new_template), } + resources_by_action = {"Add": {}, "Modify": {}, "Remove": {}} - for resource_name in resource_names_by_action['Add']: - resources_by_action['Add'][resource_name] = { - 'LogicalResourceId': resource_name, - 'ResourceType': new_template[resource_name]['Type'] + for resource_name in resource_names_by_action["Add"]: + resources_by_action["Add"][resource_name] = { + "LogicalResourceId": resource_name, + "ResourceType": new_template[resource_name]["Type"], } - for resource_name in resource_names_by_action['Modify']: - resources_by_action['Modify'][resource_name] = { - 'LogicalResourceId': resource_name, - 'ResourceType': new_template[resource_name]['Type'] + for resource_name in resource_names_by_action["Modify"]: + resources_by_action["Modify"][resource_name] = { + "LogicalResourceId": resource_name, + "ResourceType": new_template[resource_name]["Type"], } - for resource_name in resource_names_by_action['Remove']: - resources_by_action['Remove'][resource_name] = { - 'LogicalResourceId': resource_name, - 'ResourceType': old_template[resource_name]['Type'] + for resource_name in resource_names_by_action["Remove"]: + resources_by_action["Remove"][resource_name] = { + "LogicalResourceId": resource_name, + "ResourceType": old_template[resource_name]["Type"], } return resources_by_action @@ -519,35 +577,38 @@ class ResourceMap(collections.Mapping): resources_by_action = self.diff(template, parameters) old_template = self._resource_json_map - new_template = template['Resources'] + new_template = template["Resources"] self._resource_json_map = new_template - for resource_name, resource in resources_by_action['Add'].items(): + for resource_name, resource in resources_by_action["Add"].items(): resource_json = new_template[resource_name] new_resource = parse_and_create_resource( - resource_name, resource_json, self, self._region_name) + resource_name, resource_json, self, self._region_name + ) self._parsed_resources[resource_name] = new_resource - for resource_name, resource in resources_by_action['Remove'].items(): + for resource_name, resource in resources_by_action["Remove"].items(): resource_json = old_template[resource_name] parse_and_delete_resource( - resource_name, resource_json, self, self._region_name) + resource_name, resource_json, self, self._region_name + ) self._parsed_resources.pop(resource_name) tries = 1 - while resources_by_action['Modify'] and tries < 5: - for resource_name, resource in resources_by_action['Modify'].copy().items(): + while resources_by_action["Modify"] and tries < 5: + for resource_name, resource in resources_by_action["Modify"].copy().items(): resource_json = new_template[resource_name] try: changed_resource = parse_and_update_resource( - resource_name, resource_json, self, self._region_name) + resource_name, resource_json, self, self._region_name + ) except Exception as e: # skip over dependency violations, and try again in a # second pass last_exception = e else: self._parsed_resources[resource_name] = changed_resource - del resources_by_action['Modify'][resource_name] + del resources_by_action["Modify"][resource_name] tries += 1 if tries == 5: raise last_exception @@ -559,7 +620,7 @@ class ResourceMap(collections.Mapping): for resource in remaining_resources.copy(): parsed_resource = self._parsed_resources.get(resource) try: - if parsed_resource and hasattr(parsed_resource, 'delete'): + if parsed_resource and hasattr(parsed_resource, "delete"): parsed_resource.delete(self._region_name) except Exception as e: # skip over dependency violations, and try again in a @@ -573,11 +634,10 @@ class ResourceMap(collections.Mapping): class OutputMap(collections.Mapping): - def __init__(self, resources, template, stack_id): self._template = template self._stack_id = stack_id - self._output_json_map = template.get('Outputs') + self._output_json_map = template.get("Outputs") # Create the default resources self._resource_map = resources @@ -591,7 +651,8 @@ class OutputMap(collections.Mapping): else: output_json = self._output_json_map.get(output_logical_id) new_output = parse_output( - output_logical_id, output_json, self._resource_map) + output_logical_id, output_json, self._resource_map + ) self._parsed_outputs[output_logical_id] = new_output return new_output @@ -610,9 +671,11 @@ class OutputMap(collections.Mapping): exports = [] if self.outputs: for key, value in self._output_json_map.items(): - if value.get('Export'): - cleaned_name = clean_json(value['Export'].get('Name'), self._resource_map) - cleaned_value = clean_json(value.get('Value'), self._resource_map) + if value.get("Export"): + cleaned_name = clean_json( + value["Export"].get("Name"), self._resource_map + ) + cleaned_value = clean_json(value.get("Value"), self._resource_map) exports.append(Export(self._stack_id, cleaned_name, cleaned_value)) return exports @@ -622,7 +685,6 @@ class OutputMap(collections.Mapping): class Export(object): - def __init__(self, exporting_stack_id, name, value): self._exporting_stack_id = exporting_stack_id self._name = name diff --git a/moto/cloudformation/responses.py b/moto/cloudformation/responses.py index 80970262f..f5e094c15 100644 --- a/moto/cloudformation/responses.py +++ b/moto/cloudformation/responses.py @@ -12,7 +12,6 @@ from .exceptions import ValidationError class CloudFormationResponse(BaseResponse): - @property def cloudformation_backend(self): return cloudformation_backends[self.region] @@ -20,17 +19,18 @@ class CloudFormationResponse(BaseResponse): def _get_stack_from_s3_url(self, template_url): template_url_parts = urlparse(template_url) if "localhost" in template_url: - bucket_name, key_name = template_url_parts.path.lstrip( - "/").split("/", 1) + bucket_name, key_name = template_url_parts.path.lstrip("/").split("/", 1) else: - if template_url_parts.netloc.endswith('amazonaws.com') \ - and template_url_parts.netloc.startswith('s3'): + if template_url_parts.netloc.endswith( + "amazonaws.com" + ) and template_url_parts.netloc.startswith("s3"): # Handle when S3 url uses amazon url with bucket in path # Also handles getting region as technically s3 is region'd # region = template_url.netloc.split('.')[1] - bucket_name, key_name = template_url_parts.path.lstrip( - "/").split("/", 1) + bucket_name, key_name = template_url_parts.path.lstrip("/").split( + "/", 1 + ) else: bucket_name = template_url_parts.netloc.split(".")[0] key_name = template_url_parts.path.lstrip("/") @@ -39,24 +39,26 @@ class CloudFormationResponse(BaseResponse): return key.value.decode("utf-8") def create_stack(self): - stack_name = self._get_param('StackName') - stack_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') - role_arn = self._get_param('RoleARN') + stack_name = self._get_param("StackName") + stack_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") + role_arn = self._get_param("RoleARN") parameters_list = self._get_list_prefix("Parameters.member") - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) # Hack dict-comprehension - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in parameters_list - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in parameters_list + ] + ) if template_url: stack_body = self._get_stack_from_s3_url(template_url) - stack_notification_arns = self._get_multi_param( - 'NotificationARNs.member') + stack_notification_arns = self._get_multi_param("NotificationARNs.member") stack = self.cloudformation_backend.create_stack( name=stack_name, @@ -68,34 +70,37 @@ class CloudFormationResponse(BaseResponse): role_arn=role_arn, ) if self.request_json: - return json.dumps({ - 'CreateStackResponse': { - 'CreateStackResult': { - 'StackId': stack.stack_id, + return json.dumps( + { + "CreateStackResponse": { + "CreateStackResult": {"StackId": stack.stack_id} } } - }) + ) else: template = self.response_template(CREATE_STACK_RESPONSE_TEMPLATE) return template.render(stack=stack) @amzn_request_id def create_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') - stack_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') - role_arn = self._get_param('RoleARN') - update_or_create = self._get_param('ChangeSetType', 'CREATE') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") + stack_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") + role_arn = self._get_param("RoleARN") + update_or_create = self._get_param("ChangeSetType", "CREATE") parameters_list = self._get_list_prefix("Parameters.member") - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) - parameters = {param['parameter_key']: param['parameter_value'] - for param in parameters_list} + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) + parameters = { + param["parameter_key"]: param["parameter_value"] + for param in parameters_list + } if template_url: stack_body = self._get_stack_from_s3_url(template_url) - stack_notification_arns = self._get_multi_param( - 'NotificationARNs.member') + stack_notification_arns = self._get_multi_param("NotificationARNs.member") change_set_id, stack_id = self.cloudformation_backend.create_change_set( stack_name=stack_name, change_set_name=change_set_name, @@ -108,66 +113,64 @@ class CloudFormationResponse(BaseResponse): change_set_type=update_or_create, ) if self.request_json: - return json.dumps({ - 'CreateChangeSetResponse': { - 'CreateChangeSetResult': { - 'Id': change_set_id, - 'StackId': stack_id, + return json.dumps( + { + "CreateChangeSetResponse": { + "CreateChangeSetResult": { + "Id": change_set_id, + "StackId": stack_id, + } } } - }) + ) else: template = self.response_template(CREATE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render(stack_id=stack_id, change_set_id=change_set_id) def delete_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") - self.cloudformation_backend.delete_change_set(change_set_name=change_set_name, stack_name=stack_name) + self.cloudformation_backend.delete_change_set( + change_set_name=change_set_name, stack_name=stack_name + ) if self.request_json: - return json.dumps({ - 'DeleteChangeSetResponse': { - 'DeleteChangeSetResult': {}, - } - }) + return json.dumps( + {"DeleteChangeSetResponse": {"DeleteChangeSetResult": {}}} + ) else: template = self.response_template(DELETE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render() def describe_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") change_set = self.cloudformation_backend.describe_change_set( - change_set_name=change_set_name, - stack_name=stack_name, + change_set_name=change_set_name, stack_name=stack_name ) template = self.response_template(DESCRIBE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render(change_set=change_set) @amzn_request_id def execute_change_set(self): - stack_name = self._get_param('StackName') - change_set_name = self._get_param('ChangeSetName') + stack_name = self._get_param("StackName") + change_set_name = self._get_param("ChangeSetName") self.cloudformation_backend.execute_change_set( - stack_name=stack_name, - change_set_name=change_set_name, + stack_name=stack_name, change_set_name=change_set_name ) if self.request_json: - return json.dumps({ - 'ExecuteChangeSetResponse': { - 'ExecuteChangeSetResult': {}, - } - }) + return json.dumps( + {"ExecuteChangeSetResponse": {"ExecuteChangeSetResult": {}}} + ) else: template = self.response_template(EXECUTE_CHANGE_SET_RESPONSE_TEMPLATE) return template.render() def describe_stacks(self): stack_name_or_id = None - if self._get_param('StackName'): - stack_name_or_id = self.querystring.get('StackName')[0] - token = self._get_param('NextToken') + if self._get_param("StackName"): + stack_name_or_id = self.querystring.get("StackName")[0] + token = self._get_param("NextToken") stacks = self.cloudformation_backend.describe_stacks(stack_name_or_id) stack_ids = [stack.stack_id for stack in stacks] if token: @@ -175,7 +178,7 @@ class CloudFormationResponse(BaseResponse): else: start = 0 max_results = 50 # using this to mske testing of paginated stacks more convenient than default 1 MB - stacks_resp = stacks[start:start + max_results] + stacks_resp = stacks[start : start + max_results] next_token = None if len(stacks) > (start + max_results): next_token = stacks_resp[-1].stack_id @@ -183,9 +186,9 @@ class CloudFormationResponse(BaseResponse): return template.render(stacks=stacks_resp, next_token=next_token) def describe_stack_resource(self): - stack_name = self._get_param('StackName') + stack_name = self._get_param("StackName") stack = self.cloudformation_backend.get_stack(stack_name) - logical_resource_id = self._get_param('LogicalResourceId') + logical_resource_id = self._get_param("LogicalResourceId") for stack_resource in stack.stack_resources: if stack_resource.logical_resource_id == logical_resource_id: @@ -194,19 +197,18 @@ class CloudFormationResponse(BaseResponse): else: raise ValidationError(logical_resource_id) - template = self.response_template( - DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE) + template = self.response_template(DESCRIBE_STACK_RESOURCE_RESPONSE_TEMPLATE) return template.render(stack=stack, resource=resource) def describe_stack_resources(self): - stack_name = self._get_param('StackName') + stack_name = self._get_param("StackName") stack = self.cloudformation_backend.get_stack(stack_name) template = self.response_template(DESCRIBE_STACK_RESOURCES_RESPONSE) return template.render(stack=stack) def describe_stack_events(self): - stack_name = self._get_param('StackName') + stack_name = self._get_param("StackName") stack = self.cloudformation_backend.get_stack(stack_name) template = self.response_template(DESCRIBE_STACK_EVENTS_RESPONSE) @@ -223,68 +225,82 @@ class CloudFormationResponse(BaseResponse): return template.render(stacks=stacks) def list_stack_resources(self): - stack_name_or_id = self._get_param('StackName') - resources = self.cloudformation_backend.list_stack_resources( - stack_name_or_id) + stack_name_or_id = self._get_param("StackName") + resources = self.cloudformation_backend.list_stack_resources(stack_name_or_id) template = self.response_template(LIST_STACKS_RESOURCES_RESPONSE) return template.render(resources=resources) def get_template(self): - name_or_stack_id = self.querystring.get('StackName')[0] + name_or_stack_id = self.querystring.get("StackName")[0] stack = self.cloudformation_backend.get_stack(name_or_stack_id) if self.request_json: - return json.dumps({ - "GetTemplateResponse": { - "GetTemplateResult": { - "TemplateBody": stack.template, - "ResponseMetadata": { - "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + return json.dumps( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": stack.template, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - }) + ) else: template = self.response_template(GET_TEMPLATE_RESPONSE_TEMPLATE) return template.render(stack=stack) def update_stack(self): - stack_name = self._get_param('StackName') - role_arn = self._get_param('RoleARN') - template_url = self._get_param('TemplateURL') - stack_body = self._get_param('TemplateBody') + stack_name = self._get_param("StackName") + role_arn = self._get_param("RoleARN") + template_url = self._get_param("TemplateURL") + stack_body = self._get_param("TemplateBody") stack = self.cloudformation_backend.get_stack(stack_name) - if self._get_param('UsePreviousTemplate') == "true": + if self._get_param("UsePreviousTemplate") == "true": stack_body = stack.template elif not stack_body and template_url: stack_body = self._get_stack_from_s3_url(template_url) incoming_params = self._get_list_prefix("Parameters.member") - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in incoming_params if 'parameter_value' in parameter - ]) - previous = dict([ - (parameter['parameter_key'], stack.parameters[parameter['parameter_key']]) - for parameter - in incoming_params if 'use_previous_value' in parameter - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in incoming_params + if "parameter_value" in parameter + ] + ) + previous = dict( + [ + ( + parameter["parameter_key"], + stack.parameters[parameter["parameter_key"]], + ) + for parameter in incoming_params + if "use_previous_value" in parameter + ] + ) parameters.update(previous) # boto3 is supposed to let you clear the tags by passing an empty value, but the request body doesn't # end up containing anything we can use to differentiate between passing an empty value versus not # passing anything. so until that changes, moto won't be able to clear tags, only update them. - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) # so that if we don't pass the parameter, we don't clear all the tags accidentally if not tags: tags = None stack = self.cloudformation_backend.get_stack(stack_name) - if stack.status == 'ROLLBACK_COMPLETE': + if stack.status == "ROLLBACK_COMPLETE": raise ValidationError( - stack.stack_id, message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format(stack.stack_id)) + stack.stack_id, + message="Stack:{0} is in ROLLBACK_COMPLETE state and can not be updated.".format( + stack.stack_id + ), + ) stack = self.cloudformation_backend.update_stack( name=stack_name, @@ -295,11 +311,7 @@ class CloudFormationResponse(BaseResponse): ) if self.request_json: stack_body = { - 'UpdateStackResponse': { - 'UpdateStackResult': { - 'StackId': stack.name, - } - } + "UpdateStackResponse": {"UpdateStackResult": {"StackId": stack.name}} } return json.dumps(stack_body) else: @@ -307,56 +319,57 @@ class CloudFormationResponse(BaseResponse): return template.render(stack=stack) def delete_stack(self): - name_or_stack_id = self.querystring.get('StackName')[0] + name_or_stack_id = self.querystring.get("StackName")[0] self.cloudformation_backend.delete_stack(name_or_stack_id) if self.request_json: - return json.dumps({ - 'DeleteStackResponse': { - 'DeleteStackResult': {}, - } - }) + return json.dumps({"DeleteStackResponse": {"DeleteStackResult": {}}}) else: template = self.response_template(DELETE_STACK_RESPONSE_TEMPLATE) return template.render() def list_exports(self): - token = self._get_param('NextToken') + token = self._get_param("NextToken") exports, next_token = self.cloudformation_backend.list_exports(token=token) template = self.response_template(LIST_EXPORTS_RESPONSE) return template.render(exports=exports, next_token=next_token) def validate_template(self): - cfn_lint = self.cloudformation_backend.validate_template(self._get_param('TemplateBody')) + cfn_lint = self.cloudformation_backend.validate_template( + self._get_param("TemplateBody") + ) if cfn_lint: raise ValidationError(cfn_lint[0].message) description = "" try: - description = json.loads(self._get_param('TemplateBody'))['Description'] + description = json.loads(self._get_param("TemplateBody"))["Description"] except (ValueError, KeyError): pass try: - description = yaml.load(self._get_param('TemplateBody'))['Description'] + description = yaml.load(self._get_param("TemplateBody"))["Description"] except (yaml.ParserError, KeyError): pass template = self.response_template(VALIDATE_STACK_RESPONSE_TEMPLATE) return template.render(description=description) def create_stack_set(self): - stackset_name = self._get_param('StackSetName') - stack_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') + stackset_name = self._get_param("StackSetName") + stack_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") # role_arn = self._get_param('RoleARN') parameters_list = self._get_list_prefix("Parameters.member") - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) # Copy-Pasta - Hack dict-comprehension - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in parameters_list - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in parameters_list + ] + ) if template_url: stack_body = self._get_stack_from_s3_url(template_url) @@ -368,59 +381,65 @@ class CloudFormationResponse(BaseResponse): # role_arn=role_arn, ) if self.request_json: - return json.dumps({ - 'CreateStackSetResponse': { - 'CreateStackSetResult': { - 'StackSetId': stackset.stackset_id, + return json.dumps( + { + "CreateStackSetResponse": { + "CreateStackSetResult": {"StackSetId": stackset.stackset_id} } } - }) + ) else: template = self.response_template(CREATE_STACK_SET_RESPONSE_TEMPLATE) return template.render(stackset=stackset) def create_stack_instances(self): - stackset_name = self._get_param('StackSetName') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - parameters = self._get_multi_param('ParameterOverrides.member') - self.cloudformation_backend.create_stack_instances(stackset_name, accounts, regions, parameters) + stackset_name = self._get_param("StackSetName") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + parameters = self._get_multi_param("ParameterOverrides.member") + self.cloudformation_backend.create_stack_instances( + stackset_name, accounts, regions, parameters + ) template = self.response_template(CREATE_STACK_INSTANCES_TEMPLATE) return template.render() def delete_stack_set(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") self.cloudformation_backend.delete_stack_set(stackset_name) template = self.response_template(DELETE_STACK_SET_RESPONSE_TEMPLATE) return template.render() def delete_stack_instances(self): - stackset_name = self._get_param('StackSetName') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - operation = self.cloudformation_backend.delete_stack_instances(stackset_name, accounts, regions) + stackset_name = self._get_param("StackSetName") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + operation = self.cloudformation_backend.delete_stack_instances( + stackset_name, accounts, regions + ) template = self.response_template(DELETE_STACK_INSTANCES_TEMPLATE) return template.render(operation=operation) def describe_stack_set(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") stackset = self.cloudformation_backend.get_stack_set(stackset_name) if not stackset.admin_role: - stackset.admin_role = 'arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole' + stackset.admin_role = "arn:aws:iam::123456789012:role/AWSCloudFormationStackSetAdministrationRole" if not stackset.execution_role: - stackset.execution_role = 'AWSCloudFormationStackSetExecutionRole' + stackset.execution_role = "AWSCloudFormationStackSetExecutionRole" template = self.response_template(DESCRIBE_STACK_SET_RESPONSE_TEMPLATE) return template.render(stackset=stackset) def describe_stack_instance(self): - stackset_name = self._get_param('StackSetName') - account = self._get_param('StackInstanceAccount') - region = self._get_param('StackInstanceRegion') + stackset_name = self._get_param("StackSetName") + account = self._get_param("StackInstanceAccount") + region = self._get_param("StackInstanceRegion") - instance = self.cloudformation_backend.get_stack_set(stackset_name).instances.get_instance(account, region) + instance = self.cloudformation_backend.get_stack_set( + stackset_name + ).instances.get_instance(account, region) template = self.response_template(DESCRIBE_STACK_INSTANCE_TEMPLATE) rendered = template.render(instance=instance) return rendered @@ -431,61 +450,66 @@ class CloudFormationResponse(BaseResponse): return template.render(stacksets=stacksets) def list_stack_instances(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") stackset = self.cloudformation_backend.get_stack_set(stackset_name) template = self.response_template(LIST_STACK_INSTANCES_TEMPLATE) return template.render(stackset=stackset) def list_stack_set_operations(self): - stackset_name = self._get_param('StackSetName') + stackset_name = self._get_param("StackSetName") stackset = self.cloudformation_backend.get_stack_set(stackset_name) template = self.response_template(LIST_STACK_SET_OPERATIONS_RESPONSE_TEMPLATE) return template.render(stackset=stackset) def stop_stack_set_operation(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") stackset = self.cloudformation_backend.get_stack_set(stackset_name) - stackset.update_operation(operation_id, 'STOPPED') + stackset.update_operation(operation_id, "STOPPED") template = self.response_template(STOP_STACK_SET_OPERATION_RESPONSE_TEMPLATE) return template.render() def describe_stack_set_operation(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") stackset = self.cloudformation_backend.get_stack_set(stackset_name) operation = stackset.get_operation(operation_id) template = self.response_template(DESCRIBE_STACKSET_OPERATION_RESPONSE_TEMPLATE) return template.render(stackset=stackset, operation=operation) def list_stack_set_operation_results(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") stackset = self.cloudformation_backend.get_stack_set(stackset_name) operation = stackset.get_operation(operation_id) - template = self.response_template(LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE) + template = self.response_template( + LIST_STACK_SET_OPERATION_RESULTS_RESPONSE_TEMPLATE + ) return template.render(operation=operation) def update_stack_set(self): - stackset_name = self._get_param('StackSetName') - operation_id = self._get_param('OperationId') - description = self._get_param('Description') - execution_role = self._get_param('ExecutionRoleName') - admin_role = self._get_param('AdministrationRoleARN') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - template_body = self._get_param('TemplateBody') - template_url = self._get_param('TemplateURL') + stackset_name = self._get_param("StackSetName") + operation_id = self._get_param("OperationId") + description = self._get_param("Description") + execution_role = self._get_param("ExecutionRoleName") + admin_role = self._get_param("AdministrationRoleARN") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + template_body = self._get_param("TemplateBody") + template_url = self._get_param("TemplateURL") if template_url: template_body = self._get_stack_from_s3_url(template_url) - tags = dict((item['key'], item['value']) - for item in self._get_list_prefix("Tags.member")) + tags = dict( + (item["key"], item["value"]) + for item in self._get_list_prefix("Tags.member") + ) parameters_list = self._get_list_prefix("Parameters.member") - parameters = dict([ - (parameter['parameter_key'], parameter['parameter_value']) - for parameter - in parameters_list - ]) + parameters = dict( + [ + (parameter["parameter_key"], parameter["parameter_value"]) + for parameter in parameters_list + ] + ) operation = self.cloudformation_backend.update_stack_set( stackset_name=stackset_name, template=template_body, @@ -496,18 +520,20 @@ class CloudFormationResponse(BaseResponse): execution_role=execution_role, accounts=accounts, regions=regions, - operation_id=operation_id + operation_id=operation_id, ) template = self.response_template(UPDATE_STACK_SET_RESPONSE_TEMPLATE) return template.render(operation=operation) def update_stack_instances(self): - stackset_name = self._get_param('StackSetName') - accounts = self._get_multi_param('Accounts.member') - regions = self._get_multi_param('Regions.member') - parameters = self._get_multi_param('ParameterOverrides.member') - operation = self.cloudformation_backend.get_stack_set(stackset_name).update_instances(accounts, regions, parameters) + stackset_name = self._get_param("StackSetName") + accounts = self._get_multi_param("Accounts.member") + regions = self._get_multi_param("Regions.member") + parameters = self._get_multi_param("ParameterOverrides.member") + operation = self.cloudformation_backend.get_stack_set( + stackset_name + ).update_instances(accounts, regions, parameters) template = self.response_template(UPDATE_STACK_INSTANCES_RESPONSE_TEMPLATE) return template.render(operation=operation) diff --git a/moto/cloudformation/urls.py b/moto/cloudformation/urls.py index 468c68d98..84251e82b 100644 --- a/moto/cloudformation/urls.py +++ b/moto/cloudformation/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import CloudFormationResponse -url_bases = [ - "https?://cloudformation.(.+).amazonaws.com", -] +url_bases = ["https?://cloudformation.(.+).amazonaws.com"] -url_paths = { - '{0}/$': CloudFormationResponse.dispatch, -} +url_paths = {"{0}/$": CloudFormationResponse.dispatch} diff --git a/moto/cloudformation/utils.py b/moto/cloudformation/utils.py index e4290ce1a..42dfa0b63 100644 --- a/moto/cloudformation/utils.py +++ b/moto/cloudformation/utils.py @@ -11,44 +11,51 @@ from cfnlint import decode, core def generate_stack_id(stack_name, region="us-east-1", account="123456789"): random_id = uuid.uuid4() - return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format(region, account, stack_name, random_id) + return "arn:aws:cloudformation:{}:{}:stack/{}/{}".format( + region, account, stack_name, random_id + ) def generate_changeset_id(changeset_name, region_name): random_id = uuid.uuid4() - return 'arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}'.format(region_name, changeset_name, random_id) + return "arn:aws:cloudformation:{0}:123456789:changeSet/{1}/{2}".format( + region_name, changeset_name, random_id + ) def generate_stackset_id(stackset_name): random_id = uuid.uuid4() - return '{}:{}'.format(stackset_name, random_id) + return "{}:{}".format(stackset_name, random_id) def generate_stackset_arn(stackset_id, region_name): - return 'arn:aws:cloudformation:{}:123456789012:stackset/{}'.format(region_name, stackset_id) + return "arn:aws:cloudformation:{}:123456789012:stackset/{}".format( + region_name, stackset_id + ) def random_suffix(): size = 12 chars = list(range(10)) + list(string.ascii_uppercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) def yaml_tag_constructor(loader, tag, node): """convert shorthand intrinsic function to full name """ + def _f(loader, tag, node): - if tag == '!GetAtt': - return node.value.split('.') + if tag == "!GetAtt": + return node.value.split(".") elif type(node) == yaml.SequenceNode: return loader.construct_sequence(node) else: return node.value - if tag == '!Ref': - key = 'Ref' + if tag == "!Ref": + key = "Ref" else: - key = 'Fn::{}'.format(tag[1:]) + key = "Fn::{}".format(tag[1:]) return {key: _f(loader, tag, node)} @@ -71,13 +78,9 @@ def validate_template_cfn_lint(template): rules = core.get_rules([], [], []) # Use us-east-1 region (spec file) for validation - regions = ['us-east-1'] + regions = ["us-east-1"] # Process all the rules and gather the errors - matches = core.run_checks( - abs_filename, - template, - rules, - regions) + matches = core.run_checks(abs_filename, template, rules, regions) return matches diff --git a/moto/cloudwatch/__init__.py b/moto/cloudwatch/__init__.py index 861fb703a..86a774933 100644 --- a/moto/cloudwatch/__init__.py +++ b/moto/cloudwatch/__init__.py @@ -1,6 +1,6 @@ from .models import cloudwatch_backends from ..core.models import base_decorator, deprecated_base_decorator -cloudwatch_backend = cloudwatch_backends['us-east-1'] +cloudwatch_backend = cloudwatch_backends["us-east-1"] mock_cloudwatch = base_decorator(cloudwatch_backends) mock_cloudwatch_deprecated = deprecated_base_decorator(cloudwatch_backends) diff --git a/moto/cloudwatch/models.py b/moto/cloudwatch/models.py index ed644f874..2f5a14890 100644 --- a/moto/cloudwatch/models.py +++ b/moto/cloudwatch/models.py @@ -1,4 +1,3 @@ - import json from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.core import BaseBackend, BaseModel @@ -14,7 +13,6 @@ _EMPTY_LIST = tuple() class Dimension(object): - def __init__(self, name, value): self.name = name self.value = value @@ -49,10 +47,23 @@ def daterange(start, stop, step=timedelta(days=1), inclusive=False): class FakeAlarm(BaseModel): - - def __init__(self, name, namespace, metric_name, comparison_operator, evaluation_periods, - period, threshold, statistic, description, dimensions, alarm_actions, - ok_actions, insufficient_data_actions, unit): + def __init__( + self, + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ): self.name = name self.namespace = namespace self.metric_name = metric_name @@ -62,8 +73,9 @@ class FakeAlarm(BaseModel): self.threshold = threshold self.statistic = statistic self.description = description - self.dimensions = [Dimension(dimension['name'], dimension[ - 'value']) for dimension in dimensions] + self.dimensions = [ + Dimension(dimension["name"], dimension["value"]) for dimension in dimensions + ] self.alarm_actions = alarm_actions self.ok_actions = ok_actions self.insufficient_data_actions = insufficient_data_actions @@ -72,15 +84,21 @@ class FakeAlarm(BaseModel): self.history = [] - self.state_reason = '' - self.state_reason_data = '{}' - self.state_value = 'OK' + self.state_reason = "" + self.state_reason_data = "{}" + self.state_value = "OK" self.state_updated_timestamp = datetime.utcnow() def update_state(self, reason, reason_data, state_value): # History type, that then decides what the rest of the items are, can be one of ConfigurationUpdate | StateUpdate | Action self.history.append( - ('StateUpdate', self.state_reason, self.state_reason_data, self.state_value, self.state_updated_timestamp) + ( + "StateUpdate", + self.state_reason, + self.state_reason_data, + self.state_value, + self.state_updated_timestamp, + ) ) self.state_reason = reason @@ -90,14 +108,14 @@ class FakeAlarm(BaseModel): class MetricDatum(BaseModel): - def __init__(self, namespace, name, value, dimensions, timestamp): self.namespace = namespace self.name = name self.value = value self.timestamp = timestamp or datetime.utcnow().replace(tzinfo=tzutc()) - self.dimensions = [Dimension(dimension['Name'], dimension[ - 'Value']) for dimension in dimensions] + self.dimensions = [ + Dimension(dimension["Name"], dimension["Value"]) for dimension in dimensions + ] class Dashboard(BaseModel): @@ -120,7 +138,7 @@ class Dashboard(BaseModel): return len(self.body) def __repr__(self): - return ''.format(self.name) + return "".format(self.name) class Statistics: @@ -131,7 +149,7 @@ class Statistics: @property def sample_count(self): - if 'SampleCount' not in self.stats: + if "SampleCount" not in self.stats: return None return len(self.values) @@ -142,28 +160,28 @@ class Statistics: @property def sum(self): - if 'Sum' not in self.stats: + if "Sum" not in self.stats: return None return sum(self.values) @property def minimum(self): - if 'Minimum' not in self.stats: + if "Minimum" not in self.stats: return None return min(self.values) @property def maximum(self): - if 'Maximum' not in self.stats: + if "Maximum" not in self.stats: return None return max(self.values) @property def average(self): - if 'Average' not in self.stats: + if "Average" not in self.stats: return None # when moto is 3.4+ we can switch to the statistics module @@ -171,18 +189,44 @@ class Statistics: class CloudWatchBackend(BaseBackend): - def __init__(self): self.alarms = {} self.dashboards = {} self.metric_data = [] - def put_metric_alarm(self, name, namespace, metric_name, comparison_operator, evaluation_periods, - period, threshold, statistic, description, dimensions, - alarm_actions, ok_actions, insufficient_data_actions, unit): - alarm = FakeAlarm(name, namespace, metric_name, comparison_operator, evaluation_periods, period, - threshold, statistic, description, dimensions, alarm_actions, - ok_actions, insufficient_data_actions, unit) + def put_metric_alarm( + self, + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ): + alarm = FakeAlarm( + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ) self.alarms[name] = alarm return alarm @@ -214,14 +258,12 @@ class CloudWatchBackend(BaseBackend): ] def get_alarms_by_alarm_names(self, alarm_names): - return [ - alarm - for alarm in self.alarms.values() - if alarm.name in alarm_names - ] + return [alarm for alarm in self.alarms.values() if alarm.name in alarm_names] def get_alarms_by_state_value(self, target_state): - return filter(lambda alarm: alarm.state_value == target_state, self.alarms.values()) + return filter( + lambda alarm: alarm.state_value == target_state, self.alarms.values() + ) def delete_alarms(self, alarm_names): for alarm_name in alarm_names: @@ -230,17 +272,31 @@ class CloudWatchBackend(BaseBackend): def put_metric_data(self, namespace, metric_data): for metric_member in metric_data: # Preserve "datetime" for get_metric_statistics comparisons - timestamp = metric_member.get('Timestamp') + timestamp = metric_member.get("Timestamp") if timestamp is not None and type(timestamp) != datetime: - timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') + timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") timestamp = timestamp.replace(tzinfo=tzutc()) - self.metric_data.append(MetricDatum( - namespace, metric_member['MetricName'], float(metric_member.get('Value', 0)), metric_member.get('Dimensions.member', _EMPTY_LIST), timestamp)) + self.metric_data.append( + MetricDatum( + namespace, + metric_member["MetricName"], + float(metric_member.get("Value", 0)), + metric_member.get("Dimensions.member", _EMPTY_LIST), + timestamp, + ) + ) - def get_metric_statistics(self, namespace, metric_name, start_time, end_time, period, stats): + def get_metric_statistics( + self, namespace, metric_name, start_time, end_time, period, stats + ): period_delta = timedelta(seconds=period) - filtered_data = [md for md in self.metric_data if - md.namespace == namespace and md.name == metric_name and start_time <= md.timestamp <= end_time] + filtered_data = [ + md + for md in self.metric_data + if md.namespace == namespace + and md.name == metric_name + and start_time <= md.timestamp <= end_time + ] # earliest to oldest filtered_data = sorted(filtered_data, key=lambda x: x.timestamp) @@ -249,9 +305,15 @@ class CloudWatchBackend(BaseBackend): idx = 0 data = list() - for dt in daterange(filtered_data[0].timestamp, filtered_data[-1].timestamp + period_delta, period_delta): + for dt in daterange( + filtered_data[0].timestamp, + filtered_data[-1].timestamp + period_delta, + period_delta, + ): s = Statistics(stats, dt) - while idx < len(filtered_data) and filtered_data[idx].timestamp < (dt + period_delta): + while idx < len(filtered_data) and filtered_data[idx].timestamp < ( + dt + period_delta + ): s.values.append(filtered_data[idx].value) idx += 1 @@ -268,7 +330,7 @@ class CloudWatchBackend(BaseBackend): def put_dashboard(self, name, body): self.dashboards[name] = Dashboard(name, body) - def list_dashboards(self, prefix=''): + def list_dashboards(self, prefix=""): for key, value in self.dashboards.items(): if key.startswith(prefix): yield value @@ -280,7 +342,12 @@ class CloudWatchBackend(BaseBackend): left_over = to_delete - all_dashboards if len(left_over) > 0: # Some dashboards are not found - return False, 'The specified dashboard does not exist. [{0}]'.format(', '.join(left_over)) + return ( + False, + "The specified dashboard does not exist. [{0}]".format( + ", ".join(left_over) + ), + ) for dashboard in to_delete: del self.dashboards[dashboard] @@ -295,32 +362,36 @@ class CloudWatchBackend(BaseBackend): if reason_data is not None: json.loads(reason_data) except ValueError: - raise RESTError('InvalidFormat', 'StateReasonData is invalid JSON') + raise RESTError("InvalidFormat", "StateReasonData is invalid JSON") if alarm_name not in self.alarms: - raise RESTError('ResourceNotFound', 'Alarm {0} not found'.format(alarm_name), status=404) + raise RESTError( + "ResourceNotFound", "Alarm {0} not found".format(alarm_name), status=404 + ) - if state_value not in ('OK', 'ALARM', 'INSUFFICIENT_DATA'): - raise RESTError('InvalidParameterValue', 'StateValue is not one of OK | ALARM | INSUFFICIENT_DATA') + if state_value not in ("OK", "ALARM", "INSUFFICIENT_DATA"): + raise RESTError( + "InvalidParameterValue", + "StateValue is not one of OK | ALARM | INSUFFICIENT_DATA", + ) self.alarms[alarm_name].update_state(reason, reason_data, state_value) class LogGroup(BaseModel): - def __init__(self, spec): # required - self.name = spec['LogGroupName'] + self.name = spec["LogGroupName"] # optional - self.tags = spec.get('Tags', []) + self.tags = spec.get("Tags", []) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - spec = { - 'LogGroupName': properties['LogGroupName'] - } - optional_properties = 'Tags'.split() + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + spec = {"LogGroupName": properties["LogGroupName"]} + optional_properties = "Tags".split() for prop in optional_properties: if prop in properties: spec[prop] = properties[prop] diff --git a/moto/cloudwatch/responses.py b/moto/cloudwatch/responses.py index bf176e1be..5c381f36b 100644 --- a/moto/cloudwatch/responses.py +++ b/moto/cloudwatch/responses.py @@ -6,7 +6,6 @@ from dateutil.parser import parse as dtparse class CloudWatchResponse(BaseResponse): - @property def cloudwatch_backend(self): return cloudwatch_backends[self.region] @@ -17,45 +16,54 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def put_metric_alarm(self): - name = self._get_param('AlarmName') - namespace = self._get_param('Namespace') - metric_name = self._get_param('MetricName') - comparison_operator = self._get_param('ComparisonOperator') - evaluation_periods = self._get_param('EvaluationPeriods') - period = self._get_param('Period') - threshold = self._get_param('Threshold') - statistic = self._get_param('Statistic') - description = self._get_param('AlarmDescription') - dimensions = self._get_list_prefix('Dimensions.member') - alarm_actions = self._get_multi_param('AlarmActions.member') - ok_actions = self._get_multi_param('OKActions.member') + name = self._get_param("AlarmName") + namespace = self._get_param("Namespace") + metric_name = self._get_param("MetricName") + comparison_operator = self._get_param("ComparisonOperator") + evaluation_periods = self._get_param("EvaluationPeriods") + period = self._get_param("Period") + threshold = self._get_param("Threshold") + statistic = self._get_param("Statistic") + description = self._get_param("AlarmDescription") + dimensions = self._get_list_prefix("Dimensions.member") + alarm_actions = self._get_multi_param("AlarmActions.member") + ok_actions = self._get_multi_param("OKActions.member") insufficient_data_actions = self._get_multi_param( - "InsufficientDataActions.member") - unit = self._get_param('Unit') - alarm = self.cloudwatch_backend.put_metric_alarm(name, namespace, metric_name, - comparison_operator, - evaluation_periods, period, - threshold, statistic, - description, dimensions, - alarm_actions, ok_actions, - insufficient_data_actions, - unit) + "InsufficientDataActions.member" + ) + unit = self._get_param("Unit") + alarm = self.cloudwatch_backend.put_metric_alarm( + name, + namespace, + metric_name, + comparison_operator, + evaluation_periods, + period, + threshold, + statistic, + description, + dimensions, + alarm_actions, + ok_actions, + insufficient_data_actions, + unit, + ) template = self.response_template(PUT_METRIC_ALARM_TEMPLATE) return template.render(alarm=alarm) @amzn_request_id def describe_alarms(self): - action_prefix = self._get_param('ActionPrefix') - alarm_name_prefix = self._get_param('AlarmNamePrefix') - alarm_names = self._get_multi_param('AlarmNames.member') - state_value = self._get_param('StateValue') + action_prefix = self._get_param("ActionPrefix") + alarm_name_prefix = self._get_param("AlarmNamePrefix") + alarm_names = self._get_multi_param("AlarmNames.member") + state_value = self._get_param("StateValue") if action_prefix: - alarms = self.cloudwatch_backend.get_alarms_by_action_prefix( - action_prefix) + alarms = self.cloudwatch_backend.get_alarms_by_action_prefix(action_prefix) elif alarm_name_prefix: alarms = self.cloudwatch_backend.get_alarms_by_alarm_name_prefix( - alarm_name_prefix) + alarm_name_prefix + ) elif alarm_names: alarms = self.cloudwatch_backend.get_alarms_by_alarm_names(alarm_names) elif state_value: @@ -68,15 +76,15 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def delete_alarms(self): - alarm_names = self._get_multi_param('AlarmNames.member') + alarm_names = self._get_multi_param("AlarmNames.member") self.cloudwatch_backend.delete_alarms(alarm_names) template = self.response_template(DELETE_METRIC_ALARMS_TEMPLATE) return template.render() @amzn_request_id def put_metric_data(self): - namespace = self._get_param('Namespace') - metric_data = self._get_multi_param('MetricData.member') + namespace = self._get_param("Namespace") + metric_data = self._get_multi_param("MetricData.member") self.cloudwatch_backend.put_metric_data(namespace, metric_data) template = self.response_template(PUT_METRIC_DATA_TEMPLATE) @@ -84,25 +92,29 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def get_metric_statistics(self): - namespace = self._get_param('Namespace') - metric_name = self._get_param('MetricName') - start_time = dtparse(self._get_param('StartTime')) - end_time = dtparse(self._get_param('EndTime')) - period = int(self._get_param('Period')) + namespace = self._get_param("Namespace") + metric_name = self._get_param("MetricName") + start_time = dtparse(self._get_param("StartTime")) + end_time = dtparse(self._get_param("EndTime")) + period = int(self._get_param("Period")) statistics = self._get_multi_param("Statistics.member") # Unsupported Parameters (To Be Implemented) - unit = self._get_param('Unit') - extended_statistics = self._get_param('ExtendedStatistics') - dimensions = self._get_param('Dimensions') + unit = self._get_param("Unit") + extended_statistics = self._get_param("ExtendedStatistics") + dimensions = self._get_param("Dimensions") if unit or extended_statistics or dimensions: - raise NotImplemented() + raise NotImplementedError() # TODO: this should instead throw InvalidParameterCombination if not statistics: - raise NotImplemented("Must specify either Statistics or ExtendedStatistics") + raise NotImplementedError( + "Must specify either Statistics or ExtendedStatistics" + ) - datapoints = self.cloudwatch_backend.get_metric_statistics(namespace, metric_name, start_time, end_time, period, statistics) + datapoints = self.cloudwatch_backend.get_metric_statistics( + namespace, metric_name, start_time, end_time, period, statistics + ) template = self.response_template(GET_METRIC_STATISTICS_TEMPLATE) return template.render(label=metric_name, datapoints=datapoints) @@ -114,13 +126,13 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def delete_dashboards(self): - dashboards = self._get_multi_param('DashboardNames.member') + dashboards = self._get_multi_param("DashboardNames.member") if dashboards is None: - return self._error('InvalidParameterValue', 'Need at least 1 dashboard') + return self._error("InvalidParameterValue", "Need at least 1 dashboard") status, error = self.cloudwatch_backend.delete_dashboards(dashboards) if not status: - return self._error('ResourceNotFound', error) + return self._error("ResourceNotFound", error) template = self.response_template(DELETE_DASHBOARD_TEMPLATE) return template.render() @@ -143,18 +155,18 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def get_dashboard(self): - dashboard_name = self._get_param('DashboardName') + dashboard_name = self._get_param("DashboardName") dashboard = self.cloudwatch_backend.get_dashboard(dashboard_name) if dashboard is None: - return self._error('ResourceNotFound', 'Dashboard does not exist') + return self._error("ResourceNotFound", "Dashboard does not exist") template = self.response_template(GET_DASHBOARD_TEMPLATE) return template.render(dashboard=dashboard) @amzn_request_id def list_dashboards(self): - prefix = self._get_param('DashboardNamePrefix', '') + prefix = self._get_param("DashboardNamePrefix", "") dashboards = self.cloudwatch_backend.list_dashboards(prefix) @@ -163,13 +175,13 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def put_dashboard(self): - name = self._get_param('DashboardName') - body = self._get_param('DashboardBody') + name = self._get_param("DashboardName") + body = self._get_param("DashboardBody") try: json.loads(body) except ValueError: - return self._error('InvalidParameterInput', 'Body is invalid JSON') + return self._error("InvalidParameterInput", "Body is invalid JSON") self.cloudwatch_backend.put_dashboard(name, body) @@ -178,12 +190,14 @@ class CloudWatchResponse(BaseResponse): @amzn_request_id def set_alarm_state(self): - alarm_name = self._get_param('AlarmName') - reason = self._get_param('StateReason') - reason_data = self._get_param('StateReasonData') - state_value = self._get_param('StateValue') + alarm_name = self._get_param("AlarmName") + reason = self._get_param("StateReason") + reason_data = self._get_param("StateReasonData") + state_value = self._get_param("StateValue") - self.cloudwatch_backend.set_alarm_state(alarm_name, reason, reason_data, state_value) + self.cloudwatch_backend.set_alarm_state( + alarm_name, reason, reason_data, state_value + ) template = self.response_template(SET_ALARM_STATE_TEMPLATE) return template.render() diff --git a/moto/cloudwatch/urls.py b/moto/cloudwatch/urls.py index 0a9101cfb..b97bcff05 100644 --- a/moto/cloudwatch/urls.py +++ b/moto/cloudwatch/urls.py @@ -1,9 +1,5 @@ from .responses import CloudWatchResponse -url_bases = [ - "https?://monitoring.(.+).amazonaws.com", -] +url_bases = ["https?://monitoring.(.+).amazonaws.com"] -url_paths = { - '{0}/$': CloudWatchResponse.dispatch, -} +url_paths = {"{0}/$": CloudWatchResponse.dispatch} diff --git a/moto/cognitoidentity/__init__.py b/moto/cognitoidentity/__init__.py index 2f040fa19..8045b9d0f 100644 --- a/moto/cognitoidentity/__init__.py +++ b/moto/cognitoidentity/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import cognitoidentity_backends from ..core.models import base_decorator, deprecated_base_decorator -cognitoidentity_backend = cognitoidentity_backends['us-east-1'] +cognitoidentity_backend = cognitoidentity_backends["us-east-1"] mock_cognitoidentity = base_decorator(cognitoidentity_backends) mock_cognitoidentity_deprecated = deprecated_base_decorator(cognitoidentity_backends) diff --git a/moto/cognitoidentity/exceptions.py b/moto/cognitoidentity/exceptions.py index ec22f3b42..44e391abd 100644 --- a/moto/cognitoidentity/exceptions.py +++ b/moto/cognitoidentity/exceptions.py @@ -6,10 +6,8 @@ from werkzeug.exceptions import BadRequest class ResourceNotFoundError(BadRequest): - def __init__(self, message): super(ResourceNotFoundError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) diff --git a/moto/cognitoidentity/models.py b/moto/cognitoidentity/models.py index 6f752ab69..2a4f5d4bc 100644 --- a/moto/cognitoidentity/models.py +++ b/moto/cognitoidentity/models.py @@ -13,22 +13,24 @@ from .utils import get_random_identity_id class CognitoIdentity(BaseModel): - def __init__(self, region, identity_pool_name, **kwargs): self.identity_pool_name = identity_pool_name - self.allow_unauthenticated_identities = kwargs.get('allow_unauthenticated_identities', '') - self.supported_login_providers = kwargs.get('supported_login_providers', {}) - self.developer_provider_name = kwargs.get('developer_provider_name', '') - self.open_id_connect_provider_arns = kwargs.get('open_id_connect_provider_arns', []) - self.cognito_identity_providers = kwargs.get('cognito_identity_providers', []) - self.saml_provider_arns = kwargs.get('saml_provider_arns', []) + self.allow_unauthenticated_identities = kwargs.get( + "allow_unauthenticated_identities", "" + ) + self.supported_login_providers = kwargs.get("supported_login_providers", {}) + self.developer_provider_name = kwargs.get("developer_provider_name", "") + self.open_id_connect_provider_arns = kwargs.get( + "open_id_connect_provider_arns", [] + ) + self.cognito_identity_providers = kwargs.get("cognito_identity_providers", []) + self.saml_provider_arns = kwargs.get("saml_provider_arns", []) self.identity_pool_id = get_random_identity_id(region) self.creation_time = datetime.datetime.utcnow() class CognitoIdentityBackend(BaseBackend): - def __init__(self, region): super(CognitoIdentityBackend, self).__init__() self.region = region @@ -45,47 +47,61 @@ class CognitoIdentityBackend(BaseBackend): if not identity_pool: raise ResourceNotFoundError(identity_pool) - response = json.dumps({ - 'AllowUnauthenticatedIdentities': identity_pool.allow_unauthenticated_identities, - 'CognitoIdentityProviders': identity_pool.cognito_identity_providers, - 'DeveloperProviderName': identity_pool.developer_provider_name, - 'IdentityPoolId': identity_pool.identity_pool_id, - 'IdentityPoolName': identity_pool.identity_pool_name, - 'IdentityPoolTags': {}, - 'OpenIdConnectProviderARNs': identity_pool.open_id_connect_provider_arns, - 'SamlProviderARNs': identity_pool.saml_provider_arns, - 'SupportedLoginProviders': identity_pool.supported_login_providers - }) + response = json.dumps( + { + "AllowUnauthenticatedIdentities": identity_pool.allow_unauthenticated_identities, + "CognitoIdentityProviders": identity_pool.cognito_identity_providers, + "DeveloperProviderName": identity_pool.developer_provider_name, + "IdentityPoolId": identity_pool.identity_pool_id, + "IdentityPoolName": identity_pool.identity_pool_name, + "IdentityPoolTags": {}, + "OpenIdConnectProviderARNs": identity_pool.open_id_connect_provider_arns, + "SamlProviderARNs": identity_pool.saml_provider_arns, + "SupportedLoginProviders": identity_pool.supported_login_providers, + } + ) return response - def create_identity_pool(self, identity_pool_name, allow_unauthenticated_identities, - supported_login_providers, developer_provider_name, open_id_connect_provider_arns, - cognito_identity_providers, saml_provider_arns): - new_identity = CognitoIdentity(self.region, identity_pool_name, + def create_identity_pool( + self, + identity_pool_name, + allow_unauthenticated_identities, + supported_login_providers, + developer_provider_name, + open_id_connect_provider_arns, + cognito_identity_providers, + saml_provider_arns, + ): + new_identity = CognitoIdentity( + self.region, + identity_pool_name, allow_unauthenticated_identities=allow_unauthenticated_identities, supported_login_providers=supported_login_providers, developer_provider_name=developer_provider_name, open_id_connect_provider_arns=open_id_connect_provider_arns, cognito_identity_providers=cognito_identity_providers, - saml_provider_arns=saml_provider_arns) + saml_provider_arns=saml_provider_arns, + ) self.identity_pools[new_identity.identity_pool_id] = new_identity - response = json.dumps({ - 'IdentityPoolId': new_identity.identity_pool_id, - 'IdentityPoolName': new_identity.identity_pool_name, - 'AllowUnauthenticatedIdentities': new_identity.allow_unauthenticated_identities, - 'SupportedLoginProviders': new_identity.supported_login_providers, - 'DeveloperProviderName': new_identity.developer_provider_name, - 'OpenIdConnectProviderARNs': new_identity.open_id_connect_provider_arns, - 'CognitoIdentityProviders': new_identity.cognito_identity_providers, - 'SamlProviderARNs': new_identity.saml_provider_arns - }) + response = json.dumps( + { + "IdentityPoolId": new_identity.identity_pool_id, + "IdentityPoolName": new_identity.identity_pool_name, + "AllowUnauthenticatedIdentities": new_identity.allow_unauthenticated_identities, + "SupportedLoginProviders": new_identity.supported_login_providers, + "DeveloperProviderName": new_identity.developer_provider_name, + "OpenIdConnectProviderARNs": new_identity.open_id_connect_provider_arns, + "CognitoIdentityProviders": new_identity.cognito_identity_providers, + "SamlProviderARNs": new_identity.saml_provider_arns, + } + ) return response def get_id(self): - identity_id = {'IdentityId': get_random_identity_id(self.region)} + identity_id = {"IdentityId": get_random_identity_id(self.region)} return json.dumps(identity_id) def get_credentials_for_identity(self, identity_id): @@ -95,31 +111,26 @@ class CognitoIdentityBackend(BaseBackend): expiration_str = str(iso_8601_datetime_with_milliseconds(expiration)) response = json.dumps( { - "Credentials": - { - "AccessKeyId": "TESTACCESSKEY12345", - "Expiration": expiration_str, - "SecretKey": "ABCSECRETKEY", - "SessionToken": "ABC12345" - }, - "IdentityId": identity_id - }) + "Credentials": { + "AccessKeyId": "TESTACCESSKEY12345", + "Expiration": expiration_str, + "SecretKey": "ABCSECRETKEY", + "SessionToken": "ABC12345", + }, + "IdentityId": identity_id, + } + ) return response def get_open_id_token_for_developer_identity(self, identity_id): response = json.dumps( - { - "IdentityId": identity_id, - "Token": get_random_identity_id(self.region) - }) + {"IdentityId": identity_id, "Token": get_random_identity_id(self.region)} + ) return response def get_open_id_token(self, identity_id): response = json.dumps( - { - "IdentityId": identity_id, - "Token": get_random_identity_id(self.region) - } + {"IdentityId": identity_id, "Token": get_random_identity_id(self.region)} ) return response diff --git a/moto/cognitoidentity/responses.py b/moto/cognitoidentity/responses.py index 709fdb40a..5fd352ba4 100644 --- a/moto/cognitoidentity/responses.py +++ b/moto/cognitoidentity/responses.py @@ -6,15 +6,16 @@ from .utils import get_random_identity_id class CognitoIdentityResponse(BaseResponse): - def create_identity_pool(self): - identity_pool_name = self._get_param('IdentityPoolName') - allow_unauthenticated_identities = self._get_param('AllowUnauthenticatedIdentities') - supported_login_providers = self._get_param('SupportedLoginProviders') - developer_provider_name = self._get_param('DeveloperProviderName') - open_id_connect_provider_arns = self._get_param('OpenIdConnectProviderARNs') - cognito_identity_providers = self._get_param('CognitoIdentityProviders') - saml_provider_arns = self._get_param('SamlProviderARNs') + identity_pool_name = self._get_param("IdentityPoolName") + allow_unauthenticated_identities = self._get_param( + "AllowUnauthenticatedIdentities" + ) + supported_login_providers = self._get_param("SupportedLoginProviders") + developer_provider_name = self._get_param("DeveloperProviderName") + open_id_connect_provider_arns = self._get_param("OpenIdConnectProviderARNs") + cognito_identity_providers = self._get_param("CognitoIdentityProviders") + saml_provider_arns = self._get_param("SamlProviderARNs") return cognitoidentity_backends[self.region].create_identity_pool( identity_pool_name=identity_pool_name, @@ -23,20 +24,27 @@ class CognitoIdentityResponse(BaseResponse): developer_provider_name=developer_provider_name, open_id_connect_provider_arns=open_id_connect_provider_arns, cognito_identity_providers=cognito_identity_providers, - saml_provider_arns=saml_provider_arns) + saml_provider_arns=saml_provider_arns, + ) def get_id(self): return cognitoidentity_backends[self.region].get_id() def describe_identity_pool(self): - return cognitoidentity_backends[self.region].describe_identity_pool(self._get_param('IdentityPoolId')) + return cognitoidentity_backends[self.region].describe_identity_pool( + self._get_param("IdentityPoolId") + ) def get_credentials_for_identity(self): - return cognitoidentity_backends[self.region].get_credentials_for_identity(self._get_param('IdentityId')) + return cognitoidentity_backends[self.region].get_credentials_for_identity( + self._get_param("IdentityId") + ) def get_open_id_token_for_developer_identity(self): - return cognitoidentity_backends[self.region].get_open_id_token_for_developer_identity( - self._get_param('IdentityId') or get_random_identity_id(self.region) + return cognitoidentity_backends[ + self.region + ].get_open_id_token_for_developer_identity( + self._get_param("IdentityId") or get_random_identity_id(self.region) ) def get_open_id_token(self): diff --git a/moto/cognitoidentity/urls.py b/moto/cognitoidentity/urls.py index 3fe63ef07..e96c189b4 100644 --- a/moto/cognitoidentity/urls.py +++ b/moto/cognitoidentity/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import CognitoIdentityResponse -url_bases = [ - "https?://cognito-identity.(.+).amazonaws.com", -] +url_bases = ["https?://cognito-identity.(.+).amazonaws.com"] -url_paths = { - '{0}/$': CognitoIdentityResponse.dispatch, -} +url_paths = {"{0}/$": CognitoIdentityResponse.dispatch} diff --git a/moto/cognitoidp/exceptions.py b/moto/cognitoidp/exceptions.py index d05ee8f09..e52b7c49f 100644 --- a/moto/cognitoidp/exceptions.py +++ b/moto/cognitoidp/exceptions.py @@ -5,50 +5,40 @@ from werkzeug.exceptions import BadRequest class ResourceNotFoundError(BadRequest): - def __init__(self, message): super(ResourceNotFoundError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) class UserNotFoundError(BadRequest): - def __init__(self, message): super(UserNotFoundError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'UserNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "UserNotFoundException"} + ) class UsernameExistsException(BadRequest): - def __init__(self, message): super(UsernameExistsException, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'UsernameExistsException', - }) + self.description = json.dumps( + {"message": message, "__type": "UsernameExistsException"} + ) class GroupExistsException(BadRequest): - def __init__(self, message): super(GroupExistsException, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'GroupExistsException', - }) + self.description = json.dumps( + {"message": message, "__type": "GroupExistsException"} + ) class NotAuthorizedError(BadRequest): - def __init__(self, message): super(NotAuthorizedError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'NotAuthorizedException', - }) + self.description = json.dumps( + {"message": message, "__type": "NotAuthorizedException"} + ) diff --git a/moto/cognitoidp/models.py b/moto/cognitoidp/models.py index 25a173f29..6700920ce 100644 --- a/moto/cognitoidp/models.py +++ b/moto/cognitoidp/models.py @@ -14,8 +14,13 @@ from jose import jws from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel -from .exceptions import GroupExistsException, NotAuthorizedError, ResourceNotFoundError, UserNotFoundError, \ - UsernameExistsException +from .exceptions import ( + GroupExistsException, + NotAuthorizedError, + ResourceNotFoundError, + UserNotFoundError, + UsernameExistsException, +) UserStatus = { "FORCE_CHANGE_PASSWORD": "FORCE_CHANGE_PASSWORD", @@ -45,19 +50,22 @@ def paginate(limit, start_arg="next_token", limit_arg="max_results"): def outer_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): - start = int(default_start if kwargs.get(start_arg) is None else kwargs[start_arg]) + start = int( + default_start if kwargs.get(start_arg) is None else kwargs[start_arg] + ) lim = int(limit if kwargs.get(limit_arg) is None else kwargs[limit_arg]) stop = start + lim result = func(*args, **kwargs) limited_results = list(itertools.islice(result, start, stop)) next_token = stop if stop < len(result) else None return limited_results, next_token + return wrapper + return outer_wrapper class CognitoIdpUserPool(BaseModel): - def __init__(self, region, name, extended_config): self.region = region self.id = "{}_{}".format(self.region, str(uuid.uuid4().hex)) @@ -75,7 +83,9 @@ class CognitoIdpUserPool(BaseModel): self.access_tokens = {} self.id_tokens = {} - with open(os.path.join(os.path.dirname(__file__), "resources/jwks-private.json")) as f: + with open( + os.path.join(os.path.dirname(__file__), "resources/jwks-private.json") + ) as f: self.json_web_key = json.loads(f.read()) def _base_json(self): @@ -92,14 +102,18 @@ class CognitoIdpUserPool(BaseModel): if extended: user_pool_json.update(self.extended_config) else: - user_pool_json["LambdaConfig"] = self.extended_config.get("LambdaConfig") or {} + user_pool_json["LambdaConfig"] = ( + self.extended_config.get("LambdaConfig") or {} + ) return user_pool_json def create_jwt(self, client_id, username, expires_in=60 * 60, extra_data={}): now = int(time.time()) payload = { - "iss": "https://cognito-idp.{}.amazonaws.com/{}".format(self.region, self.id), + "iss": "https://cognito-idp.{}.amazonaws.com/{}".format( + self.region, self.id + ), "sub": self.users[username].id, "aud": client_id, "token_use": "id", @@ -108,7 +122,7 @@ class CognitoIdpUserPool(BaseModel): } payload.update(extra_data) - return jws.sign(payload, self.json_web_key, algorithm='RS256'), expires_in + return jws.sign(payload, self.json_web_key, algorithm="RS256"), expires_in def create_id_token(self, client_id, username): id_token, expires_in = self.create_jwt(client_id, username) @@ -121,11 +135,10 @@ class CognitoIdpUserPool(BaseModel): return refresh_token def create_access_token(self, client_id, username): - extra_data = self.get_user_extra_data_by_client_id( - client_id, username + extra_data = self.get_user_extra_data_by_client_id(client_id, username) + access_token, expires_in = self.create_jwt( + client_id, username, extra_data=extra_data ) - access_token, expires_in = self.create_jwt(client_id, username, - extra_data=extra_data) self.access_tokens[access_token] = (client_id, username) return access_token, expires_in @@ -143,29 +156,27 @@ class CognitoIdpUserPool(BaseModel): current_client = self.clients.get(client_id, None) if current_client: for readable_field in current_client.get_readable_fields(): - attribute = list(filter( - lambda f: f['Name'] == readable_field, - self.users.get(username).attributes - )) + attribute = list( + filter( + lambda f: f["Name"] == readable_field, + self.users.get(username).attributes, + ) + ) if len(attribute) > 0: - extra_data.update({ - attribute[0]['Name']: attribute[0]['Value'] - }) + extra_data.update({attribute[0]["Name"]: attribute[0]["Value"]}) return extra_data class CognitoIdpUserPoolDomain(BaseModel): - def __init__(self, user_pool_id, domain, custom_domain_config=None): self.user_pool_id = user_pool_id self.domain = domain self.custom_domain_config = custom_domain_config or {} def _distribution_name(self): - if self.custom_domain_config and \ - 'CertificateArn' in self.custom_domain_config: + if self.custom_domain_config and "CertificateArn" in self.custom_domain_config: hash = hashlib.md5( - self.custom_domain_config['CertificateArn'].encode('utf-8') + self.custom_domain_config["CertificateArn"].encode("utf-8") ).hexdigest() return "{hash}.cloudfront.net".format(hash=hash[:16]) return None @@ -183,14 +194,11 @@ class CognitoIdpUserPoolDomain(BaseModel): "Version": None, } elif distribution: - return { - "CloudFrontDomain": distribution, - } + return {"CloudFrontDomain": distribution} return None class CognitoIdpUserPoolClient(BaseModel): - def __init__(self, user_pool_id, extended_config): self.user_pool_id = user_pool_id self.id = str(uuid.uuid4()) @@ -212,11 +220,10 @@ class CognitoIdpUserPoolClient(BaseModel): return user_pool_client_json def get_readable_fields(self): - return self.extended_config.get('ReadAttributes', []) + return self.extended_config.get("ReadAttributes", []) class CognitoIdpIdentityProvider(BaseModel): - def __init__(self, name, extended_config): self.name = name self.extended_config = extended_config or {} @@ -240,7 +247,6 @@ class CognitoIdpIdentityProvider(BaseModel): class CognitoIdpGroup(BaseModel): - def __init__(self, user_pool_id, group_name, description, role_arn, precedence): self.user_pool_id = user_pool_id self.group_name = group_name @@ -267,7 +273,6 @@ class CognitoIdpGroup(BaseModel): class CognitoIdpUser(BaseModel): - def __init__(self, user_pool_id, username, password, status, attributes): self.id = str(uuid.uuid4()) self.user_pool_id = user_pool_id @@ -300,19 +305,18 @@ class CognitoIdpUser(BaseModel): { "Enabled": self.enabled, attributes_key: self.attributes, - "MFAOptions": [] + "MFAOptions": [], } ) return user_json def update_attributes(self, new_attributes): - def flatten_attrs(attrs): - return {attr['Name']: attr['Value'] for attr in attrs} + return {attr["Name"]: attr["Value"] for attr in attrs} def expand_attrs(attrs): - return [{'Name': k, 'Value': v} for k, v in attrs.items()] + return [{"Name": k, "Value": v} for k, v in attrs.items()] flat_attributes = flatten_attrs(self.attributes) flat_attributes.update(flatten_attrs(new_attributes)) @@ -320,7 +324,6 @@ class CognitoIdpUser(BaseModel): class CognitoIdpBackend(BaseBackend): - def __init__(self, region): super(CognitoIdpBackend, self).__init__() self.region = region @@ -496,7 +499,9 @@ class CognitoIdpBackend(BaseBackend): if not user_pool: raise ResourceNotFoundError(user_pool_id) - group = CognitoIdpGroup(user_pool_id, group_name, description, role_arn, precedence) + group = CognitoIdpGroup( + user_pool_id, group_name, description, role_arn, precedence + ) if group.group_name in user_pool.groups: raise GroupExistsException("A group with the name already exists") user_pool.groups[group.group_name] = group @@ -565,7 +570,13 @@ class CognitoIdpBackend(BaseBackend): if username in user_pool.users: raise UsernameExistsException(username) - user = CognitoIdpUser(user_pool_id, username, temporary_password, UserStatus["FORCE_CHANGE_PASSWORD"], attributes) + user = CognitoIdpUser( + user_pool_id, + username, + temporary_password, + UserStatus["FORCE_CHANGE_PASSWORD"], + attributes, + ) user_pool.users[user.username] = user return user @@ -611,7 +622,9 @@ class CognitoIdpBackend(BaseBackend): def _log_user_in(self, user_pool, client, username): refresh_token = user_pool.create_refresh_token(client.id, username) - access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) + access_token, id_token, expires_in = user_pool.create_tokens_from_refresh_token( + refresh_token + ) return { "AuthenticationResult": { @@ -654,7 +667,11 @@ class CognitoIdpBackend(BaseBackend): return self._log_user_in(user_pool, client, username) elif auth_flow == "REFRESH_TOKEN": refresh_token = auth_parameters.get("REFRESH_TOKEN") - id_token, access_token, expires_in = user_pool.create_tokens_from_refresh_token(refresh_token) + ( + id_token, + access_token, + expires_in, + ) = user_pool.create_tokens_from_refresh_token(refresh_token) return { "AuthenticationResult": { @@ -666,7 +683,9 @@ class CognitoIdpBackend(BaseBackend): else: return {} - def respond_to_auth_challenge(self, session, client_id, challenge_name, challenge_responses): + def respond_to_auth_challenge( + self, session, client_id, challenge_name, challenge_responses + ): user_pool = self.sessions.get(session) if not user_pool: raise ResourceNotFoundError(session) diff --git a/moto/cognitoidp/responses.py b/moto/cognitoidp/responses.py index 75dd8c181..80247b076 100644 --- a/moto/cognitoidp/responses.py +++ b/moto/cognitoidp/responses.py @@ -8,7 +8,6 @@ from .models import cognitoidp_backends, find_region_by_value class CognitoIdpResponse(BaseResponse): - @property def parameters(self): return json.loads(self.body) @@ -16,10 +15,10 @@ class CognitoIdpResponse(BaseResponse): # User pool def create_user_pool(self): name = self.parameters.pop("PoolName") - user_pool = cognitoidp_backends[self.region].create_user_pool(name, self.parameters) - return json.dumps({ - "UserPool": user_pool.to_json(extended=True) - }) + user_pool = cognitoidp_backends[self.region].create_user_pool( + name, self.parameters + ) + return json.dumps({"UserPool": user_pool.to_json(extended=True)}) def list_user_pools(self): max_results = self._get_param("MaxResults") @@ -27,9 +26,7 @@ class CognitoIdpResponse(BaseResponse): user_pools, next_token = cognitoidp_backends[self.region].list_user_pools( max_results=max_results, next_token=next_token ) - response = { - "UserPools": [user_pool.to_json() for user_pool in user_pools], - } + response = {"UserPools": [user_pool.to_json() for user_pool in user_pools]} if next_token: response["NextToken"] = str(next_token) return json.dumps(response) @@ -37,9 +34,7 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool(self): user_pool_id = self._get_param("UserPoolId") user_pool = cognitoidp_backends[self.region].describe_user_pool(user_pool_id) - return json.dumps({ - "UserPool": user_pool.to_json(extended=True) - }) + return json.dumps({"UserPool": user_pool.to_json(extended=True)}) def delete_user_pool(self): user_pool_id = self._get_param("UserPoolId") @@ -61,14 +56,14 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool_domain(self): domain = self._get_param("Domain") - user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain(domain) + user_pool_domain = cognitoidp_backends[self.region].describe_user_pool_domain( + domain + ) domain_description = {} if user_pool_domain: domain_description = user_pool_domain.to_json() - return json.dumps({ - "DomainDescription": domain_description - }) + return json.dumps({"DomainDescription": domain_description}) def delete_user_pool_domain(self): domain = self._get_param("Domain") @@ -89,19 +84,24 @@ class CognitoIdpResponse(BaseResponse): # User pool client def create_user_pool_client(self): user_pool_id = self.parameters.pop("UserPoolId") - user_pool_client = cognitoidp_backends[self.region].create_user_pool_client(user_pool_id, self.parameters) - return json.dumps({ - "UserPoolClient": user_pool_client.to_json(extended=True) - }) + user_pool_client = cognitoidp_backends[self.region].create_user_pool_client( + user_pool_id, self.parameters + ) + return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) def list_user_pool_clients(self): user_pool_id = self._get_param("UserPoolId") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken", "0") - user_pool_clients, next_token = cognitoidp_backends[self.region].list_user_pool_clients(user_pool_id, - max_results=max_results, next_token=next_token) + user_pool_clients, next_token = cognitoidp_backends[ + self.region + ].list_user_pool_clients( + user_pool_id, max_results=max_results, next_token=next_token + ) response = { - "UserPoolClients": [user_pool_client.to_json() for user_pool_client in user_pool_clients] + "UserPoolClients": [ + user_pool_client.to_json() for user_pool_client in user_pool_clients + ] } if next_token: response["NextToken"] = str(next_token) @@ -110,43 +110,51 @@ class CognitoIdpResponse(BaseResponse): def describe_user_pool_client(self): user_pool_id = self._get_param("UserPoolId") client_id = self._get_param("ClientId") - user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client(user_pool_id, client_id) - return json.dumps({ - "UserPoolClient": user_pool_client.to_json(extended=True) - }) + user_pool_client = cognitoidp_backends[self.region].describe_user_pool_client( + user_pool_id, client_id + ) + return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) def update_user_pool_client(self): user_pool_id = self.parameters.pop("UserPoolId") client_id = self.parameters.pop("ClientId") - user_pool_client = cognitoidp_backends[self.region].update_user_pool_client(user_pool_id, client_id, self.parameters) - return json.dumps({ - "UserPoolClient": user_pool_client.to_json(extended=True) - }) + user_pool_client = cognitoidp_backends[self.region].update_user_pool_client( + user_pool_id, client_id, self.parameters + ) + return json.dumps({"UserPoolClient": user_pool_client.to_json(extended=True)}) def delete_user_pool_client(self): user_pool_id = self._get_param("UserPoolId") client_id = self._get_param("ClientId") - cognitoidp_backends[self.region].delete_user_pool_client(user_pool_id, client_id) + cognitoidp_backends[self.region].delete_user_pool_client( + user_pool_id, client_id + ) return "" # Identity provider def create_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self.parameters.pop("ProviderName") - identity_provider = cognitoidp_backends[self.region].create_identity_provider(user_pool_id, name, self.parameters) - return json.dumps({ - "IdentityProvider": identity_provider.to_json(extended=True) - }) + identity_provider = cognitoidp_backends[self.region].create_identity_provider( + user_pool_id, name, self.parameters + ) + return json.dumps( + {"IdentityProvider": identity_provider.to_json(extended=True)} + ) def list_identity_providers(self): user_pool_id = self._get_param("UserPoolId") max_results = self._get_param("MaxResults") next_token = self._get_param("NextToken", "0") - identity_providers, next_token = cognitoidp_backends[self.region].list_identity_providers( + identity_providers, next_token = cognitoidp_backends[ + self.region + ].list_identity_providers( user_pool_id, max_results=max_results, next_token=next_token ) response = { - "Providers": [identity_provider.to_json() for identity_provider in identity_providers] + "Providers": [ + identity_provider.to_json() for identity_provider in identity_providers + ] } if next_token: response["NextToken"] = str(next_token) @@ -155,18 +163,22 @@ class CognitoIdpResponse(BaseResponse): def describe_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self._get_param("ProviderName") - identity_provider = cognitoidp_backends[self.region].describe_identity_provider(user_pool_id, name) - return json.dumps({ - "IdentityProvider": identity_provider.to_json(extended=True) - }) + identity_provider = cognitoidp_backends[self.region].describe_identity_provider( + user_pool_id, name + ) + return json.dumps( + {"IdentityProvider": identity_provider.to_json(extended=True)} + ) def update_identity_provider(self): user_pool_id = self._get_param("UserPoolId") name = self._get_param("ProviderName") - identity_provider = cognitoidp_backends[self.region].update_identity_provider(user_pool_id, name, self.parameters) - return json.dumps({ - "IdentityProvider": identity_provider.to_json(extended=True) - }) + identity_provider = cognitoidp_backends[self.region].update_identity_provider( + user_pool_id, name, self.parameters + ) + return json.dumps( + {"IdentityProvider": identity_provider.to_json(extended=True)} + ) def delete_identity_provider(self): user_pool_id = self._get_param("UserPoolId") @@ -183,31 +195,21 @@ class CognitoIdpResponse(BaseResponse): precedence = self._get_param("Precedence") group = cognitoidp_backends[self.region].create_group( - user_pool_id, - group_name, - description, - role_arn, - precedence, + user_pool_id, group_name, description, role_arn, precedence ) - return json.dumps({ - "Group": group.to_json(), - }) + return json.dumps({"Group": group.to_json()}) def get_group(self): group_name = self._get_param("GroupName") user_pool_id = self._get_param("UserPoolId") group = cognitoidp_backends[self.region].get_group(user_pool_id, group_name) - return json.dumps({ - "Group": group.to_json(), - }) + return json.dumps({"Group": group.to_json()}) def list_groups(self): user_pool_id = self._get_param("UserPoolId") groups = cognitoidp_backends[self.region].list_groups(user_pool_id) - return json.dumps({ - "Groups": [group.to_json() for group in groups], - }) + return json.dumps({"Groups": [group.to_json() for group in groups]}) def delete_group(self): group_name = self._get_param("GroupName") @@ -221,9 +223,7 @@ class CognitoIdpResponse(BaseResponse): group_name = self._get_param("GroupName") cognitoidp_backends[self.region].admin_add_user_to_group( - user_pool_id, - group_name, - username, + user_pool_id, group_name, username ) return "" @@ -231,18 +231,18 @@ class CognitoIdpResponse(BaseResponse): def list_users_in_group(self): user_pool_id = self._get_param("UserPoolId") group_name = self._get_param("GroupName") - users = cognitoidp_backends[self.region].list_users_in_group(user_pool_id, group_name) - return json.dumps({ - "Users": [user.to_json(extended=True) for user in users], - }) + users = cognitoidp_backends[self.region].list_users_in_group( + user_pool_id, group_name + ) + return json.dumps({"Users": [user.to_json(extended=True) for user in users]}) def admin_list_groups_for_user(self): username = self._get_param("Username") user_pool_id = self._get_param("UserPoolId") - groups = cognitoidp_backends[self.region].admin_list_groups_for_user(user_pool_id, username) - return json.dumps({ - "Groups": [group.to_json() for group in groups], - }) + groups = cognitoidp_backends[self.region].admin_list_groups_for_user( + user_pool_id, username + ) + return json.dumps({"Groups": [group.to_json() for group in groups]}) def admin_remove_user_from_group(self): user_pool_id = self._get_param("UserPoolId") @@ -250,9 +250,7 @@ class CognitoIdpResponse(BaseResponse): group_name = self._get_param("GroupName") cognitoidp_backends[self.region].admin_remove_user_from_group( - user_pool_id, - group_name, - username, + user_pool_id, group_name, username ) return "" @@ -266,28 +264,24 @@ class CognitoIdpResponse(BaseResponse): user_pool_id, username, temporary_password, - self._get_param("UserAttributes", []) + self._get_param("UserAttributes", []), ) - return json.dumps({ - "User": user.to_json(extended=True) - }) + return json.dumps({"User": user.to_json(extended=True)}) def admin_get_user(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") user = cognitoidp_backends[self.region].admin_get_user(user_pool_id, username) - return json.dumps( - user.to_json(extended=True, attributes_key="UserAttributes") - ) + return json.dumps(user.to_json(extended=True, attributes_key="UserAttributes")) def list_users(self): user_pool_id = self._get_param("UserPoolId") limit = self._get_param("Limit") token = self._get_param("PaginationToken") - users, token = cognitoidp_backends[self.region].list_users(user_pool_id, - limit=limit, - pagination_token=token) + users, token = cognitoidp_backends[self.region].list_users( + user_pool_id, limit=limit, pagination_token=token + ) response = {"Users": [user.to_json(extended=True) for user in users]} if token: response["PaginationToken"] = str(token) @@ -318,10 +312,7 @@ class CognitoIdpResponse(BaseResponse): auth_parameters = self._get_param("AuthParameters") auth_result = cognitoidp_backends[self.region].admin_initiate_auth( - user_pool_id, - client_id, - auth_flow, - auth_parameters, + user_pool_id, client_id, auth_flow, auth_parameters ) return json.dumps(auth_result) @@ -332,21 +323,15 @@ class CognitoIdpResponse(BaseResponse): challenge_name = self._get_param("ChallengeName") challenge_responses = self._get_param("ChallengeResponses") auth_result = cognitoidp_backends[self.region].respond_to_auth_challenge( - session, - client_id, - challenge_name, - challenge_responses, + session, client_id, challenge_name, challenge_responses ) return json.dumps(auth_result) def forgot_password(self): - return json.dumps({ - "CodeDeliveryDetails": { - "DeliveryMedium": "EMAIL", - "Destination": "...", - } - }) + return json.dumps( + {"CodeDeliveryDetails": {"DeliveryMedium": "EMAIL", "Destination": "..."}} + ) # This endpoint receives no authorization header, so if moto-server is listening # on localhost (doesn't get a region in the host header), it doesn't know what @@ -357,7 +342,9 @@ class CognitoIdpResponse(BaseResponse): username = self._get_param("Username") password = self._get_param("Password") region = find_region_by_value("client_id", client_id) - cognitoidp_backends[region].confirm_forgot_password(client_id, username, password) + cognitoidp_backends[region].confirm_forgot_password( + client_id, username, password + ) return "" # Ditto the comment on confirm_forgot_password. @@ -366,21 +353,26 @@ class CognitoIdpResponse(BaseResponse): previous_password = self._get_param("PreviousPassword") proposed_password = self._get_param("ProposedPassword") region = find_region_by_value("access_token", access_token) - cognitoidp_backends[region].change_password(access_token, previous_password, proposed_password) + cognitoidp_backends[region].change_password( + access_token, previous_password, proposed_password + ) return "" def admin_update_user_attributes(self): user_pool_id = self._get_param("UserPoolId") username = self._get_param("Username") attributes = self._get_param("UserAttributes") - cognitoidp_backends[self.region].admin_update_user_attributes(user_pool_id, username, attributes) + cognitoidp_backends[self.region].admin_update_user_attributes( + user_pool_id, username, attributes + ) return "" class CognitoIdpJsonWebKeyResponse(BaseResponse): - def __init__(self): - with open(os.path.join(os.path.dirname(__file__), "resources/jwks-public.json")) as f: + with open( + os.path.join(os.path.dirname(__file__), "resources/jwks-public.json") + ) as f: self.json_web_key = f.read() def serve_json_web_key(self, request, full_url, headers): diff --git a/moto/cognitoidp/urls.py b/moto/cognitoidp/urls.py index 77441ed5e..5d1dff1d0 100644 --- a/moto/cognitoidp/urls.py +++ b/moto/cognitoidp/urls.py @@ -1,11 +1,9 @@ from __future__ import unicode_literals from .responses import CognitoIdpResponse, CognitoIdpJsonWebKeyResponse -url_bases = [ - "https?://cognito-idp.(.+).amazonaws.com", -] +url_bases = ["https?://cognito-idp.(.+).amazonaws.com"] url_paths = { - '{0}/$': CognitoIdpResponse.dispatch, - '{0}//.well-known/jwks.json$': CognitoIdpJsonWebKeyResponse().serve_json_web_key, + "{0}/$": CognitoIdpResponse.dispatch, + "{0}//.well-known/jwks.json$": CognitoIdpJsonWebKeyResponse().serve_json_web_key, } diff --git a/moto/compat.py b/moto/compat.py index a92a5f67b..d7f5ab5e6 100644 --- a/moto/compat.py +++ b/moto/compat.py @@ -1,5 +1,5 @@ try: - from collections import OrderedDict # flake8: noqa + from collections import OrderedDict # noqa except ImportError: # python 2.6 or earlier, use backport - from ordereddict import OrderedDict # flake8: noqa + from ordereddict import OrderedDict # noqa diff --git a/moto/config/exceptions.py b/moto/config/exceptions.py index 611f3640c..4a0dc0d73 100644 --- a/moto/config/exceptions.py +++ b/moto/config/exceptions.py @@ -6,8 +6,12 @@ class NameTooLongException(JsonRESTError): code = 400 def __init__(self, name, location): - message = '1 validation error detected: Value \'{name}\' at \'{location}\' failed to satisfy' \ - ' constraint: Member must have length less than or equal to 256'.format(name=name, location=location) + message = ( + "1 validation error detected: Value '{name}' at '{location}' failed to satisfy" + " constraint: Member must have length less than or equal to 256".format( + name=name, location=location + ) + ) super(NameTooLongException, self).__init__("ValidationException", message) @@ -15,41 +19,54 @@ class InvalidConfigurationRecorderNameException(JsonRESTError): code = 400 def __init__(self, name): - message = 'The configuration recorder name \'{name}\' is not valid, blank string.'.format(name=name) - super(InvalidConfigurationRecorderNameException, self).__init__("InvalidConfigurationRecorderNameException", - message) + message = "The configuration recorder name '{name}' is not valid, blank string.".format( + name=name + ) + super(InvalidConfigurationRecorderNameException, self).__init__( + "InvalidConfigurationRecorderNameException", message + ) class MaxNumberOfConfigurationRecordersExceededException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Failed to put configuration recorder \'{name}\' because the maximum number of ' \ - 'configuration recorders: 1 is reached.'.format(name=name) + message = ( + "Failed to put configuration recorder '{name}' because the maximum number of " + "configuration recorders: 1 is reached.".format(name=name) + ) super(MaxNumberOfConfigurationRecordersExceededException, self).__init__( - "MaxNumberOfConfigurationRecordersExceededException", message) + "MaxNumberOfConfigurationRecordersExceededException", message + ) class InvalidRecordingGroupException(JsonRESTError): code = 400 def __init__(self): - message = 'The recording group provided is not valid' - super(InvalidRecordingGroupException, self).__init__("InvalidRecordingGroupException", message) + message = "The recording group provided is not valid" + super(InvalidRecordingGroupException, self).__init__( + "InvalidRecordingGroupException", message + ) class InvalidResourceTypeException(JsonRESTError): code = 400 def __init__(self, bad_list, good_list): - message = '{num} validation error detected: Value \'{bad_list}\' at ' \ - '\'configurationRecorder.recordingGroup.resourceTypes\' failed to satisfy constraint: ' \ - 'Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]'.format( - num=len(bad_list), bad_list=bad_list, good_list=good_list) + message = ( + "{num} validation error detected: Value '{bad_list}' at " + "'configurationRecorder.recordingGroup.resourceTypes' failed to satisfy constraint: " + "Member must satisfy constraint: [Member must satisfy enum value set: {good_list}]".format( + num=len(bad_list), bad_list=bad_list, good_list=good_list + ) + ) # For PY2: message = str(message) - super(InvalidResourceTypeException, self).__init__("ValidationException", message) + super(InvalidResourceTypeException, self).__init__( + "ValidationException", message + ) class NoSuchConfigurationAggregatorException(JsonRESTError): @@ -57,36 +74,48 @@ class NoSuchConfigurationAggregatorException(JsonRESTError): def __init__(self, number=1): if number == 1: - message = 'The configuration aggregator does not exist. Check the configuration aggregator name and try again.' + message = "The configuration aggregator does not exist. Check the configuration aggregator name and try again." else: - message = 'At least one of the configuration aggregators does not exist. Check the configuration aggregator' \ - ' names and try again.' - super(NoSuchConfigurationAggregatorException, self).__init__("NoSuchConfigurationAggregatorException", message) + message = ( + "At least one of the configuration aggregators does not exist. Check the configuration aggregator" + " names and try again." + ) + super(NoSuchConfigurationAggregatorException, self).__init__( + "NoSuchConfigurationAggregatorException", message + ) class NoSuchConfigurationRecorderException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Cannot find configuration recorder with the specified name \'{name}\'.'.format(name=name) - super(NoSuchConfigurationRecorderException, self).__init__("NoSuchConfigurationRecorderException", message) + message = "Cannot find configuration recorder with the specified name '{name}'.".format( + name=name + ) + super(NoSuchConfigurationRecorderException, self).__init__( + "NoSuchConfigurationRecorderException", message + ) class InvalidDeliveryChannelNameException(JsonRESTError): code = 400 def __init__(self, name): - message = 'The delivery channel name \'{name}\' is not valid, blank string.'.format(name=name) - super(InvalidDeliveryChannelNameException, self).__init__("InvalidDeliveryChannelNameException", - message) + message = "The delivery channel name '{name}' is not valid, blank string.".format( + name=name + ) + super(InvalidDeliveryChannelNameException, self).__init__( + "InvalidDeliveryChannelNameException", message + ) class NoSuchBucketException(JsonRESTError): """We are *only* validating that there is value that is not '' here.""" + code = 400 def __init__(self): - message = 'Cannot find a S3 bucket with an empty bucket name.' + message = "Cannot find a S3 bucket with an empty bucket name." super(NoSuchBucketException, self).__init__("NoSuchBucketException", message) @@ -94,89 +123,120 @@ class InvalidNextTokenException(JsonRESTError): code = 400 def __init__(self): - message = 'The nextToken provided is invalid' - super(InvalidNextTokenException, self).__init__("InvalidNextTokenException", message) + message = "The nextToken provided is invalid" + super(InvalidNextTokenException, self).__init__( + "InvalidNextTokenException", message + ) class InvalidS3KeyPrefixException(JsonRESTError): code = 400 def __init__(self): - message = 'The s3 key prefix \'\' is not valid, empty s3 key prefix.' - super(InvalidS3KeyPrefixException, self).__init__("InvalidS3KeyPrefixException", message) + message = "The s3 key prefix '' is not valid, empty s3 key prefix." + super(InvalidS3KeyPrefixException, self).__init__( + "InvalidS3KeyPrefixException", message + ) class InvalidSNSTopicARNException(JsonRESTError): """We are *only* validating that there is value that is not '' here.""" + code = 400 def __init__(self): - message = 'The sns topic arn \'\' is not valid.' - super(InvalidSNSTopicARNException, self).__init__("InvalidSNSTopicARNException", message) + message = "The sns topic arn '' is not valid." + super(InvalidSNSTopicARNException, self).__init__( + "InvalidSNSTopicARNException", message + ) class InvalidDeliveryFrequency(JsonRESTError): code = 400 def __init__(self, value, good_list): - message = '1 validation error detected: Value \'{value}\' at ' \ - '\'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency\' failed to satisfy ' \ - 'constraint: Member must satisfy enum value set: {good_list}'.format(value=value, good_list=good_list) - super(InvalidDeliveryFrequency, self).__init__("InvalidDeliveryFrequency", message) + message = ( + "1 validation error detected: Value '{value}' at " + "'deliveryChannel.configSnapshotDeliveryProperties.deliveryFrequency' failed to satisfy " + "constraint: Member must satisfy enum value set: {good_list}".format( + value=value, good_list=good_list + ) + ) + super(InvalidDeliveryFrequency, self).__init__( + "InvalidDeliveryFrequency", message + ) class MaxNumberOfDeliveryChannelsExceededException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Failed to put delivery channel \'{name}\' because the maximum number of ' \ - 'delivery channels: 1 is reached.'.format(name=name) + message = ( + "Failed to put delivery channel '{name}' because the maximum number of " + "delivery channels: 1 is reached.".format(name=name) + ) super(MaxNumberOfDeliveryChannelsExceededException, self).__init__( - "MaxNumberOfDeliveryChannelsExceededException", message) + "MaxNumberOfDeliveryChannelsExceededException", message + ) class NoSuchDeliveryChannelException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Cannot find delivery channel with specified name \'{name}\'.'.format(name=name) - super(NoSuchDeliveryChannelException, self).__init__("NoSuchDeliveryChannelException", message) + message = "Cannot find delivery channel with specified name '{name}'.".format( + name=name + ) + super(NoSuchDeliveryChannelException, self).__init__( + "NoSuchDeliveryChannelException", message + ) class NoAvailableConfigurationRecorderException(JsonRESTError): code = 400 def __init__(self): - message = 'Configuration recorder is not available to put delivery channel.' - super(NoAvailableConfigurationRecorderException, self).__init__("NoAvailableConfigurationRecorderException", - message) + message = "Configuration recorder is not available to put delivery channel." + super(NoAvailableConfigurationRecorderException, self).__init__( + "NoAvailableConfigurationRecorderException", message + ) class NoAvailableDeliveryChannelException(JsonRESTError): code = 400 def __init__(self): - message = 'Delivery channel is not available to start configuration recorder.' - super(NoAvailableDeliveryChannelException, self).__init__("NoAvailableDeliveryChannelException", message) + message = "Delivery channel is not available to start configuration recorder." + super(NoAvailableDeliveryChannelException, self).__init__( + "NoAvailableDeliveryChannelException", message + ) class LastDeliveryChannelDeleteFailedException(JsonRESTError): code = 400 def __init__(self, name): - message = 'Failed to delete last specified delivery channel with name \'{name}\', because there, ' \ - 'because there is a running configuration recorder.'.format(name=name) - super(LastDeliveryChannelDeleteFailedException, self).__init__("LastDeliveryChannelDeleteFailedException", message) + message = ( + "Failed to delete last specified delivery channel with name '{name}', because there, " + "because there is a running configuration recorder.".format(name=name) + ) + super(LastDeliveryChannelDeleteFailedException, self).__init__( + "LastDeliveryChannelDeleteFailedException", message + ) class TooManyAccountSources(JsonRESTError): code = 400 def __init__(self, length): - locations = ['com.amazonaws.xyz'] * length + locations = ["com.amazonaws.xyz"] * length - message = 'Value \'[{locations}]\' at \'accountAggregationSources\' failed to satisfy constraint: ' \ - 'Member must have length less than or equal to 1'.format(locations=', '.join(locations)) + message = ( + "Value '[{locations}]' at 'accountAggregationSources' failed to satisfy constraint: " + "Member must have length less than or equal to 1".format( + locations=", ".join(locations) + ) + ) super(TooManyAccountSources, self).__init__("ValidationException", message) @@ -185,16 +245,22 @@ class DuplicateTags(JsonRESTError): def __init__(self): super(DuplicateTags, self).__init__( - 'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.') + "InvalidInput", + "Duplicate tag keys found. Please note that Tag keys are case insensitive.", + ) class TagKeyTooBig(JsonRESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): + def __init__(self, tag, param="tags.X.member.key"): super(TagKeyTooBig, self).__init__( - 'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 128".format(tag, param)) + "ValidationException", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 128".format( + tag, param + ), + ) class TagValueTooBig(JsonRESTError): @@ -202,76 +268,100 @@ class TagValueTooBig(JsonRESTError): def __init__(self, tag): super(TagValueTooBig, self).__init__( - 'ValidationException', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " - "constraint: Member must have length less than or equal to 256".format(tag)) + "ValidationException", + "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " + "constraint: Member must have length less than or equal to 256".format(tag), + ) class InvalidParameterValueException(JsonRESTError): code = 400 def __init__(self, message): - super(InvalidParameterValueException, self).__init__('InvalidParameterValueException', message) + super(InvalidParameterValueException, self).__init__( + "InvalidParameterValueException", message + ) class InvalidTagCharacters(JsonRESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): - message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param) - message += 'constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+' + def __init__(self, tag, param="tags.X.member.key"): + message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format( + tag, param + ) + message += "constraint: Member must satisfy regular expression pattern: [\\\\p{L}\\\\p{Z}\\\\p{N}_.:/=+\\\\-@]+" - super(InvalidTagCharacters, self).__init__('ValidationException', message) + super(InvalidTagCharacters, self).__init__("ValidationException", message) class TooManyTags(JsonRESTError): code = 400 - def __init__(self, tags, param='tags'): + def __init__(self, tags, param="tags"): super(TooManyTags, self).__init__( - 'ValidationException', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 50.".format(tags, param)) + "ValidationException", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 50.".format( + tags, param + ), + ) class InvalidResourceParameters(JsonRESTError): code = 400 def __init__(self): - super(InvalidResourceParameters, self).__init__('ValidationException', 'Both Resource ID and Resource Name ' - 'cannot be specified in the request') + super(InvalidResourceParameters, self).__init__( + "ValidationException", + "Both Resource ID and Resource Name " "cannot be specified in the request", + ) class InvalidLimit(JsonRESTError): code = 400 def __init__(self, value): - super(InvalidLimit, self).__init__('ValidationException', 'Value \'{value}\' at \'limit\' failed to satisify constraint: Member' - ' must have value less than or equal to 100'.format(value=value)) + super(InvalidLimit, self).__init__( + "ValidationException", + "Value '{value}' at 'limit' failed to satisify constraint: Member" + " must have value less than or equal to 100".format(value=value), + ) class TooManyResourceIds(JsonRESTError): code = 400 def __init__(self): - super(TooManyResourceIds, self).__init__('ValidationException', "The specified list had more than 20 resource ID's. " - "It must have '20' or less items") + super(TooManyResourceIds, self).__init__( + "ValidationException", + "The specified list had more than 20 resource ID's. " + "It must have '20' or less items", + ) class ResourceNotDiscoveredException(JsonRESTError): code = 400 def __init__(self, type, resource): - super(ResourceNotDiscoveredException, self).__init__('ResourceNotDiscoveredException', - 'Resource {resource} of resourceType:{type} is unknown or has not been ' - 'discovered'.format(resource=resource, type=type)) + super(ResourceNotDiscoveredException, self).__init__( + "ResourceNotDiscoveredException", + "Resource {resource} of resourceType:{type} is unknown or has not been " + "discovered".format(resource=resource, type=type), + ) class TooManyResourceKeys(JsonRESTError): code = 400 def __init__(self, bad_list): - message = '1 validation error detected: Value \'{bad_list}\' at ' \ - '\'resourceKeys\' failed to satisfy constraint: ' \ - 'Member must have length less than or equal to 100'.format(bad_list=bad_list) + message = ( + "1 validation error detected: Value '{bad_list}' at " + "'resourceKeys' failed to satisfy constraint: " + "Member must have length less than or equal to 100".format( + bad_list=bad_list + ) + ) # For PY2: message = str(message) diff --git a/moto/config/models.py b/moto/config/models.py index 3b3544f2d..f608b759a 100644 --- a/moto/config/models.py +++ b/moto/config/models.py @@ -9,35 +9,55 @@ from datetime import datetime from boto3 import Session -from moto.config.exceptions import InvalidResourceTypeException, InvalidDeliveryFrequency, \ - InvalidConfigurationRecorderNameException, NameTooLongException, \ - MaxNumberOfConfigurationRecordersExceededException, InvalidRecordingGroupException, \ - NoSuchConfigurationRecorderException, NoAvailableConfigurationRecorderException, \ - InvalidDeliveryChannelNameException, NoSuchBucketException, InvalidS3KeyPrefixException, \ - InvalidSNSTopicARNException, MaxNumberOfDeliveryChannelsExceededException, NoAvailableDeliveryChannelException, \ - NoSuchDeliveryChannelException, LastDeliveryChannelDeleteFailedException, TagKeyTooBig, \ - TooManyTags, TagValueTooBig, TooManyAccountSources, InvalidParameterValueException, InvalidNextTokenException, \ - NoSuchConfigurationAggregatorException, InvalidTagCharacters, DuplicateTags, InvalidLimit, InvalidResourceParameters, \ - TooManyResourceIds, ResourceNotDiscoveredException, TooManyResourceKeys +from moto.config.exceptions import ( + InvalidResourceTypeException, + InvalidDeliveryFrequency, + InvalidConfigurationRecorderNameException, + NameTooLongException, + MaxNumberOfConfigurationRecordersExceededException, + InvalidRecordingGroupException, + NoSuchConfigurationRecorderException, + NoAvailableConfigurationRecorderException, + InvalidDeliveryChannelNameException, + NoSuchBucketException, + InvalidS3KeyPrefixException, + InvalidSNSTopicARNException, + MaxNumberOfDeliveryChannelsExceededException, + NoAvailableDeliveryChannelException, + NoSuchDeliveryChannelException, + LastDeliveryChannelDeleteFailedException, + TagKeyTooBig, + TooManyTags, + TagValueTooBig, + TooManyAccountSources, + InvalidParameterValueException, + InvalidNextTokenException, + NoSuchConfigurationAggregatorException, + InvalidTagCharacters, + DuplicateTags, + InvalidLimit, + InvalidResourceParameters, + TooManyResourceIds, + ResourceNotDiscoveredException, + TooManyResourceKeys, +) from moto.core import BaseBackend, BaseModel from moto.s3.config import s3_config_query -DEFAULT_ACCOUNT_ID = '123456789012' +DEFAULT_ACCOUNT_ID = "123456789012" POP_STRINGS = [ - 'capitalizeStart', - 'CapitalizeStart', - 'capitalizeArn', - 'CapitalizeArn', - 'capitalizeARN', - 'CapitalizeARN' + "capitalizeStart", + "CapitalizeStart", + "capitalizeArn", + "CapitalizeArn", + "capitalizeARN", + "CapitalizeARN", ] DEFAULT_PAGE_SIZE = 100 # Map the Config resource type to a backend: -RESOURCE_MAP = { - 'AWS::S3::Bucket': s3_config_query -} +RESOURCE_MAP = {"AWS::S3::Bucket": s3_config_query} def datetime2int(date): @@ -45,12 +65,14 @@ def datetime2int(date): def snake_to_camels(original, cap_start, cap_arn): - parts = original.split('_') + parts = original.split("_") - camel_cased = parts[0].lower() + ''.join(p.title() for p in parts[1:]) + camel_cased = parts[0].lower() + "".join(p.title() for p in parts[1:]) if cap_arn: - camel_cased = camel_cased.replace('Arn', 'ARN') # Some config services use 'ARN' instead of 'Arn' + camel_cased = camel_cased.replace( + "Arn", "ARN" + ) # Some config services use 'ARN' instead of 'Arn' if cap_start: camel_cased = camel_cased[0].upper() + camel_cased[1::] @@ -67,7 +89,7 @@ def random_string(): return "".join(chars) -def validate_tag_key(tag_key, exception_param='tags.X.member.key'): +def validate_tag_key(tag_key, exception_param="tags.X.member.key"): """Validates the tag key. :param tag_key: The tag key to check against. @@ -81,7 +103,7 @@ def validate_tag_key(tag_key, exception_param='tags.X.member.key'): # Validate that the tag key fits the proper Regex: # [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+ - match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key) + match = re.findall(r"[\w\s_.:/=+\-@]+", tag_key) # Kudos if you can come up with a better way of doing a global search :) if not len(match) or len(match[0]) < len(tag_key): raise InvalidTagCharacters(tag_key, param=exception_param) @@ -106,14 +128,14 @@ def validate_tags(tags): for tag in tags: # Validate the Key: - validate_tag_key(tag['Key']) - check_tag_duplicate(proper_tags, tag['Key']) + validate_tag_key(tag["Key"]) + check_tag_duplicate(proper_tags, tag["Key"]) # Validate the Value: - if len(tag['Value']) > 256: - raise TagValueTooBig(tag['Value']) + if len(tag["Value"]) > 256: + raise TagValueTooBig(tag["Value"]) - proper_tags[tag['Key']] = tag['Value'] + proper_tags[tag["Key"]] = tag["Value"] return proper_tags @@ -134,9 +156,17 @@ class ConfigEmptyDictable(BaseModel): for item, value in self.__dict__.items(): if value is not None: if isinstance(value, ConfigEmptyDictable): - data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value.to_dict() + data[ + snake_to_camels( + item, self.capitalize_start, self.capitalize_arn + ) + ] = value.to_dict() else: - data[snake_to_camels(item, self.capitalize_start, self.capitalize_arn)] = value + data[ + snake_to_camels( + item, self.capitalize_start, self.capitalize_arn + ) + ] = value # Cleanse the extra properties: for prop in POP_STRINGS: @@ -146,7 +176,6 @@ class ConfigEmptyDictable(BaseModel): class ConfigRecorderStatus(ConfigEmptyDictable): - def __init__(self, name): super(ConfigRecorderStatus, self).__init__() @@ -161,7 +190,7 @@ class ConfigRecorderStatus(ConfigEmptyDictable): def start(self): self.recording = True - self.last_status = 'PENDING' + self.last_status = "PENDING" self.last_start_time = datetime2int(datetime.utcnow()) self.last_status_change_time = datetime2int(datetime.utcnow()) @@ -172,7 +201,6 @@ class ConfigRecorderStatus(ConfigEmptyDictable): class ConfigDeliverySnapshotProperties(ConfigEmptyDictable): - def __init__(self, delivery_frequency): super(ConfigDeliverySnapshotProperties, self).__init__() @@ -180,8 +208,9 @@ class ConfigDeliverySnapshotProperties(ConfigEmptyDictable): class ConfigDeliveryChannel(ConfigEmptyDictable): - - def __init__(self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None): + def __init__( + self, name, s3_bucket_name, prefix=None, sns_arn=None, snapshot_properties=None + ): super(ConfigDeliveryChannel, self).__init__() self.name = name @@ -192,8 +221,12 @@ class ConfigDeliveryChannel(ConfigEmptyDictable): class RecordingGroup(ConfigEmptyDictable): - - def __init__(self, all_supported=True, include_global_resource_types=False, resource_types=None): + def __init__( + self, + all_supported=True, + include_global_resource_types=False, + resource_types=None, + ): super(RecordingGroup, self).__init__() self.all_supported = all_supported @@ -202,8 +235,7 @@ class RecordingGroup(ConfigEmptyDictable): class ConfigRecorder(ConfigEmptyDictable): - - def __init__(self, role_arn, recording_group, name='default', status=None): + def __init__(self, role_arn, recording_group, name="default", status=None): super(ConfigRecorder, self).__init__() self.name = name @@ -217,18 +249,21 @@ class ConfigRecorder(ConfigEmptyDictable): class AccountAggregatorSource(ConfigEmptyDictable): - def __init__(self, account_ids, aws_regions=None, all_aws_regions=None): super(AccountAggregatorSource, self).__init__(capitalize_start=True) # Can't have both the regions and all_regions flag present -- also can't have them both missing: if aws_regions and all_aws_regions: - raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies ' - 'the use of all regions. You must choose one of these options.') + raise InvalidParameterValueException( + "Your configuration aggregator contains a list of regions and also specifies " + "the use of all regions. You must choose one of these options." + ) if not (aws_regions or all_aws_regions): - raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported ' - 'regions and try again.') + raise InvalidParameterValueException( + "Your request does not specify any regions. Select AWS Config-supported " + "regions and try again." + ) self.account_ids = account_ids self.aws_regions = aws_regions @@ -240,18 +275,23 @@ class AccountAggregatorSource(ConfigEmptyDictable): class OrganizationAggregationSource(ConfigEmptyDictable): - def __init__(self, role_arn, aws_regions=None, all_aws_regions=None): - super(OrganizationAggregationSource, self).__init__(capitalize_start=True, capitalize_arn=False) + super(OrganizationAggregationSource, self).__init__( + capitalize_start=True, capitalize_arn=False + ) # Can't have both the regions and all_regions flag present -- also can't have them both missing: if aws_regions and all_aws_regions: - raise InvalidParameterValueException('Your configuration aggregator contains a list of regions and also specifies ' - 'the use of all regions. You must choose one of these options.') + raise InvalidParameterValueException( + "Your configuration aggregator contains a list of regions and also specifies " + "the use of all regions. You must choose one of these options." + ) if not (aws_regions or all_aws_regions): - raise InvalidParameterValueException('Your request does not specify any regions. Select AWS Config-supported ' - 'regions and try again.') + raise InvalidParameterValueException( + "Your request does not specify any regions. Select AWS Config-supported " + "regions and try again." + ) self.role_arn = role_arn self.aws_regions = aws_regions @@ -263,15 +303,14 @@ class OrganizationAggregationSource(ConfigEmptyDictable): class ConfigAggregator(ConfigEmptyDictable): - def __init__(self, name, region, account_sources=None, org_source=None, tags=None): - super(ConfigAggregator, self).__init__(capitalize_start=True, capitalize_arn=False) + super(ConfigAggregator, self).__init__( + capitalize_start=True, capitalize_arn=False + ) self.configuration_aggregator_name = name - self.configuration_aggregator_arn = 'arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}'.format( - region=region, - id=DEFAULT_ACCOUNT_ID, - random=random_string() + self.configuration_aggregator_arn = "arn:aws:config:{region}:{id}:config-aggregator/config-aggregator-{random}".format( + region=region, id=DEFAULT_ACCOUNT_ID, random=random_string() ) self.account_aggregation_sources = account_sources self.organization_aggregation_source = org_source @@ -287,7 +326,9 @@ class ConfigAggregator(ConfigEmptyDictable): # Override the account aggregation sources if present: if self.account_aggregation_sources: - result['AccountAggregationSources'] = [a.to_dict() for a in self.account_aggregation_sources] + result["AccountAggregationSources"] = [ + a.to_dict() for a in self.account_aggregation_sources + ] # Tags are listed in the list_tags_for_resource API call ... not implementing yet -- please feel free to! # if self.tags: @@ -297,15 +338,22 @@ class ConfigAggregator(ConfigEmptyDictable): class ConfigAggregationAuthorization(ConfigEmptyDictable): + def __init__( + self, current_region, authorized_account_id, authorized_aws_region, tags=None + ): + super(ConfigAggregationAuthorization, self).__init__( + capitalize_start=True, capitalize_arn=False + ) - def __init__(self, current_region, authorized_account_id, authorized_aws_region, tags=None): - super(ConfigAggregationAuthorization, self).__init__(capitalize_start=True, capitalize_arn=False) - - self.aggregation_authorization_arn = 'arn:aws:config:{region}:{id}:aggregation-authorization/' \ - '{auth_account}/{auth_region}'.format(region=current_region, - id=DEFAULT_ACCOUNT_ID, - auth_account=authorized_account_id, - auth_region=authorized_aws_region) + self.aggregation_authorization_arn = ( + "arn:aws:config:{region}:{id}:aggregation-authorization/" + "{auth_account}/{auth_region}".format( + region=current_region, + id=DEFAULT_ACCOUNT_ID, + auth_account=authorized_account_id, + auth_region=authorized_aws_region, + ) + ) self.authorized_account_id = authorized_account_id self.authorized_aws_region = authorized_aws_region self.creation_time = datetime2int(datetime.utcnow()) @@ -315,7 +363,6 @@ class ConfigAggregationAuthorization(ConfigEmptyDictable): class ConfigBackend(BaseBackend): - def __init__(self): self.recorders = {} self.delivery_channels = {} @@ -325,9 +372,11 @@ class ConfigBackend(BaseBackend): @staticmethod def _validate_resource_types(resource_list): # Load the service file: - resource_package = 'botocore' - resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) - config_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) + resource_package = "botocore" + resource_path = "/".join(("data", "config", "2014-11-12", "service-2.json")) + config_schema = json.loads( + pkg_resources.resource_string(resource_package, resource_path) + ) # Verify that each entry exists in the supported list: bad_list = [] @@ -335,72 +384,114 @@ class ConfigBackend(BaseBackend): # For PY2: r_str = str(resource) - if r_str not in config_schema['shapes']['ResourceType']['enum']: + if r_str not in config_schema["shapes"]["ResourceType"]["enum"]: bad_list.append(r_str) if bad_list: - raise InvalidResourceTypeException(bad_list, config_schema['shapes']['ResourceType']['enum']) + raise InvalidResourceTypeException( + bad_list, config_schema["shapes"]["ResourceType"]["enum"] + ) @staticmethod def _validate_delivery_snapshot_properties(properties): # Load the service file: - resource_package = 'botocore' - resource_path = '/'.join(('data', 'config', '2014-11-12', 'service-2.json')) - conifg_schema = json.loads(pkg_resources.resource_string(resource_package, resource_path)) + resource_package = "botocore" + resource_path = "/".join(("data", "config", "2014-11-12", "service-2.json")) + conifg_schema = json.loads( + pkg_resources.resource_string(resource_package, resource_path) + ) # Verify that the deliveryFrequency is set to an acceptable value: - if properties.get('deliveryFrequency', None) not in \ - conifg_schema['shapes']['MaximumExecutionFrequency']['enum']: - raise InvalidDeliveryFrequency(properties.get('deliveryFrequency', None), - conifg_schema['shapes']['MaximumExecutionFrequency']['enum']) + if ( + properties.get("deliveryFrequency", None) + not in conifg_schema["shapes"]["MaximumExecutionFrequency"]["enum"] + ): + raise InvalidDeliveryFrequency( + properties.get("deliveryFrequency", None), + conifg_schema["shapes"]["MaximumExecutionFrequency"]["enum"], + ) def put_configuration_aggregator(self, config_aggregator, region): # Validate the name: - if len(config_aggregator['ConfigurationAggregatorName']) > 256: - raise NameTooLongException(config_aggregator['ConfigurationAggregatorName'], 'configurationAggregatorName') + if len(config_aggregator["ConfigurationAggregatorName"]) > 256: + raise NameTooLongException( + config_aggregator["ConfigurationAggregatorName"], + "configurationAggregatorName", + ) account_sources = None org_source = None # Tag validation: - tags = validate_tags(config_aggregator.get('Tags', [])) + tags = validate_tags(config_aggregator.get("Tags", [])) # Exception if both AccountAggregationSources and OrganizationAggregationSource are supplied: - if config_aggregator.get('AccountAggregationSources') and config_aggregator.get('OrganizationAggregationSource'): - raise InvalidParameterValueException('The configuration aggregator cannot be created because your request contains both the' - ' AccountAggregationSource and the OrganizationAggregationSource. Include only ' - 'one aggregation source and try again.') + if config_aggregator.get("AccountAggregationSources") and config_aggregator.get( + "OrganizationAggregationSource" + ): + raise InvalidParameterValueException( + "The configuration aggregator cannot be created because your request contains both the" + " AccountAggregationSource and the OrganizationAggregationSource. Include only " + "one aggregation source and try again." + ) # If neither are supplied: - if not config_aggregator.get('AccountAggregationSources') and not config_aggregator.get('OrganizationAggregationSource'): - raise InvalidParameterValueException('The configuration aggregator cannot be created because your request is missing either ' - 'the AccountAggregationSource or the OrganizationAggregationSource. Include the ' - 'appropriate aggregation source and try again.') + if not config_aggregator.get( + "AccountAggregationSources" + ) and not config_aggregator.get("OrganizationAggregationSource"): + raise InvalidParameterValueException( + "The configuration aggregator cannot be created because your request is missing either " + "the AccountAggregationSource or the OrganizationAggregationSource. Include the " + "appropriate aggregation source and try again." + ) - if config_aggregator.get('AccountAggregationSources'): + if config_aggregator.get("AccountAggregationSources"): # Currently, only 1 account aggregation source can be set: - if len(config_aggregator['AccountAggregationSources']) > 1: - raise TooManyAccountSources(len(config_aggregator['AccountAggregationSources'])) + if len(config_aggregator["AccountAggregationSources"]) > 1: + raise TooManyAccountSources( + len(config_aggregator["AccountAggregationSources"]) + ) account_sources = [] - for a in config_aggregator['AccountAggregationSources']: - account_sources.append(AccountAggregatorSource(a['AccountIds'], aws_regions=a.get('AwsRegions'), - all_aws_regions=a.get('AllAwsRegions'))) + for a in config_aggregator["AccountAggregationSources"]: + account_sources.append( + AccountAggregatorSource( + a["AccountIds"], + aws_regions=a.get("AwsRegions"), + all_aws_regions=a.get("AllAwsRegions"), + ) + ) else: - org_source = OrganizationAggregationSource(config_aggregator['OrganizationAggregationSource']['RoleArn'], - aws_regions=config_aggregator['OrganizationAggregationSource'].get('AwsRegions'), - all_aws_regions=config_aggregator['OrganizationAggregationSource'].get( - 'AllAwsRegions')) + org_source = OrganizationAggregationSource( + config_aggregator["OrganizationAggregationSource"]["RoleArn"], + aws_regions=config_aggregator["OrganizationAggregationSource"].get( + "AwsRegions" + ), + all_aws_regions=config_aggregator["OrganizationAggregationSource"].get( + "AllAwsRegions" + ), + ) # Grab the existing one if it exists and update it: - if not self.config_aggregators.get(config_aggregator['ConfigurationAggregatorName']): - aggregator = ConfigAggregator(config_aggregator['ConfigurationAggregatorName'], region, account_sources=account_sources, - org_source=org_source, tags=tags) - self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] = aggregator + if not self.config_aggregators.get( + config_aggregator["ConfigurationAggregatorName"] + ): + aggregator = ConfigAggregator( + config_aggregator["ConfigurationAggregatorName"], + region, + account_sources=account_sources, + org_source=org_source, + tags=tags, + ) + self.config_aggregators[ + config_aggregator["ConfigurationAggregatorName"] + ] = aggregator else: - aggregator = self.config_aggregators[config_aggregator['ConfigurationAggregatorName']] + aggregator = self.config_aggregators[ + config_aggregator["ConfigurationAggregatorName"] + ] aggregator.tags = tags aggregator.account_aggregation_sources = account_sources aggregator.organization_aggregation_source = org_source @@ -411,7 +502,7 @@ class ConfigBackend(BaseBackend): def describe_configuration_aggregators(self, names, token, limit): limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit agg_list = [] - result = {'ConfigurationAggregators': []} + result = {"ConfigurationAggregators": []} if names: for name in names: @@ -441,11 +532,13 @@ class ConfigBackend(BaseBackend): start = sorted_aggregators.index(token) # Get the list of items to collect: - agg_list = sorted_aggregators[start:(start + limit)] - result['ConfigurationAggregators'] = [self.config_aggregators[agg].to_dict() for agg in agg_list] + agg_list = sorted_aggregators[start : (start + limit)] + result["ConfigurationAggregators"] = [ + self.config_aggregators[agg].to_dict() for agg in agg_list + ] if len(sorted_aggregators) > (start + limit): - result['NextToken'] = sorted_aggregators[start + limit] + result["NextToken"] = sorted_aggregators[start + limit] return result @@ -455,16 +548,22 @@ class ConfigBackend(BaseBackend): del self.config_aggregators[config_aggregator] - def put_aggregation_authorization(self, current_region, authorized_account, authorized_region, tags): + def put_aggregation_authorization( + self, current_region, authorized_account, authorized_region, tags + ): # Tag validation: tags = validate_tags(tags or []) # Does this already exist? - key = '{}/{}'.format(authorized_account, authorized_region) + key = "{}/{}".format(authorized_account, authorized_region) agg_auth = self.aggregation_authorizations.get(key) if not agg_auth: - agg_auth = ConfigAggregationAuthorization(current_region, authorized_account, authorized_region, tags=tags) - self.aggregation_authorizations['{}/{}'.format(authorized_account, authorized_region)] = agg_auth + agg_auth = ConfigAggregationAuthorization( + current_region, authorized_account, authorized_region, tags=tags + ) + self.aggregation_authorizations[ + "{}/{}".format(authorized_account, authorized_region) + ] = agg_auth else: # Only update the tags: agg_auth.tags = tags @@ -473,7 +572,7 @@ class ConfigBackend(BaseBackend): def describe_aggregation_authorizations(self, token, limit): limit = DEFAULT_PAGE_SIZE if not limit or limit < 0 else limit - result = {'AggregationAuthorizations': []} + result = {"AggregationAuthorizations": []} if not self.aggregation_authorizations: return result @@ -492,70 +591,82 @@ class ConfigBackend(BaseBackend): start = sorted_authorizations.index(token) # Get the list of items to collect: - auth_list = sorted_authorizations[start:(start + limit)] - result['AggregationAuthorizations'] = [self.aggregation_authorizations[auth].to_dict() for auth in auth_list] + auth_list = sorted_authorizations[start : (start + limit)] + result["AggregationAuthorizations"] = [ + self.aggregation_authorizations[auth].to_dict() for auth in auth_list + ] if len(sorted_authorizations) > (start + limit): - result['NextToken'] = sorted_authorizations[start + limit] + result["NextToken"] = sorted_authorizations[start + limit] return result def delete_aggregation_authorization(self, authorized_account, authorized_region): # This will always return a 200 -- regardless if there is or isn't an existing # aggregation authorization. - key = '{}/{}'.format(authorized_account, authorized_region) + key = "{}/{}".format(authorized_account, authorized_region) self.aggregation_authorizations.pop(key, None) def put_configuration_recorder(self, config_recorder): # Validate the name: - if not config_recorder.get('name'): - raise InvalidConfigurationRecorderNameException(config_recorder.get('name')) - if len(config_recorder.get('name')) > 256: - raise NameTooLongException(config_recorder.get('name'), 'configurationRecorder.name') + if not config_recorder.get("name"): + raise InvalidConfigurationRecorderNameException(config_recorder.get("name")) + if len(config_recorder.get("name")) > 256: + raise NameTooLongException( + config_recorder.get("name"), "configurationRecorder.name" + ) # We're going to assume that the passed in Role ARN is correct. # Config currently only allows 1 configuration recorder for an account: - if len(self.recorders) == 1 and not self.recorders.get(config_recorder['name']): - raise MaxNumberOfConfigurationRecordersExceededException(config_recorder['name']) + if len(self.recorders) == 1 and not self.recorders.get(config_recorder["name"]): + raise MaxNumberOfConfigurationRecordersExceededException( + config_recorder["name"] + ) # Is this updating an existing one? recorder_status = None - if self.recorders.get(config_recorder['name']): - recorder_status = self.recorders[config_recorder['name']].status + if self.recorders.get(config_recorder["name"]): + recorder_status = self.recorders[config_recorder["name"]].status # Validate the Recording Group: - if config_recorder.get('recordingGroup') is None: + if config_recorder.get("recordingGroup") is None: recording_group = RecordingGroup() else: - rg = config_recorder['recordingGroup'] + rg = config_recorder["recordingGroup"] # If an empty dict is passed in, then bad: if not rg: raise InvalidRecordingGroupException() # Can't have both the resource types specified and the other flags as True. - if rg.get('resourceTypes') and ( - rg.get('allSupported', False) or - rg.get('includeGlobalResourceTypes', False)): + if rg.get("resourceTypes") and ( + rg.get("allSupported", False) + or rg.get("includeGlobalResourceTypes", False) + ): raise InvalidRecordingGroupException() # Must supply resourceTypes if 'allSupported' is not supplied: - if not rg.get('allSupported') and not rg.get('resourceTypes'): + if not rg.get("allSupported") and not rg.get("resourceTypes"): raise InvalidRecordingGroupException() # Validate that the list provided is correct: - self._validate_resource_types(rg.get('resourceTypes', [])) + self._validate_resource_types(rg.get("resourceTypes", [])) recording_group = RecordingGroup( - all_supported=rg.get('allSupported', True), - include_global_resource_types=rg.get('includeGlobalResourceTypes', False), - resource_types=rg.get('resourceTypes', []) + all_supported=rg.get("allSupported", True), + include_global_resource_types=rg.get( + "includeGlobalResourceTypes", False + ), + resource_types=rg.get("resourceTypes", []), ) - self.recorders[config_recorder['name']] = \ - ConfigRecorder(config_recorder['roleARN'], recording_group, name=config_recorder['name'], - status=recorder_status) + self.recorders[config_recorder["name"]] = ConfigRecorder( + config_recorder["roleARN"], + recording_group, + name=config_recorder["name"], + status=recorder_status, + ) def describe_configuration_recorders(self, recorder_names): recorders = [] @@ -597,43 +708,54 @@ class ConfigBackend(BaseBackend): raise NoAvailableConfigurationRecorderException() # Validate the name: - if not delivery_channel.get('name'): - raise InvalidDeliveryChannelNameException(delivery_channel.get('name')) - if len(delivery_channel.get('name')) > 256: - raise NameTooLongException(delivery_channel.get('name'), 'deliveryChannel.name') + if not delivery_channel.get("name"): + raise InvalidDeliveryChannelNameException(delivery_channel.get("name")) + if len(delivery_channel.get("name")) > 256: + raise NameTooLongException( + delivery_channel.get("name"), "deliveryChannel.name" + ) # We are going to assume that the bucket exists -- but will verify if the bucket provided is blank: - if not delivery_channel.get('s3BucketName'): + if not delivery_channel.get("s3BucketName"): raise NoSuchBucketException() # We are going to assume that the bucket has the correct policy attached to it. We are only going to verify # if the prefix provided is not an empty string: - if delivery_channel.get('s3KeyPrefix', None) == '': + if delivery_channel.get("s3KeyPrefix", None) == "": raise InvalidS3KeyPrefixException() # Ditto for SNS -- Only going to assume that the ARN provided is not an empty string: - if delivery_channel.get('snsTopicARN', None) == '': + if delivery_channel.get("snsTopicARN", None) == "": raise InvalidSNSTopicARNException() # Config currently only allows 1 delivery channel for an account: - if len(self.delivery_channels) == 1 and not self.delivery_channels.get(delivery_channel['name']): - raise MaxNumberOfDeliveryChannelsExceededException(delivery_channel['name']) + if len(self.delivery_channels) == 1 and not self.delivery_channels.get( + delivery_channel["name"] + ): + raise MaxNumberOfDeliveryChannelsExceededException(delivery_channel["name"]) - if not delivery_channel.get('configSnapshotDeliveryProperties'): + if not delivery_channel.get("configSnapshotDeliveryProperties"): dp = None else: # Validate the config snapshot delivery properties: - self._validate_delivery_snapshot_properties(delivery_channel['configSnapshotDeliveryProperties']) + self._validate_delivery_snapshot_properties( + delivery_channel["configSnapshotDeliveryProperties"] + ) dp = ConfigDeliverySnapshotProperties( - delivery_channel['configSnapshotDeliveryProperties']['deliveryFrequency']) + delivery_channel["configSnapshotDeliveryProperties"][ + "deliveryFrequency" + ] + ) - self.delivery_channels[delivery_channel['name']] = \ - ConfigDeliveryChannel(delivery_channel['name'], delivery_channel['s3BucketName'], - prefix=delivery_channel.get('s3KeyPrefix', None), - sns_arn=delivery_channel.get('snsTopicARN', None), - snapshot_properties=dp) + self.delivery_channels[delivery_channel["name"]] = ConfigDeliveryChannel( + delivery_channel["name"], + delivery_channel["s3BucketName"], + prefix=delivery_channel.get("s3KeyPrefix", None), + sns_arn=delivery_channel.get("snsTopicARN", None), + snapshot_properties=dp, + ) def describe_delivery_channels(self, channel_names): channels = [] @@ -687,7 +809,15 @@ class ConfigBackend(BaseBackend): del self.delivery_channels[channel_name] - def list_discovered_resources(self, resource_type, backend_region, resource_ids, resource_name, limit, next_token): + def list_discovered_resources( + self, + resource_type, + backend_region, + resource_ids, + resource_name, + limit, + next_token, + ): """This will query against the mocked AWS Config (non-aggregated) listing function that must exist for the resource backend. :param resource_type: @@ -716,33 +846,45 @@ class ConfigBackend(BaseBackend): # call upon the resource type's Config Query class to retrieve the list of resources that match the criteria: if RESOURCE_MAP.get(resource_type, {}): # Is this a global resource type? -- if so, re-write the region to 'global': - backend_query_region = backend_region # Always provide the backend this request arrived from. - if RESOURCE_MAP[resource_type].backends.get('global'): - backend_region = 'global' + backend_query_region = ( + backend_region # Always provide the backend this request arrived from. + ) + if RESOURCE_MAP[resource_type].backends.get("global"): + backend_region = "global" # For non-aggregated queries, the we only care about the backend_region. Need to verify that moto has implemented # the region for the given backend: if RESOURCE_MAP[resource_type].backends.get(backend_region): # Fetch the resources for the backend's region: - identifiers, new_token = \ - RESOURCE_MAP[resource_type].list_config_service_resources(resource_ids, resource_name, limit, next_token, - backend_region=backend_query_region) + identifiers, new_token = RESOURCE_MAP[ + resource_type + ].list_config_service_resources( + resource_ids, + resource_name, + limit, + next_token, + backend_region=backend_query_region, + ) - result = {'resourceIdentifiers': [ - { - 'resourceType': identifier['type'], - 'resourceId': identifier['id'], - 'resourceName': identifier['name'] - } - for identifier in identifiers] + result = { + "resourceIdentifiers": [ + { + "resourceType": identifier["type"], + "resourceId": identifier["id"], + "resourceName": identifier["name"], + } + for identifier in identifiers + ] } if new_token: - result['nextToken'] = new_token + result["nextToken"] = new_token return result - def list_aggregate_discovered_resources(self, aggregator_name, resource_type, filters, limit, next_token): + def list_aggregate_discovered_resources( + self, aggregator_name, resource_type, filters, limit, next_token + ): """This will query against the mocked AWS Config listing function that must exist for the resource backend. As far a moto goes -- the only real difference between this function and the `list_discovered_resources` function is that @@ -770,27 +912,35 @@ class ConfigBackend(BaseBackend): # call upon the resource type's Config Query class to retrieve the list of resources that match the criteria: if RESOURCE_MAP.get(resource_type, {}): # We only care about a filter's Region, Resource Name, and Resource ID: - resource_region = filters.get('Region') - resource_id = [filters['ResourceId']] if filters.get('ResourceId') else None - resource_name = filters.get('ResourceName') + resource_region = filters.get("Region") + resource_id = [filters["ResourceId"]] if filters.get("ResourceId") else None + resource_name = filters.get("ResourceName") - identifiers, new_token = \ - RESOURCE_MAP[resource_type].list_config_service_resources(resource_id, resource_name, limit, next_token, - resource_region=resource_region) + identifiers, new_token = RESOURCE_MAP[ + resource_type + ].list_config_service_resources( + resource_id, + resource_name, + limit, + next_token, + resource_region=resource_region, + ) - result = {'ResourceIdentifiers': [ - { - 'SourceAccountId': DEFAULT_ACCOUNT_ID, - 'SourceRegion': identifier['region'], - 'ResourceType': identifier['type'], - 'ResourceId': identifier['id'], - 'ResourceName': identifier['name'] - } - for identifier in identifiers] + result = { + "ResourceIdentifiers": [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": identifier["region"], + "ResourceType": identifier["type"], + "ResourceId": identifier["id"], + "ResourceName": identifier["name"], + } + for identifier in identifiers + ] } if new_token: - result['NextToken'] = new_token + result["NextToken"] = new_token return result @@ -806,22 +956,26 @@ class ConfigBackend(BaseBackend): raise ResourceNotDiscoveredException(resource_type, id) # Is the resource type global? - backend_query_region = backend_region # Always provide the backend this request arrived from. - if RESOURCE_MAP[resource_type].backends.get('global'): - backend_region = 'global' + backend_query_region = ( + backend_region # Always provide the backend this request arrived from. + ) + if RESOURCE_MAP[resource_type].backends.get("global"): + backend_region = "global" # If the backend region isn't implemented then we won't find the item: if not RESOURCE_MAP[resource_type].backends.get(backend_region): raise ResourceNotDiscoveredException(resource_type, id) # Get the item: - item = RESOURCE_MAP[resource_type].get_config_resource(id, backend_region=backend_query_region) + item = RESOURCE_MAP[resource_type].get_config_resource( + id, backend_region=backend_query_region + ) if not item: raise ResourceNotDiscoveredException(resource_type, id) - item['accountId'] = DEFAULT_ACCOUNT_ID + item["accountId"] = DEFAULT_ACCOUNT_ID - return {'configurationItems': [item]} + return {"configurationItems": [item]} def batch_get_resource_config(self, resource_keys, backend_region): """Returns the configuration of an item in the AWS Config format of the resource for the current regional backend. @@ -831,37 +985,50 @@ class ConfigBackend(BaseBackend): """ # Can't have more than 100 items if len(resource_keys) > 100: - raise TooManyResourceKeys(['com.amazonaws.starling.dove.ResourceKey@12345'] * len(resource_keys)) + raise TooManyResourceKeys( + ["com.amazonaws.starling.dove.ResourceKey@12345"] * len(resource_keys) + ) results = [] for resource in resource_keys: # Does the resource type exist? - if not RESOURCE_MAP.get(resource['resourceType']): + if not RESOURCE_MAP.get(resource["resourceType"]): # Not found so skip. continue # Is the resource type global? config_backend_region = backend_region - backend_query_region = backend_region # Always provide the backend this request arrived from. - if RESOURCE_MAP[resource['resourceType']].backends.get('global'): - config_backend_region = 'global' + backend_query_region = ( + backend_region # Always provide the backend this request arrived from. + ) + if RESOURCE_MAP[resource["resourceType"]].backends.get("global"): + config_backend_region = "global" # If the backend region isn't implemented then we won't find the item: - if not RESOURCE_MAP[resource['resourceType']].backends.get(config_backend_region): + if not RESOURCE_MAP[resource["resourceType"]].backends.get( + config_backend_region + ): continue # Get the item: - item = RESOURCE_MAP[resource['resourceType']].get_config_resource(resource['resourceId'], backend_region=backend_query_region) + item = RESOURCE_MAP[resource["resourceType"]].get_config_resource( + resource["resourceId"], backend_region=backend_query_region + ) if not item: continue - item['accountId'] = DEFAULT_ACCOUNT_ID + item["accountId"] = DEFAULT_ACCOUNT_ID results.append(item) - return {'baseConfigurationItems': results, 'unprocessedResourceKeys': []} # At this time, moto is not adding unprocessed items. + return { + "baseConfigurationItems": results, + "unprocessedResourceKeys": [], + } # At this time, moto is not adding unprocessed items. - def batch_get_aggregate_resource_config(self, aggregator_name, resource_identifiers): + def batch_get_aggregate_resource_config( + self, aggregator_name, resource_identifiers + ): """Returns the configuration of an item in the AWS Config format of the resource for the current regional backend. As far a moto goes -- the only real difference between this function and the `batch_get_resource_config` function is that @@ -874,15 +1041,18 @@ class ConfigBackend(BaseBackend): # Can't have more than 100 items if len(resource_identifiers) > 100: - raise TooManyResourceKeys(['com.amazonaws.starling.dove.AggregateResourceIdentifier@12345'] * len(resource_identifiers)) + raise TooManyResourceKeys( + ["com.amazonaws.starling.dove.AggregateResourceIdentifier@12345"] + * len(resource_identifiers) + ) found = [] not_found = [] for identifier in resource_identifiers: - resource_type = identifier['ResourceType'] - resource_region = identifier['SourceRegion'] - resource_id = identifier['ResourceId'] - resource_name = identifier.get('ResourceName', None) + resource_type = identifier["ResourceType"] + resource_region = identifier["SourceRegion"] + resource_id = identifier["ResourceId"] + resource_name = identifier.get("ResourceName", None) # Does the resource type exist? if not RESOURCE_MAP.get(resource_type): @@ -890,23 +1060,29 @@ class ConfigBackend(BaseBackend): continue # Get the item: - item = RESOURCE_MAP[resource_type].get_config_resource(resource_id, resource_name=resource_name, - resource_region=resource_region) + item = RESOURCE_MAP[resource_type].get_config_resource( + resource_id, + resource_name=resource_name, + resource_region=resource_region, + ) if not item: not_found.append(identifier) continue - item['accountId'] = DEFAULT_ACCOUNT_ID + item["accountId"] = DEFAULT_ACCOUNT_ID # The 'tags' field is not included in aggregate results for some reason... - item.pop('tags', None) + item.pop("tags", None) found.append(item) - return {'BaseConfigurationItems': found, 'UnprocessedResourceIdentifiers': not_found} + return { + "BaseConfigurationItems": found, + "UnprocessedResourceIdentifiers": not_found, + } config_backends = {} boto3_session = Session() -for region in boto3_session.get_available_regions('config'): +for region in boto3_session.get_available_regions("config"): config_backends[region] = ConfigBackend() diff --git a/moto/config/responses.py b/moto/config/responses.py index f10d48b71..e977945c9 100644 --- a/moto/config/responses.py +++ b/moto/config/responses.py @@ -4,116 +4,150 @@ from .models import config_backends class ConfigResponse(BaseResponse): - @property def config_backend(self): return config_backends[self.region] def put_configuration_recorder(self): - self.config_backend.put_configuration_recorder(self._get_param('ConfigurationRecorder')) + self.config_backend.put_configuration_recorder( + self._get_param("ConfigurationRecorder") + ) return "" def put_configuration_aggregator(self): - aggregator = self.config_backend.put_configuration_aggregator(json.loads(self.body), self.region) - schema = {'ConfigurationAggregator': aggregator} + aggregator = self.config_backend.put_configuration_aggregator( + json.loads(self.body), self.region + ) + schema = {"ConfigurationAggregator": aggregator} return json.dumps(schema) def describe_configuration_aggregators(self): - aggregators = self.config_backend.describe_configuration_aggregators(self._get_param('ConfigurationAggregatorNames'), - self._get_param('NextToken'), - self._get_param('Limit')) + aggregators = self.config_backend.describe_configuration_aggregators( + self._get_param("ConfigurationAggregatorNames"), + self._get_param("NextToken"), + self._get_param("Limit"), + ) return json.dumps(aggregators) def delete_configuration_aggregator(self): - self.config_backend.delete_configuration_aggregator(self._get_param('ConfigurationAggregatorName')) + self.config_backend.delete_configuration_aggregator( + self._get_param("ConfigurationAggregatorName") + ) return "" def put_aggregation_authorization(self): - agg_auth = self.config_backend.put_aggregation_authorization(self.region, - self._get_param('AuthorizedAccountId'), - self._get_param('AuthorizedAwsRegion'), - self._get_param('Tags')) - schema = {'AggregationAuthorization': agg_auth} + agg_auth = self.config_backend.put_aggregation_authorization( + self.region, + self._get_param("AuthorizedAccountId"), + self._get_param("AuthorizedAwsRegion"), + self._get_param("Tags"), + ) + schema = {"AggregationAuthorization": agg_auth} return json.dumps(schema) def describe_aggregation_authorizations(self): - authorizations = self.config_backend.describe_aggregation_authorizations(self._get_param('NextToken'), self._get_param('Limit')) + authorizations = self.config_backend.describe_aggregation_authorizations( + self._get_param("NextToken"), self._get_param("Limit") + ) return json.dumps(authorizations) def delete_aggregation_authorization(self): - self.config_backend.delete_aggregation_authorization(self._get_param('AuthorizedAccountId'), self._get_param('AuthorizedAwsRegion')) + self.config_backend.delete_aggregation_authorization( + self._get_param("AuthorizedAccountId"), + self._get_param("AuthorizedAwsRegion"), + ) return "" def describe_configuration_recorders(self): - recorders = self.config_backend.describe_configuration_recorders(self._get_param('ConfigurationRecorderNames')) - schema = {'ConfigurationRecorders': recorders} + recorders = self.config_backend.describe_configuration_recorders( + self._get_param("ConfigurationRecorderNames") + ) + schema = {"ConfigurationRecorders": recorders} return json.dumps(schema) def describe_configuration_recorder_status(self): recorder_statuses = self.config_backend.describe_configuration_recorder_status( - self._get_param('ConfigurationRecorderNames')) - schema = {'ConfigurationRecordersStatus': recorder_statuses} + self._get_param("ConfigurationRecorderNames") + ) + schema = {"ConfigurationRecordersStatus": recorder_statuses} return json.dumps(schema) def put_delivery_channel(self): - self.config_backend.put_delivery_channel(self._get_param('DeliveryChannel')) + self.config_backend.put_delivery_channel(self._get_param("DeliveryChannel")) return "" def describe_delivery_channels(self): - delivery_channels = self.config_backend.describe_delivery_channels(self._get_param('DeliveryChannelNames')) - schema = {'DeliveryChannels': delivery_channels} + delivery_channels = self.config_backend.describe_delivery_channels( + self._get_param("DeliveryChannelNames") + ) + schema = {"DeliveryChannels": delivery_channels} return json.dumps(schema) def describe_delivery_channel_status(self): raise NotImplementedError() def delete_delivery_channel(self): - self.config_backend.delete_delivery_channel(self._get_param('DeliveryChannelName')) + self.config_backend.delete_delivery_channel( + self._get_param("DeliveryChannelName") + ) return "" def delete_configuration_recorder(self): - self.config_backend.delete_configuration_recorder(self._get_param('ConfigurationRecorderName')) + self.config_backend.delete_configuration_recorder( + self._get_param("ConfigurationRecorderName") + ) return "" def start_configuration_recorder(self): - self.config_backend.start_configuration_recorder(self._get_param('ConfigurationRecorderName')) + self.config_backend.start_configuration_recorder( + self._get_param("ConfigurationRecorderName") + ) return "" def stop_configuration_recorder(self): - self.config_backend.stop_configuration_recorder(self._get_param('ConfigurationRecorderName')) + self.config_backend.stop_configuration_recorder( + self._get_param("ConfigurationRecorderName") + ) return "" def list_discovered_resources(self): - schema = self.config_backend.list_discovered_resources(self._get_param('resourceType'), - self.region, - self._get_param('resourceIds'), - self._get_param('resourceName'), - self._get_param('limit'), - self._get_param('nextToken')) + schema = self.config_backend.list_discovered_resources( + self._get_param("resourceType"), + self.region, + self._get_param("resourceIds"), + self._get_param("resourceName"), + self._get_param("limit"), + self._get_param("nextToken"), + ) return json.dumps(schema) def list_aggregate_discovered_resources(self): - schema = self.config_backend.list_aggregate_discovered_resources(self._get_param('ConfigurationAggregatorName'), - self._get_param('ResourceType'), - self._get_param('Filters'), - self._get_param('Limit'), - self._get_param('NextToken')) + schema = self.config_backend.list_aggregate_discovered_resources( + self._get_param("ConfigurationAggregatorName"), + self._get_param("ResourceType"), + self._get_param("Filters"), + self._get_param("Limit"), + self._get_param("NextToken"), + ) return json.dumps(schema) def get_resource_config_history(self): - schema = self.config_backend.get_resource_config_history(self._get_param('resourceType'), - self._get_param('resourceId'), - self.region) + schema = self.config_backend.get_resource_config_history( + self._get_param("resourceType"), self._get_param("resourceId"), self.region + ) return json.dumps(schema) def batch_get_resource_config(self): - schema = self.config_backend.batch_get_resource_config(self._get_param('resourceKeys'), - self.region) + schema = self.config_backend.batch_get_resource_config( + self._get_param("resourceKeys"), self.region + ) return json.dumps(schema) def batch_get_aggregate_resource_config(self): - schema = self.config_backend.batch_get_aggregate_resource_config(self._get_param('ConfigurationAggregatorName'), - self._get_param('ResourceIdentifiers')) + schema = self.config_backend.batch_get_aggregate_resource_config( + self._get_param("ConfigurationAggregatorName"), + self._get_param("ResourceIdentifiers"), + ) return json.dumps(schema) diff --git a/moto/config/urls.py b/moto/config/urls.py index fd7b6969f..62cf34a52 100644 --- a/moto/config/urls.py +++ b/moto/config/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import ConfigResponse -url_bases = [ - "https?://config.(.+).amazonaws.com", -] +url_bases = ["https?://config.(.+).amazonaws.com"] -url_paths = { - '{0}/$': ConfigResponse.dispatch, -} +url_paths = {"{0}/$": ConfigResponse.dispatch} diff --git a/moto/core/__init__.py b/moto/core/__init__.py index 801e675df..4a4dfdfb6 100644 --- a/moto/core/__init__.py +++ b/moto/core/__init__.py @@ -1,7 +1,9 @@ from __future__ import unicode_literals -from .models import BaseModel, BaseBackend, moto_api_backend # flake8: noqa +from .models import BaseModel, BaseBackend, moto_api_backend # noqa from .responses import ActionAuthenticatorMixin moto_api_backends = {"global": moto_api_backend} -set_initial_no_auth_action_count = ActionAuthenticatorMixin.set_initial_no_auth_action_count +set_initial_no_auth_action_count = ( + ActionAuthenticatorMixin.set_initial_no_auth_action_count +) diff --git a/moto/core/access_control.py b/moto/core/access_control.py index 3fb11eebd..9991063f9 100644 --- a/moto/core/access_control.py +++ b/moto/core/access_control.py @@ -26,7 +26,12 @@ from six import string_types from moto.iam.models import ACCOUNT_ID, Policy from moto.iam import iam_backend -from moto.core.exceptions import SignatureDoesNotMatchError, AccessDeniedError, InvalidClientTokenIdError, AuthFailureError +from moto.core.exceptions import ( + SignatureDoesNotMatchError, + AccessDeniedError, + InvalidClientTokenIdError, + AuthFailureError, +) from moto.s3.exceptions import ( BucketAccessDeniedError, S3AccessDeniedError, @@ -35,7 +40,7 @@ from moto.s3.exceptions import ( S3InvalidAccessKeyIdError, BucketInvalidAccessKeyIdError, BucketSignatureDoesNotMatchError, - S3SignatureDoesNotMatchError + S3SignatureDoesNotMatchError, ) from moto.sts import sts_backend @@ -50,9 +55,8 @@ def create_access_key(access_key_id, headers): class IAMUserAccessKey(object): - def __init__(self, access_key_id, headers): - iam_users = iam_backend.list_users('/', None, None) + iam_users = iam_backend.list_users("/", None, None) for iam_user in iam_users: for access_key in iam_user.access_keys: if access_key.access_key_id == access_key_id: @@ -67,8 +71,7 @@ class IAMUserAccessKey(object): @property def arn(self): return "arn:aws:iam::{account_id}:user/{iam_user_name}".format( - account_id=ACCOUNT_ID, - iam_user_name=self._owner_user_name + account_id=ACCOUNT_ID, iam_user_name=self._owner_user_name ) def create_credentials(self): @@ -79,27 +82,34 @@ class IAMUserAccessKey(object): inline_policy_names = iam_backend.list_user_policies(self._owner_user_name) for inline_policy_name in inline_policy_names: - inline_policy = iam_backend.get_user_policy(self._owner_user_name, inline_policy_name) + inline_policy = iam_backend.get_user_policy( + self._owner_user_name, inline_policy_name + ) user_policies.append(inline_policy) - attached_policies, _ = iam_backend.list_attached_user_policies(self._owner_user_name) + attached_policies, _ = iam_backend.list_attached_user_policies( + self._owner_user_name + ) user_policies += attached_policies user_groups = iam_backend.get_groups_for_user(self._owner_user_name) for user_group in user_groups: inline_group_policy_names = iam_backend.list_group_policies(user_group.name) for inline_group_policy_name in inline_group_policy_names: - inline_user_group_policy = iam_backend.get_group_policy(user_group.name, inline_group_policy_name) + inline_user_group_policy = iam_backend.get_group_policy( + user_group.name, inline_group_policy_name + ) user_policies.append(inline_user_group_policy) - attached_group_policies, _ = iam_backend.list_attached_group_policies(user_group.name) + attached_group_policies, _ = iam_backend.list_attached_group_policies( + user_group.name + ) user_policies += attached_group_policies return user_policies class AssumedRoleAccessKey(object): - def __init__(self, access_key_id, headers): for assumed_role in sts_backend.assumed_roles: if assumed_role.access_key_id == access_key_id: @@ -118,28 +128,33 @@ class AssumedRoleAccessKey(object): return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( account_id=ACCOUNT_ID, role_name=self._owner_role_name, - session_name=self._session_name + session_name=self._session_name, ) def create_credentials(self): - return Credentials(self._access_key_id, self._secret_access_key, self._session_token) + return Credentials( + self._access_key_id, self._secret_access_key, self._session_token + ) def collect_policies(self): role_policies = [] inline_policy_names = iam_backend.list_role_policies(self._owner_role_name) for inline_policy_name in inline_policy_names: - _, inline_policy = iam_backend.get_role_policy(self._owner_role_name, inline_policy_name) + _, inline_policy = iam_backend.get_role_policy( + self._owner_role_name, inline_policy_name + ) role_policies.append(inline_policy) - attached_policies, _ = iam_backend.list_attached_role_policies(self._owner_role_name) + attached_policies, _ = iam_backend.list_attached_role_policies( + self._owner_role_name + ) role_policies += attached_policies return role_policies class CreateAccessKeyFailure(Exception): - def __init__(self, reason, *args): super(CreateAccessKeyFailure, self).__init__(*args) self.reason = reason @@ -147,32 +162,54 @@ class CreateAccessKeyFailure(Exception): @six.add_metaclass(ABCMeta) class IAMRequestBase(object): - def __init__(self, method, path, data, headers): - log.debug("Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format( - class_name=self.__class__.__name__, method=method, path=path, data=data, headers=headers)) + log.debug( + "Creating {class_name} with method={method}, path={path}, data={data}, headers={headers}".format( + class_name=self.__class__.__name__, + method=method, + path=path, + data=data, + headers=headers, + ) + ) self._method = method self._path = path self._data = data self._headers = headers - credential_scope = self._get_string_between('Credential=', ',', self._headers['Authorization']) - credential_data = credential_scope.split('/') + credential_scope = self._get_string_between( + "Credential=", ",", self._headers["Authorization"] + ) + credential_data = credential_scope.split("/") self._region = credential_data[2] self._service = credential_data[3] - self._action = self._service + ":" + (self._data["Action"][0] if isinstance(self._data["Action"], list) else self._data["Action"]) + self._action = ( + self._service + + ":" + + ( + self._data["Action"][0] + if isinstance(self._data["Action"], list) + else self._data["Action"] + ) + ) try: - self._access_key = create_access_key(access_key_id=credential_data[0], headers=headers) + self._access_key = create_access_key( + access_key_id=credential_data[0], headers=headers + ) except CreateAccessKeyFailure as e: self._raise_invalid_access_key(e.reason) def check_signature(self): - original_signature = self._get_string_between('Signature=', ',', self._headers['Authorization']) + original_signature = self._get_string_between( + "Signature=", ",", self._headers["Authorization"] + ) calculated_signature = self._calculate_signature() if original_signature != calculated_signature: self._raise_signature_does_not_match() def check_action_permitted(self): - if self._action == 'sts:GetCallerIdentity': # always allowed, even if there's an explicit Deny for it + if ( + self._action == "sts:GetCallerIdentity" + ): # always allowed, even if there's an explicit Deny for it return True policies = self._access_key.collect_policies() @@ -213,10 +250,14 @@ class IAMRequestBase(object): return headers def _create_aws_request(self): - signed_headers = self._get_string_between('SignedHeaders=', ',', self._headers['Authorization']).split(';') + signed_headers = self._get_string_between( + "SignedHeaders=", ",", self._headers["Authorization"] + ).split(";") headers = self._create_headers_for_aws_request(signed_headers, self._headers) - request = AWSRequest(method=self._method, url=self._path, data=self._data, headers=headers) - request.context['timestamp'] = headers['X-Amz-Date'] + request = AWSRequest( + method=self._method, url=self._path, data=self._data, headers=headers + ) + request.context["timestamp"] = headers["X-Amz-Date"] return request @@ -234,7 +275,6 @@ class IAMRequestBase(object): class IAMRequest(IAMRequestBase): - def _raise_signature_does_not_match(self): if self._service == "ec2": raise AuthFailureError() @@ -251,14 +291,10 @@ class IAMRequest(IAMRequestBase): return SigV4Auth(credentials, self._service, self._region) def _raise_access_denied(self): - raise AccessDeniedError( - user_arn=self._access_key.arn, - action=self._action - ) + raise AccessDeniedError(user_arn=self._access_key.arn, action=self._action) class S3IAMRequest(IAMRequestBase): - def _raise_signature_does_not_match(self): if "BucketName" in self._data: raise BucketSignatureDoesNotMatchError(bucket=self._data["BucketName"]) @@ -288,10 +324,13 @@ class S3IAMRequest(IAMRequestBase): class IAMPolicy(object): - def __init__(self, policy): if isinstance(policy, Policy): - default_version = next(policy_version for policy_version in policy.versions if policy_version.is_default) + default_version = next( + policy_version + for policy_version in policy.versions + if policy_version.is_default + ) policy_document = default_version.document elif isinstance(policy, string_types): policy_document = policy @@ -321,7 +360,6 @@ class IAMPolicy(object): class IAMPolicyStatement(object): - def __init__(self, statement): self._statement = statement diff --git a/moto/core/exceptions.py b/moto/core/exceptions.py index 4f5662bcf..ea91eda63 100644 --- a/moto/core/exceptions.py +++ b/moto/core/exceptions.py @@ -4,7 +4,7 @@ from werkzeug.exceptions import HTTPException from jinja2 import DictLoader, Environment -SINGLE_ERROR_RESPONSE = u""" +SINGLE_ERROR_RESPONSE = """ {{error_type}} {{message}} @@ -13,7 +13,7 @@ SINGLE_ERROR_RESPONSE = u""" """ -ERROR_RESPONSE = u""" +ERROR_RESPONSE = """ @@ -26,7 +26,7 @@ ERROR_RESPONSE = u""" """ -ERROR_JSON_RESPONSE = u"""{ +ERROR_JSON_RESPONSE = """{ "message": "{{message}}", "__type": "{{error_type}}" } @@ -37,18 +37,19 @@ class RESTError(HTTPException): code = 400 templates = { - 'single_error': SINGLE_ERROR_RESPONSE, - 'error': ERROR_RESPONSE, - 'error_json': ERROR_JSON_RESPONSE, + "single_error": SINGLE_ERROR_RESPONSE, + "error": ERROR_RESPONSE, + "error_json": ERROR_JSON_RESPONSE, } - def __init__(self, error_type, message, template='error', **kwargs): + def __init__(self, error_type, message, template="error", **kwargs): super(RESTError, self).__init__() env = Environment(loader=DictLoader(self.templates)) self.error_type = error_type self.message = message self.description = env.get_template(template).render( - error_type=error_type, message=message, **kwargs) + error_type=error_type, message=message, **kwargs + ) class DryRunClientError(RESTError): @@ -56,12 +57,11 @@ class DryRunClientError(RESTError): class JsonRESTError(RESTError): - def __init__(self, error_type, message, template='error_json', **kwargs): - super(JsonRESTError, self).__init__( - error_type, message, template, **kwargs) + def __init__(self, error_type, message, template="error_json", **kwargs): + super(JsonRESTError, self).__init__(error_type, message, template, **kwargs) def get_headers(self, *args, **kwargs): - return [('Content-Type', 'application/json')] + return [("Content-Type", "application/json")] def get_body(self, *args, **kwargs): return self.description @@ -72,8 +72,9 @@ class SignatureDoesNotMatchError(RESTError): def __init__(self): super(SignatureDoesNotMatchError, self).__init__( - 'SignatureDoesNotMatch', - "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.") + "SignatureDoesNotMatch", + "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.", + ) class InvalidClientTokenIdError(RESTError): @@ -81,8 +82,9 @@ class InvalidClientTokenIdError(RESTError): def __init__(self): super(InvalidClientTokenIdError, self).__init__( - 'InvalidClientTokenId', - "The security token included in the request is invalid.") + "InvalidClientTokenId", + "The security token included in the request is invalid.", + ) class AccessDeniedError(RESTError): @@ -90,11 +92,11 @@ class AccessDeniedError(RESTError): def __init__(self, user_arn, action): super(AccessDeniedError, self).__init__( - 'AccessDenied', + "AccessDenied", "User: {user_arn} is not authorized to perform: {operation}".format( - user_arn=user_arn, - operation=action - )) + user_arn=user_arn, operation=action + ), + ) class AuthFailureError(RESTError): @@ -102,13 +104,17 @@ class AuthFailureError(RESTError): def __init__(self): super(AuthFailureError, self).__init__( - 'AuthFailure', - "AWS was not able to validate the provided access credentials") + "AuthFailure", + "AWS was not able to validate the provided access credentials", + ) class InvalidNextTokenException(JsonRESTError): """For AWS Config resource listing. This will be used by many different resource types, and so it is in moto.core.""" + code = 400 def __init__(self): - super(InvalidNextTokenException, self).__init__('InvalidNextTokenException', 'The nextToken provided is invalid') + super(InvalidNextTokenException, self).__init__( + "InvalidNextTokenException", "The nextToken provided is invalid" + ) diff --git a/moto/core/models.py b/moto/core/models.py index e0ff5ba42..e0eae5858 100644 --- a/moto/core/models.py +++ b/moto/core/models.py @@ -31,15 +31,19 @@ class BaseMockAWS(object): self.backends_for_urls = {} from moto.backends import BACKENDS + default_backends = { - "instance_metadata": BACKENDS['instance_metadata']['global'], - "moto_api": BACKENDS['moto_api']['global'], + "instance_metadata": BACKENDS["instance_metadata"]["global"], + "moto_api": BACKENDS["moto_api"]["global"], } self.backends_for_urls.update(self.backends) self.backends_for_urls.update(default_backends) # "Mock" the AWS credentials as they can't be mocked in Botocore currently - FAKE_KEYS = {"AWS_ACCESS_KEY_ID": "foobar_key", "AWS_SECRET_ACCESS_KEY": "foobar_secret"} + FAKE_KEYS = { + "AWS_ACCESS_KEY_ID": "foobar_key", + "AWS_SECRET_ACCESS_KEY": "foobar_secret", + } self.env_variables_mocks = mock.patch.dict(os.environ, FAKE_KEYS) if self.__class__.nested_count == 0: @@ -72,7 +76,7 @@ class BaseMockAWS(object): self.__class__.nested_count -= 1 if self.__class__.nested_count < 0: - raise RuntimeError('Called stop() before start().') + raise RuntimeError("Called stop() before start().") if self.__class__.nested_count == 0: self.disable_patching() @@ -85,6 +89,7 @@ class BaseMockAWS(object): finally: self.stop() return result + functools.update_wrapper(wrapper, func) wrapper.__wrapped__ = func return wrapper @@ -122,7 +127,6 @@ class BaseMockAWS(object): class HttprettyMockAWS(BaseMockAWS): - def reset(self): HTTPretty.reset() @@ -144,18 +148,26 @@ class HttprettyMockAWS(BaseMockAWS): HTTPretty.reset() -RESPONSES_METHODS = [responses.GET, responses.DELETE, responses.HEAD, - responses.OPTIONS, responses.PATCH, responses.POST, responses.PUT] +RESPONSES_METHODS = [ + responses.GET, + responses.DELETE, + responses.HEAD, + responses.OPTIONS, + responses.PATCH, + responses.POST, + responses.PUT, +] class CallbackResponse(responses.CallbackResponse): - ''' + """ Need to subclass so we can change a couple things - ''' + """ + def get_response(self, request): - ''' + """ Need to override this so we can pass decode_content=False - ''' + """ headers = self.get_headers() result = self.callback(request) @@ -177,17 +189,17 @@ class CallbackResponse(responses.CallbackResponse): ) def _url_matches(self, url, other, match_querystring=False): - ''' + """ Need to override this so we can fix querystrings breaking regex matching - ''' + """ if not match_querystring: - other = other.split('?', 1)[0] + other = other.split("?", 1)[0] if responses._is_string(url): if responses._has_unicode(url): url = responses._clean_unicode(url) if not isinstance(other, six.text_type): - other = other.encode('ascii').decode('utf8') + other = other.encode("ascii").decode("utf8") return self._url_matches_strict(url, other) elif isinstance(url, responses.Pattern) and url.match(other): return True @@ -195,22 +207,23 @@ class CallbackResponse(responses.CallbackResponse): return False -botocore_mock = responses.RequestsMock(assert_all_requests_are_fired=False, target='botocore.vendored.requests.adapters.HTTPAdapter.send') +botocore_mock = responses.RequestsMock( + assert_all_requests_are_fired=False, + target="botocore.vendored.requests.adapters.HTTPAdapter.send", +) responses_mock = responses._default_mock # Add passthrough to allow any other requests to work # Since this uses .startswith, it applies to http and https requests. responses_mock.add_passthru("http") -BOTOCORE_HTTP_METHODS = [ - 'GET', 'DELETE', 'HEAD', 'OPTIONS', 'PATCH', 'POST', 'PUT' -] +BOTOCORE_HTTP_METHODS = ["GET", "DELETE", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"] class MockRawResponse(BytesIO): def __init__(self, input): if isinstance(input, six.text_type): - input = input.encode('utf-8') + input = input.encode("utf-8") super(MockRawResponse, self).__init__(input) def stream(self, **kwargs): @@ -241,7 +254,7 @@ class BotocoreStubber(object): found_index = None matchers = self.methods.get(request.method) - base_url = request.url.split('?', 1)[0] + base_url = request.url.split("?", 1)[0] for i, (pattern, callback) in enumerate(matchers): if pattern.match(base_url): if found_index is None: @@ -254,8 +267,10 @@ class BotocoreStubber(object): if response_callback is not None: for header, value in request.headers.items(): if isinstance(value, six.binary_type): - request.headers[header] = value.decode('utf-8') - status, headers, body = response_callback(request, request.url, request.headers) + request.headers[header] = value.decode("utf-8") + status, headers, body = response_callback( + request, request.url, request.headers + ) body = MockRawResponse(body) response = AWSResponse(request.url, status, headers, body) @@ -263,7 +278,7 @@ class BotocoreStubber(object): botocore_stubber = BotocoreStubber() -BUILTIN_HANDLERS.append(('before-send', botocore_stubber)) +BUILTIN_HANDLERS.append(("before-send", botocore_stubber)) def not_implemented_callback(request): @@ -287,7 +302,9 @@ class BotocoreEventMockAWS(BaseMockAWS): pattern = re.compile(key) botocore_stubber.register_response(method, pattern, value) - if not hasattr(responses_mock, '_patcher') or not hasattr(responses_mock._patcher, 'target'): + if not hasattr(responses_mock, "_patcher") or not hasattr( + responses_mock._patcher, "target" + ): responses_mock.start() for method in RESPONSES_METHODS: @@ -336,9 +353,9 @@ MockAWS = BotocoreEventMockAWS class ServerModeMockAWS(BaseMockAWS): - def reset(self): import requests + requests.post("http://localhost:5000/moto-api/reset") def enable_patching(self): @@ -350,13 +367,13 @@ class ServerModeMockAWS(BaseMockAWS): import mock def fake_boto3_client(*args, **kwargs): - if 'endpoint_url' not in kwargs: - kwargs['endpoint_url'] = "http://localhost:5000" + if "endpoint_url" not in kwargs: + kwargs["endpoint_url"] = "http://localhost:5000" return real_boto3_client(*args, **kwargs) def fake_boto3_resource(*args, **kwargs): - if 'endpoint_url' not in kwargs: - kwargs['endpoint_url'] = "http://localhost:5000" + if "endpoint_url" not in kwargs: + kwargs["endpoint_url"] = "http://localhost:5000" return real_boto3_resource(*args, **kwargs) def fake_httplib_send_output(self, message_body=None, *args, **kwargs): @@ -364,7 +381,7 @@ class ServerModeMockAWS(BaseMockAWS): bytes_buffer = [] for chunk in mixed_buffer: if isinstance(chunk, six.text_type): - bytes_buffer.append(chunk.encode('utf-8')) + bytes_buffer.append(chunk.encode("utf-8")) else: bytes_buffer.append(chunk) msg = b"\r\n".join(bytes_buffer) @@ -385,10 +402,12 @@ class ServerModeMockAWS(BaseMockAWS): if message_body is not None: self.send(message_body) - self._client_patcher = mock.patch('boto3.client', fake_boto3_client) - self._resource_patcher = mock.patch('boto3.resource', fake_boto3_resource) + self._client_patcher = mock.patch("boto3.client", fake_boto3_client) + self._resource_patcher = mock.patch("boto3.resource", fake_boto3_resource) if six.PY2: - self._httplib_patcher = mock.patch('httplib.HTTPConnection._send_output', fake_httplib_send_output) + self._httplib_patcher = mock.patch( + "httplib.HTTPConnection._send_output", fake_httplib_send_output + ) self._client_patcher.start() self._resource_patcher.start() @@ -404,7 +423,6 @@ class ServerModeMockAWS(BaseMockAWS): class Model(type): - def __new__(self, clsname, bases, namespace): cls = super(Model, self).__new__(self, clsname, bases, namespace) cls.__models__ = {} @@ -419,9 +437,11 @@ class Model(type): @staticmethod def prop(model_name): """ decorator to mark a class method as returning model values """ + def dec(f): f.__returns_model__ = model_name return f + return dec @@ -431,7 +451,7 @@ model_data = defaultdict(dict) class InstanceTrackerMeta(type): def __new__(meta, name, bases, dct): cls = super(InstanceTrackerMeta, meta).__new__(meta, name, bases, dct) - if name == 'BaseModel': + if name == "BaseModel": return cls service = cls.__module__.split(".")[1] @@ -450,7 +470,6 @@ class BaseModel(object): class BaseBackend(object): - def _reset_model_refs(self): # Remove all references to the models stored for service, models in model_data.items(): @@ -466,8 +485,9 @@ class BaseBackend(object): def _url_module(self): backend_module = self.__class__.__module__ backend_urls_module_name = backend_module.replace("models", "urls") - backend_urls_module = __import__(backend_urls_module_name, fromlist=[ - 'url_bases', 'url_paths']) + backend_urls_module = __import__( + backend_urls_module_name, fromlist=["url_bases", "url_paths"] + ) return backend_urls_module @property @@ -523,9 +543,9 @@ class BaseBackend(object): def decorator(self, func=None): if settings.TEST_SERVER_MODE: - mocked_backend = ServerModeMockAWS({'global': self}) + mocked_backend = ServerModeMockAWS({"global": self}) else: - mocked_backend = MockAWS({'global': self}) + mocked_backend = MockAWS({"global": self}) if func: return mocked_backend(func) @@ -534,9 +554,9 @@ class BaseBackend(object): def deprecated_decorator(self, func=None): if func: - return HttprettyMockAWS({'global': self})(func) + return HttprettyMockAWS({"global": self})(func) else: - return HttprettyMockAWS({'global': self}) + return HttprettyMockAWS({"global": self}) # def list_config_service_resources(self, resource_ids, resource_name, limit, next_token): # """For AWS Config. This will list all of the resources of the given type and optional resource name and region""" @@ -544,12 +564,19 @@ class BaseBackend(object): class ConfigQueryModel(object): - def __init__(self, backends): """Inits based on the resource type's backends (1 for each region if applicable)""" self.backends = backends - def list_config_service_resources(self, resource_ids, resource_name, limit, next_token, backend_region=None, resource_region=None): + def list_config_service_resources( + self, + resource_ids, + resource_name, + limit, + next_token, + backend_region=None, + resource_region=None, + ): """For AWS Config. This will list all of the resources of the given type and optional resource name and region. This supports both aggregated and non-aggregated listing. The following notes the difference: @@ -593,7 +620,9 @@ class ConfigQueryModel(object): """ raise NotImplementedError() - def get_config_resource(self, resource_id, resource_name=None, backend_region=None, resource_region=None): + def get_config_resource( + self, resource_id, resource_name=None, backend_region=None, resource_region=None + ): """For AWS Config. This will query the backend for the specific resource type configuration. This supports both aggregated, and non-aggregated fetching -- for batched fetching -- the Config batching requests @@ -644,9 +673,9 @@ class deprecated_base_decorator(base_decorator): class MotoAPIBackend(BaseBackend): - def reset(self): from moto.backends import BACKENDS + for name, backends in BACKENDS.items(): if name == "moto_api": continue diff --git a/moto/core/responses.py b/moto/core/responses.py index 213fa278c..bf4af902a 100644 --- a/moto/core/responses.py +++ b/moto/core/responses.py @@ -40,7 +40,7 @@ def _decode_dict(d): newkey = [] for k in key: if isinstance(k, six.binary_type): - newkey.append(k.decode('utf-8')) + newkey.append(k.decode("utf-8")) else: newkey.append(k) else: @@ -52,7 +52,7 @@ def _decode_dict(d): newvalue = [] for v in value: if isinstance(v, six.binary_type): - newvalue.append(v.decode('utf-8')) + newvalue.append(v.decode("utf-8")) else: newvalue.append(v) else: @@ -83,12 +83,15 @@ class DynamicDictLoader(DictLoader): class _TemplateEnvironmentMixin(object): + LEFT_PATTERN = re.compile(r"[\s\n]+<") + RIGHT_PATTERN = re.compile(r">[\s\n]+") def __init__(self): super(_TemplateEnvironmentMixin, self).__init__() self.loader = DynamicDictLoader({}) self.environment = Environment( - loader=self.loader, autoescape=self.should_autoescape) + loader=self.loader, autoescape=self.should_autoescape + ) @property def should_autoescape(self): @@ -101,9 +104,16 @@ class _TemplateEnvironmentMixin(object): def response_template(self, source): template_id = id(source) if not self.contains_template(template_id): - self.loader.update({template_id: source}) - self.environment = Environment(loader=self.loader, autoescape=self.should_autoescape, trim_blocks=True, - lstrip_blocks=True) + collapsed = re.sub( + self.RIGHT_PATTERN, ">", re.sub(self.LEFT_PATTERN, "<", source) + ) + self.loader.update({template_id: collapsed}) + self.environment = Environment( + loader=self.loader, + autoescape=self.should_autoescape, + trim_blocks=True, + lstrip_blocks=True, + ) return self.environment.get_template(template_id) @@ -112,8 +122,13 @@ class ActionAuthenticatorMixin(object): request_count = 0 def _authenticate_and_authorize_action(self, iam_request_cls): - if ActionAuthenticatorMixin.request_count >= settings.INITIAL_NO_AUTH_ACTION_COUNT: - iam_request = iam_request_cls(method=self.method, path=self.path, data=self.data, headers=self.headers) + if ( + ActionAuthenticatorMixin.request_count + >= settings.INITIAL_NO_AUTH_ACTION_COUNT + ): + iam_request = iam_request_cls( + method=self.method, path=self.path, data=self.data, headers=self.headers + ) iam_request.check_signature() iam_request.check_action_permitted() else: @@ -130,10 +145,17 @@ class ActionAuthenticatorMixin(object): def decorator(function): def wrapper(*args, **kwargs): if settings.TEST_SERVER_MODE: - response = requests.post("http://localhost:5000/moto-api/reset-auth", data=str(initial_no_auth_action_count).encode()) - original_initial_no_auth_action_count = response.json()['PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT'] + response = requests.post( + "http://localhost:5000/moto-api/reset-auth", + data=str(initial_no_auth_action_count).encode(), + ) + original_initial_no_auth_action_count = response.json()[ + "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT" + ] else: - original_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT + original_initial_no_auth_action_count = ( + settings.INITIAL_NO_AUTH_ACTION_COUNT + ) original_request_count = ActionAuthenticatorMixin.request_count settings.INITIAL_NO_AUTH_ACTION_COUNT = initial_no_auth_action_count ActionAuthenticatorMixin.request_count = 0 @@ -141,10 +163,15 @@ class ActionAuthenticatorMixin(object): result = function(*args, **kwargs) finally: if settings.TEST_SERVER_MODE: - requests.post("http://localhost:5000/moto-api/reset-auth", data=str(original_initial_no_auth_action_count).encode()) + requests.post( + "http://localhost:5000/moto-api/reset-auth", + data=str(original_initial_no_auth_action_count).encode(), + ) else: ActionAuthenticatorMixin.request_count = original_request_count - settings.INITIAL_NO_AUTH_ACTION_COUNT = original_initial_no_auth_action_count + settings.INITIAL_NO_AUTH_ACTION_COUNT = ( + original_initial_no_auth_action_count + ) return result functools.update_wrapper(wrapper, function) @@ -156,11 +183,13 @@ class ActionAuthenticatorMixin(object): class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): - default_region = 'us-east-1' + default_region = "us-east-1" # to extract region, use [^.] - region_regex = re.compile(r'\.(?P[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com') - param_list_regex = re.compile(r'(.*)\.(\d+)\.') - access_key_regex = re.compile(r'AWS.*(?P(?[a-z]{2}-[a-z]+-\d{1})\.amazonaws\.com") + param_list_regex = re.compile(r"(.*)\.(\d+)\.") + access_key_regex = re.compile( + r"AWS.*(?P(? '^/cars/.*/drivers/.*/drive$' """ - def _convert(elem, is_last): - if not re.match('^{.*}$', elem): - return elem - name = elem.replace('{', '').replace('}', '') - if is_last: - return '(?P<%s>[^/]*)' % name - return '(?P<%s>.*)' % name - elems = uri.split('/') + def _convert(elem, is_last): + if not re.match("^{.*}$", elem): + return elem + name = elem.replace("{", "").replace("}", "") + if is_last: + return "(?P<%s>[^/]*)" % name + return "(?P<%s>.*)" % name + + elems = uri.split("/") num_elems = len(elems) - regexp = '^{}$'.format('/'.join([_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)])) + regexp = "^{}$".format( + "/".join( + [_convert(elem, (i == num_elems - 1)) for i, elem in enumerate(elems)] + ) + ) return regexp def _get_action_from_method_and_request_uri(self, method, request_uri): @@ -288,19 +329,19 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # service response class should have 'SERVICE_NAME' class member, # if you want to get action from method and url - if not hasattr(self, 'SERVICE_NAME'): + if not hasattr(self, "SERVICE_NAME"): return None service = self.SERVICE_NAME conn = boto3.client(service, region_name=self.region) # make cache if it does not exist yet - if not hasattr(self, 'method_urls'): + if not hasattr(self, "method_urls"): self.method_urls = defaultdict(lambda: defaultdict(str)) op_names = conn._service_model.operation_names for op_name in op_names: op_model = conn._service_model.operation_model(op_name) - _method = op_model.http['method'] - uri_regexp = self.uri_to_regexp(op_model.http['requestUri']) + _method = op_model.http["method"] + uri_regexp = self.uri_to_regexp(op_model.http["requestUri"]) self.method_urls[_method][uri_regexp] = op_model.name regexp_and_names = self.method_urls[method] for regexp, name in regexp_and_names.items(): @@ -311,11 +352,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return None def _get_action(self): - action = self.querystring.get('Action', [""])[0] + action = self.querystring.get("Action", [""])[0] if not action: # Some services use a header for the action # Headers are case-insensitive. Probably a better way to do this. - match = self.headers.get( - 'x-amz-target') or self.headers.get('X-Amz-Target') + match = self.headers.get("x-amz-target") or self.headers.get("X-Amz-Target") if match: action = match.split(".")[-1] # get action from method and uri @@ -347,10 +387,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return self._send_response(headers, response) if not action: - return 404, headers, '' + return 404, headers, "" raise NotImplementedError( - "The {0} action has not been implemented".format(action)) + "The {0} action has not been implemented".format(action) + ) @staticmethod def _send_response(headers, response): @@ -358,11 +399,11 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): body, new_headers = response else: status, new_headers, body = response - status = new_headers.get('status', 200) + status = new_headers.get("status", 200) headers.update(new_headers) # Cast status to string if "status" in headers: - headers['status'] = str(headers['status']) + headers["status"] = str(headers["status"]) return status, headers, body def _get_param(self, param_name, if_none=None): @@ -396,9 +437,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _get_bool_param(self, param_name, if_none=None): val = self._get_param(param_name) if val is not None: - if val.lower() == 'true': + if val.lower() == "true": return True - elif val.lower() == 'false': + elif val.lower() == "false": return False return if_none @@ -416,11 +457,16 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if is_tracked(name) or not name.startswith(param_prefix): continue - if len(name) > len(param_prefix) and \ - not name[len(param_prefix):].startswith('.'): + if len(name) > len(param_prefix) and not name[ + len(param_prefix) : + ].startswith("."): continue - match = self.param_list_regex.search(name[len(param_prefix):]) if len(name) > len(param_prefix) else None + match = ( + self.param_list_regex.search(name[len(param_prefix) :]) + if len(name) > len(param_prefix) + else None + ) if match: prefix = param_prefix + match.group(1) value = self._get_multi_param(prefix) @@ -435,7 +481,10 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if len(value_dict) > 1: # strip off period prefix - value_dict = {name[len(param_prefix) + 1:]: value for name, value in value_dict.items()} + value_dict = { + name[len(param_prefix) + 1 :]: value + for name, value in value_dict.items() + } else: value_dict = list(value_dict.values())[0] @@ -454,7 +503,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): index = 1 while True: value_dict = self._get_multi_param_helper(prefix + str(index)) - if not value_dict and value_dict != '': + if not value_dict and value_dict != "": break values.append(value_dict) @@ -479,8 +528,9 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): params = {} for key, value in self.querystring.items(): if key.startswith(param_prefix): - params[camelcase_to_underscores( - key.replace(param_prefix, ""))] = value[0] + params[camelcase_to_underscores(key.replace(param_prefix, ""))] = value[ + 0 + ] return params def _get_list_prefix(self, param_prefix): @@ -513,19 +563,20 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): new_items = {} for key, value in self.querystring.items(): if key.startswith(index_prefix): - new_items[camelcase_to_underscores( - key.replace(index_prefix, ""))] = value[0] + new_items[ + camelcase_to_underscores(key.replace(index_prefix, "")) + ] = value[0] if not new_items: break results.append(new_items) param_index += 1 return results - def _get_map_prefix(self, param_prefix, key_end='.key', value_end='.value'): + def _get_map_prefix(self, param_prefix, key_end=".key", value_end=".value"): results = {} param_index = 1 while 1: - index_prefix = '{0}.{1}.'.format(param_prefix, param_index) + index_prefix = "{0}.{1}.".format(param_prefix, param_index) k, v = None, None for key, value in self.querystring.items(): @@ -552,8 +603,8 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): param_index = 1 while True: - key_name = 'tag.{0}._key'.format(param_index) - value_name = 'tag.{0}._value'.format(param_index) + key_name = "tag.{0}._key".format(param_index) + value_name = "tag.{0}._value".format(param_index) try: results[resource_type][tag[key_name]] = tag[value_name] @@ -563,7 +614,7 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return results - def _get_object_map(self, prefix, name='Name', value='Value'): + def _get_object_map(self, prefix, name="Name", value="Value"): """ Given a query dict like { @@ -591,15 +642,14 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): index = 1 while True: # Loop through looking for keys representing object name - name_key = '{0}.{1}.{2}'.format(prefix, index, name) + name_key = "{0}.{1}.{2}".format(prefix, index, name) obj_name = self.querystring.get(name_key) if not obj_name: # Found all keys break obj = {} - value_key_prefix = '{0}.{1}.{2}.'.format( - prefix, index, value) + value_key_prefix = "{0}.{1}.{2}.".format(prefix, index, value) for k, v in self.querystring.items(): if k.startswith(value_key_prefix): _, value_key = k.split(value_key_prefix, 1) @@ -613,31 +663,46 @@ class BaseResponse(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): @property def request_json(self): - return 'JSON' in self.querystring.get('ContentType', []) + return "JSON" in self.querystring.get("ContentType", []) def is_not_dryrun(self, action): - if 'true' in self.querystring.get('DryRun', ['false']): - message = 'An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set' % action - raise DryRunClientError( - error_type="DryRunOperation", message=message) + if "true" in self.querystring.get("DryRun", ["false"]): + message = ( + "An error occurred (DryRunOperation) when calling the %s operation: Request would have succeeded, but DryRun flag is set" + % action + ) + raise DryRunClientError(error_type="DryRunOperation", message=message) return True class MotoAPIResponse(BaseResponse): - def reset_response(self, request, full_url, headers): if request.method == "POST": from .models import moto_api_backend + moto_api_backend.reset() return 200, {}, json.dumps({"status": "ok"}) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto"}) def reset_auth_response(self, request, full_url, headers): if request.method == "POST": - previous_initial_no_auth_action_count = settings.INITIAL_NO_AUTH_ACTION_COUNT + previous_initial_no_auth_action_count = ( + settings.INITIAL_NO_AUTH_ACTION_COUNT + ) settings.INITIAL_NO_AUTH_ACTION_COUNT = float(request.data.decode()) ActionAuthenticatorMixin.request_count = 0 - return 200, {}, json.dumps({"status": "ok", "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str(previous_initial_no_auth_action_count)}) + return ( + 200, + {}, + json.dumps( + { + "status": "ok", + "PREVIOUS_INITIAL_NO_AUTH_ACTION_COUNT": str( + previous_initial_no_auth_action_count + ), + } + ), + ) return 400, {}, json.dumps({"Error": "Need to POST to reset Moto Auth"}) def model_data(self, request, full_url, headers): @@ -665,7 +730,8 @@ class MotoAPIResponse(BaseResponse): def dashboard(self, request, full_url, headers): from flask import render_template - return render_template('dashboard.html') + + return render_template("dashboard.html") class _RecursiveDictRef(object): @@ -676,7 +742,7 @@ class _RecursiveDictRef(object): self.dic = {} def __repr__(self): - return '{!r}'.format(self.dic) + return "{!r}".format(self.dic) def __getattr__(self, key): return self.dic.__getattr__(key) @@ -700,21 +766,21 @@ class AWSServiceSpec(object): """ def __init__(self, path): - self.path = resource_filename('botocore', path) - with io.open(self.path, 'r', encoding='utf-8') as f: + self.path = resource_filename("botocore", path) + with io.open(self.path, "r", encoding="utf-8") as f: spec = json.load(f) - self.metadata = spec['metadata'] - self.operations = spec['operations'] - self.shapes = spec['shapes'] + self.metadata = spec["metadata"] + self.operations = spec["operations"] + self.shapes = spec["shapes"] def input_spec(self, operation): try: op = self.operations[operation] except KeyError: - raise ValueError('Invalid operation: {}'.format(operation)) - if 'input' not in op: + raise ValueError("Invalid operation: {}".format(operation)) + if "input" not in op: return {} - shape = self.shapes[op['input']['shape']] + shape = self.shapes[op["input"]["shape"]] return self._expand(shape) def output_spec(self, operation): @@ -728,129 +794,133 @@ class AWSServiceSpec(object): try: op = self.operations[operation] except KeyError: - raise ValueError('Invalid operation: {}'.format(operation)) - if 'output' not in op: + raise ValueError("Invalid operation: {}".format(operation)) + if "output" not in op: return {} - shape = self.shapes[op['output']['shape']] + shape = self.shapes[op["output"]["shape"]] return self._expand(shape) def _expand(self, shape): def expand(dic, seen=None): seen = seen or {} - if dic['type'] == 'structure': + if dic["type"] == "structure": nodes = {} - for k, v in dic['members'].items(): + for k, v in dic["members"].items(): seen_till_here = dict(seen) if k in seen_till_here: nodes[k] = seen_till_here[k] continue seen_till_here[k] = _RecursiveDictRef() - nodes[k] = expand(self.shapes[v['shape']], seen_till_here) + nodes[k] = expand(self.shapes[v["shape"]], seen_till_here) seen_till_here[k].set_reference(k, nodes[k]) - nodes['type'] = 'structure' + nodes["type"] = "structure" return nodes - elif dic['type'] == 'list': + elif dic["type"] == "list": seen_till_here = dict(seen) - shape = dic['member']['shape'] + shape = dic["member"]["shape"] if shape in seen_till_here: return seen_till_here[shape] seen_till_here[shape] = _RecursiveDictRef() expanded = expand(self.shapes[shape], seen_till_here) seen_till_here[shape].set_reference(shape, expanded) - return {'type': 'list', 'member': expanded} + return {"type": "list", "member": expanded} - elif dic['type'] == 'map': + elif dic["type"] == "map": seen_till_here = dict(seen) - node = {'type': 'map'} + node = {"type": "map"} - if 'shape' in dic['key']: - shape = dic['key']['shape'] + if "shape" in dic["key"]: + shape = dic["key"]["shape"] seen_till_here[shape] = _RecursiveDictRef() - node['key'] = expand(self.shapes[shape], seen_till_here) - seen_till_here[shape].set_reference(shape, node['key']) + node["key"] = expand(self.shapes[shape], seen_till_here) + seen_till_here[shape].set_reference(shape, node["key"]) else: - node['key'] = dic['key']['type'] + node["key"] = dic["key"]["type"] - if 'shape' in dic['value']: - shape = dic['value']['shape'] + if "shape" in dic["value"]: + shape = dic["value"]["shape"] seen_till_here[shape] = _RecursiveDictRef() - node['value'] = expand(self.shapes[shape], seen_till_here) - seen_till_here[shape].set_reference(shape, node['value']) + node["value"] = expand(self.shapes[shape], seen_till_here) + seen_till_here[shape].set_reference(shape, node["value"]) else: - node['value'] = dic['value']['type'] + node["value"] = dic["value"]["type"] return node else: - return {'type': dic['type']} + return {"type": dic["type"]} return expand(shape) def to_str(value, spec): - vtype = spec['type'] - if vtype == 'boolean': - return 'true' if value else 'false' - elif vtype == 'integer': + vtype = spec["type"] + if vtype == "boolean": + return "true" if value else "false" + elif vtype == "integer": return str(value) - elif vtype == 'float': + elif vtype == "float": return str(value) - elif vtype == 'double': + elif vtype == "double": return str(value) - elif vtype == 'timestamp': - return datetime.datetime.utcfromtimestamp( - value).replace(tzinfo=pytz.utc).isoformat() - elif vtype == 'string': + elif vtype == "timestamp": + return ( + datetime.datetime.utcfromtimestamp(value) + .replace(tzinfo=pytz.utc) + .isoformat() + ) + elif vtype == "string": return str(value) elif value is None: - return 'null' + return "null" else: - raise TypeError('Unknown type {}'.format(vtype)) + raise TypeError("Unknown type {}".format(vtype)) def from_str(value, spec): - vtype = spec['type'] - if vtype == 'boolean': - return True if value == 'true' else False - elif vtype == 'integer': + vtype = spec["type"] + if vtype == "boolean": + return True if value == "true" else False + elif vtype == "integer": return int(value) - elif vtype == 'float': + elif vtype == "float": return float(value) - elif vtype == 'double': + elif vtype == "double": return float(value) - elif vtype == 'timestamp': + elif vtype == "timestamp": return value - elif vtype == 'string': + elif vtype == "string": return value - raise TypeError('Unknown type {}'.format(vtype)) + raise TypeError("Unknown type {}".format(vtype)) def flatten_json_request_body(prefix, dict_body, spec): """Convert a JSON request body into query params.""" - if len(spec) == 1 and 'type' in spec: + if len(spec) == 1 and "type" in spec: return {prefix: to_str(dict_body, spec)} flat = {} for key, value in dict_body.items(): - node_type = spec[key]['type'] - if node_type == 'list': + node_type = spec[key]["type"] + if node_type == "list": for idx, v in enumerate(value, 1): - pref = key + '.member.' + str(idx) - flat.update(flatten_json_request_body( - pref, v, spec[key]['member'])) - elif node_type == 'map': + pref = key + ".member." + str(idx) + flat.update(flatten_json_request_body(pref, v, spec[key]["member"])) + elif node_type == "map": for idx, (k, v) in enumerate(value.items(), 1): - pref = key + '.entry.' + str(idx) - flat.update(flatten_json_request_body( - pref + '.key', k, spec[key]['key'])) - flat.update(flatten_json_request_body( - pref + '.value', v, spec[key]['value'])) + pref = key + ".entry." + str(idx) + flat.update( + flatten_json_request_body(pref + ".key", k, spec[key]["key"]) + ) + flat.update( + flatten_json_request_body(pref + ".value", v, spec[key]["value"]) + ) else: flat.update(flatten_json_request_body(key, value, spec[key])) if prefix: - prefix = prefix + '.' + prefix = prefix + "." return dict((prefix + k, v) for k, v in flat.items()) @@ -873,41 +943,40 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None): od = OrderedDict() for k, v in value.items(): - if k.startswith('@'): + if k.startswith("@"): continue if k not in spec: # this can happen when with an older version of # botocore for which the node in XML template is not # defined in service spec. - log.warning( - 'Field %s is not defined by the botocore version in use', k) + log.warning("Field %s is not defined by the botocore version in use", k) continue - if spec[k]['type'] == 'list': + if spec[k]["type"] == "list": if v is None: od[k] = [] - elif len(spec[k]['member']) == 1: - if isinstance(v['member'], list): - od[k] = transform(v['member'], spec[k]['member']) + elif len(spec[k]["member"]) == 1: + if isinstance(v["member"], list): + od[k] = transform(v["member"], spec[k]["member"]) else: - od[k] = [transform(v['member'], spec[k]['member'])] - elif isinstance(v['member'], list): - od[k] = [transform(o, spec[k]['member']) - for o in v['member']] - elif isinstance(v['member'], OrderedDict): - od[k] = [transform(v['member'], spec[k]['member'])] + od[k] = [transform(v["member"], spec[k]["member"])] + elif isinstance(v["member"], list): + od[k] = [transform(o, spec[k]["member"]) for o in v["member"]] + elif isinstance(v["member"], OrderedDict): + od[k] = [transform(v["member"], spec[k]["member"])] else: - raise ValueError('Malformatted input') - elif spec[k]['type'] == 'map': + raise ValueError("Malformatted input") + elif spec[k]["type"] == "map": if v is None: od[k] = {} else: - items = ([v['entry']] if not isinstance(v['entry'], list) else - v['entry']) + items = ( + [v["entry"]] if not isinstance(v["entry"], list) else v["entry"] + ) for item in items: - key = from_str(item['key'], spec[k]['key']) - val = from_str(item['value'], spec[k]['value']) + key = from_str(item["key"], spec[k]["key"]) + val = from_str(item["value"], spec[k]["value"]) if k not in od: od[k] = {} od[k][key] = val @@ -921,7 +990,7 @@ def xml_to_json_response(service_spec, operation, xml, result_node=None): dic = xmltodict.parse(xml) output_spec = service_spec.output_spec(operation) try: - for k in (result_node or (operation + 'Response', operation + 'Result')): + for k in result_node or (operation + "Response", operation + "Result"): dic = dic[k] except KeyError: return None diff --git a/moto/core/urls.py b/moto/core/urls.py index 46025221e..12036b5c3 100644 --- a/moto/core/urls.py +++ b/moto/core/urls.py @@ -1,15 +1,13 @@ from __future__ import unicode_literals from .responses import MotoAPIResponse -url_bases = [ - "https?://motoapi.amazonaws.com" -] +url_bases = ["https?://motoapi.amazonaws.com"] response_instance = MotoAPIResponse() url_paths = { - '{0}/moto-api/$': response_instance.dashboard, - '{0}/moto-api/data.json': response_instance.model_data, - '{0}/moto-api/reset': response_instance.reset_response, - '{0}/moto-api/reset-auth': response_instance.reset_auth_response, + "{0}/moto-api/$": response_instance.dashboard, + "{0}/moto-api/data.json": response_instance.model_data, + "{0}/moto-api/reset": response_instance.reset_response, + "{0}/moto-api/reset-auth": response_instance.reset_auth_response, } diff --git a/moto/core/utils.py b/moto/core/utils.py index ca670e871..a15b7cd1e 100644 --- a/moto/core/utils.py +++ b/moto/core/utils.py @@ -15,9 +15,9 @@ REQUEST_ID_LONG = string.digits + string.ascii_uppercase def camelcase_to_underscores(argument): - ''' Converts a camelcase param like theNewAttribute to the equivalent - python underscore variable like the_new_attribute''' - result = '' + """ Converts a camelcase param like theNewAttribute to the equivalent + python underscore variable like the_new_attribute""" + result = "" prev_char_title = True if not argument: return argument @@ -41,18 +41,18 @@ def camelcase_to_underscores(argument): def underscores_to_camelcase(argument): - ''' Converts a camelcase param like the_new_attribute to the equivalent + """ Converts a camelcase param like the_new_attribute to the equivalent camelcase version like theNewAttribute. Note that the first letter is - NOT capitalized by this function ''' - result = '' + NOT capitalized by this function """ + result = "" previous_was_underscore = False for char in argument: - if char != '_': + if char != "_": if previous_was_underscore: result += char.upper() else: result += char - previous_was_underscore = char == '_' + previous_was_underscore = char == "_" return result @@ -69,12 +69,18 @@ def method_names_from_class(clazz): def get_random_hex(length=8): - chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f'] - return ''.join(six.text_type(random.choice(chars)) for x in range(length)) + chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"] + return "".join(six.text_type(random.choice(chars)) for x in range(length)) def get_random_message_id(): - return '{0}-{1}-{2}-{3}-{4}'.format(get_random_hex(8), get_random_hex(4), get_random_hex(4), get_random_hex(4), get_random_hex(12)) + return "{0}-{1}-{2}-{3}-{4}".format( + get_random_hex(8), + get_random_hex(4), + get_random_hex(4), + get_random_hex(4), + get_random_hex(12), + ) def convert_regex_to_flask_path(url_path): @@ -97,7 +103,6 @@ def convert_regex_to_flask_path(url_path): class convert_httpretty_response(object): - def __init__(self, callback): self.callback = callback @@ -114,13 +119,12 @@ class convert_httpretty_response(object): def __call__(self, request, url, headers, **kwargs): result = self.callback(request, url, headers) status, headers, response = result - if 'server' not in headers: + if "server" not in headers: headers["server"] = "amazon.com" return status, headers, response class convert_flask_to_httpretty_response(object): - def __init__(self, callback): self.callback = callback @@ -145,13 +149,12 @@ class convert_flask_to_httpretty_response(object): status, headers, content = 200, {}, result response = Response(response=content, status=status, headers=headers) - if request.method == "HEAD" and 'content-length' in headers: - response.headers['Content-Length'] = headers['content-length'] + if request.method == "HEAD" and "content-length" in headers: + response.headers["Content-Length"] = headers["content-length"] return response class convert_flask_to_responses_response(object): - def __init__(self, callback): self.callback = callback @@ -176,14 +179,14 @@ class convert_flask_to_responses_response(object): def iso_8601_datetime_with_milliseconds(datetime): - return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + 'Z' + return datetime.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" def iso_8601_datetime_without_milliseconds(datetime): - return datetime.strftime("%Y-%m-%dT%H:%M:%S") + 'Z' + return datetime.strftime("%Y-%m-%dT%H:%M:%S") + "Z" -RFC1123 = '%a, %d %b %Y %H:%M:%S GMT' +RFC1123 = "%a, %d %b %Y %H:%M:%S GMT" def rfc_1123_datetime(datetime): @@ -212,16 +215,16 @@ def gen_amz_crc32(response, headerdict=None): crc = str(binascii.crc32(response)) if headerdict is not None and isinstance(headerdict, dict): - headerdict.update({'x-amz-crc32': crc}) + headerdict.update({"x-amz-crc32": crc}) return crc def gen_amzn_requestid_long(headerdict=None): - req_id = ''.join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)]) + req_id = "".join([random.choice(REQUEST_ID_LONG) for _ in range(0, 52)]) if headerdict is not None and isinstance(headerdict, dict): - headerdict.update({'x-amzn-requestid': req_id}) + headerdict.update({"x-amzn-requestid": req_id}) return req_id @@ -239,13 +242,13 @@ def amz_crc32(f): else: if len(response) == 2: body, new_headers = response - status = new_headers.get('status', 200) + status = new_headers.get("status", 200) else: status, new_headers, body = response headers.update(new_headers) # Cast status to string if "status" in headers: - headers['status'] = str(headers['status']) + headers["status"] = str(headers["status"]) try: # Doesnt work on python2 for some odd unicode strings @@ -271,7 +274,7 @@ def amzn_request_id(f): else: if len(response) == 2: body, new_headers = response - status = new_headers.get('status', 200) + status = new_headers.get("status", 200) else: status, new_headers, body = response headers.update(new_headers) @@ -280,7 +283,7 @@ def amzn_request_id(f): # Update request ID in XML try: - body = re.sub(r'(?<=).*(?=<\/RequestId>)', request_id, body) + body = re.sub(r"(?<=).*(?=<\/RequestId>)", request_id, body) except Exception: # Will just ignore if it cant work on bytes (which are str's on python2) pass @@ -293,7 +296,7 @@ def path_url(url): parsed_url = urlparse(url) path = parsed_url.path if not path: - path = '/' + path = "/" if parsed_url.query: - path = path + '?' + parsed_url.query + path = path + "?" + parsed_url.query return path diff --git a/moto/datapipeline/__init__.py b/moto/datapipeline/__init__.py index 2565ddd5a..42ee5d6ff 100644 --- a/moto/datapipeline/__init__.py +++ b/moto/datapipeline/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import datapipeline_backends from ..core.models import base_decorator, deprecated_base_decorator -datapipeline_backend = datapipeline_backends['us-east-1'] +datapipeline_backend = datapipeline_backends["us-east-1"] mock_datapipeline = base_decorator(datapipeline_backends) mock_datapipeline_deprecated = deprecated_base_decorator(datapipeline_backends) diff --git a/moto/datapipeline/models.py b/moto/datapipeline/models.py index bb8417a20..cc1fe777e 100644 --- a/moto/datapipeline/models.py +++ b/moto/datapipeline/models.py @@ -8,85 +8,65 @@ from .utils import get_random_pipeline_id, remove_capitalization_of_dict_keys class PipelineObject(BaseModel): - def __init__(self, object_id, name, fields): self.object_id = object_id self.name = name self.fields = fields def to_json(self): - return { - "fields": self.fields, - "id": self.object_id, - "name": self.name, - } + return {"fields": self.fields, "id": self.object_id, "name": self.name} class Pipeline(BaseModel): - def __init__(self, name, unique_id, **kwargs): self.name = name self.unique_id = unique_id - self.description = kwargs.get('description', '') + self.description = kwargs.get("description", "") self.pipeline_id = get_random_pipeline_id() self.creation_time = datetime.datetime.utcnow() self.objects = [] self.status = "PENDING" - self.tags = kwargs.get('tags', []) + self.tags = kwargs.get("tags", []) @property def physical_resource_id(self): return self.pipeline_id def to_meta_json(self): - return { - "id": self.pipeline_id, - "name": self.name, - } + return {"id": self.pipeline_id, "name": self.name} def to_json(self): return { "description": self.description, - "fields": [{ - "key": "@pipelineState", - "stringValue": self.status, - }, { - "key": "description", - "stringValue": self.description - }, { - "key": "name", - "stringValue": self.name - }, { - "key": "@creationTime", - "stringValue": datetime.datetime.strftime(self.creation_time, '%Y-%m-%dT%H-%M-%S'), - }, { - "key": "@id", - "stringValue": self.pipeline_id, - }, { - "key": "@sphere", - "stringValue": "PIPELINE" - }, { - "key": "@version", - "stringValue": "1" - }, { - "key": "@userId", - "stringValue": "924374875933" - }, { - "key": "@accountId", - "stringValue": "924374875933" - }, { - "key": "uniqueId", - "stringValue": self.unique_id - }], + "fields": [ + {"key": "@pipelineState", "stringValue": self.status}, + {"key": "description", "stringValue": self.description}, + {"key": "name", "stringValue": self.name}, + { + "key": "@creationTime", + "stringValue": datetime.datetime.strftime( + self.creation_time, "%Y-%m-%dT%H-%M-%S" + ), + }, + {"key": "@id", "stringValue": self.pipeline_id}, + {"key": "@sphere", "stringValue": "PIPELINE"}, + {"key": "@version", "stringValue": "1"}, + {"key": "@userId", "stringValue": "924374875933"}, + {"key": "@accountId", "stringValue": "924374875933"}, + {"key": "uniqueId", "stringValue": self.unique_id}, + ], "name": self.name, "pipelineId": self.pipeline_id, - "tags": self.tags + "tags": self.tags, } def set_pipeline_objects(self, pipeline_objects): self.objects = [ - PipelineObject(pipeline_object['id'], pipeline_object[ - 'name'], pipeline_object['fields']) + PipelineObject( + pipeline_object["id"], + pipeline_object["name"], + pipeline_object["fields"], + ) for pipeline_object in remove_capitalization_of_dict_keys(pipeline_objects) ] @@ -94,15 +74,19 @@ class Pipeline(BaseModel): self.status = "SCHEDULED" @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): datapipeline_backend = datapipeline_backends[region_name] properties = cloudformation_json["Properties"] cloudformation_unique_id = "cf-" + properties["Name"] pipeline = datapipeline_backend.create_pipeline( - properties["Name"], cloudformation_unique_id) + properties["Name"], cloudformation_unique_id + ) datapipeline_backend.put_pipeline_definition( - pipeline.pipeline_id, properties["PipelineObjects"]) + pipeline.pipeline_id, properties["PipelineObjects"] + ) if properties["Activate"]: pipeline.activate() @@ -110,7 +94,6 @@ class Pipeline(BaseModel): class DataPipelineBackend(BaseBackend): - def __init__(self): self.pipelines = OrderedDict() @@ -123,8 +106,11 @@ class DataPipelineBackend(BaseBackend): return self.pipelines.values() def describe_pipelines(self, pipeline_ids): - pipelines = [pipeline for pipeline in self.pipelines.values( - ) if pipeline.pipeline_id in pipeline_ids] + pipelines = [ + pipeline + for pipeline in self.pipelines.values() + if pipeline.pipeline_id in pipeline_ids + ] return pipelines def get_pipeline(self, pipeline_id): @@ -144,7 +130,8 @@ class DataPipelineBackend(BaseBackend): def describe_objects(self, object_ids, pipeline_id): pipeline = self.get_pipeline(pipeline_id) pipeline_objects = [ - pipeline_object for pipeline_object in pipeline.objects + pipeline_object + for pipeline_object in pipeline.objects if pipeline_object.object_id in object_ids ] return pipeline_objects diff --git a/moto/datapipeline/responses.py b/moto/datapipeline/responses.py index e462e3981..42e1ff2c3 100644 --- a/moto/datapipeline/responses.py +++ b/moto/datapipeline/responses.py @@ -7,7 +7,6 @@ from .models import datapipeline_backends class DataPipelineResponse(BaseResponse): - @property def parameters(self): # TODO this should really be moved to core/responses.py @@ -21,47 +20,47 @@ class DataPipelineResponse(BaseResponse): return datapipeline_backends[self.region] def create_pipeline(self): - name = self.parameters.get('name') - unique_id = self.parameters.get('uniqueId') - description = self.parameters.get('description', '') - tags = self.parameters.get('tags', []) - pipeline = self.datapipeline_backend.create_pipeline(name, unique_id, description=description, tags=tags) - return json.dumps({ - "pipelineId": pipeline.pipeline_id, - }) + name = self.parameters.get("name") + unique_id = self.parameters.get("uniqueId") + description = self.parameters.get("description", "") + tags = self.parameters.get("tags", []) + pipeline = self.datapipeline_backend.create_pipeline( + name, unique_id, description=description, tags=tags + ) + return json.dumps({"pipelineId": pipeline.pipeline_id}) def list_pipelines(self): pipelines = list(self.datapipeline_backend.list_pipelines()) pipeline_ids = [pipeline.pipeline_id for pipeline in pipelines] max_pipelines = 50 - marker = self.parameters.get('marker') + marker = self.parameters.get("marker") if marker: start = pipeline_ids.index(marker) + 1 else: start = 0 - pipelines_resp = pipelines[start:start + max_pipelines] + pipelines_resp = pipelines[start : start + max_pipelines] has_more_results = False marker = None if start + max_pipelines < len(pipeline_ids) - 1: has_more_results = True marker = pipelines_resp[-1].pipeline_id - return json.dumps({ - "hasMoreResults": has_more_results, - "marker": marker, - "pipelineIdList": [ - pipeline.to_meta_json() for pipeline in pipelines_resp - ] - }) + return json.dumps( + { + "hasMoreResults": has_more_results, + "marker": marker, + "pipelineIdList": [ + pipeline.to_meta_json() for pipeline in pipelines_resp + ], + } + ) def describe_pipelines(self): pipeline_ids = self.parameters["pipelineIds"] pipelines = self.datapipeline_backend.describe_pipelines(pipeline_ids) - return json.dumps({ - "pipelineDescriptionList": [ - pipeline.to_json() for pipeline in pipelines - ] - }) + return json.dumps( + {"pipelineDescriptionList": [pipeline.to_json() for pipeline in pipelines]} + ) def delete_pipeline(self): pipeline_id = self.parameters["pipelineId"] @@ -72,31 +71,38 @@ class DataPipelineResponse(BaseResponse): pipeline_id = self.parameters["pipelineId"] pipeline_objects = self.parameters["pipelineObjects"] - self.datapipeline_backend.put_pipeline_definition( - pipeline_id, pipeline_objects) + self.datapipeline_backend.put_pipeline_definition(pipeline_id, pipeline_objects) return json.dumps({"errored": False}) def get_pipeline_definition(self): pipeline_id = self.parameters["pipelineId"] pipeline_definition = self.datapipeline_backend.get_pipeline_definition( - pipeline_id) - return json.dumps({ - "pipelineObjects": [pipeline_object.to_json() for pipeline_object in pipeline_definition] - }) + pipeline_id + ) + return json.dumps( + { + "pipelineObjects": [ + pipeline_object.to_json() for pipeline_object in pipeline_definition + ] + } + ) def describe_objects(self): pipeline_id = self.parameters["pipelineId"] object_ids = self.parameters["objectIds"] pipeline_objects = self.datapipeline_backend.describe_objects( - object_ids, pipeline_id) - return json.dumps({ - "hasMoreResults": False, - "marker": None, - "pipelineObjects": [ - pipeline_object.to_json() for pipeline_object in pipeline_objects - ] - }) + object_ids, pipeline_id + ) + return json.dumps( + { + "hasMoreResults": False, + "marker": None, + "pipelineObjects": [ + pipeline_object.to_json() for pipeline_object in pipeline_objects + ], + } + ) def activate_pipeline(self): pipeline_id = self.parameters["pipelineId"] diff --git a/moto/datapipeline/urls.py b/moto/datapipeline/urls.py index 40805874b..078b44b19 100644 --- a/moto/datapipeline/urls.py +++ b/moto/datapipeline/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DataPipelineResponse -url_bases = [ - "https?://datapipeline.(.+).amazonaws.com", -] +url_bases = ["https?://datapipeline.(.+).amazonaws.com"] -url_paths = { - '{0}/$': DataPipelineResponse.dispatch, -} +url_paths = {"{0}/$": DataPipelineResponse.dispatch} diff --git a/moto/datapipeline/utils.py b/moto/datapipeline/utils.py index 75df4a9a5..9135181e7 100644 --- a/moto/datapipeline/utils.py +++ b/moto/datapipeline/utils.py @@ -14,7 +14,9 @@ def remove_capitalization_of_dict_keys(obj): normalized_key = key[:1].lower() + key[1:] result[normalized_key] = remove_capitalization_of_dict_keys(value) return result - elif isinstance(obj, collections.Iterable) and not isinstance(obj, six.string_types): + elif isinstance(obj, collections.Iterable) and not isinstance( + obj, six.string_types + ): result = obj.__class__() for item in obj: result += (remove_capitalization_of_dict_keys(item),) diff --git a/moto/dynamodb/comparisons.py b/moto/dynamodb/comparisons.py index d9b391557..5418f906f 100644 --- a/moto/dynamodb/comparisons.py +++ b/moto/dynamodb/comparisons.py @@ -1,19 +1,22 @@ from __future__ import unicode_literals + # TODO add tests for all of these COMPARISON_FUNCS = { - 'EQ': lambda item_value, test_value: item_value == test_value, - 'NE': lambda item_value, test_value: item_value != test_value, - 'LE': lambda item_value, test_value: item_value <= test_value, - 'LT': lambda item_value, test_value: item_value < test_value, - 'GE': lambda item_value, test_value: item_value >= test_value, - 'GT': lambda item_value, test_value: item_value > test_value, - 'NULL': lambda item_value: item_value is None, - 'NOT_NULL': lambda item_value: item_value is not None, - 'CONTAINS': lambda item_value, test_value: test_value in item_value, - 'NOT_CONTAINS': lambda item_value, test_value: test_value not in item_value, - 'BEGINS_WITH': lambda item_value, test_value: item_value.startswith(test_value), - 'IN': lambda item_value, *test_values: item_value in test_values, - 'BETWEEN': lambda item_value, lower_test_value, upper_test_value: lower_test_value <= item_value <= upper_test_value, + "EQ": lambda item_value, test_value: item_value == test_value, + "NE": lambda item_value, test_value: item_value != test_value, + "LE": lambda item_value, test_value: item_value <= test_value, + "LT": lambda item_value, test_value: item_value < test_value, + "GE": lambda item_value, test_value: item_value >= test_value, + "GT": lambda item_value, test_value: item_value > test_value, + "NULL": lambda item_value: item_value is None, + "NOT_NULL": lambda item_value: item_value is not None, + "CONTAINS": lambda item_value, test_value: test_value in item_value, + "NOT_CONTAINS": lambda item_value, test_value: test_value not in item_value, + "BEGINS_WITH": lambda item_value, test_value: item_value.startswith(test_value), + "IN": lambda item_value, *test_values: item_value in test_values, + "BETWEEN": lambda item_value, lower_test_value, upper_test_value: lower_test_value + <= item_value + <= upper_test_value, } diff --git a/moto/dynamodb/models.py b/moto/dynamodb/models.py index 300189a0e..f00f6042d 100644 --- a/moto/dynamodb/models.py +++ b/moto/dynamodb/models.py @@ -10,9 +10,8 @@ from .comparisons import get_comparison_func class DynamoJsonEncoder(json.JSONEncoder): - def default(self, obj): - if hasattr(obj, 'to_json'): + if hasattr(obj, "to_json"): return obj.to_json() @@ -33,10 +32,7 @@ class DynamoType(object): return hash((self.type, self.value)) def __eq__(self, other): - return ( - self.type == other.type and - self.value == other.value - ) + return self.type == other.type and self.value == other.value def __repr__(self): return "DynamoType: {0}".format(self.to_json()) @@ -54,7 +50,6 @@ class DynamoType(object): class Item(BaseModel): - def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): self.hash_key = hash_key self.hash_key_type = hash_key_type @@ -73,9 +68,7 @@ class Item(BaseModel): for attribute_key, attribute in self.attrs.items(): attributes[attribute_key] = attribute.value - return { - "Attributes": attributes - } + return {"Attributes": attributes} def describe_attrs(self, attributes): if attributes: @@ -85,16 +78,20 @@ class Item(BaseModel): included[key] = value else: included = self.attrs - return { - "Item": included - } + return {"Item": included} class Table(BaseModel): - - def __init__(self, name, hash_key_attr, hash_key_type, - range_key_attr=None, range_key_type=None, read_capacity=None, - write_capacity=None): + def __init__( + self, + name, + hash_key_attr, + hash_key_type, + range_key_attr=None, + range_key_type=None, + read_capacity=None, + write_capacity=None, + ): self.name = name self.hash_key_attr = hash_key_attr self.hash_key_type = hash_key_type @@ -117,12 +114,12 @@ class Table(BaseModel): "KeySchema": { "HashKeyElement": { "AttributeName": self.hash_key_attr, - "AttributeType": self.hash_key_type - }, + "AttributeType": self.hash_key_type, + } }, "ProvisionedThroughput": { "ReadCapacityUnits": self.read_capacity, - "WriteCapacityUnits": self.write_capacity + "WriteCapacityUnits": self.write_capacity, }, "TableName": self.name, "TableStatus": "ACTIVE", @@ -133,19 +130,29 @@ class Table(BaseModel): if self.has_range_key: results["Table"]["KeySchema"]["RangeKeyElement"] = { "AttributeName": self.range_key_attr, - "AttributeType": self.range_key_type + "AttributeType": self.range_key_type, } return results @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - key_attr = [i['AttributeName'] for i in properties['KeySchema'] if i['KeyType'] == 'HASH'][0] - key_type = [i['AttributeType'] for i in properties['AttributeDefinitions'] if i['AttributeName'] == key_attr][0] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + key_attr = [ + i["AttributeName"] + for i in properties["KeySchema"] + if i["KeyType"] == "HASH" + ][0] + key_type = [ + i["AttributeType"] + for i in properties["AttributeDefinitions"] + if i["AttributeName"] == key_attr + ][0] spec = { - 'name': properties['TableName'], - 'hash_key_attr': key_attr, - 'hash_key_type': key_type + "name": properties["TableName"], + "hash_key_attr": key_attr, + "hash_key_type": key_type, } # TODO: optional properties still missing: # range_key_attr, range_key_type, read_capacity, write_capacity @@ -173,8 +180,9 @@ class Table(BaseModel): else: range_value = None - item = Item(hash_value, self.hash_key_type, range_value, - self.range_key_type, item_attrs) + item = Item( + hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs + ) if range_value: self.items[hash_value][range_value] = item @@ -185,7 +193,8 @@ class Table(BaseModel): def get_item(self, hash_key, range_key): if self.has_range_key and not range_key: raise ValueError( - "Table has a range key, but no range key was passed into get_item") + "Table has a range key, but no range key was passed into get_item" + ) try: if range_key: return self.items[hash_key][range_key] @@ -228,7 +237,10 @@ class Table(BaseModel): for result in self.all_items(): scanned_count += 1 passes_all_conditions = True - for attribute_name, (comparison_operator, comparison_objs) in filters.items(): + for ( + attribute_name, + (comparison_operator, comparison_objs), + ) in filters.items(): attribute = result.attrs.get(attribute_name) if attribute: @@ -236,7 +248,7 @@ class Table(BaseModel): if not attribute.compare(comparison_operator, comparison_objs): passes_all_conditions = False break - elif comparison_operator == 'NULL': + elif comparison_operator == "NULL": # Comparison is NULL and we don't have the attribute continue else: @@ -261,15 +273,17 @@ class Table(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'StreamArn': - region = 'us-east-1' - time = '2000-01-01T00:00:00.000' - return 'arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}'.format(region, self.name, time) + + if attribute_name == "StreamArn": + region = "us-east-1" + time = "2000-01-01T00:00:00.000" + return "arn:aws:dynamodb:{0}:123456789012:table/{1}/stream/{2}".format( + region, self.name, time + ) raise UnformattedGetAttTemplateException() class DynamoDBBackend(BaseBackend): - def __init__(self): self.tables = OrderedDict() @@ -310,8 +324,7 @@ class DynamoDBBackend(BaseBackend): return None, None hash_key = DynamoType(hash_key_dict) - range_values = [DynamoType(range_value) - for range_value in range_value_dicts] + range_values = [DynamoType(range_value) for range_value in range_value_dicts] return table.query(hash_key, range_comparison, range_values) diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index 990069a46..85ae58fc5 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -8,7 +8,6 @@ from .models import dynamodb_backend, dynamo_json_dump class DynamoHandler(BaseResponse): - def get_endpoint_name(self, headers): """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -16,15 +15,15 @@ class DynamoHandler(BaseResponse): ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables """ # Headers are case-insensitive. Probably a better way to do this. - match = headers.get('x-amz-target') or headers.get('X-Amz-Target') + match = headers.get("x-amz-target") or headers.get("X-Amz-Target") if match: return match.split(".")[1] def error(self, type_, status=400): - return status, self.response_headers, dynamo_json_dump({'__type': type_}) + return status, self.response_headers, dynamo_json_dump({"__type": type_}) def call_action(self): - self.body = json.loads(self.body or '{}') + self.body = json.loads(self.body or "{}") endpoint = self.get_endpoint_name(self.headers) if endpoint: endpoint = camelcase_to_underscores(endpoint) @@ -41,7 +40,7 @@ class DynamoHandler(BaseResponse): def list_tables(self): body = self.body - limit = body.get('Limit') + limit = body.get("Limit") if body.get("ExclusiveStartTableName"): last = body.get("ExclusiveStartTableName") start = list(dynamodb_backend.tables.keys()).index(last) + 1 @@ -49,7 +48,7 @@ class DynamoHandler(BaseResponse): start = 0 all_tables = list(dynamodb_backend.tables.keys()) if limit: - tables = all_tables[start:start + limit] + tables = all_tables[start : start + limit] else: tables = all_tables[start:] response = {"TableNames": tables} @@ -59,16 +58,16 @@ class DynamoHandler(BaseResponse): def create_table(self): body = self.body - name = body['TableName'] + name = body["TableName"] - key_schema = body['KeySchema'] - hash_key = key_schema['HashKeyElement'] - hash_key_attr = hash_key['AttributeName'] - hash_key_type = hash_key['AttributeType'] + key_schema = body["KeySchema"] + hash_key = key_schema["HashKeyElement"] + hash_key_attr = hash_key["AttributeName"] + hash_key_type = hash_key["AttributeType"] - range_key = key_schema.get('RangeKeyElement', {}) - range_key_attr = range_key.get('AttributeName') - range_key_type = range_key.get('AttributeType') + range_key = key_schema.get("RangeKeyElement", {}) + range_key_attr = range_key.get("AttributeName") + range_key_type = range_key.get("AttributeType") throughput = body["ProvisionedThroughput"] read_units = throughput["ReadCapacityUnits"] @@ -86,137 +85,131 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(table.describe) def delete_table(self): - name = self.body['TableName'] + name = self.body["TableName"] table = dynamodb_backend.delete_table(name) if table: return dynamo_json_dump(table.describe) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) def update_table(self): - name = self.body['TableName'] + name = self.body["TableName"] throughput = self.body["ProvisionedThroughput"] new_read_units = throughput["ReadCapacityUnits"] new_write_units = throughput["WriteCapacityUnits"] table = dynamodb_backend.update_table_throughput( - name, new_read_units, new_write_units) + name, new_read_units, new_write_units + ) return dynamo_json_dump(table.describe) def describe_table(self): - name = self.body['TableName'] + name = self.body["TableName"] try: table = dynamodb_backend.tables[name] except KeyError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) return dynamo_json_dump(table.describe) def put_item(self): - name = self.body['TableName'] - item = self.body['Item'] + name = self.body["TableName"] + item = self.body["Item"] result = dynamodb_backend.put_item(name, item) if result: item_dict = result.to_json() - item_dict['ConsumedCapacityUnits'] = 1 + item_dict["ConsumedCapacityUnits"] = 1 return dynamo_json_dump(item_dict) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) def batch_write_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] for table_name, table_requests in table_batches.items(): for table_request in table_requests: request_type = list(table_request)[0] request = list(table_request.values())[0] - if request_type == 'PutRequest': - item = request['Item'] + if request_type == "PutRequest": + item = request["Item"] dynamodb_backend.put_item(table_name, item) - elif request_type == 'DeleteRequest': - key = request['Key'] - hash_key = key['HashKeyElement'] - range_key = key.get('RangeKeyElement') - item = dynamodb_backend.delete_item( - table_name, hash_key, range_key) + elif request_type == "DeleteRequest": + key = request["Key"] + hash_key = key["HashKeyElement"] + range_key = key.get("RangeKeyElement") + item = dynamodb_backend.delete_item(table_name, hash_key, range_key) response = { "Responses": { - "Thread": { - "ConsumedCapacityUnits": 1.0 - }, - "Reply": { - "ConsumedCapacityUnits": 1.0 - } + "Thread": {"ConsumedCapacityUnits": 1.0}, + "Reply": {"ConsumedCapacityUnits": 1.0}, }, - "UnprocessedItems": {} + "UnprocessedItems": {}, } return dynamo_json_dump(response) def get_item(self): - name = self.body['TableName'] - key = self.body['Key'] - hash_key = key['HashKeyElement'] - range_key = key.get('RangeKeyElement') - attrs_to_get = self.body.get('AttributesToGet') + name = self.body["TableName"] + key = self.body["Key"] + hash_key = key["HashKeyElement"] + range_key = key.get("RangeKeyElement") + attrs_to_get = self.body.get("AttributesToGet") try: item = dynamodb_backend.get_item(name, hash_key, range_key) except ValueError: - er = 'com.amazon.coral.validate#ValidationException' + er = "com.amazon.coral.validate#ValidationException" return self.error(er, status=400) if item: item_dict = item.describe_attrs(attrs_to_get) - item_dict['ConsumedCapacityUnits'] = 0.5 + item_dict["ConsumedCapacityUnits"] = 0.5 return dynamo_json_dump(item_dict) else: # Item not found - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er, status=404) def batch_get_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] - results = { - "Responses": { - "UnprocessedKeys": {} - } - } + results = {"Responses": {"UnprocessedKeys": {}}} for table_name, table_request in table_batches.items(): items = [] - keys = table_request['Keys'] - attributes_to_get = table_request.get('AttributesToGet') + keys = table_request["Keys"] + attributes_to_get = table_request.get("AttributesToGet") for key in keys: hash_key = key["HashKeyElement"] range_key = key.get("RangeKeyElement") - item = dynamodb_backend.get_item( - table_name, hash_key, range_key) + item = dynamodb_backend.get_item(table_name, hash_key, range_key) if item: item_describe = item.describe_attrs(attributes_to_get) items.append(item_describe) results["Responses"][table_name] = { - "Items": items, "ConsumedCapacityUnits": 1} + "Items": items, + "ConsumedCapacityUnits": 1, + } return dynamo_json_dump(results) def query(self): - name = self.body['TableName'] - hash_key = self.body['HashKeyValue'] - range_condition = self.body.get('RangeKeyCondition') + name = self.body["TableName"] + hash_key = self.body["HashKeyValue"] + range_condition = self.body.get("RangeKeyCondition") if range_condition: - range_comparison = range_condition['ComparisonOperator'] - range_values = range_condition['AttributeValueList'] + range_comparison = range_condition["ComparisonOperator"] + range_values = range_condition["AttributeValueList"] else: range_comparison = None range_values = [] items, last_page = dynamodb_backend.query( - name, hash_key, range_comparison, range_values) + name, hash_key, range_comparison, range_values + ) if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) result = { @@ -234,10 +227,10 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) def scan(self): - name = self.body['TableName'] + name = self.body["TableName"] filters = {} - scan_filters = self.body.get('ScanFilter', {}) + scan_filters = self.body.get("ScanFilter", {}) for attribute_name, scan_filter in scan_filters.items(): # Keys are attribute names. Values are tuples of (comparison, # comparison_value) @@ -248,14 +241,14 @@ class DynamoHandler(BaseResponse): items, scanned_count, last_page = dynamodb_backend.scan(name, filters) if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) result = { "Count": len(items), "Items": [item.attrs for item in items if item], "ConsumedCapacityUnits": 1, - "ScannedCount": scanned_count + "ScannedCount": scanned_count, } # Implement this when we do pagination @@ -267,19 +260,19 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) def delete_item(self): - name = self.body['TableName'] - key = self.body['Key'] - hash_key = key['HashKeyElement'] - range_key = key.get('RangeKeyElement') - return_values = self.body.get('ReturnValues', '') + name = self.body["TableName"] + key = self.body["Key"] + hash_key = key["HashKeyElement"] + range_key = key.get("RangeKeyElement") + return_values = self.body.get("ReturnValues", "") item = dynamodb_backend.delete_item(name, hash_key, range_key) if item: - if return_values == 'ALL_OLD': + if return_values == "ALL_OLD": item_dict = item.to_json() else: - item_dict = {'Attributes': []} - item_dict['ConsumedCapacityUnits'] = 0.5 + item_dict = {"Attributes": []} + item_dict["ConsumedCapacityUnits"] = 0.5 return dynamo_json_dump(item_dict) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" return self.error(er) diff --git a/moto/dynamodb/urls.py b/moto/dynamodb/urls.py index 6988f6e15..26f0701a2 100644 --- a/moto/dynamodb/urls.py +++ b/moto/dynamodb/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DynamoHandler -url_bases = [ - "https?://dynamodb.(.+).amazonaws.com" -] +url_bases = ["https?://dynamodb.(.+).amazonaws.com"] -url_paths = { - "{0}/": DynamoHandler.dispatch, -} +url_paths = {"{0}/": DynamoHandler.dispatch} diff --git a/moto/dynamodb2/__init__.py b/moto/dynamodb2/__init__.py index a56a83b35..3d6e8ec1f 100644 --- a/moto/dynamodb2/__init__.py +++ b/moto/dynamodb2/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import dynamodb_backends as dynamodb_backends2 from ..core.models import base_decorator, deprecated_base_decorator -dynamodb_backend2 = dynamodb_backends2['us-east-1'] +dynamodb_backend2 = dynamodb_backends2["us-east-1"] mock_dynamodb2 = base_decorator(dynamodb_backends2) mock_dynamodb2_deprecated = deprecated_base_decorator(dynamodb_backends2) diff --git a/moto/dynamodb2/comparisons.py b/moto/dynamodb2/comparisons.py index c6aee7a68..69d7f74e0 100644 --- a/moto/dynamodb2/comparisons.py +++ b/moto/dynamodb2/comparisons.py @@ -1,7 +1,5 @@ from __future__ import unicode_literals import re -import six -import re from collections import deque from collections import namedtuple @@ -27,37 +25,35 @@ def get_expected(expected): expr = 'Id > 5 AND Subs < 7' """ ops = { - 'EQ': OpEqual, - 'NE': OpNotEqual, - 'LE': OpLessThanOrEqual, - 'LT': OpLessThan, - 'GE': OpGreaterThanOrEqual, - 'GT': OpGreaterThan, - 'NOT_NULL': FuncAttrExists, - 'NULL': FuncAttrNotExists, - 'CONTAINS': FuncContains, - 'NOT_CONTAINS': FuncNotContains, - 'BEGINS_WITH': FuncBeginsWith, - 'IN': FuncIn, - 'BETWEEN': FuncBetween, + "EQ": OpEqual, + "NE": OpNotEqual, + "LE": OpLessThanOrEqual, + "LT": OpLessThan, + "GE": OpGreaterThanOrEqual, + "GT": OpGreaterThan, + "NOT_NULL": FuncAttrExists, + "NULL": FuncAttrNotExists, + "CONTAINS": FuncContains, + "NOT_CONTAINS": FuncNotContains, + "BEGINS_WITH": FuncBeginsWith, + "IN": FuncIn, + "BETWEEN": FuncBetween, } # NOTE: Always uses ConditionalOperator=AND conditions = [] for key, cond in expected.items(): path = AttributePath([key]) - if 'Exists' in cond: - if cond['Exists']: - conditions.append(FuncAttrExists(path)) + if "Exists" in cond: + if cond["Exists"]: + conditions.append(FuncAttrExists(path)) else: - conditions.append(FuncAttrNotExists(path)) - elif 'Value' in cond: - conditions.append(OpEqual(path, AttributeValue(cond['Value']))) - elif 'ComparisonOperator' in cond: - operator_name = cond['ComparisonOperator'] - values = [ - AttributeValue(v) - for v in cond.get("AttributeValueList", [])] + conditions.append(FuncAttrNotExists(path)) + elif "Value" in cond: + conditions.append(OpEqual(path, AttributeValue(cond["Value"]))) + elif "ComparisonOperator" in cond: + operator_name = cond["ComparisonOperator"] + values = [AttributeValue(v) for v in cond.get("AttributeValueList", [])] OpClass = ops[operator_name] conditions.append(OpClass(path, *values)) @@ -77,7 +73,8 @@ class Op(object): """ Base class for a FilterExpression operator """ - OP = '' + + OP = "" def __init__(self, lhs, rhs): self.lhs = lhs @@ -87,45 +84,42 @@ class Op(object): raise NotImplementedError("Expr not defined for {0}".format(type(self))) def __repr__(self): - return '({0} {1} {2})'.format(self.lhs, self.OP, self.rhs) + return "({0} {1} {2})".format(self.lhs, self.OP, self.rhs) + # TODO add tests for all of these -EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # flake8: noqa -NE_FUNCTION = lambda item_value, test_value: item_value != test_value # flake8: noqa -LE_FUNCTION = lambda item_value, test_value: item_value <= test_value # flake8: noqa -LT_FUNCTION = lambda item_value, test_value: item_value < test_value # flake8: noqa -GE_FUNCTION = lambda item_value, test_value: item_value >= test_value # flake8: noqa -GT_FUNCTION = lambda item_value, test_value: item_value > test_value # flake8: noqa +EQ_FUNCTION = lambda item_value, test_value: item_value == test_value # noqa +NE_FUNCTION = lambda item_value, test_value: item_value != test_value # noqa +LE_FUNCTION = lambda item_value, test_value: item_value <= test_value # noqa +LT_FUNCTION = lambda item_value, test_value: item_value < test_value # noqa +GE_FUNCTION = lambda item_value, test_value: item_value >= test_value # noqa +GT_FUNCTION = lambda item_value, test_value: item_value > test_value # noqa COMPARISON_FUNCS = { - 'EQ': EQ_FUNCTION, - '=': EQ_FUNCTION, - - 'NE': NE_FUNCTION, - '!=': NE_FUNCTION, - - 'LE': LE_FUNCTION, - '<=': LE_FUNCTION, - - 'LT': LT_FUNCTION, - '<': LT_FUNCTION, - - 'GE': GE_FUNCTION, - '>=': GE_FUNCTION, - - 'GT': GT_FUNCTION, - '>': GT_FUNCTION, - + "EQ": EQ_FUNCTION, + "=": EQ_FUNCTION, + "NE": NE_FUNCTION, + "!=": NE_FUNCTION, + "LE": LE_FUNCTION, + "<=": LE_FUNCTION, + "LT": LT_FUNCTION, + "<": LT_FUNCTION, + "GE": GE_FUNCTION, + ">=": GE_FUNCTION, + "GT": GT_FUNCTION, + ">": GT_FUNCTION, # NULL means the value should not exist at all - 'NULL': lambda item_value: False, + "NULL": lambda item_value: False, # NOT_NULL means the value merely has to exist, and values of None are valid - 'NOT_NULL': lambda item_value: True, - 'CONTAINS': lambda item_value, test_value: test_value in item_value, - 'NOT_CONTAINS': lambda item_value, test_value: test_value not in item_value, - 'BEGINS_WITH': lambda item_value, test_value: item_value.startswith(test_value), - 'IN': lambda item_value, *test_values: item_value in test_values, - 'BETWEEN': lambda item_value, lower_test_value, upper_test_value: lower_test_value <= item_value <= upper_test_value, + "NOT_NULL": lambda item_value: True, + "CONTAINS": lambda item_value, test_value: test_value in item_value, + "NOT_CONTAINS": lambda item_value, test_value: test_value not in item_value, + "BEGINS_WITH": lambda item_value, test_value: item_value.startswith(test_value), + "IN": lambda item_value, *test_values: item_value in test_values, + "BETWEEN": lambda item_value, lower_test_value, upper_test_value: lower_test_value + <= item_value + <= upper_test_value, } @@ -138,8 +132,12 @@ class RecursionStopIteration(StopIteration): class ConditionExpressionParser: - def __init__(self, condition_expression, expression_attribute_names, - expression_attribute_values): + def __init__( + self, + condition_expression, + expression_attribute_names, + expression_attribute_values, + ): self.condition_expression = condition_expression self.expression_attribute_names = expression_attribute_names self.expression_attribute_values = expression_attribute_values @@ -203,52 +201,49 @@ class ConditionExpressionParser: # Condition nodes # --------------- - OR = 'OR' - AND = 'AND' - NOT = 'NOT' - PARENTHESES = 'PARENTHESES' - FUNCTION = 'FUNCTION' - BETWEEN = 'BETWEEN' - IN = 'IN' - COMPARISON = 'COMPARISON' + OR = "OR" + AND = "AND" + NOT = "NOT" + PARENTHESES = "PARENTHESES" + FUNCTION = "FUNCTION" + BETWEEN = "BETWEEN" + IN = "IN" + COMPARISON = "COMPARISON" # Operand nodes # ------------- - EXPRESSION_ATTRIBUTE_VALUE = 'EXPRESSION_ATTRIBUTE_VALUE' - PATH = 'PATH' + EXPRESSION_ATTRIBUTE_VALUE = "EXPRESSION_ATTRIBUTE_VALUE" + PATH = "PATH" # Literal nodes # -------------- - LITERAL = 'LITERAL' - + LITERAL = "LITERAL" class Nonterminal: """Enum defining nonterminals for productions.""" - CONDITION = 'CONDITION' - OPERAND = 'OPERAND' - COMPARATOR = 'COMPARATOR' - FUNCTION_NAME = 'FUNCTION_NAME' - IDENTIFIER = 'IDENTIFIER' - AND = 'AND' - OR = 'OR' - NOT = 'NOT' - BETWEEN = 'BETWEEN' - IN = 'IN' - COMMA = 'COMMA' - LEFT_PAREN = 'LEFT_PAREN' - RIGHT_PAREN = 'RIGHT_PAREN' - WHITESPACE = 'WHITESPACE' + CONDITION = "CONDITION" + OPERAND = "OPERAND" + COMPARATOR = "COMPARATOR" + FUNCTION_NAME = "FUNCTION_NAME" + IDENTIFIER = "IDENTIFIER" + AND = "AND" + OR = "OR" + NOT = "NOT" + BETWEEN = "BETWEEN" + IN = "IN" + COMMA = "COMMA" + LEFT_PAREN = "LEFT_PAREN" + RIGHT_PAREN = "RIGHT_PAREN" + WHITESPACE = "WHITESPACE" - - Node = namedtuple('Node', ['nonterminal', 'kind', 'text', 'value', 'children']) + Node = namedtuple("Node", ["nonterminal", "kind", "text", "value", "children"]) def _lex_condition_expression(self): nodes = deque() remaining_expression = self.condition_expression while remaining_expression: - node, remaining_expression = \ - self._lex_one_node(remaining_expression) + node, remaining_expression = self._lex_one_node(remaining_expression) if node.nonterminal == self.Nonterminal.WHITESPACE: continue nodes.append(node) @@ -256,49 +251,52 @@ class ConditionExpressionParser: def _lex_one_node(self, remaining_expression): # TODO: Handle indexing like [1] - attribute_regex = '(:|#)?[A-z0-9\-_]+' - patterns = [( - self.Nonterminal.WHITESPACE, re.compile('^ +') - ), ( - self.Nonterminal.COMPARATOR, re.compile( - '^(' - # Put long expressions first for greedy matching - '<>|' - '<=|' - '>=|' - '=|' - '<|' - '>)'), - ), ( - self.Nonterminal.OPERAND, re.compile( - '^' + - attribute_regex + '(\.' + attribute_regex + '|\[[0-9]\])*') - ), ( - self.Nonterminal.COMMA, re.compile('^,') - ), ( - self.Nonterminal.LEFT_PAREN, re.compile('^\(') - ), ( - self.Nonterminal.RIGHT_PAREN, re.compile('^\)') - )] + attribute_regex = "(:|#)?[A-z0-9\-_]+" + patterns = [ + (self.Nonterminal.WHITESPACE, re.compile("^ +")), + ( + self.Nonterminal.COMPARATOR, + re.compile( + "^(" + # Put long expressions first for greedy matching + "<>|" + "<=|" + ">=|" + "=|" + "<|" + ">)" + ), + ), + ( + self.Nonterminal.OPERAND, + re.compile( + "^" + attribute_regex + "(\." + attribute_regex + "|\[[0-9]\])*" + ), + ), + (self.Nonterminal.COMMA, re.compile("^,")), + (self.Nonterminal.LEFT_PAREN, re.compile("^\(")), + (self.Nonterminal.RIGHT_PAREN, re.compile("^\)")), + ] for nonterminal, pattern in patterns: match = pattern.match(remaining_expression) if match: match_text = match.group() break - else: # pragma: no cover - raise ValueError("Cannot parse condition starting at: " + - remaining_expression) + else: # pragma: no cover + raise ValueError( + "Cannot parse condition starting at: " + remaining_expression + ) - value = match_text node = self.Node( nonterminal=nonterminal, kind=self.Kind.LITERAL, text=match_text, value=match_text, - children=[]) + children=[], + ) - remaining_expression = remaining_expression[len(match_text):] + remaining_expression = remaining_expression[len(match_text) :] return node, remaining_expression @@ -309,10 +307,8 @@ class ConditionExpressionParser: node = nodes.popleft() if node.nonterminal == self.Nonterminal.OPERAND: - path = node.value.replace('[', '.[').split('.') - children = [ - self._parse_path_element(name) - for name in path] + path = node.value.replace("[", ".[").split(".") + children = [self._parse_path_element(name) for name in path] if len(children) == 1: child = children[0] if child.nonterminal != self.Nonterminal.IDENTIFIER: @@ -322,36 +318,40 @@ class ConditionExpressionParser: for child in children: self._assert( child.nonterminal == self.Nonterminal.IDENTIFIER, - "Cannot use %s in path" % child.text, [node]) - output.append(self.Node( - nonterminal=self.Nonterminal.OPERAND, - kind=self.Kind.PATH, - text=node.text, - value=None, - children=children)) + "Cannot use %s in path" % child.text, + [node], + ) + output.append( + self.Node( + nonterminal=self.Nonterminal.OPERAND, + kind=self.Kind.PATH, + text=node.text, + value=None, + children=children, + ) + ) else: output.append(node) return output def _parse_path_element(self, name): reserved = { - 'and': self.Nonterminal.AND, - 'or': self.Nonterminal.OR, - 'in': self.Nonterminal.IN, - 'between': self.Nonterminal.BETWEEN, - 'not': self.Nonterminal.NOT, + "and": self.Nonterminal.AND, + "or": self.Nonterminal.OR, + "in": self.Nonterminal.IN, + "between": self.Nonterminal.BETWEEN, + "not": self.Nonterminal.NOT, } functions = { - 'attribute_exists', - 'attribute_not_exists', - 'attribute_type', - 'begins_with', - 'contains', - 'size', + "attribute_exists", + "attribute_not_exists", + "attribute_type", + "begins_with", + "contains", + "size", } - if name.lower() in reserved: # e.g. AND nonterminal = reserved[name.lower()] @@ -360,7 +360,8 @@ class ConditionExpressionParser: kind=self.Kind.LITERAL, text=name, value=name, - children=[]) + children=[], + ) elif name in functions: # e.g. attribute_exists return self.Node( @@ -368,33 +369,37 @@ class ConditionExpressionParser: kind=self.Kind.LITERAL, text=name, value=name, - children=[]) - elif name.startswith(':'): + children=[], + ) + elif name.startswith(":"): # e.g. :value0 return self.Node( nonterminal=self.Nonterminal.OPERAND, kind=self.Kind.EXPRESSION_ATTRIBUTE_VALUE, text=name, value=self._lookup_expression_attribute_value(name), - children=[]) - elif name.startswith('#'): + children=[], + ) + elif name.startswith("#"): # e.g. #name0 return self.Node( nonterminal=self.Nonterminal.IDENTIFIER, kind=self.Kind.LITERAL, text=name, value=self._lookup_expression_attribute_name(name), - children=[]) - elif name.startswith('['): + children=[], + ) + elif name.startswith("["): # e.g. [123] - if not name.endswith(']'): # pragma: no cover + if not name.endswith("]"): # pragma: no cover raise ValueError("Bad path element %s" % name) return self.Node( nonterminal=self.Nonterminal.IDENTIFIER, kind=self.Kind.LITERAL, text=name, value=int(name[1:-1]), - children=[]) + children=[], + ) else: # e.g. ItemId return self.Node( @@ -402,7 +407,8 @@ class ConditionExpressionParser: kind=self.Kind.LITERAL, text=name, value=name, - children=[]) + children=[], + ) def _lookup_expression_attribute_value(self, name): return self.expression_attribute_values[name] @@ -465,7 +471,7 @@ class ConditionExpressionParser: if len(nodes) < len(production): return False for i in range(len(production)): - if production[i] == '*': + if production[i] == "*": continue expected = getattr(self.Nonterminal, production[i]) if nodes[i].nonterminal != expected: @@ -477,22 +483,24 @@ class ConditionExpressionParser: output = deque() while nodes: - if self._matches(nodes, ['*', 'COMPARATOR']): + if self._matches(nodes, ["*", "COMPARATOR"]): self._assert( - self._matches(nodes, ['OPERAND', 'COMPARATOR', 'OPERAND']), - "Bad comparison", list(nodes)[:3]) + self._matches(nodes, ["OPERAND", "COMPARATOR", "OPERAND"]), + "Bad comparison", + list(nodes)[:3], + ) lhs = nodes.popleft() comparator = nodes.popleft() rhs = nodes.popleft() - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.COMPARISON, - text=" ".join([ - lhs.text, - comparator.text, - rhs.text]), - value=None, - children=[lhs, comparator, rhs])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.COMPARISON, + text=" ".join([lhs.text, comparator.text, rhs.text]), + value=None, + children=[lhs, comparator, rhs], + ) + ) else: output.append(nodes.popleft()) return output @@ -501,37 +509,40 @@ class ConditionExpressionParser: """Apply condition := operand IN ( operand , ... ).""" output = deque() while nodes: - if self._matches(nodes, ['*', 'IN']): + if self._matches(nodes, ["*", "IN"]): self._assert( - self._matches(nodes, ['OPERAND', 'IN', 'LEFT_PAREN']), - "Bad IN expression", list(nodes)[:3]) + self._matches(nodes, ["OPERAND", "IN", "LEFT_PAREN"]), + "Bad IN expression", + list(nodes)[:3], + ) lhs = nodes.popleft() in_node = nodes.popleft() left_paren = nodes.popleft() all_children = [lhs, in_node, left_paren] rhs = [] while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): + if self._matches(nodes, ["OPERAND", "COMMA"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] rhs.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + elif self._matches(nodes, ["OPERAND", "RIGHT_PAREN"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] rhs.append(operand) break # Close else: - self._assert( - False, - "Bad IN expression starting at", nodes) - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.IN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs] + rhs)) + self._assert(False, "Bad IN expression starting at", nodes) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.IN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs] + rhs, + ) + ) else: output.append(nodes.popleft()) return output @@ -540,23 +551,29 @@ class ConditionExpressionParser: """Apply condition := operand BETWEEN operand AND operand.""" output = deque() while nodes: - if self._matches(nodes, ['*', 'BETWEEN']): + if self._matches(nodes, ["*", "BETWEEN"]): self._assert( - self._matches(nodes, ['OPERAND', 'BETWEEN', 'OPERAND', - 'AND', 'OPERAND']), - "Bad BETWEEN expression", list(nodes)[:5]) + self._matches( + nodes, ["OPERAND", "BETWEEN", "OPERAND", "AND", "OPERAND"] + ), + "Bad BETWEEN expression", + list(nodes)[:5], + ) lhs = nodes.popleft() between_node = nodes.popleft() low = nodes.popleft() and_node = nodes.popleft() high = nodes.popleft() all_children = [lhs, between_node, low, and_node, high] - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.BETWEEN, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, low, high])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.BETWEEN, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, low, high], + ) + ) else: output.append(nodes.popleft()) return output @@ -566,30 +583,33 @@ class ConditionExpressionParser: output = deque() either_kind = {self.Kind.PATH, self.Kind.EXPRESSION_ATTRIBUTE_VALUE} expected_argument_kind_map = { - 'attribute_exists': [{self.Kind.PATH}], - 'attribute_not_exists': [{self.Kind.PATH}], - 'attribute_type': [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], - 'begins_with': [either_kind, either_kind], - 'contains': [either_kind, either_kind], - 'size': [{self.Kind.PATH}], + "attribute_exists": [{self.Kind.PATH}], + "attribute_not_exists": [{self.Kind.PATH}], + "attribute_type": [either_kind, {self.Kind.EXPRESSION_ATTRIBUTE_VALUE}], + "begins_with": [either_kind, either_kind], + "contains": [either_kind, either_kind], + "size": [{self.Kind.PATH}], } while nodes: - if self._matches(nodes, ['FUNCTION_NAME']): + if self._matches(nodes, ["FUNCTION_NAME"]): self._assert( - self._matches(nodes, ['FUNCTION_NAME', 'LEFT_PAREN', - 'OPERAND', '*']), - "Bad function expression at", list(nodes)[:4]) + self._matches( + nodes, ["FUNCTION_NAME", "LEFT_PAREN", "OPERAND", "*"] + ), + "Bad function expression at", + list(nodes)[:4], + ) function_name = nodes.popleft() left_paren = nodes.popleft() all_children = [function_name, left_paren] arguments = [] while True: - if self._matches(nodes, ['OPERAND', 'COMMA']): + if self._matches(nodes, ["OPERAND", "COMMA"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] arguments.append(operand) - elif self._matches(nodes, ['OPERAND', 'RIGHT_PAREN']): + elif self._matches(nodes, ["OPERAND", "RIGHT_PAREN"]): operand = nodes.popleft() separator = nodes.popleft() all_children += [operand, separator] @@ -598,25 +618,34 @@ class ConditionExpressionParser: else: self._assert( False, - "Bad function expression", all_children + list(nodes)[:2]) + "Bad function expression", + all_children + list(nodes)[:2], + ) expected_kinds = expected_argument_kind_map[function_name.value] self._assert( len(arguments) == len(expected_kinds), - "Wrong number of arguments in", all_children) + "Wrong number of arguments in", + all_children, + ) for i in range(len(expected_kinds)): self._assert( arguments[i].kind in expected_kinds[i], - "Wrong type for argument %d in" % i, all_children) - if function_name.value == 'size': + "Wrong type for argument %d in" % i, + all_children, + ) + if function_name.value == "size": nonterminal = self.Nonterminal.OPERAND else: nonterminal = self.Nonterminal.CONDITION - nodes.appendleft(self.Node( - nonterminal=nonterminal, - kind=self.Kind.FUNCTION, - text=" ".join([t.text for t in all_children]), - value=None, - children=[function_name] + arguments)) + nodes.appendleft( + self.Node( + nonterminal=nonterminal, + kind=self.Kind.FUNCTION, + text=" ".join([t.text for t in all_children]), + value=None, + children=[function_name] + arguments, + ) + ) else: output.append(nodes.popleft()) return output @@ -625,38 +654,40 @@ class ConditionExpressionParser: """Apply condition := ( condition ) and booleans.""" output = deque() while nodes: - if self._matches(nodes, ['LEFT_PAREN']): - parsed = self._apply_parens_and_booleans(nodes, left_paren=nodes.popleft()) - self._assert( - len(parsed) >= 1, - "Failed to close parentheses at", nodes) + if self._matches(nodes, ["LEFT_PAREN"]): + parsed = self._apply_parens_and_booleans( + nodes, left_paren=nodes.popleft() + ) + self._assert(len(parsed) >= 1, "Failed to close parentheses at", nodes) parens = parsed.popleft() self._assert( parens.kind == self.Kind.PARENTHESES, - "Failed to close parentheses at", nodes) + "Failed to close parentheses at", + nodes, + ) output.append(parens) nodes = parsed - elif self._matches(nodes, ['RIGHT_PAREN']): - self._assert( - left_paren is not None, - "Unmatched ) at", nodes) + elif self._matches(nodes, ["RIGHT_PAREN"]): + self._assert(left_paren is not None, "Unmatched ) at", nodes) close_paren = nodes.popleft() children = self._apply_booleans(output) all_children = [left_paren] + list(children) + [close_paren] - return deque([ - self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.PARENTHESES, - text=" ".join([t.text for t in all_children]), - value=None, - children=list(children), - )] + list(nodes)) + return deque( + [ + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.PARENTHESES, + text=" ".join([t.text for t in all_children]), + value=None, + children=list(children), + ) + ] + + list(nodes) + ) else: output.append(nodes.popleft()) - self._assert( - left_paren is None, - "Unmatched ( at", list(output)) + self._assert(left_paren is None, "Unmatched ( at", list(output)) return self._apply_booleans(output) def _apply_booleans(self, nodes): @@ -665,30 +696,35 @@ class ConditionExpressionParser: nodes = self._apply_and(nodes) nodes = self._apply_or(nodes) # The expression should reduce to a single condition - self._assert( - len(nodes) == 1, - "Unexpected expression at", list(nodes)[1:]) + self._assert(len(nodes) == 1, "Unexpected expression at", list(nodes)[1:]) self._assert( nodes[0].nonterminal == self.Nonterminal.CONDITION, - "Incomplete condition", nodes) + "Incomplete condition", + nodes, + ) return nodes def _apply_not(self, nodes): """Apply condition := NOT condition.""" output = deque() while nodes: - if self._matches(nodes, ['NOT']): + if self._matches(nodes, ["NOT"]): self._assert( - self._matches(nodes, ['NOT', 'CONDITION']), - "Bad NOT expression", list(nodes)[:2]) + self._matches(nodes, ["NOT", "CONDITION"]), + "Bad NOT expression", + list(nodes)[:2], + ) not_node = nodes.popleft() child = nodes.popleft() - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.NOT, - text=" ".join([not_node.text, child.text]), - value=None, - children=[child])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.NOT, + text=" ".join([not_node.text, child.text]), + value=None, + children=[child], + ) + ) else: output.append(nodes.popleft()) @@ -698,20 +734,25 @@ class ConditionExpressionParser: """Apply condition := condition AND condition.""" output = deque() while nodes: - if self._matches(nodes, ['*', 'AND']): + if self._matches(nodes, ["*", "AND"]): self._assert( - self._matches(nodes, ['CONDITION', 'AND', 'CONDITION']), - "Bad AND expression", list(nodes)[:3]) + self._matches(nodes, ["CONDITION", "AND", "CONDITION"]), + "Bad AND expression", + list(nodes)[:3], + ) lhs = nodes.popleft() and_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, and_node, rhs] - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.AND, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.AND, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs], + ) + ) else: output.append(nodes.popleft()) @@ -721,20 +762,25 @@ class ConditionExpressionParser: """Apply condition := condition OR condition.""" output = deque() while nodes: - if self._matches(nodes, ['*', 'OR']): + if self._matches(nodes, ["*", "OR"]): self._assert( - self._matches(nodes, ['CONDITION', 'OR', 'CONDITION']), - "Bad OR expression", list(nodes)[:3]) + self._matches(nodes, ["CONDITION", "OR", "CONDITION"]), + "Bad OR expression", + list(nodes)[:3], + ) lhs = nodes.popleft() or_node = nodes.popleft() rhs = nodes.popleft() all_children = [lhs, or_node, rhs] - nodes.appendleft(self.Node( - nonterminal=self.Nonterminal.CONDITION, - kind=self.Kind.OR, - text=" ".join([t.text for t in all_children]), - value=None, - children=[lhs, rhs])) + nodes.appendleft( + self.Node( + nonterminal=self.Nonterminal.CONDITION, + kind=self.Kind.OR, + text=" ".join([t.text for t in all_children]), + value=None, + children=[lhs, rhs], + ) + ) else: output.append(nodes.popleft()) @@ -748,30 +794,25 @@ class ConditionExpressionParser: elif node.kind == self.Kind.FUNCTION: # size() function_node = node.children[0] - arguments = node.children[1:] + arguments = node.children[1:] function_name = function_node.value arguments = [self._make_operand(arg) for arg in arguments] return FUNC_CLASS[function_name](*arguments) - else: # pragma: no cover + else: # pragma: no cover raise ValueError("Unknown operand: %r" % node) - def _make_op_condition(self, node): if node.kind == self.Kind.OR: lhs, rhs = node.children - return OpOr( - self._make_op_condition(lhs), - self._make_op_condition(rhs)) + return OpOr(self._make_op_condition(lhs), self._make_op_condition(rhs)) elif node.kind == self.Kind.AND: lhs, rhs = node.children - return OpAnd( - self._make_op_condition(lhs), - self._make_op_condition(rhs)) + return OpAnd(self._make_op_condition(lhs), self._make_op_condition(rhs)) elif node.kind == self.Kind.NOT: - child, = node.children + (child,) = node.children return OpNot(self._make_op_condition(child)) elif node.kind == self.Kind.PARENTHESES: - child, = node.children + (child,) = node.children return self._make_op_condition(child) elif node.kind == self.Kind.FUNCTION: function_node = node.children[0] @@ -784,7 +825,8 @@ class ConditionExpressionParser: return FuncBetween( self._make_operand(query), self._make_operand(low), - self._make_operand(high)) + self._make_operand(high), + ) elif node.kind == self.Kind.IN: query = node.children[0] possible_values = node.children[1:] @@ -794,9 +836,9 @@ class ConditionExpressionParser: elif node.kind == self.Kind.COMPARISON: lhs, comparator, rhs = node.children return COMPARATOR_CLASS[comparator.value]( - self._make_operand(lhs), - self._make_operand(rhs)) - else: # pragma: no cover + self._make_operand(lhs), self._make_operand(rhs) + ) + else: # pragma: no cover raise ValueError("Unknown expression node kind %r" % node.kind) def _assert(self, condition, message, nodes): @@ -873,21 +915,20 @@ class AttributeValue(Operand): def expr(self, item): # TODO: Reuse DynamoType code - if self.type == 'N': + if self.type == "N": try: return int(self.value) except ValueError: return float(self.value) - elif self.type in ['SS', 'NS', 'BS']: + elif self.type in ["SS", "NS", "BS"]: sub_type = self.type[0] - return set([AttributeValue({sub_type: v}).expr(item) - for v in self.value]) - elif self.type == 'L': + return set([AttributeValue({sub_type: v}).expr(item) for v in self.value]) + elif self.type == "L": return [AttributeValue(v).expr(item) for v in self.value] - elif self.type == 'M': - return dict([ - (k, AttributeValue(v).expr(item)) - for k, v in self.value.items()]) + elif self.type == "M": + return dict( + [(k, AttributeValue(v).expr(item)) for k, v in self.value.items()] + ) else: return self.value return self.value @@ -900,7 +941,7 @@ class AttributeValue(Operand): class OpDefault(Op): - OP = 'NONE' + OP = "NONE" def expr(self, item): """If no condition is specified, always True.""" @@ -908,7 +949,7 @@ class OpDefault(Op): class OpNot(Op): - OP = 'NOT' + OP = "NOT" def __init__(self, lhs): super(OpNot, self).__init__(lhs, None) @@ -918,11 +959,11 @@ class OpNot(Op): return not lhs def __str__(self): - return '({0} {1})'.format(self.OP, self.lhs) + return "({0} {1})".format(self.OP, self.lhs) class OpAnd(Op): - OP = 'AND' + OP = "AND" def expr(self, item): lhs = self.lhs.expr(item) @@ -930,7 +971,7 @@ class OpAnd(Op): class OpLessThan(Op): - OP = '<' + OP = "<" def expr(self, item): lhs = self.lhs.expr(item) @@ -945,7 +986,7 @@ class OpLessThan(Op): class OpGreaterThan(Op): - OP = '>' + OP = ">" def expr(self, item): lhs = self.lhs.expr(item) @@ -960,7 +1001,7 @@ class OpGreaterThan(Op): class OpEqual(Op): - OP = '=' + OP = "=" def expr(self, item): lhs = self.lhs.expr(item) @@ -969,7 +1010,7 @@ class OpEqual(Op): class OpNotEqual(Op): - OP = '<>' + OP = "<>" def expr(self, item): lhs = self.lhs.expr(item) @@ -978,7 +1019,7 @@ class OpNotEqual(Op): class OpLessThanOrEqual(Op): - OP = '<=' + OP = "<=" def expr(self, item): lhs = self.lhs.expr(item) @@ -993,7 +1034,7 @@ class OpLessThanOrEqual(Op): class OpGreaterThanOrEqual(Op): - OP = '>=' + OP = ">=" def expr(self, item): lhs = self.lhs.expr(item) @@ -1008,7 +1049,7 @@ class OpGreaterThanOrEqual(Op): class OpOr(Op): - OP = 'OR' + OP = "OR" def expr(self, item): lhs = self.lhs.expr(item) @@ -1019,7 +1060,8 @@ class Func(object): """ Base class for a FilterExpression function """ - FUNC = 'Unknown' + + FUNC = "Unknown" def __init__(self, *arguments): self.arguments = arguments @@ -1028,13 +1070,13 @@ class Func(object): raise NotImplementedError def __repr__(self): - return '{0}({1})'.format( - self.FUNC, - " ".join([repr(arg) for arg in self.arguments])) + return "{0}({1})".format( + self.FUNC, " ".join([repr(arg) for arg in self.arguments]) + ) class FuncAttrExists(Func): - FUNC = 'attribute_exists' + FUNC = "attribute_exists" def __init__(self, attribute): self.attr = attribute @@ -1049,7 +1091,7 @@ def FuncAttrNotExists(attribute): class FuncAttrType(Func): - FUNC = 'attribute_type' + FUNC = "attribute_type" def __init__(self, attribute, _type): self.attr = attribute @@ -1061,7 +1103,7 @@ class FuncAttrType(Func): class FuncBeginsWith(Func): - FUNC = 'begins_with' + FUNC = "begins_with" def __init__(self, attribute, substr): self.attr = attribute @@ -1069,15 +1111,15 @@ class FuncBeginsWith(Func): super(FuncBeginsWith, self).__init__(attribute, substr) def expr(self, item): - if self.attr.get_type(item) != 'S': + if self.attr.get_type(item) != "S": return False - if self.substr.get_type(item) != 'S': + if self.substr.get_type(item) != "S": return False return self.attr.expr(item).startswith(self.substr.expr(item)) class FuncContains(Func): - FUNC = 'contains' + FUNC = "contains" def __init__(self, attribute, operand): self.attr = attribute @@ -1085,7 +1127,7 @@ class FuncContains(Func): super(FuncContains, self).__init__(attribute, operand) def expr(self, item): - if self.attr.get_type(item) in ('S', 'SS', 'NS', 'BS', 'L'): + if self.attr.get_type(item) in ("S", "SS", "NS", "BS", "L"): try: return self.operand.expr(item) in self.attr.expr(item) except TypeError: @@ -1098,7 +1140,7 @@ def FuncNotContains(attribute, operand): class FuncSize(Func): - FUNC = 'size' + FUNC = "size" def __init__(self, attribute): self.attr = attribute @@ -1106,15 +1148,15 @@ class FuncSize(Func): def expr(self, item): if self.attr.get_type(item) is None: - raise ValueError('Invalid attribute name {0}'.format(self.attr)) + raise ValueError("Invalid attribute name {0}".format(self.attr)) - if self.attr.get_type(item) in ('S', 'SS', 'NS', 'B', 'BS', 'L', 'M'): + if self.attr.get_type(item) in ("S", "SS", "NS", "B", "BS", "L", "M"): return len(self.attr.expr(item)) - raise ValueError('Invalid filter expression') + raise ValueError("Invalid filter expression") class FuncBetween(Func): - FUNC = 'BETWEEN' + FUNC = "BETWEEN" def __init__(self, attribute, start, end): self.attr = attribute @@ -1139,7 +1181,7 @@ class FuncBetween(Func): class FuncIn(Func): - FUNC = 'IN' + FUNC = "IN" def __init__(self, attribute, *possible_values): self.attr = attribute @@ -1155,20 +1197,20 @@ class FuncIn(Func): COMPARATOR_CLASS = { - '<': OpLessThan, - '>': OpGreaterThan, - '<=': OpLessThanOrEqual, - '>=': OpGreaterThanOrEqual, - '=': OpEqual, - '<>': OpNotEqual + "<": OpLessThan, + ">": OpGreaterThan, + "<=": OpLessThanOrEqual, + ">=": OpGreaterThanOrEqual, + "=": OpEqual, + "<>": OpNotEqual, } FUNC_CLASS = { - 'attribute_exists': FuncAttrExists, - 'attribute_not_exists': FuncAttrNotExists, - 'attribute_type': FuncAttrType, - 'begins_with': FuncBeginsWith, - 'contains': FuncContains, - 'size': FuncSize, - 'between': FuncBetween + "attribute_exists": FuncAttrExists, + "attribute_not_exists": FuncAttrNotExists, + "attribute_type": FuncAttrType, + "begins_with": FuncBeginsWith, + "contains": FuncContains, + "size": FuncSize, + "between": FuncBetween, } diff --git a/moto/dynamodb2/exceptions.py b/moto/dynamodb2/exceptions.py index ef5d2b982..1f3b5f974 100644 --- a/moto/dynamodb2/exceptions.py +++ b/moto/dynamodb2/exceptions.py @@ -7,4 +7,4 @@ class InvalidUpdateExpression(ValueError): class ItemSizeTooLarge(Exception): - message = 'Item size has exceeded the maximum allowed size' + message = "Item size has exceeded the maximum allowed size" diff --git a/moto/dynamodb2/models.py b/moto/dynamodb2/models.py index 0b8fdfbf2..cd49d7b1f 100644 --- a/moto/dynamodb2/models.py +++ b/moto/dynamodb2/models.py @@ -21,9 +21,8 @@ from .exceptions import InvalidIndexNameError, InvalidUpdateExpression, ItemSize class DynamoJsonEncoder(json.JSONEncoder): - def default(self, obj): - if hasattr(obj, 'to_json'): + if hasattr(obj, "to_json"): return obj.to_json() @@ -32,7 +31,7 @@ def dynamo_json_dump(dynamo_object): def bytesize(val): - return len(str(val).encode('utf-8')) + return len(str(val).encode("utf-8")) def attribute_is_list(attr): @@ -41,7 +40,7 @@ def attribute_is_list(attr): :param attr: attr or attr[index] :return: attr, index or None """ - list_index_update = re.match('(.+)\\[([0-9]+)\\]', attr) + list_index_update = re.match("(.+)\\[([0-9]+)\\]", attr) if list_index_update: attr = list_index_update.group(1) return attr, list_index_update.group(2) if list_index_update else None @@ -74,35 +73,37 @@ class DynamoType(object): # {'L': [DynamoType, ..]} ==> DynamoType.set() self.value[min(index, len(self.value) - 1)].set(key, new_value) else: - attr = (key or '').split('.').pop(0) + attr = (key or "").split(".").pop(0) attr, list_index = attribute_is_list(attr) if not key: # {'S': value} ==> {'S': new_value} self.value = new_value.value else: if attr not in self.value: # nonexistingattribute - type_of_new_attr = 'M' if '.' in key else new_value.type + type_of_new_attr = "M" if "." in key else new_value.type self.value[attr] = DynamoType({type_of_new_attr: {}}) # {'M': {'foo': DynamoType}} ==> DynamoType.set(new_value) - self.value[attr].set('.'.join(key.split('.')[1:]), new_value, list_index) + self.value[attr].set( + ".".join(key.split(".")[1:]), new_value, list_index + ) def delete(self, key, index=None): if index: if not key: if int(index) < len(self.value): del self.value[int(index)] - elif '.' in key: - self.value[int(index)].delete('.'.join(key.split('.')[1:])) + elif "." in key: + self.value[int(index)].delete(".".join(key.split(".")[1:])) else: self.value[int(index)].delete(key) else: - attr = key.split('.')[0] + attr = key.split(".")[0] attr, list_index = attribute_is_list(attr) if list_index: - self.value[attr].delete('.'.join(key.split('.')[1:]), list_index) - elif '.' in key: - self.value[attr].delete('.'.join(key.split('.')[1:])) + self.value[attr].delete(".".join(key.split(".")[1:]), list_index) + elif "." in key: + self.value[attr].delete(".".join(key.split(".")[1:])) else: self.value.pop(key) @@ -110,10 +111,7 @@ class DynamoType(object): return hash((self.type, self.value)) def __eq__(self, other): - return ( - self.type == other.type and - self.value == other.value - ) + return self.type == other.type and self.value == other.value def __lt__(self, other): return self.cast_value < other.cast_value @@ -139,14 +137,11 @@ class DynamoType(object): return float(self.value) elif self.is_set(): sub_type = self.type[0] - return set([DynamoType({sub_type: v}).cast_value - for v in self.value]) + return set([DynamoType({sub_type: v}).cast_value for v in self.value]) elif self.is_list(): return [DynamoType(v).cast_value for v in self.value] elif self.is_map(): - return dict([ - (k, DynamoType(v).cast_value) - for k, v in self.value.items()]) + return dict([(k, DynamoType(v).cast_value) for k, v in self.value.items()]) else: return self.value @@ -175,7 +170,9 @@ class DynamoType(object): elif self.is_list(): value_size = sum([v.size() for v in self.value]) elif self.is_map(): - value_size = sum([bytesize(k) + DynamoType(v).size() for k, v in self.value.items()]) + value_size = sum( + [bytesize(k) + DynamoType(v).size() for k, v in self.value.items()] + ) elif type(self.value) == bool: value_size = 1 else: @@ -194,16 +191,16 @@ class DynamoType(object): return comparison_func(self.cast_value, *range_values) def is_number(self): - return self.type == 'N' + return self.type == "N" def is_set(self): - return self.type == 'SS' or self.type == 'NS' or self.type == 'BS' + return self.type == "SS" or self.type == "NS" or self.type == "BS" def is_list(self): - return self.type == 'L' + return self.type == "L" def is_map(self): - return self.type == 'M' + return self.type == "M" def same_type(self, other): return self.type == other.type @@ -216,8 +213,15 @@ class LimitedSizeDict(dict): self.update(*args, **kwargs) def __setitem__(self, key, value): - current_item_size = sum([item.size() if type(item) == DynamoType else bytesize(str(item)) for item in (list(self.keys()) + list(self.values()))]) - new_item_size = bytesize(key) + (value.size() if type(value) == DynamoType else bytesize(str(value))) + current_item_size = sum( + [ + item.size() if type(item) == DynamoType else bytesize(str(item)) + for item in (list(self.keys()) + list(self.values())) + ] + ) + new_item_size = bytesize(key) + ( + value.size() if type(value) == DynamoType else bytesize(str(value)) + ) # Official limit is set to 400000 (400KB) # Manual testing confirms that the actual limit is between 409 and 410KB # We'll set the limit to something in between to be safe @@ -227,7 +231,6 @@ class LimitedSizeDict(dict): class Item(BaseModel): - def __init__(self, hash_key, hash_key_type, range_key, range_key_type, attrs): self.hash_key = hash_key self.hash_key_type = hash_key_type @@ -244,13 +247,9 @@ class Item(BaseModel): def to_json(self): attributes = {} for attribute_key, attribute in self.attrs.items(): - attributes[attribute_key] = { - attribute.type: attribute.value - } + attributes[attribute_key] = {attribute.type: attribute.value} - return { - "Attributes": attributes - } + return {"Attributes": attributes} def describe_attrs(self, attributes): if attributes: @@ -260,31 +259,41 @@ class Item(BaseModel): included[key] = value else: included = self.attrs - return { - "Item": included - } + return {"Item": included} - def update(self, update_expression, expression_attribute_names, expression_attribute_values): + def update( + self, update_expression, expression_attribute_names, expression_attribute_values + ): # Update subexpressions are identifiable by the operator keyword, so split on that and # get rid of the empty leading string. - parts = [p for p in re.split(r'\b(SET|REMOVE|ADD|DELETE)\b', update_expression, flags=re.I) if p] + parts = [ + p + for p in re.split( + r"\b(SET|REMOVE|ADD|DELETE)\b", update_expression, flags=re.I + ) + if p + ] # make sure that we correctly found only operator/value pairs - assert len(parts) % 2 == 0, "Mismatched operators and values in update expression: '{}'".format(update_expression) + assert ( + len(parts) % 2 == 0 + ), "Mismatched operators and values in update expression: '{}'".format( + update_expression + ) for action, valstr in zip(parts[:-1:2], parts[1::2]): action = action.upper() # "Should" retain arguments in side (...) - values = re.split(r',(?![^(]*\))', valstr) + values = re.split(r",(?![^(]*\))", valstr) for value in values: # A Real value value = value.lstrip(":").rstrip(",").strip() for k, v in expression_attribute_names.items(): - value = re.sub(r'{0}\b'.format(k), v, value) + value = re.sub(r"{0}\b".format(k), v, value) if action == "REMOVE": key = value - attr, list_index = attribute_is_list(key.split('.')[0]) - if '.' not in key: + attr, list_index = attribute_is_list(key.split(".")[0]) + if "." not in key: if list_index: new_list = DynamoType(self.attrs[attr]) new_list.delete(None, list_index) @@ -293,14 +302,14 @@ class Item(BaseModel): self.attrs.pop(value, None) else: # Handle nested dict updates - self.attrs[attr].delete('.'.join(key.split('.')[1:])) - elif action == 'SET': + self.attrs[attr].delete(".".join(key.split(".")[1:])) + elif action == "SET": key, value = value.split("=", 1) key = key.strip() value = value.strip() # check whether key is a list - attr, list_index = attribute_is_list(key.split('.')[0]) + attr, list_index = attribute_is_list(key.split(".")[0]) # If value not exists, changes value to a default if needed, else its the same as it was value = self._get_default(value) # If operation == list_append, get the original value and append it @@ -314,14 +323,16 @@ class Item(BaseModel): else: dyn_value = value - if '.' in key and attr not in self.attrs: + if "." in key and attr not in self.attrs: raise ValueError # Setting nested attr not allowed if first attr does not exist yet elif attr not in self.attrs: self.attrs[attr] = dyn_value # set new top-level attribute else: - self.attrs[attr].set('.'.join(key.split('.')[1:]), dyn_value, list_index) # set value recursively + self.attrs[attr].set( + ".".join(key.split(".")[1:]), dyn_value, list_index + ) # set value recursively - elif action == 'ADD': + elif action == "ADD": key, value = value.split(" ", 1) key = key.strip() value_str = value.strip() @@ -333,13 +344,17 @@ class Item(BaseModel): # Handle adding numbers - value gets added to existing value, # or added to 0 if it doesn't exist yet if dyn_value.is_number(): - existing = self.attrs.get(key, DynamoType({"N": '0'})) + existing = self.attrs.get(key, DynamoType({"N": "0"})) if not existing.same_type(dyn_value): raise TypeError() - self.attrs[key] = DynamoType({"N": str( - decimal.Decimal(existing.value) + - decimal.Decimal(dyn_value.value) - )}) + self.attrs[key] = DynamoType( + { + "N": str( + decimal.Decimal(existing.value) + + decimal.Decimal(dyn_value.value) + ) + } + ) # Handle adding sets - value is added to the set, or set is # created with only this value if it doesn't exist yet @@ -353,7 +368,7 @@ class Item(BaseModel): else: # Number and Sets are the only supported types for ADD raise TypeError - elif action == 'DELETE': + elif action == "DELETE": key, value = value.split(" ", 1) key = key.strip() value_str = value.strip() @@ -371,24 +386,28 @@ class Item(BaseModel): new_set = set(existing.value).difference(dyn_value.value) self.attrs[key] = DynamoType({existing.type: list(new_set)}) else: - raise NotImplementedError('{} update action not yet supported'.format(action)) + raise NotImplementedError( + "{} update action not yet supported".format(action) + ) def _get_appended_list(self, value, expression_attribute_values): if type(value) != DynamoType: - list_append_re = re.match('list_append\\((.+),(.+)\\)', value) + list_append_re = re.match("list_append\\((.+),(.+)\\)", value) if list_append_re: new_value = expression_attribute_values[list_append_re.group(2).strip()] old_list = self.attrs[list_append_re.group(1)] if not old_list.is_list(): raise ParamValidationError - old_list.value.extend(new_value['L']) + old_list.value.extend(new_value["L"]) value = old_list return value def _get_default(self, value): - if value.startswith('if_not_exists'): + if value.startswith("if_not_exists"): # Function signature - match = re.match(r'.*if_not_exists\s*\((?P.+),\s*(?P.+)\).*', value) + match = re.match( + r".*if_not_exists\s*\((?P.+),\s*(?P.+)\).*", value + ) if not match: raise TypeError @@ -401,13 +420,13 @@ class Item(BaseModel): def update_with_attribute_updates(self, attribute_updates): for attribute_name, update_action in attribute_updates.items(): - action = update_action['Action'] - if action == 'DELETE' and 'Value' not in update_action: + action = update_action["Action"] + if action == "DELETE" and "Value" not in update_action: if attribute_name in self.attrs: del self.attrs[attribute_name] continue - new_value = list(update_action['Value'].values())[0] - if action == 'PUT': + new_value = list(update_action["Value"].values())[0] + if action == "PUT": # TODO deal with other types if isinstance(new_value, list): self.attrs[attribute_name] = DynamoType({"L": new_value}) @@ -415,50 +434,54 @@ class Item(BaseModel): self.attrs[attribute_name] = DynamoType({"SS": new_value}) elif isinstance(new_value, dict): self.attrs[attribute_name] = DynamoType({"M": new_value}) - elif set(update_action['Value'].keys()) == set(['N']): + elif set(update_action["Value"].keys()) == set(["N"]): self.attrs[attribute_name] = DynamoType({"N": new_value}) - elif set(update_action['Value'].keys()) == set(['NULL']): + elif set(update_action["Value"].keys()) == set(["NULL"]): if attribute_name in self.attrs: del self.attrs[attribute_name] else: self.attrs[attribute_name] = DynamoType({"S": new_value}) - elif action == 'ADD': - if set(update_action['Value'].keys()) == set(['N']): - existing = self.attrs.get( - attribute_name, DynamoType({"N": '0'})) - self.attrs[attribute_name] = DynamoType({"N": str( - decimal.Decimal(existing.value) + - decimal.Decimal(new_value) - )}) - elif set(update_action['Value'].keys()) == set(['SS']): + elif action == "ADD": + if set(update_action["Value"].keys()) == set(["N"]): + existing = self.attrs.get(attribute_name, DynamoType({"N": "0"})) + self.attrs[attribute_name] = DynamoType( + { + "N": str( + decimal.Decimal(existing.value) + + decimal.Decimal(new_value) + ) + } + ) + elif set(update_action["Value"].keys()) == set(["SS"]): existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) new_set = set(existing.value).union(set(new_value)) - self.attrs[attribute_name] = DynamoType({ - "SS": list(new_set) - }) + self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) else: # TODO: implement other data types raise NotImplementedError( - 'ADD not supported for %s' % ', '.join(update_action['Value'].keys())) - elif action == 'DELETE': - if set(update_action['Value'].keys()) == set(['SS']): + "ADD not supported for %s" + % ", ".join(update_action["Value"].keys()) + ) + elif action == "DELETE": + if set(update_action["Value"].keys()) == set(["SS"]): existing = self.attrs.get(attribute_name, DynamoType({"SS": {}})) new_set = set(existing.value).difference(set(new_value)) - self.attrs[attribute_name] = DynamoType({ - "SS": list(new_set) - }) + self.attrs[attribute_name] = DynamoType({"SS": list(new_set)}) else: raise NotImplementedError( - 'ADD not supported for %s' % ', '.join(update_action['Value'].keys())) + "ADD not supported for %s" + % ", ".join(update_action["Value"].keys()) + ) else: raise NotImplementedError( - '%s action not support for update_with_attribute_updates' % action) + "%s action not support for update_with_attribute_updates" % action + ) class StreamRecord(BaseModel): def __init__(self, table, stream_type, event_name, old, new, seq): - old_a = old.to_json()['Attributes'] if old is not None else {} - new_a = new.to_json()['Attributes'] if new is not None else {} + old_a = old.to_json()["Attributes"] if old is not None else {} + new_a = new.to_json()["Attributes"] if new is not None else {} rec = old if old is not None else new keys = {table.hash_key_attr: rec.hash_key.to_json()} @@ -466,28 +489,27 @@ class StreamRecord(BaseModel): keys[table.range_key_attr] = rec.range_key.to_json() self.record = { - 'eventID': uuid.uuid4().hex, - 'eventName': event_name, - 'eventSource': 'aws:dynamodb', - 'eventVersion': '1.0', - 'awsRegion': 'us-east-1', - 'dynamodb': { - 'StreamViewType': stream_type, - 'ApproximateCreationDateTime': datetime.datetime.utcnow().isoformat(), - 'SequenceNumber': str(seq), - 'SizeBytes': 1, - 'Keys': keys - } + "eventID": uuid.uuid4().hex, + "eventName": event_name, + "eventSource": "aws:dynamodb", + "eventVersion": "1.0", + "awsRegion": "us-east-1", + "dynamodb": { + "StreamViewType": stream_type, + "ApproximateCreationDateTime": datetime.datetime.utcnow().isoformat(), + "SequenceNumber": str(seq), + "SizeBytes": 1, + "Keys": keys, + }, } - if stream_type in ('NEW_IMAGE', 'NEW_AND_OLD_IMAGES'): - self.record['dynamodb']['NewImage'] = new_a - if stream_type in ('OLD_IMAGE', 'NEW_AND_OLD_IMAGES'): - self.record['dynamodb']['OldImage'] = old_a + if stream_type in ("NEW_IMAGE", "NEW_AND_OLD_IMAGES"): + self.record["dynamodb"]["NewImage"] = new_a + if stream_type in ("OLD_IMAGE", "NEW_AND_OLD_IMAGES"): + self.record["dynamodb"]["OldImage"] = old_a # This is a substantial overestimate but it's the easiest to do now - self.record['dynamodb']['SizeBytes'] = len( - json.dumps(self.record['dynamodb'])) + self.record["dynamodb"]["SizeBytes"] = len(json.dumps(self.record["dynamodb"])) def to_json(self): return self.record @@ -496,36 +518,40 @@ class StreamRecord(BaseModel): class StreamShard(BaseModel): def __init__(self, table): self.table = table - self.id = 'shardId-00000001541626099285-f35f62ef' + self.id = "shardId-00000001541626099285-f35f62ef" self.starting_sequence_number = 1100000000017454423009 self.items = [] self.created_on = datetime.datetime.utcnow() def to_json(self): return { - 'ShardId': self.id, - 'SequenceNumberRange': { - 'StartingSequenceNumber': str(self.starting_sequence_number) - } + "ShardId": self.id, + "SequenceNumberRange": { + "StartingSequenceNumber": str(self.starting_sequence_number) + }, } def add(self, old, new): - t = self.table.stream_specification['StreamViewType'] + t = self.table.stream_specification["StreamViewType"] if old is None: - event_name = 'INSERT' + event_name = "INSERT" elif new is None: - event_name = 'DELETE' + event_name = "DELETE" else: - event_name = 'MODIFY' + event_name = "MODIFY" seq = len(self.items) + self.starting_sequence_number - self.items.append( - StreamRecord(self.table, t, event_name, old, new, seq)) + self.items.append(StreamRecord(self.table, t, event_name, old, new, seq)) result = None from moto.awslambda import lambda_backends - for arn, esm in self.table.lambda_event_source_mappings.items(): - region = arn[len('arn:aws:lambda:'):arn.index(':', len('arn:aws:lambda:'))] - result = lambda_backends[region].send_dynamodb_items(arn, self.items, esm.event_source_arn) + for arn, esm in self.table.lambda_event_source_mappings.items(): + region = arn[ + len("arn:aws:lambda:") : arn.index(":", len("arn:aws:lambda:")) + ] + + result = lambda_backends[region].send_dynamodb_items( + arn, self.items, esm.event_source_arn + ) if result: self.items = [] @@ -538,8 +564,16 @@ class StreamShard(BaseModel): class Table(BaseModel): - - def __init__(self, table_name, schema=None, attr=None, throughput=None, indexes=None, global_indexes=None, streams=None): + def __init__( + self, + table_name, + schema=None, + attr=None, + throughput=None, + indexes=None, + global_indexes=None, + streams=None, + ): self.name = table_name self.attr = attr self.schema = schema @@ -555,8 +589,7 @@ class Table(BaseModel): self.range_key_attr = elem["AttributeName"] self.range_key_type = elem["KeyType"] if throughput is None: - self.throughput = { - 'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10} + self.throughput = {"WriteCapacityUnits": 10, "ReadCapacityUnits": 10} else: self.throughput = throughput self.throughput["NumberOfDecreasesToday"] = 0 @@ -567,66 +600,72 @@ class Table(BaseModel): self.table_arn = self._generate_arn(table_name) self.tags = [] self.ttl = { - 'TimeToLiveStatus': 'DISABLED' # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED', + "TimeToLiveStatus": "DISABLED" # One of 'ENABLING'|'DISABLING'|'ENABLED'|'DISABLED', # 'AttributeName': 'string' # Can contain this } self.set_stream_specification(streams) self.lambda_event_source_mappings = {} @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] params = {} - if 'KeySchema' in properties: - params['schema'] = properties['KeySchema'] - if 'AttributeDefinitions' in properties: - params['attr'] = properties['AttributeDefinitions'] - if 'GlobalSecondaryIndexes' in properties: - params['global_indexes'] = properties['GlobalSecondaryIndexes'] - if 'ProvisionedThroughput' in properties: - params['throughput'] = properties['ProvisionedThroughput'] - if 'LocalSecondaryIndexes' in properties: - params['indexes'] = properties['LocalSecondaryIndexes'] + if "KeySchema" in properties: + params["schema"] = properties["KeySchema"] + if "AttributeDefinitions" in properties: + params["attr"] = properties["AttributeDefinitions"] + if "GlobalSecondaryIndexes" in properties: + params["global_indexes"] = properties["GlobalSecondaryIndexes"] + if "ProvisionedThroughput" in properties: + params["throughput"] = properties["ProvisionedThroughput"] + if "LocalSecondaryIndexes" in properties: + params["indexes"] = properties["LocalSecondaryIndexes"] - table = dynamodb_backends[region_name].create_table(name=properties['TableName'], **params) + table = dynamodb_backends[region_name].create_table( + name=properties["TableName"], **params + ) return table def _generate_arn(self, name): - return 'arn:aws:dynamodb:us-east-1:123456789011:table/' + name + return "arn:aws:dynamodb:us-east-1:123456789011:table/" + name def set_stream_specification(self, streams): self.stream_specification = streams - if streams and (streams.get('StreamEnabled') or streams.get('StreamViewType')): - self.stream_specification['StreamEnabled'] = True + if streams and (streams.get("StreamEnabled") or streams.get("StreamViewType")): + self.stream_specification["StreamEnabled"] = True self.latest_stream_label = datetime.datetime.utcnow().isoformat() self.stream_shard = StreamShard(self) else: - self.stream_specification = {'StreamEnabled': False} + self.stream_specification = {"StreamEnabled": False} self.latest_stream_label = None self.stream_shard = None - def describe(self, base_key='TableDescription'): + def describe(self, base_key="TableDescription"): results = { base_key: { - 'AttributeDefinitions': self.attr, - 'ProvisionedThroughput': self.throughput, - 'TableSizeBytes': 0, - 'TableName': self.name, - 'TableStatus': 'ACTIVE', - 'TableArn': self.table_arn, - 'KeySchema': self.schema, - 'ItemCount': len(self), - 'CreationDateTime': unix_time(self.created_at), - 'GlobalSecondaryIndexes': [index for index in self.global_indexes], - 'LocalSecondaryIndexes': [index for index in self.indexes], + "AttributeDefinitions": self.attr, + "ProvisionedThroughput": self.throughput, + "TableSizeBytes": 0, + "TableName": self.name, + "TableStatus": "ACTIVE", + "TableArn": self.table_arn, + "KeySchema": self.schema, + "ItemCount": len(self), + "CreationDateTime": unix_time(self.created_at), + "GlobalSecondaryIndexes": [index for index in self.global_indexes], + "LocalSecondaryIndexes": [index for index in self.indexes], } } - if self.stream_specification and self.stream_specification['StreamEnabled']: - results[base_key]['StreamSpecification'] = self.stream_specification + if self.stream_specification and self.stream_specification["StreamEnabled"]: + results[base_key]["StreamSpecification"] = self.stream_specification if self.latest_stream_label: - results[base_key]['LatestStreamLabel'] = self.latest_stream_label - results[base_key]['LatestStreamArn'] = self.table_arn + '/stream/' + self.latest_stream_label + results[base_key]["LatestStreamLabel"] = self.latest_stream_label + results[base_key]["LatestStreamArn"] = ( + self.table_arn + "/stream/" + self.latest_stream_label + ) return results def __len__(self): @@ -643,9 +682,9 @@ class Table(BaseModel): keys = [self.hash_key_attr] for index in self.global_indexes: hash_key = None - for key in index['KeySchema']: - if key['KeyType'] == 'HASH': - hash_key = key['AttributeName'] + for key in index["KeySchema"]: + if key["KeyType"] == "HASH": + hash_key = key["AttributeName"] keys.append(hash_key) return keys @@ -654,15 +693,21 @@ class Table(BaseModel): keys = [self.range_key_attr] for index in self.global_indexes: range_key = None - for key in index['KeySchema']: - if key['KeyType'] == 'RANGE': - range_key = keys.append(key['AttributeName']) + for key in index["KeySchema"]: + if key["KeyType"] == "RANGE": + range_key = keys.append(key["AttributeName"]) keys.append(range_key) return keys - def put_item(self, item_attrs, expected=None, condition_expression=None, - expression_attribute_names=None, - expression_attribute_values=None, overwrite=False): + def put_item( + self, + item_attrs, + expected=None, + condition_expression=None, + expression_attribute_names=None, + expression_attribute_values=None, + overwrite=False, + ): hash_value = DynamoType(item_attrs.get(self.hash_key_attr)) if self.has_range_key: range_value = DynamoType(item_attrs.get(self.range_key_attr)) @@ -673,26 +718,27 @@ class Table(BaseModel): expected = {} lookup_range_value = range_value else: - expected_range_value = expected.get( - self.range_key_attr, {}).get("Value") - if(expected_range_value is None): + expected_range_value = expected.get(self.range_key_attr, {}).get("Value") + if expected_range_value is None: lookup_range_value = range_value else: lookup_range_value = DynamoType(expected_range_value) current = self.get_item(hash_value, lookup_range_value) - item = Item(hash_value, self.hash_key_type, range_value, - self.range_key_type, item_attrs) + item = Item( + hash_value, self.hash_key_type, range_value, self.range_key_type, item_attrs + ) if not overwrite: if not get_expected(expected).expr(current): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") condition_op = get_filter_expression( condition_expression, expression_attribute_names, - expression_attribute_values) + expression_attribute_values, + ) if not condition_op.expr(current): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") if range_value: self.items[hash_value][range_value] = item @@ -717,7 +763,8 @@ class Table(BaseModel): def get_item(self, hash_key, range_key=None, projection_expression=None): if self.has_range_key and not range_key: raise ValueError( - "Table has a range key, but no range key was passed into get_item") + "Table has a range key, but no range key was passed into get_item" + ) try: result = None @@ -727,7 +774,7 @@ class Table(BaseModel): result = self.items[hash_key] if projection_expression and result: - expressions = [x.strip() for x in projection_expression.split(',')] + expressions = [x.strip() for x in projection_expression.split(",")] result = copy.deepcopy(result) for attr in list(result.attrs): if attr not in expressions: @@ -754,30 +801,42 @@ class Table(BaseModel): except KeyError: return None - def query(self, hash_key, range_comparison, range_objs, limit, - exclusive_start_key, scan_index_forward, projection_expression, - index_name=None, filter_expression=None, **filter_kwargs): + def query( + self, + hash_key, + range_comparison, + range_objs, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name=None, + filter_expression=None, + **filter_kwargs + ): results = [] if index_name: all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) if index_name not in indexes_by_name: - raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % ( - index_name, self.name, ', '.join(indexes_by_name.keys()) - )) + raise ValueError( + "Invalid index: %s for table: %s. Available indexes are: %s" + % (index_name, self.name, ", ".join(indexes_by_name.keys())) + ) index = indexes_by_name[index_name] try: - index_hash_key = [key for key in index[ - 'KeySchema'] if key['KeyType'] == 'HASH'][0] + index_hash_key = [ + key for key in index["KeySchema"] if key["KeyType"] == "HASH" + ][0] except IndexError: - raise ValueError('Missing Hash Key. KeySchema: %s' % - index['KeySchema']) + raise ValueError("Missing Hash Key. KeySchema: %s" % index["KeySchema"]) try: - index_range_key = [key for key in index[ - 'KeySchema'] if key['KeyType'] == 'RANGE'][0] + index_range_key = [ + key for key in index["KeySchema"] if key["KeyType"] == "RANGE" + ][0] except IndexError: index_range_key = None @@ -785,25 +844,32 @@ class Table(BaseModel): for item in self.all_items(): if not isinstance(item, Item): continue - item_hash_key = item.attrs.get(index_hash_key['AttributeName']) + item_hash_key = item.attrs.get(index_hash_key["AttributeName"]) if index_range_key is None: if item_hash_key and item_hash_key == hash_key: possible_results.append(item) else: - item_range_key = item.attrs.get(index_range_key['AttributeName']) + item_range_key = item.attrs.get(index_range_key["AttributeName"]) if item_hash_key and item_hash_key == hash_key and item_range_key: possible_results.append(item) else: - possible_results = [item for item in list(self.all_items()) if isinstance( - item, Item) and item.hash_key == hash_key] + possible_results = [ + item + for item in list(self.all_items()) + if isinstance(item, Item) and item.hash_key == hash_key + ] if range_comparison: if index_name and not index_range_key: raise ValueError( - 'Range Key comparison but no range key found for index: %s' % index_name) + "Range Key comparison but no range key found for index: %s" + % index_name + ) elif index_name: for result in possible_results: - if result.attrs.get(index_range_key['AttributeName']).compare(range_comparison, range_objs): + if result.attrs.get(index_range_key["AttributeName"]).compare( + range_comparison, range_objs + ): results.append(result) else: for result in possible_results: @@ -813,9 +879,12 @@ class Table(BaseModel): if filter_kwargs: for result in possible_results: for field, value in filter_kwargs.items(): - dynamo_types = [DynamoType(ele) for ele in value[ - "AttributeValueList"]] - if result.attrs.get(field).compare(value['ComparisonOperator'], dynamo_types): + dynamo_types = [ + DynamoType(ele) for ele in value["AttributeValueList"] + ] + if result.attrs.get(field).compare( + value["ComparisonOperator"], dynamo_types + ): results.append(result) if not range_comparison and not filter_kwargs: @@ -826,8 +895,11 @@ class Table(BaseModel): if index_name: if index_range_key: - results.sort(key=lambda item: item.attrs[index_range_key['AttributeName']].value - if item.attrs.get(index_range_key['AttributeName']) else None) + results.sort( + key=lambda item: item.attrs[index_range_key["AttributeName"]].value + if item.attrs.get(index_range_key["AttributeName"]) + else None + ) else: results.sort(key=lambda item: item.range_key) @@ -840,15 +912,16 @@ class Table(BaseModel): results = [item for item in results if filter_expression.expr(item)] if projection_expression: - expressions = [x.strip() for x in projection_expression.split(',')] + expressions = [x.strip() for x in projection_expression.split(",")] results = copy.deepcopy(results) for result in results: for attr in list(result.attrs): if attr not in expressions: result.attrs.pop(attr) - results, last_evaluated_key = self._trim_results(results, limit, - exclusive_start_key) + results, last_evaluated_key = self._trim_results( + results, limit, exclusive_start_key + ) return results, scanned_count, last_evaluated_key def all_items(self): @@ -865,9 +938,9 @@ class Table(BaseModel): def has_idx_items(self, index_name): all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) idx = indexes_by_name[index_name] - idx_col_set = set([i['AttributeName'] for i in idx['KeySchema']]) + idx_col_set = set([i["AttributeName"] for i in idx["KeySchema"]]) for hash_set in self.items.values(): if self.range_key_attr: @@ -878,15 +951,25 @@ class Table(BaseModel): if idx_col_set.issubset(set(hash_set.attrs)): yield hash_set - def scan(self, filters, limit, exclusive_start_key, filter_expression=None, index_name=None, projection_expression=None): + def scan( + self, + filters, + limit, + exclusive_start_key, + filter_expression=None, + index_name=None, + projection_expression=None, + ): results = [] scanned_count = 0 all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) if index_name: if index_name not in indexes_by_name: - raise InvalidIndexNameError('The table does not have the specified index: %s' % index_name) + raise InvalidIndexNameError( + "The table does not have the specified index: %s" % index_name + ) items = self.has_idx_items(index_name) else: items = self.all_items() @@ -894,7 +977,10 @@ class Table(BaseModel): for item in items: scanned_count += 1 passes_all_conditions = True - for attribute_name, (comparison_operator, comparison_objs) in filters.items(): + for ( + attribute_name, + (comparison_operator, comparison_objs), + ) in filters.items(): attribute = item.attrs.get(attribute_name) if attribute: @@ -902,7 +988,7 @@ class Table(BaseModel): if not attribute.compare(comparison_operator, comparison_objs): passes_all_conditions = False break - elif comparison_operator == 'NULL': + elif comparison_operator == "NULL": # Comparison is NULL and we don't have the attribute continue else: @@ -918,15 +1004,16 @@ class Table(BaseModel): results.append(item) if projection_expression: - expressions = [x.strip() for x in projection_expression.split(',')] + expressions = [x.strip() for x in projection_expression.split(",")] results = copy.deepcopy(results) for result in results: for attr in list(result.attrs): if attr not in expressions: result.attrs.pop(attr) - results, last_evaluated_key = self._trim_results(results, limit, - exclusive_start_key, index_name) + results, last_evaluated_key = self._trim_results( + results, limit, exclusive_start_key, index_name + ) return results, scanned_count, last_evaluated_key def _trim_results(self, results, limit, exclusive_start_key, scanned_index=None): @@ -936,24 +1023,25 @@ class Table(BaseModel): if range_key is not None: range_key = DynamoType(range_key) for i in range(len(results)): - if results[i].hash_key == hash_key and results[i].range_key == range_key: - results = results[i + 1:] + if ( + results[i].hash_key == hash_key + and results[i].range_key == range_key + ): + results = results[i + 1 :] break last_evaluated_key = None if limit and len(results) > limit: results = results[:limit] - last_evaluated_key = { - self.hash_key_attr: results[-1].hash_key - } + last_evaluated_key = {self.hash_key_attr: results[-1].hash_key} if results[-1].range_key is not None: last_evaluated_key[self.range_key_attr] = results[-1].range_key if scanned_index: all_indexes = self.all_indexes() - indexes_by_name = dict((i['IndexName'], i) for i in all_indexes) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) idx = indexes_by_name[scanned_index] - idx_col_list = [i['AttributeName'] for i in idx['KeySchema']] + idx_col_list = [i["AttributeName"] for i in idx["KeySchema"]] for col in idx_col_list: last_evaluated_key[col] = results[-1].attrs[col] @@ -971,7 +1059,6 @@ class Table(BaseModel): class DynamoDBBackend(BaseBackend): - def __init__(self, region_name=None): self.region_name = region_name self.tables = OrderedDict() @@ -1000,7 +1087,9 @@ class DynamoDBBackend(BaseBackend): def untag_resource(self, table_arn, tag_keys): for table in self.tables: if self.tables[table].table_arn == table_arn: - self.tables[table].tags = [tag for tag in self.tables[table].tags if tag['Key'] not in tag_keys] + self.tables[table].tags = [ + tag for tag in self.tables[table].tags if tag["Key"] not in tag_keys + ] def list_tags_of_resource(self, table_arn): required_table = None @@ -1016,55 +1105,76 @@ class DynamoDBBackend(BaseBackend): def update_table_streams(self, name, stream_specification): table = self.tables[name] - if (stream_specification.get('StreamEnabled') or stream_specification.get('StreamViewType')) and table.latest_stream_label: - raise ValueError('Table already has stream enabled') + if ( + stream_specification.get("StreamEnabled") + or stream_specification.get("StreamViewType") + ) and table.latest_stream_label: + raise ValueError("Table already has stream enabled") table.set_stream_specification(stream_specification) return table def update_table_global_indexes(self, name, global_index_updates): table = self.tables[name] - gsis_by_name = dict((i['IndexName'], i) for i in table.global_indexes) + gsis_by_name = dict((i["IndexName"], i) for i in table.global_indexes) for gsi_update in global_index_updates: - gsi_to_create = gsi_update.get('Create') - gsi_to_update = gsi_update.get('Update') - gsi_to_delete = gsi_update.get('Delete') + gsi_to_create = gsi_update.get("Create") + gsi_to_update = gsi_update.get("Update") + gsi_to_delete = gsi_update.get("Delete") if gsi_to_delete: - index_name = gsi_to_delete['IndexName'] + index_name = gsi_to_delete["IndexName"] if index_name not in gsis_by_name: - raise ValueError('Global Secondary Index does not exist, but tried to delete: %s' % - gsi_to_delete['IndexName']) + raise ValueError( + "Global Secondary Index does not exist, but tried to delete: %s" + % gsi_to_delete["IndexName"] + ) del gsis_by_name[index_name] if gsi_to_update: - index_name = gsi_to_update['IndexName'] + index_name = gsi_to_update["IndexName"] if index_name not in gsis_by_name: - raise ValueError('Global Secondary Index does not exist, but tried to update: %s' % - gsi_to_update['IndexName']) + raise ValueError( + "Global Secondary Index does not exist, but tried to update: %s" + % gsi_to_update["IndexName"] + ) gsis_by_name[index_name].update(gsi_to_update) if gsi_to_create: - if gsi_to_create['IndexName'] in gsis_by_name: + if gsi_to_create["IndexName"] in gsis_by_name: raise ValueError( - 'Global Secondary Index already exists: %s' % gsi_to_create['IndexName']) + "Global Secondary Index already exists: %s" + % gsi_to_create["IndexName"] + ) - gsis_by_name[gsi_to_create['IndexName']] = gsi_to_create + gsis_by_name[gsi_to_create["IndexName"]] = gsi_to_create # in python 3.6, dict.values() returns a dict_values object, but we expect it to be a list in other # parts of the codebase table.global_indexes = list(gsis_by_name.values()) return table - def put_item(self, table_name, item_attrs, expected=None, - condition_expression=None, expression_attribute_names=None, - expression_attribute_values=None, overwrite=False): + def put_item( + self, + table_name, + item_attrs, + expected=None, + condition_expression=None, + expression_attribute_names=None, + expression_attribute_values=None, + overwrite=False, + ): table = self.tables.get(table_name) if not table: return None - return table.put_item(item_attrs, expected, condition_expression, - expression_attribute_names, - expression_attribute_values, overwrite) + return table.put_item( + item_attrs, + expected, + condition_expression, + expression_attribute_names, + expression_attribute_values, + overwrite, + ) def get_table_keys_name(self, table_name, keys): """ @@ -1090,12 +1200,16 @@ class DynamoDBBackend(BaseBackend): return potential_hash, potential_range def get_keys_value(self, table, keys): - if table.hash_key_attr not in keys or (table.has_range_key and table.range_key_attr not in keys): + if table.hash_key_attr not in keys or ( + table.has_range_key and table.range_key_attr not in keys + ): raise ValueError( - "Table has a range key, but no range key was passed into get_item") + "Table has a range key, but no range key was passed into get_item" + ) hash_key = DynamoType(keys[table.hash_key_attr]) - range_key = DynamoType( - keys[table.range_key_attr]) if table.has_range_key else None + range_key = ( + DynamoType(keys[table.range_key_attr]) if table.has_range_key else None + ) return hash_key, range_key def get_table(self, table_name): @@ -1108,24 +1222,58 @@ class DynamoDBBackend(BaseBackend): hash_key, range_key = self.get_keys_value(table, keys) return table.get_item(hash_key, range_key, projection_expression) - def query(self, table_name, hash_key_dict, range_comparison, range_value_dicts, - limit, exclusive_start_key, scan_index_forward, projection_expression, index_name=None, - expr_names=None, expr_values=None, filter_expression=None, - **filter_kwargs): + def query( + self, + table_name, + hash_key_dict, + range_comparison, + range_value_dicts, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name=None, + expr_names=None, + expr_values=None, + filter_expression=None, + **filter_kwargs + ): table = self.tables.get(table_name) if not table: return None, None hash_key = DynamoType(hash_key_dict) - range_values = [DynamoType(range_value) - for range_value in range_value_dicts] + range_values = [DynamoType(range_value) for range_value in range_value_dicts] - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) + filter_expression = get_filter_expression( + filter_expression, expr_names, expr_values + ) - return table.query(hash_key, range_comparison, range_values, limit, - exclusive_start_key, scan_index_forward, projection_expression, index_name, filter_expression, **filter_kwargs) + return table.query( + hash_key, + range_comparison, + range_values, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name, + filter_expression, + **filter_kwargs + ) - def scan(self, table_name, filters, limit, exclusive_start_key, filter_expression, expr_names, expr_values, index_name, projection_expression): + def scan( + self, + table_name, + filters, + limit, + exclusive_start_key, + filter_expression, + expr_names, + expr_values, + index_name, + projection_expression, + ): table = self.tables.get(table_name) if not table: return None, None, None @@ -1135,14 +1283,37 @@ class DynamoDBBackend(BaseBackend): dynamo_types = [DynamoType(value) for value in comparison_values] scan_filters[key] = (comparison_operator, dynamo_types) - filter_expression = get_filter_expression(filter_expression, expr_names, expr_values) + filter_expression = get_filter_expression( + filter_expression, expr_names, expr_values + ) - projection_expression = ','.join([expr_names.get(attr, attr) for attr in projection_expression.replace(' ', '').split(',')]) + projection_expression = ",".join( + [ + expr_names.get(attr, attr) + for attr in projection_expression.replace(" ", "").split(",") + ] + ) - return table.scan(scan_filters, limit, exclusive_start_key, filter_expression, index_name, projection_expression) + return table.scan( + scan_filters, + limit, + exclusive_start_key, + filter_expression, + index_name, + projection_expression, + ) - def update_item(self, table_name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected=None, condition_expression=None): + def update_item( + self, + table_name, + key, + update_expression, + attribute_updates, + expression_attribute_names, + expression_attribute_values, + expected=None, + condition_expression=None, + ): table = self.get_table(table_name) if all([table.hash_key_attr in key, table.range_key_attr in key]): @@ -1165,40 +1336,44 @@ class DynamoDBBackend(BaseBackend): expected = {} if not get_expected(expected).expr(item): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") condition_op = get_filter_expression( condition_expression, expression_attribute_names, - expression_attribute_values) + expression_attribute_values, + ) if not condition_op.expr(item): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") # Update does not fail on new items, so create one if item is None: - data = { - table.hash_key_attr: { - hash_value.type: hash_value.value, - }, - } + data = {table.hash_key_attr: {hash_value.type: hash_value.value}} if range_value: - data.update({ - table.range_key_attr: { - range_value.type: range_value.value, - } - }) + data.update( + {table.range_key_attr: {range_value.type: range_value.value}} + ) table.put_item(data) item = table.get_item(hash_value, range_value) if update_expression: - item.update(update_expression, expression_attribute_names, - expression_attribute_values) + item.update( + update_expression, + expression_attribute_names, + expression_attribute_values, + ) else: item.update_with_attribute_updates(attribute_updates) return item - def delete_item(self, table_name, key, expression_attribute_names=None, expression_attribute_values=None, - condition_expression=None): + def delete_item( + self, + table_name, + key, + expression_attribute_names=None, + expression_attribute_values=None, + condition_expression=None, + ): table = self.get_table(table_name) if not table: return None @@ -1209,34 +1384,39 @@ class DynamoDBBackend(BaseBackend): condition_op = get_filter_expression( condition_expression, expression_attribute_names, - expression_attribute_values) + expression_attribute_values, + ) if not condition_op.expr(item): - raise ValueError('The conditional request failed') + raise ValueError("The conditional request failed") return table.delete_item(hash_value, range_value) def update_ttl(self, table_name, ttl_spec): table = self.tables.get(table_name) if table is None: - raise JsonRESTError('ResourceNotFound', 'Table not found') + raise JsonRESTError("ResourceNotFound", "Table not found") - if 'Enabled' not in ttl_spec or 'AttributeName' not in ttl_spec: - raise JsonRESTError('InvalidParameterValue', - 'TimeToLiveSpecification does not contain Enabled and AttributeName') + if "Enabled" not in ttl_spec or "AttributeName" not in ttl_spec: + raise JsonRESTError( + "InvalidParameterValue", + "TimeToLiveSpecification does not contain Enabled and AttributeName", + ) - if ttl_spec['Enabled']: - table.ttl['TimeToLiveStatus'] = 'ENABLED' + if ttl_spec["Enabled"]: + table.ttl["TimeToLiveStatus"] = "ENABLED" else: - table.ttl['TimeToLiveStatus'] = 'DISABLED' - table.ttl['AttributeName'] = ttl_spec['AttributeName'] + table.ttl["TimeToLiveStatus"] = "DISABLED" + table.ttl["AttributeName"] = ttl_spec["AttributeName"] def describe_ttl(self, table_name): table = self.tables.get(table_name) if table is None: - raise JsonRESTError('ResourceNotFound', 'Table not found') + raise JsonRESTError("ResourceNotFound", "Table not found") return table.ttl available_regions = boto3.session.Session().get_available_regions("dynamodb") -dynamodb_backends = {region: DynamoDBBackend(region_name=region) for region in available_regions} +dynamodb_backends = { + region: DynamoDBBackend(region_name=region) for region in available_regions +} diff --git a/moto/dynamodb2/responses.py b/moto/dynamodb2/responses.py index b5e5e11a8..fd1d19ff6 100644 --- a/moto/dynamodb2/responses.py +++ b/moto/dynamodb2/responses.py @@ -16,25 +16,30 @@ def has_empty_keys_or_values(_dict): if not isinstance(_dict, dict): return False return any( - key == '' or value == '' or - has_empty_keys_or_values(value) + key == "" or value == "" or has_empty_keys_or_values(value) for key, value in _dict.items() ) def get_empty_str_error(): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return (400, - {'server': 'amazon.com'}, - dynamo_json_dump({'__type': er, - 'message': ('One or more parameter values were ' - 'invalid: An AttributeValue may not ' - 'contain an empty string')} - )) + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return ( + 400, + {"server": "amazon.com"}, + dynamo_json_dump( + { + "__type": er, + "message": ( + "One or more parameter values were " + "invalid: An AttributeValue may not " + "contain an empty string" + ), + } + ), + ) class DynamoHandler(BaseResponse): - def get_endpoint_name(self, headers): """Parses request headers and extracts part od the X-Amz-Target that corresponds to a method of DynamoHandler @@ -42,12 +47,16 @@ class DynamoHandler(BaseResponse): ie: X-Amz-Target: DynamoDB_20111205.ListTables -> ListTables """ # Headers are case-insensitive. Probably a better way to do this. - match = headers.get('x-amz-target') or headers.get('X-Amz-Target') + match = headers.get("x-amz-target") or headers.get("X-Amz-Target") if match: return match.split(".")[1] def error(self, type_, message, status=400): - return status, self.response_headers, dynamo_json_dump({'__type': type_, 'message': message}) + return ( + status, + self.response_headers, + dynamo_json_dump({"__type": type_, "message": message}), + ) @property def dynamodb_backend(self): @@ -59,7 +68,7 @@ class DynamoHandler(BaseResponse): @amzn_request_id def call_action(self): - self.body = json.loads(self.body or '{}') + self.body = json.loads(self.body or "{}") endpoint = self.get_endpoint_name(self.headers) if endpoint: endpoint = camelcase_to_underscores(endpoint) @@ -76,7 +85,7 @@ class DynamoHandler(BaseResponse): def list_tables(self): body = self.body - limit = body.get('Limit', 100) + limit = body.get("Limit", 100) if body.get("ExclusiveStartTableName"): last = body.get("ExclusiveStartTableName") start = list(self.dynamodb_backend.tables.keys()).index(last) + 1 @@ -84,7 +93,7 @@ class DynamoHandler(BaseResponse): start = 0 all_tables = list(self.dynamodb_backend.tables.keys()) if limit: - tables = all_tables[start:start + limit] + tables = all_tables[start : start + limit] else: tables = all_tables[start:] response = {"TableNames": tables} @@ -96,19 +105,21 @@ class DynamoHandler(BaseResponse): def create_table(self): body = self.body # get the table name - table_name = body['TableName'] + table_name = body["TableName"] # check billing mode and get the throughput if "BillingMode" in body.keys() and body["BillingMode"] == "PAY_PER_REQUEST": if "ProvisionedThroughput" in body.keys(): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, - 'ProvisionedThroughput cannot be specified \ - when BillingMode is PAY_PER_REQUEST') + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error( + er, + "ProvisionedThroughput cannot be specified \ + when BillingMode is PAY_PER_REQUEST", + ) throughput = None - else: # Provisioned (default billing mode) + else: # Provisioned (default billing mode) throughput = body.get("ProvisionedThroughput") # getting the schema - key_schema = body['KeySchema'] + key_schema = body["KeySchema"] # getting attribute definition attr = body["AttributeDefinitions"] # getting the indexes @@ -116,200 +127,224 @@ class DynamoHandler(BaseResponse): local_secondary_indexes = body.get("LocalSecondaryIndexes", []) # Verify AttributeDefinitions list all expected_attrs = [] - expected_attrs.extend([key['AttributeName'] for key in key_schema]) - expected_attrs.extend(schema['AttributeName'] for schema in itertools.chain(*list(idx['KeySchema'] for idx in local_secondary_indexes))) - expected_attrs.extend(schema['AttributeName'] for schema in itertools.chain(*list(idx['KeySchema'] for idx in global_indexes))) + expected_attrs.extend([key["AttributeName"] for key in key_schema]) + expected_attrs.extend( + schema["AttributeName"] + for schema in itertools.chain( + *list(idx["KeySchema"] for idx in local_secondary_indexes) + ) + ) + expected_attrs.extend( + schema["AttributeName"] + for schema in itertools.chain( + *list(idx["KeySchema"] for idx in global_indexes) + ) + ) expected_attrs = list(set(expected_attrs)) expected_attrs.sort() - actual_attrs = [item['AttributeName'] for item in attr] + actual_attrs = [item["AttributeName"] for item in attr] actual_attrs.sort() if actual_attrs != expected_attrs: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, - 'One or more parameter values were invalid: ' - 'Some index key attributes are not defined in AttributeDefinitions. ' - 'Keys: ' + str(expected_attrs) + ', AttributeDefinitions: ' + str(actual_attrs)) + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error( + er, + "One or more parameter values were invalid: " + "Some index key attributes are not defined in AttributeDefinitions. " + "Keys: " + + str(expected_attrs) + + ", AttributeDefinitions: " + + str(actual_attrs), + ) # get the stream specification streams = body.get("StreamSpecification") - table = self.dynamodb_backend.create_table(table_name, - schema=key_schema, - throughput=throughput, - attr=attr, - global_indexes=global_indexes, - indexes=local_secondary_indexes, - streams=streams) + table = self.dynamodb_backend.create_table( + table_name, + schema=key_schema, + throughput=throughput, + attr=attr, + global_indexes=global_indexes, + indexes=local_secondary_indexes, + streams=streams, + ) if table is not None: return dynamo_json_dump(table.describe()) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceInUseException' - return self.error(er, 'Resource in use') + er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" + return self.error(er, "Resource in use") def delete_table(self): - name = self.body['TableName'] + name = self.body["TableName"] table = self.dynamodb_backend.delete_table(name) if table is not None: return dynamo_json_dump(table.describe()) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") def tag_resource(self): - table_arn = self.body['ResourceArn'] - tags = self.body['Tags'] + table_arn = self.body["ResourceArn"] + tags = self.body["Tags"] self.dynamodb_backend.tag_resource(table_arn, tags) - return '' + return "" def untag_resource(self): - table_arn = self.body['ResourceArn'] - tags = self.body['TagKeys'] + table_arn = self.body["ResourceArn"] + tags = self.body["TagKeys"] self.dynamodb_backend.untag_resource(table_arn, tags) - return '' + return "" def list_tags_of_resource(self): try: - table_arn = self.body['ResourceArn'] + table_arn = self.body["ResourceArn"] all_tags = self.dynamodb_backend.list_tags_of_resource(table_arn) - all_tag_keys = [tag['Key'] for tag in all_tags] - marker = self.body.get('NextToken') + all_tag_keys = [tag["Key"] for tag in all_tags] + marker = self.body.get("NextToken") if marker: start = all_tag_keys.index(marker) + 1 else: start = 0 max_items = 10 # there is no default, but using 10 to make testing easier - tags_resp = all_tags[start:start + max_items] + tags_resp = all_tags[start : start + max_items] next_marker = None if len(all_tags) > start + max_items: - next_marker = tags_resp[-1]['Key'] + next_marker = tags_resp[-1]["Key"] if next_marker: - return json.dumps({'Tags': tags_resp, - 'NextToken': next_marker}) - return json.dumps({'Tags': tags_resp}) + return json.dumps({"Tags": tags_resp, "NextToken": next_marker}) + return json.dumps({"Tags": tags_resp}) except AttributeError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") def update_table(self): - name = self.body['TableName'] + name = self.body["TableName"] table = self.dynamodb_backend.get_table(name) - if 'GlobalSecondaryIndexUpdates' in self.body: + if "GlobalSecondaryIndexUpdates" in self.body: table = self.dynamodb_backend.update_table_global_indexes( - name, self.body['GlobalSecondaryIndexUpdates']) - if 'ProvisionedThroughput' in self.body: + name, self.body["GlobalSecondaryIndexUpdates"] + ) + if "ProvisionedThroughput" in self.body: throughput = self.body["ProvisionedThroughput"] table = self.dynamodb_backend.update_table_throughput(name, throughput) - if 'StreamSpecification' in self.body: + if "StreamSpecification" in self.body: try: - table = self.dynamodb_backend.update_table_streams(name, self.body['StreamSpecification']) + table = self.dynamodb_backend.update_table_streams( + name, self.body["StreamSpecification"] + ) except ValueError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceInUseException' - return self.error(er, 'Cannot enable stream') + er = "com.amazonaws.dynamodb.v20111205#ResourceInUseException" + return self.error(er, "Cannot enable stream") return dynamo_json_dump(table.describe()) def describe_table(self): - name = self.body['TableName'] + name = self.body["TableName"] try: table = self.dynamodb_backend.tables[name] except KeyError: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') - return dynamo_json_dump(table.describe(base_key='Table')) + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") + return dynamo_json_dump(table.describe(base_key="Table")) def put_item(self): - name = self.body['TableName'] - item = self.body['Item'] - return_values = self.body.get('ReturnValues', 'NONE') + name = self.body["TableName"] + item = self.body["Item"] + return_values = self.body.get("ReturnValues", "NONE") - if return_values not in ('ALL_OLD', 'NONE'): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Return values set to invalid value') + if return_values not in ("ALL_OLD", "NONE"): + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Return values set to invalid value") if has_empty_keys_or_values(item): return get_empty_str_error() - overwrite = 'Expected' not in self.body + overwrite = "Expected" not in self.body if not overwrite: - expected = self.body['Expected'] + expected = self.body["Expected"] else: expected = None - if return_values == 'ALL_OLD': + if return_values == "ALL_OLD": existing_item = self.dynamodb_backend.get_item(name, item) if existing_item: - existing_attributes = existing_item.to_json()['Attributes'] + existing_attributes = existing_item.to_json()["Attributes"] else: existing_attributes = {} # Attempt to parse simple ConditionExpressions into an Expected # expression - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + condition_expression = self.body.get("ConditionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) if condition_expression: overwrite = False try: result = self.dynamodb_backend.put_item( - name, item, expected, condition_expression, - expression_attribute_names, expression_attribute_values, - overwrite) + name, + item, + expected, + condition_expression, + expression_attribute_names, + expression_attribute_values, + overwrite, + ) except ItemSizeTooLarge: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' + er = "com.amazonaws.dynamodb.v20111205#ValidationException" return self.error(er, ItemSizeTooLarge.message) except ValueError: - er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' - return self.error(er, 'A condition specified in the operation could not be evaluated.') + er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) if result: item_dict = result.to_json() - item_dict['ConsumedCapacity'] = { - 'TableName': name, - 'CapacityUnits': 1 - } - if return_values == 'ALL_OLD': - item_dict['Attributes'] = existing_attributes + item_dict["ConsumedCapacity"] = {"TableName": name, "CapacityUnits": 1} + if return_values == "ALL_OLD": + item_dict["Attributes"] = existing_attributes else: - item_dict.pop('Attributes', None) + item_dict.pop("Attributes", None) return dynamo_json_dump(item_dict) else: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") def batch_write_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] for table_name, table_requests in table_batches.items(): for table_request in table_requests: request_type = list(table_request.keys())[0] request = list(table_request.values())[0] - if request_type == 'PutRequest': - item = request['Item'] + if request_type == "PutRequest": + item = request["Item"] self.dynamodb_backend.put_item(table_name, item) - elif request_type == 'DeleteRequest': - keys = request['Key'] + elif request_type == "DeleteRequest": + keys = request["Key"] item = self.dynamodb_backend.delete_item(table_name, keys) response = { "ConsumedCapacity": [ { - 'TableName': table_name, - 'CapacityUnits': 1.0, - 'Table': {'CapacityUnits': 1.0} - } for table_name, table_requests in table_batches.items() + "TableName": table_name, + "CapacityUnits": 1.0, + "Table": {"CapacityUnits": 1.0}, + } + for table_name, table_requests in table_batches.items() ], "ItemCollectionMetrics": {}, - "UnprocessedItems": {} + "UnprocessedItems": {}, } return dynamo_json_dump(response) def get_item(self): - name = self.body['TableName'] - key = self.body['Key'] - projection_expression = self.body.get('ProjectionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) + name = self.body["TableName"] + key = self.body["Key"] + projection_expression = self.body.get("ProjectionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) projection_expression = self._adjust_projection_expression( projection_expression, expression_attribute_names @@ -318,38 +353,31 @@ class DynamoHandler(BaseResponse): try: item = self.dynamodb_backend.get_item(name, key, projection_expression) except ValueError: - er = 'com.amazon.coral.validate#ValidationException' - return self.error(er, 'Validation Exception') + er = "com.amazon.coral.validate#ValidationException" + return self.error(er, "Validation Exception") if item: item_dict = item.describe_attrs(attributes=None) - item_dict['ConsumedCapacity'] = { - 'TableName': name, - 'CapacityUnits': 0.5 - } + item_dict["ConsumedCapacity"] = {"TableName": name, "CapacityUnits": 0.5} return dynamo_json_dump(item_dict) else: # Item not found - return 200, self.response_headers, '{}' + return 200, self.response_headers, "{}" def batch_get_item(self): - table_batches = self.body['RequestItems'] + table_batches = self.body["RequestItems"] - results = { - "ConsumedCapacity": [], - "Responses": { - }, - "UnprocessedKeys": { - } - } + results = {"ConsumedCapacity": [], "Responses": {}, "UnprocessedKeys": {}} for table_name, table_request in table_batches.items(): - keys = table_request['Keys'] + keys = table_request["Keys"] if self._contains_duplicates(keys): - er = 'com.amazon.coral.validate#ValidationException' - return self.error(er, 'Provided list of item keys contains duplicates') - attributes_to_get = table_request.get('AttributesToGet') - projection_expression = table_request.get('ProjectionExpression') - expression_attribute_names = table_request.get('ExpressionAttributeNames', {}) + er = "com.amazon.coral.validate#ValidationException" + return self.error(er, "Provided list of item keys contains duplicates") + attributes_to_get = table_request.get("AttributesToGet") + projection_expression = table_request.get("ProjectionExpression") + expression_attribute_names = table_request.get( + "ExpressionAttributeNames", {} + ) projection_expression = self._adjust_projection_expression( projection_expression, expression_attribute_names @@ -357,16 +385,16 @@ class DynamoHandler(BaseResponse): results["Responses"][table_name] = [] for key in keys: - item = self.dynamodb_backend.get_item(table_name, key, projection_expression) + item = self.dynamodb_backend.get_item( + table_name, key, projection_expression + ) if item: item_describe = item.describe_attrs(attributes_to_get) - results["Responses"][table_name].append( - item_describe["Item"]) + results["Responses"][table_name].append(item_describe["Item"]) - results["ConsumedCapacity"].append({ - "CapacityUnits": len(keys), - "TableName": table_name - }) + results["ConsumedCapacity"].append( + {"CapacityUnits": len(keys), "TableName": table_name} + ) return dynamo_json_dump(results) def _contains_duplicates(self, keys): @@ -379,13 +407,13 @@ class DynamoHandler(BaseResponse): return False def query(self): - name = self.body['TableName'] + name = self.body["TableName"] # {u'KeyConditionExpression': u'#n0 = :v0', u'ExpressionAttributeValues': {u':v0': {u'S': u'johndoe'}}, u'ExpressionAttributeNames': {u'#n0': u'username'}} - key_condition_expression = self.body.get('KeyConditionExpression') - projection_expression = self.body.get('ProjectionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - filter_expression = self.body.get('FilterExpression') - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + key_condition_expression = self.body.get("KeyConditionExpression") + projection_expression = self.body.get("ProjectionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + filter_expression = self.body.get("FilterExpression") + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) projection_expression = self._adjust_projection_expression( projection_expression, expression_attribute_names @@ -394,133 +422,148 @@ class DynamoHandler(BaseResponse): filter_kwargs = {} if key_condition_expression: - value_alias_map = self.body.get('ExpressionAttributeValues', {}) + value_alias_map = self.body.get("ExpressionAttributeValues", {}) table = self.dynamodb_backend.get_table(name) # If table does not exist if table is None: - return self.error('com.amazonaws.dynamodb.v20120810#ResourceNotFoundException', - 'Requested resource not found') + return self.error( + "com.amazonaws.dynamodb.v20120810#ResourceNotFoundException", + "Requested resource not found", + ) - index_name = self.body.get('IndexName') + index_name = self.body.get("IndexName") if index_name: - all_indexes = (table.global_indexes or []) + \ - (table.indexes or []) - indexes_by_name = dict((i['IndexName'], i) - for i in all_indexes) + all_indexes = (table.global_indexes or []) + (table.indexes or []) + indexes_by_name = dict((i["IndexName"], i) for i in all_indexes) if index_name not in indexes_by_name: - raise ValueError('Invalid index: %s for table: %s. Available indexes are: %s' % ( - index_name, name, ', '.join(indexes_by_name.keys()) - )) + raise ValueError( + "Invalid index: %s for table: %s. Available indexes are: %s" + % (index_name, name, ", ".join(indexes_by_name.keys())) + ) - index = indexes_by_name[index_name]['KeySchema'] + index = indexes_by_name[index_name]["KeySchema"] else: index = table.schema - reverse_attribute_lookup = dict((v, k) for k, v in - six.iteritems(self.body.get('ExpressionAttributeNames', {}))) + reverse_attribute_lookup = dict( + (v, k) + for k, v in six.iteritems(self.body.get("ExpressionAttributeNames", {})) + ) if " AND " in key_condition_expression: expressions = key_condition_expression.split(" AND ", 1) - index_hash_key = [key for key in index if key['KeyType'] == 'HASH'][0] - hash_key_var = reverse_attribute_lookup.get(index_hash_key['AttributeName'], - index_hash_key['AttributeName']) - hash_key_regex = r'(^|[\s(]){0}\b'.format(hash_key_var) - i, hash_key_expression = next((i, e) for i, e in enumerate(expressions) - if re.search(hash_key_regex, e)) - hash_key_expression = hash_key_expression.strip('()') + index_hash_key = [key for key in index if key["KeyType"] == "HASH"][0] + hash_key_var = reverse_attribute_lookup.get( + index_hash_key["AttributeName"], index_hash_key["AttributeName"] + ) + hash_key_regex = r"(^|[\s(]){0}\b".format(hash_key_var) + i, hash_key_expression = next( + (i, e) + for i, e in enumerate(expressions) + if re.search(hash_key_regex, e) + ) + hash_key_expression = hash_key_expression.strip("()") expressions.pop(i) # TODO implement more than one range expression and OR operators - range_key_expression = expressions[0].strip('()') + range_key_expression = expressions[0].strip("()") range_key_expression_components = range_key_expression.split() range_comparison = range_key_expression_components[1] - if 'AND' in range_key_expression: - range_comparison = 'BETWEEN' + if "AND" in range_key_expression: + range_comparison = "BETWEEN" range_values = [ value_alias_map[range_key_expression_components[2]], value_alias_map[range_key_expression_components[4]], ] - elif 'begins_with' in range_key_expression: - range_comparison = 'BEGINS_WITH' - range_values = [ - value_alias_map[range_key_expression_components[1]], - ] + elif "begins_with" in range_key_expression: + range_comparison = "BEGINS_WITH" + range_values = [value_alias_map[range_key_expression_components[1]]] else: - range_values = [value_alias_map[ - range_key_expression_components[2]]] + range_values = [value_alias_map[range_key_expression_components[2]]] else: - hash_key_expression = key_condition_expression.strip('()') + hash_key_expression = key_condition_expression.strip("()") range_comparison = None range_values = [] - if '=' not in hash_key_expression: - return self.error('com.amazonaws.dynamodb.v20111205#ValidationException', - 'Query key condition not supported') + if "=" not in hash_key_expression: + return self.error( + "com.amazonaws.dynamodb.v20111205#ValidationException", + "Query key condition not supported", + ) hash_key_value_alias = hash_key_expression.split("=")[1].strip() # Temporary fix until we get proper KeyConditionExpression function - hash_key = value_alias_map.get(hash_key_value_alias, {'S': hash_key_value_alias}) + hash_key = value_alias_map.get( + hash_key_value_alias, {"S": hash_key_value_alias} + ) else: # 'KeyConditions': {u'forum_name': {u'ComparisonOperator': u'EQ', u'AttributeValueList': [{u'S': u'the-key'}]}} - key_conditions = self.body.get('KeyConditions') + key_conditions = self.body.get("KeyConditions") query_filters = self.body.get("QueryFilter") if key_conditions: - hash_key_name, range_key_name = self.dynamodb_backend.get_table_keys_name( - name, key_conditions.keys()) + ( + hash_key_name, + range_key_name, + ) = self.dynamodb_backend.get_table_keys_name( + name, key_conditions.keys() + ) for key, value in key_conditions.items(): if key not in (hash_key_name, range_key_name): filter_kwargs[key] = value if hash_key_name is None: er = "'com.amazonaws.dynamodb.v20120810#ResourceNotFoundException" - return self.error(er, 'Requested resource not found') - hash_key = key_conditions[hash_key_name][ - 'AttributeValueList'][0] + return self.error(er, "Requested resource not found") + hash_key = key_conditions[hash_key_name]["AttributeValueList"][0] if len(key_conditions) == 1: range_comparison = None range_values = [] else: if range_key_name is None and not filter_kwargs: er = "com.amazon.coral.validate#ValidationException" - return self.error(er, 'Validation Exception') + return self.error(er, "Validation Exception") else: range_condition = key_conditions.get(range_key_name) if range_condition: - range_comparison = range_condition[ - 'ComparisonOperator'] - range_values = range_condition[ - 'AttributeValueList'] + range_comparison = range_condition["ComparisonOperator"] + range_values = range_condition["AttributeValueList"] else: range_comparison = None range_values = [] if query_filters: filter_kwargs.update(query_filters) - index_name = self.body.get('IndexName') - exclusive_start_key = self.body.get('ExclusiveStartKey') + index_name = self.body.get("IndexName") + exclusive_start_key = self.body.get("ExclusiveStartKey") limit = self.body.get("Limit") scan_index_forward = self.body.get("ScanIndexForward") items, scanned_count, last_evaluated_key = self.dynamodb_backend.query( - name, hash_key, range_comparison, range_values, limit, - exclusive_start_key, scan_index_forward, projection_expression, index_name=index_name, - expr_names=expression_attribute_names, expr_values=expression_attribute_values, - filter_expression=filter_expression, **filter_kwargs + name, + hash_key, + range_comparison, + range_values, + limit, + exclusive_start_key, + scan_index_forward, + projection_expression, + index_name=index_name, + expr_names=expression_attribute_names, + expr_values=expression_attribute_values, + filter_expression=filter_expression, + **filter_kwargs ) if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") result = { "Count": len(items), - 'ConsumedCapacity': { - 'TableName': name, - 'CapacityUnits': 1, - }, - "ScannedCount": scanned_count + "ConsumedCapacity": {"TableName": name, "CapacityUnits": 1}, + "ScannedCount": scanned_count, } - if self.body.get('Select', '').upper() != 'COUNT': + if self.body.get("Select", "").upper() != "COUNT": result["Items"] = [item.attrs for item in items] if last_evaluated_key is not None: @@ -528,9 +571,11 @@ class DynamoHandler(BaseResponse): return dynamo_json_dump(result) - def _adjust_projection_expression(self, projection_expression, expression_attribute_names): + def _adjust_projection_expression( + self, projection_expression, expression_attribute_names + ): if projection_expression and expression_attribute_names: - expressions = [x.strip() for x in projection_expression.split(',')] + expressions = [x.strip() for x in projection_expression.split(",")] projection_expr = None for expression in expressions: if projection_expr is not None: @@ -539,8 +584,9 @@ class DynamoHandler(BaseResponse): projection_expr = "" if expression in expression_attribute_names: - projection_expr = projection_expr + \ - expression_attribute_names[expression] + projection_expr = ( + projection_expr + expression_attribute_names[expression] + ) else: projection_expr = projection_expr + expression return projection_expr @@ -548,10 +594,10 @@ class DynamoHandler(BaseResponse): return projection_expression def scan(self): - name = self.body['TableName'] + name = self.body["TableName"] filters = {} - scan_filters = self.body.get('ScanFilter', {}) + scan_filters = self.body.get("ScanFilter", {}) for attribute_name, scan_filter in scan_filters.items(): # Keys are attribute names. Values are tuples of (comparison, # comparison_value) @@ -559,193 +605,217 @@ class DynamoHandler(BaseResponse): comparison_values = scan_filter.get("AttributeValueList", []) filters[attribute_name] = (comparison_operator, comparison_values) - filter_expression = self.body.get('FilterExpression') - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - projection_expression = self.body.get('ProjectionExpression', '') - exclusive_start_key = self.body.get('ExclusiveStartKey') + filter_expression = self.body.get("FilterExpression") + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + projection_expression = self.body.get("ProjectionExpression", "") + exclusive_start_key = self.body.get("ExclusiveStartKey") limit = self.body.get("Limit") - index_name = self.body.get('IndexName') + index_name = self.body.get("IndexName") try: - items, scanned_count, last_evaluated_key = self.dynamodb_backend.scan(name, filters, - limit, - exclusive_start_key, - filter_expression, - expression_attribute_names, - expression_attribute_values, - index_name, - projection_expression) + items, scanned_count, last_evaluated_key = self.dynamodb_backend.scan( + name, + filters, + limit, + exclusive_start_key, + filter_expression, + expression_attribute_names, + expression_attribute_values, + index_name, + projection_expression, + ) except InvalidIndexNameError as err: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' + er = "com.amazonaws.dynamodb.v20111205#ValidationException" return self.error(er, str(err)) except ValueError as err: - er = 'com.amazonaws.dynamodb.v20111205#ValidationError' - return self.error(er, 'Bad Filter Expression: {0}'.format(err)) + er = "com.amazonaws.dynamodb.v20111205#ValidationError" + return self.error(er, "Bad Filter Expression: {0}".format(err)) except Exception as err: - er = 'com.amazonaws.dynamodb.v20111205#InternalFailure' - return self.error(er, 'Internal error. {0}'.format(err)) + er = "com.amazonaws.dynamodb.v20111205#InternalFailure" + return self.error(er, "Internal error. {0}".format(err)) # Items should be a list, at least an empty one. Is None if table does not exist. # Should really check this at the beginning if items is None: - er = 'com.amazonaws.dynamodb.v20111205#ResourceNotFoundException' - return self.error(er, 'Requested resource not found') + er = "com.amazonaws.dynamodb.v20111205#ResourceNotFoundException" + return self.error(er, "Requested resource not found") result = { "Count": len(items), "Items": [item.attrs for item in items], - 'ConsumedCapacity': { - 'TableName': name, - 'CapacityUnits': 1, - }, - "ScannedCount": scanned_count + "ConsumedCapacity": {"TableName": name, "CapacityUnits": 1}, + "ScannedCount": scanned_count, } if last_evaluated_key is not None: result["LastEvaluatedKey"] = last_evaluated_key return dynamo_json_dump(result) def delete_item(self): - name = self.body['TableName'] - key = self.body['Key'] - return_values = self.body.get('ReturnValues', 'NONE') - if return_values not in ('ALL_OLD', 'NONE'): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Return values set to invalid value') + name = self.body["TableName"] + key = self.body["Key"] + return_values = self.body.get("ReturnValues", "NONE") + if return_values not in ("ALL_OLD", "NONE"): + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Return values set to invalid value") table = self.dynamodb_backend.get_table(name) if not table: - er = 'com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException' - return self.error(er, 'A condition specified in the operation could not be evaluated.') + er = "com.amazonaws.dynamodb.v20120810#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) # Attempt to parse simple ConditionExpressions into an Expected # expression - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + condition_expression = self.body.get("ConditionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) try: item = self.dynamodb_backend.delete_item( - name, key, expression_attribute_names, expression_attribute_values, - condition_expression + name, + key, + expression_attribute_names, + expression_attribute_values, + condition_expression, ) except ValueError: - er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' - return self.error(er, 'A condition specified in the operation could not be evaluated.') + er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) - if item and return_values == 'ALL_OLD': + if item and return_values == "ALL_OLD": item_dict = item.to_json() else: - item_dict = {'Attributes': {}} - item_dict['ConsumedCapacityUnits'] = 0.5 + item_dict = {"Attributes": {}} + item_dict["ConsumedCapacityUnits"] = 0.5 return dynamo_json_dump(item_dict) def update_item(self): - name = self.body['TableName'] - key = self.body['Key'] - return_values = self.body.get('ReturnValues', 'NONE') - update_expression = self.body.get('UpdateExpression', '').strip() - attribute_updates = self.body.get('AttributeUpdates') - expression_attribute_names = self.body.get( - 'ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get( - 'ExpressionAttributeValues', {}) + name = self.body["TableName"] + key = self.body["Key"] + return_values = self.body.get("ReturnValues", "NONE") + update_expression = self.body.get("UpdateExpression", "").strip() + attribute_updates = self.body.get("AttributeUpdates") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) existing_item = self.dynamodb_backend.get_item(name, key) if existing_item: - existing_attributes = existing_item.to_json()['Attributes'] + existing_attributes = existing_item.to_json()["Attributes"] else: existing_attributes = {} - if return_values not in ('NONE', 'ALL_OLD', 'ALL_NEW', 'UPDATED_OLD', - 'UPDATED_NEW'): - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Return values set to invalid value') + if return_values not in ( + "NONE", + "ALL_OLD", + "ALL_NEW", + "UPDATED_OLD", + "UPDATED_NEW", + ): + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Return values set to invalid value") if has_empty_keys_or_values(expression_attribute_values): return get_empty_str_error() - if 'Expected' in self.body: - expected = self.body['Expected'] + if "Expected" in self.body: + expected = self.body["Expected"] else: expected = None # Attempt to parse simple ConditionExpressions into an Expected # expression - condition_expression = self.body.get('ConditionExpression') - expression_attribute_names = self.body.get('ExpressionAttributeNames', {}) - expression_attribute_values = self.body.get('ExpressionAttributeValues', {}) + condition_expression = self.body.get("ConditionExpression") + expression_attribute_names = self.body.get("ExpressionAttributeNames", {}) + expression_attribute_values = self.body.get("ExpressionAttributeValues", {}) # Support spaces between operators in an update expression # E.g. `a = b + c` -> `a=b+c` if update_expression: - update_expression = re.sub( - r'\s*([=\+-])\s*', '\\1', update_expression) + update_expression = re.sub(r"\s*([=\+-])\s*", "\\1", update_expression) try: item = self.dynamodb_backend.update_item( - name, key, update_expression, attribute_updates, expression_attribute_names, - expression_attribute_values, expected, condition_expression + name, + key, + update_expression, + attribute_updates, + expression_attribute_names, + expression_attribute_values, + expected, + condition_expression, ) except InvalidUpdateExpression: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'The document path provided in the update expression is invalid for update') + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error( + er, + "The document path provided in the update expression is invalid for update", + ) except ItemSizeTooLarge: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' + er = "com.amazonaws.dynamodb.v20111205#ValidationException" return self.error(er, ItemSizeTooLarge.message) except ValueError: - er = 'com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException' - return self.error(er, 'A condition specified in the operation could not be evaluated.') + er = "com.amazonaws.dynamodb.v20111205#ConditionalCheckFailedException" + return self.error( + er, "A condition specified in the operation could not be evaluated." + ) except TypeError: - er = 'com.amazonaws.dynamodb.v20111205#ValidationException' - return self.error(er, 'Validation Exception') + er = "com.amazonaws.dynamodb.v20111205#ValidationException" + return self.error(er, "Validation Exception") item_dict = item.to_json() - item_dict['ConsumedCapacity'] = { - 'TableName': name, - 'CapacityUnits': 0.5 - } + item_dict["ConsumedCapacity"] = {"TableName": name, "CapacityUnits": 0.5} unchanged_attributes = { - k for k in existing_attributes.keys() - if existing_attributes[k] == item_dict['Attributes'].get(k) + k + for k in existing_attributes.keys() + if existing_attributes[k] == item_dict["Attributes"].get(k) } - changed_attributes = set(existing_attributes.keys()).union(item_dict['Attributes'].keys()).difference(unchanged_attributes) + changed_attributes = ( + set(existing_attributes.keys()) + .union(item_dict["Attributes"].keys()) + .difference(unchanged_attributes) + ) - if return_values == 'NONE': - item_dict['Attributes'] = {} - elif return_values == 'ALL_OLD': - item_dict['Attributes'] = existing_attributes - elif return_values == 'UPDATED_OLD': - item_dict['Attributes'] = { - k: v for k, v in existing_attributes.items() - if k in changed_attributes + if return_values == "NONE": + item_dict["Attributes"] = {} + elif return_values == "ALL_OLD": + item_dict["Attributes"] = existing_attributes + elif return_values == "UPDATED_OLD": + item_dict["Attributes"] = { + k: v for k, v in existing_attributes.items() if k in changed_attributes } - elif return_values == 'UPDATED_NEW': - item_dict['Attributes'] = { - k: v for k, v in item_dict['Attributes'].items() + elif return_values == "UPDATED_NEW": + item_dict["Attributes"] = { + k: v + for k, v in item_dict["Attributes"].items() if k in changed_attributes } return dynamo_json_dump(item_dict) def describe_limits(self): - return json.dumps({ - 'AccountMaxReadCapacityUnits': 20000, - 'TableMaxWriteCapacityUnits': 10000, - 'AccountMaxWriteCapacityUnits': 20000, - 'TableMaxReadCapacityUnits': 10000 - }) + return json.dumps( + { + "AccountMaxReadCapacityUnits": 20000, + "TableMaxWriteCapacityUnits": 10000, + "AccountMaxWriteCapacityUnits": 20000, + "TableMaxReadCapacityUnits": 10000, + } + ) def update_time_to_live(self): - name = self.body['TableName'] - ttl_spec = self.body['TimeToLiveSpecification'] + name = self.body["TableName"] + ttl_spec = self.body["TimeToLiveSpecification"] self.dynamodb_backend.update_ttl(name, ttl_spec) - return json.dumps({'TimeToLiveSpecification': ttl_spec}) + return json.dumps({"TimeToLiveSpecification": ttl_spec}) def describe_time_to_live(self): - name = self.body['TableName'] + name = self.body["TableName"] ttl_spec = self.dynamodb_backend.describe_ttl(name) - return json.dumps({'TimeToLiveDescription': ttl_spec}) + return json.dumps({"TimeToLiveDescription": ttl_spec}) diff --git a/moto/dynamodb2/urls.py b/moto/dynamodb2/urls.py index 6988f6e15..26f0701a2 100644 --- a/moto/dynamodb2/urls.py +++ b/moto/dynamodb2/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DynamoHandler -url_bases = [ - "https?://dynamodb.(.+).amazonaws.com" -] +url_bases = ["https?://dynamodb.(.+).amazonaws.com"] -url_paths = { - "{0}/": DynamoHandler.dispatch, -} +url_paths = {"{0}/": DynamoHandler.dispatch} diff --git a/moto/dynamodbstreams/__init__.py b/moto/dynamodbstreams/__init__.py index b35879eba..85dd5404c 100644 --- a/moto/dynamodbstreams/__init__.py +++ b/moto/dynamodbstreams/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import dynamodbstreams_backends from ..core.models import base_decorator -dynamodbstreams_backend = dynamodbstreams_backends['us-east-1'] +dynamodbstreams_backend = dynamodbstreams_backends["us-east-1"] mock_dynamodbstreams = base_decorator(dynamodbstreams_backends) diff --git a/moto/dynamodbstreams/models.py b/moto/dynamodbstreams/models.py index 3e20ae13f..6e99d8ef6 100644 --- a/moto/dynamodbstreams/models.py +++ b/moto/dynamodbstreams/models.py @@ -10,51 +10,59 @@ from moto.dynamodb2.models import dynamodb_backends class ShardIterator(BaseModel): - def __init__(self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None): - self.id = base64.b64encode(os.urandom(472)).decode('utf-8') + def __init__( + self, streams_backend, stream_shard, shard_iterator_type, sequence_number=None + ): + self.id = base64.b64encode(os.urandom(472)).decode("utf-8") self.streams_backend = streams_backend self.stream_shard = stream_shard self.shard_iterator_type = shard_iterator_type - if shard_iterator_type == 'TRIM_HORIZON': + if shard_iterator_type == "TRIM_HORIZON": self.sequence_number = stream_shard.starting_sequence_number - elif shard_iterator_type == 'LATEST': - self.sequence_number = stream_shard.starting_sequence_number + len(stream_shard.items) - elif shard_iterator_type == 'AT_SEQUENCE_NUMBER': + elif shard_iterator_type == "LATEST": + self.sequence_number = stream_shard.starting_sequence_number + len( + stream_shard.items + ) + elif shard_iterator_type == "AT_SEQUENCE_NUMBER": self.sequence_number = sequence_number - elif shard_iterator_type == 'AFTER_SEQUENCE_NUMBER': + elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER": self.sequence_number = sequence_number + 1 @property def arn(self): - return '{}/stream/{}|1|{}'.format( + return "{}/stream/{}|1|{}".format( self.stream_shard.table.table_arn, self.stream_shard.table.latest_stream_label, - self.id) + self.id, + ) def to_json(self): - return { - 'ShardIterator': self.arn - } + return {"ShardIterator": self.arn} def get(self, limit=1000): items = self.stream_shard.get(self.sequence_number, limit) try: - last_sequence_number = max(int(i['dynamodb']['SequenceNumber']) for i in items) - new_shard_iterator = ShardIterator(self.streams_backend, - self.stream_shard, - 'AFTER_SEQUENCE_NUMBER', - last_sequence_number) + last_sequence_number = max( + int(i["dynamodb"]["SequenceNumber"]) for i in items + ) + new_shard_iterator = ShardIterator( + self.streams_backend, + self.stream_shard, + "AFTER_SEQUENCE_NUMBER", + last_sequence_number, + ) except ValueError: - new_shard_iterator = ShardIterator(self.streams_backend, - self.stream_shard, - 'AT_SEQUENCE_NUMBER', - self.sequence_number) + new_shard_iterator = ShardIterator( + self.streams_backend, + self.stream_shard, + "AT_SEQUENCE_NUMBER", + self.sequence_number, + ) - self.streams_backend.shard_iterators[new_shard_iterator.arn] = new_shard_iterator - return { - 'NextShardIterator': new_shard_iterator.arn, - 'Records': items - } + self.streams_backend.shard_iterators[ + new_shard_iterator.arn + ] = new_shard_iterator + return {"NextShardIterator": new_shard_iterator.arn, "Records": items} class DynamoDBStreamsBackend(BaseBackend): @@ -72,23 +80,27 @@ class DynamoDBStreamsBackend(BaseBackend): return dynamodb_backends[self.region] def _get_table_from_arn(self, arn): - table_name = arn.split(':', 6)[5].split('/')[1] + table_name = arn.split(":", 6)[5].split("/")[1] return self.dynamodb.get_table(table_name) def describe_stream(self, arn): table = self._get_table_from_arn(arn) - resp = {'StreamDescription': { - 'StreamArn': arn, - 'StreamLabel': table.latest_stream_label, - 'StreamStatus': ('ENABLED' if table.latest_stream_label - else 'DISABLED'), - 'StreamViewType': table.stream_specification['StreamViewType'], - 'CreationRequestDateTime': table.stream_shard.created_on.isoformat(), - 'TableName': table.name, - 'KeySchema': table.schema, - 'Shards': ([table.stream_shard.to_json()] if table.stream_shard - else []) - }} + resp = { + "StreamDescription": { + "StreamArn": arn, + "StreamLabel": table.latest_stream_label, + "StreamStatus": ( + "ENABLED" if table.latest_stream_label else "DISABLED" + ), + "StreamViewType": table.stream_specification["StreamViewType"], + "CreationRequestDateTime": table.stream_shard.created_on.isoformat(), + "TableName": table.name, + "KeySchema": table.schema, + "Shards": ( + [table.stream_shard.to_json()] if table.stream_shard else [] + ), + } + } return json.dumps(resp) @@ -98,22 +110,26 @@ class DynamoDBStreamsBackend(BaseBackend): if table_name is not None and table.name != table_name: continue if table.latest_stream_label: - d = table.describe(base_key='Table') - streams.append({ - 'StreamArn': d['Table']['LatestStreamArn'], - 'TableName': d['Table']['TableName'], - 'StreamLabel': d['Table']['LatestStreamLabel'] - }) + d = table.describe(base_key="Table") + streams.append( + { + "StreamArn": d["Table"]["LatestStreamArn"], + "TableName": d["Table"]["TableName"], + "StreamLabel": d["Table"]["LatestStreamLabel"], + } + ) - return json.dumps({'Streams': streams}) + return json.dumps({"Streams": streams}) - def get_shard_iterator(self, arn, shard_id, shard_iterator_type, sequence_number=None): + def get_shard_iterator( + self, arn, shard_id, shard_iterator_type, sequence_number=None + ): table = self._get_table_from_arn(arn) assert table.stream_shard.id == shard_id - shard_iterator = ShardIterator(self, table.stream_shard, - shard_iterator_type, - sequence_number) + shard_iterator = ShardIterator( + self, table.stream_shard, shard_iterator_type, sequence_number + ) self.shard_iterators[shard_iterator.arn] = shard_iterator return json.dumps(shard_iterator.to_json()) @@ -123,7 +139,7 @@ class DynamoDBStreamsBackend(BaseBackend): return json.dumps(shard_iterator.get(limit)) -available_regions = boto3.session.Session().get_available_regions( - 'dynamodbstreams') -dynamodbstreams_backends = {region: DynamoDBStreamsBackend(region=region) - for region in available_regions} +available_regions = boto3.session.Session().get_available_regions("dynamodbstreams") +dynamodbstreams_backends = { + region: DynamoDBStreamsBackend(region=region) for region in available_regions +} diff --git a/moto/dynamodbstreams/responses.py b/moto/dynamodbstreams/responses.py index 7774f3239..d4f5c78a6 100644 --- a/moto/dynamodbstreams/responses.py +++ b/moto/dynamodbstreams/responses.py @@ -7,34 +7,34 @@ from six import string_types class DynamoDBStreamsHandler(BaseResponse): - @property def backend(self): return dynamodbstreams_backends[self.region] def describe_stream(self): - arn = self._get_param('StreamArn') + arn = self._get_param("StreamArn") return self.backend.describe_stream(arn) def list_streams(self): - table_name = self._get_param('TableName') + table_name = self._get_param("TableName") return self.backend.list_streams(table_name) def get_shard_iterator(self): - arn = self._get_param('StreamArn') - shard_id = self._get_param('ShardId') - shard_iterator_type = self._get_param('ShardIteratorType') - sequence_number = self._get_param('SequenceNumber') + arn = self._get_param("StreamArn") + shard_id = self._get_param("ShardId") + shard_iterator_type = self._get_param("ShardIteratorType") + sequence_number = self._get_param("SequenceNumber") # according to documentation sequence_number param should be string if isinstance(sequence_number, string_types): sequence_number = int(sequence_number) - return self.backend.get_shard_iterator(arn, shard_id, - shard_iterator_type, sequence_number) + return self.backend.get_shard_iterator( + arn, shard_id, shard_iterator_type, sequence_number + ) def get_records(self): - arn = self._get_param('ShardIterator') - limit = self._get_param('Limit') + arn = self._get_param("ShardIterator") + limit = self._get_param("Limit") if limit is None: limit = 1000 return self.backend.get_records(arn, limit) diff --git a/moto/dynamodbstreams/urls.py b/moto/dynamodbstreams/urls.py index 1d0f94c35..a7589ae13 100644 --- a/moto/dynamodbstreams/urls.py +++ b/moto/dynamodbstreams/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import DynamoDBStreamsHandler -url_bases = [ - "https?://streams.dynamodb.(.+).amazonaws.com" -] +url_bases = ["https?://streams.dynamodb.(.+).amazonaws.com"] -url_paths = { - "{0}/$": DynamoDBStreamsHandler.dispatch, -} +url_paths = {"{0}/$": DynamoDBStreamsHandler.dispatch} diff --git a/moto/ec2/__init__.py b/moto/ec2/__init__.py index ba8cbe0a0..c16912f57 100644 --- a/moto/ec2/__init__.py +++ b/moto/ec2/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import ec2_backends from ..core.models import base_decorator, deprecated_base_decorator -ec2_backend = ec2_backends['us-east-1'] +ec2_backend = ec2_backends["us-east-1"] mock_ec2 = base_decorator(ec2_backends) mock_ec2_deprecated = deprecated_base_decorator(ec2_backends) diff --git a/moto/ec2/exceptions.py b/moto/ec2/exceptions.py index b7a49cc57..b2c1792f2 100644 --- a/moto/ec2/exceptions.py +++ b/moto/ec2/exceptions.py @@ -7,499 +7,468 @@ class EC2ClientError(RESTError): class DependencyViolationError(EC2ClientError): - def __init__(self, message): - super(DependencyViolationError, self).__init__( - "DependencyViolation", message) + super(DependencyViolationError, self).__init__("DependencyViolation", message) class MissingParameterError(EC2ClientError): - def __init__(self, parameter): super(MissingParameterError, self).__init__( "MissingParameter", - "The request must contain the parameter {0}" - .format(parameter)) + "The request must contain the parameter {0}".format(parameter), + ) class InvalidDHCPOptionsIdError(EC2ClientError): - def __init__(self, dhcp_options_id): super(InvalidDHCPOptionsIdError, self).__init__( "InvalidDhcpOptionID.NotFound", - "DhcpOptionID {0} does not exist." - .format(dhcp_options_id)) + "DhcpOptionID {0} does not exist.".format(dhcp_options_id), + ) class MalformedDHCPOptionsIdError(EC2ClientError): - def __init__(self, dhcp_options_id): super(MalformedDHCPOptionsIdError, self).__init__( "InvalidDhcpOptionsId.Malformed", - "Invalid id: \"{0}\" (expecting \"dopt-...\")" - .format(dhcp_options_id)) + 'Invalid id: "{0}" (expecting "dopt-...")'.format(dhcp_options_id), + ) class InvalidKeyPairNameError(EC2ClientError): - def __init__(self, key): super(InvalidKeyPairNameError, self).__init__( - "InvalidKeyPair.NotFound", - "The keypair '{0}' does not exist." - .format(key)) + "InvalidKeyPair.NotFound", "The keypair '{0}' does not exist.".format(key) + ) class InvalidKeyPairDuplicateError(EC2ClientError): - def __init__(self, key): super(InvalidKeyPairDuplicateError, self).__init__( - "InvalidKeyPair.Duplicate", - "The keypair '{0}' already exists." - .format(key)) + "InvalidKeyPair.Duplicate", "The keypair '{0}' already exists.".format(key) + ) class InvalidKeyPairFormatError(EC2ClientError): - def __init__(self): super(InvalidKeyPairFormatError, self).__init__( - "InvalidKeyPair.Format", - "Key is not in valid OpenSSH public key format") + "InvalidKeyPair.Format", "Key is not in valid OpenSSH public key format" + ) class InvalidVPCIdError(EC2ClientError): - def __init__(self, vpc_id): super(InvalidVPCIdError, self).__init__( - "InvalidVpcID.NotFound", - "VpcID {0} does not exist." - .format(vpc_id)) + "InvalidVpcID.NotFound", "VpcID {0} does not exist.".format(vpc_id) + ) class InvalidSubnetIdError(EC2ClientError): - def __init__(self, subnet_id): super(InvalidSubnetIdError, self).__init__( "InvalidSubnetID.NotFound", - "The subnet ID '{0}' does not exist" - .format(subnet_id)) + "The subnet ID '{0}' does not exist".format(subnet_id), + ) class InvalidNetworkAclIdError(EC2ClientError): - def __init__(self, network_acl_id): super(InvalidNetworkAclIdError, self).__init__( "InvalidNetworkAclID.NotFound", - "The network acl ID '{0}' does not exist" - .format(network_acl_id)) + "The network acl ID '{0}' does not exist".format(network_acl_id), + ) class InvalidVpnGatewayIdError(EC2ClientError): - def __init__(self, network_acl_id): super(InvalidVpnGatewayIdError, self).__init__( "InvalidVpnGatewayID.NotFound", - "The virtual private gateway ID '{0}' does not exist" - .format(network_acl_id)) + "The virtual private gateway ID '{0}' does not exist".format( + network_acl_id + ), + ) class InvalidVpnConnectionIdError(EC2ClientError): - def __init__(self, network_acl_id): super(InvalidVpnConnectionIdError, self).__init__( "InvalidVpnConnectionID.NotFound", - "The vpnConnection ID '{0}' does not exist" - .format(network_acl_id)) + "The vpnConnection ID '{0}' does not exist".format(network_acl_id), + ) class InvalidCustomerGatewayIdError(EC2ClientError): - def __init__(self, customer_gateway_id): super(InvalidCustomerGatewayIdError, self).__init__( "InvalidCustomerGatewayID.NotFound", - "The customer gateway ID '{0}' does not exist" - .format(customer_gateway_id)) + "The customer gateway ID '{0}' does not exist".format(customer_gateway_id), + ) class InvalidNetworkInterfaceIdError(EC2ClientError): - def __init__(self, eni_id): super(InvalidNetworkInterfaceIdError, self).__init__( "InvalidNetworkInterfaceID.NotFound", - "The network interface ID '{0}' does not exist" - .format(eni_id)) + "The network interface ID '{0}' does not exist".format(eni_id), + ) class InvalidNetworkAttachmentIdError(EC2ClientError): - def __init__(self, attachment_id): super(InvalidNetworkAttachmentIdError, self).__init__( "InvalidAttachmentID.NotFound", - "The network interface attachment ID '{0}' does not exist" - .format(attachment_id)) + "The network interface attachment ID '{0}' does not exist".format( + attachment_id + ), + ) class InvalidSecurityGroupDuplicateError(EC2ClientError): - def __init__(self, name): super(InvalidSecurityGroupDuplicateError, self).__init__( "InvalidGroup.Duplicate", - "The security group '{0}' already exists" - .format(name)) + "The security group '{0}' already exists".format(name), + ) class InvalidSecurityGroupNotFoundError(EC2ClientError): - def __init__(self, name): super(InvalidSecurityGroupNotFoundError, self).__init__( "InvalidGroup.NotFound", - "The security group '{0}' does not exist" - .format(name)) + "The security group '{0}' does not exist".format(name), + ) class InvalidPermissionNotFoundError(EC2ClientError): - def __init__(self): super(InvalidPermissionNotFoundError, self).__init__( "InvalidPermission.NotFound", - "The specified rule does not exist in this security group") + "The specified rule does not exist in this security group", + ) class InvalidPermissionDuplicateError(EC2ClientError): - def __init__(self): super(InvalidPermissionDuplicateError, self).__init__( - "InvalidPermission.Duplicate", - "The specified rule already exists") + "InvalidPermission.Duplicate", "The specified rule already exists" + ) class InvalidRouteTableIdError(EC2ClientError): - def __init__(self, route_table_id): super(InvalidRouteTableIdError, self).__init__( "InvalidRouteTableID.NotFound", - "The routeTable ID '{0}' does not exist" - .format(route_table_id)) + "The routeTable ID '{0}' does not exist".format(route_table_id), + ) class InvalidRouteError(EC2ClientError): - def __init__(self, route_table_id, cidr): super(InvalidRouteError, self).__init__( "InvalidRoute.NotFound", - "no route with destination-cidr-block {0} in route table {1}" - .format(cidr, route_table_id)) + "no route with destination-cidr-block {0} in route table {1}".format( + cidr, route_table_id + ), + ) class InvalidInstanceIdError(EC2ClientError): - def __init__(self, instance_id): super(InvalidInstanceIdError, self).__init__( "InvalidInstanceID.NotFound", - "The instance ID '{0}' does not exist" - .format(instance_id)) + "The instance ID '{0}' does not exist".format(instance_id), + ) class InvalidAMIIdError(EC2ClientError): - def __init__(self, ami_id): super(InvalidAMIIdError, self).__init__( "InvalidAMIID.NotFound", - "The image id '[{0}]' does not exist" - .format(ami_id)) + "The image id '[{0}]' does not exist".format(ami_id), + ) class InvalidAMIAttributeItemValueError(EC2ClientError): - def __init__(self, attribute, value): super(InvalidAMIAttributeItemValueError, self).__init__( "InvalidAMIAttributeItemValue", - "Invalid attribute item value \"{0}\" for {1} item type." - .format(value, attribute)) + 'Invalid attribute item value "{0}" for {1} item type.'.format( + value, attribute + ), + ) class MalformedAMIIdError(EC2ClientError): - def __init__(self, ami_id): super(MalformedAMIIdError, self).__init__( "InvalidAMIID.Malformed", - "Invalid id: \"{0}\" (expecting \"ami-...\")" - .format(ami_id)) + 'Invalid id: "{0}" (expecting "ami-...")'.format(ami_id), + ) class InvalidSnapshotIdError(EC2ClientError): - def __init__(self, snapshot_id): super(InvalidSnapshotIdError, self).__init__( - "InvalidSnapshot.NotFound", - "") # Note: AWS returns empty message for this, as of 2014.08.22. + "InvalidSnapshot.NotFound", "" + ) # Note: AWS returns empty message for this, as of 2014.08.22. class InvalidVolumeIdError(EC2ClientError): - def __init__(self, volume_id): super(InvalidVolumeIdError, self).__init__( "InvalidVolume.NotFound", - "The volume '{0}' does not exist." - .format(volume_id)) + "The volume '{0}' does not exist.".format(volume_id), + ) class InvalidVolumeAttachmentError(EC2ClientError): - def __init__(self, volume_id, instance_id): super(InvalidVolumeAttachmentError, self).__init__( "InvalidAttachment.NotFound", - "Volume {0} can not be detached from {1} because it is not attached" - .format(volume_id, instance_id)) + "Volume {0} can not be detached from {1} because it is not attached".format( + volume_id, instance_id + ), + ) class InvalidDomainError(EC2ClientError): - def __init__(self, domain): super(InvalidDomainError, self).__init__( - "InvalidParameterValue", - "Invalid value '{0}' for domain." - .format(domain)) + "InvalidParameterValue", "Invalid value '{0}' for domain.".format(domain) + ) class InvalidAddressError(EC2ClientError): - def __init__(self, ip): super(InvalidAddressError, self).__init__( - "InvalidAddress.NotFound", - "Address '{0}' not found." - .format(ip)) + "InvalidAddress.NotFound", "Address '{0}' not found.".format(ip) + ) class InvalidAllocationIdError(EC2ClientError): - def __init__(self, allocation_id): super(InvalidAllocationIdError, self).__init__( "InvalidAllocationID.NotFound", - "Allocation ID '{0}' not found." - .format(allocation_id)) + "Allocation ID '{0}' not found.".format(allocation_id), + ) class InvalidAssociationIdError(EC2ClientError): - def __init__(self, association_id): super(InvalidAssociationIdError, self).__init__( "InvalidAssociationID.NotFound", - "Association ID '{0}' not found." - .format(association_id)) + "Association ID '{0}' not found.".format(association_id), + ) class InvalidVpcCidrBlockAssociationIdError(EC2ClientError): - def __init__(self, association_id): super(InvalidVpcCidrBlockAssociationIdError, self).__init__( "InvalidVpcCidrBlockAssociationIdError.NotFound", - "The vpc CIDR block association ID '{0}' does not exist" - .format(association_id)) + "The vpc CIDR block association ID '{0}' does not exist".format( + association_id + ), + ) class InvalidVPCPeeringConnectionIdError(EC2ClientError): - def __init__(self, vpc_peering_connection_id): super(InvalidVPCPeeringConnectionIdError, self).__init__( "InvalidVpcPeeringConnectionId.NotFound", - "VpcPeeringConnectionID {0} does not exist." - .format(vpc_peering_connection_id)) + "VpcPeeringConnectionID {0} does not exist.".format( + vpc_peering_connection_id + ), + ) class InvalidVPCPeeringConnectionStateTransitionError(EC2ClientError): - def __init__(self, vpc_peering_connection_id): super(InvalidVPCPeeringConnectionStateTransitionError, self).__init__( "InvalidStateTransition", - "VpcPeeringConnectionID {0} is not in the correct state for the request." - .format(vpc_peering_connection_id)) + "VpcPeeringConnectionID {0} is not in the correct state for the request.".format( + vpc_peering_connection_id + ), + ) class InvalidParameterValueError(EC2ClientError): - def __init__(self, parameter_value): super(InvalidParameterValueError, self).__init__( "InvalidParameterValue", - "Value {0} is invalid for parameter." - .format(parameter_value)) + "Value {0} is invalid for parameter.".format(parameter_value), + ) class InvalidParameterValueErrorTagNull(EC2ClientError): - def __init__(self): super(InvalidParameterValueErrorTagNull, self).__init__( "InvalidParameterValue", - "Tag value cannot be null. Use empty string instead.") + "Tag value cannot be null. Use empty string instead.", + ) class InvalidParameterValueErrorUnknownAttribute(EC2ClientError): - def __init__(self, parameter_value): super(InvalidParameterValueErrorUnknownAttribute, self).__init__( "InvalidParameterValue", - "Value ({0}) for parameter attribute is invalid. Unknown attribute." - .format(parameter_value)) + "Value ({0}) for parameter attribute is invalid. Unknown attribute.".format( + parameter_value + ), + ) class InvalidInternetGatewayIdError(EC2ClientError): - def __init__(self, internet_gateway_id): super(InvalidInternetGatewayIdError, self).__init__( "InvalidInternetGatewayID.NotFound", - "InternetGatewayID {0} does not exist." - .format(internet_gateway_id)) + "InternetGatewayID {0} does not exist.".format(internet_gateway_id), + ) class GatewayNotAttachedError(EC2ClientError): - def __init__(self, internet_gateway_id, vpc_id): super(GatewayNotAttachedError, self).__init__( "Gateway.NotAttached", - "InternetGatewayID {0} is not attached to a VPC {1}." - .format(internet_gateway_id, vpc_id)) + "InternetGatewayID {0} is not attached to a VPC {1}.".format( + internet_gateway_id, vpc_id + ), + ) class ResourceAlreadyAssociatedError(EC2ClientError): - def __init__(self, resource_id): super(ResourceAlreadyAssociatedError, self).__init__( "Resource.AlreadyAssociated", - "Resource {0} is already associated." - .format(resource_id)) + "Resource {0} is already associated.".format(resource_id), + ) class TagLimitExceeded(EC2ClientError): - def __init__(self): super(TagLimitExceeded, self).__init__( "TagLimitExceeded", - "The maximum number of Tags for a resource has been reached.") + "The maximum number of Tags for a resource has been reached.", + ) class InvalidID(EC2ClientError): - def __init__(self, resource_id): super(InvalidID, self).__init__( - "InvalidID", - "The ID '{0}' is not valid" - .format(resource_id)) + "InvalidID", "The ID '{0}' is not valid".format(resource_id) + ) class InvalidCIDRSubnetError(EC2ClientError): - def __init__(self, cidr): super(InvalidCIDRSubnetError, self).__init__( "InvalidParameterValue", - "invalid CIDR subnet specification: {0}" - .format(cidr)) + "invalid CIDR subnet specification: {0}".format(cidr), + ) class RulesPerSecurityGroupLimitExceededError(EC2ClientError): - def __init__(self): super(RulesPerSecurityGroupLimitExceededError, self).__init__( "RulesPerSecurityGroupLimitExceeded", - 'The maximum number of rules per security group ' - 'has been reached.') + "The maximum number of rules per security group " "has been reached.", + ) class MotoNotImplementedError(NotImplementedError): - def __init__(self, blurb): super(MotoNotImplementedError, self).__init__( "{0} has not been implemented in Moto yet." " Feel free to open an issue at" - " https://github.com/spulec/moto/issues".format(blurb)) + " https://github.com/spulec/moto/issues".format(blurb) + ) class FilterNotImplementedError(MotoNotImplementedError): - def __init__(self, filter_name, method_name): super(FilterNotImplementedError, self).__init__( - "The filter '{0}' for {1}".format( - filter_name, method_name)) + "The filter '{0}' for {1}".format(filter_name, method_name) + ) class CidrLimitExceeded(EC2ClientError): - def __init__(self, vpc_id, max_cidr_limit): super(CidrLimitExceeded, self).__init__( "CidrLimitExceeded", - "This network '{0}' has met its maximum number of allowed CIDRs: {1}".format(vpc_id, max_cidr_limit) + "This network '{0}' has met its maximum number of allowed CIDRs: {1}".format( + vpc_id, max_cidr_limit + ), ) class OperationNotPermitted(EC2ClientError): - def __init__(self, association_id): super(OperationNotPermitted, self).__init__( "OperationNotPermitted", "The vpc CIDR block with association ID {} may not be disassociated. " - "It is the primary IPv4 CIDR block of the VPC".format(association_id) + "It is the primary IPv4 CIDR block of the VPC".format(association_id), ) class InvalidAvailabilityZoneError(EC2ClientError): - def __init__(self, availability_zone_value, valid_availability_zones): super(InvalidAvailabilityZoneError, self).__init__( "InvalidParameterValue", "Value ({0}) for parameter availabilityZone is invalid. " - "Subnets can currently only be created in the following availability zones: {1}.".format(availability_zone_value, valid_availability_zones) + "Subnets can currently only be created in the following availability zones: {1}.".format( + availability_zone_value, valid_availability_zones + ), ) class NetworkAclEntryAlreadyExistsError(EC2ClientError): - def __init__(self, rule_number): super(NetworkAclEntryAlreadyExistsError, self).__init__( "NetworkAclEntryAlreadyExists", - "The network acl entry identified by {} already exists.".format(rule_number) + "The network acl entry identified by {} already exists.".format( + rule_number + ), ) class InvalidSubnetRangeError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidSubnetRangeError, self).__init__( - "InvalidSubnet.Range", - "The CIDR '{}' is invalid.".format(cidr_block) + "InvalidSubnet.Range", "The CIDR '{}' is invalid.".format(cidr_block) ) class InvalidCIDRBlockParameterError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidCIDRBlockParameterError, self).__init__( "InvalidParameterValue", - "Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block) + "Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format( + cidr_block + ), ) class InvalidDestinationCIDRBlockParameterError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidDestinationCIDRBlockParameterError, self).__init__( "InvalidParameterValue", - "Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(cidr_block) + "Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format( + cidr_block + ), ) class InvalidSubnetConflictError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidSubnetConflictError, self).__init__( "InvalidSubnet.Conflict", - "The CIDR '{}' conflicts with another subnet".format(cidr_block) + "The CIDR '{}' conflicts with another subnet".format(cidr_block), ) class InvalidVPCRangeError(EC2ClientError): - def __init__(self, cidr_block): super(InvalidVPCRangeError, self).__init__( - "InvalidVpc.Range", - "The CIDR '{}' is invalid.".format(cidr_block) + "InvalidVpc.Range", "The CIDR '{}' is invalid.".format(cidr_block) ) @@ -509,7 +478,9 @@ class OperationNotPermitted2(EC2ClientError): super(OperationNotPermitted2, self).__init__( "OperationNotPermitted", "Incorrect region ({0}) specified for this request." - "VPC peering connection {1} must be accepted in region {2}".format(client_region, pcx_id, acceptor_region) + "VPC peering connection {1} must be accepted in region {2}".format( + client_region, pcx_id, acceptor_region + ), ) @@ -519,9 +490,9 @@ class OperationNotPermitted3(EC2ClientError): super(OperationNotPermitted3, self).__init__( "OperationNotPermitted", "Incorrect region ({0}) specified for this request." - "VPC peering connection {1} must be accepted or rejected in region {2}".format(client_region, - pcx_id, - acceptor_region) + "VPC peering connection {1} must be accepted or rejected in region {2}".format( + client_region, pcx_id, acceptor_region + ), ) @@ -529,5 +500,5 @@ class InvalidLaunchTemplateNameError(EC2ClientError): def __init__(self): super(InvalidLaunchTemplateNameError, self).__init__( "InvalidLaunchTemplateName.AlreadyExistsException", - "Launch template name already in use." + "Launch template name already in use.", ) diff --git a/moto/ec2/models.py b/moto/ec2/models.py index 10d6f2b28..efbbeb6fe 100644 --- a/moto/ec2/models.py +++ b/moto/ec2/models.py @@ -23,7 +23,10 @@ from boto.ec2.launchspecification import LaunchSpecification from moto.compat import OrderedDict from moto.core import BaseBackend from moto.core.models import Model, BaseModel -from moto.core.utils import iso_8601_datetime_with_milliseconds, camelcase_to_underscores +from moto.core.utils import ( + iso_8601_datetime_with_milliseconds, + camelcase_to_underscores, +) from .exceptions import ( CidrLimitExceeded, DependencyViolationError, @@ -84,7 +87,8 @@ from .exceptions import ( OperationNotPermitted3, ResourceAlreadyAssociatedError, RulesPerSecurityGroupLimitExceededError, - TagLimitExceeded) + TagLimitExceeded, +) from .utils import ( EC2_RESOURCE_TO_PREFIX, EC2_PREFIX_TO_RESOURCE, @@ -132,27 +136,30 @@ from .utils import ( is_tag_filter, tag_filter_matches, rsa_public_key_parse, - rsa_public_key_fingerprint + rsa_public_key_fingerprint, ) INSTANCE_TYPES = json.load( - open(resource_filename(__name__, 'resources/instance_types.json'), 'r') + open(resource_filename(__name__, "resources/instance_types.json"), "r") ) AMIS = json.load( - open(os.environ.get('MOTO_AMIS_PATH') or resource_filename( - __name__, 'resources/amis.json'), 'r') + open( + os.environ.get("MOTO_AMIS_PATH") + or resource_filename(__name__, "resources/amis.json"), + "r", + ) ) OWNER_ID = "111122223333" def utc_date_and_time(): - return datetime.utcnow().strftime('%Y-%m-%dT%H:%M:%S.000Z') + return datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.000Z") def validate_resource_ids(resource_ids): if not resource_ids: - raise MissingParameterError(parameter='resourceIdSet') + raise MissingParameterError(parameter="resourceIdSet") for resource_id in resource_ids: if not is_valid_resource_id(resource_id): raise InvalidID(resource_id=resource_id) @@ -160,7 +167,7 @@ def validate_resource_ids(resource_ids): class InstanceState(object): - def __init__(self, name='pending', code=0): + def __init__(self, name="pending", code=0): self.name = name self.code = code @@ -173,8 +180,7 @@ class StateReason(object): class TaggedEC2Resource(BaseModel): def get_tags(self, *args, **kwargs): - tags = self.ec2_backend.describe_tags( - filters={'resource-id': [self.id]}) + tags = self.ec2_backend.describe_tags(filters={"resource-id": [self.id]}) return tags def add_tag(self, key, value): @@ -187,24 +193,32 @@ class TaggedEC2Resource(BaseModel): def get_filter_value(self, filter_name, method_name=None): tags = self.get_tags() - if filter_name.startswith('tag:'): - tagname = filter_name.replace('tag:', '', 1) + if filter_name.startswith("tag:"): + tagname = filter_name.replace("tag:", "", 1) for tag in tags: - if tag['key'] == tagname: - return tag['value'] + if tag["key"] == tagname: + return tag["value"] - return '' - elif filter_name == 'tag-key': - return [tag['key'] for tag in tags] - elif filter_name == 'tag-value': - return [tag['value'] for tag in tags] + return "" + elif filter_name == "tag-key": + return [tag["key"] for tag in tags] + elif filter_name == "tag-value": + return [tag["value"] for tag in tags] else: raise FilterNotImplementedError(filter_name, method_name) class NetworkInterface(TaggedEC2Resource): - def __init__(self, ec2_backend, subnet, private_ip_address, device_index=0, - public_ip_auto_assign=True, group_ids=None, description=None): + def __init__( + self, + ec2_backend, + subnet, + private_ip_address, + device_index=0, + public_ip_auto_assign=True, + group_ids=None, + description=None, + ): self.ec2_backend = ec2_backend self.id = random_eni_id() self.device_index = device_index @@ -231,32 +245,39 @@ class NetworkInterface(TaggedEC2Resource): if not group: # Create with specific group ID. group = SecurityGroup( - self.ec2_backend, group_id, group_id, group_id, vpc_id=subnet.vpc_id) + self.ec2_backend, + group_id, + group_id, + group_id, + vpc_id=subnet.vpc_id, + ) self.ec2_backend.groups[subnet.vpc_id][group_id] = group if group: self._group_set.append(group) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - security_group_ids = properties.get('SecurityGroups', []) + security_group_ids = properties.get("SecurityGroups", []) ec2_backend = ec2_backends[region_name] - subnet_id = properties.get('SubnetId') + subnet_id = properties.get("SubnetId") if subnet_id: subnet = ec2_backend.get_subnet(subnet_id) else: subnet = None - private_ip_address = properties.get('PrivateIpAddress', None) - description = properties.get('Description', None) + private_ip_address = properties.get("PrivateIpAddress", None) + description = properties.get("Description", None) network_interface = ec2_backend.create_network_interface( subnet, private_ip_address, group_ids=security_group_ids, - description=description + description=description, ) return network_interface @@ -280,11 +301,13 @@ class NetworkInterface(TaggedEC2Resource): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'PrimaryPrivateIpAddress': + + if attribute_name == "PrimaryPrivateIpAddress": return self.private_ip_address - elif attribute_name == 'SecondaryPrivateIpAddresses': + elif attribute_name == "SecondaryPrivateIpAddresses": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "SecondaryPrivateIpAddresses" ]"') + '"Fn::GetAtt" : [ "{0}" , "SecondaryPrivateIpAddresses" ]"' + ) raise UnformattedGetAttTemplateException() @property @@ -292,23 +315,24 @@ class NetworkInterface(TaggedEC2Resource): return self.id def get_filter_value(self, filter_name): - if filter_name == 'network-interface-id': + if filter_name == "network-interface-id": return self.id - elif filter_name in ('addresses.private-ip-address', 'private-ip-address'): + elif filter_name in ("addresses.private-ip-address", "private-ip-address"): return self.private_ip_address - elif filter_name == 'subnet-id': + elif filter_name == "subnet-id": return self.subnet.id - elif filter_name == 'vpc-id': + elif filter_name == "vpc-id": return self.subnet.vpc_id - elif filter_name == 'group-id': + elif filter_name == "group-id": return [group.id for group in self._group_set] - elif filter_name == 'availability-zone': + elif filter_name == "availability-zone": return self.subnet.availability_zone - elif filter_name == 'description': + elif filter_name == "description": return self.description else: return super(NetworkInterface, self).get_filter_value( - filter_name, 'DescribeNetworkInterfaces') + filter_name, "DescribeNetworkInterfaces" + ) class NetworkInterfaceBackend(object): @@ -316,9 +340,17 @@ class NetworkInterfaceBackend(object): self.enis = {} super(NetworkInterfaceBackend, self).__init__() - def create_network_interface(self, subnet, private_ip_address, group_ids=None, description=None, **kwargs): + def create_network_interface( + self, subnet, private_ip_address, group_ids=None, description=None, **kwargs + ): eni = NetworkInterface( - self, subnet, private_ip_address, group_ids=group_ids, description=description, **kwargs) + self, + subnet, + private_ip_address, + group_ids=group_ids, + description=description, + **kwargs + ) self.enis[eni.id] = eni return eni @@ -339,11 +371,12 @@ class NetworkInterfaceBackend(object): if filters: for (_filter, _filter_value) in filters.items(): - if _filter == 'network-interface-id': - _filter = 'id' - enis = [eni for eni in enis if getattr( - eni, _filter) in _filter_value] - elif _filter == 'group-id': + if _filter == "network-interface-id": + _filter = "id" + enis = [ + eni for eni in enis if getattr(eni, _filter) in _filter_value + ] + elif _filter == "group-id": original_enis = enis enis = [] for eni in original_enis: @@ -351,15 +384,18 @@ class NetworkInterfaceBackend(object): if group.id in _filter_value: enis.append(eni) break - elif _filter == 'private-ip-address:': - enis = [eni for eni in enis if eni.private_ip_address in _filter_value] - elif _filter == 'subnet-id': + elif _filter == "private-ip-address:": + enis = [ + eni for eni in enis if eni.private_ip_address in _filter_value + ] + elif _filter == "subnet-id": enis = [eni for eni in enis if eni.subnet.id in _filter_value] - elif _filter == 'description': + elif _filter == "description": enis = [eni for eni in enis if eni.description in _filter_value] else: self.raise_not_implemented_error( - "The filter '{0}' for DescribeNetworkInterfaces".format(_filter)) + "The filter '{0}' for DescribeNetworkInterfaces".format(_filter) + ) return enis def attach_network_interface(self, eni_id, instance_id, device_index): @@ -390,17 +426,30 @@ class NetworkInterfaceBackend(object): if eni_ids: enis = [eni for eni in enis if eni.id in eni_ids] if len(enis) != len(eni_ids): - invalid_id = list(set(eni_ids).difference( - set([eni.id for eni in enis])))[0] + invalid_id = list( + set(eni_ids).difference(set([eni.id for eni in enis])) + )[0] raise InvalidNetworkInterfaceIdError(invalid_id) return generic_filter(filters, enis) class Instance(TaggedEC2Resource, BotoInstance): - VALID_ATTRIBUTES = {'instanceType', 'kernel', 'ramdisk', 'userData', 'disableApiTermination', - 'instanceInitiatedShutdownBehavior', 'rootDeviceName', 'blockDeviceMapping', - 'productCodes', 'sourceDestCheck', 'groupSet', 'ebsOptimized', 'sriovNetSupport'} + VALID_ATTRIBUTES = { + "instanceType", + "kernel", + "ramdisk", + "userData", + "disableApiTermination", + "instanceInitiatedShutdownBehavior", + "rootDeviceName", + "blockDeviceMapping", + "productCodes", + "sourceDestCheck", + "groupSet", + "ebsOptimized", + "sriovNetSupport", + } def __init__(self, ec2_backend, image_id, user_data, security_groups, **kwargs): super(Instance, self).__init__() @@ -424,7 +473,9 @@ class Instance(TaggedEC2Resource, BotoInstance): self.launch_time = utc_date_and_time() self.ami_launch_index = kwargs.get("ami_launch_index", 0) self.disable_api_termination = kwargs.get("disable_api_termination", False) - self.instance_initiated_shutdown_behavior = kwargs.get("instance_initiated_shutdown_behavior", "stop") + self.instance_initiated_shutdown_behavior = kwargs.get( + "instance_initiated_shutdown_behavior", "stop" + ) self.sriov_net_support = "simple" self._spot_fleet_id = kwargs.get("spot_fleet_id", None) self.associate_public_ip = kwargs.get("associate_public_ip", False) @@ -432,29 +483,31 @@ class Instance(TaggedEC2Resource, BotoInstance): # If we are in EC2-Classic, autoassign a public IP self.associate_public_ip = True - amis = self.ec2_backend.describe_images(filters={'image-id': image_id}) + amis = self.ec2_backend.describe_images(filters={"image-id": image_id}) ami = amis[0] if amis else None if ami is None: - warnings.warn('Could not find AMI with image-id:{0}, ' - 'in the near future this will ' - 'cause an error.\n' - 'Use ec2_backend.describe_images() to ' - 'find suitable image for your test'.format(image_id), - PendingDeprecationWarning) + warnings.warn( + "Could not find AMI with image-id:{0}, " + "in the near future this will " + "cause an error.\n" + "Use ec2_backend.describe_images() to " + "find suitable image for your test".format(image_id), + PendingDeprecationWarning, + ) self.platform = ami.platform if ami else None - self.virtualization_type = ami.virtualization_type if ami else 'paravirtual' - self.architecture = ami.architecture if ami else 'x86_64' + self.virtualization_type = ami.virtualization_type if ami else "paravirtual" + self.architecture = ami.architecture if ami else "x86_64" # handle weird bug around user_data -- something grabs the repr(), so # it must be clean if isinstance(self.user_data, list) and len(self.user_data) > 0: if six.PY3 and isinstance(self.user_data[0], six.binary_type): # string will have a "b" prefix -- need to get rid of it - self.user_data[0] = self.user_data[0].decode('utf-8') + self.user_data[0] = self.user_data[0].decode("utf-8") elif six.PY2 and isinstance(self.user_data[0], six.text_type): # string will have a "u" prefix -- need to get rid of it - self.user_data[0] = self.user_data[0].encode('utf-8') + self.user_data[0] = self.user_data[0].encode("utf-8") if self.subnet_id: subnet = ec2_backend.get_subnet(self.subnet_id) @@ -463,11 +516,11 @@ class Instance(TaggedEC2Resource, BotoInstance): if self.associate_public_ip is None: # Mapping public ip hasnt been explicitly enabled or disabled - self.associate_public_ip = subnet.map_public_ip_on_launch == 'true' + self.associate_public_ip = subnet.map_public_ip_on_launch == "true" elif placement: self._placement.zone = placement else: - self._placement.zone = ec2_backend.region_name + 'a' + self._placement.zone = ec2_backend.region_name + "a" self.block_device_mapping = BlockDeviceMapping() @@ -475,7 +528,7 @@ class Instance(TaggedEC2Resource, BotoInstance): self.prep_nics( kwargs.get("nics", {}), private_ip=kwargs.get("private_ip"), - associate_public_ip=self.associate_public_ip + associate_public_ip=self.associate_public_ip, ) def __del__(self): @@ -491,12 +544,12 @@ class Instance(TaggedEC2Resource, BotoInstance): def setup_defaults(self): # Default have an instance with root volume should you not wish to # override with attach volume cmd. - volume = self.ec2_backend.create_volume(8, 'us-east-1a') - self.ec2_backend.attach_volume(volume.id, self.id, '/dev/sda1') + volume = self.ec2_backend.create_volume(8, "us-east-1a") + self.ec2_backend.attach_volume(volume.id, self.id, "/dev/sda1") def teardown_defaults(self): - volume_id = self.block_device_mapping['/dev/sda1'].volume_id - self.ec2_backend.detach_volume(volume_id, self.id, '/dev/sda1') + volume_id = self.block_device_mapping["/dev/sda1"].volume_id + self.ec2_backend.detach_volume(volume_id, self.id, "/dev/sda1") self.ec2_backend.delete_volume(volume_id) @property @@ -509,7 +562,7 @@ class Instance(TaggedEC2Resource, BotoInstance): @property def private_dns(self): - formatted_ip = self.private_ip.replace('.', '-') + formatted_ip = self.private_ip.replace(".", "-") if self.region_name == "us-east-1": return "ip-{0}.ec2.internal".format(formatted_ip) else: @@ -522,30 +575,36 @@ class Instance(TaggedEC2Resource, BotoInstance): @property def public_dns(self): if self.public_ip: - formatted_ip = self.public_ip.replace('.', '-') + formatted_ip = self.public_ip.replace(".", "-") if self.region_name == "us-east-1": return "ec2-{0}.compute-1.amazonaws.com".format(formatted_ip) else: - return "ec2-{0}.{1}.compute.amazonaws.com".format(formatted_ip, self.region_name) + return "ec2-{0}.{1}.compute.amazonaws.com".format( + formatted_ip, self.region_name + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - security_group_ids = properties.get('SecurityGroups', []) - group_names = [ec2_backend.get_security_group_from_id( - group_id).name for group_id in security_group_ids] + security_group_ids = properties.get("SecurityGroups", []) + group_names = [ + ec2_backend.get_security_group_from_id(group_id).name + for group_id in security_group_ids + ] reservation = ec2_backend.add_instances( - image_id=properties['ImageId'], - user_data=properties.get('UserData'), + image_id=properties["ImageId"], + user_data=properties.get("UserData"), count=1, security_group_names=group_names, instance_type=properties.get("InstanceType", "m1.small"), subnet_id=properties.get("SubnetId"), key_name=properties.get("KeyName"), - private_ip=properties.get('PrivateIpAddress'), + private_ip=properties.get("PrivateIpAddress"), ) instance = reservation.instances[0] for tag in properties.get("Tags", []): @@ -553,19 +612,24 @@ class Instance(TaggedEC2Resource, BotoInstance): return instance @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] all_instances = ec2_backend.all_instances() # the resource_name for instances is the stack name, logical id, and random suffix separated # by hyphens. So to lookup the instances using the 'aws:cloudformation:logical-id' tag, we need to # extract the logical-id from the resource_name - logical_id = resource_name.split('-')[1] + logical_id = resource_name.split("-")[1] for instance in all_instances: instance_tags = instance.get_tags() for tag in instance_tags: - if tag['key'] == 'aws:cloudformation:logical-id' and tag['value'] == logical_id: + if ( + tag["key"] == "aws:cloudformation:logical-id" + and tag["value"] == logical_id + ): instance.delete(region_name) @property @@ -590,9 +654,12 @@ class Instance(TaggedEC2Resource, BotoInstance): self._state.code = 80 self._reason = "User initiated ({0})".format( - datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')) - self._state_reason = StateReason("Client.UserInitiatedShutdown: User initiated shutdown", - "Client.UserInitiatedShutdown") + datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC") + ) + self._state_reason = StateReason( + "Client.UserInitiatedShutdown: User initiated shutdown", + "Client.UserInitiatedShutdown", + ) def delete(self, region): self.terminate() @@ -606,18 +673,26 @@ class Instance(TaggedEC2Resource, BotoInstance): if self._spot_fleet_id: spot_fleet = self.ec2_backend.get_spot_fleet_request(self._spot_fleet_id) for spec in spot_fleet.launch_specs: - if spec.instance_type == self.instance_type and spec.subnet_id == self.subnet_id: + if ( + spec.instance_type == self.instance_type + and spec.subnet_id == self.subnet_id + ): break spot_fleet.fulfilled_capacity -= spec.weighted_capacity - spot_fleet.spot_requests = [req for req in spot_fleet.spot_requests if req.instance != self] + spot_fleet.spot_requests = [ + req for req in spot_fleet.spot_requests if req.instance != self + ] self._state.name = "terminated" self._state.code = 48 self._reason = "User initiated ({0})".format( - datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC')) - self._state_reason = StateReason("Client.UserInitiatedShutdown: User initiated shutdown", - "Client.UserInitiatedShutdown") + datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC") + ) + self._state_reason = StateReason( + "Client.UserInitiatedShutdown: User initiated shutdown", + "Client.UserInitiatedShutdown", + ) def reboot(self, *args, **kwargs): self._state.name = "running" @@ -653,22 +728,24 @@ class Instance(TaggedEC2Resource, BotoInstance): private_ip = random_private_ip() # Primary NIC defaults - primary_nic = {'SubnetId': self.subnet_id, - 'PrivateIpAddress': private_ip, - 'AssociatePublicIpAddress': associate_public_ip} + primary_nic = { + "SubnetId": self.subnet_id, + "PrivateIpAddress": private_ip, + "AssociatePublicIpAddress": associate_public_ip, + } primary_nic = dict((k, v) for k, v in primary_nic.items() if v) # If empty NIC spec but primary NIC values provided, create NIC from # them. if primary_nic and not nic_spec: nic_spec[0] = primary_nic - nic_spec[0]['DeviceIndex'] = 0 + nic_spec[0]["DeviceIndex"] = 0 # Flesh out data structures and associations for nic in nic_spec.values(): - device_index = int(nic.get('DeviceIndex')) + device_index = int(nic.get("DeviceIndex")) - nic_id = nic.get('NetworkInterfaceId') + nic_id = nic.get("NetworkInterfaceId") if nic_id: # If existing NIC found, use it. use_nic = self.ec2_backend.get_network_interface(nic_id) @@ -680,21 +757,21 @@ class Instance(TaggedEC2Resource, BotoInstance): if device_index == 0 and primary_nic: nic.update(primary_nic) - if 'SubnetId' in nic: - subnet = self.ec2_backend.get_subnet(nic['SubnetId']) + if "SubnetId" in nic: + subnet = self.ec2_backend.get_subnet(nic["SubnetId"]) else: subnet = None - group_id = nic.get('SecurityGroupId') + group_id = nic.get("SecurityGroupId") group_ids = [group_id] if group_id else [] - use_nic = self.ec2_backend.create_network_interface(subnet, - nic.get( - 'PrivateIpAddress'), - device_index=device_index, - public_ip_auto_assign=nic.get( - 'AssociatePublicIpAddress', False), - group_ids=group_ids) + use_nic = self.ec2_backend.create_network_interface( + subnet, + nic.get("PrivateIpAddress"), + device_index=device_index, + public_ip_auto_assign=nic.get("AssociatePublicIpAddress", False), + group_ids=group_ids, + ) self.attach_eni(use_nic, device_index) @@ -717,15 +794,16 @@ class Instance(TaggedEC2Resource, BotoInstance): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'AvailabilityZone': + + if attribute_name == "AvailabilityZone": return self.placement - elif attribute_name == 'PrivateDnsName': + elif attribute_name == "PrivateDnsName": return self.private_dns - elif attribute_name == 'PublicDnsName': + elif attribute_name == "PublicDnsName": return self.public_dns - elif attribute_name == 'PrivateIp': + elif attribute_name == "PrivateIp": return self.private_ip - elif attribute_name == 'PublicIp': + elif attribute_name == "PublicIp": return self.public_ip raise UnformattedGetAttTemplateException() @@ -741,28 +819,26 @@ class InstanceBackend(object): return instance raise InvalidInstanceIdError(instance_id) - def add_instances(self, image_id, count, user_data, security_group_names, - **kwargs): + def add_instances(self, image_id, count, user_data, security_group_names, **kwargs): new_reservation = Reservation() new_reservation.id = random_reservation_id() - security_groups = [self.get_security_group_from_name(name) - for name in security_group_names] - security_groups.extend(self.get_security_group_from_id(sg_id) - for sg_id in kwargs.pop("security_group_ids", [])) + security_groups = [ + self.get_security_group_from_name(name) for name in security_group_names + ] + security_groups.extend( + self.get_security_group_from_id(sg_id) + for sg_id in kwargs.pop("security_group_ids", []) + ) self.reservations[new_reservation.id] = new_reservation tags = kwargs.pop("tags", {}) - instance_tags = tags.get('instance', {}) + instance_tags = tags.get("instance", {}) for index in range(count): kwargs["ami_launch_index"] = index new_instance = Instance( - self, - image_id, - user_data, - security_groups, - **kwargs + self, image_id, user_data, security_groups, **kwargs ) new_reservation.instances.append(new_instance) new_instance.add_tags(instance_tags) @@ -789,7 +865,8 @@ class InstanceBackend(object): terminated_instances = [] if not instance_ids: raise EC2ClientError( - "InvalidParameterCombination", "No instances specified") + "InvalidParameterCombination", "No instances specified" + ) for instance in self.get_multi_instances_by_id(instance_ids): instance.terminate() terminated_instances.append(instance) @@ -814,15 +891,15 @@ class InstanceBackend(object): new_group_list = [] for new_group_id in new_group_id_list: new_group_list.append(self.get_security_group_from_id(new_group_id)) - setattr(instance, 'security_groups', new_group_list) + setattr(instance, "security_groups", new_group_list) return instance def describe_instance_attribute(self, instance_id, attribute): if attribute not in Instance.VALID_ATTRIBUTES: raise InvalidParameterValueErrorUnknownAttribute(attribute) - if attribute == 'groupSet': - key = 'security_groups' + if attribute == "groupSet": + key = "security_groups" else: key = camelcase_to_underscores(attribute) instance = self.get_instance(instance_id) @@ -875,25 +952,34 @@ class InstanceBackend(object): reservations = [] for reservation in self.all_reservations(): reservation_instance_ids = [ - instance.id for instance in reservation.instances] + instance.id for instance in reservation.instances + ] matching_reservation = any( - instance_id in reservation_instance_ids for instance_id in instance_ids) + instance_id in reservation_instance_ids for instance_id in instance_ids + ) if matching_reservation: reservation.instances = [ - instance for instance in reservation.instances if instance.id in instance_ids] + instance + for instance in reservation.instances + if instance.id in instance_ids + ] reservations.append(reservation) found_instance_ids = [ - instance.id for reservation in reservations for instance in reservation.instances] + instance.id + for reservation in reservations + for instance in reservation.instances + ] if len(found_instance_ids) != len(instance_ids): - invalid_id = list(set(instance_ids).difference( - set(found_instance_ids)))[0] + invalid_id = list(set(instance_ids).difference(set(found_instance_ids)))[0] raise InvalidInstanceIdError(invalid_id) if filters is not None: reservations = filter_reservations(reservations, filters) return reservations def all_reservations(self, filters=None): - reservations = [copy.copy(reservation) for reservation in self.reservations.values()] + reservations = [ + copy.copy(reservation) for reservation in self.reservations.values() + ] if filters is not None: reservations = filter_reservations(reservations, filters) return reservations @@ -906,12 +992,12 @@ class KeyPair(object): self.material = material def get_filter_value(self, filter_name): - if filter_name == 'key-name': + if filter_name == "key-name": return self.name - elif filter_name == 'fingerprint': + elif filter_name == "fingerprint": return self.fingerprint else: - raise FilterNotImplementedError(filter_name, 'DescribeKeyPairs') + raise FilterNotImplementedError(filter_name, "DescribeKeyPairs") class KeyPairBackend(object): @@ -934,8 +1020,11 @@ class KeyPairBackend(object): def describe_key_pairs(self, key_names=None, filters=None): results = [] if key_names: - results = [keypair for keypair in self.keypairs.values() - if keypair.name in key_names] + results = [ + keypair + for keypair in self.keypairs.values() + if keypair.name in key_names + ] if len(key_names) > len(results): unknown_keys = set(key_names) - set(results) raise InvalidKeyPairNameError(unknown_keys) @@ -957,35 +1046,35 @@ class KeyPairBackend(object): raise InvalidKeyPairFormatError() fingerprint = rsa_public_key_fingerprint(rsa_public_key) - keypair = KeyPair(key_name, material=public_key_material, fingerprint=fingerprint) + keypair = KeyPair( + key_name, material=public_key_material, fingerprint=fingerprint + ) self.keypairs[key_name] = keypair return keypair class TagBackend(object): - VALID_TAG_FILTERS = ['key', - 'resource-id', - 'resource-type', - 'value'] + VALID_TAG_FILTERS = ["key", "resource-id", "resource-type", "value"] - VALID_TAG_RESOURCE_FILTER_TYPES = ['customer-gateway', - 'dhcp-options', - 'image', - 'instance', - 'internet-gateway', - 'network-acl', - 'network-interface', - 'reserved-instances', - 'route-table', - 'security-group', - 'snapshot', - 'spot-instances-request', - 'subnet', - 'volume', - 'vpc', - 'vpc-peering-connection' - 'vpn-connection', - 'vpn-gateway'] + VALID_TAG_RESOURCE_FILTER_TYPES = [ + "customer-gateway", + "dhcp-options", + "image", + "instance", + "internet-gateway", + "network-acl", + "network-interface", + "reserved-instances", + "route-table", + "security-group", + "snapshot", + "spot-instances-request", + "subnet", + "volume", + "vpc", + "vpc-peering-connection" "vpn-connection", + "vpn-gateway", + ] def __init__(self): self.tags = defaultdict(dict) @@ -996,7 +1085,11 @@ class TagBackend(object): raise InvalidParameterValueErrorTagNull() for resource_id in resource_ids: if resource_id in self.tags: - if len(self.tags[resource_id]) + len([tag for tag in tags if not tag.startswith("aws:")]) > 50: + if ( + len(self.tags[resource_id]) + + len([tag for tag in tags if not tag.startswith("aws:")]) + > 50 + ): raise TagLimitExceeded() elif len([tag for tag in tags if not tag.startswith("aws:")]) > 50: raise TagLimitExceeded() @@ -1017,6 +1110,7 @@ class TagBackend(object): def describe_tags(self, filters=None): import re + results = [] key_filters = [] resource_id_filters = [] @@ -1025,21 +1119,24 @@ class TagBackend(object): if filters is not None: for tag_filter in filters: if tag_filter in self.VALID_TAG_FILTERS: - if tag_filter == 'key': + if tag_filter == "key": for value in filters[tag_filter]: - key_filters.append(re.compile( - simple_aws_filter_to_re(value))) - if tag_filter == 'resource-id': + key_filters.append( + re.compile(simple_aws_filter_to_re(value)) + ) + if tag_filter == "resource-id": for value in filters[tag_filter]: resource_id_filters.append( - re.compile(simple_aws_filter_to_re(value))) - if tag_filter == 'resource-type': + re.compile(simple_aws_filter_to_re(value)) + ) + if tag_filter == "resource-type": for value in filters[tag_filter]: resource_type_filters.append(value) - if tag_filter == 'value': + if tag_filter == "value": for value in filters[tag_filter]: - value_filters.append(re.compile( - simple_aws_filter_to_re(value))) + value_filters.append( + re.compile(simple_aws_filter_to_re(value)) + ) for resource_id, tags in self.tags.items(): for key, value in tags.items(): add_result = False @@ -1064,7 +1161,10 @@ class TagBackend(object): id_pass = True if resource_type_filters: for resource_type in resource_type_filters: - if EC2_PREFIX_TO_RESOURCE[get_prefix(resource_id)] == resource_type: + if ( + EC2_PREFIX_TO_RESOURCE[get_prefix(resource_id)] + == resource_type + ): type_pass = True else: type_pass = True @@ -1079,24 +1179,41 @@ class TagBackend(object): # If we're not filtering, or we are filtering and this if add_result: result = { - 'resource_id': resource_id, - 'key': key, - 'value': value, - 'resource_type': EC2_PREFIX_TO_RESOURCE[get_prefix(resource_id)], + "resource_id": resource_id, + "key": key, + "value": value, + "resource_type": EC2_PREFIX_TO_RESOURCE[ + get_prefix(resource_id) + ], } results.append(result) return results class Ami(TaggedEC2Resource): - def __init__(self, ec2_backend, ami_id, instance=None, source_ami=None, - name=None, description=None, owner_id=OWNER_ID, - public=False, virtualization_type=None, architecture=None, - state='available', creation_date=None, platform=None, - image_type='machine', image_location=None, hypervisor=None, - root_device_type='standard', root_device_name='/dev/sda1', sriov='simple', - region_name='us-east-1a' - ): + def __init__( + self, + ec2_backend, + ami_id, + instance=None, + source_ami=None, + name=None, + description=None, + owner_id=OWNER_ID, + public=False, + virtualization_type=None, + architecture=None, + state="available", + creation_date=None, + platform=None, + image_type="machine", + image_location=None, + hypervisor=None, + root_device_type="standard", + root_device_name="/dev/sda1", + sriov="simple", + region_name="us-east-1a", + ): self.ec2_backend = ec2_backend self.id = ami_id self.state = state @@ -1113,7 +1230,9 @@ class Ami(TaggedEC2Resource): self.root_device_name = root_device_name self.root_device_type = root_device_type self.sriov = sriov - self.creation_date = utc_date_and_time() if creation_date is None else creation_date + self.creation_date = ( + utc_date_and_time() if creation_date is None else creation_date + ) if instance: self.instance = instance @@ -1142,42 +1261,42 @@ class Ami(TaggedEC2Resource): self.launch_permission_users = set() if public: - self.launch_permission_groups.add('all') + self.launch_permission_groups.add("all") # AWS auto-creates these, we should reflect the same. volume = self.ec2_backend.create_volume(15, region_name) self.ebs_snapshot = self.ec2_backend.create_snapshot( - volume.id, "Auto-created snapshot for AMI %s" % self.id, owner_id) + volume.id, "Auto-created snapshot for AMI %s" % self.id, owner_id + ) self.ec2_backend.delete_volume(volume.id) @property def is_public(self): - return 'all' in self.launch_permission_groups + return "all" in self.launch_permission_groups @property def is_public_string(self): return str(self.is_public).lower() def get_filter_value(self, filter_name): - if filter_name == 'virtualization-type': + if filter_name == "virtualization-type": return self.virtualization_type - elif filter_name == 'kernel-id': + elif filter_name == "kernel-id": return self.kernel_id - elif filter_name in ['architecture', 'platform']: + elif filter_name in ["architecture", "platform"]: return getattr(self, filter_name) - elif filter_name == 'image-id': + elif filter_name == "image-id": return self.id - elif filter_name == 'is-public': + elif filter_name == "is-public": return self.is_public_string - elif filter_name == 'state': + elif filter_name == "state": return self.state - elif filter_name == 'name': + elif filter_name == "name": return self.name - elif filter_name == 'owner-id': + elif filter_name == "owner-id": return self.owner_id else: - return super(Ami, self).get_filter_value( - filter_name, 'DescribeImages') + return super(Ami, self).get_filter_value(filter_name, "DescribeImages") class AmiBackend(object): @@ -1193,7 +1312,7 @@ class AmiBackend(object): def _load_amis(self): for ami in AMIS: - ami_id = ami['ami_id'] + ami_id = ami["ami_id"] self.amis[ami_id] = Ami(self, **ami) def create_image(self, instance_id, name=None, description=None, context=None): @@ -1201,35 +1320,51 @@ class AmiBackend(object): ami_id = random_ami_id() instance = self.get_instance(instance_id) - ami = Ami(self, ami_id, instance=instance, source_ami=None, - name=name, description=description, - owner_id=context.get_current_user() if context else OWNER_ID) + ami = Ami( + self, + ami_id, + instance=instance, + source_ami=None, + name=name, + description=description, + owner_id=context.get_current_user() if context else OWNER_ID, + ) self.amis[ami_id] = ami return ami def copy_image(self, source_image_id, source_region, name=None, description=None): source_ami = ec2_backends[source_region].describe_images( - ami_ids=[source_image_id])[0] + ami_ids=[source_image_id] + )[0] ami_id = random_ami_id() - ami = Ami(self, ami_id, instance=None, source_ami=source_ami, - name=name, description=description) + ami = Ami( + self, + ami_id, + instance=None, + source_ami=source_ami, + name=name, + description=description, + ) self.amis[ami_id] = ami return ami - def describe_images(self, ami_ids=(), filters=None, exec_users=None, owners=None, - context=None): + def describe_images( + self, ami_ids=(), filters=None, exec_users=None, owners=None, context=None + ): images = self.amis.values() if len(ami_ids): # boto3 seems to default to just searching based on ami ids if that parameter is passed # and if no images are found, it raises an errors - malformed_ami_ids = [ami_id for ami_id in ami_ids if not ami_id.startswith('ami-')] + malformed_ami_ids = [ + ami_id for ami_id in ami_ids if not ami_id.startswith("ami-") + ] if malformed_ami_ids: raise MalformedAMIIdError(malformed_ami_ids) images = [ami for ami in images if ami.id in ami_ids] if len(images) == 0: - raise InvalidAMIIdError(ami_ids) + raise InvalidAMIIdError(ami_ids) else: # Limit images by launch permissions if exec_users: @@ -1243,10 +1378,14 @@ class AmiBackend(object): # Limit by owner ids if owners: # support filtering by Owners=['self'] - owners = list(map( - lambda o: context.get_current_user() - if context and o == 'self' else o, - owners)) + owners = list( + map( + lambda o: context.get_current_user() + if context and o == "self" + else o, + owners, + ) + ) images = [ami for ami in images if ami.owner_id in owners] # Generic filters @@ -1281,7 +1420,7 @@ class AmiBackend(object): if len(user_id) != 12 or not user_id.isdigit(): raise InvalidAMIAttributeItemValueError("userId", user_id) - if group and group != 'all': + if group and group != "all": raise InvalidAMIAttributeItemValueError("UserGroup", group) def add_launch_permission(self, ami_id, user_ids=None, group=None): @@ -1328,96 +1467,150 @@ class RegionsAndZonesBackend(object): regions = [Region(ri.name, ri.endpoint) for ri in boto.ec2.regions()] zones = { - 'ap-south-1': [ + "ap-south-1": [ Zone(region_name="ap-south-1", name="ap-south-1a", zone_id="aps1-az1"), - Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3") + Zone(region_name="ap-south-1", name="ap-south-1b", zone_id="aps1-az3"), ], - 'eu-west-3': [ + "eu-west-3": [ Zone(region_name="eu-west-3", name="eu-west-3a", zone_id="euw3-az1"), Zone(region_name="eu-west-3", name="eu-west-3b", zone_id="euw3-az2"), - Zone(region_name="eu-west-3", name="eu-west-3c", zone_id="euw3-az3") + Zone(region_name="eu-west-3", name="eu-west-3c", zone_id="euw3-az3"), ], - 'eu-north-1': [ + "eu-north-1": [ Zone(region_name="eu-north-1", name="eu-north-1a", zone_id="eun1-az1"), Zone(region_name="eu-north-1", name="eu-north-1b", zone_id="eun1-az2"), - Zone(region_name="eu-north-1", name="eu-north-1c", zone_id="eun1-az3") + Zone(region_name="eu-north-1", name="eu-north-1c", zone_id="eun1-az3"), ], - 'eu-west-2': [ + "eu-west-2": [ Zone(region_name="eu-west-2", name="eu-west-2a", zone_id="euw2-az2"), Zone(region_name="eu-west-2", name="eu-west-2b", zone_id="euw2-az3"), - Zone(region_name="eu-west-2", name="eu-west-2c", zone_id="euw2-az1") + Zone(region_name="eu-west-2", name="eu-west-2c", zone_id="euw2-az1"), ], - 'eu-west-1': [ + "eu-west-1": [ Zone(region_name="eu-west-1", name="eu-west-1a", zone_id="euw1-az3"), Zone(region_name="eu-west-1", name="eu-west-1b", zone_id="euw1-az1"), - Zone(region_name="eu-west-1", name="eu-west-1c", zone_id="euw1-az2") + Zone(region_name="eu-west-1", name="eu-west-1c", zone_id="euw1-az2"), ], - 'ap-northeast-3': [ - Zone(region_name="ap-northeast-3", name="ap-northeast-2a", zone_id="apne3-az1") + "ap-northeast-3": [ + Zone( + region_name="ap-northeast-3", + name="ap-northeast-2a", + zone_id="apne3-az1", + ) ], - 'ap-northeast-2': [ - Zone(region_name="ap-northeast-2", name="ap-northeast-2a", zone_id="apne2-az1"), - Zone(region_name="ap-northeast-2", name="ap-northeast-2c", zone_id="apne2-az3") + "ap-northeast-2": [ + Zone( + region_name="ap-northeast-2", + name="ap-northeast-2a", + zone_id="apne2-az1", + ), + Zone( + region_name="ap-northeast-2", + name="ap-northeast-2c", + zone_id="apne2-az3", + ), ], - 'ap-northeast-1': [ - Zone(region_name="ap-northeast-1", name="ap-northeast-1a", zone_id="apne1-az4"), - Zone(region_name="ap-northeast-1", name="ap-northeast-1c", zone_id="apne1-az1"), - Zone(region_name="ap-northeast-1", name="ap-northeast-1d", zone_id="apne1-az2") + "ap-northeast-1": [ + Zone( + region_name="ap-northeast-1", + name="ap-northeast-1a", + zone_id="apne1-az4", + ), + Zone( + region_name="ap-northeast-1", + name="ap-northeast-1c", + zone_id="apne1-az1", + ), + Zone( + region_name="ap-northeast-1", + name="ap-northeast-1d", + zone_id="apne1-az2", + ), ], - 'sa-east-1': [ + "sa-east-1": [ Zone(region_name="sa-east-1", name="sa-east-1a", zone_id="sae1-az1"), - Zone(region_name="sa-east-1", name="sa-east-1c", zone_id="sae1-az3") + Zone(region_name="sa-east-1", name="sa-east-1c", zone_id="sae1-az3"), ], - 'ca-central-1': [ + "ca-central-1": [ Zone(region_name="ca-central-1", name="ca-central-1a", zone_id="cac1-az1"), - Zone(region_name="ca-central-1", name="ca-central-1b", zone_id="cac1-az2") + Zone(region_name="ca-central-1", name="ca-central-1b", zone_id="cac1-az2"), ], - 'ap-southeast-1': [ - Zone(region_name="ap-southeast-1", name="ap-southeast-1a", zone_id="apse1-az1"), - Zone(region_name="ap-southeast-1", name="ap-southeast-1b", zone_id="apse1-az2"), - Zone(region_name="ap-southeast-1", name="ap-southeast-1c", zone_id="apse1-az3") + "ap-southeast-1": [ + Zone( + region_name="ap-southeast-1", + name="ap-southeast-1a", + zone_id="apse1-az1", + ), + Zone( + region_name="ap-southeast-1", + name="ap-southeast-1b", + zone_id="apse1-az2", + ), + Zone( + region_name="ap-southeast-1", + name="ap-southeast-1c", + zone_id="apse1-az3", + ), ], - 'ap-southeast-2': [ - Zone(region_name="ap-southeast-2", name="ap-southeast-2a", zone_id="apse2-az1"), - Zone(region_name="ap-southeast-2", name="ap-southeast-2b", zone_id="apse2-az3"), - Zone(region_name="ap-southeast-2", name="ap-southeast-2c", zone_id="apse2-az2") + "ap-southeast-2": [ + Zone( + region_name="ap-southeast-2", + name="ap-southeast-2a", + zone_id="apse2-az1", + ), + Zone( + region_name="ap-southeast-2", + name="ap-southeast-2b", + zone_id="apse2-az3", + ), + Zone( + region_name="ap-southeast-2", + name="ap-southeast-2c", + zone_id="apse2-az2", + ), ], - 'eu-central-1': [ + "eu-central-1": [ Zone(region_name="eu-central-1", name="eu-central-1a", zone_id="euc1-az2"), Zone(region_name="eu-central-1", name="eu-central-1b", zone_id="euc1-az3"), - Zone(region_name="eu-central-1", name="eu-central-1c", zone_id="euc1-az1") + Zone(region_name="eu-central-1", name="eu-central-1c", zone_id="euc1-az1"), ], - 'us-east-1': [ + "us-east-1": [ Zone(region_name="us-east-1", name="us-east-1a", zone_id="use1-az6"), Zone(region_name="us-east-1", name="us-east-1b", zone_id="use1-az1"), Zone(region_name="us-east-1", name="us-east-1c", zone_id="use1-az2"), Zone(region_name="us-east-1", name="us-east-1d", zone_id="use1-az4"), Zone(region_name="us-east-1", name="us-east-1e", zone_id="use1-az3"), - Zone(region_name="us-east-1", name="us-east-1f", zone_id="use1-az5") + Zone(region_name="us-east-1", name="us-east-1f", zone_id="use1-az5"), ], - 'us-east-2': [ + "us-east-2": [ Zone(region_name="us-east-2", name="us-east-2a", zone_id="use2-az1"), Zone(region_name="us-east-2", name="us-east-2b", zone_id="use2-az2"), - Zone(region_name="us-east-2", name="us-east-2c", zone_id="use2-az3") + Zone(region_name="us-east-2", name="us-east-2c", zone_id="use2-az3"), ], - 'us-west-1': [ + "us-west-1": [ Zone(region_name="us-west-1", name="us-west-1a", zone_id="usw1-az3"), - Zone(region_name="us-west-1", name="us-west-1b", zone_id="usw1-az1") + Zone(region_name="us-west-1", name="us-west-1b", zone_id="usw1-az1"), ], - 'us-west-2': [ + "us-west-2": [ Zone(region_name="us-west-2", name="us-west-2a", zone_id="usw2-az2"), Zone(region_name="us-west-2", name="us-west-2b", zone_id="usw2-az1"), - Zone(region_name="us-west-2", name="us-west-2c", zone_id="usw2-az3") + Zone(region_name="us-west-2", name="us-west-2c", zone_id="usw2-az3"), ], - 'cn-north-1': [ + "cn-north-1": [ Zone(region_name="cn-north-1", name="cn-north-1a", zone_id="cnn1-az1"), - Zone(region_name="cn-north-1", name="cn-north-1b", zone_id="cnn1-az2") + Zone(region_name="cn-north-1", name="cn-north-1b", zone_id="cnn1-az2"), + ], + "us-gov-west-1": [ + Zone( + region_name="us-gov-west-1", name="us-gov-west-1a", zone_id="usgw1-az1" + ), + Zone( + region_name="us-gov-west-1", name="us-gov-west-1b", zone_id="usgw1-az2" + ), + Zone( + region_name="us-gov-west-1", name="us-gov-west-1c", zone_id="usgw1-az3" + ), ], - 'us-gov-west-1': [ - Zone(region_name="us-gov-west-1", name="us-gov-west-1a", zone_id="usgw1-az1"), - Zone(region_name="us-gov-west-1", name="us-gov-west-1b", zone_id="usgw1-az2"), - Zone(region_name="us-gov-west-1", name="us-gov-west-1c", zone_id="usgw1-az3") - ] } def describe_regions(self, region_names=[]): @@ -1454,7 +1647,7 @@ class SecurityRule(object): self.from_port, self.to_port, self.ip_ranges, - self.source_groups + self.source_groups, ) def __eq__(self, other): @@ -1468,20 +1661,22 @@ class SecurityGroup(TaggedEC2Resource): self.name = name self.description = description self.ingress_rules = [] - self.egress_rules = [SecurityRule(-1, None, None, ['0.0.0.0/0'], [])] + self.egress_rules = [SecurityRule(-1, None, None, ["0.0.0.0/0"], [])] self.enis = {} self.vpc_id = vpc_id self.owner_id = OWNER_ID @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - vpc_id = properties.get('VpcId') + vpc_id = properties.get("VpcId") security_group = ec2_backend.create_security_group( name=resource_name, - description=properties.get('GroupDescription'), + description=properties.get("GroupDescription"), vpc_id=vpc_id, ) @@ -1490,15 +1685,15 @@ class SecurityGroup(TaggedEC2Resource): tag_value = tag["Value"] security_group.add_tag(tag_key, tag_value) - for ingress_rule in properties.get('SecurityGroupIngress', []): - source_group_id = ingress_rule.get('SourceSecurityGroupId') + for ingress_rule in properties.get("SecurityGroupIngress", []): + source_group_id = ingress_rule.get("SourceSecurityGroupId") ec2_backend.authorize_security_group_ingress( group_name_or_id=security_group.id, - ip_protocol=ingress_rule['IpProtocol'], - from_port=ingress_rule['FromPort'], - to_port=ingress_rule['ToPort'], - ip_ranges=ingress_rule.get('CidrIp'), + ip_protocol=ingress_rule["IpProtocol"], + from_port=ingress_rule["FromPort"], + to_port=ingress_rule["ToPort"], + ip_ranges=ingress_rule.get("CidrIp"), source_group_ids=[source_group_id], vpc_id=vpc_id, ) @@ -1506,28 +1701,33 @@ class SecurityGroup(TaggedEC2Resource): return security_group @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls._delete_security_group_given_vpc_id( - original_resource.name, original_resource.vpc_id, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, original_resource.vpc_id, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - vpc_id = properties.get('VpcId') - cls._delete_security_group_given_vpc_id( - resource_name, vpc_id, region_name) + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + vpc_id = properties.get("VpcId") + cls._delete_security_group_given_vpc_id(resource_name, vpc_id, region_name) @classmethod def _delete_security_group_given_vpc_id(cls, resource_name, vpc_id, region_name): ec2_backend = ec2_backends[region_name] - security_group = ec2_backend.get_security_group_from_name( - resource_name, vpc_id) + security_group = ec2_backend.get_security_group_from_name(resource_name, vpc_id) if security_group: security_group.delete(region_name) def delete(self, region_name): - ''' Not exposed as part of the ELB API - used for CloudFormation. ''' + """ Not exposed as part of the ELB API - used for CloudFormation. """ self.ec2_backend.delete_security_group(group_id=self.id) @property @@ -1538,18 +1738,18 @@ class SecurityGroup(TaggedEC2Resource): def to_attr(filter_name): attr = None - if filter_name == 'group-name': - attr = 'name' - elif filter_name == 'group-id': - attr = 'id' - elif filter_name == 'vpc-id': - attr = 'vpc_id' + if filter_name == "group-name": + attr = "name" + elif filter_name == "group-id": + attr = "id" + elif filter_name == "vpc-id": + attr = "vpc_id" else: - attr = filter_name.replace('-', '_') + attr = filter_name.replace("-", "_") return attr - if key.startswith('ip-permission'): + if key.startswith("ip-permission"): match = re.search(r"ip-permission.(*)", key) ingress_attr = to_attr(match.groups()[0]) @@ -1575,7 +1775,8 @@ class SecurityGroup(TaggedEC2Resource): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'GroupId': + + if attribute_name == "GroupId": return self.id raise UnformattedGetAttTemplateException() @@ -1590,13 +1791,13 @@ class SecurityGroup(TaggedEC2Resource): def get_number_of_ingress_rules(self): return sum( - len(rule.ip_ranges) + len(rule.source_groups) - for rule in self.ingress_rules) + len(rule.ip_ranges) + len(rule.source_groups) for rule in self.ingress_rules + ) def get_number_of_egress_rules(self): return sum( - len(rule.ip_ranges) + len(rule.source_groups) - for rule in self.egress_rules) + len(rule.ip_ranges) + len(rule.source_groups) for rule in self.egress_rules + ) class SecurityGroupBackend(object): @@ -1611,7 +1812,7 @@ class SecurityGroupBackend(object): def create_security_group(self, name, description, vpc_id=None, force=False): if not description: - raise MissingParameterError('GroupDescription') + raise MissingParameterError("GroupDescription") group_id = random_security_group_id() if not force: @@ -1624,30 +1825,27 @@ class SecurityGroupBackend(object): return group def describe_security_groups(self, group_ids=None, groupnames=None, filters=None): - matches = itertools.chain(*[x.values() - for x in self.groups.values()]) + matches = itertools.chain(*[x.values() for x in self.groups.values()]) if group_ids: - matches = [grp for grp in matches - if grp.id in group_ids] + matches = [grp for grp in matches if grp.id in group_ids] if len(group_ids) > len(matches): unknown_ids = set(group_ids) - set(matches) raise InvalidSecurityGroupNotFoundError(unknown_ids) if groupnames: - matches = [grp for grp in matches - if grp.name in groupnames] + matches = [grp for grp in matches if grp.name in groupnames] if len(groupnames) > len(matches): unknown_names = set(groupnames) - set(matches) raise InvalidSecurityGroupNotFoundError(unknown_names) if filters: - matches = [grp for grp in matches - if grp.matches_filters(filters)] + matches = [grp for grp in matches if grp.matches_filters(filters)] return matches def _delete_security_group(self, vpc_id, group_id): if self.groups[vpc_id][group_id].enis: raise DependencyViolationError( - "{0} is being utilized by {1}".format(group_id, 'ENIs')) + "{0} is being utilized by {1}".format(group_id, "ENIs") + ) return self.groups[vpc_id].pop(group_id) def delete_security_group(self, name=None, group_id=None): @@ -1668,7 +1866,8 @@ class SecurityGroupBackend(object): def get_security_group_from_id(self, group_id): # 2 levels of chaining necessary since it's a complex structure all_groups = itertools.chain.from_iterable( - [x.values() for x in self.groups.values()]) + [x.values() for x in self.groups.values()] + ) for group in all_groups: if group.id == group_id: return group @@ -1685,15 +1884,17 @@ class SecurityGroupBackend(object): group = self.get_security_group_from_name(group_name_or_id, vpc_id) return group - def authorize_security_group_ingress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def authorize_security_group_ingress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) if ip_ranges and not isinstance(ip_ranges, list): ip_ranges = [ip_ranges] @@ -1703,16 +1904,19 @@ class SecurityGroupBackend(object): raise InvalidCIDRSubnetError(cidr=cidr) self._verify_group_will_respect_rule_count_limit( - group, group.get_number_of_ingress_rules(), - ip_ranges, source_group_names, source_group_ids) + group, + group.get_number_of_ingress_rules(), + ip_ranges, + source_group_names, + source_group_ids, + ) source_group_names = source_group_names if source_group_names else [] source_group_ids = source_group_ids if source_group_ids else [] source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1723,25 +1927,27 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) group.add_ingress_rule(security_rule) - def revoke_security_group_ingress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def revoke_security_group_ingress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1751,21 +1957,24 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) if security_rule in group.ingress_rules: group.ingress_rules.remove(security_rule) return security_rule raise InvalidPermissionNotFoundError() - def authorize_security_group_egress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def authorize_security_group_egress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) if ip_ranges and not isinstance(ip_ranges, list): @@ -1776,16 +1985,19 @@ class SecurityGroupBackend(object): raise InvalidCIDRSubnetError(cidr=cidr) self._verify_group_will_respect_rule_count_limit( - group, group.get_number_of_egress_rules(), - ip_ranges, source_group_names, source_group_ids) + group, + group.get_number_of_egress_rules(), + ip_ranges, + source_group_names, + source_group_ids, + ) source_group_names = source_group_names if source_group_names else [] source_group_ids = source_group_ids if source_group_ids else [] source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1796,25 +2008,27 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) group.add_egress_rule(security_rule) - def revoke_security_group_egress(self, - group_name_or_id, - ip_protocol, - from_port, - to_port, - ip_ranges, - source_group_names=None, - source_group_ids=None, - vpc_id=None): + def revoke_security_group_egress( + self, + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_group_names=None, + source_group_ids=None, + vpc_id=None, + ): group = self.get_security_group_by_name_or_id(group_name_or_id, vpc_id) source_groups = [] for source_group_name in source_group_names: - source_group = self.get_security_group_from_name( - source_group_name, vpc_id) + source_group = self.get_security_group_from_name(source_group_name, vpc_id) if source_group: source_groups.append(source_group) @@ -1824,15 +2038,21 @@ class SecurityGroupBackend(object): source_groups.append(source_group) security_rule = SecurityRule( - ip_protocol, from_port, to_port, ip_ranges, source_groups) + ip_protocol, from_port, to_port, ip_ranges, source_groups + ) if security_rule in group.egress_rules: group.egress_rules.remove(security_rule) return security_rule raise InvalidPermissionNotFoundError() def _verify_group_will_respect_rule_count_limit( - self, group, current_rule_nb, - ip_ranges, source_group_names=None, source_group_ids=None): + self, + group, + current_rule_nb, + ip_ranges, + source_group_names=None, + source_group_ids=None, + ): max_nb_rules = 50 if group.vpc_id else 100 future_group_nb_rules = current_rule_nb if ip_ranges: @@ -1851,12 +2071,14 @@ class SecurityGroupIngress(object): self.properties = properties @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - group_name = properties.get('GroupName') - group_id = properties.get('GroupId') + group_name = properties.get("GroupName") + group_id = properties.get("GroupId") ip_protocol = properties.get("IpProtocol") cidr_ip = properties.get("CidrIp") cidr_ipv6 = properties.get("CidrIpv6") @@ -1868,7 +2090,12 @@ class SecurityGroupIngress(object): to_port = properties.get("ToPort") assert group_id or group_name - assert source_security_group_name or cidr_ip or cidr_ipv6 or source_security_group_id + assert ( + source_security_group_name + or cidr_ip + or cidr_ipv6 + or source_security_group_id + ) assert ip_protocol if source_security_group_id: @@ -1886,10 +2113,12 @@ class SecurityGroupIngress(object): if group_id: security_group = ec2_backend.describe_security_groups(group_ids=[group_id])[ - 0] + 0 + ] else: security_group = ec2_backend.describe_security_groups( - groupnames=[group_name])[0] + groupnames=[group_name] + )[0] ec2_backend.authorize_security_group_ingress( group_name_or_id=security_group.id, @@ -1913,23 +2142,27 @@ class VolumeAttachment(object): self.status = status @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - instance_id = properties['InstanceId'] - volume_id = properties['VolumeId'] + instance_id = properties["InstanceId"] + volume_id = properties["VolumeId"] ec2_backend = ec2_backends[region_name] attachment = ec2_backend.attach_volume( volume_id=volume_id, instance_id=instance_id, - device_path=properties['Device'], + device_path=properties["Device"], ) return attachment class Volume(TaggedEC2Resource): - def __init__(self, ec2_backend, volume_id, size, zone, snapshot_id=None, encrypted=False): + def __init__( + self, ec2_backend, volume_id, size, zone, snapshot_id=None, encrypted=False + ): self.id = volume_id self.size = size self.zone = zone @@ -1940,13 +2173,14 @@ class Volume(TaggedEC2Resource): self.encrypted = encrypted @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] volume = ec2_backend.create_volume( - size=properties.get('Size'), - zone_name=properties.get('AvailabilityZone'), + size=properties.get("Size"), zone_name=properties.get("AvailabilityZone") ) return volume @@ -1957,72 +2191,80 @@ class Volume(TaggedEC2Resource): @property def status(self): if self.attachment: - return 'in-use' + return "in-use" else: - return 'available' + return "available" def get_filter_value(self, filter_name): - if filter_name.startswith('attachment') and not self.attachment: + if filter_name.startswith("attachment") and not self.attachment: return None - elif filter_name == 'attachment.attach-time': + elif filter_name == "attachment.attach-time": return self.attachment.attach_time - elif filter_name == 'attachment.device': + elif filter_name == "attachment.device": return self.attachment.device - elif filter_name == 'attachment.instance-id': + elif filter_name == "attachment.instance-id": return self.attachment.instance.id - elif filter_name == 'attachment.status': + elif filter_name == "attachment.status": return self.attachment.status - elif filter_name == 'create-time': + elif filter_name == "create-time": return self.create_time - elif filter_name == 'size': + elif filter_name == "size": return self.size - elif filter_name == 'snapshot-id': + elif filter_name == "snapshot-id": return self.snapshot_id - elif filter_name == 'status': + elif filter_name == "status": return self.status - elif filter_name == 'volume-id': + elif filter_name == "volume-id": return self.id - elif filter_name == 'encrypted': + elif filter_name == "encrypted": return str(self.encrypted).lower() - elif filter_name == 'availability-zone': + elif filter_name == "availability-zone": return self.zone.name else: - return super(Volume, self).get_filter_value( - filter_name, 'DescribeVolumes') + return super(Volume, self).get_filter_value(filter_name, "DescribeVolumes") class Snapshot(TaggedEC2Resource): - def __init__(self, ec2_backend, snapshot_id, volume, description, encrypted=False, owner_id=OWNER_ID): + def __init__( + self, + ec2_backend, + snapshot_id, + volume, + description, + encrypted=False, + owner_id=OWNER_ID, + ): self.id = snapshot_id self.volume = volume self.description = description self.start_time = utc_date_and_time() self.create_volume_permission_groups = set() self.ec2_backend = ec2_backend - self.status = 'completed' + self.status = "completed" self.encrypted = encrypted self.owner_id = owner_id def get_filter_value(self, filter_name): - if filter_name == 'description': + if filter_name == "description": return self.description - elif filter_name == 'snapshot-id': + elif filter_name == "snapshot-id": return self.id - elif filter_name == 'start-time': + elif filter_name == "start-time": return self.start_time - elif filter_name == 'volume-id': + elif filter_name == "volume-id": return self.volume.id - elif filter_name == 'volume-size': + elif filter_name == "volume-size": return self.volume.size - elif filter_name == 'encrypted': + elif filter_name == "encrypted": return str(self.encrypted).lower() - elif filter_name == 'status': + elif filter_name == "status": return self.status - elif filter_name == 'owner-id': + elif filter_name == "owner-id": return self.owner_id else: return super(Snapshot, self).get_filter_value( - filter_name, 'DescribeSnapshots') + filter_name, "DescribeSnapshots" + ) class EBSBackend(object): @@ -2048,8 +2290,7 @@ class EBSBackend(object): def describe_volumes(self, volume_ids=None, filters=None): matches = self.volumes.values() if volume_ids: - matches = [vol for vol in matches - if vol.id in volume_ids] + matches = [vol for vol in matches if vol.id in volume_ids] if len(volume_ids) > len(matches): unknown_ids = set(volume_ids) - set(matches) raise InvalidVolumeIdError(unknown_ids) @@ -2075,11 +2316,14 @@ class EBSBackend(object): if not volume or not instance: return False - volume.attachment = VolumeAttachment( - volume, instance, device_path, 'attached') + volume.attachment = VolumeAttachment(volume, instance, device_path, "attached") # Modify instance to capture mount of block device. - bdt = BlockDeviceType(volume_id=volume_id, status=volume.status, size=volume.size, - attach_time=utc_date_and_time()) + bdt = BlockDeviceType( + volume_id=volume_id, + status=volume.status, + size=volume.size, + attach_time=utc_date_and_time(), + ) instance.block_device_mapping[device_path] = bdt return volume.attachment @@ -2090,7 +2334,7 @@ class EBSBackend(object): old_attachment = volume.attachment if not old_attachment: raise InvalidVolumeAttachmentError(volume_id, instance_id) - old_attachment.status = 'detached' + old_attachment.status = "detached" volume.attachment = None return old_attachment @@ -2108,8 +2352,7 @@ class EBSBackend(object): def describe_snapshots(self, snapshot_ids=None, filters=None): matches = self.snapshots.values() if snapshot_ids: - matches = [snap for snap in matches - if snap.id in snapshot_ids] + matches = [snap for snap in matches if snap.id in snapshot_ids] if len(snapshot_ids) > len(matches): unknown_ids = set(snapshot_ids) - set(matches) raise InvalidSnapshotIdError(unknown_ids) @@ -2119,10 +2362,16 @@ class EBSBackend(object): def copy_snapshot(self, source_snapshot_id, source_region, description=None): source_snapshot = ec2_backends[source_region].describe_snapshots( - snapshot_ids=[source_snapshot_id])[0] + snapshot_ids=[source_snapshot_id] + )[0] snapshot_id = random_snapshot_id() - snapshot = Snapshot(self, snapshot_id, volume=source_snapshot.volume, - description=description, encrypted=source_snapshot.encrypted) + snapshot = Snapshot( + self, + snapshot_id, + volume=source_snapshot.volume, + description=description, + encrypted=source_snapshot.encrypted, + ) self.snapshots[snapshot_id] = snapshot return snapshot @@ -2144,9 +2393,10 @@ class EBSBackend(object): def add_create_volume_permission(self, snapshot_id, user_id=None, group=None): if user_id: self.raise_not_implemented_error( - "The UserId parameter for ModifySnapshotAttribute") + "The UserId parameter for ModifySnapshotAttribute" + ) - if group != 'all': + if group != "all": raise InvalidAMIAttributeItemValueError("UserGroup", group) snapshot = self.get_snapshot(snapshot_id) snapshot.create_volume_permission_groups.add(group) @@ -2155,9 +2405,10 @@ class EBSBackend(object): def remove_create_volume_permission(self, snapshot_id, user_id=None, group=None): if user_id: self.raise_not_implemented_error( - "The UserId parameter for ModifySnapshotAttribute") + "The UserId parameter for ModifySnapshotAttribute" + ) - if group != 'all': + if group != "all": raise InvalidAMIAttributeItemValueError("UserGroup", group) snapshot = self.get_snapshot(snapshot_id) snapshot.create_volume_permission_groups.discard(group) @@ -2165,34 +2416,46 @@ class EBSBackend(object): class VPC(TaggedEC2Resource): - def __init__(self, ec2_backend, vpc_id, cidr_block, is_default, instance_tenancy='default', - amazon_provided_ipv6_cidr_block=False): + def __init__( + self, + ec2_backend, + vpc_id, + cidr_block, + is_default, + instance_tenancy="default", + amazon_provided_ipv6_cidr_block=False, + ): self.ec2_backend = ec2_backend self.id = vpc_id self.cidr_block = cidr_block self.cidr_block_association_set = {} self.dhcp_options = None - self.state = 'available' + self.state = "available" self.instance_tenancy = instance_tenancy - self.is_default = 'true' if is_default else 'false' - self.enable_dns_support = 'true' + self.is_default = "true" if is_default else "false" + self.enable_dns_support = "true" # This attribute is set to 'true' only for default VPCs # or VPCs created using the wizard of the VPC console - self.enable_dns_hostnames = 'true' if is_default else 'false' + self.enable_dns_hostnames = "true" if is_default else "false" self.associate_vpc_cidr_block(cidr_block) if amazon_provided_ipv6_cidr_block: - self.associate_vpc_cidr_block(cidr_block, amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_block) + self.associate_vpc_cidr_block( + cidr_block, + amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_block, + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] vpc = ec2_backend.create_vpc( - cidr_block=properties['CidrBlock'], - instance_tenancy=properties.get('InstanceTenancy', 'default') + cidr_block=properties["CidrBlock"], + instance_tenancy=properties.get("InstanceTenancy", "default"), ) for tag in properties.get("Tags", []): tag_key = tag["Key"] @@ -2206,58 +2469,86 @@ class VPC(TaggedEC2Resource): return self.id def get_filter_value(self, filter_name): - if filter_name in ('vpc-id', 'vpcId'): + if filter_name in ("vpc-id", "vpcId"): return self.id - elif filter_name in ('cidr', 'cidr-block', 'cidrBlock'): + elif filter_name in ("cidr", "cidr-block", "cidrBlock"): return self.cidr_block - elif filter_name in ('cidr-block-association.cidr-block', 'ipv6-cidr-block-association.ipv6-cidr-block'): - return [c['cidr_block'] for c in self.get_cidr_block_association_set(ipv6='ipv6' in filter_name)] - elif filter_name in ('cidr-block-association.association-id', 'ipv6-cidr-block-association.association-id'): + elif filter_name in ( + "cidr-block-association.cidr-block", + "ipv6-cidr-block-association.ipv6-cidr-block", + ): + return [ + c["cidr_block"] + for c in self.get_cidr_block_association_set(ipv6="ipv6" in filter_name) + ] + elif filter_name in ( + "cidr-block-association.association-id", + "ipv6-cidr-block-association.association-id", + ): return self.cidr_block_association_set.keys() - elif filter_name in ('cidr-block-association.state', 'ipv6-cidr-block-association.state'): - return [c['cidr_block_state']['state'] for c in self.get_cidr_block_association_set(ipv6='ipv6' in filter_name)] - elif filter_name in ('instance_tenancy', 'InstanceTenancy'): + elif filter_name in ( + "cidr-block-association.state", + "ipv6-cidr-block-association.state", + ): + return [ + c["cidr_block_state"]["state"] + for c in self.get_cidr_block_association_set(ipv6="ipv6" in filter_name) + ] + elif filter_name in ("instance_tenancy", "InstanceTenancy"): return self.instance_tenancy - elif filter_name in ('is-default', 'isDefault'): + elif filter_name in ("is-default", "isDefault"): return self.is_default - elif filter_name == 'state': + elif filter_name == "state": return self.state - elif filter_name in ('dhcp-options-id', 'dhcpOptionsId'): + elif filter_name in ("dhcp-options-id", "dhcpOptionsId"): if not self.dhcp_options: return None return self.dhcp_options.id else: - return super(VPC, self).get_filter_value(filter_name, 'DescribeVpcs') + return super(VPC, self).get_filter_value(filter_name, "DescribeVpcs") - def associate_vpc_cidr_block(self, cidr_block, amazon_provided_ipv6_cidr_block=False): + def associate_vpc_cidr_block( + self, cidr_block, amazon_provided_ipv6_cidr_block=False + ): max_associations = 5 if not amazon_provided_ipv6_cidr_block else 1 - if len(self.get_cidr_block_association_set(amazon_provided_ipv6_cidr_block)) >= max_associations: + if ( + len(self.get_cidr_block_association_set(amazon_provided_ipv6_cidr_block)) + >= max_associations + ): raise CidrLimitExceeded(self.id, max_associations) association_id = random_vpc_cidr_association_id() association_set = { - 'association_id': association_id, - 'cidr_block_state': {'state': 'associated', 'StatusMessage': ''} + "association_id": association_id, + "cidr_block_state": {"state": "associated", "StatusMessage": ""}, } - association_set['cidr_block'] = random_ipv6_cidr() if amazon_provided_ipv6_cidr_block else cidr_block + association_set["cidr_block"] = ( + random_ipv6_cidr() if amazon_provided_ipv6_cidr_block else cidr_block + ) self.cidr_block_association_set[association_id] = association_set return association_set def disassociate_vpc_cidr_block(self, association_id): - if self.cidr_block == self.cidr_block_association_set.get(association_id, {}).get('cidr_block'): + if self.cidr_block == self.cidr_block_association_set.get( + association_id, {} + ).get("cidr_block"): raise OperationNotPermitted(association_id) response = self.cidr_block_association_set.pop(association_id, {}) if response: - response['vpc_id'] = self.id - response['cidr_block_state']['state'] = 'disassociating' + response["vpc_id"] = self.id + response["cidr_block_state"]["state"] = "disassociating" return response def get_cidr_block_association_set(self, ipv6=False): - return [c for c in self.cidr_block_association_set.values() if ('::/' if ipv6 else '.') in c.get('cidr_block')] + return [ + c + for c in self.cidr_block_association_set.values() + if ("::/" if ipv6 else ".") in c.get("cidr_block") + ] class VPCBackend(object): @@ -2275,15 +2566,29 @@ class VPCBackend(object): if inst is not None: yield inst - def create_vpc(self, cidr_block, instance_tenancy='default', amazon_provided_ipv6_cidr_block=False): + def create_vpc( + self, + cidr_block, + instance_tenancy="default", + amazon_provided_ipv6_cidr_block=False, + ): vpc_id = random_vpc_id() try: - vpc_cidr_block = ipaddress.IPv4Network(six.text_type(cidr_block), strict=False) + vpc_cidr_block = ipaddress.IPv4Network( + six.text_type(cidr_block), strict=False + ) except ValueError: raise InvalidCIDRBlockParameterError(cidr_block) if vpc_cidr_block.prefixlen < 16 or vpc_cidr_block.prefixlen > 28: raise InvalidVPCRangeError(cidr_block) - vpc = VPC(self, vpc_id, cidr_block, len(self.vpcs) == 0, instance_tenancy, amazon_provided_ipv6_cidr_block) + vpc = VPC( + self, + vpc_id, + cidr_block, + len(self.vpcs) == 0, + instance_tenancy, + amazon_provided_ipv6_cidr_block, + ) self.vpcs[vpc_id] = vpc # AWS creates a default main route table and security group. @@ -2292,10 +2597,11 @@ class VPCBackend(object): # AWS creates a default Network ACL self.create_network_acl(vpc_id, default=True) - default = self.get_security_group_from_name('default', vpc_id=vpc_id) + default = self.get_security_group_from_name("default", vpc_id=vpc_id) if not default: self.create_security_group( - 'default', 'default VPC security group', vpc_id=vpc_id) + "default", "default VPC security group", vpc_id=vpc_id + ) return vpc @@ -2314,8 +2620,7 @@ class VPCBackend(object): def get_all_vpcs(self, vpc_ids=None, filters=None): matches = self.vpcs.values() if vpc_ids: - matches = [vpc for vpc in matches - if vpc.id in vpc_ids] + matches = [vpc for vpc in matches if vpc.id in vpc_ids] if len(vpc_ids) > len(matches): unknown_ids = set(vpc_ids) - set(matches) raise InvalidVPCIdError(unknown_ids) @@ -2325,7 +2630,7 @@ class VPCBackend(object): def delete_vpc(self, vpc_id): # Delete route table if only main route table remains. - route_tables = self.get_all_route_tables(filters={'vpc-id': vpc_id}) + route_tables = self.get_all_route_tables(filters={"vpc-id": vpc_id}) if len(route_tables) > 1: raise DependencyViolationError( "The vpc {0} has dependencies and cannot be deleted.".format(vpc_id) @@ -2334,7 +2639,7 @@ class VPCBackend(object): self.delete_route_table(route_table.id) # Delete default security group if exists. - default = self.get_security_group_from_name('default', vpc_id=vpc_id) + default = self.get_security_group_from_name("default", vpc_id=vpc_id) if default: self.delete_security_group(group_id=default.id) @@ -2351,14 +2656,14 @@ class VPCBackend(object): def describe_vpc_attribute(self, vpc_id, attr_name): vpc = self.get_vpc(vpc_id) - if attr_name in ('enable_dns_support', 'enable_dns_hostnames'): + if attr_name in ("enable_dns_support", "enable_dns_hostnames"): return getattr(vpc, attr_name) else: raise InvalidParameterValueError(attr_name) def modify_vpc_attribute(self, vpc_id, attr_name, attr_value): vpc = self.get_vpc(vpc_id) - if attr_name in ('enable_dns_support', 'enable_dns_hostnames'): + if attr_name in ("enable_dns_support", "enable_dns_hostnames"): setattr(vpc, attr_name, attr_value) else: raise InvalidParameterValueError(attr_name) @@ -2371,35 +2676,37 @@ class VPCBackend(object): else: raise InvalidVpcCidrBlockAssociationIdError(association_id) - def associate_vpc_cidr_block(self, vpc_id, cidr_block, amazon_provided_ipv6_cidr_block): + def associate_vpc_cidr_block( + self, vpc_id, cidr_block, amazon_provided_ipv6_cidr_block + ): vpc = self.get_vpc(vpc_id) return vpc.associate_vpc_cidr_block(cidr_block, amazon_provided_ipv6_cidr_block) class VPCPeeringConnectionStatus(object): - def __init__(self, code='initiating-request', message=''): + def __init__(self, code="initiating-request", message=""): self.code = code self.message = message def deleted(self): - self.code = 'deleted' - self.message = 'Deleted by {deleter ID}' + self.code = "deleted" + self.message = "Deleted by {deleter ID}" def initiating(self): - self.code = 'initiating-request' - self.message = 'Initiating Request to {accepter ID}' + self.code = "initiating-request" + self.message = "Initiating Request to {accepter ID}" def pending(self): - self.code = 'pending-acceptance' - self.message = 'Pending Acceptance by {accepter ID}' + self.code = "pending-acceptance" + self.message = "Pending Acceptance by {accepter ID}" def accept(self): - self.code = 'active' - self.message = 'Active' + self.code = "active" + self.message = "Active" def reject(self): - self.code = 'rejected' - self.message = 'Inactive' + self.code = "rejected" + self.message = "Inactive" class VPCPeeringConnection(TaggedEC2Resource): @@ -2410,12 +2717,14 @@ class VPCPeeringConnection(TaggedEC2Resource): self._status = VPCPeeringConnectionStatus() @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] - vpc = ec2_backend.get_vpc(properties['VpcId']) - peer_vpc = ec2_backend.get_vpc(properties['PeerVpcId']) + vpc = ec2_backend.get_vpc(properties["VpcId"]) + peer_vpc = ec2_backend.get_vpc(properties["PeerVpcId"]) vpc_pcx = ec2_backend.create_vpc_peering_connection(vpc, peer_vpc) @@ -2474,7 +2783,7 @@ class VPCPeeringConnectionBackend(object): pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: raise OperationNotPermitted2(self.region_name, vpc_pcx.id, pcx_acp_region) - if vpc_pcx._status.code != 'pending-acceptance': + if vpc_pcx._status.code != "pending-acceptance": raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) vpc_pcx._status.accept() return vpc_pcx @@ -2486,15 +2795,25 @@ class VPCPeeringConnectionBackend(object): pcx_acp_region = vpc_pcx.peer_vpc.ec2_backend.region_name if pcx_req_region != pcx_acp_region and self.region_name == pcx_req_region: raise OperationNotPermitted3(self.region_name, vpc_pcx.id, pcx_acp_region) - if vpc_pcx._status.code != 'pending-acceptance': + if vpc_pcx._status.code != "pending-acceptance": raise InvalidVPCPeeringConnectionStateTransitionError(vpc_pcx.id) vpc_pcx._status.reject() return vpc_pcx class Subnet(TaggedEC2Resource): - def __init__(self, ec2_backend, subnet_id, vpc_id, cidr_block, availability_zone, default_for_az, - map_public_ip_on_launch, owner_id=OWNER_ID, assign_ipv6_address_on_creation=False): + def __init__( + self, + ec2_backend, + subnet_id, + vpc_id, + cidr_block, + availability_zone, + default_for_az, + map_public_ip_on_launch, + owner_id=OWNER_ID, + assign_ipv6_address_on_creation=False, + ): self.ec2_backend = ec2_backend self.id = subnet_id self.vpc_id = vpc_id @@ -2509,22 +2828,24 @@ class Subnet(TaggedEC2Resource): # Theory is we assign ip's as we go (as 16,777,214 usable IPs in a /8) self._subnet_ip_generator = self.cidr.hosts() - self.reserved_ips = [six.next(self._subnet_ip_generator) for _ in range(0, 3)] # Reserved by AWS + self.reserved_ips = [ + six.next(self._subnet_ip_generator) for _ in range(0, 3) + ] # Reserved by AWS self._unused_ips = set() # if instance is destroyed hold IP here for reuse self._subnet_ips = {} # has IP: instance @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - vpc_id = properties['VpcId'] - cidr_block = properties['CidrBlock'] - availability_zone = properties.get('AvailabilityZone') + vpc_id = properties["VpcId"] + cidr_block = properties["CidrBlock"] + availability_zone = properties.get("AvailabilityZone") ec2_backend = ec2_backends[region_name] subnet = ec2_backend.create_subnet( - vpc_id=vpc_id, - cidr_block=cidr_block, - availability_zone=availability_zone, + vpc_id=vpc_id, cidr_block=cidr_block, availability_zone=availability_zone ) for tag in properties.get("Tags", []): tag_key = tag["Key"] @@ -2558,25 +2879,24 @@ class Subnet(TaggedEC2Resource): Taken from: http://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeSubnets.html """ - if filter_name in ('cidr', 'cidrBlock', 'cidr-block'): + if filter_name in ("cidr", "cidrBlock", "cidr-block"): return self.cidr_block - elif filter_name in ('vpc-id', 'vpcId'): + elif filter_name in ("vpc-id", "vpcId"): return self.vpc_id - elif filter_name == 'subnet-id': + elif filter_name == "subnet-id": return self.id - elif filter_name in ('availabilityZone', 'availability-zone'): + elif filter_name in ("availabilityZone", "availability-zone"): return self.availability_zone - elif filter_name in ('defaultForAz', 'default-for-az'): + elif filter_name in ("defaultForAz", "default-for-az"): return self.default_for_az else: - return super(Subnet, self).get_filter_value( - filter_name, 'DescribeSubnets') + return super(Subnet, self).get_filter_value(filter_name, "DescribeSubnets") def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'AvailabilityZone': - raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "AvailabilityZone" ]"') + + if attribute_name == "AvailabilityZone": + raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "AvailabilityZone" ]"') raise UnformattedGetAttTemplateException() def get_available_subnet_ip(self, instance): @@ -2600,10 +2920,12 @@ class Subnet(TaggedEC2Resource): def request_ip(self, ip, instance): if ipaddress.ip_address(ip) not in self.cidr: - raise Exception('IP does not fall in the subnet CIDR of {0}'.format(self.cidr)) + raise Exception( + "IP does not fall in the subnet CIDR of {0}".format(self.cidr) + ) if ip in self._subnet_ips: - raise Exception('IP already in use') + raise Exception("IP already in use") try: self._unused_ips.remove(ip) except KeyError: @@ -2634,17 +2956,25 @@ class SubnetBackend(object): def create_subnet(self, vpc_id, cidr_block, availability_zone, context=None): subnet_id = random_subnet_id() - vpc = self.get_vpc(vpc_id) # Validate VPC exists and the supplied CIDR block is a subnet of the VPC's - vpc_cidr_block = ipaddress.IPv4Network(six.text_type(vpc.cidr_block), strict=False) + vpc = self.get_vpc( + vpc_id + ) # Validate VPC exists and the supplied CIDR block is a subnet of the VPC's + vpc_cidr_block = ipaddress.IPv4Network( + six.text_type(vpc.cidr_block), strict=False + ) try: - subnet_cidr_block = ipaddress.IPv4Network(six.text_type(cidr_block), strict=False) + subnet_cidr_block = ipaddress.IPv4Network( + six.text_type(cidr_block), strict=False + ) except ValueError: raise InvalidCIDRBlockParameterError(cidr_block) - if not (vpc_cidr_block.network_address <= subnet_cidr_block.network_address and - vpc_cidr_block.broadcast_address >= subnet_cidr_block.broadcast_address): + if not ( + vpc_cidr_block.network_address <= subnet_cidr_block.network_address + and vpc_cidr_block.broadcast_address >= subnet_cidr_block.broadcast_address + ): raise InvalidSubnetRangeError(cidr_block) - for subnet in self.get_all_subnets(filters={'vpc-id': vpc_id}): + for subnet in self.get_all_subnets(filters={"vpc-id": vpc_id}): if subnet.cidr.overlaps(subnet_cidr_block): raise InvalidSubnetConflictError(cidr_block) @@ -2653,14 +2983,36 @@ class SubnetBackend(object): default_for_az = str(availability_zone not in self.subnets).lower() map_public_ip_on_launch = default_for_az if availability_zone is None: - availability_zone = 'us-east-1a' + availability_zone = "us-east-1a" try: - availability_zone_data = next(zone for zones in RegionsAndZonesBackend.zones.values() for zone in zones if zone.name == availability_zone) + availability_zone_data = next( + zone + for zones in RegionsAndZonesBackend.zones.values() + for zone in zones + if zone.name == availability_zone + ) except StopIteration: - raise InvalidAvailabilityZoneError(availability_zone, ", ".join([zone.name for zones in RegionsAndZonesBackend.zones.values() for zone in zones])) - subnet = Subnet(self, subnet_id, vpc_id, cidr_block, availability_zone_data, - default_for_az, map_public_ip_on_launch, - owner_id=context.get_current_user() if context else OWNER_ID, assign_ipv6_address_on_creation=False) + raise InvalidAvailabilityZoneError( + availability_zone, + ", ".join( + [ + zone.name + for zones in RegionsAndZonesBackend.zones.values() + for zone in zones + ] + ), + ) + subnet = Subnet( + self, + subnet_id, + vpc_id, + cidr_block, + availability_zone_data, + default_for_az, + map_public_ip_on_launch, + owner_id=context.get_current_user() if context else OWNER_ID, + assign_ipv6_address_on_creation=False, + ) # AWS associates a new subnet with the default Network ACL self.associate_default_network_acl_with_subnet(subnet_id, vpc_id) @@ -2669,11 +3021,9 @@ class SubnetBackend(object): def get_all_subnets(self, subnet_ids=None, filters=None): # Extract a list of all subnets - matches = itertools.chain(*[x.values() - for x in self.subnets.values()]) + matches = itertools.chain(*[x.values() for x in self.subnets.values()]) if subnet_ids: - matches = [sn for sn in matches - if sn.id in subnet_ids] + matches = [sn for sn in matches if sn.id in subnet_ids] if len(subnet_ids) > len(matches): unknown_ids = set(subnet_ids) - set(matches) raise InvalidSubnetIdError(unknown_ids) @@ -2690,7 +3040,7 @@ class SubnetBackend(object): def modify_subnet_attribute(self, subnet_id, attr_name, attr_value): subnet = self.get_subnet(subnet_id) - if attr_name in ('map_public_ip_on_launch', 'assign_ipv6_address_on_creation'): + if attr_name in ("map_public_ip_on_launch", "assign_ipv6_address_on_creation"): setattr(subnet, attr_name, attr_value) else: raise InvalidParameterValueError(attr_name) @@ -2702,16 +3052,17 @@ class SubnetRouteTableAssociation(object): self.subnet_id = subnet_id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - route_table_id = properties['RouteTableId'] - subnet_id = properties['SubnetId'] + route_table_id = properties["RouteTableId"] + subnet_id = properties["SubnetId"] ec2_backend = ec2_backends[region_name] subnet_association = ec2_backend.create_subnet_association( - route_table_id=route_table_id, - subnet_id=subnet_id, + route_table_id=route_table_id, subnet_id=subnet_id ) return subnet_association @@ -2722,10 +3073,10 @@ class SubnetRouteTableAssociationBackend(object): super(SubnetRouteTableAssociationBackend, self).__init__() def create_subnet_association(self, route_table_id, subnet_id): - subnet_association = SubnetRouteTableAssociation( - route_table_id, subnet_id) - self.subnet_associations["{0}:{1}".format( - route_table_id, subnet_id)] = subnet_association + subnet_association = SubnetRouteTableAssociation(route_table_id, subnet_id) + self.subnet_associations[ + "{0}:{1}".format(route_table_id, subnet_id) + ] = subnet_association return subnet_association @@ -2739,14 +3090,14 @@ class RouteTable(TaggedEC2Resource): self.routes = {} @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - vpc_id = properties['VpcId'] + vpc_id = properties["VpcId"] ec2_backend = ec2_backends[region_name] - route_table = ec2_backend.create_route_table( - vpc_id=vpc_id, - ) + route_table = ec2_backend.create_route_table(vpc_id=vpc_id) return route_table @property @@ -2758,9 +3109,9 @@ class RouteTable(TaggedEC2Resource): # Note: Boto only supports 'true'. # https://github.com/boto/boto/issues/1742 if self.main: - return 'true' + return "true" else: - return 'false' + return "false" elif filter_name == "route-table-id": return self.id elif filter_name == "vpc-id": @@ -2773,7 +3124,8 @@ class RouteTable(TaggedEC2Resource): return self.associations.values() else: return super(RouteTable, self).get_filter_value( - filter_name, 'DescribeRouteTables') + filter_name, "DescribeRouteTables" + ) class RouteTableBackend(object): @@ -2803,10 +3155,16 @@ class RouteTableBackend(object): if route_table_ids: route_tables = [ - route_table for route_table in route_tables if route_table.id in route_table_ids] + route_table + for route_table in route_tables + if route_table.id in route_table_ids + ] if len(route_tables) != len(route_table_ids): - invalid_id = list(set(route_table_ids).difference( - set([route_table.id for route_table in route_tables])))[0] + invalid_id = list( + set(route_table_ids).difference( + set([route_table.id for route_table in route_tables]) + ) + )[0] raise InvalidRouteTableIdError(invalid_id) return generic_filter(filters, route_tables) @@ -2815,7 +3173,9 @@ class RouteTableBackend(object): route_table = self.get_route_table(route_table_id) if route_table.associations: raise DependencyViolationError( - "The routeTable '{0}' has dependencies and cannot be deleted.".format(route_table_id) + "The routeTable '{0}' has dependencies and cannot be deleted.".format( + route_table_id + ) ) self.route_tables.pop(route_table_id) return True @@ -2823,9 +3183,12 @@ class RouteTableBackend(object): def associate_route_table(self, route_table_id, subnet_id): # Idempotent if association already exists. route_tables_by_subnet = self.get_all_route_tables( - filters={'association.subnet-id': [subnet_id]}) + filters={"association.subnet-id": [subnet_id]} + ) if route_tables_by_subnet: - for association_id, check_subnet_id in route_tables_by_subnet[0].associations.items(): + for association_id, check_subnet_id in route_tables_by_subnet[ + 0 + ].associations.items(): if subnet_id == check_subnet_id: return association_id @@ -2850,7 +3213,8 @@ class RouteTableBackend(object): # Find route table which currently has the association, error if none. route_tables_by_association_id = self.get_all_route_tables( - filters={'association.route-table-association-id': [association_id]}) + filters={"association.route-table-association-id": [association_id]} + ) if not route_tables_by_association_id: raise InvalidAssociationIdError(association_id) @@ -2861,8 +3225,16 @@ class RouteTableBackend(object): class Route(object): - def __init__(self, route_table, destination_cidr_block, local=False, - gateway=None, instance=None, interface=None, vpc_pcx=None): + def __init__( + self, + route_table, + destination_cidr_block, + local=False, + gateway=None, + instance=None, + interface=None, + vpc_pcx=None, + ): self.id = generate_route_id(route_table.id, destination_cidr_block) self.route_table = route_table self.destination_cidr_block = destination_cidr_block @@ -2873,19 +3245,21 @@ class Route(object): self.vpc_pcx = vpc_pcx @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - gateway_id = properties.get('GatewayId') - instance_id = properties.get('InstanceId') - interface_id = properties.get('NetworkInterfaceId') - pcx_id = properties.get('VpcPeeringConnectionId') + gateway_id = properties.get("GatewayId") + instance_id = properties.get("InstanceId") + interface_id = properties.get("NetworkInterfaceId") + pcx_id = properties.get("VpcPeeringConnectionId") - route_table_id = properties['RouteTableId'] + route_table_id = properties["RouteTableId"] ec2_backend = ec2_backends[region_name] route_table = ec2_backend.create_route( route_table_id=route_table_id, - destination_cidr_block=properties.get('DestinationCidrBlock'), + destination_cidr_block=properties.get("DestinationCidrBlock"), gateway_id=gateway_id, instance_id=instance_id, interface_id=interface_id, @@ -2898,20 +3272,26 @@ class RouteBackend(object): def __init__(self): super(RouteBackend, self).__init__() - def create_route(self, route_table_id, destination_cidr_block, local=False, - gateway_id=None, instance_id=None, interface_id=None, - vpc_peering_connection_id=None): + def create_route( + self, + route_table_id, + destination_cidr_block, + local=False, + gateway_id=None, + instance_id=None, + interface_id=None, + vpc_peering_connection_id=None, + ): route_table = self.get_route_table(route_table_id) if interface_id: - self.raise_not_implemented_error( - "CreateRoute to NetworkInterfaceId") + self.raise_not_implemented_error("CreateRoute to NetworkInterfaceId") gateway = None if gateway_id: - if EC2_RESOURCE_TO_PREFIX['vpn-gateway'] in gateway_id: + if EC2_RESOURCE_TO_PREFIX["vpn-gateway"] in gateway_id: gateway = self.get_vpn_gateway(gateway_id) - elif EC2_RESOURCE_TO_PREFIX['internet-gateway'] in gateway_id: + elif EC2_RESOURCE_TO_PREFIX["internet-gateway"] in gateway_id: gateway = self.get_internet_gateway(gateway_id) try: @@ -2919,39 +3299,50 @@ class RouteBackend(object): except ValueError: raise InvalidDestinationCIDRBlockParameterError(destination_cidr_block) - route = Route(route_table, destination_cidr_block, local=local, - gateway=gateway, - instance=self.get_instance( - instance_id) if instance_id else None, - interface=None, - vpc_pcx=self.get_vpc_peering_connection( - vpc_peering_connection_id) if vpc_peering_connection_id else None) + route = Route( + route_table, + destination_cidr_block, + local=local, + gateway=gateway, + instance=self.get_instance(instance_id) if instance_id else None, + interface=None, + vpc_pcx=self.get_vpc_peering_connection(vpc_peering_connection_id) + if vpc_peering_connection_id + else None, + ) route_table.routes[route.id] = route return route - def replace_route(self, route_table_id, destination_cidr_block, - gateway_id=None, instance_id=None, interface_id=None, - vpc_peering_connection_id=None): + def replace_route( + self, + route_table_id, + destination_cidr_block, + gateway_id=None, + instance_id=None, + interface_id=None, + vpc_peering_connection_id=None, + ): route_table = self.get_route_table(route_table_id) route_id = generate_route_id(route_table.id, destination_cidr_block) route = route_table.routes[route_id] if interface_id: - self.raise_not_implemented_error( - "ReplaceRoute to NetworkInterfaceId") + self.raise_not_implemented_error("ReplaceRoute to NetworkInterfaceId") route.gateway = None if gateway_id: - if EC2_RESOURCE_TO_PREFIX['vpn-gateway'] in gateway_id: + if EC2_RESOURCE_TO_PREFIX["vpn-gateway"] in gateway_id: route.gateway = self.get_vpn_gateway(gateway_id) - elif EC2_RESOURCE_TO_PREFIX['internet-gateway'] in gateway_id: + elif EC2_RESOURCE_TO_PREFIX["internet-gateway"] in gateway_id: route.gateway = self.get_internet_gateway(gateway_id) - route.instance = self.get_instance( - instance_id) if instance_id else None + route.instance = self.get_instance(instance_id) if instance_id else None route.interface = None - route.vpc_pcx = self.get_vpc_peering_connection( - vpc_peering_connection_id) if vpc_peering_connection_id else None + route.vpc_pcx = ( + self.get_vpc_peering_connection(vpc_peering_connection_id) + if vpc_peering_connection_id + else None + ) route_table.routes[route.id] = route return route @@ -2977,7 +3368,9 @@ class InternetGateway(TaggedEC2Resource): self.vpc = None @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] return ec2_backend.create_internet_gateway() @@ -3052,16 +3445,18 @@ class VPCGatewayAttachment(BaseModel): self.vpc_id = vpc_id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ec2_backend = ec2_backends[region_name] attachment = ec2_backend.create_vpc_gateway_attachment( - gateway_id=properties['InternetGatewayId'], - vpc_id=properties['VpcId'], + gateway_id=properties["InternetGatewayId"], vpc_id=properties["VpcId"] ) ec2_backend.attach_internet_gateway( - properties['InternetGatewayId'], properties['VpcId']) + properties["InternetGatewayId"], properties["VpcId"] + ) return attachment @property @@ -3081,11 +3476,30 @@ class VPCGatewayAttachmentBackend(object): class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource): - def __init__(self, ec2_backend, spot_request_id, price, image_id, type, - valid_from, valid_until, launch_group, availability_zone_group, - key_name, security_groups, user_data, instance_type, placement, - kernel_id, ramdisk_id, monitoring_enabled, subnet_id, tags, spot_fleet_id, - **kwargs): + def __init__( + self, + ec2_backend, + spot_request_id, + price, + image_id, + type, + valid_from, + valid_until, + launch_group, + availability_zone_group, + key_name, + security_groups, + user_data, + instance_type, + placement, + kernel_id, + ramdisk_id, + monitoring_enabled, + subnet_id, + tags, + spot_fleet_id, + **kwargs + ): super(SpotInstanceRequest, self).__init__(**kwargs) ls = LaunchSpecification() self.ec2_backend = ec2_backend @@ -3112,30 +3526,31 @@ class SpotInstanceRequest(BotoSpotRequest, TaggedEC2Resource): if security_groups: for group_name in security_groups: - group = self.ec2_backend.get_security_group_from_name( - group_name) + group = self.ec2_backend.get_security_group_from_name(group_name) if group: ls.groups.append(group) else: # If not security groups, add the default - default_group = self.ec2_backend.get_security_group_from_name( - "default") + default_group = self.ec2_backend.get_security_group_from_name("default") ls.groups.append(default_group) self.instance = self.launch_instance() def get_filter_value(self, filter_name): - if filter_name == 'state': + if filter_name == "state": return self.state - elif filter_name == 'spot-instance-request-id': + elif filter_name == "spot-instance-request-id": return self.id else: return super(SpotInstanceRequest, self).get_filter_value( - filter_name, 'DescribeSpotInstanceRequests') + filter_name, "DescribeSpotInstanceRequests" + ) def launch_instance(self): reservation = self.ec2_backend.add_instances( - image_id=self.launch_specification.image_id, count=1, user_data=self.user_data, + image_id=self.launch_specification.image_id, + count=1, + user_data=self.user_data, instance_type=self.launch_specification.instance_type, subnet_id=self.launch_specification.subnet_id, key_name=self.launch_specification.key_name, @@ -3154,25 +3569,59 @@ class SpotRequestBackend(object): self.spot_instance_requests = {} super(SpotRequestBackend, self).__init__() - def request_spot_instances(self, price, image_id, count, type, valid_from, - valid_until, launch_group, availability_zone_group, - key_name, security_groups, user_data, - instance_type, placement, kernel_id, ramdisk_id, - monitoring_enabled, subnet_id, tags=None, spot_fleet_id=None): + def request_spot_instances( + self, + price, + image_id, + count, + type, + valid_from, + valid_until, + launch_group, + availability_zone_group, + key_name, + security_groups, + user_data, + instance_type, + placement, + kernel_id, + ramdisk_id, + monitoring_enabled, + subnet_id, + tags=None, + spot_fleet_id=None, + ): requests = [] tags = tags or {} for _ in range(count): spot_request_id = random_spot_request_id() - request = SpotInstanceRequest(self, - spot_request_id, price, image_id, type, valid_from, valid_until, - launch_group, availability_zone_group, key_name, security_groups, - user_data, instance_type, placement, kernel_id, ramdisk_id, - monitoring_enabled, subnet_id, tags, spot_fleet_id) + request = SpotInstanceRequest( + self, + spot_request_id, + price, + image_id, + type, + valid_from, + valid_until, + launch_group, + availability_zone_group, + key_name, + security_groups, + user_data, + instance_type, + placement, + kernel_id, + ramdisk_id, + monitoring_enabled, + subnet_id, + tags, + spot_fleet_id, + ) self.spot_instance_requests[spot_request_id] = request requests.append(request) return requests - @Model.prop('SpotInstanceRequest') + @Model.prop("SpotInstanceRequest") def describe_spot_instance_requests(self, filters=None): requests = self.spot_instance_requests.values() @@ -3186,9 +3635,21 @@ class SpotRequestBackend(object): class SpotFleetLaunchSpec(object): - def __init__(self, ebs_optimized, group_set, iam_instance_profile, image_id, - instance_type, key_name, monitoring, spot_price, subnet_id, tag_specifications, - user_data, weighted_capacity): + def __init__( + self, + ebs_optimized, + group_set, + iam_instance_profile, + image_id, + instance_type, + key_name, + monitoring, + spot_price, + subnet_id, + tag_specifications, + user_data, + weighted_capacity, + ): self.ebs_optimized = ebs_optimized self.group_set = group_set self.iam_instance_profile = iam_instance_profile @@ -3204,8 +3665,16 @@ class SpotFleetLaunchSpec(object): class SpotFleetRequest(TaggedEC2Resource): - def __init__(self, ec2_backend, spot_fleet_request_id, spot_price, - target_capacity, iam_fleet_role, allocation_strategy, launch_specs): + def __init__( + self, + ec2_backend, + spot_fleet_request_id, + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ): self.ec2_backend = ec2_backend self.id = spot_fleet_request_id @@ -3218,21 +3687,23 @@ class SpotFleetRequest(TaggedEC2Resource): self.launch_specs = [] for spec in launch_specs: - self.launch_specs.append(SpotFleetLaunchSpec( - ebs_optimized=spec['ebs_optimized'], - group_set=[val for key, val in spec.items( - ) if key.startswith("group_set")], - iam_instance_profile=spec.get('iam_instance_profile._arn'), - image_id=spec['image_id'], - instance_type=spec['instance_type'], - key_name=spec.get('key_name'), - monitoring=spec.get('monitoring._enabled'), - spot_price=spec.get('spot_price', self.spot_price), - subnet_id=spec['subnet_id'], - tag_specifications=self._parse_tag_specifications(spec), - user_data=spec.get('user_data'), - weighted_capacity=spec['weighted_capacity'], - ) + self.launch_specs.append( + SpotFleetLaunchSpec( + ebs_optimized=spec["ebs_optimized"], + group_set=[ + val for key, val in spec.items() if key.startswith("group_set") + ], + iam_instance_profile=spec.get("iam_instance_profile._arn"), + image_id=spec["image_id"], + instance_type=spec["instance_type"], + key_name=spec.get("key_name"), + monitoring=spec.get("monitoring._enabled"), + spot_price=spec.get("spot_price", self.spot_price), + subnet_id=spec["subnet_id"], + tag_specifications=self._parse_tag_specifications(spec), + user_data=spec.get("user_data"), + weighted_capacity=spec["weighted_capacity"], + ) ) self.spot_requests = [] @@ -3243,26 +3714,34 @@ class SpotFleetRequest(TaggedEC2Resource): return self.id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json[ - 'Properties']['SpotFleetRequestConfigData'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"]["SpotFleetRequestConfigData"] ec2_backend = ec2_backends[region_name] - spot_price = properties.get('SpotPrice') - target_capacity = properties['TargetCapacity'] - iam_fleet_role = properties['IamFleetRole'] - allocation_strategy = properties['AllocationStrategy'] + spot_price = properties.get("SpotPrice") + target_capacity = properties["TargetCapacity"] + iam_fleet_role = properties["IamFleetRole"] + allocation_strategy = properties["AllocationStrategy"] launch_specs = properties["LaunchSpecifications"] launch_specs = [ - dict([(camelcase_to_underscores(key), val) - for key, val in launch_spec.items()]) - for launch_spec - in launch_specs + dict( + [ + (camelcase_to_underscores(key), val) + for key, val in launch_spec.items() + ] + ) + for launch_spec in launch_specs ] - spot_fleet_request = ec2_backend.request_spot_fleet(spot_price, - target_capacity, iam_fleet_role, allocation_strategy, - launch_specs) + spot_fleet_request = ec2_backend.request_spot_fleet( + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ) return spot_fleet_request @@ -3270,11 +3749,12 @@ class SpotFleetRequest(TaggedEC2Resource): weight_map = defaultdict(int) weight_so_far = 0 - if self.allocation_strategy == 'diversified': + if self.allocation_strategy == "diversified": launch_spec_index = 0 while True: launch_spec = self.launch_specs[ - launch_spec_index % len(self.launch_specs)] + launch_spec_index % len(self.launch_specs) + ] weight_map[launch_spec] += 1 weight_so_far += launch_spec.weighted_capacity if weight_so_far >= weight_to_add: @@ -3283,10 +3763,15 @@ class SpotFleetRequest(TaggedEC2Resource): else: # lowestPrice cheapest_spec = sorted( # FIXME: change `+inf` to the on demand price scaled to weighted capacity when it's not present - self.launch_specs, key=lambda spec: float(spec.spot_price or '+inf'))[0] - weight_so_far = weight_to_add + (weight_to_add % cheapest_spec.weighted_capacity) + self.launch_specs, + key=lambda spec: float(spec.spot_price or "+inf"), + )[0] + weight_so_far = weight_to_add + ( + weight_to_add % cheapest_spec.weighted_capacity + ) weight_map[cheapest_spec] = int( - weight_so_far // cheapest_spec.weighted_capacity) + weight_so_far // cheapest_spec.weighted_capacity + ) return weight_map, weight_so_far @@ -3324,7 +3809,10 @@ class SpotFleetRequest(TaggedEC2Resource): for req in self.spot_requests: instance = req.instance for spec in self.launch_specs: - if spec.instance_type == instance.instance_type and spec.subnet_id == instance.subnet_id: + if ( + spec.instance_type == instance.instance_type + and spec.subnet_id == instance.subnet_id + ): break if new_fulfilled_capacity - spec.weighted_capacity < self.target_capacity: @@ -3332,25 +3820,48 @@ class SpotFleetRequest(TaggedEC2Resource): new_fulfilled_capacity -= spec.weighted_capacity instance_ids.append(instance.id) - self.spot_requests = [req for req in self.spot_requests if req.instance.id not in instance_ids] + self.spot_requests = [ + req for req in self.spot_requests if req.instance.id not in instance_ids + ] self.ec2_backend.terminate_instances(instance_ids) def _parse_tag_specifications(self, spec): try: - tag_spec_num = max([int(key.split('.')[1]) for key in spec if key.startswith("tag_specification_set")]) + tag_spec_num = max( + [ + int(key.split(".")[1]) + for key in spec + if key.startswith("tag_specification_set") + ] + ) except ValueError: # no tag specifications return {} tag_specifications = {} for si in range(1, tag_spec_num + 1): - resource_type = spec["tag_specification_set.{si}._resource_type".format(si=si)] + resource_type = spec[ + "tag_specification_set.{si}._resource_type".format(si=si) + ] - tags = [key for key in spec if key.startswith("tag_specification_set.{si}._tag".format(si=si))] - tag_num = max([int(key.split('.')[3]) for key in tags]) - tag_specifications[resource_type] = dict(( - spec["tag_specification_set.{si}._tag.{ti}._key".format(si=si, ti=ti)], - spec["tag_specification_set.{si}._tag.{ti}._value".format(si=si, ti=ti)], - ) for ti in range(1, tag_num + 1)) + tags = [ + key + for key in spec + if key.startswith("tag_specification_set.{si}._tag".format(si=si)) + ] + tag_num = max([int(key.split(".")[3]) for key in tags]) + tag_specifications[resource_type] = dict( + ( + spec[ + "tag_specification_set.{si}._tag.{ti}._key".format(si=si, ti=ti) + ], + spec[ + "tag_specification_set.{si}._tag.{ti}._value".format( + si=si, ti=ti + ) + ], + ) + for ti in range(1, tag_num + 1) + ) return tag_specifications @@ -3360,12 +3871,25 @@ class SpotFleetBackend(object): self.spot_fleet_requests = {} super(SpotFleetBackend, self).__init__() - def request_spot_fleet(self, spot_price, target_capacity, iam_fleet_role, - allocation_strategy, launch_specs): + def request_spot_fleet( + self, + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ): spot_fleet_request_id = random_spot_fleet_request_id() - request = SpotFleetRequest(self, spot_fleet_request_id, spot_price, - target_capacity, iam_fleet_role, allocation_strategy, launch_specs) + request = SpotFleetRequest( + self, + spot_fleet_request_id, + spot_price, + target_capacity, + iam_fleet_role, + allocation_strategy, + launch_specs, + ) self.spot_fleet_requests[spot_fleet_request_id] = request return request @@ -3381,7 +3905,8 @@ class SpotFleetBackend(object): if spot_fleet_request_ids: requests = [ - request for request in requests if request.id in spot_fleet_request_ids] + request for request in requests if request.id in spot_fleet_request_ids + ] return requests @@ -3396,15 +3921,17 @@ class SpotFleetBackend(object): del self.spot_fleet_requests[spot_fleet_request_id] return spot_requests - def modify_spot_fleet_request(self, spot_fleet_request_id, target_capacity, terminate_instances): + def modify_spot_fleet_request( + self, spot_fleet_request_id, target_capacity, terminate_instances + ): if target_capacity < 0: - raise ValueError('Cannot reduce spot fleet capacity below 0') + raise ValueError("Cannot reduce spot fleet capacity below 0") spot_fleet_request = self.spot_fleet_requests[spot_fleet_request_id] delta = target_capacity - spot_fleet_request.fulfilled_capacity spot_fleet_request.target_capacity = target_capacity if delta > 0: spot_fleet_request.create_spot_requests(delta) - elif delta < 0 and terminate_instances == 'Default': + elif delta < 0 and terminate_instances == "Default": spot_fleet_request.terminate_instances() return True @@ -3422,18 +3949,19 @@ class ElasticAddress(object): self.association_id = None @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] - properties = cloudformation_json.get('Properties') + properties = cloudformation_json.get("Properties") instance_id = None if properties: - domain = properties.get('Domain') - eip = ec2_backend.allocate_address( - domain=domain if domain else 'standard') - instance_id = properties.get('InstanceId') + domain = properties.get("Domain") + eip = ec2_backend.allocate_address(domain=domain if domain else "standard") + instance_id = properties.get("InstanceId") else: - eip = ec2_backend.allocate_address(domain='standard') + eip = ec2_backend.allocate_address(domain="standard") if instance_id: instance = ec2_backend.get_instance_by_id(instance_id) @@ -3447,28 +3975,29 @@ class ElasticAddress(object): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'AllocationId': + + if attribute_name == "AllocationId": return self.allocation_id raise UnformattedGetAttTemplateException() def get_filter_value(self, filter_name): - if filter_name == 'allocation-id': + if filter_name == "allocation-id": return self.allocation_id - elif filter_name == 'association-id': + elif filter_name == "association-id": return self.association_id - elif filter_name == 'domain': + elif filter_name == "domain": return self.domain - elif filter_name == 'instance-id' and self.instance: + elif filter_name == "instance-id" and self.instance: return self.instance.id - elif filter_name == 'network-interface-id' and self.eni: + elif filter_name == "network-interface-id" and self.eni: return self.eni.id - elif filter_name == 'private-ip-address' and self.eni: + elif filter_name == "private-ip-address" and self.eni: return self.eni.private_ip_address - elif filter_name == 'public-ip': + elif filter_name == "public-ip": return self.public_ip else: # TODO: implement network-interface-owner-id - raise FilterNotImplementedError(filter_name, 'DescribeAddresses') + raise FilterNotImplementedError(filter_name, "DescribeAddresses") class ElasticAddressBackend(object): @@ -3477,7 +4006,7 @@ class ElasticAddressBackend(object): super(ElasticAddressBackend, self).__init__() def allocate_address(self, domain, address=None): - if domain not in ['standard', 'vpc']: + if domain not in ["standard", "vpc"]: raise InvalidDomainError(domain) if address: address = ElasticAddress(domain, address) @@ -3487,8 +4016,7 @@ class ElasticAddressBackend(object): return address def address_by_ip(self, ips): - eips = [address for address in self.addresses - if address.public_ip in ips] + eips = [address for address in self.addresses if address.public_ip in ips] # TODO: Trim error message down to specific invalid address. if not eips or len(ips) > len(eips): @@ -3497,8 +4025,11 @@ class ElasticAddressBackend(object): return eips def address_by_allocation(self, allocation_ids): - eips = [address for address in self.addresses - if address.allocation_id in allocation_ids] + eips = [ + address + for address in self.addresses + if address.allocation_id in allocation_ids + ] # TODO: Trim error message down to specific invalid id. if not eips or len(allocation_ids) > len(eips): @@ -3507,8 +4038,11 @@ class ElasticAddressBackend(object): return eips def address_by_association(self, association_ids): - eips = [address for address in self.addresses - if address.association_id in association_ids] + eips = [ + address + for address in self.addresses + if address.association_id in association_ids + ] # TODO: Trim error message down to specific invalid id. if not eips or len(association_ids) > len(eips): @@ -3516,7 +4050,14 @@ class ElasticAddressBackend(object): return eips - def associate_address(self, instance=None, eni=None, address=None, allocation_id=None, reassociate=False): + def associate_address( + self, + instance=None, + eni=None, + address=None, + allocation_id=None, + reassociate=False, + ): eips = [] if address: eips = self.address_by_ip([address]) @@ -3524,10 +4065,10 @@ class ElasticAddressBackend(object): eips = self.address_by_allocation([allocation_id]) eip = eips[0] - new_instance_association = bool(instance and ( - not eip.instance or eip.instance.id == instance.id)) - new_eni_association = bool( - eni and (not eip.eni or eni.id == eip.eni.id)) + new_instance_association = bool( + instance and (not eip.instance or eip.instance.id == instance.id) + ) + new_eni_association = bool(eni and (not eip.eni or eni.id == eip.eni.id)) if new_instance_association or new_eni_association or reassociate: eip.instance = instance @@ -3547,14 +4088,12 @@ class ElasticAddressBackend(object): def describe_addresses(self, allocation_ids=None, public_ips=None, filters=None): matches = self.addresses if allocation_ids: - matches = [addr for addr in matches - if addr.allocation_id in allocation_ids] + matches = [addr for addr in matches if addr.allocation_id in allocation_ids] if len(allocation_ids) > len(matches): unknown_ids = set(allocation_ids) - set(matches) raise InvalidAllocationIdError(unknown_ids) if public_ips: - matches = [addr for addr in matches - if addr.public_ip in public_ips] + matches = [addr for addr in matches if addr.public_ip in public_ips] if len(public_ips) > len(matches): unknown_ips = set(allocation_ids) - set(matches) raise InvalidAddressError(unknown_ips) @@ -3596,9 +4135,15 @@ class ElasticAddressBackend(object): class DHCPOptionsSet(TaggedEC2Resource): - def __init__(self, ec2_backend, domain_name_servers=None, domain_name=None, - ntp_servers=None, netbios_name_servers=None, - netbios_node_type=None): + def __init__( + self, + ec2_backend, + domain_name_servers=None, + domain_name=None, + ntp_servers=None, + netbios_name_servers=None, + netbios_node_type=None, + ): self.ec2_backend = ec2_backend self._options = { "domain-name-servers": domain_name_servers, @@ -3623,16 +4168,17 @@ class DHCPOptionsSet(TaggedEC2Resource): Taken from: http://docs.aws.amazon.com/AWSEC2/latest/APIReference/API_DescribeDhcpOptions.html """ - if filter_name == 'dhcp-options-id': + if filter_name == "dhcp-options-id": return self.id - elif filter_name == 'key': + elif filter_name == "key": return list(self._options.keys()) - elif filter_name == 'value': + elif filter_name == "value": values = [item for item in list(self._options.values()) if item] return itertools.chain(*values) else: return super(DHCPOptionsSet, self).get_filter_value( - filter_name, 'DescribeDhcpOptions') + filter_name, "DescribeDhcpOptions" + ) @property def options(self): @@ -3649,9 +4195,13 @@ class DHCPOptionsSetBackend(object): vpc.dhcp_options = dhcp_options def create_dhcp_options( - self, domain_name_servers=None, domain_name=None, - ntp_servers=None, netbios_name_servers=None, - netbios_node_type=None): + self, + domain_name_servers=None, + domain_name=None, + ntp_servers=None, + netbios_name_servers=None, + netbios_node_type=None, + ): NETBIOS_NODE_TYPES = [1, 2, 4, 8] @@ -3663,8 +4213,12 @@ class DHCPOptionsSetBackend(object): raise InvalidParameterValueError(netbios_node_type) options = DHCPOptionsSet( - self, domain_name_servers, domain_name, ntp_servers, - netbios_name_servers, netbios_node_type + self, + domain_name_servers, + domain_name, + ntp_servers, + netbios_name_servers, + netbios_node_type, ) self.dhcp_options_sets[options.id] = options return options @@ -3679,13 +4233,12 @@ class DHCPOptionsSetBackend(object): return options_sets or self.dhcp_options_sets.values() def delete_dhcp_options_set(self, options_id): - if not (options_id and options_id.startswith('dopt-')): + if not (options_id and options_id.startswith("dopt-")): raise MalformedDHCPOptionsIdError(options_id) if options_id in self.dhcp_options_sets: if self.dhcp_options_sets[options_id].vpc: - raise DependencyViolationError( - "Cannot delete assigned DHCP options.") + raise DependencyViolationError("Cannot delete assigned DHCP options.") self.dhcp_options_sets.pop(options_id) else: raise InvalidDHCPOptionsIdError(options_id) @@ -3696,21 +4249,31 @@ class DHCPOptionsSetBackend(object): if dhcp_options_ids: dhcp_options_sets = [ - dhcp_options_set for dhcp_options_set in dhcp_options_sets if dhcp_options_set.id in dhcp_options_ids] + dhcp_options_set + for dhcp_options_set in dhcp_options_sets + if dhcp_options_set.id in dhcp_options_ids + ] if len(dhcp_options_sets) != len(dhcp_options_ids): - invalid_id = list(set(dhcp_options_ids).difference( - set([dhcp_options_set.id for dhcp_options_set in dhcp_options_sets])))[0] + invalid_id = list( + set(dhcp_options_ids).difference( + set( + [ + dhcp_options_set.id + for dhcp_options_set in dhcp_options_sets + ] + ) + ) + )[0] raise InvalidDHCPOptionsIdError(invalid_id) return generic_filter(filters, dhcp_options_sets) class VPNConnection(TaggedEC2Resource): - def __init__(self, ec2_backend, id, type, - customer_gateway_id, vpn_gateway_id): + def __init__(self, ec2_backend, id, type, customer_gateway_id, vpn_gateway_id): self.ec2_backend = ec2_backend self.id = id - self.state = 'available' + self.state = "available" self.customer_gateway_configuration = {} self.type = type self.customer_gateway_id = customer_gateway_id @@ -3720,8 +4283,9 @@ class VPNConnection(TaggedEC2Resource): self.static_routes = None def get_filter_value(self, filter_name): - return super(VPNConnection, self).get_filter_value( - filter_name, 'DescribeVpnConnections') + return super(VPNConnection, self).get_filter_value( + filter_name, "DescribeVpnConnections" + ) class VPNConnectionBackend(object): @@ -3729,16 +4293,18 @@ class VPNConnectionBackend(object): self.vpn_connections = {} super(VPNConnectionBackend, self).__init__() - def create_vpn_connection(self, type, customer_gateway_id, - vpn_gateway_id, - static_routes_only=None): + def create_vpn_connection( + self, type, customer_gateway_id, vpn_gateway_id, static_routes_only=None + ): vpn_connection_id = random_vpn_connection_id() if static_routes_only: pass vpn_connection = VPNConnection( - self, id=vpn_connection_id, type=type, + self, + id=vpn_connection_id, + type=type, customer_gateway_id=customer_gateway_id, - vpn_gateway_id=vpn_gateway_id + vpn_gateway_id=vpn_gateway_id, ) self.vpn_connections[vpn_connection.id] = vpn_connection return vpn_connection @@ -3764,11 +4330,17 @@ class VPNConnectionBackend(object): vpn_connections = self.vpn_connections.values() if vpn_connection_ids: - vpn_connections = [vpn_connection for vpn_connection in vpn_connections - if vpn_connection.id in vpn_connection_ids] + vpn_connections = [ + vpn_connection + for vpn_connection in vpn_connections + if vpn_connection.id in vpn_connection_ids + ] if len(vpn_connections) != len(vpn_connection_ids): - invalid_id = list(set(vpn_connection_ids).difference( - set([vpn_connection.id for vpn_connection in vpn_connections])))[0] + invalid_id = list( + set(vpn_connection_ids).difference( + set([vpn_connection.id for vpn_connection in vpn_connections]) + ) + )[0] raise InvalidVpnConnectionIdError(invalid_id) return generic_filter(filters, vpn_connections) @@ -3796,25 +4368,40 @@ class NetworkAclBackend(object): def add_default_entries(self, network_acl_id): default_acl_entries = [ - {'rule_number': "100", 'rule_action': 'allow', 'egress': 'true'}, - {'rule_number': "32767", 'rule_action': 'deny', 'egress': 'true'}, - {'rule_number': "100", 'rule_action': 'allow', 'egress': 'false'}, - {'rule_number': "32767", 'rule_action': 'deny', 'egress': 'false'} + {"rule_number": "100", "rule_action": "allow", "egress": "true"}, + {"rule_number": "32767", "rule_action": "deny", "egress": "true"}, + {"rule_number": "100", "rule_action": "allow", "egress": "false"}, + {"rule_number": "32767", "rule_action": "deny", "egress": "false"}, ] for entry in default_acl_entries: - self.create_network_acl_entry(network_acl_id=network_acl_id, rule_number=entry['rule_number'], protocol='-1', - rule_action=entry['rule_action'], egress=entry['egress'], cidr_block='0.0.0.0/0', - icmp_code=None, icmp_type=None, port_range_from=None, port_range_to=None) + self.create_network_acl_entry( + network_acl_id=network_acl_id, + rule_number=entry["rule_number"], + protocol="-1", + rule_action=entry["rule_action"], + egress=entry["egress"], + cidr_block="0.0.0.0/0", + icmp_code=None, + icmp_type=None, + port_range_from=None, + port_range_to=None, + ) def get_all_network_acls(self, network_acl_ids=None, filters=None): network_acls = self.network_acls.values() if network_acl_ids: - network_acls = [network_acl for network_acl in network_acls - if network_acl.id in network_acl_ids] + network_acls = [ + network_acl + for network_acl in network_acls + if network_acl.id in network_acl_ids + ] if len(network_acls) != len(network_acl_ids): - invalid_id = list(set(network_acl_ids).difference( - set([network_acl.id for network_acl in network_acls])))[0] + invalid_id = list( + set(network_acl_ids).difference( + set([network_acl.id for network_acl in network_acls]) + ) + )[0] raise InvalidRouteTableIdError(invalid_id) return generic_filter(filters, network_acls) @@ -3825,46 +4412,91 @@ class NetworkAclBackend(object): raise InvalidNetworkAclIdError(network_acl_id) return deleted - def create_network_acl_entry(self, network_acl_id, rule_number, - protocol, rule_action, egress, cidr_block, - icmp_code, icmp_type, port_range_from, - port_range_to): + def create_network_acl_entry( + self, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ): network_acl = self.get_network_acl(network_acl_id) - if any(entry.egress == egress and entry.rule_number == rule_number for entry in network_acl.network_acl_entries): + if any( + entry.egress == egress and entry.rule_number == rule_number + for entry in network_acl.network_acl_entries + ): raise NetworkAclEntryAlreadyExistsError(rule_number) - network_acl_entry = NetworkAclEntry(self, network_acl_id, rule_number, - protocol, rule_action, egress, - cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_entry = NetworkAclEntry( + self, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) network_acl.network_acl_entries.append(network_acl_entry) return network_acl_entry def delete_network_acl_entry(self, network_acl_id, rule_number, egress): network_acl = self.get_network_acl(network_acl_id) - entry = next(entry for entry in network_acl.network_acl_entries - if entry.egress == egress and entry.rule_number == rule_number) + entry = next( + entry + for entry in network_acl.network_acl_entries + if entry.egress == egress and entry.rule_number == rule_number + ) if entry is not None: network_acl.network_acl_entries.remove(entry) return entry - def replace_network_acl_entry(self, network_acl_id, rule_number, protocol, rule_action, egress, - cidr_block, icmp_code, icmp_type, port_range_from, port_range_to): + def replace_network_acl_entry( + self, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ): self.delete_network_acl_entry(network_acl_id, rule_number, egress) - network_acl_entry = self.create_network_acl_entry(network_acl_id, rule_number, - protocol, rule_action, egress, - cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_entry = self.create_network_acl_entry( + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) return network_acl_entry - def replace_network_acl_association(self, association_id, - network_acl_id): + def replace_network_acl_association(self, association_id, network_acl_id): # lookup existing association for subnet and delete it - default_acl = next(value for key, value in self.network_acls.items() - if association_id in value.associations.keys()) + default_acl = next( + value + for key, value in self.network_acls.items() + if association_id in value.associations.keys() + ) subnet_id = None for key, value in default_acl.associations.items(): @@ -3874,24 +4506,27 @@ class NetworkAclBackend(object): break new_assoc_id = random_network_acl_subnet_association_id() - association = NetworkAclAssociation(self, - new_assoc_id, - subnet_id, - network_acl_id) + association = NetworkAclAssociation( + self, new_assoc_id, subnet_id, network_acl_id + ) new_acl = self.get_network_acl(network_acl_id) new_acl.associations[new_assoc_id] = association return association def associate_default_network_acl_with_subnet(self, subnet_id, vpc_id): association_id = random_network_acl_subnet_association_id() - acl = next(acl for acl in self.network_acls.values() if acl.default and acl.vpc_id == vpc_id) - acl.associations[association_id] = NetworkAclAssociation(self, association_id, - subnet_id, acl.id) + acl = next( + acl + for acl in self.network_acls.values() + if acl.default and acl.vpc_id == vpc_id + ) + acl.associations[association_id] = NetworkAclAssociation( + self, association_id, subnet_id, acl.id + ) class NetworkAclAssociation(object): - def __init__(self, ec2_backend, new_association_id, - subnet_id, network_acl_id): + def __init__(self, ec2_backend, new_association_id, subnet_id, network_acl_id): self.ec2_backend = ec2_backend self.id = new_association_id self.new_association_id = new_association_id @@ -3907,7 +4542,7 @@ class NetworkAcl(TaggedEC2Resource): self.vpc_id = vpc_id self.network_acl_entries = [] self.associations = {} - self.default = 'true' if default is True else 'false' + self.default = "true" if default is True else "false" def get_filter_value(self, filter_name): if filter_name == "default": @@ -3920,14 +4555,25 @@ class NetworkAcl(TaggedEC2Resource): return [assoc.subnet_id for assoc in self.associations.values()] else: return super(NetworkAcl, self).get_filter_value( - filter_name, 'DescribeNetworkAcls') + filter_name, "DescribeNetworkAcls" + ) class NetworkAclEntry(TaggedEC2Resource): - def __init__(self, ec2_backend, network_acl_id, rule_number, - protocol, rule_action, egress, cidr_block, - icmp_code, icmp_type, port_range_from, - port_range_to): + def __init__( + self, + ec2_backend, + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ): self.ec2_backend = ec2_backend self.network_acl_id = network_acl_id self.rule_number = rule_number @@ -3950,8 +4596,9 @@ class VpnGateway(TaggedEC2Resource): super(VpnGateway, self).__init__() def get_filter_value(self, filter_name): - return super(VpnGateway, self).get_filter_value( - filter_name, 'DescribeVpnGateways') + return super(VpnGateway, self).get_filter_value( + filter_name, "DescribeVpnGateways" + ) class VpnGatewayAttachment(object): @@ -3966,7 +4613,7 @@ class VpnGatewayBackend(object): self.vpn_gateways = {} super(VpnGatewayBackend, self).__init__() - def create_vpn_gateway(self, type='ipsec.1'): + def create_vpn_gateway(self, type="ipsec.1"): vpn_gateway_id = random_vpn_gateway_id() vpn_gateway = VpnGateway(self, vpn_gateway_id, type) self.vpn_gateways[vpn_gateway_id] = vpn_gateway @@ -3985,7 +4632,7 @@ class VpnGatewayBackend(object): def attach_vpn_gateway(self, vpn_gateway_id, vpc_id): vpn_gateway = self.get_vpn_gateway(vpn_gateway_id) self.get_vpc(vpc_id) - attachment = VpnGatewayAttachment(vpc_id, state='attached') + attachment = VpnGatewayAttachment(vpc_id, state="attached") vpn_gateway.attachments[vpc_id] = attachment return attachment @@ -4015,8 +4662,9 @@ class CustomerGateway(TaggedEC2Resource): super(CustomerGateway, self).__init__() def get_filter_value(self, filter_name): - return super(CustomerGateway, self).get_filter_value( - filter_name, 'DescribeCustomerGateways') + return super(CustomerGateway, self).get_filter_value( + filter_name, "DescribeCustomerGateways" + ) class CustomerGatewayBackend(object): @@ -4024,10 +4672,11 @@ class CustomerGatewayBackend(object): self.customer_gateways = {} super(CustomerGatewayBackend, self).__init__() - def create_customer_gateway(self, type='ipsec.1', ip_address=None, bgp_asn=None): + def create_customer_gateway(self, type="ipsec.1", ip_address=None, bgp_asn=None): customer_gateway_id = random_customer_gateway_id() customer_gateway = CustomerGateway( - self, customer_gateway_id, type, ip_address, bgp_asn) + self, customer_gateway_id, type, ip_address, bgp_asn + ) self.customer_gateways[customer_gateway_id] = customer_gateway return customer_gateway @@ -4036,8 +4685,7 @@ class CustomerGatewayBackend(object): return generic_filter(filters, customer_gateways) def get_customer_gateway(self, customer_gateway_id): - customer_gateway = self.customer_gateways.get( - customer_gateway_id, None) + customer_gateway = self.customer_gateways.get(customer_gateway_id, None) if not customer_gateway: raise InvalidCustomerGatewayIdError(customer_gateway_id) return customer_gateway @@ -4055,7 +4703,7 @@ class NatGateway(object): self.id = random_nat_gateway_id() self.subnet_id = subnet_id self.allocation_id = allocation_id - self.state = 'available' + self.state = "available" self.private_ip = random_private_ip() # protected properties @@ -4063,11 +4711,11 @@ class NatGateway(object): self._backend = backend # NOTE: this is the core of NAT Gateways creation self._eni = self._backend.create_network_interface( - backend.get_subnet(self.subnet_id), self.private_ip) + backend.get_subnet(self.subnet_id), self.private_ip + ) # associate allocation with ENI - self._backend.associate_address( - eni=self._eni, allocation_id=self.allocation_id) + self._backend.associate_address(eni=self._eni, allocation_id=self.allocation_id) @property def vpc_id(self): @@ -4088,11 +4736,13 @@ class NatGateway(object): return eips[0].public_ip @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): ec2_backend = ec2_backends[region_name] nat_gateway = ec2_backend.create_nat_gateway( - cloudformation_json['Properties']['SubnetId'], - cloudformation_json['Properties']['AllocationId'], + cloudformation_json["Properties"]["SubnetId"], + cloudformation_json["Properties"]["AllocationId"], ) return nat_gateway @@ -4157,11 +4807,12 @@ class LaunchTemplate(TaggedEC2Resource): return self.latest_version().number def get_filter_value(self, filter_name): - if filter_name == 'launch-template-name': + if filter_name == "launch-template-name": return self.name else: return super(LaunchTemplate, self).get_filter_value( - filter_name, "DescribeLaunchTemplates") + filter_name, "DescribeLaunchTemplates" + ) class LaunchTemplateBackend(object): @@ -4186,7 +4837,9 @@ class LaunchTemplateBackend(object): def get_launch_template_by_name(self, name): return self.get_launch_template(self.launch_template_name_to_ids[name]) - def get_launch_templates(self, template_names=None, template_ids=None, filters=None): + def get_launch_templates( + self, template_names=None, template_ids=None, filters=None + ): if template_names and not template_ids: template_ids = [] for name in template_names: @@ -4200,16 +4853,35 @@ class LaunchTemplateBackend(object): return generic_filter(filters, templates) -class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend, - RegionsAndZonesBackend, SecurityGroupBackend, AmiBackend, - VPCBackend, SubnetBackend, SubnetRouteTableAssociationBackend, - NetworkInterfaceBackend, VPNConnectionBackend, - VPCPeeringConnectionBackend, - RouteTableBackend, RouteBackend, InternetGatewayBackend, - VPCGatewayAttachmentBackend, SpotFleetBackend, - SpotRequestBackend, ElasticAddressBackend, KeyPairBackend, - DHCPOptionsSetBackend, NetworkAclBackend, VpnGatewayBackend, - CustomerGatewayBackend, NatGatewayBackend, LaunchTemplateBackend): +class EC2Backend( + BaseBackend, + InstanceBackend, + TagBackend, + EBSBackend, + RegionsAndZonesBackend, + SecurityGroupBackend, + AmiBackend, + VPCBackend, + SubnetBackend, + SubnetRouteTableAssociationBackend, + NetworkInterfaceBackend, + VPNConnectionBackend, + VPCPeeringConnectionBackend, + RouteTableBackend, + RouteBackend, + InternetGatewayBackend, + VPCGatewayAttachmentBackend, + SpotFleetBackend, + SpotRequestBackend, + ElasticAddressBackend, + KeyPairBackend, + DHCPOptionsSetBackend, + NetworkAclBackend, + VpnGatewayBackend, + CustomerGatewayBackend, + NatGatewayBackend, + LaunchTemplateBackend, +): def __init__(self, region_name): self.region_name = region_name super(EC2Backend, self).__init__() @@ -4220,20 +4892,20 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend, # docs.aws.amazon.com/AmazonVPC/latest/UserGuide/default-vpc.html # if not self.vpcs: - vpc = self.create_vpc('172.31.0.0/16') + vpc = self.create_vpc("172.31.0.0/16") else: # For now this is included for potential # backward-compatibility issues vpc = self.vpcs.values()[0] # Create default subnet for each availability zone - ip, _ = vpc.cidr_block.split('/') - ip = ip.split('.') + ip, _ = vpc.cidr_block.split("/") + ip = ip.split(".") ip[2] = 0 for zone in self.describe_availability_zones(): az_name = zone.name - cidr_block = '.'.join(str(i) for i in ip) + '/20' + cidr_block = ".".join(str(i) for i in ip) + "/20" self.create_subnet(vpc.id, cidr_block, availability_zone=az_name) ip[2] += 16 @@ -4253,49 +4925,51 @@ class EC2Backend(BaseBackend, InstanceBackend, TagBackend, EBSBackend, def do_resources_exist(self, resource_ids): for resource_id in resource_ids: resource_prefix = get_prefix(resource_id) - if resource_prefix == EC2_RESOURCE_TO_PREFIX['customer-gateway']: + if resource_prefix == EC2_RESOURCE_TO_PREFIX["customer-gateway"]: self.get_customer_gateway(customer_gateway_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['dhcp-options']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["dhcp-options"]: self.describe_dhcp_options(options_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['image']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["image"]: self.describe_images(ami_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['instance']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["instance"]: self.get_instance_by_id(instance_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['internet-gateway']: - self.describe_internet_gateways( - internet_gateway_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['launch-template']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["internet-gateway"]: + self.describe_internet_gateways(internet_gateway_ids=[resource_id]) + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["launch-template"]: self.get_launch_template(resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-acl']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["network-acl"]: self.get_all_network_acls() - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["network-interface"]: self.describe_network_interfaces( - filters={'network-interface-id': resource_id}) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['reserved-instance']: - self.raise_not_implemented_error('DescribeReservedInstances') - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['route-table']: + filters={"network-interface-id": resource_id} + ) + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["reserved-instance"]: + self.raise_not_implemented_error("DescribeReservedInstances") + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["route-table"]: self.get_route_table(route_table_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['security-group']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["security-group"]: self.describe_security_groups(group_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['snapshot']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["snapshot"]: self.get_snapshot(snapshot_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['spot-instance-request']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["spot-instance-request"]: self.describe_spot_instance_requests( - filters={'spot-instance-request-id': resource_id}) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['subnet']: + filters={"spot-instance-request-id": resource_id} + ) + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["subnet"]: self.get_subnet(subnet_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['volume']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["volume"]: self.get_volume(volume_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpc']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpc"]: self.get_vpc(vpc_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpc-peering-connection']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"]: self.get_vpc_peering_connection(vpc_pcx_id=resource_id) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpn-connection']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpn-connection"]: self.describe_vpn_connections(vpn_connection_ids=[resource_id]) - elif resource_prefix == EC2_RESOURCE_TO_PREFIX['vpn-gateway']: + elif resource_prefix == EC2_RESOURCE_TO_PREFIX["vpn-gateway"]: self.get_vpn_gateway(vpn_gateway_id=resource_id) return True -ec2_backends = {region.name: EC2Backend(region.name) - for region in RegionsAndZonesBackend.regions} +ec2_backends = { + region.name: EC2Backend(region.name) for region in RegionsAndZonesBackend.regions +} diff --git a/moto/ec2/responses/__init__.py b/moto/ec2/responses/__init__.py index d0648eb50..21cbf8249 100644 --- a/moto/ec2/responses/__init__.py +++ b/moto/ec2/responses/__init__.py @@ -70,10 +70,10 @@ class EC2Response( Windows, NatGateways, ): - @property def ec2_backend(self): from moto.ec2.models import ec2_backends + return ec2_backends[self.region] @property diff --git a/moto/ec2/responses/account_attributes.py b/moto/ec2/responses/account_attributes.py index 8a5b9a4b0..068a7c395 100644 --- a/moto/ec2/responses/account_attributes.py +++ b/moto/ec2/responses/account_attributes.py @@ -3,13 +3,12 @@ from moto.core.responses import BaseResponse class AccountAttributes(BaseResponse): - def describe_account_attributes(self): template = self.response_template(DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT) return template.render() -DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = u""" +DESCRIBE_ACCOUNT_ATTRIBUTES_RESULT = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE diff --git a/moto/ec2/responses/amazon_dev_pay.py b/moto/ec2/responses/amazon_dev_pay.py index 14df3f004..982b7f4a3 100644 --- a/moto/ec2/responses/amazon_dev_pay.py +++ b/moto/ec2/responses/amazon_dev_pay.py @@ -3,7 +3,7 @@ from moto.core.responses import BaseResponse class AmazonDevPay(BaseResponse): - def confirm_product_instance(self): raise NotImplementedError( - 'AmazonDevPay.confirm_product_instance is not yet implemented') + "AmazonDevPay.confirm_product_instance is not yet implemented" + ) diff --git a/moto/ec2/responses/amis.py b/moto/ec2/responses/amis.py index 17e1e228d..6736a7175 100755 --- a/moto/ec2/responses/amis.py +++ b/moto/ec2/responses/amis.py @@ -4,76 +4,83 @@ from moto.ec2.utils import filters_from_querystring class AmisResponse(BaseResponse): - def create_image(self): - name = self.querystring.get('Name')[0] - description = self._get_param('Description', if_none='') - instance_id = self._get_param('InstanceId') - if self.is_not_dryrun('CreateImage'): + name = self.querystring.get("Name")[0] + description = self._get_param("Description", if_none="") + instance_id = self._get_param("InstanceId") + if self.is_not_dryrun("CreateImage"): image = self.ec2_backend.create_image( - instance_id, name, description, context=self) + instance_id, name, description, context=self + ) template = self.response_template(CREATE_IMAGE_RESPONSE) return template.render(image=image) def copy_image(self): - source_image_id = self._get_param('SourceImageId') - source_region = self._get_param('SourceRegion') - name = self._get_param('Name') - description = self._get_param('Description') - if self.is_not_dryrun('CopyImage'): + source_image_id = self._get_param("SourceImageId") + source_region = self._get_param("SourceRegion") + name = self._get_param("Name") + description = self._get_param("Description") + if self.is_not_dryrun("CopyImage"): image = self.ec2_backend.copy_image( - source_image_id, source_region, name, description) + source_image_id, source_region, name, description + ) template = self.response_template(COPY_IMAGE_RESPONSE) return template.render(image=image) def deregister_image(self): - ami_id = self._get_param('ImageId') - if self.is_not_dryrun('DeregisterImage'): + ami_id = self._get_param("ImageId") + if self.is_not_dryrun("DeregisterImage"): success = self.ec2_backend.deregister_image(ami_id) template = self.response_template(DEREGISTER_IMAGE_RESPONSE) return template.render(success=str(success).lower()) def describe_images(self): - ami_ids = self._get_multi_param('ImageId') + ami_ids = self._get_multi_param("ImageId") filters = filters_from_querystring(self.querystring) - owners = self._get_multi_param('Owner') - exec_users = self._get_multi_param('ExecutableBy') + owners = self._get_multi_param("Owner") + exec_users = self._get_multi_param("ExecutableBy") images = self.ec2_backend.describe_images( - ami_ids=ami_ids, filters=filters, exec_users=exec_users, - owners=owners, context=self) + ami_ids=ami_ids, + filters=filters, + exec_users=exec_users, + owners=owners, + context=self, + ) template = self.response_template(DESCRIBE_IMAGES_RESPONSE) return template.render(images=images) def describe_image_attribute(self): - ami_id = self._get_param('ImageId') + ami_id = self._get_param("ImageId") groups = self.ec2_backend.get_launch_permission_groups(ami_id) users = self.ec2_backend.get_launch_permission_users(ami_id) template = self.response_template(DESCRIBE_IMAGE_ATTRIBUTES_RESPONSE) return template.render(ami_id=ami_id, groups=groups, users=users) def modify_image_attribute(self): - ami_id = self._get_param('ImageId') - operation_type = self._get_param('OperationType') - group = self._get_param('UserGroup.1') - user_ids = self._get_multi_param('UserId') - if self.is_not_dryrun('ModifyImageAttribute'): - if (operation_type == 'add'): + ami_id = self._get_param("ImageId") + operation_type = self._get_param("OperationType") + group = self._get_param("UserGroup.1") + user_ids = self._get_multi_param("UserId") + if self.is_not_dryrun("ModifyImageAttribute"): + if operation_type == "add": self.ec2_backend.add_launch_permission( - ami_id, user_ids=user_ids, group=group) - elif (operation_type == 'remove'): + ami_id, user_ids=user_ids, group=group + ) + elif operation_type == "remove": self.ec2_backend.remove_launch_permission( - ami_id, user_ids=user_ids, group=group) + ami_id, user_ids=user_ids, group=group + ) return MODIFY_IMAGE_ATTRIBUTE_RESPONSE def register_image(self): - if self.is_not_dryrun('RegisterImage'): - raise NotImplementedError( - 'AMIs.register_image is not yet implemented') + if self.is_not_dryrun("RegisterImage"): + raise NotImplementedError("AMIs.register_image is not yet implemented") def reset_image_attribute(self): - if self.is_not_dryrun('ResetImageAttribute'): + if self.is_not_dryrun("ResetImageAttribute"): raise NotImplementedError( - 'AMIs.reset_image_attribute is not yet implemented') + "AMIs.reset_image_attribute is not yet implemented" + ) CREATE_IMAGE_RESPONSE = """ diff --git a/moto/ec2/responses/availability_zones_and_regions.py b/moto/ec2/responses/availability_zones_and_regions.py index a6e35a89c..d63e2f4ad 100644 --- a/moto/ec2/responses/availability_zones_and_regions.py +++ b/moto/ec2/responses/availability_zones_and_regions.py @@ -3,14 +3,13 @@ from moto.core.responses import BaseResponse class AvailabilityZonesAndRegions(BaseResponse): - def describe_availability_zones(self): zones = self.ec2_backend.describe_availability_zones() template = self.response_template(DESCRIBE_ZONES_RESPONSE) return template.render(zones=zones) def describe_regions(self): - region_names = self._get_multi_param('RegionName') + region_names = self._get_multi_param("RegionName") regions = self.ec2_backend.describe_regions(region_names) template = self.response_template(DESCRIBE_REGIONS_RESPONSE) return template.render(regions=regions) diff --git a/moto/ec2/responses/customer_gateways.py b/moto/ec2/responses/customer_gateways.py index 866b93045..65b93cc2e 100644 --- a/moto/ec2/responses/customer_gateways.py +++ b/moto/ec2/responses/customer_gateways.py @@ -4,21 +4,20 @@ from moto.ec2.utils import filters_from_querystring class CustomerGateways(BaseResponse): - def create_customer_gateway(self): # raise NotImplementedError('CustomerGateways(AmazonVPC).create_customer_gateway is not yet implemented') - type = self._get_param('Type') - ip_address = self._get_param('IpAddress') - bgp_asn = self._get_param('BgpAsn') + type = self._get_param("Type") + ip_address = self._get_param("IpAddress") + bgp_asn = self._get_param("BgpAsn") customer_gateway = self.ec2_backend.create_customer_gateway( - type, ip_address=ip_address, bgp_asn=bgp_asn) + type, ip_address=ip_address, bgp_asn=bgp_asn + ) template = self.response_template(CREATE_CUSTOMER_GATEWAY_RESPONSE) return template.render(customer_gateway=customer_gateway) def delete_customer_gateway(self): - customer_gateway_id = self._get_param('CustomerGatewayId') - delete_status = self.ec2_backend.delete_customer_gateway( - customer_gateway_id) + customer_gateway_id = self._get_param("CustomerGatewayId") + delete_status = self.ec2_backend.delete_customer_gateway(customer_gateway_id) template = self.response_template(DELETE_CUSTOMER_GATEWAY_RESPONSE) return template.render(customer_gateway=delete_status) diff --git a/moto/ec2/responses/dhcp_options.py b/moto/ec2/responses/dhcp_options.py index 1f740d14b..868ab85cf 100644 --- a/moto/ec2/responses/dhcp_options.py +++ b/moto/ec2/responses/dhcp_options.py @@ -1,15 +1,12 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse -from moto.ec2.utils import ( - filters_from_querystring, - dhcp_configuration_from_querystring) +from moto.ec2.utils import filters_from_querystring, dhcp_configuration_from_querystring class DHCPOptions(BaseResponse): - def associate_dhcp_options(self): - dhcp_opt_id = self._get_param('DhcpOptionsId') - vpc_id = self._get_param('VpcId') + dhcp_opt_id = self._get_param("DhcpOptionsId") + vpc_id = self._get_param("VpcId") dhcp_opt = self.ec2_backend.describe_dhcp_options([dhcp_opt_id])[0] vpc = self.ec2_backend.get_vpc(vpc_id) @@ -35,14 +32,14 @@ class DHCPOptions(BaseResponse): domain_name=domain_name, ntp_servers=ntp_servers, netbios_name_servers=netbios_name_servers, - netbios_node_type=netbios_node_type + netbios_node_type=netbios_node_type, ) template = self.response_template(CREATE_DHCP_OPTIONS_RESPONSE) return template.render(dhcp_options_set=dhcp_options_set) def delete_dhcp_options(self): - dhcp_opt_id = self._get_param('DhcpOptionsId') + dhcp_opt_id = self._get_param("DhcpOptionsId") delete_status = self.ec2_backend.delete_dhcp_options_set(dhcp_opt_id) template = self.response_template(DELETE_DHCP_OPTIONS_RESPONSE) return template.render(delete_status=delete_status) @@ -50,13 +47,12 @@ class DHCPOptions(BaseResponse): def describe_dhcp_options(self): dhcp_opt_ids = self._get_multi_param("DhcpOptionsId") filters = filters_from_querystring(self.querystring) - dhcp_opts = self.ec2_backend.get_all_dhcp_options( - dhcp_opt_ids, filters) + dhcp_opts = self.ec2_backend.get_all_dhcp_options(dhcp_opt_ids, filters) template = self.response_template(DESCRIBE_DHCP_OPTIONS_RESPONSE) return template.render(dhcp_options=dhcp_opts) -CREATE_DHCP_OPTIONS_RESPONSE = u""" +CREATE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -92,14 +88,14 @@ CREATE_DHCP_OPTIONS_RESPONSE = u""" """ -DELETE_DHCP_OPTIONS_RESPONSE = u""" +DELETE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE {{delete_status}} """ -DESCRIBE_DHCP_OPTIONS_RESPONSE = u""" +DESCRIBE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE @@ -139,7 +135,7 @@ DESCRIBE_DHCP_OPTIONS_RESPONSE = u""" """ -ASSOCIATE_DHCP_OPTIONS_RESPONSE = u""" +ASSOCIATE_DHCP_OPTIONS_RESPONSE = """ 7a62c49f-347e-4fc4-9331-6e8eEXAMPLE true diff --git a/moto/ec2/responses/elastic_block_store.py b/moto/ec2/responses/elastic_block_store.py index acd37b283..d11470242 100644 --- a/moto/ec2/responses/elastic_block_store.py +++ b/moto/ec2/responses/elastic_block_store.py @@ -4,137 +4,148 @@ from moto.ec2.utils import filters_from_querystring class ElasticBlockStore(BaseResponse): - def attach_volume(self): - volume_id = self._get_param('VolumeId') - instance_id = self._get_param('InstanceId') - device_path = self._get_param('Device') - if self.is_not_dryrun('AttachVolume'): + volume_id = self._get_param("VolumeId") + instance_id = self._get_param("InstanceId") + device_path = self._get_param("Device") + if self.is_not_dryrun("AttachVolume"): attachment = self.ec2_backend.attach_volume( - volume_id, instance_id, device_path) + volume_id, instance_id, device_path + ) template = self.response_template(ATTACHED_VOLUME_RESPONSE) return template.render(attachment=attachment) def copy_snapshot(self): - source_snapshot_id = self._get_param('SourceSnapshotId') - source_region = self._get_param('SourceRegion') - description = self._get_param('Description') - if self.is_not_dryrun('CopySnapshot'): + source_snapshot_id = self._get_param("SourceSnapshotId") + source_region = self._get_param("SourceRegion") + description = self._get_param("Description") + if self.is_not_dryrun("CopySnapshot"): snapshot = self.ec2_backend.copy_snapshot( - source_snapshot_id, source_region, description) + source_snapshot_id, source_region, description + ) template = self.response_template(COPY_SNAPSHOT_RESPONSE) return template.render(snapshot=snapshot) def create_snapshot(self): - volume_id = self._get_param('VolumeId') - description = self._get_param('Description') + volume_id = self._get_param("VolumeId") + description = self._get_param("Description") tags = self._parse_tag_specification("TagSpecification") - snapshot_tags = tags.get('snapshot', {}) - if self.is_not_dryrun('CreateSnapshot'): + snapshot_tags = tags.get("snapshot", {}) + if self.is_not_dryrun("CreateSnapshot"): snapshot = self.ec2_backend.create_snapshot(volume_id, description) snapshot.add_tags(snapshot_tags) template = self.response_template(CREATE_SNAPSHOT_RESPONSE) return template.render(snapshot=snapshot) def create_volume(self): - size = self._get_param('Size') - zone = self._get_param('AvailabilityZone') - snapshot_id = self._get_param('SnapshotId') + size = self._get_param("Size") + zone = self._get_param("AvailabilityZone") + snapshot_id = self._get_param("SnapshotId") tags = self._parse_tag_specification("TagSpecification") - volume_tags = tags.get('volume', {}) - encrypted = self._get_param('Encrypted', if_none=False) - if self.is_not_dryrun('CreateVolume'): - volume = self.ec2_backend.create_volume( - size, zone, snapshot_id, encrypted) + volume_tags = tags.get("volume", {}) + encrypted = self._get_param("Encrypted", if_none=False) + if self.is_not_dryrun("CreateVolume"): + volume = self.ec2_backend.create_volume(size, zone, snapshot_id, encrypted) volume.add_tags(volume_tags) template = self.response_template(CREATE_VOLUME_RESPONSE) return template.render(volume=volume) def delete_snapshot(self): - snapshot_id = self._get_param('SnapshotId') - if self.is_not_dryrun('DeleteSnapshot'): + snapshot_id = self._get_param("SnapshotId") + if self.is_not_dryrun("DeleteSnapshot"): self.ec2_backend.delete_snapshot(snapshot_id) return DELETE_SNAPSHOT_RESPONSE def delete_volume(self): - volume_id = self._get_param('VolumeId') - if self.is_not_dryrun('DeleteVolume'): + volume_id = self._get_param("VolumeId") + if self.is_not_dryrun("DeleteVolume"): self.ec2_backend.delete_volume(volume_id) return DELETE_VOLUME_RESPONSE def describe_snapshots(self): filters = filters_from_querystring(self.querystring) - snapshot_ids = self._get_multi_param('SnapshotId') - snapshots = self.ec2_backend.describe_snapshots(snapshot_ids=snapshot_ids, filters=filters) + snapshot_ids = self._get_multi_param("SnapshotId") + snapshots = self.ec2_backend.describe_snapshots( + snapshot_ids=snapshot_ids, filters=filters + ) template = self.response_template(DESCRIBE_SNAPSHOTS_RESPONSE) return template.render(snapshots=snapshots) def describe_volumes(self): filters = filters_from_querystring(self.querystring) - volume_ids = self._get_multi_param('VolumeId') - volumes = self.ec2_backend.describe_volumes(volume_ids=volume_ids, filters=filters) + volume_ids = self._get_multi_param("VolumeId") + volumes = self.ec2_backend.describe_volumes( + volume_ids=volume_ids, filters=filters + ) template = self.response_template(DESCRIBE_VOLUMES_RESPONSE) return template.render(volumes=volumes) def describe_volume_attribute(self): raise NotImplementedError( - 'ElasticBlockStore.describe_volume_attribute is not yet implemented') + "ElasticBlockStore.describe_volume_attribute is not yet implemented" + ) def describe_volume_status(self): raise NotImplementedError( - 'ElasticBlockStore.describe_volume_status is not yet implemented') + "ElasticBlockStore.describe_volume_status is not yet implemented" + ) def detach_volume(self): - volume_id = self._get_param('VolumeId') - instance_id = self._get_param('InstanceId') - device_path = self._get_param('Device') - if self.is_not_dryrun('DetachVolume'): + volume_id = self._get_param("VolumeId") + instance_id = self._get_param("InstanceId") + device_path = self._get_param("Device") + if self.is_not_dryrun("DetachVolume"): attachment = self.ec2_backend.detach_volume( - volume_id, instance_id, device_path) + volume_id, instance_id, device_path + ) template = self.response_template(DETATCH_VOLUME_RESPONSE) return template.render(attachment=attachment) def enable_volume_io(self): - if self.is_not_dryrun('EnableVolumeIO'): + if self.is_not_dryrun("EnableVolumeIO"): raise NotImplementedError( - 'ElasticBlockStore.enable_volume_io is not yet implemented') + "ElasticBlockStore.enable_volume_io is not yet implemented" + ) def import_volume(self): - if self.is_not_dryrun('ImportVolume'): + if self.is_not_dryrun("ImportVolume"): raise NotImplementedError( - 'ElasticBlockStore.import_volume is not yet implemented') + "ElasticBlockStore.import_volume is not yet implemented" + ) def describe_snapshot_attribute(self): - snapshot_id = self._get_param('SnapshotId') - groups = self.ec2_backend.get_create_volume_permission_groups( - snapshot_id) - template = self.response_template( - DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE) + snapshot_id = self._get_param("SnapshotId") + groups = self.ec2_backend.get_create_volume_permission_groups(snapshot_id) + template = self.response_template(DESCRIBE_SNAPSHOT_ATTRIBUTES_RESPONSE) return template.render(snapshot_id=snapshot_id, groups=groups) def modify_snapshot_attribute(self): - snapshot_id = self._get_param('SnapshotId') - operation_type = self._get_param('OperationType') - group = self._get_param('UserGroup.1') - user_id = self._get_param('UserId.1') - if self.is_not_dryrun('ModifySnapshotAttribute'): - if (operation_type == 'add'): + snapshot_id = self._get_param("SnapshotId") + operation_type = self._get_param("OperationType") + group = self._get_param("UserGroup.1") + user_id = self._get_param("UserId.1") + if self.is_not_dryrun("ModifySnapshotAttribute"): + if operation_type == "add": self.ec2_backend.add_create_volume_permission( - snapshot_id, user_id=user_id, group=group) - elif (operation_type == 'remove'): + snapshot_id, user_id=user_id, group=group + ) + elif operation_type == "remove": self.ec2_backend.remove_create_volume_permission( - snapshot_id, user_id=user_id, group=group) + snapshot_id, user_id=user_id, group=group + ) return MODIFY_SNAPSHOT_ATTRIBUTE_RESPONSE def modify_volume_attribute(self): - if self.is_not_dryrun('ModifyVolumeAttribute'): + if self.is_not_dryrun("ModifyVolumeAttribute"): raise NotImplementedError( - 'ElasticBlockStore.modify_volume_attribute is not yet implemented') + "ElasticBlockStore.modify_volume_attribute is not yet implemented" + ) def reset_snapshot_attribute(self): - if self.is_not_dryrun('ResetSnapshotAttribute'): + if self.is_not_dryrun("ResetSnapshotAttribute"): raise NotImplementedError( - 'ElasticBlockStore.reset_snapshot_attribute is not yet implemented') + "ElasticBlockStore.reset_snapshot_attribute is not yet implemented" + ) CREATE_VOLUME_RESPONSE = """ diff --git a/moto/ec2/responses/elastic_ip_addresses.py b/moto/ec2/responses/elastic_ip_addresses.py index 6e1c9fe38..e25922706 100644 --- a/moto/ec2/responses/elastic_ip_addresses.py +++ b/moto/ec2/responses/elastic_ip_addresses.py @@ -4,14 +4,14 @@ from moto.ec2.utils import filters_from_querystring class ElasticIPAddresses(BaseResponse): - def allocate_address(self): - domain = self._get_param('Domain', if_none='standard') - reallocate_address = self._get_param('Address', if_none=None) - if self.is_not_dryrun('AllocateAddress'): + domain = self._get_param("Domain", if_none="standard") + reallocate_address = self._get_param("Address", if_none=None) + if self.is_not_dryrun("AllocateAddress"): if reallocate_address: address = self.ec2_backend.allocate_address( - domain, address=reallocate_address) + domain, address=reallocate_address + ) else: address = self.ec2_backend.allocate_address(domain) template = self.response_template(ALLOCATE_ADDRESS_RESPONSE) @@ -21,73 +21,92 @@ class ElasticIPAddresses(BaseResponse): instance = eni = None if "InstanceId" in self.querystring: - instance = self.ec2_backend.get_instance( - self._get_param('InstanceId')) + instance = self.ec2_backend.get_instance(self._get_param("InstanceId")) elif "NetworkInterfaceId" in self.querystring: eni = self.ec2_backend.get_network_interface( - self._get_param('NetworkInterfaceId')) + self._get_param("NetworkInterfaceId") + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect InstanceId/NetworkId parameter.") + "MissingParameter", + "Invalid request, expect InstanceId/NetworkId parameter.", + ) reassociate = False if "AllowReassociation" in self.querystring: - reassociate = self._get_param('AllowReassociation') == "true" + reassociate = self._get_param("AllowReassociation") == "true" - if self.is_not_dryrun('AssociateAddress'): + if self.is_not_dryrun("AssociateAddress"): if instance or eni: if "PublicIp" in self.querystring: eip = self.ec2_backend.associate_address( - instance=instance, eni=eni, - address=self._get_param('PublicIp'), reassociate=reassociate) + instance=instance, + eni=eni, + address=self._get_param("PublicIp"), + reassociate=reassociate, + ) elif "AllocationId" in self.querystring: eip = self.ec2_backend.associate_address( - instance=instance, eni=eni, - allocation_id=self._get_param('AllocationId'), reassociate=reassociate) + instance=instance, + eni=eni, + allocation_id=self._get_param("AllocationId"), + reassociate=reassociate, + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") + "MissingParameter", + "Invalid request, expect PublicIp/AllocationId parameter.", + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect either instance or ENI.") + "MissingParameter", + "Invalid request, expect either instance or ENI.", + ) template = self.response_template(ASSOCIATE_ADDRESS_RESPONSE) return template.render(address=eip) def describe_addresses(self): - allocation_ids = self._get_multi_param('AllocationId') - public_ips = self._get_multi_param('PublicIp') + allocation_ids = self._get_multi_param("AllocationId") + public_ips = self._get_multi_param("PublicIp") filters = filters_from_querystring(self.querystring) addresses = self.ec2_backend.describe_addresses( - allocation_ids, public_ips, filters) + allocation_ids, public_ips, filters + ) template = self.response_template(DESCRIBE_ADDRESS_RESPONSE) return template.render(addresses=addresses) def disassociate_address(self): - if self.is_not_dryrun('DisAssociateAddress'): + if self.is_not_dryrun("DisAssociateAddress"): if "PublicIp" in self.querystring: self.ec2_backend.disassociate_address( - address=self._get_param('PublicIp')) + address=self._get_param("PublicIp") + ) elif "AssociationId" in self.querystring: self.ec2_backend.disassociate_address( - association_id=self._get_param('AssociationId')) + association_id=self._get_param("AssociationId") + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect PublicIp/AssociationId parameter.") + "MissingParameter", + "Invalid request, expect PublicIp/AssociationId parameter.", + ) return self.response_template(DISASSOCIATE_ADDRESS_RESPONSE).render() def release_address(self): - if self.is_not_dryrun('ReleaseAddress'): + if self.is_not_dryrun("ReleaseAddress"): if "PublicIp" in self.querystring: - self.ec2_backend.release_address( - address=self._get_param('PublicIp')) + self.ec2_backend.release_address(address=self._get_param("PublicIp")) elif "AllocationId" in self.querystring: self.ec2_backend.release_address( - allocation_id=self._get_param('AllocationId')) + allocation_id=self._get_param("AllocationId") + ) else: self.ec2_backend.raise_error( - "MissingParameter", "Invalid request, expect PublicIp/AllocationId parameter.") + "MissingParameter", + "Invalid request, expect PublicIp/AllocationId parameter.", + ) return self.response_template(RELEASE_ADDRESS_RESPONSE).render() diff --git a/moto/ec2/responses/elastic_network_interfaces.py b/moto/ec2/responses/elastic_network_interfaces.py index 9c37e70da..fa014b219 100644 --- a/moto/ec2/responses/elastic_network_interfaces.py +++ b/moto/ec2/responses/elastic_network_interfaces.py @@ -4,71 +4,69 @@ from moto.ec2.utils import filters_from_querystring class ElasticNetworkInterfaces(BaseResponse): - def create_network_interface(self): - subnet_id = self._get_param('SubnetId') - private_ip_address = self._get_param('PrivateIpAddress') - groups = self._get_multi_param('SecurityGroupId') + subnet_id = self._get_param("SubnetId") + private_ip_address = self._get_param("PrivateIpAddress") + groups = self._get_multi_param("SecurityGroupId") subnet = self.ec2_backend.get_subnet(subnet_id) - description = self._get_param('Description') - if self.is_not_dryrun('CreateNetworkInterface'): + description = self._get_param("Description") + if self.is_not_dryrun("CreateNetworkInterface"): eni = self.ec2_backend.create_network_interface( - subnet, private_ip_address, groups, description) - template = self.response_template( - CREATE_NETWORK_INTERFACE_RESPONSE) + subnet, private_ip_address, groups, description + ) + template = self.response_template(CREATE_NETWORK_INTERFACE_RESPONSE) return template.render(eni=eni) def delete_network_interface(self): - eni_id = self._get_param('NetworkInterfaceId') - if self.is_not_dryrun('DeleteNetworkInterface'): + eni_id = self._get_param("NetworkInterfaceId") + if self.is_not_dryrun("DeleteNetworkInterface"): self.ec2_backend.delete_network_interface(eni_id) - template = self.response_template( - DELETE_NETWORK_INTERFACE_RESPONSE) + template = self.response_template(DELETE_NETWORK_INTERFACE_RESPONSE) return template.render() def describe_network_interface_attribute(self): raise NotImplementedError( - 'ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented') + "ElasticNetworkInterfaces(AmazonVPC).describe_network_interface_attribute is not yet implemented" + ) def describe_network_interfaces(self): - eni_ids = self._get_multi_param('NetworkInterfaceId') + eni_ids = self._get_multi_param("NetworkInterfaceId") filters = filters_from_querystring(self.querystring) enis = self.ec2_backend.get_all_network_interfaces(eni_ids, filters) template = self.response_template(DESCRIBE_NETWORK_INTERFACES_RESPONSE) return template.render(enis=enis) def attach_network_interface(self): - eni_id = self._get_param('NetworkInterfaceId') - instance_id = self._get_param('InstanceId') - device_index = self._get_param('DeviceIndex') - if self.is_not_dryrun('AttachNetworkInterface'): + eni_id = self._get_param("NetworkInterfaceId") + instance_id = self._get_param("InstanceId") + device_index = self._get_param("DeviceIndex") + if self.is_not_dryrun("AttachNetworkInterface"): attachment_id = self.ec2_backend.attach_network_interface( - eni_id, instance_id, device_index) - template = self.response_template( - ATTACH_NETWORK_INTERFACE_RESPONSE) + eni_id, instance_id, device_index + ) + template = self.response_template(ATTACH_NETWORK_INTERFACE_RESPONSE) return template.render(attachment_id=attachment_id) def detach_network_interface(self): - attachment_id = self._get_param('AttachmentId') - if self.is_not_dryrun('DetachNetworkInterface'): + attachment_id = self._get_param("AttachmentId") + if self.is_not_dryrun("DetachNetworkInterface"): self.ec2_backend.detach_network_interface(attachment_id) - template = self.response_template( - DETACH_NETWORK_INTERFACE_RESPONSE) + template = self.response_template(DETACH_NETWORK_INTERFACE_RESPONSE) return template.render() def modify_network_interface_attribute(self): # Currently supports modifying one and only one security group - eni_id = self._get_param('NetworkInterfaceId') - group_id = self._get_param('SecurityGroupId.1') - if self.is_not_dryrun('ModifyNetworkInterface'): - self.ec2_backend.modify_network_interface_attribute( - eni_id, group_id) + eni_id = self._get_param("NetworkInterfaceId") + group_id = self._get_param("SecurityGroupId.1") + if self.is_not_dryrun("ModifyNetworkInterface"): + self.ec2_backend.modify_network_interface_attribute(eni_id, group_id) return MODIFY_NETWORK_INTERFACE_ATTRIBUTE_RESPONSE def reset_network_interface_attribute(self): - if self.is_not_dryrun('ResetNetworkInterface'): + if self.is_not_dryrun("ResetNetworkInterface"): raise NotImplementedError( - 'ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented') + "ElasticNetworkInterfaces(AmazonVPC).reset_network_interface_attribute is not yet implemented" + ) CREATE_NETWORK_INTERFACE_RESPONSE = """ diff --git a/moto/ec2/responses/general.py b/moto/ec2/responses/general.py index 262d9f8ea..5dcd73358 100644 --- a/moto/ec2/responses/general.py +++ b/moto/ec2/responses/general.py @@ -3,20 +3,19 @@ from moto.core.responses import BaseResponse class General(BaseResponse): - def get_console_output(self): - instance_id = self._get_param('InstanceId') + instance_id = self._get_param("InstanceId") if not instance_id: # For compatibility with boto. # See: https://github.com/spulec/moto/pull/1152#issuecomment-332487599 - instance_id = self._get_multi_param('InstanceId')[0] + instance_id = self._get_multi_param("InstanceId")[0] instance = self.ec2_backend.get_instance(instance_id) template = self.response_template(GET_CONSOLE_OUTPUT_RESULT) return template.render(instance=instance) -GET_CONSOLE_OUTPUT_RESULT = ''' +GET_CONSOLE_OUTPUT_RESULT = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE {{ instance.id }} @@ -29,4 +28,4 @@ R0hNRU0gYXZhaWxhYmxlLgo3MjdNQiBMT1dNRU0gYXZhaWxhYmxlLgpOWCAoRXhlY3V0ZSBEaXNh YmxlKSBwcm90ZWN0aW9uOiBhY3RpdmUKSVJRIGxvY2t1cCBkZXRlY3Rpb24gZGlzYWJsZWQKQnVp bHQgMSB6b25lbGlzdHMKS2VybmVsIGNvbW1hbmQgbGluZTogcm9vdD0vZGV2L3NkYTEgcm8gNApF bmFibGluZyBmYXN0IEZQVSBzYXZlIGFuZCByZXN0b3JlLi4uIGRvbmUuCg== -''' +""" diff --git a/moto/ec2/responses/instances.py b/moto/ec2/responses/instances.py index 28123b995..4b7a20a17 100644 --- a/moto/ec2/responses/instances.py +++ b/moto/ec2/responses/instances.py @@ -4,20 +4,19 @@ from boto.ec2.instancetype import InstanceType from moto.autoscaling import autoscaling_backends from moto.core.responses import BaseResponse from moto.core.utils import camelcase_to_underscores -from moto.ec2.utils import filters_from_querystring, \ - dict_from_querystring +from moto.ec2.utils import filters_from_querystring, dict_from_querystring from moto.elbv2 import elbv2_backends class InstanceResponse(BaseResponse): - def describe_instances(self): filter_dict = filters_from_querystring(self.querystring) - instance_ids = self._get_multi_param('InstanceId') + instance_ids = self._get_multi_param("InstanceId") token = self._get_param("NextToken") if instance_ids: reservations = self.ec2_backend.get_reservations_by_instance_ids( - instance_ids, filters=filter_dict) + instance_ids, filters=filter_dict + ) else: reservations = self.ec2_backend.all_reservations(filters=filter_dict) @@ -26,47 +25,66 @@ class InstanceResponse(BaseResponse): start = reservation_ids.index(token) + 1 else: start = 0 - max_results = int(self._get_param('MaxResults', 100)) - reservations_resp = reservations[start:start + max_results] + max_results = int(self._get_param("MaxResults", 100)) + reservations_resp = reservations[start : start + max_results] next_token = None if max_results and len(reservations) > (start + max_results): next_token = reservations_resp[-1].id template = self.response_template(EC2_DESCRIBE_INSTANCES) - return template.render(reservations=reservations_resp, next_token=next_token).replace('True', 'true').replace('False', 'false') + return ( + template.render(reservations=reservations_resp, next_token=next_token) + .replace("True", "true") + .replace("False", "false") + ) def run_instances(self): - min_count = int(self._get_param('MinCount', if_none='1')) - image_id = self._get_param('ImageId') - owner_id = self._get_param('OwnerId') - user_data = self._get_param('UserData') - security_group_names = self._get_multi_param('SecurityGroup') - security_group_ids = self._get_multi_param('SecurityGroupId') + min_count = int(self._get_param("MinCount", if_none="1")) + image_id = self._get_param("ImageId") + owner_id = self._get_param("OwnerId") + user_data = self._get_param("UserData") + security_group_names = self._get_multi_param("SecurityGroup") + security_group_ids = self._get_multi_param("SecurityGroupId") nics = dict_from_querystring("NetworkInterface", self.querystring) - instance_type = self._get_param('InstanceType', if_none='m1.small') - placement = self._get_param('Placement.AvailabilityZone') - subnet_id = self._get_param('SubnetId') - private_ip = self._get_param('PrivateIpAddress') - associate_public_ip = self._get_param('AssociatePublicIpAddress') - key_name = self._get_param('KeyName') - ebs_optimized = self._get_param('EbsOptimized') - instance_initiated_shutdown_behavior = self._get_param("InstanceInitiatedShutdownBehavior") + instance_type = self._get_param("InstanceType", if_none="m1.small") + placement = self._get_param("Placement.AvailabilityZone") + subnet_id = self._get_param("SubnetId") + private_ip = self._get_param("PrivateIpAddress") + associate_public_ip = self._get_param("AssociatePublicIpAddress") + key_name = self._get_param("KeyName") + ebs_optimized = self._get_param("EbsOptimized") + instance_initiated_shutdown_behavior = self._get_param( + "InstanceInitiatedShutdownBehavior" + ) tags = self._parse_tag_specification("TagSpecification") region_name = self.region - if self.is_not_dryrun('RunInstance'): + if self.is_not_dryrun("RunInstance"): new_reservation = self.ec2_backend.add_instances( - image_id, min_count, user_data, security_group_names, - instance_type=instance_type, placement=placement, region_name=region_name, subnet_id=subnet_id, - owner_id=owner_id, key_name=key_name, security_group_ids=security_group_ids, - nics=nics, private_ip=private_ip, associate_public_ip=associate_public_ip, - tags=tags, ebs_optimized=ebs_optimized, instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior) + image_id, + min_count, + user_data, + security_group_names, + instance_type=instance_type, + placement=placement, + region_name=region_name, + subnet_id=subnet_id, + owner_id=owner_id, + key_name=key_name, + security_group_ids=security_group_ids, + nics=nics, + private_ip=private_ip, + associate_public_ip=associate_public_ip, + tags=tags, + ebs_optimized=ebs_optimized, + instance_initiated_shutdown_behavior=instance_initiated_shutdown_behavior, + ) template = self.response_template(EC2_RUN_INSTANCES) return template.render(reservation=new_reservation) def terminate_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('TerminateInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("TerminateInstance"): instances = self.ec2_backend.terminate_instances(instance_ids) autoscaling_backends[self.region].notify_terminate_instances(instance_ids) elbv2_backends[self.region].notify_terminate_instances(instance_ids) @@ -74,33 +92,32 @@ class InstanceResponse(BaseResponse): return template.render(instances=instances) def reboot_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('RebootInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("RebootInstance"): instances = self.ec2_backend.reboot_instances(instance_ids) template = self.response_template(EC2_REBOOT_INSTANCES) return template.render(instances=instances) def stop_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('StopInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("StopInstance"): instances = self.ec2_backend.stop_instances(instance_ids) template = self.response_template(EC2_STOP_INSTANCES) return template.render(instances=instances) def start_instances(self): - instance_ids = self._get_multi_param('InstanceId') - if self.is_not_dryrun('StartInstance'): + instance_ids = self._get_multi_param("InstanceId") + if self.is_not_dryrun("StartInstance"): instances = self.ec2_backend.start_instances(instance_ids) template = self.response_template(EC2_START_INSTANCES) return template.render(instances=instances) def describe_instance_status(self): - instance_ids = self._get_multi_param('InstanceId') - include_all_instances = self._get_param('IncludeAllInstances') == 'true' + instance_ids = self._get_multi_param("InstanceId") + include_all_instances = self._get_param("IncludeAllInstances") == "true" if instance_ids: - instances = self.ec2_backend.get_multi_instances_by_id( - instance_ids) + instances = self.ec2_backend.get_multi_instances_by_id(instance_ids) elif include_all_instances: instances = self.ec2_backend.all_instances() else: @@ -110,40 +127,45 @@ class InstanceResponse(BaseResponse): return template.render(instances=instances) def describe_instance_types(self): - instance_types = [InstanceType( - name='t1.micro', cores=1, memory=644874240, disk=0)] + instance_types = [ + InstanceType(name="t1.micro", cores=1, memory=644874240, disk=0) + ] template = self.response_template(EC2_DESCRIBE_INSTANCE_TYPES) return template.render(instance_types=instance_types) def describe_instance_attribute(self): # TODO this and modify below should raise IncorrectInstanceState if # instance not in stopped state - attribute = self._get_param('Attribute') - instance_id = self._get_param('InstanceId') + attribute = self._get_param("Attribute") + instance_id = self._get_param("InstanceId") instance, value = self.ec2_backend.describe_instance_attribute( - instance_id, attribute) + instance_id, attribute + ) if attribute == "groupSet": - template = self.response_template( - EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE) + template = self.response_template(EC2_DESCRIBE_INSTANCE_GROUPSET_ATTRIBUTE) else: template = self.response_template(EC2_DESCRIBE_INSTANCE_ATTRIBUTE) return template.render(instance=instance, attribute=attribute, value=value) def modify_instance_attribute(self): - handlers = [self._dot_value_instance_attribute_handler, - self._block_device_mapping_handler, - self._security_grp_instance_attribute_handler] + handlers = [ + self._dot_value_instance_attribute_handler, + self._block_device_mapping_handler, + self._security_grp_instance_attribute_handler, + ] for handler in handlers: success = handler() if success: return success - msg = "This specific call to ModifyInstanceAttribute has not been" \ - " implemented in Moto yet. Feel free to open an issue at" \ - " https://github.com/spulec/moto/issues" + msg = ( + "This specific call to ModifyInstanceAttribute has not been" + " implemented in Moto yet. Feel free to open an issue at" + " https://github.com/spulec/moto/issues" + ) raise NotImplementedError(msg) def _block_device_mapping_handler(self): @@ -166,8 +188,8 @@ class InstanceResponse(BaseResponse): configuration, but it should be trivial to add anything else. """ mapping_counter = 1 - mapping_device_name_fmt = 'BlockDeviceMapping.%s.DeviceName' - mapping_del_on_term_fmt = 'BlockDeviceMapping.%s.Ebs.DeleteOnTermination' + mapping_device_name_fmt = "BlockDeviceMapping.%s.DeviceName" + mapping_del_on_term_fmt = "BlockDeviceMapping.%s.Ebs.DeleteOnTermination" while True: mapping_device_name = mapping_device_name_fmt % mapping_counter if mapping_device_name not in self.querystring.keys(): @@ -175,15 +197,14 @@ class InstanceResponse(BaseResponse): mapping_del_on_term = mapping_del_on_term_fmt % mapping_counter del_on_term_value_str = self.querystring[mapping_del_on_term][0] - del_on_term_value = True if 'true' == del_on_term_value_str else False + del_on_term_value = True if "true" == del_on_term_value_str else False device_name_value = self.querystring[mapping_device_name][0] - instance_id = self._get_param('InstanceId') + instance_id = self._get_param("InstanceId") instance = self.ec2_backend.get_instance(instance_id) - if self.is_not_dryrun('ModifyInstanceAttribute'): - block_device_type = instance.block_device_mapping[ - device_name_value] + if self.is_not_dryrun("ModifyInstanceAttribute"): + block_device_type = instance.block_device_mapping[device_name_value] block_device_type.delete_on_termination = del_on_term_value # +1 for the next device @@ -195,32 +216,33 @@ class InstanceResponse(BaseResponse): def _dot_value_instance_attribute_handler(self): attribute_key = None for key, value in self.querystring.items(): - if '.Value' in key: + if ".Value" in key: attribute_key = key break if not attribute_key: return - if self.is_not_dryrun('Modify' + attribute_key.split(".")[0]): + if self.is_not_dryrun("Modify" + attribute_key.split(".")[0]): value = self.querystring.get(attribute_key)[0] - normalized_attribute = camelcase_to_underscores( - attribute_key.split(".")[0]) - instance_id = self._get_param('InstanceId') + normalized_attribute = camelcase_to_underscores(attribute_key.split(".")[0]) + instance_id = self._get_param("InstanceId") self.ec2_backend.modify_instance_attribute( - instance_id, normalized_attribute, value) + instance_id, normalized_attribute, value + ) return EC2_MODIFY_INSTANCE_ATTRIBUTE def _security_grp_instance_attribute_handler(self): new_security_grp_list = [] for key, value in self.querystring.items(): - if 'GroupId.' in key: + if "GroupId." in key: new_security_grp_list.append(self.querystring.get(key)[0]) - instance_id = self._get_param('InstanceId') - if self.is_not_dryrun('ModifyInstanceSecurityGroups'): + instance_id = self._get_param("InstanceId") + if self.is_not_dryrun("ModifyInstanceSecurityGroups"): self.ec2_backend.modify_instance_security_groups( - instance_id, new_security_grp_list) + instance_id, new_security_grp_list + ) return EC2_MODIFY_INSTANCE_ATTRIBUTE diff --git a/moto/ec2/responses/internet_gateways.py b/moto/ec2/responses/internet_gateways.py index ebea14adf..d232b3b05 100644 --- a/moto/ec2/responses/internet_gateways.py +++ b/moto/ec2/responses/internet_gateways.py @@ -1,29 +1,26 @@ from __future__ import unicode_literals from moto.core.responses import BaseResponse -from moto.ec2.utils import ( - filters_from_querystring, -) +from moto.ec2.utils import filters_from_querystring class InternetGateways(BaseResponse): - def attach_internet_gateway(self): - igw_id = self._get_param('InternetGatewayId') - vpc_id = self._get_param('VpcId') - if self.is_not_dryrun('AttachInternetGateway'): + igw_id = self._get_param("InternetGatewayId") + vpc_id = self._get_param("VpcId") + if self.is_not_dryrun("AttachInternetGateway"): self.ec2_backend.attach_internet_gateway(igw_id, vpc_id) template = self.response_template(ATTACH_INTERNET_GATEWAY_RESPONSE) return template.render() def create_internet_gateway(self): - if self.is_not_dryrun('CreateInternetGateway'): + if self.is_not_dryrun("CreateInternetGateway"): igw = self.ec2_backend.create_internet_gateway() template = self.response_template(CREATE_INTERNET_GATEWAY_RESPONSE) return template.render(internet_gateway=igw) def delete_internet_gateway(self): - igw_id = self._get_param('InternetGatewayId') - if self.is_not_dryrun('DeleteInternetGateway'): + igw_id = self._get_param("InternetGatewayId") + if self.is_not_dryrun("DeleteInternetGateway"): self.ec2_backend.delete_internet_gateway(igw_id) template = self.response_template(DELETE_INTERNET_GATEWAY_RESPONSE) return template.render() @@ -33,10 +30,10 @@ class InternetGateways(BaseResponse): if "InternetGatewayId.1" in self.querystring: igw_ids = self._get_multi_param("InternetGatewayId") igws = self.ec2_backend.describe_internet_gateways( - igw_ids, filters=filter_dict) + igw_ids, filters=filter_dict + ) else: - igws = self.ec2_backend.describe_internet_gateways( - filters=filter_dict) + igws = self.ec2_backend.describe_internet_gateways(filters=filter_dict) template = self.response_template(DESCRIBE_INTERNET_GATEWAYS_RESPONSE) return template.render(internet_gateways=igws) @@ -44,20 +41,20 @@ class InternetGateways(BaseResponse): def detach_internet_gateway(self): # TODO validate no instances with EIPs in VPC before detaching # raise else DependencyViolationError() - igw_id = self._get_param('InternetGatewayId') - vpc_id = self._get_param('VpcId') - if self.is_not_dryrun('DetachInternetGateway'): + igw_id = self._get_param("InternetGatewayId") + vpc_id = self._get_param("VpcId") + if self.is_not_dryrun("DetachInternetGateway"): self.ec2_backend.detach_internet_gateway(igw_id, vpc_id) template = self.response_template(DETACH_INTERNET_GATEWAY_RESPONSE) return template.render() -ATTACH_INTERNET_GATEWAY_RESPONSE = u""" +ATTACH_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE true """ -CREATE_INTERNET_GATEWAY_RESPONSE = u""" +CREATE_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE {{ internet_gateway.id }} @@ -75,12 +72,12 @@ CREATE_INTERNET_GATEWAY_RESPONSE = u""" +DELETE_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE true """ -DESCRIBE_INTERNET_GATEWAYS_RESPONSE = u""" 59dbff89-35bd-4eac-99ed-be587EXAMPLE @@ -112,7 +109,7 @@ DESCRIBE_INTERNET_GATEWAYS_RESPONSE = u""" """ -DETACH_INTERNET_GATEWAY_RESPONSE = u""" +DETACH_INTERNET_GATEWAY_RESPONSE = """ 59dbff89-35bd-4eac-99ed-be587EXAMPLE true """ diff --git a/moto/ec2/responses/ip_addresses.py b/moto/ec2/responses/ip_addresses.py index fab5cbddc..789abfdec 100644 --- a/moto/ec2/responses/ip_addresses.py +++ b/moto/ec2/responses/ip_addresses.py @@ -4,13 +4,14 @@ from moto.core.responses import BaseResponse class IPAddresses(BaseResponse): - def assign_private_ip_addresses(self): - if self.is_not_dryrun('AssignPrivateIPAddress'): + if self.is_not_dryrun("AssignPrivateIPAddress"): raise NotImplementedError( - 'IPAddresses.assign_private_ip_addresses is not yet implemented') + "IPAddresses.assign_private_ip_addresses is not yet implemented" + ) def unassign_private_ip_addresses(self): - if self.is_not_dryrun('UnAssignPrivateIPAddress'): + if self.is_not_dryrun("UnAssignPrivateIPAddress"): raise NotImplementedError( - 'IPAddresses.unassign_private_ip_addresses is not yet implemented') + "IPAddresses.unassign_private_ip_addresses is not yet implemented" + ) diff --git a/moto/ec2/responses/key_pairs.py b/moto/ec2/responses/key_pairs.py index d927bddda..fa2e60904 100644 --- a/moto/ec2/responses/key_pairs.py +++ b/moto/ec2/responses/key_pairs.py @@ -5,32 +5,32 @@ from moto.ec2.utils import filters_from_querystring class KeyPairs(BaseResponse): - def create_key_pair(self): - name = self._get_param('KeyName') - if self.is_not_dryrun('CreateKeyPair'): + name = self._get_param("KeyName") + if self.is_not_dryrun("CreateKeyPair"): keypair = self.ec2_backend.create_key_pair(name) template = self.response_template(CREATE_KEY_PAIR_RESPONSE) return template.render(keypair=keypair) def delete_key_pair(self): - name = self._get_param('KeyName') - if self.is_not_dryrun('DeleteKeyPair'): - success = six.text_type( - self.ec2_backend.delete_key_pair(name)).lower() - return self.response_template(DELETE_KEY_PAIR_RESPONSE).render(success=success) + name = self._get_param("KeyName") + if self.is_not_dryrun("DeleteKeyPair"): + success = six.text_type(self.ec2_backend.delete_key_pair(name)).lower() + return self.response_template(DELETE_KEY_PAIR_RESPONSE).render( + success=success + ) def describe_key_pairs(self): - names = self._get_multi_param('KeyName') + names = self._get_multi_param("KeyName") filters = filters_from_querystring(self.querystring) keypairs = self.ec2_backend.describe_key_pairs(names, filters) template = self.response_template(DESCRIBE_KEY_PAIRS_RESPONSE) return template.render(keypairs=keypairs) def import_key_pair(self): - name = self._get_param('KeyName') - material = self._get_param('PublicKeyMaterial') - if self.is_not_dryrun('ImportKeyPair'): + name = self._get_param("KeyName") + material = self._get_param("PublicKeyMaterial") + if self.is_not_dryrun("ImportKeyPair"): keypair = self.ec2_backend.import_key_pair(name, material) template = self.response_template(IMPORT_KEYPAIR_RESPONSE) return template.render(keypair=keypair) diff --git a/moto/ec2/responses/launch_templates.py b/moto/ec2/responses/launch_templates.py index a8d92a928..22faba539 100644 --- a/moto/ec2/responses/launch_templates.py +++ b/moto/ec2/responses/launch_templates.py @@ -10,9 +10,9 @@ from xml.dom import minidom def xml_root(name): - root = ElementTree.Element(name, { - "xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/" - }) + root = ElementTree.Element( + name, {"xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/"} + ) request_id = str(uuid.uuid4()) + "example" ElementTree.SubElement(root, "requestId").text = request_id @@ -22,10 +22,10 @@ def xml_root(name): def xml_serialize(tree, key, value): name = key[0].lower() + key[1:] if isinstance(value, list): - if name[-1] == 's': + if name[-1] == "s": name = name[:-1] - name = name + 'Set' + name = name + "Set" node = ElementTree.SubElement(tree, name) @@ -36,17 +36,19 @@ def xml_serialize(tree, key, value): xml_serialize(node, dictkey, dictvalue) elif isinstance(value, list): for item in value: - xml_serialize(node, 'item', item) + xml_serialize(node, "item", item) elif value is None: pass else: - raise NotImplementedError("Don't know how to serialize \"{}\" to xml".format(value.__class__)) + raise NotImplementedError( + 'Don\'t know how to serialize "{}" to xml'.format(value.__class__) + ) def pretty_xml(tree): - rough = ElementTree.tostring(tree, 'utf-8') + rough = ElementTree.tostring(tree, "utf-8") parsed = minidom.parseString(rough) - return parsed.toprettyxml(indent=' ') + return parsed.toprettyxml(indent=" ") def parse_object(raw_data): @@ -92,68 +94,87 @@ def parse_lists(data): class LaunchTemplates(BaseResponse): def create_launch_template(self): - name = self._get_param('LaunchTemplateName') - version_description = self._get_param('VersionDescription') + name = self._get_param("LaunchTemplateName") + version_description = self._get_param("VersionDescription") tag_spec = self._parse_tag_specification("TagSpecification") - raw_template_data = self._get_dict_param('LaunchTemplateData.') + raw_template_data = self._get_dict_param("LaunchTemplateData.") parsed_template_data = parse_object(raw_template_data) - if self.is_not_dryrun('CreateLaunchTemplate'): + if self.is_not_dryrun("CreateLaunchTemplate"): if tag_spec: - if 'TagSpecifications' not in parsed_template_data: - parsed_template_data['TagSpecifications'] = [] + if "TagSpecifications" not in parsed_template_data: + parsed_template_data["TagSpecifications"] = [] converted_tag_spec = [] for resource_type, tags in six.iteritems(tag_spec): - converted_tag_spec.append({ - "ResourceType": resource_type, - "Tags": [{"Key": key, "Value": value} for key, value in six.iteritems(tags)], - }) + converted_tag_spec.append( + { + "ResourceType": resource_type, + "Tags": [ + {"Key": key, "Value": value} + for key, value in six.iteritems(tags) + ], + } + ) - parsed_template_data['TagSpecifications'].extend(converted_tag_spec) + parsed_template_data["TagSpecifications"].extend(converted_tag_spec) - template = self.ec2_backend.create_launch_template(name, version_description, parsed_template_data) + template = self.ec2_backend.create_launch_template( + name, version_description, parsed_template_data + ) version = template.default_version() tree = xml_root("CreateLaunchTemplateResponse") - xml_serialize(tree, "launchTemplate", { - "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersionNumber": template.default_version_number, - "latestVersionNumber": version.number, - "launchTemplateId": template.id, - "launchTemplateName": template.name - }) + xml_serialize( + tree, + "launchTemplate", + { + "createTime": version.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersionNumber": template.default_version_number, + "latestVersionNumber": version.number, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + }, + ) return pretty_xml(tree) def create_launch_template_version(self): - name = self._get_param('LaunchTemplateName') - tmpl_id = self._get_param('LaunchTemplateId') + name = self._get_param("LaunchTemplateName") + tmpl_id = self._get_param("LaunchTemplateId") if name: template = self.ec2_backend.get_launch_template_by_name(name) if tmpl_id: template = self.ec2_backend.get_launch_template(tmpl_id) - version_description = self._get_param('VersionDescription') + version_description = self._get_param("VersionDescription") - raw_template_data = self._get_dict_param('LaunchTemplateData.') + raw_template_data = self._get_dict_param("LaunchTemplateData.") template_data = parse_object(raw_template_data) - if self.is_not_dryrun('CreateLaunchTemplate'): + if self.is_not_dryrun("CreateLaunchTemplate"): version = template.create_version(template_data, version_description) tree = xml_root("CreateLaunchTemplateVersionResponse") - xml_serialize(tree, "launchTemplateVersion", { - "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersion": template.is_default(version), - "launchTemplateData": version.data, - "launchTemplateId": template.id, - "launchTemplateName": template.name, - "versionDescription": version.description, - "versionNumber": version.number, - }) + xml_serialize( + tree, + "launchTemplateVersion", + { + "createTime": version.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersion": template.is_default(version), + "launchTemplateData": version.data, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + "versionDescription": version.description, + "versionNumber": version.number, + }, + ) return pretty_xml(tree) # def delete_launch_template(self): @@ -163,8 +184,8 @@ class LaunchTemplates(BaseResponse): # pass def describe_launch_template_versions(self): - name = self._get_param('LaunchTemplateName') - template_id = self._get_param('LaunchTemplateId') + name = self._get_param("LaunchTemplateName") + template_id = self._get_param("LaunchTemplateId") if name: template = self.ec2_backend.get_launch_template_by_name(name) if template_id: @@ -177,12 +198,15 @@ class LaunchTemplates(BaseResponse): filters = filters_from_querystring(self.querystring) if filters: - raise FilterNotImplementedError("all filters", "DescribeLaunchTemplateVersions") + raise FilterNotImplementedError( + "all filters", "DescribeLaunchTemplateVersions" + ) - if self.is_not_dryrun('DescribeLaunchTemplateVersions'): - tree = ElementTree.Element("DescribeLaunchTemplateVersionsResponse", { - "xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/", - }) + if self.is_not_dryrun("DescribeLaunchTemplateVersions"): + tree = ElementTree.Element( + "DescribeLaunchTemplateVersionsResponse", + {"xmlns": "http://ec2.amazonaws.com/doc/2016-11-15/"}, + ) request_id = ElementTree.SubElement(tree, "requestId") request_id.text = "65cadec1-b364-4354-8ca8-4176dexample" @@ -209,16 +233,22 @@ class LaunchTemplates(BaseResponse): ret_versions = ret_versions[:max_results] for version in ret_versions: - xml_serialize(versions_node, "item", { - "createTime": version.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersion": True, - "launchTemplateData": version.data, - "launchTemplateId": template.id, - "launchTemplateName": template.name, - "versionDescription": version.description, - "versionNumber": version.number, - }) + xml_serialize( + versions_node, + "item", + { + "createTime": version.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersion": True, + "launchTemplateData": version.data, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + "versionDescription": version.description, + "versionNumber": version.number, + }, + ) return pretty_xml(tree) @@ -232,19 +262,29 @@ class LaunchTemplates(BaseResponse): tree = ElementTree.Element("DescribeLaunchTemplatesResponse") templates_node = ElementTree.SubElement(tree, "launchTemplates") - templates = self.ec2_backend.get_launch_templates(template_names=template_names, template_ids=template_ids, filters=filters) + templates = self.ec2_backend.get_launch_templates( + template_names=template_names, + template_ids=template_ids, + filters=filters, + ) templates = templates[:max_results] for template in templates: - xml_serialize(templates_node, "item", { - "createTime": template.create_time, - "createdBy": "arn:aws:iam::{OWNER_ID}:root".format(OWNER_ID=OWNER_ID), - "defaultVersionNumber": template.default_version_number, - "latestVersionNumber": template.latest_version_number, - "launchTemplateId": template.id, - "launchTemplateName": template.name, - }) + xml_serialize( + templates_node, + "item", + { + "createTime": template.create_time, + "createdBy": "arn:aws:iam::{OWNER_ID}:root".format( + OWNER_ID=OWNER_ID + ), + "defaultVersionNumber": template.default_version_number, + "latestVersionNumber": template.latest_version_number, + "launchTemplateId": template.id, + "launchTemplateName": template.name, + }, + ) return pretty_xml(tree) diff --git a/moto/ec2/responses/monitoring.py b/moto/ec2/responses/monitoring.py index 2024abe7e..4ef8db087 100644 --- a/moto/ec2/responses/monitoring.py +++ b/moto/ec2/responses/monitoring.py @@ -3,13 +3,14 @@ from moto.core.responses import BaseResponse class Monitoring(BaseResponse): - def monitor_instances(self): - if self.is_not_dryrun('MonitorInstances'): + if self.is_not_dryrun("MonitorInstances"): raise NotImplementedError( - 'Monitoring.monitor_instances is not yet implemented') + "Monitoring.monitor_instances is not yet implemented" + ) def unmonitor_instances(self): - if self.is_not_dryrun('UnMonitorInstances'): + if self.is_not_dryrun("UnMonitorInstances"): raise NotImplementedError( - 'Monitoring.unmonitor_instances is not yet implemented') + "Monitoring.unmonitor_instances is not yet implemented" + ) diff --git a/moto/ec2/responses/nat_gateways.py b/moto/ec2/responses/nat_gateways.py index ce9479e82..efa5c2656 100644 --- a/moto/ec2/responses/nat_gateways.py +++ b/moto/ec2/responses/nat_gateways.py @@ -4,17 +4,17 @@ from moto.ec2.utils import filters_from_querystring class NatGateways(BaseResponse): - def create_nat_gateway(self): - subnet_id = self._get_param('SubnetId') - allocation_id = self._get_param('AllocationId') + subnet_id = self._get_param("SubnetId") + allocation_id = self._get_param("AllocationId") nat_gateway = self.ec2_backend.create_nat_gateway( - subnet_id=subnet_id, allocation_id=allocation_id) + subnet_id=subnet_id, allocation_id=allocation_id + ) template = self.response_template(CREATE_NAT_GATEWAY) return template.render(nat_gateway=nat_gateway) def delete_nat_gateway(self): - nat_gateway_id = self._get_param('NatGatewayId') + nat_gateway_id = self._get_param("NatGatewayId") nat_gateway = self.ec2_backend.delete_nat_gateway(nat_gateway_id) template = self.response_template(DELETE_NAT_GATEWAY_RESPONSE) return template.render(nat_gateway=nat_gateway) diff --git a/moto/ec2/responses/network_acls.py b/moto/ec2/responses/network_acls.py index 97f370306..8d89e6065 100644 --- a/moto/ec2/responses/network_acls.py +++ b/moto/ec2/responses/network_acls.py @@ -4,82 +4,95 @@ from moto.ec2.utils import filters_from_querystring class NetworkACLs(BaseResponse): - def create_network_acl(self): - vpc_id = self._get_param('VpcId') + vpc_id = self._get_param("VpcId") network_acl = self.ec2_backend.create_network_acl(vpc_id) template = self.response_template(CREATE_NETWORK_ACL_RESPONSE) return template.render(network_acl=network_acl) def create_network_acl_entry(self): - network_acl_id = self._get_param('NetworkAclId') - rule_number = self._get_param('RuleNumber') - protocol = self._get_param('Protocol') - rule_action = self._get_param('RuleAction') - egress = self._get_param('Egress') - cidr_block = self._get_param('CidrBlock') - icmp_code = self._get_param('Icmp.Code') - icmp_type = self._get_param('Icmp.Type') - port_range_from = self._get_param('PortRange.From') - port_range_to = self._get_param('PortRange.To') + network_acl_id = self._get_param("NetworkAclId") + rule_number = self._get_param("RuleNumber") + protocol = self._get_param("Protocol") + rule_action = self._get_param("RuleAction") + egress = self._get_param("Egress") + cidr_block = self._get_param("CidrBlock") + icmp_code = self._get_param("Icmp.Code") + icmp_type = self._get_param("Icmp.Type") + port_range_from = self._get_param("PortRange.From") + port_range_to = self._get_param("PortRange.To") network_acl_entry = self.ec2_backend.create_network_acl_entry( - network_acl_id, rule_number, protocol, rule_action, - egress, cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) template = self.response_template(CREATE_NETWORK_ACL_ENTRY_RESPONSE) return template.render(network_acl_entry=network_acl_entry) def delete_network_acl(self): - network_acl_id = self._get_param('NetworkAclId') + network_acl_id = self._get_param("NetworkAclId") self.ec2_backend.delete_network_acl(network_acl_id) template = self.response_template(DELETE_NETWORK_ACL_ASSOCIATION) return template.render() def delete_network_acl_entry(self): - network_acl_id = self._get_param('NetworkAclId') - rule_number = self._get_param('RuleNumber') - egress = self._get_param('Egress') + network_acl_id = self._get_param("NetworkAclId") + rule_number = self._get_param("RuleNumber") + egress = self._get_param("Egress") self.ec2_backend.delete_network_acl_entry(network_acl_id, rule_number, egress) template = self.response_template(DELETE_NETWORK_ACL_ENTRY_RESPONSE) return template.render() def replace_network_acl_entry(self): - network_acl_id = self._get_param('NetworkAclId') - rule_number = self._get_param('RuleNumber') - protocol = self._get_param('Protocol') - rule_action = self._get_param('RuleAction') - egress = self._get_param('Egress') - cidr_block = self._get_param('CidrBlock') - icmp_code = self._get_param('Icmp.Code') - icmp_type = self._get_param('Icmp.Type') - port_range_from = self._get_param('PortRange.From') - port_range_to = self._get_param('PortRange.To') + network_acl_id = self._get_param("NetworkAclId") + rule_number = self._get_param("RuleNumber") + protocol = self._get_param("Protocol") + rule_action = self._get_param("RuleAction") + egress = self._get_param("Egress") + cidr_block = self._get_param("CidrBlock") + icmp_code = self._get_param("Icmp.Code") + icmp_type = self._get_param("Icmp.Type") + port_range_from = self._get_param("PortRange.From") + port_range_to = self._get_param("PortRange.To") self.ec2_backend.replace_network_acl_entry( - network_acl_id, rule_number, protocol, rule_action, - egress, cidr_block, icmp_code, icmp_type, - port_range_from, port_range_to) + network_acl_id, + rule_number, + protocol, + rule_action, + egress, + cidr_block, + icmp_code, + icmp_type, + port_range_from, + port_range_to, + ) template = self.response_template(REPLACE_NETWORK_ACL_ENTRY_RESPONSE) return template.render() def describe_network_acls(self): - network_acl_ids = self._get_multi_param('NetworkAclId') + network_acl_ids = self._get_multi_param("NetworkAclId") filters = filters_from_querystring(self.querystring) - network_acls = self.ec2_backend.get_all_network_acls( - network_acl_ids, filters) + network_acls = self.ec2_backend.get_all_network_acls(network_acl_ids, filters) template = self.response_template(DESCRIBE_NETWORK_ACL_RESPONSE) return template.render(network_acls=network_acls) def replace_network_acl_association(self): - association_id = self._get_param('AssociationId') - network_acl_id = self._get_param('NetworkAclId') + association_id = self._get_param("AssociationId") + network_acl_id = self._get_param("NetworkAclId") association = self.ec2_backend.replace_network_acl_association( - association_id, - network_acl_id + association_id, network_acl_id ) template = self.response_template(REPLACE_NETWORK_ACL_ASSOCIATION) return template.render(association=association) diff --git a/moto/ec2/responses/placement_groups.py b/moto/ec2/responses/placement_groups.py index 06930f700..2a7ade653 100644 --- a/moto/ec2/responses/placement_groups.py +++ b/moto/ec2/responses/placement_groups.py @@ -3,17 +3,19 @@ from moto.core.responses import BaseResponse class PlacementGroups(BaseResponse): - def create_placement_group(self): - if self.is_not_dryrun('CreatePlacementGroup'): + if self.is_not_dryrun("CreatePlacementGroup"): raise NotImplementedError( - 'PlacementGroups.create_placement_group is not yet implemented') + "PlacementGroups.create_placement_group is not yet implemented" + ) def delete_placement_group(self): - if self.is_not_dryrun('DeletePlacementGroup'): + if self.is_not_dryrun("DeletePlacementGroup"): raise NotImplementedError( - 'PlacementGroups.delete_placement_group is not yet implemented') + "PlacementGroups.delete_placement_group is not yet implemented" + ) def describe_placement_groups(self): raise NotImplementedError( - 'PlacementGroups.describe_placement_groups is not yet implemented') + "PlacementGroups.describe_placement_groups is not yet implemented" + ) diff --git a/moto/ec2/responses/reserved_instances.py b/moto/ec2/responses/reserved_instances.py index 07bd6661e..23a2b8715 100644 --- a/moto/ec2/responses/reserved_instances.py +++ b/moto/ec2/responses/reserved_instances.py @@ -3,30 +3,35 @@ from moto.core.responses import BaseResponse class ReservedInstances(BaseResponse): - def cancel_reserved_instances_listing(self): - if self.is_not_dryrun('CancelReservedInstances'): + if self.is_not_dryrun("CancelReservedInstances"): raise NotImplementedError( - 'ReservedInstances.cancel_reserved_instances_listing is not yet implemented') + "ReservedInstances.cancel_reserved_instances_listing is not yet implemented" + ) def create_reserved_instances_listing(self): - if self.is_not_dryrun('CreateReservedInstances'): + if self.is_not_dryrun("CreateReservedInstances"): raise NotImplementedError( - 'ReservedInstances.create_reserved_instances_listing is not yet implemented') + "ReservedInstances.create_reserved_instances_listing is not yet implemented" + ) def describe_reserved_instances(self): raise NotImplementedError( - 'ReservedInstances.describe_reserved_instances is not yet implemented') + "ReservedInstances.describe_reserved_instances is not yet implemented" + ) def describe_reserved_instances_listings(self): raise NotImplementedError( - 'ReservedInstances.describe_reserved_instances_listings is not yet implemented') + "ReservedInstances.describe_reserved_instances_listings is not yet implemented" + ) def describe_reserved_instances_offerings(self): raise NotImplementedError( - 'ReservedInstances.describe_reserved_instances_offerings is not yet implemented') + "ReservedInstances.describe_reserved_instances_offerings is not yet implemented" + ) def purchase_reserved_instances_offering(self): - if self.is_not_dryrun('PurchaseReservedInstances'): + if self.is_not_dryrun("PurchaseReservedInstances"): raise NotImplementedError( - 'ReservedInstances.purchase_reserved_instances_offering is not yet implemented') + "ReservedInstances.purchase_reserved_instances_offering is not yet implemented" + ) diff --git a/moto/ec2/responses/route_tables.py b/moto/ec2/responses/route_tables.py index 3878f325d..ef796e401 100644 --- a/moto/ec2/responses/route_tables.py +++ b/moto/ec2/responses/route_tables.py @@ -4,89 +4,94 @@ from moto.ec2.utils import filters_from_querystring class RouteTables(BaseResponse): - def associate_route_table(self): - route_table_id = self._get_param('RouteTableId') - subnet_id = self._get_param('SubnetId') + route_table_id = self._get_param("RouteTableId") + subnet_id = self._get_param("SubnetId") association_id = self.ec2_backend.associate_route_table( - route_table_id, subnet_id) + route_table_id, subnet_id + ) template = self.response_template(ASSOCIATE_ROUTE_TABLE_RESPONSE) return template.render(association_id=association_id) def create_route(self): - route_table_id = self._get_param('RouteTableId') - destination_cidr_block = self._get_param('DestinationCidrBlock') - gateway_id = self._get_param('GatewayId') - instance_id = self._get_param('InstanceId') - interface_id = self._get_param('NetworkInterfaceId') - pcx_id = self._get_param('VpcPeeringConnectionId') + route_table_id = self._get_param("RouteTableId") + destination_cidr_block = self._get_param("DestinationCidrBlock") + gateway_id = self._get_param("GatewayId") + instance_id = self._get_param("InstanceId") + interface_id = self._get_param("NetworkInterfaceId") + pcx_id = self._get_param("VpcPeeringConnectionId") - self.ec2_backend.create_route(route_table_id, destination_cidr_block, - gateway_id=gateway_id, - instance_id=instance_id, - interface_id=interface_id, - vpc_peering_connection_id=pcx_id) + self.ec2_backend.create_route( + route_table_id, + destination_cidr_block, + gateway_id=gateway_id, + instance_id=instance_id, + interface_id=interface_id, + vpc_peering_connection_id=pcx_id, + ) template = self.response_template(CREATE_ROUTE_RESPONSE) return template.render() def create_route_table(self): - vpc_id = self._get_param('VpcId') + vpc_id = self._get_param("VpcId") route_table = self.ec2_backend.create_route_table(vpc_id) template = self.response_template(CREATE_ROUTE_TABLE_RESPONSE) return template.render(route_table=route_table) def delete_route(self): - route_table_id = self._get_param('RouteTableId') - destination_cidr_block = self._get_param('DestinationCidrBlock') + route_table_id = self._get_param("RouteTableId") + destination_cidr_block = self._get_param("DestinationCidrBlock") self.ec2_backend.delete_route(route_table_id, destination_cidr_block) template = self.response_template(DELETE_ROUTE_RESPONSE) return template.render() def delete_route_table(self): - route_table_id = self._get_param('RouteTableId') + route_table_id = self._get_param("RouteTableId") self.ec2_backend.delete_route_table(route_table_id) template = self.response_template(DELETE_ROUTE_TABLE_RESPONSE) return template.render() def describe_route_tables(self): - route_table_ids = self._get_multi_param('RouteTableId') + route_table_ids = self._get_multi_param("RouteTableId") filters = filters_from_querystring(self.querystring) - route_tables = self.ec2_backend.get_all_route_tables( - route_table_ids, filters) + route_tables = self.ec2_backend.get_all_route_tables(route_table_ids, filters) template = self.response_template(DESCRIBE_ROUTE_TABLES_RESPONSE) return template.render(route_tables=route_tables) def disassociate_route_table(self): - association_id = self._get_param('AssociationId') + association_id = self._get_param("AssociationId") self.ec2_backend.disassociate_route_table(association_id) template = self.response_template(DISASSOCIATE_ROUTE_TABLE_RESPONSE) return template.render() def replace_route(self): - route_table_id = self._get_param('RouteTableId') - destination_cidr_block = self._get_param('DestinationCidrBlock') - gateway_id = self._get_param('GatewayId') - instance_id = self._get_param('InstanceId') - interface_id = self._get_param('NetworkInterfaceId') - pcx_id = self._get_param('VpcPeeringConnectionId') + route_table_id = self._get_param("RouteTableId") + destination_cidr_block = self._get_param("DestinationCidrBlock") + gateway_id = self._get_param("GatewayId") + instance_id = self._get_param("InstanceId") + interface_id = self._get_param("NetworkInterfaceId") + pcx_id = self._get_param("VpcPeeringConnectionId") - self.ec2_backend.replace_route(route_table_id, destination_cidr_block, - gateway_id=gateway_id, - instance_id=instance_id, - interface_id=interface_id, - vpc_peering_connection_id=pcx_id) + self.ec2_backend.replace_route( + route_table_id, + destination_cidr_block, + gateway_id=gateway_id, + instance_id=instance_id, + interface_id=interface_id, + vpc_peering_connection_id=pcx_id, + ) template = self.response_template(REPLACE_ROUTE_RESPONSE) return template.render() def replace_route_table_association(self): - route_table_id = self._get_param('RouteTableId') - association_id = self._get_param('AssociationId') + route_table_id = self._get_param("RouteTableId") + association_id = self._get_param("AssociationId") new_association_id = self.ec2_backend.replace_route_table_association( - association_id, route_table_id) - template = self.response_template( - REPLACE_ROUTE_TABLE_ASSOCIATION_RESPONSE) + association_id, route_table_id + ) + template = self.response_template(REPLACE_ROUTE_TABLE_ASSOCIATION_RESPONSE) return template.render(association_id=new_association_id) diff --git a/moto/ec2/responses/security_groups.py b/moto/ec2/responses/security_groups.py index 4aecfcf78..d2cfff977 100644 --- a/moto/ec2/responses/security_groups.py +++ b/moto/ec2/responses/security_groups.py @@ -12,37 +12,35 @@ def try_parse_int(value, default=None): def parse_sg_attributes_from_dict(sg_attributes): - ip_protocol = sg_attributes.get('IpProtocol', [None])[0] - from_port = sg_attributes.get('FromPort', [None])[0] - to_port = sg_attributes.get('ToPort', [None])[0] + ip_protocol = sg_attributes.get("IpProtocol", [None])[0] + from_port = sg_attributes.get("FromPort", [None])[0] + to_port = sg_attributes.get("ToPort", [None])[0] ip_ranges = [] - ip_ranges_tree = sg_attributes.get('IpRanges') or {} + ip_ranges_tree = sg_attributes.get("IpRanges") or {} for ip_range_idx in sorted(ip_ranges_tree.keys()): - ip_ranges.append(ip_ranges_tree[ip_range_idx]['CidrIp'][0]) + ip_ranges.append(ip_ranges_tree[ip_range_idx]["CidrIp"][0]) source_groups = [] source_group_ids = [] - groups_tree = sg_attributes.get('Groups') or {} + groups_tree = sg_attributes.get("Groups") or {} for group_idx in sorted(groups_tree.keys()): group_dict = groups_tree[group_idx] - if 'GroupId' in group_dict: - source_group_ids.append(group_dict['GroupId'][0]) - elif 'GroupName' in group_dict: - source_groups.append(group_dict['GroupName'][0]) + if "GroupId" in group_dict: + source_group_ids.append(group_dict["GroupId"][0]) + elif "GroupName" in group_dict: + source_groups.append(group_dict["GroupName"][0]) return ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids class SecurityGroups(BaseResponse): - def _process_rules_from_querystring(self): - group_name_or_id = (self._get_param('GroupName') or - self._get_param('GroupId')) + group_name_or_id = self._get_param("GroupName") or self._get_param("GroupId") querytree = {} for key, value in self.querystring.items(): - key_splitted = key.split('.') + key_splitted = key.split(".") key_splitted = [try_parse_int(e, e) for e in key_splitted] d = querytree @@ -52,41 +50,70 @@ class SecurityGroups(BaseResponse): d = d[subkey] d[key_splitted[-1]] = value - if 'IpPermissions' not in querytree: + if "IpPermissions" not in querytree: # Handle single rule syntax - ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids = parse_sg_attributes_from_dict(querytree) - yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, - source_groups, source_group_ids) + ( + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) = parse_sg_attributes_from_dict(querytree) + yield ( + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) - ip_permissions = querytree.get('IpPermissions') or {} + ip_permissions = querytree.get("IpPermissions") or {} for ip_permission_idx in sorted(ip_permissions.keys()): ip_permission = ip_permissions[ip_permission_idx] - ip_protocol, from_port, to_port, ip_ranges, source_groups, source_group_ids = parse_sg_attributes_from_dict(ip_permission) + ( + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) = parse_sg_attributes_from_dict(ip_permission) - yield (group_name_or_id, ip_protocol, from_port, to_port, ip_ranges, - source_groups, source_group_ids) + yield ( + group_name_or_id, + ip_protocol, + from_port, + to_port, + ip_ranges, + source_groups, + source_group_ids, + ) def authorize_security_group_egress(self): - if self.is_not_dryrun('GrantSecurityGroupEgress'): + if self.is_not_dryrun("GrantSecurityGroupEgress"): for args in self._process_rules_from_querystring(): self.ec2_backend.authorize_security_group_egress(*args) return AUTHORIZE_SECURITY_GROUP_EGRESS_RESPONSE def authorize_security_group_ingress(self): - if self.is_not_dryrun('GrantSecurityGroupIngress'): + if self.is_not_dryrun("GrantSecurityGroupIngress"): for args in self._process_rules_from_querystring(): self.ec2_backend.authorize_security_group_ingress(*args) return AUTHORIZE_SECURITY_GROUP_INGRESS_REPONSE def create_security_group(self): - name = self._get_param('GroupName') - description = self._get_param('GroupDescription') - vpc_id = self._get_param('VpcId') + name = self._get_param("GroupName") + description = self._get_param("GroupDescription") + vpc_id = self._get_param("VpcId") - if self.is_not_dryrun('CreateSecurityGroup'): + if self.is_not_dryrun("CreateSecurityGroup"): group = self.ec2_backend.create_security_group( - name, description, vpc_id=vpc_id) + name, description, vpc_id=vpc_id + ) template = self.response_template(CREATE_SECURITY_GROUP_RESPONSE) return template.render(group=group) @@ -95,10 +122,10 @@ class SecurityGroups(BaseResponse): # See # http://docs.aws.amazon.com/AWSEC2/latest/APIReference/ApiReference-query-DeleteSecurityGroup.html - name = self._get_param('GroupName') - sg_id = self._get_param('GroupId') + name = self._get_param("GroupName") + sg_id = self._get_param("GroupId") - if self.is_not_dryrun('DeleteSecurityGroup'): + if self.is_not_dryrun("DeleteSecurityGroup"): if name: self.ec2_backend.delete_security_group(name) elif sg_id: @@ -112,16 +139,14 @@ class SecurityGroups(BaseResponse): filters = filters_from_querystring(self.querystring) groups = self.ec2_backend.describe_security_groups( - group_ids=group_ids, - groupnames=groupnames, - filters=filters + group_ids=group_ids, groupnames=groupnames, filters=filters ) template = self.response_template(DESCRIBE_SECURITY_GROUPS_RESPONSE) return template.render(groups=groups) def revoke_security_group_egress(self): - if self.is_not_dryrun('RevokeSecurityGroupEgress'): + if self.is_not_dryrun("RevokeSecurityGroupEgress"): for args in self._process_rules_from_querystring(): success = self.ec2_backend.revoke_security_group_egress(*args) if not success: @@ -129,7 +154,7 @@ class SecurityGroups(BaseResponse): return REVOKE_SECURITY_GROUP_EGRESS_RESPONSE def revoke_security_group_ingress(self): - if self.is_not_dryrun('RevokeSecurityGroupIngress'): + if self.is_not_dryrun("RevokeSecurityGroupIngress"): for args in self._process_rules_from_querystring(): self.ec2_backend.revoke_security_group_ingress(*args) return REVOKE_SECURITY_GROUP_INGRESS_REPONSE diff --git a/moto/ec2/responses/spot_fleets.py b/moto/ec2/responses/spot_fleets.py index bb9aeb4ca..b7de85323 100644 --- a/moto/ec2/responses/spot_fleets.py +++ b/moto/ec2/responses/spot_fleets.py @@ -3,12 +3,12 @@ from moto.core.responses import BaseResponse class SpotFleets(BaseResponse): - def cancel_spot_fleet_requests(self): spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.") terminate_instances = self._get_param("TerminateInstances") spot_fleets = self.ec2_backend.cancel_spot_fleet_requests( - spot_fleet_request_ids, terminate_instances) + spot_fleet_request_ids, terminate_instances + ) template = self.response_template(CANCEL_SPOT_FLEETS_TEMPLATE) return template.render(spot_fleets=spot_fleets) @@ -16,37 +16,42 @@ class SpotFleets(BaseResponse): spot_fleet_request_id = self._get_param("SpotFleetRequestId") spot_requests = self.ec2_backend.describe_spot_fleet_instances( - spot_fleet_request_id) - template = self.response_template( - DESCRIBE_SPOT_FLEET_INSTANCES_TEMPLATE) - return template.render(spot_request_id=spot_fleet_request_id, spot_requests=spot_requests) + spot_fleet_request_id + ) + template = self.response_template(DESCRIBE_SPOT_FLEET_INSTANCES_TEMPLATE) + return template.render( + spot_request_id=spot_fleet_request_id, spot_requests=spot_requests + ) def describe_spot_fleet_requests(self): spot_fleet_request_ids = self._get_multi_param("SpotFleetRequestId.") - requests = self.ec2_backend.describe_spot_fleet_requests( - spot_fleet_request_ids) + requests = self.ec2_backend.describe_spot_fleet_requests(spot_fleet_request_ids) template = self.response_template(DESCRIBE_SPOT_FLEET_TEMPLATE) return template.render(requests=requests) def modify_spot_fleet_request(self): spot_fleet_request_id = self._get_param("SpotFleetRequestId") target_capacity = self._get_int_param("TargetCapacity") - terminate_instances = self._get_param("ExcessCapacityTerminationPolicy", if_none="Default") + terminate_instances = self._get_param( + "ExcessCapacityTerminationPolicy", if_none="Default" + ) successful = self.ec2_backend.modify_spot_fleet_request( - spot_fleet_request_id, target_capacity, terminate_instances) + spot_fleet_request_id, target_capacity, terminate_instances + ) template = self.response_template(MODIFY_SPOT_FLEET_REQUEST_TEMPLATE) return template.render(successful=successful) def request_spot_fleet(self): spot_config = self._get_dict_param("SpotFleetRequestConfig.") - spot_price = spot_config.get('spot_price') - target_capacity = spot_config['target_capacity'] - iam_fleet_role = spot_config['iam_fleet_role'] - allocation_strategy = spot_config['allocation_strategy'] + spot_price = spot_config.get("spot_price") + target_capacity = spot_config["target_capacity"] + iam_fleet_role = spot_config["iam_fleet_role"] + allocation_strategy = spot_config["allocation_strategy"] launch_specs = self._get_list_prefix( - "SpotFleetRequestConfig.LaunchSpecifications") + "SpotFleetRequestConfig.LaunchSpecifications" + ) request = self.ec2_backend.request_spot_fleet( spot_price=spot_price, diff --git a/moto/ec2/responses/spot_instances.py b/moto/ec2/responses/spot_instances.py index b0e80a320..392ad9524 100644 --- a/moto/ec2/responses/spot_instances.py +++ b/moto/ec2/responses/spot_instances.py @@ -4,64 +4,61 @@ from moto.ec2.utils import filters_from_querystring class SpotInstances(BaseResponse): - def cancel_spot_instance_requests(self): - request_ids = self._get_multi_param('SpotInstanceRequestId') - if self.is_not_dryrun('CancelSpotInstance'): - requests = self.ec2_backend.cancel_spot_instance_requests( - request_ids) + request_ids = self._get_multi_param("SpotInstanceRequestId") + if self.is_not_dryrun("CancelSpotInstance"): + requests = self.ec2_backend.cancel_spot_instance_requests(request_ids) template = self.response_template(CANCEL_SPOT_INSTANCES_TEMPLATE) return template.render(requests=requests) def create_spot_datafeed_subscription(self): - if self.is_not_dryrun('CreateSpotDatafeedSubscription'): + if self.is_not_dryrun("CreateSpotDatafeedSubscription"): raise NotImplementedError( - 'SpotInstances.create_spot_datafeed_subscription is not yet implemented') + "SpotInstances.create_spot_datafeed_subscription is not yet implemented" + ) def delete_spot_datafeed_subscription(self): - if self.is_not_dryrun('DeleteSpotDatafeedSubscription'): + if self.is_not_dryrun("DeleteSpotDatafeedSubscription"): raise NotImplementedError( - 'SpotInstances.delete_spot_datafeed_subscription is not yet implemented') + "SpotInstances.delete_spot_datafeed_subscription is not yet implemented" + ) def describe_spot_datafeed_subscription(self): raise NotImplementedError( - 'SpotInstances.describe_spot_datafeed_subscription is not yet implemented') + "SpotInstances.describe_spot_datafeed_subscription is not yet implemented" + ) def describe_spot_instance_requests(self): filters = filters_from_querystring(self.querystring) - requests = self.ec2_backend.describe_spot_instance_requests( - filters=filters) + requests = self.ec2_backend.describe_spot_instance_requests(filters=filters) template = self.response_template(DESCRIBE_SPOT_INSTANCES_TEMPLATE) return template.render(requests=requests) def describe_spot_price_history(self): raise NotImplementedError( - 'SpotInstances.describe_spot_price_history is not yet implemented') + "SpotInstances.describe_spot_price_history is not yet implemented" + ) def request_spot_instances(self): - price = self._get_param('SpotPrice') - image_id = self._get_param('LaunchSpecification.ImageId') - count = self._get_int_param('InstanceCount', 1) - type = self._get_param('Type', 'one-time') - valid_from = self._get_param('ValidFrom') - valid_until = self._get_param('ValidUntil') - launch_group = self._get_param('LaunchGroup') - availability_zone_group = self._get_param('AvailabilityZoneGroup') - key_name = self._get_param('LaunchSpecification.KeyName') - security_groups = self._get_multi_param( - 'LaunchSpecification.SecurityGroup') - user_data = self._get_param('LaunchSpecification.UserData') - instance_type = self._get_param( - 'LaunchSpecification.InstanceType', 'm1.small') - placement = self._get_param( - 'LaunchSpecification.Placement.AvailabilityZone') - kernel_id = self._get_param('LaunchSpecification.KernelId') - ramdisk_id = self._get_param('LaunchSpecification.RamdiskId') - monitoring_enabled = self._get_param( - 'LaunchSpecification.Monitoring.Enabled') - subnet_id = self._get_param('LaunchSpecification.SubnetId') + price = self._get_param("SpotPrice") + image_id = self._get_param("LaunchSpecification.ImageId") + count = self._get_int_param("InstanceCount", 1) + type = self._get_param("Type", "one-time") + valid_from = self._get_param("ValidFrom") + valid_until = self._get_param("ValidUntil") + launch_group = self._get_param("LaunchGroup") + availability_zone_group = self._get_param("AvailabilityZoneGroup") + key_name = self._get_param("LaunchSpecification.KeyName") + security_groups = self._get_multi_param("LaunchSpecification.SecurityGroup") + user_data = self._get_param("LaunchSpecification.UserData") + instance_type = self._get_param("LaunchSpecification.InstanceType", "m1.small") + placement = self._get_param("LaunchSpecification.Placement.AvailabilityZone") + kernel_id = self._get_param("LaunchSpecification.KernelId") + ramdisk_id = self._get_param("LaunchSpecification.RamdiskId") + monitoring_enabled = self._get_param("LaunchSpecification.Monitoring.Enabled") + subnet_id = self._get_param("LaunchSpecification.SubnetId") - if self.is_not_dryrun('RequestSpotInstance'): + if self.is_not_dryrun("RequestSpotInstance"): requests = self.ec2_backend.request_spot_instances( price=price, image_id=image_id, diff --git a/moto/ec2/responses/subnets.py b/moto/ec2/responses/subnets.py index 0412d9e8b..c42583f23 100644 --- a/moto/ec2/responses/subnets.py +++ b/moto/ec2/responses/subnets.py @@ -6,44 +6,42 @@ from moto.ec2.utils import filters_from_querystring class Subnets(BaseResponse): - def create_subnet(self): - vpc_id = self._get_param('VpcId') - cidr_block = self._get_param('CidrBlock') + vpc_id = self._get_param("VpcId") + cidr_block = self._get_param("CidrBlock") availability_zone = self._get_param( - 'AvailabilityZone', if_none=random.choice( - self.ec2_backend.describe_availability_zones()).name) + "AvailabilityZone", + if_none=random.choice(self.ec2_backend.describe_availability_zones()).name, + ) subnet = self.ec2_backend.create_subnet( - vpc_id, - cidr_block, - availability_zone, - context=self, + vpc_id, cidr_block, availability_zone, context=self ) template = self.response_template(CREATE_SUBNET_RESPONSE) return template.render(subnet=subnet) def delete_subnet(self): - subnet_id = self._get_param('SubnetId') + subnet_id = self._get_param("SubnetId") subnet = self.ec2_backend.delete_subnet(subnet_id) template = self.response_template(DELETE_SUBNET_RESPONSE) return template.render(subnet=subnet) def describe_subnets(self): - subnet_ids = self._get_multi_param('SubnetId') + subnet_ids = self._get_multi_param("SubnetId") filters = filters_from_querystring(self.querystring) subnets = self.ec2_backend.get_all_subnets(subnet_ids, filters) template = self.response_template(DESCRIBE_SUBNETS_RESPONSE) return template.render(subnets=subnets) def modify_subnet_attribute(self): - subnet_id = self._get_param('SubnetId') + subnet_id = self._get_param("SubnetId") - for attribute in ('MapPublicIpOnLaunch', 'AssignIpv6AddressOnCreation'): - if self.querystring.get('%s.Value' % attribute): + for attribute in ("MapPublicIpOnLaunch", "AssignIpv6AddressOnCreation"): + if self.querystring.get("%s.Value" % attribute): attr_name = camelcase_to_underscores(attribute) - attr_value = self.querystring.get('%s.Value' % attribute)[0] + attr_value = self.querystring.get("%s.Value" % attribute)[0] self.ec2_backend.modify_subnet_attribute( - subnet_id, attr_name, attr_value) + subnet_id, attr_name, attr_value + ) return MODIFY_SUBNET_ATTRIBUTE_RESPONSE diff --git a/moto/ec2/responses/tags.py b/moto/ec2/responses/tags.py index 65d3da255..5290b7409 100644 --- a/moto/ec2/responses/tags.py +++ b/moto/ec2/responses/tags.py @@ -6,21 +6,20 @@ from moto.ec2.utils import tags_from_query_string, filters_from_querystring class TagResponse(BaseResponse): - def create_tags(self): - resource_ids = self._get_multi_param('ResourceId') + resource_ids = self._get_multi_param("ResourceId") validate_resource_ids(resource_ids) self.ec2_backend.do_resources_exist(resource_ids) tags = tags_from_query_string(self.querystring) - if self.is_not_dryrun('CreateTags'): + if self.is_not_dryrun("CreateTags"): self.ec2_backend.create_tags(resource_ids, tags) return CREATE_RESPONSE def delete_tags(self): - resource_ids = self._get_multi_param('ResourceId') + resource_ids = self._get_multi_param("ResourceId") validate_resource_ids(resource_ids) tags = tags_from_query_string(self.querystring) - if self.is_not_dryrun('DeleteTags'): + if self.is_not_dryrun("DeleteTags"): self.ec2_backend.delete_tags(resource_ids, tags) return DELETE_RESPONSE diff --git a/moto/ec2/responses/virtual_private_gateways.py b/moto/ec2/responses/virtual_private_gateways.py index 75de31b93..ce30aa9b2 100644 --- a/moto/ec2/responses/virtual_private_gateways.py +++ b/moto/ec2/responses/virtual_private_gateways.py @@ -4,25 +4,21 @@ from moto.ec2.utils import filters_from_querystring class VirtualPrivateGateways(BaseResponse): - def attach_vpn_gateway(self): - vpn_gateway_id = self._get_param('VpnGatewayId') - vpc_id = self._get_param('VpcId') - attachment = self.ec2_backend.attach_vpn_gateway( - vpn_gateway_id, - vpc_id - ) + vpn_gateway_id = self._get_param("VpnGatewayId") + vpc_id = self._get_param("VpcId") + attachment = self.ec2_backend.attach_vpn_gateway(vpn_gateway_id, vpc_id) template = self.response_template(ATTACH_VPN_GATEWAY_RESPONSE) return template.render(attachment=attachment) def create_vpn_gateway(self): - type = self._get_param('Type') + type = self._get_param("Type") vpn_gateway = self.ec2_backend.create_vpn_gateway(type) template = self.response_template(CREATE_VPN_GATEWAY_RESPONSE) return template.render(vpn_gateway=vpn_gateway) def delete_vpn_gateway(self): - vpn_gateway_id = self._get_param('VpnGatewayId') + vpn_gateway_id = self._get_param("VpnGatewayId") vpn_gateway = self.ec2_backend.delete_vpn_gateway(vpn_gateway_id) template = self.response_template(DELETE_VPN_GATEWAY_RESPONSE) return template.render(vpn_gateway=vpn_gateway) @@ -34,12 +30,9 @@ class VirtualPrivateGateways(BaseResponse): return template.render(vpn_gateways=vpn_gateways) def detach_vpn_gateway(self): - vpn_gateway_id = self._get_param('VpnGatewayId') - vpc_id = self._get_param('VpcId') - attachment = self.ec2_backend.detach_vpn_gateway( - vpn_gateway_id, - vpc_id - ) + vpn_gateway_id = self._get_param("VpnGatewayId") + vpc_id = self._get_param("VpcId") + attachment = self.ec2_backend.detach_vpn_gateway(vpn_gateway_id, vpc_id) template = self.response_template(DETACH_VPN_GATEWAY_RESPONSE) return template.render(attachment=attachment) diff --git a/moto/ec2/responses/vm_export.py b/moto/ec2/responses/vm_export.py index 6fdf59ba3..a4c831fcb 100644 --- a/moto/ec2/responses/vm_export.py +++ b/moto/ec2/responses/vm_export.py @@ -3,15 +3,15 @@ from moto.core.responses import BaseResponse class VMExport(BaseResponse): - def cancel_export_task(self): - raise NotImplementedError( - 'VMExport.cancel_export_task is not yet implemented') + raise NotImplementedError("VMExport.cancel_export_task is not yet implemented") def create_instance_export_task(self): raise NotImplementedError( - 'VMExport.create_instance_export_task is not yet implemented') + "VMExport.create_instance_export_task is not yet implemented" + ) def describe_export_tasks(self): raise NotImplementedError( - 'VMExport.describe_export_tasks is not yet implemented') + "VMExport.describe_export_tasks is not yet implemented" + ) diff --git a/moto/ec2/responses/vm_import.py b/moto/ec2/responses/vm_import.py index 8c2ba138c..50f77c66c 100644 --- a/moto/ec2/responses/vm_import.py +++ b/moto/ec2/responses/vm_import.py @@ -3,19 +3,18 @@ from moto.core.responses import BaseResponse class VMImport(BaseResponse): - def cancel_conversion_task(self): raise NotImplementedError( - 'VMImport.cancel_conversion_task is not yet implemented') + "VMImport.cancel_conversion_task is not yet implemented" + ) def describe_conversion_tasks(self): raise NotImplementedError( - 'VMImport.describe_conversion_tasks is not yet implemented') + "VMImport.describe_conversion_tasks is not yet implemented" + ) def import_instance(self): - raise NotImplementedError( - 'VMImport.import_instance is not yet implemented') + raise NotImplementedError("VMImport.import_instance is not yet implemented") def import_volume(self): - raise NotImplementedError( - 'VMImport.import_volume is not yet implemented') + raise NotImplementedError("VMImport.import_volume is not yet implemented") diff --git a/moto/ec2/responses/vpc_peering_connections.py b/moto/ec2/responses/vpc_peering_connections.py index 68bae72da..ff792a6cc 100644 --- a/moto/ec2/responses/vpc_peering_connections.py +++ b/moto/ec2/responses/vpc_peering_connections.py @@ -3,44 +3,40 @@ from moto.core.responses import BaseResponse class VPCPeeringConnections(BaseResponse): - def create_vpc_peering_connection(self): - peer_region = self._get_param('PeerRegion') + peer_region = self._get_param("PeerRegion") if peer_region == self.region or peer_region is None: - peer_vpc = self.ec2_backend.get_vpc(self._get_param('PeerVpcId')) + peer_vpc = self.ec2_backend.get_vpc(self._get_param("PeerVpcId")) else: - peer_vpc = self.ec2_backend.get_cross_vpc(self._get_param('PeerVpcId'), peer_region) - vpc = self.ec2_backend.get_vpc(self._get_param('VpcId')) + peer_vpc = self.ec2_backend.get_cross_vpc( + self._get_param("PeerVpcId"), peer_region + ) + vpc = self.ec2_backend.get_vpc(self._get_param("VpcId")) vpc_pcx = self.ec2_backend.create_vpc_peering_connection(vpc, peer_vpc) - template = self.response_template( - CREATE_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(CREATE_VPC_PEERING_CONNECTION_RESPONSE) return template.render(vpc_pcx=vpc_pcx) def delete_vpc_peering_connection(self): - vpc_pcx_id = self._get_param('VpcPeeringConnectionId') + vpc_pcx_id = self._get_param("VpcPeeringConnectionId") vpc_pcx = self.ec2_backend.delete_vpc_peering_connection(vpc_pcx_id) - template = self.response_template( - DELETE_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(DELETE_VPC_PEERING_CONNECTION_RESPONSE) return template.render(vpc_pcx=vpc_pcx) def describe_vpc_peering_connections(self): vpc_pcxs = self.ec2_backend.get_all_vpc_peering_connections() - template = self.response_template( - DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE) + template = self.response_template(DESCRIBE_VPC_PEERING_CONNECTIONS_RESPONSE) return template.render(vpc_pcxs=vpc_pcxs) def accept_vpc_peering_connection(self): - vpc_pcx_id = self._get_param('VpcPeeringConnectionId') + vpc_pcx_id = self._get_param("VpcPeeringConnectionId") vpc_pcx = self.ec2_backend.accept_vpc_peering_connection(vpc_pcx_id) - template = self.response_template( - ACCEPT_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(ACCEPT_VPC_PEERING_CONNECTION_RESPONSE) return template.render(vpc_pcx=vpc_pcx) def reject_vpc_peering_connection(self): - vpc_pcx_id = self._get_param('VpcPeeringConnectionId') + vpc_pcx_id = self._get_param("VpcPeeringConnectionId") self.ec2_backend.reject_vpc_peering_connection(vpc_pcx_id) - template = self.response_template( - REJECT_VPC_PEERING_CONNECTION_RESPONSE) + template = self.response_template(REJECT_VPC_PEERING_CONNECTION_RESPONSE) return template.render() diff --git a/moto/ec2/responses/vpcs.py b/moto/ec2/responses/vpcs.py index 88673d863..1773e4cc8 100644 --- a/moto/ec2/responses/vpcs.py +++ b/moto/ec2/responses/vpcs.py @@ -5,74 +5,102 @@ from moto.ec2.utils import filters_from_querystring class VPCs(BaseResponse): - def create_vpc(self): - cidr_block = self._get_param('CidrBlock') - instance_tenancy = self._get_param('InstanceTenancy', if_none='default') - amazon_provided_ipv6_cidr_blocks = self._get_param('AmazonProvidedIpv6CidrBlock') - vpc = self.ec2_backend.create_vpc(cidr_block, instance_tenancy, - amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_blocks) - doc_date = '2013-10-15' if 'Boto/' in self.headers.get('user-agent', '') else '2016-11-15' + cidr_block = self._get_param("CidrBlock") + instance_tenancy = self._get_param("InstanceTenancy", if_none="default") + amazon_provided_ipv6_cidr_blocks = self._get_param( + "AmazonProvidedIpv6CidrBlock" + ) + vpc = self.ec2_backend.create_vpc( + cidr_block, + instance_tenancy, + amazon_provided_ipv6_cidr_block=amazon_provided_ipv6_cidr_blocks, + ) + doc_date = ( + "2013-10-15" + if "Boto/" in self.headers.get("user-agent", "") + else "2016-11-15" + ) template = self.response_template(CREATE_VPC_RESPONSE) return template.render(vpc=vpc, doc_date=doc_date) def delete_vpc(self): - vpc_id = self._get_param('VpcId') + vpc_id = self._get_param("VpcId") vpc = self.ec2_backend.delete_vpc(vpc_id) template = self.response_template(DELETE_VPC_RESPONSE) return template.render(vpc=vpc) def describe_vpcs(self): - vpc_ids = self._get_multi_param('VpcId') + vpc_ids = self._get_multi_param("VpcId") filters = filters_from_querystring(self.querystring) vpcs = self.ec2_backend.get_all_vpcs(vpc_ids=vpc_ids, filters=filters) - doc_date = '2013-10-15' if 'Boto/' in self.headers.get('user-agent', '') else '2016-11-15' + doc_date = ( + "2013-10-15" + if "Boto/" in self.headers.get("user-agent", "") + else "2016-11-15" + ) template = self.response_template(DESCRIBE_VPCS_RESPONSE) return template.render(vpcs=vpcs, doc_date=doc_date) def describe_vpc_attribute(self): - vpc_id = self._get_param('VpcId') - attribute = self._get_param('Attribute') + vpc_id = self._get_param("VpcId") + attribute = self._get_param("Attribute") attr_name = camelcase_to_underscores(attribute) value = self.ec2_backend.describe_vpc_attribute(vpc_id, attr_name) template = self.response_template(DESCRIBE_VPC_ATTRIBUTE_RESPONSE) return template.render(vpc_id=vpc_id, attribute=attribute, value=value) def modify_vpc_attribute(self): - vpc_id = self._get_param('VpcId') + vpc_id = self._get_param("VpcId") - for attribute in ('EnableDnsSupport', 'EnableDnsHostnames'): - if self.querystring.get('%s.Value' % attribute): + for attribute in ("EnableDnsSupport", "EnableDnsHostnames"): + if self.querystring.get("%s.Value" % attribute): attr_name = camelcase_to_underscores(attribute) - attr_value = self.querystring.get('%s.Value' % attribute)[0] - self.ec2_backend.modify_vpc_attribute( - vpc_id, attr_name, attr_value) + attr_value = self.querystring.get("%s.Value" % attribute)[0] + self.ec2_backend.modify_vpc_attribute(vpc_id, attr_name, attr_value) return MODIFY_VPC_ATTRIBUTE_RESPONSE def associate_vpc_cidr_block(self): - vpc_id = self._get_param('VpcId') - amazon_provided_ipv6_cidr_blocks = self._get_param('AmazonProvidedIpv6CidrBlock') + vpc_id = self._get_param("VpcId") + amazon_provided_ipv6_cidr_blocks = self._get_param( + "AmazonProvidedIpv6CidrBlock" + ) # todo test on AWS if can create an association for IPV4 and IPV6 in the same call? - cidr_block = self._get_param('CidrBlock') if not amazon_provided_ipv6_cidr_blocks else None - value = self.ec2_backend.associate_vpc_cidr_block(vpc_id, cidr_block, amazon_provided_ipv6_cidr_blocks) + cidr_block = ( + self._get_param("CidrBlock") + if not amazon_provided_ipv6_cidr_blocks + else None + ) + value = self.ec2_backend.associate_vpc_cidr_block( + vpc_id, cidr_block, amazon_provided_ipv6_cidr_blocks + ) if not amazon_provided_ipv6_cidr_blocks: render_template = ASSOCIATE_VPC_CIDR_BLOCK_RESPONSE else: render_template = IPV6_ASSOCIATE_VPC_CIDR_BLOCK_RESPONSE template = self.response_template(render_template) - return template.render(vpc_id=vpc_id, value=value, cidr_block=value['cidr_block'], - association_id=value['association_id'], cidr_block_state='associating') + return template.render( + vpc_id=vpc_id, + value=value, + cidr_block=value["cidr_block"], + association_id=value["association_id"], + cidr_block_state="associating", + ) def disassociate_vpc_cidr_block(self): - association_id = self._get_param('AssociationId') + association_id = self._get_param("AssociationId") value = self.ec2_backend.disassociate_vpc_cidr_block(association_id) - if "::" in value.get('cidr_block', ''): + if "::" in value.get("cidr_block", ""): render_template = IPV6_DISASSOCIATE_VPC_CIDR_BLOCK_RESPONSE else: render_template = DISASSOCIATE_VPC_CIDR_BLOCK_RESPONSE template = self.response_template(render_template) - return template.render(vpc_id=value['vpc_id'], cidr_block=value['cidr_block'], - association_id=value['association_id'], cidr_block_state='disassociating') + return template.render( + vpc_id=value["vpc_id"], + cidr_block=value["cidr_block"], + association_id=value["association_id"], + cidr_block_state="disassociating", + ) CREATE_VPC_RESPONSE = """ diff --git a/moto/ec2/responses/vpn_connections.py b/moto/ec2/responses/vpn_connections.py index 276e3ca99..9ddd4d7d9 100644 --- a/moto/ec2/responses/vpn_connections.py +++ b/moto/ec2/responses/vpn_connections.py @@ -4,29 +4,29 @@ from moto.ec2.utils import filters_from_querystring class VPNConnections(BaseResponse): - def create_vpn_connection(self): - type = self._get_param('Type') - cgw_id = self._get_param('CustomerGatewayId') - vgw_id = self._get_param('VPNGatewayId') - static_routes = self._get_param('StaticRoutesOnly') + type = self._get_param("Type") + cgw_id = self._get_param("CustomerGatewayId") + vgw_id = self._get_param("VPNGatewayId") + static_routes = self._get_param("StaticRoutesOnly") vpn_connection = self.ec2_backend.create_vpn_connection( - type, cgw_id, vgw_id, static_routes_only=static_routes) + type, cgw_id, vgw_id, static_routes_only=static_routes + ) template = self.response_template(CREATE_VPN_CONNECTION_RESPONSE) return template.render(vpn_connection=vpn_connection) def delete_vpn_connection(self): - vpn_connection_id = self._get_param('VpnConnectionId') - vpn_connection = self.ec2_backend.delete_vpn_connection( - vpn_connection_id) + vpn_connection_id = self._get_param("VpnConnectionId") + vpn_connection = self.ec2_backend.delete_vpn_connection(vpn_connection_id) template = self.response_template(DELETE_VPN_CONNECTION_RESPONSE) return template.render(vpn_connection=vpn_connection) def describe_vpn_connections(self): - vpn_connection_ids = self._get_multi_param('VpnConnectionId') + vpn_connection_ids = self._get_multi_param("VpnConnectionId") filters = filters_from_querystring(self.querystring) vpn_connections = self.ec2_backend.get_all_vpn_connections( - vpn_connection_ids=vpn_connection_ids, filters=filters) + vpn_connection_ids=vpn_connection_ids, filters=filters + ) template = self.response_template(DESCRIBE_VPN_CONNECTION_RESPONSE) return template.render(vpn_connections=vpn_connections) diff --git a/moto/ec2/responses/windows.py b/moto/ec2/responses/windows.py index 13dfa9b67..14b2b0666 100644 --- a/moto/ec2/responses/windows.py +++ b/moto/ec2/responses/windows.py @@ -3,19 +3,16 @@ from moto.core.responses import BaseResponse class Windows(BaseResponse): - def bundle_instance(self): - raise NotImplementedError( - 'Windows.bundle_instance is not yet implemented') + raise NotImplementedError("Windows.bundle_instance is not yet implemented") def cancel_bundle_task(self): - raise NotImplementedError( - 'Windows.cancel_bundle_task is not yet implemented') + raise NotImplementedError("Windows.cancel_bundle_task is not yet implemented") def describe_bundle_tasks(self): raise NotImplementedError( - 'Windows.describe_bundle_tasks is not yet implemented') + "Windows.describe_bundle_tasks is not yet implemented" + ) def get_password_data(self): - raise NotImplementedError( - 'Windows.get_password_data is not yet implemented') + raise NotImplementedError("Windows.get_password_data is not yet implemented") diff --git a/moto/ec2/urls.py b/moto/ec2/urls.py index 241ab7133..b83a9e950 100644 --- a/moto/ec2/urls.py +++ b/moto/ec2/urls.py @@ -2,10 +2,6 @@ from __future__ import unicode_literals from .responses import EC2Response -url_bases = [ - "https?://ec2.(.+).amazonaws.com(|.cn)", -] +url_bases = ["https?://ec2.(.+).amazonaws.com(|.cn)"] -url_paths = { - '{0}/': EC2Response.dispatch, -} +url_paths = {"{0}/": EC2Response.dispatch} diff --git a/moto/ec2/utils.py b/moto/ec2/utils.py index a718c7812..2301248c1 100644 --- a/moto/ec2/utils.py +++ b/moto/ec2/utils.py @@ -15,173 +15,171 @@ from sshpubkeys.keys import SSHKey EC2_RESOURCE_TO_PREFIX = { - 'customer-gateway': 'cgw', - 'dhcp-options': 'dopt', - 'image': 'ami', - 'instance': 'i', - 'internet-gateway': 'igw', - 'launch-template': 'lt', - 'nat-gateway': 'nat', - 'network-acl': 'acl', - 'network-acl-subnet-assoc': 'aclassoc', - 'network-interface': 'eni', - 'network-interface-attachment': 'eni-attach', - 'reserved-instance': 'uuid4', - 'route-table': 'rtb', - 'route-table-association': 'rtbassoc', - 'security-group': 'sg', - 'snapshot': 'snap', - 'spot-instance-request': 'sir', - 'spot-fleet-request': 'sfr', - 'subnet': 'subnet', - 'reservation': 'r', - 'volume': 'vol', - 'vpc': 'vpc', - 'vpc-cidr-association-id': 'vpc-cidr-assoc', - 'vpc-elastic-ip': 'eipalloc', - 'vpc-elastic-ip-association': 'eipassoc', - 'vpc-peering-connection': 'pcx', - 'vpn-connection': 'vpn', - 'vpn-gateway': 'vgw'} + "customer-gateway": "cgw", + "dhcp-options": "dopt", + "image": "ami", + "instance": "i", + "internet-gateway": "igw", + "launch-template": "lt", + "nat-gateway": "nat", + "network-acl": "acl", + "network-acl-subnet-assoc": "aclassoc", + "network-interface": "eni", + "network-interface-attachment": "eni-attach", + "reserved-instance": "uuid4", + "route-table": "rtb", + "route-table-association": "rtbassoc", + "security-group": "sg", + "snapshot": "snap", + "spot-instance-request": "sir", + "spot-fleet-request": "sfr", + "subnet": "subnet", + "reservation": "r", + "volume": "vol", + "vpc": "vpc", + "vpc-cidr-association-id": "vpc-cidr-assoc", + "vpc-elastic-ip": "eipalloc", + "vpc-elastic-ip-association": "eipassoc", + "vpc-peering-connection": "pcx", + "vpn-connection": "vpn", + "vpn-gateway": "vgw", +} EC2_PREFIX_TO_RESOURCE = dict((v, k) for (k, v) in EC2_RESOURCE_TO_PREFIX.items()) def random_resource_id(size=8): - chars = list(range(10)) + ['a', 'b', 'c', 'd', 'e', 'f'] - resource_id = ''.join(six.text_type(random.choice(chars)) for _ in range(size)) + chars = list(range(10)) + ["a", "b", "c", "d", "e", "f"] + resource_id = "".join(six.text_type(random.choice(chars)) for _ in range(size)) return resource_id -def random_id(prefix='', size=8): - return '{0}-{1}'.format(prefix, random_resource_id(size)) +def random_id(prefix="", size=8): + return "{0}-{1}".format(prefix, random_resource_id(size)) def random_ami_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['image']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["image"]) def random_instance_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['instance'], size=17) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["instance"], size=17) def random_reservation_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['reservation']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["reservation"]) def random_security_group_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['security-group']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["security-group"]) def random_snapshot_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['snapshot']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["snapshot"]) def random_spot_request_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['spot-instance-request']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["spot-instance-request"]) def random_spot_fleet_request_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['spot-fleet-request']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["spot-fleet-request"]) def random_subnet_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['subnet']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["subnet"]) def random_subnet_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['route-table-association']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["route-table-association"]) def random_network_acl_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-acl']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl"]) def random_network_acl_subnet_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-acl-subnet-assoc']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-acl-subnet-assoc"]) def random_vpn_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpn-gateway']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-gateway"]) def random_vpn_connection_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpn-connection']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpn-connection"]) def random_customer_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['customer-gateway']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["customer-gateway"]) def random_volume_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['volume']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["volume"]) def random_vpc_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc"]) def random_vpc_cidr_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-cidr-association-id']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-cidr-association-id"]) def random_vpc_peering_connection_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-peering-connection']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-peering-connection"]) def random_eip_association_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-elastic-ip-association']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-elastic-ip-association"]) def random_internet_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['internet-gateway']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["internet-gateway"]) def random_route_table_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['route-table']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["route-table"]) def random_eip_allocation_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['vpc-elastic-ip']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["vpc-elastic-ip"]) def random_dhcp_option_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['dhcp-options']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["dhcp-options"]) def random_eni_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-interface']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-interface"]) def random_eni_attach_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['network-interface-attachment']) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["network-interface-attachment"]) def random_nat_gateway_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['nat-gateway'], size=17) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["nat-gateway"], size=17) def random_launch_template_id(): - return random_id(prefix=EC2_RESOURCE_TO_PREFIX['launch-template'], size=17) + return random_id(prefix=EC2_RESOURCE_TO_PREFIX["launch-template"], size=17) def random_public_ip(): - return '54.214.{0}.{1}'.format(random.choice(range(255)), - random.choice(range(255))) + return "54.214.{0}.{1}".format(random.choice(range(255)), random.choice(range(255))) def random_private_ip(): - return '10.{0}.{1}.{2}'.format(random.choice(range(255)), - random.choice(range(255)), - random.choice(range(255))) + return "10.{0}.{1}.{2}".format( + random.choice(range(255)), random.choice(range(255)), random.choice(range(255)) + ) def random_ip(): return "127.{0}.{1}.{2}".format( - random.randint(0, 255), - random.randint(0, 255), - random.randint(0, 255) + random.randint(0, 255), random.randint(0, 255), random.randint(0, 255) ) @@ -194,13 +192,13 @@ def generate_route_id(route_table_id, cidr_block): def split_route_id(route_id): - values = route_id.split('~') + values = route_id.split("~") return values[0], values[1] def tags_from_query_string(querystring_dict): - prefix = 'Tag' - suffix = 'Key' + prefix = "Tag" + suffix = "Key" response_values = {} for key, value in querystring_dict.items(): if key.startswith(prefix) and key.endswith(suffix): @@ -208,14 +206,13 @@ def tags_from_query_string(querystring_dict): tag_key = querystring_dict.get("Tag.{0}.Key".format(tag_index))[0] tag_value_key = "Tag.{0}.Value".format(tag_index) if tag_value_key in querystring_dict: - response_values[tag_key] = querystring_dict.get(tag_value_key)[ - 0] + response_values[tag_key] = querystring_dict.get(tag_value_key)[0] else: response_values[tag_key] = None return response_values -def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration'): +def dhcp_configuration_from_querystring(querystring, option="DhcpConfiguration"): """ turn: {u'AWSAccessKeyId': [u'the_key'], @@ -234,7 +231,7 @@ def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration' {u'domain-name': [u'example.com'], u'domain-name-servers': [u'10.0.0.6', u'10.0.0.7']} """ - key_needle = re.compile(u'{0}.[0-9]+.Key'.format(option), re.UNICODE) + key_needle = re.compile("{0}.[0-9]+.Key".format(option), re.UNICODE) response_values = {} for key, value in querystring.items(): @@ -243,8 +240,7 @@ def dhcp_configuration_from_querystring(querystring, option=u'DhcpConfiguration' key_index = key.split(".")[1] value_index = 1 while True: - value_key = u'{0}.{1}.Value.{2}'.format( - option, key_index, value_index) + value_key = "{0}.{1}.Value.{2}".format(option, key_index, value_index) if value_key in querystring: values.extend(querystring[value_key]) else: @@ -261,8 +257,11 @@ def filters_from_querystring(querystring_dict): if match: filter_index = match.groups()[0] value_prefix = "Filter.{0}.Value".format(filter_index) - filter_values = [filter_value[0] for filter_key, filter_value in querystring_dict.items() if - filter_key.startswith(value_prefix)] + filter_values = [ + filter_value[0] + for filter_key, filter_value in querystring_dict.items() + if filter_key.startswith(value_prefix) + ] response_values[value[0]] = filter_values return response_values @@ -283,7 +282,7 @@ def dict_from_querystring(parameter, querystring_dict): def get_object_value(obj, attr): - keys = attr.split('.') + keys = attr.split(".") val = obj for key in keys: if hasattr(val, key): @@ -301,36 +300,37 @@ def get_object_value(obj, attr): def is_tag_filter(filter_name): - return (filter_name.startswith('tag:') or - filter_name.startswith('tag-value') or - filter_name.startswith('tag-key')) + return ( + filter_name.startswith("tag:") + or filter_name.startswith("tag-value") + or filter_name.startswith("tag-key") + ) def get_obj_tag(obj, filter_name): - tag_name = filter_name.replace('tag:', '', 1) - tags = dict((tag['key'], tag['value']) for tag in obj.get_tags()) + tag_name = filter_name.replace("tag:", "", 1) + tags = dict((tag["key"], tag["value"]) for tag in obj.get_tags()) return tags.get(tag_name) def get_obj_tag_names(obj): - tags = set((tag['key'] for tag in obj.get_tags())) + tags = set((tag["key"] for tag in obj.get_tags())) return tags def get_obj_tag_values(obj): - tags = set((tag['value'] for tag in obj.get_tags())) + tags = set((tag["value"] for tag in obj.get_tags())) return tags def tag_filter_matches(obj, filter_name, filter_values): - regex_filters = [re.compile(simple_aws_filter_to_re(f)) - for f in filter_values] - if filter_name == 'tag-key': + regex_filters = [re.compile(simple_aws_filter_to_re(f)) for f in filter_values] + if filter_name == "tag-key": tag_values = get_obj_tag_names(obj) - elif filter_name == 'tag-value': + elif filter_name == "tag-value": tag_values = get_obj_tag_values(obj) else: - tag_values = [get_obj_tag(obj, filter_name) or ''] + tag_values = [get_obj_tag(obj, filter_name) or ""] for tag_value in tag_values: if any(regex.match(tag_value) for regex in regex_filters): @@ -340,22 +340,22 @@ def tag_filter_matches(obj, filter_name, filter_values): filter_dict_attribute_mapping = { - 'instance-state-name': 'state', - 'instance-id': 'id', - 'state-reason-code': '_state_reason.code', - 'source-dest-check': 'source_dest_check', - 'vpc-id': 'vpc_id', - 'group-id': 'security_groups.id', - 'instance.group-id': 'security_groups.id', - 'instance.group-name': 'security_groups.name', - 'instance-type': 'instance_type', - 'private-ip-address': 'private_ip', - 'ip-address': 'public_ip', - 'availability-zone': 'placement', - 'architecture': 'architecture', - 'image-id': 'image_id', - 'network-interface.private-dns-name': 'private_dns', - 'private-dns-name': 'private_dns' + "instance-state-name": "state", + "instance-id": "id", + "state-reason-code": "_state_reason.code", + "source-dest-check": "source_dest_check", + "vpc-id": "vpc_id", + "group-id": "security_groups.id", + "instance.group-id": "security_groups.id", + "instance.group-name": "security_groups.name", + "instance-type": "instance_type", + "private-ip-address": "private_ip", + "ip-address": "public_ip", + "availability-zone": "placement", + "architecture": "architecture", + "image-id": "image_id", + "network-interface.private-dns-name": "private_dns", + "private-dns-name": "private_dns", } @@ -372,8 +372,9 @@ def passes_filter_dict(instance, filter_dict): return False else: raise NotImplementedError( - "Filter dicts have not been implemented in Moto for '%s' yet. Feel free to open an issue at https://github.com/spulec/moto/issues" % - filter_name) + "Filter dicts have not been implemented in Moto for '%s' yet. Feel free to open an issue at https://github.com/spulec/moto/issues" + % filter_name + ) return True @@ -418,7 +419,8 @@ def passes_igw_filter_dict(igw, filter_dict): else: raise NotImplementedError( "Internet Gateway filter dicts have not been implemented in Moto for '%s' yet. Feel free to open an issue at https://github.com/spulec/moto/issues", - filter_name) + filter_name, + ) return True @@ -445,7 +447,9 @@ def is_filter_matching(obj, filter, filter_value): try: value = set(value) - return (value and value.issubset(filter_value)) or value.issuperset(filter_value) + return (value and value.issubset(filter_value)) or value.issuperset( + filter_value + ) except TypeError: return value in filter_value @@ -453,47 +457,48 @@ def is_filter_matching(obj, filter, filter_value): def generic_filter(filters, objects): if filters: for (_filter, _filter_value) in filters.items(): - objects = [obj for obj in objects if is_filter_matching( - obj, _filter, _filter_value)] + objects = [ + obj + for obj in objects + if is_filter_matching(obj, _filter, _filter_value) + ] return objects def simple_aws_filter_to_re(filter_string): - tmp_filter = filter_string.replace(r'\?', '[?]') - tmp_filter = tmp_filter.replace(r'\*', '[*]') + tmp_filter = filter_string.replace(r"\?", "[?]") + tmp_filter = tmp_filter.replace(r"\*", "[*]") tmp_filter = fnmatch.translate(tmp_filter) return tmp_filter def random_key_pair(): private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048, - backend=default_backend()) + public_exponent=65537, key_size=2048, backend=default_backend() + ) private_key_material = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.TraditionalOpenSSL, - encryption_algorithm=serialization.NoEncryption()) + encryption_algorithm=serialization.NoEncryption(), + ) public_key_fingerprint = rsa_public_key_fingerprint(private_key.public_key()) return { - 'fingerprint': public_key_fingerprint, - 'material': private_key_material.decode('ascii') + "fingerprint": public_key_fingerprint, + "material": private_key_material.decode("ascii"), } def get_prefix(resource_id): - resource_id_prefix, separator, after = resource_id.partition('-') - if resource_id_prefix == EC2_RESOURCE_TO_PREFIX['network-interface']: - if after.startswith('attach'): - resource_id_prefix = EC2_RESOURCE_TO_PREFIX[ - 'network-interface-attachment'] + resource_id_prefix, separator, after = resource_id.partition("-") + if resource_id_prefix == EC2_RESOURCE_TO_PREFIX["network-interface"]: + if after.startswith("attach"): + resource_id_prefix = EC2_RESOURCE_TO_PREFIX["network-interface-attachment"] if resource_id_prefix not in EC2_RESOURCE_TO_PREFIX.values(): - uuid4hex = re.compile( - r'[0-9a-f]{12}4[0-9a-f]{3}[89ab][0-9a-f]{15}\Z', re.I) + uuid4hex = re.compile(r"[0-9a-f]{12}4[0-9a-f]{3}[89ab][0-9a-f]{15}\Z", re.I) if uuid4hex.match(resource_id) is not None: - resource_id_prefix = EC2_RESOURCE_TO_PREFIX['reserved-instance'] + resource_id_prefix = EC2_RESOURCE_TO_PREFIX["reserved-instance"] else: return None return resource_id_prefix @@ -504,13 +509,13 @@ def is_valid_resource_id(resource_id): resource_id_prefix = get_prefix(resource_id) if resource_id_prefix not in valid_prefixes: return False - resource_id_pattern = resource_id_prefix + '-[0-9a-f]{8}' + resource_id_pattern = resource_id_prefix + "-[0-9a-f]{8}" resource_pattern_re = re.compile(resource_id_pattern) return resource_pattern_re.match(resource_id) is not None def is_valid_cidr(cird): - cidr_pattern = r'^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])(\/(\d|[1-2]\d|3[0-2]))$' + cidr_pattern = r"^(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){3}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])(\/(\d|[1-2]\d|3[0-2]))$" cidr_pattern_re = re.compile(cidr_pattern) return cidr_pattern_re.match(cird) is not None @@ -528,20 +533,20 @@ def generate_instance_identity_document(instance): """ document = { - 'devPayProductCodes': None, - 'availabilityZone': instance.placement['AvailabilityZone'], - 'privateIp': instance.private_ip_address, - 'version': '2010-8-31', - 'region': instance.placement['AvailabilityZone'][:-1], - 'instanceId': instance.id, - 'billingProducts': None, - 'instanceType': instance.instance_type, - 'accountId': '012345678910', - 'pendingTime': '2015-11-19T16:32:11Z', - 'imageId': instance.image_id, - 'kernelId': instance.kernel_id, - 'ramdiskId': instance.ramdisk_id, - 'architecture': instance.architecture, + "devPayProductCodes": None, + "availabilityZone": instance.placement["AvailabilityZone"], + "privateIp": instance.private_ip_address, + "version": "2010-8-31", + "region": instance.placement["AvailabilityZone"][:-1], + "instanceId": instance.id, + "billingProducts": None, + "instanceType": instance.instance_type, + "accountId": "012345678910", + "pendingTime": "2015-11-19T16:32:11Z", + "imageId": instance.image_id, + "kernelId": instance.kernel_id, + "ramdiskId": instance.ramdisk_id, + "architecture": instance.architecture, } return document @@ -555,10 +560,10 @@ def rsa_public_key_parse(key_material): decoded_key = base64.b64decode(key_material).decode("ascii") public_key = SSHKey(decoded_key) except (sshpubkeys.exceptions.InvalidKeyException, UnicodeDecodeError): - raise ValueError('bad key') + raise ValueError("bad key") if not public_key.rsa: - raise ValueError('bad key') + raise ValueError("bad key") return public_key.rsa @@ -566,7 +571,8 @@ def rsa_public_key_parse(key_material): def rsa_public_key_fingerprint(rsa_public_key): key_data = rsa_public_key.public_bytes( encoding=serialization.Encoding.DER, - format=serialization.PublicFormat.SubjectPublicKeyInfo) + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) fingerprint_hex = hashlib.md5(key_data).hexdigest() - fingerprint = re.sub(r'([a-f0-9]{2})(?!$)', r'\1:', fingerprint_hex) + fingerprint = re.sub(r"([a-f0-9]{2})(?!$)", r"\1:", fingerprint_hex) return fingerprint diff --git a/moto/ecr/__init__.py b/moto/ecr/__init__.py index 56b2cacbb..e90cd9e4c 100644 --- a/moto/ecr/__init__.py +++ b/moto/ecr/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import ecr_backends from ..core.models import base_decorator, deprecated_base_decorator -ecr_backend = ecr_backends['us-east-1'] +ecr_backend = ecr_backends["us-east-1"] mock_ecr = base_decorator(ecr_backends) mock_ecr_deprecated = deprecated_base_decorator(ecr_backends) diff --git a/moto/ecr/exceptions.py b/moto/ecr/exceptions.py index f7b951b53..9b55f0589 100644 --- a/moto/ecr/exceptions.py +++ b/moto/ecr/exceptions.py @@ -9,7 +9,8 @@ class RepositoryNotFoundException(RESTError): super(RepositoryNotFoundException, self).__init__( error_type="RepositoryNotFoundException", message="The repository with name '{0}' does not exist in the registry " - "with id '{1}'".format(repository_name, registry_id)) + "with id '{1}'".format(repository_name, registry_id), + ) class ImageNotFoundException(RESTError): @@ -19,4 +20,7 @@ class ImageNotFoundException(RESTError): super(ImageNotFoundException, self).__init__( error_type="ImageNotFoundException", message="The image with imageId {0} does not exist within the repository with name '{1}' " - "in the registry with id '{2}'".format(image_id, repository_name, registry_id)) + "in the registry with id '{2}'".format( + image_id, repository_name, registry_id + ), + ) diff --git a/moto/ecr/models.py b/moto/ecr/models.py index b03f25dee..f84df79aa 100644 --- a/moto/ecr/models.py +++ b/moto/ecr/models.py @@ -2,7 +2,6 @@ from __future__ import unicode_literals import hashlib import re -from copy import copy from datetime import datetime from random import random @@ -12,26 +11,26 @@ from moto.core import BaseBackend, BaseModel from moto.ec2 import ec2_backends from moto.ecr.exceptions import ImageNotFoundException, RepositoryNotFoundException -DEFAULT_REGISTRY_ID = '012345678910' +DEFAULT_REGISTRY_ID = "012345678910" class BaseObject(BaseModel): - def camelCase(self, key): words = [] - for i, word in enumerate(key.split('_')): + for i, word in enumerate(key.split("_")): if i > 0: words.append(word.title()) else: words.append(word) - return ''.join(words) + return "".join(words) def gen_response_object(self): - response_object = copy(self.__dict__) - for key, value in response_object.items(): - if '_' in key: + response_object = dict() + for key, value in self.__dict__.items(): + if "_" in key: response_object[self.camelCase(key)] = value - del response_object[key] + else: + response_object[key] = value return response_object @property @@ -40,15 +39,16 @@ class BaseObject(BaseModel): class Repository(BaseObject): - def __init__(self, repository_name): self.registry_id = DEFAULT_REGISTRY_ID - self.arn = 'arn:aws:ecr:us-east-1:{0}:repository/{1}'.format( - self.registry_id, repository_name) + self.arn = "arn:aws:ecr:us-east-1:{0}:repository/{1}".format( + self.registry_id, repository_name + ) self.name = repository_name # self.created = datetime.utcnow() - self.uri = '{0}.dkr.ecr.us-east-1.amazonaws.com/{1}'.format( - self.registry_id, repository_name) + self.uri = "{0}.dkr.ecr.us-east-1.amazonaws.com/{1}".format( + self.registry_id, repository_name + ) self.images = [] @property @@ -59,38 +59,45 @@ class Repository(BaseObject): def response_object(self): response_object = self.gen_response_object() - response_object['registryId'] = self.registry_id - response_object['repositoryArn'] = self.arn - response_object['repositoryName'] = self.name - response_object['repositoryUri'] = self.uri + response_object["registryId"] = self.registry_id + response_object["repositoryArn"] = self.arn + response_object["repositoryName"] = self.name + response_object["repositoryUri"] = self.uri # response_object['createdAt'] = self.created - del response_object['arn'], response_object['name'], response_object['images'] + del response_object["arn"], response_object["name"], response_object["images"] return response_object @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] ecr_backend = ecr_backends[region_name] return ecr_backend.create_repository( # RepositoryName is optional in CloudFormation, thus create a random # name if necessary repository_name=properties.get( - 'RepositoryName', 'ecrrepository{0}'.format(int(random() * 10 ** 6))), + "RepositoryName", "ecrrepository{0}".format(int(random() * 10 ** 6)) + ) ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - if original_resource.name != properties['RepositoryName']: + if original_resource.name != properties["RepositoryName"]: ecr_backend = ecr_backends[region_name] ecr_backend.delete_cluster(original_resource.arn) return ecr_backend.create_repository( # RepositoryName is optional in CloudFormation, thus create a # random name if necessary repository_name=properties.get( - 'RepositoryName', 'RepositoryName{0}'.format(int(random() * 10 ** 6))), + "RepositoryName", + "RepositoryName{0}".format(int(random() * 10 ** 6)), + ) ) else: # no-op when nothing changed between old and new resources @@ -98,8 +105,9 @@ class Repository(BaseObject): class Image(BaseObject): - - def __init__(self, tag, manifest, repository, digest=None, registry_id=DEFAULT_REGISTRY_ID): + def __init__( + self, tag, manifest, repository, digest=None, registry_id=DEFAULT_REGISTRY_ID + ): self.image_tag = tag self.image_tags = [tag] if tag is not None else [] self.image_manifest = manifest @@ -110,8 +118,10 @@ class Image(BaseObject): self.image_pushed_at = str(datetime.utcnow().isoformat()) def _create_digest(self): - image_contents = 'docker_image{0}'.format(int(random() * 10 ** 6)) - self.image_digest = "sha256:%s" % hashlib.sha256(image_contents.encode('utf-8')).hexdigest() + image_contents = "docker_image{0}".format(int(random() * 10 ** 6)) + self.image_digest = ( + "sha256:%s" % hashlib.sha256(image_contents.encode("utf-8")).hexdigest() + ) def get_image_digest(self): if not self.image_digest: @@ -135,54 +145,61 @@ class Image(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - response_object['imageId'] = {} - response_object['imageId']['imageTag'] = self.image_tag - response_object['imageId']['imageDigest'] = self.get_image_digest() - response_object['imageManifest'] = self.image_manifest - response_object['repositoryName'] = self.repository - response_object['registryId'] = self.registry_id - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageId"] = {} + response_object["imageId"]["imageTag"] = self.image_tag + response_object["imageId"]["imageDigest"] = self.get_image_digest() + response_object["imageManifest"] = self.image_manifest + response_object["repositoryName"] = self.repository + response_object["registryId"] = self.registry_id + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } @property def response_list_object(self): response_object = self.gen_response_object() - response_object['imageTag'] = self.image_tag - response_object['imageDigest'] = "i don't know" - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageTag"] = self.image_tag + response_object["imageDigest"] = "i don't know" + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } @property def response_describe_object(self): response_object = self.gen_response_object() - response_object['imageTags'] = self.image_tags - response_object['imageDigest'] = self.get_image_digest() - response_object['imageManifest'] = self.image_manifest - response_object['repositoryName'] = self.repository - response_object['registryId'] = self.registry_id - response_object['imageSizeInBytes'] = self.image_size_in_bytes - response_object['imagePushedAt'] = self.image_pushed_at + response_object["imageTags"] = self.image_tags + response_object["imageDigest"] = self.get_image_digest() + response_object["imageManifest"] = self.image_manifest + response_object["repositoryName"] = self.repository + response_object["registryId"] = self.registry_id + response_object["imageSizeInBytes"] = self.image_size_in_bytes + response_object["imagePushedAt"] = self.image_pushed_at return {k: v for k, v in response_object.items() if v is not None and v != []} @property def response_batch_get_image(self): response_object = {} - response_object['imageId'] = {} - response_object['imageId']['imageTag'] = self.image_tag - response_object['imageId']['imageDigest'] = self.get_image_digest() - response_object['imageManifest'] = self.image_manifest - response_object['repositoryName'] = self.repository - response_object['registryId'] = self.registry_id - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageId"] = {} + response_object["imageId"]["imageTag"] = self.image_tag + response_object["imageId"]["imageDigest"] = self.get_image_digest() + response_object["imageManifest"] = self.image_manifest + response_object["repositoryName"] = self.repository + response_object["registryId"] = self.registry_id + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } @property def response_batch_delete_image(self): response_object = {} - response_object['imageDigest'] = self.get_image_digest() - response_object['imageTag'] = self.image_tag - return {k: v for k, v in response_object.items() if v is not None and v != [None]} + response_object["imageDigest"] = self.get_image_digest() + response_object["imageTag"] = self.image_tag + return { + k: v for k, v in response_object.items() if v is not None and v != [None] + } class ECRBackend(BaseBackend): - def __init__(self): self.repositories = {} @@ -193,7 +210,9 @@ class ECRBackend(BaseBackend): if repository_names: for repository_name in repository_names: if repository_name not in self.repositories: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) repositories = [] for repository in self.repositories.values(): @@ -218,7 +237,9 @@ class ECRBackend(BaseBackend): if repository_name in self.repositories: return self.repositories.pop(repository_name) else: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) def list_images(self, repository_name, registry_id=None): """ @@ -235,7 +256,9 @@ class ECRBackend(BaseBackend): found = True if not found: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) images = [] for image in repository.images: @@ -247,26 +270,34 @@ class ECRBackend(BaseBackend): if repository_name in self.repositories: repository = self.repositories[repository_name] else: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) if image_ids: response = set() for image_id in image_ids: found = False for image in repository.images: - if (('imageDigest' in image_id and image.get_image_digest() == image_id['imageDigest']) or - ('imageTag' in image_id and image_id['imageTag'] in image.image_tags)): + if ( + "imageDigest" in image_id + and image.get_image_digest() == image_id["imageDigest"] + ) or ( + "imageTag" in image_id + and image_id["imageTag"] in image.image_tags + ): found = True response.add(image) if not found: image_id_representation = "{imageDigest:'%s', imageTag:'%s'}" % ( - image_id.get('imageDigest', 'null'), - image_id.get('imageTag', 'null'), + image_id.get("imageDigest", "null"), + image_id.get("imageTag", "null"), ) raise ImageNotFoundException( image_id=image_id_representation, repository_name=repository_name, - registry_id=registry_id or DEFAULT_REGISTRY_ID) + registry_id=registry_id or DEFAULT_REGISTRY_ID, + ) else: response = [] @@ -281,7 +312,12 @@ class ECRBackend(BaseBackend): else: raise Exception("{0} is not a repository".format(repository_name)) - existing_images = list(filter(lambda x: x.response_object['imageManifest'] == image_manifest, repository.images)) + existing_images = list( + filter( + lambda x: x.response_object["imageManifest"] == image_manifest, + repository.images, + ) + ) if not existing_images: # this image is not in ECR yet image = Image(image_tag, image_manifest, repository_name) @@ -292,36 +328,47 @@ class ECRBackend(BaseBackend): existing_images[0].update_tag(image_tag) return existing_images[0] - def batch_get_image(self, repository_name, registry_id=None, image_ids=None, accepted_media_types=None): + def batch_get_image( + self, + repository_name, + registry_id=None, + image_ids=None, + accepted_media_types=None, + ): if repository_name in self.repositories: repository = self.repositories[repository_name] else: - raise RepositoryNotFoundException(repository_name, registry_id or DEFAULT_REGISTRY_ID) + raise RepositoryNotFoundException( + repository_name, registry_id or DEFAULT_REGISTRY_ID + ) if not image_ids: - raise ParamValidationError(msg='Missing required parameter in input: "imageIds"') + raise ParamValidationError( + msg='Missing required parameter in input: "imageIds"' + ) - response = { - 'images': [], - 'failures': [], - } + response = {"images": [], "failures": []} for image_id in image_ids: found = False for image in repository.images: - if (('imageDigest' in image_id and image.get_image_digest() == image_id['imageDigest']) or - ('imageTag' in image_id and image.image_tag == image_id['imageTag'])): + if ( + "imageDigest" in image_id + and image.get_image_digest() == image_id["imageDigest"] + ) or ( + "imageTag" in image_id and image.image_tag == image_id["imageTag"] + ): found = True - response['images'].append(image.response_batch_get_image) + response["images"].append(image.response_batch_get_image) if not found: - response['failures'].append({ - 'imageId': { - 'imageTag': image_id.get('imageTag', 'null') - }, - 'failureCode': 'ImageNotFound', - 'failureReason': 'Requested image not found' - }) + response["failures"].append( + { + "imageId": {"imageTag": image_id.get("imageTag", "null")}, + "failureCode": "ImageNotFound", + "failureReason": "Requested image not found", + } + ) return response @@ -338,10 +385,7 @@ class ECRBackend(BaseBackend): msg='Missing required parameter in input: "imageIds"' ) - response = { - "imageIds": [], - "failures": [] - } + response = {"imageIds": [], "failures": []} for image_id in image_ids: image_found = False @@ -377,8 +421,8 @@ class ECRBackend(BaseBackend): # Search by matching both digest and tag if "imageDigest" in image_id and "imageTag" in image_id: if ( - image_id["imageDigest"] == image.get_image_digest() and - image_id["imageTag"] in image.image_tags + image_id["imageDigest"] == image.get_image_digest() + and image_id["imageTag"] in image.image_tags ): image_found = True for image_tag in reversed(image.image_tags): @@ -390,7 +434,10 @@ class ECRBackend(BaseBackend): del repository.images[num] # Search by matching digest - elif "imageDigest" in image_id and image.get_image_digest() == image_id["imageDigest"]: + elif ( + "imageDigest" in image_id + and image.get_image_digest() == image_id["imageDigest"] + ): image_found = True for image_tag in reversed(image.image_tags): repository.images[num].image_tag = image_tag @@ -399,7 +446,9 @@ class ECRBackend(BaseBackend): del repository.images[num] # Search by matching tag - elif "imageTag" in image_id and image_id["imageTag"] in image.image_tags: + elif ( + "imageTag" in image_id and image_id["imageTag"] in image.image_tags + ): image_found = True repository.images[num].image_tag = image_id["imageTag"] response["imageIds"].append(image.response_batch_delete_image) @@ -416,10 +465,14 @@ class ECRBackend(BaseBackend): } if "imageDigest" in image_id: - failure_response["imageId"]["imageDigest"] = image_id.get("imageDigest", "null") + failure_response["imageId"]["imageDigest"] = image_id.get( + "imageDigest", "null" + ) if "imageTag" in image_id: - failure_response["imageId"]["imageTag"] = image_id.get("imageTag", "null") + failure_response["imageId"]["imageTag"] = image_id.get( + "imageTag", "null" + ) response["failures"].append(failure_response) diff --git a/moto/ecr/responses.py b/moto/ecr/responses.py index f758176ad..37078b878 100644 --- a/moto/ecr/responses.py +++ b/moto/ecr/responses.py @@ -24,148 +24,154 @@ class ECRResponse(BaseResponse): return self.request_params.get(param, None) def create_repository(self): - repository_name = self._get_param('repositoryName') + repository_name = self._get_param("repositoryName") if repository_name is None: - repository_name = 'default' + repository_name = "default" repository = self.ecr_backend.create_repository(repository_name) - return json.dumps({ - 'repository': repository.response_object - }) + return json.dumps({"repository": repository.response_object}) def describe_repositories(self): - describe_repositories_name = self._get_param('repositoryNames') - registry_id = self._get_param('registryId') + describe_repositories_name = self._get_param("repositoryNames") + registry_id = self._get_param("registryId") repositories = self.ecr_backend.describe_repositories( - repository_names=describe_repositories_name, registry_id=registry_id) - return json.dumps({ - 'repositories': repositories, - 'failures': [] - }) + repository_names=describe_repositories_name, registry_id=registry_id + ) + return json.dumps({"repositories": repositories, "failures": []}) def delete_repository(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") repository = self.ecr_backend.delete_repository(repository_str, registry_id) - return json.dumps({ - 'repository': repository.response_object - }) + return json.dumps({"repository": repository.response_object}) def put_image(self): - repository_str = self._get_param('repositoryName') - image_manifest = self._get_param('imageManifest') - image_tag = self._get_param('imageTag') + repository_str = self._get_param("repositoryName") + image_manifest = self._get_param("imageManifest") + image_tag = self._get_param("imageTag") image = self.ecr_backend.put_image(repository_str, image_manifest, image_tag) - return json.dumps({ - 'image': image.response_object - }) + return json.dumps({"image": image.response_object}) def list_images(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") images = self.ecr_backend.list_images(repository_str, registry_id) - return json.dumps({ - 'imageIds': [image.response_list_object for image in images], - }) + return json.dumps( + {"imageIds": [image.response_list_object for image in images]} + ) def describe_images(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') - image_ids = self._get_param('imageIds') - images = self.ecr_backend.describe_images(repository_str, registry_id, image_ids) - return json.dumps({ - 'imageDetails': [image.response_describe_object for image in images], - }) + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") + image_ids = self._get_param("imageIds") + images = self.ecr_backend.describe_images( + repository_str, registry_id, image_ids + ) + return json.dumps( + {"imageDetails": [image.response_describe_object for image in images]} + ) def batch_check_layer_availability(self): - if self.is_not_dryrun('BatchCheckLayerAvailability'): + if self.is_not_dryrun("BatchCheckLayerAvailability"): raise NotImplementedError( - 'ECR.batch_check_layer_availability is not yet implemented') + "ECR.batch_check_layer_availability is not yet implemented" + ) def batch_delete_image(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') - image_ids = self._get_param('imageIds') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") + image_ids = self._get_param("imageIds") - response = self.ecr_backend.batch_delete_image(repository_str, registry_id, image_ids) + response = self.ecr_backend.batch_delete_image( + repository_str, registry_id, image_ids + ) return json.dumps(response) def batch_get_image(self): - repository_str = self._get_param('repositoryName') - registry_id = self._get_param('registryId') - image_ids = self._get_param('imageIds') - accepted_media_types = self._get_param('acceptedMediaTypes') + repository_str = self._get_param("repositoryName") + registry_id = self._get_param("registryId") + image_ids = self._get_param("imageIds") + accepted_media_types = self._get_param("acceptedMediaTypes") - response = self.ecr_backend.batch_get_image(repository_str, registry_id, image_ids, accepted_media_types) + response = self.ecr_backend.batch_get_image( + repository_str, registry_id, image_ids, accepted_media_types + ) return json.dumps(response) def can_paginate(self): - if self.is_not_dryrun('CanPaginate'): - raise NotImplementedError( - 'ECR.can_paginate is not yet implemented') + if self.is_not_dryrun("CanPaginate"): + raise NotImplementedError("ECR.can_paginate is not yet implemented") def complete_layer_upload(self): - if self.is_not_dryrun('CompleteLayerUpload'): + if self.is_not_dryrun("CompleteLayerUpload"): raise NotImplementedError( - 'ECR.complete_layer_upload is not yet implemented') + "ECR.complete_layer_upload is not yet implemented" + ) def delete_repository_policy(self): - if self.is_not_dryrun('DeleteRepositoryPolicy'): + if self.is_not_dryrun("DeleteRepositoryPolicy"): raise NotImplementedError( - 'ECR.delete_repository_policy is not yet implemented') + "ECR.delete_repository_policy is not yet implemented" + ) def generate_presigned_url(self): - if self.is_not_dryrun('GeneratePresignedUrl'): + if self.is_not_dryrun("GeneratePresignedUrl"): raise NotImplementedError( - 'ECR.generate_presigned_url is not yet implemented') + "ECR.generate_presigned_url is not yet implemented" + ) def get_authorization_token(self): - registry_ids = self._get_param('registryIds') + registry_ids = self._get_param("registryIds") if not registry_ids: registry_ids = [DEFAULT_REGISTRY_ID] auth_data = [] for registry_id in registry_ids: - password = '{}-auth-token'.format(registry_id) - auth_token = b64encode("AWS:{}".format(password).encode('ascii')).decode() - auth_data.append({ - 'authorizationToken': auth_token, - 'expiresAt': time.mktime(datetime(2015, 1, 1).timetuple()), - 'proxyEndpoint': 'https://{}.dkr.ecr.{}.amazonaws.com'.format(registry_id, self.region) - }) - return json.dumps({'authorizationData': auth_data}) + password = "{}-auth-token".format(registry_id) + auth_token = b64encode("AWS:{}".format(password).encode("ascii")).decode() + auth_data.append( + { + "authorizationToken": auth_token, + "expiresAt": time.mktime(datetime(2015, 1, 1).timetuple()), + "proxyEndpoint": "https://{}.dkr.ecr.{}.amazonaws.com".format( + registry_id, self.region + ), + } + ) + return json.dumps({"authorizationData": auth_data}) def get_download_url_for_layer(self): - if self.is_not_dryrun('GetDownloadUrlForLayer'): + if self.is_not_dryrun("GetDownloadUrlForLayer"): raise NotImplementedError( - 'ECR.get_download_url_for_layer is not yet implemented') + "ECR.get_download_url_for_layer is not yet implemented" + ) def get_paginator(self): - if self.is_not_dryrun('GetPaginator'): - raise NotImplementedError( - 'ECR.get_paginator is not yet implemented') + if self.is_not_dryrun("GetPaginator"): + raise NotImplementedError("ECR.get_paginator is not yet implemented") def get_repository_policy(self): - if self.is_not_dryrun('GetRepositoryPolicy'): + if self.is_not_dryrun("GetRepositoryPolicy"): raise NotImplementedError( - 'ECR.get_repository_policy is not yet implemented') + "ECR.get_repository_policy is not yet implemented" + ) def get_waiter(self): - if self.is_not_dryrun('GetWaiter'): - raise NotImplementedError( - 'ECR.get_waiter is not yet implemented') + if self.is_not_dryrun("GetWaiter"): + raise NotImplementedError("ECR.get_waiter is not yet implemented") def initiate_layer_upload(self): - if self.is_not_dryrun('InitiateLayerUpload'): + if self.is_not_dryrun("InitiateLayerUpload"): raise NotImplementedError( - 'ECR.initiate_layer_upload is not yet implemented') + "ECR.initiate_layer_upload is not yet implemented" + ) def set_repository_policy(self): - if self.is_not_dryrun('SetRepositoryPolicy'): + if self.is_not_dryrun("SetRepositoryPolicy"): raise NotImplementedError( - 'ECR.set_repository_policy is not yet implemented') + "ECR.set_repository_policy is not yet implemented" + ) def upload_layer_part(self): - if self.is_not_dryrun('UploadLayerPart'): - raise NotImplementedError( - 'ECR.upload_layer_part is not yet implemented') + if self.is_not_dryrun("UploadLayerPart"): + raise NotImplementedError("ECR.upload_layer_part is not yet implemented") diff --git a/moto/ecr/urls.py b/moto/ecr/urls.py index 5b12cd843..a25874e43 100644 --- a/moto/ecr/urls.py +++ b/moto/ecr/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import ECRResponse -url_bases = [ - "https?://ecr.(.+).amazonaws.com", - "https?://api.ecr.(.+).amazonaws.com", -] +url_bases = ["https?://ecr.(.+).amazonaws.com", "https?://api.ecr.(.+).amazonaws.com"] -url_paths = { - '{0}/$': ECRResponse.dispatch, -} +url_paths = {"{0}/$": ECRResponse.dispatch} diff --git a/moto/ecs/__init__.py b/moto/ecs/__init__.py index 8fb3dd41e..3048838be 100644 --- a/moto/ecs/__init__.py +++ b/moto/ecs/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import ecs_backends from ..core.models import base_decorator, deprecated_base_decorator -ecs_backend = ecs_backends['us-east-1'] +ecs_backend = ecs_backends["us-east-1"] mock_ecs = base_decorator(ecs_backends) mock_ecs_deprecated = deprecated_base_decorator(ecs_backends) diff --git a/moto/ecs/exceptions.py b/moto/ecs/exceptions.py index 6e329f227..d08066192 100644 --- a/moto/ecs/exceptions.py +++ b/moto/ecs/exceptions.py @@ -9,7 +9,7 @@ class ServiceNotFoundException(RESTError): super(ServiceNotFoundException, self).__init__( error_type="ServiceNotFoundException", message="The service {0} does not exist".format(service_name), - template='error_json', + template="error_json", ) diff --git a/moto/ecs/models.py b/moto/ecs/models.py index 5aa9ae2cb..87578202c 100644 --- a/moto/ecs/models.py +++ b/moto/ecs/models.py @@ -12,27 +12,23 @@ from moto.core.utils import unix_time from moto.ec2 import ec2_backends from copy import copy -from .exceptions import ( - ServiceNotFoundException, - TaskDefinitionNotFoundException -) +from .exceptions import ServiceNotFoundException, TaskDefinitionNotFoundException class BaseObject(BaseModel): - def camelCase(self, key): words = [] - for i, word in enumerate(key.split('_')): + for i, word in enumerate(key.split("_")): if i > 0: words.append(word.title()) else: words.append(word) - return ''.join(words) + return "".join(words) def gen_response_object(self): response_object = copy(self.__dict__) for key, value in self.__dict__.items(): - if '_' in key: + if "_" in key: response_object[self.camelCase(key)] = value del response_object[key] return response_object @@ -43,17 +39,16 @@ class BaseObject(BaseModel): class Cluster(BaseObject): - def __init__(self, cluster_name, region_name): self.active_services_count = 0 - self.arn = 'arn:aws:ecs:{0}:012345678910:cluster/{1}'.format( - region_name, - cluster_name) + self.arn = "arn:aws:ecs:{0}:012345678910:cluster/{1}".format( + region_name, cluster_name + ) self.name = cluster_name self.pending_tasks_count = 0 self.registered_container_instances_count = 0 self.running_tasks_count = 0 - self.status = 'ACTIVE' + self.status = "ACTIVE" self.region_name = region_name @property @@ -63,16 +58,18 @@ class Cluster(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - response_object['clusterArn'] = self.arn - response_object['clusterName'] = self.name - del response_object['arn'], response_object['name'] + response_object["clusterArn"] = self.arn + response_object["clusterName"] = self.name + del response_object["arn"], response_object["name"] return response_object @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): # if properties is not provided, cloudformation will use the default values for all properties - if 'Properties' in cloudformation_json: - properties = cloudformation_json['Properties'] + if "Properties" in cloudformation_json: + properties = cloudformation_json["Properties"] else: properties = {} @@ -81,21 +78,25 @@ class Cluster(BaseObject): # ClusterName is optional in CloudFormation, thus create a random # name if necessary cluster_name=properties.get( - 'ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))), + "ClusterName", "ecscluster{0}".format(int(random() * 10 ** 6)) + ) ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - if original_resource.name != properties['ClusterName']: + if original_resource.name != properties["ClusterName"]: ecs_backend = ecs_backends[region_name] ecs_backend.delete_cluster(original_resource.arn) return ecs_backend.create_cluster( # ClusterName is optional in CloudFormation, thus create a # random name if necessary cluster_name=properties.get( - 'ClusterName', 'ecscluster{0}'.format(int(random() * 10 ** 6))), + "ClusterName", "ecscluster{0}".format(int(random() * 10 ** 6)) + ) ) else: # no-op when nothing changed between old and new resources @@ -103,18 +104,27 @@ class Cluster(BaseObject): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() class TaskDefinition(BaseObject): - - def __init__(self, family, revision, container_definitions, region_name, volumes=None, tags=None): + def __init__( + self, + family, + revision, + container_definitions, + region_name, + volumes=None, + tags=None, + ): self.family = family self.revision = revision - self.arn = 'arn:aws:ecs:{0}:012345678910:task-definition/{1}:{2}'.format( - region_name, family, revision) + self.arn = "arn:aws:ecs:{0}:012345678910:task-definition/{1}:{2}".format( + region_name, family, revision + ) self.container_definitions = container_definitions self.tags = tags if tags is not None else [] if volumes is None: @@ -125,9 +135,9 @@ class TaskDefinition(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - response_object['taskDefinitionArn'] = response_object['arn'] - del response_object['arn'] - del response_object['tags'] + response_object["taskDefinitionArn"] = response_object["arn"] + del response_object["arn"] + del response_object["tags"] return response_object @property @@ -135,56 +145,74 @@ class TaskDefinition(BaseObject): return self.arn @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] family = properties.get( - 'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) - container_definitions = properties['ContainerDefinitions'] - volumes = properties.get('Volumes') + "Family", "task-definition-{0}".format(int(random() * 10 ** 6)) + ) + container_definitions = properties["ContainerDefinitions"] + volumes = properties.get("Volumes") ecs_backend = ecs_backends[region_name] return ecs_backend.register_task_definition( - family=family, container_definitions=container_definitions, volumes=volumes) + family=family, container_definitions=container_definitions, volumes=volumes + ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] family = properties.get( - 'Family', 'task-definition-{0}'.format(int(random() * 10 ** 6))) - container_definitions = properties['ContainerDefinitions'] - volumes = properties.get('Volumes') - if (original_resource.family != family or - original_resource.container_definitions != container_definitions or - original_resource.volumes != volumes): - # currently TaskRoleArn isn't stored at TaskDefinition - # instances + "Family", "task-definition-{0}".format(int(random() * 10 ** 6)) + ) + container_definitions = properties["ContainerDefinitions"] + volumes = properties.get("Volumes") + if ( + original_resource.family != family + or original_resource.container_definitions != container_definitions + or original_resource.volumes != volumes + ): + # currently TaskRoleArn isn't stored at TaskDefinition + # instances ecs_backend = ecs_backends[region_name] ecs_backend.deregister_task_definition(original_resource.arn) return ecs_backend.register_task_definition( - family=family, container_definitions=container_definitions, volumes=volumes) + family=family, + container_definitions=container_definitions, + volumes=volumes, + ) else: # no-op when nothing changed between old and new resources return original_resource class Task(BaseObject): - - def __init__(self, cluster, task_definition, container_instance_arn, - resource_requirements, overrides={}, started_by=''): + def __init__( + self, + cluster, + task_definition, + container_instance_arn, + resource_requirements, + overrides={}, + started_by="", + ): self.cluster_arn = cluster.arn - self.task_arn = 'arn:aws:ecs:{0}:012345678910:task/{1}'.format( - cluster.region_name, - str(uuid.uuid4())) + self.task_arn = "arn:aws:ecs:{0}:012345678910:task/{1}".format( + cluster.region_name, str(uuid.uuid4()) + ) self.container_instance_arn = container_instance_arn - self.last_status = 'RUNNING' - self.desired_status = 'RUNNING' + self.last_status = "RUNNING" + self.desired_status = "RUNNING" self.task_definition_arn = task_definition.arn self.overrides = overrides self.containers = [] self.started_by = started_by - self.stopped_reason = '' + self.stopped_reason = "" self.resource_requirements = resource_requirements @property @@ -194,32 +222,42 @@ class Task(BaseObject): class Service(BaseObject): - - def __init__(self, cluster, service_name, task_definition, desired_count, load_balancers=None, scheduling_strategy=None, tags=None): + def __init__( + self, + cluster, + service_name, + task_definition, + desired_count, + load_balancers=None, + scheduling_strategy=None, + tags=None, + ): self.cluster_arn = cluster.arn - self.arn = 'arn:aws:ecs:{0}:012345678910:service/{1}'.format( - cluster.region_name, - service_name) + self.arn = "arn:aws:ecs:{0}:012345678910:service/{1}".format( + cluster.region_name, service_name + ) self.name = service_name - self.status = 'ACTIVE' + self.status = "ACTIVE" self.running_count = 0 self.task_definition = task_definition.arn self.desired_count = desired_count self.events = [] self.deployments = [ { - 'createdAt': datetime.now(pytz.utc), - 'desiredCount': self.desired_count, - 'id': 'ecs-svc/{}'.format(randint(0, 32**12)), - 'pendingCount': self.desired_count, - 'runningCount': 0, - 'status': 'PRIMARY', - 'taskDefinition': task_definition.arn, - 'updatedAt': datetime.now(pytz.utc), + "createdAt": datetime.now(pytz.utc), + "desiredCount": self.desired_count, + "id": "ecs-svc/{}".format(randint(0, 32 ** 12)), + "pendingCount": self.desired_count, + "runningCount": 0, + "status": "PRIMARY", + "taskDefinition": task_definition.arn, + "updatedAt": datetime.now(pytz.utc), } ] self.load_balancers = load_balancers if load_balancers is not None else [] - self.scheduling_strategy = scheduling_strategy if scheduling_strategy is not None else 'REPLICA' + self.scheduling_strategy = ( + scheduling_strategy if scheduling_strategy is not None else "REPLICA" + ) self.tags = tags if tags is not None else [] self.pending_count = 0 @@ -230,51 +268,60 @@ class Service(BaseObject): @property def response_object(self): response_object = self.gen_response_object() - del response_object['name'], response_object['arn'], response_object['tags'] - response_object['serviceName'] = self.name - response_object['serviceArn'] = self.arn - response_object['schedulingStrategy'] = self.scheduling_strategy + del response_object["name"], response_object["arn"], response_object["tags"] + response_object["serviceName"] = self.name + response_object["serviceArn"] = self.arn + response_object["schedulingStrategy"] = self.scheduling_strategy - for deployment in response_object['deployments']: - if isinstance(deployment['createdAt'], datetime): - deployment['createdAt'] = unix_time(deployment['createdAt'].replace(tzinfo=None)) - if isinstance(deployment['updatedAt'], datetime): - deployment['updatedAt'] = unix_time(deployment['updatedAt'].replace(tzinfo=None)) + for deployment in response_object["deployments"]: + if isinstance(deployment["createdAt"], datetime): + deployment["createdAt"] = unix_time( + deployment["createdAt"].replace(tzinfo=None) + ) + if isinstance(deployment["updatedAt"], datetime): + deployment["updatedAt"] = unix_time( + deployment["updatedAt"].replace(tzinfo=None) + ) return response_object @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - if isinstance(properties['Cluster'], Cluster): - cluster = properties['Cluster'].name + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + if isinstance(properties["Cluster"], Cluster): + cluster = properties["Cluster"].name else: - cluster = properties['Cluster'] - if isinstance(properties['TaskDefinition'], TaskDefinition): - task_definition = properties['TaskDefinition'].family + cluster = properties["Cluster"] + if isinstance(properties["TaskDefinition"], TaskDefinition): + task_definition = properties["TaskDefinition"].family else: - task_definition = properties['TaskDefinition'] - service_name = '{0}Service{1}'.format(cluster, int(random() * 10 ** 6)) - desired_count = properties['DesiredCount'] + task_definition = properties["TaskDefinition"] + service_name = "{0}Service{1}".format(cluster, int(random() * 10 ** 6)) + desired_count = properties["DesiredCount"] # TODO: LoadBalancers # TODO: Role ecs_backend = ecs_backends[region_name] return ecs_backend.create_service( - cluster, service_name, task_definition, desired_count) + cluster, service_name, task_definition, desired_count + ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - if isinstance(properties['Cluster'], Cluster): - cluster_name = properties['Cluster'].name + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + if isinstance(properties["Cluster"], Cluster): + cluster_name = properties["Cluster"].name else: - cluster_name = properties['Cluster'] - if isinstance(properties['TaskDefinition'], TaskDefinition): - task_definition = properties['TaskDefinition'].family + cluster_name = properties["Cluster"] + if isinstance(properties["TaskDefinition"], TaskDefinition): + task_definition = properties["TaskDefinition"].family else: - task_definition = properties['TaskDefinition'] - desired_count = properties['DesiredCount'] + task_definition = properties["TaskDefinition"] + desired_count = properties["DesiredCount"] ecs_backend = ecs_backends[region_name] service_name = original_resource.name @@ -282,104 +329,128 @@ class Service(BaseObject): # TODO: LoadBalancers # TODO: Role ecs_backend.delete_service(cluster_name, service_name) - new_service_name = '{0}Service{1}'.format( - cluster_name, int(random() * 10 ** 6)) + new_service_name = "{0}Service{1}".format( + cluster_name, int(random() * 10 ** 6) + ) return ecs_backend.create_service( - cluster_name, new_service_name, task_definition, desired_count) + cluster_name, new_service_name, task_definition, desired_count + ) else: - return ecs_backend.update_service(cluster_name, service_name, task_definition, desired_count) + return ecs_backend.update_service( + cluster_name, service_name, task_definition, desired_count + ) def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Name': + + if attribute_name == "Name": return self.name raise UnformattedGetAttTemplateException() class ContainerInstance(BaseObject): - def __init__(self, ec2_instance_id, region_name): self.ec2_instance_id = ec2_instance_id self.agent_connected = True - self.status = 'ACTIVE' + self.status = "ACTIVE" self.registered_resources = [ - {'doubleValue': 0.0, - 'integerValue': 4096, - 'longValue': 0, - 'name': 'CPU', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 7482, - 'longValue': 0, - 'name': 'MEMORY', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS', - 'stringSetValue': ['22', '2376', '2375', '51678', '51679'], - 'type': 'STRINGSET'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS_UDP', - 'stringSetValue': [], - 'type': 'STRINGSET'}] + { + "doubleValue": 0.0, + "integerValue": 4096, + "longValue": 0, + "name": "CPU", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 7482, + "longValue": 0, + "name": "MEMORY", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS", + "stringSetValue": ["22", "2376", "2375", "51678", "51679"], + "type": "STRINGSET", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS_UDP", + "stringSetValue": [], + "type": "STRINGSET", + }, + ] self.container_instance_arn = "arn:aws:ecs:{0}:012345678910:container-instance/{1}".format( - region_name, - str(uuid.uuid4())) + region_name, str(uuid.uuid4()) + ) self.pending_tasks_count = 0 self.remaining_resources = [ - {'doubleValue': 0.0, - 'integerValue': 4096, - 'longValue': 0, - 'name': 'CPU', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 7482, - 'longValue': 0, - 'name': 'MEMORY', - 'type': 'INTEGER'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS', - 'stringSetValue': ['22', '2376', '2375', '51678', '51679'], - 'type': 'STRINGSET'}, - {'doubleValue': 0.0, - 'integerValue': 0, - 'longValue': 0, - 'name': 'PORTS_UDP', - 'stringSetValue': [], - 'type': 'STRINGSET'} + { + "doubleValue": 0.0, + "integerValue": 4096, + "longValue": 0, + "name": "CPU", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 7482, + "longValue": 0, + "name": "MEMORY", + "type": "INTEGER", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS", + "stringSetValue": ["22", "2376", "2375", "51678", "51679"], + "type": "STRINGSET", + }, + { + "doubleValue": 0.0, + "integerValue": 0, + "longValue": 0, + "name": "PORTS_UDP", + "stringSetValue": [], + "type": "STRINGSET", + }, ] self.running_tasks_count = 0 self.version_info = { - 'agentVersion': "1.0.0", - 'agentHash': '4023248', - 'dockerVersion': 'DockerVersion: 1.5.0' + "agentVersion": "1.0.0", + "agentHash": "4023248", + "dockerVersion": "DockerVersion: 1.5.0", } ec2_backend = ec2_backends[region_name] ec2_instance = ec2_backend.get_instance(ec2_instance_id) self.attributes = { - 'ecs.ami-id': ec2_instance.image_id, - 'ecs.availability-zone': ec2_instance.placement, - 'ecs.instance-type': ec2_instance.instance_type, - 'ecs.os-type': ec2_instance.platform if ec2_instance.platform == 'windows' else 'linux' # options are windows and linux, linux is default + "ecs.ami-id": ec2_instance.image_id, + "ecs.availability-zone": ec2_instance.placement, + "ecs.instance-type": ec2_instance.instance_type, + "ecs.os-type": ec2_instance.platform + if ec2_instance.platform == "windows" + else "linux", # options are windows and linux, linux is default } @property def response_object(self): response_object = self.gen_response_object() - response_object['attributes'] = [self._format_attribute(name, value) for name, value in response_object['attributes'].items()] + response_object["attributes"] = [ + self._format_attribute(name, value) + for name, value in response_object["attributes"].items() + ] return response_object def _format_attribute(self, name, value): - formatted_attr = { - 'name': name, - } + formatted_attr = {"name": name} if value is not None: - formatted_attr['value'] = value + formatted_attr["value"] = value return formatted_attr @@ -387,35 +458,33 @@ class ClusterFailure(BaseObject): def __init__(self, reason, cluster_name, region_name): self.reason = reason self.arn = "arn:aws:ecs:{0}:012345678910:cluster/{1}".format( - region_name, - cluster_name) + region_name, cluster_name + ) @property def response_object(self): response_object = self.gen_response_object() - response_object['reason'] = self.reason - response_object['arn'] = self.arn + response_object["reason"] = self.reason + response_object["arn"] = self.arn return response_object class ContainerInstanceFailure(BaseObject): - def __init__(self, reason, container_instance_id, region_name): self.reason = reason self.arn = "arn:aws:ecs:{0}:012345678910:container-instance/{1}".format( - region_name, - container_instance_id) + region_name, container_instance_id + ) @property def response_object(self): response_object = self.gen_response_object() - response_object['reason'] = self.reason - response_object['arn'] = self.arn + response_object["reason"] = self.reason + response_object["arn"] = self.arn return response_object class EC2ContainerServiceBackend(BaseBackend): - def __init__(self, region_name): super(EC2ContainerServiceBackend, self).__init__() self.clusters = {} @@ -431,19 +500,21 @@ class EC2ContainerServiceBackend(BaseBackend): self.__init__(region_name) def describe_task_definition(self, task_definition_str): - task_definition_name = task_definition_str.split('/')[-1] - if ':' in task_definition_name: - family, revision = task_definition_name.split(':') + task_definition_name = task_definition_str.split("/")[-1] + if ":" in task_definition_name: + family, revision = task_definition_name.split(":") revision = int(revision) else: family = task_definition_name revision = self._get_last_task_definition_revision_id(family) - if family in self.task_definitions and revision in self.task_definitions[family]: + if ( + family in self.task_definitions + and revision in self.task_definitions[family] + ): return self.task_definitions[family][revision] else: - raise Exception( - "{0} is not a task_definition".format(task_definition_name)) + raise Exception("{0} is not a task_definition".format(task_definition_name)) def create_cluster(self, cluster_name): cluster = Cluster(cluster_name, self.region_name) @@ -460,26 +531,29 @@ class EC2ContainerServiceBackend(BaseBackend): list_clusters = [] failures = [] if list_clusters_name is None: - if 'default' in self.clusters: - list_clusters.append(self.clusters['default'].response_object) + if "default" in self.clusters: + list_clusters.append(self.clusters["default"].response_object) else: for cluster in list_clusters_name: - cluster_name = cluster.split('/')[-1] + cluster_name = cluster.split("/")[-1] if cluster_name in self.clusters: - list_clusters.append( - self.clusters[cluster_name].response_object) + list_clusters.append(self.clusters[cluster_name].response_object) else: - failures.append(ClusterFailure('MISSING', cluster_name, self.region_name)) + failures.append( + ClusterFailure("MISSING", cluster_name, self.region_name) + ) return list_clusters, failures def delete_cluster(self, cluster_str): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: return self.clusters.pop(cluster_name) else: raise Exception("{0} is not a cluster".format(cluster_name)) - def register_task_definition(self, family, container_definitions, volumes, tags=None): + def register_task_definition( + self, family, container_definitions, volumes, tags=None + ): if family in self.task_definitions: last_id = self._get_last_task_definition_revision_id(family) revision = (last_id or 0) + 1 @@ -487,7 +561,8 @@ class EC2ContainerServiceBackend(BaseBackend): self.task_definitions[family] = {} revision = 1 task_definition = TaskDefinition( - family, revision, container_definitions, self.region_name, volumes, tags) + family, revision, container_definitions, self.region_name, volumes, tags + ) self.task_definitions[family][revision] = task_definition return task_definition @@ -498,24 +573,28 @@ class EC2ContainerServiceBackend(BaseBackend): """ task_arns = [] for task_definition_list in self.task_definitions.values(): - task_arns.extend([ - task_definition.arn - for task_definition in task_definition_list.values() - ]) + task_arns.extend( + [ + task_definition.arn + for task_definition in task_definition_list.values() + ] + ) return task_arns def deregister_task_definition(self, task_definition_str): - task_definition_name = task_definition_str.split('/')[-1] - family, revision = task_definition_name.split(':') + task_definition_name = task_definition_str.split("/")[-1] + family, revision = task_definition_name.split(":") revision = int(revision) - if family in self.task_definitions and revision in self.task_definitions[family]: + if ( + family in self.task_definitions + and revision in self.task_definitions[family] + ): return self.task_definitions[family].pop(revision) else: - raise Exception( - "{0} is not a task_definition".format(task_definition_name)) + raise Exception("{0} is not a task_definition".format(task_definition_name)) def run_task(self, cluster_str, task_definition_str, count, overrides, started_by): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -525,24 +604,42 @@ class EC2ContainerServiceBackend(BaseBackend): self.tasks[cluster_name] = {} tasks = [] container_instances = list( - self.container_instances.get(cluster_name, {}).keys()) + self.container_instances.get(cluster_name, {}).keys() + ) if not container_instances: raise Exception("No instances found in cluster {}".format(cluster_name)) - active_container_instances = [x for x in container_instances if - self.container_instances[cluster_name][x].status == 'ACTIVE'] - resource_requirements = self._calculate_task_resource_requirements(task_definition) + active_container_instances = [ + x + for x in container_instances + if self.container_instances[cluster_name][x].status == "ACTIVE" + ] + resource_requirements = self._calculate_task_resource_requirements( + task_definition + ) # TODO: return event about unable to place task if not able to place enough tasks to meet count placed_count = 0 for container_instance in active_container_instances: - container_instance = self.container_instances[cluster_name][container_instance] + container_instance = self.container_instances[cluster_name][ + container_instance + ] container_instance_arn = container_instance.container_instance_arn try_to_place = True while try_to_place: - can_be_placed, message = self._can_be_placed(container_instance, resource_requirements) + can_be_placed, message = self._can_be_placed( + container_instance, resource_requirements + ) if can_be_placed: - task = Task(cluster, task_definition, container_instance_arn, - resource_requirements, overrides or {}, started_by or '') - self.update_container_instance_resources(container_instance, resource_requirements) + task = Task( + cluster, + task_definition, + container_instance_arn, + resource_requirements, + overrides or {}, + started_by or "", + ) + self.update_container_instance_resources( + container_instance, resource_requirements + ) tasks.append(task) self.tasks[cluster_name][task.task_arn] = task placed_count += 1 @@ -559,23 +656,33 @@ class EC2ContainerServiceBackend(BaseBackend): # cloudformation uses capitalized properties, while boto uses all lower case # CPU is optional - resource_requirements["CPU"] += container_definition.get('cpu', - container_definition.get('Cpu', 0)) + resource_requirements["CPU"] += container_definition.get( + "cpu", container_definition.get("Cpu", 0) + ) # either memory or memory reservation must be provided - if 'Memory' in container_definition or 'MemoryReservation' in container_definition: + if ( + "Memory" in container_definition + or "MemoryReservation" in container_definition + ): resource_requirements["MEMORY"] += container_definition.get( - "Memory", container_definition.get('MemoryReservation')) + "Memory", container_definition.get("MemoryReservation") + ) else: resource_requirements["MEMORY"] += container_definition.get( - "memory", container_definition.get('memoryReservation')) + "memory", container_definition.get("memoryReservation") + ) - port_mapping_key = 'PortMappings' if 'PortMappings' in container_definition else 'portMappings' + port_mapping_key = ( + "PortMappings" + if "PortMappings" in container_definition + else "portMappings" + ) for port_mapping in container_definition.get(port_mapping_key, []): - if 'hostPort' in port_mapping: - resource_requirements["PORTS"].append(port_mapping.get('hostPort')) - elif 'HostPort' in port_mapping: - resource_requirements["PORTS"].append(port_mapping.get('HostPort')) + if "hostPort" in port_mapping: + resource_requirements["PORTS"].append(port_mapping.get("hostPort")) + elif "HostPort" in port_mapping: + resource_requirements["PORTS"].append(port_mapping.get("HostPort")) return resource_requirements @@ -610,8 +717,15 @@ class EC2ContainerServiceBackend(BaseBackend): return False, "Port clash" return True, "Can be placed" - def start_task(self, cluster_str, task_definition_str, container_instances, overrides, started_by): - cluster_name = cluster_str.split('/')[-1] + def start_task( + self, + cluster_str, + task_definition_str, + container_instances, + overrides, + started_by, + ): + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -623,22 +737,31 @@ class EC2ContainerServiceBackend(BaseBackend): if not container_instances: raise Exception("No container instance list provided") - container_instance_ids = [x.split('/')[-1] - for x in container_instances] - resource_requirements = self._calculate_task_resource_requirements(task_definition) + container_instance_ids = [x.split("/")[-1] for x in container_instances] + resource_requirements = self._calculate_task_resource_requirements( + task_definition + ) for container_instance_id in container_instance_ids: container_instance = self.container_instances[cluster_name][ container_instance_id ] - task = Task(cluster, task_definition, container_instance.container_instance_arn, - resource_requirements, overrides or {}, started_by or '') + task = Task( + cluster, + task_definition, + container_instance.container_instance_arn, + resource_requirements, + overrides or {}, + started_by or "", + ) tasks.append(task) - self.update_container_instance_resources(container_instance, resource_requirements) + self.update_container_instance_resources( + container_instance, resource_requirements + ) self.tasks[cluster_name][task.task_arn] = task return tasks def describe_tasks(self, cluster_str, tasks): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -649,58 +772,88 @@ class EC2ContainerServiceBackend(BaseBackend): for cluster, cluster_tasks in self.tasks.items(): for task_arn, task in cluster_tasks.items(): task_id = task_arn.split("/")[-1] - if task_arn in tasks or task.task_arn in tasks or any(task_id in task for task in tasks): + if ( + task_arn in tasks + or task.task_arn in tasks + or any(task_id in task for task in tasks) + ): response.append(task) return response - def list_tasks(self, cluster_str, container_instance, family, started_by, service_name, desiredStatus): + def list_tasks( + self, + cluster_str, + container_instance, + family, + started_by, + service_name, + desiredStatus, + ): filtered_tasks = [] for cluster, tasks in self.tasks.items(): for arn, task in tasks.items(): filtered_tasks.append(task) if cluster_str: - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) filtered_tasks = list( - filter(lambda t: cluster_name in t.cluster_arn, filtered_tasks)) + filter(lambda t: cluster_name in t.cluster_arn, filtered_tasks) + ) if container_instance: - filtered_tasks = list(filter( - lambda t: container_instance in t.container_instance_arn, filtered_tasks)) + filtered_tasks = list( + filter( + lambda t: container_instance in t.container_instance_arn, + filtered_tasks, + ) + ) if started_by: filtered_tasks = list( - filter(lambda t: started_by == t.started_by, filtered_tasks)) + filter(lambda t: started_by == t.started_by, filtered_tasks) + ) return [t.task_arn for t in filtered_tasks] def stop_task(self, cluster_str, task_str, reason): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) if not task_str: raise Exception("A task ID or ARN is required") - task_id = task_str.split('/')[-1] + task_id = task_str.split("/")[-1] tasks = self.tasks.get(cluster_name, None) if not tasks: - raise Exception( - "Cluster {} has no registered tasks".format(cluster_name)) + raise Exception("Cluster {} has no registered tasks".format(cluster_name)) for task in tasks.keys(): if task.endswith(task_id): container_instance_arn = tasks[task].container_instance_arn - container_instance = self.container_instances[cluster_name][container_instance_arn.split('/')[-1]] - self.update_container_instance_resources(container_instance, tasks[task].resource_requirements, - removing=True) - tasks[task].last_status = 'STOPPED' - tasks[task].desired_status = 'STOPPED' + container_instance = self.container_instances[cluster_name][ + container_instance_arn.split("/")[-1] + ] + self.update_container_instance_resources( + container_instance, tasks[task].resource_requirements, removing=True + ) + tasks[task].last_status = "STOPPED" + tasks[task].desired_status = "STOPPED" tasks[task].stopped_reason = reason return tasks[task] - raise Exception("Could not find task {} on cluster {}".format( - task_str, cluster_name)) + raise Exception( + "Could not find task {} on cluster {}".format(task_str, cluster_name) + ) - def create_service(self, cluster_str, service_name, task_definition_str, desired_count, load_balancers=None, scheduling_strategy=None, tags=None): - cluster_name = cluster_str.split('/')[-1] + def create_service( + self, + cluster_str, + service_name, + task_definition_str, + desired_count, + load_balancers=None, + scheduling_strategy=None, + tags=None, + ): + cluster_name = cluster_str.split("/")[-1] if cluster_name in self.clusters: cluster = self.clusters[cluster_name] else: @@ -708,52 +861,70 @@ class EC2ContainerServiceBackend(BaseBackend): task_definition = self.describe_task_definition(task_definition_str) desired_count = desired_count if desired_count is not None else 0 - service = Service(cluster, service_name, - task_definition, desired_count, load_balancers, scheduling_strategy, tags) - cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) + service = Service( + cluster, + service_name, + task_definition, + desired_count, + load_balancers, + scheduling_strategy, + tags, + ) + cluster_service_pair = "{0}:{1}".format(cluster_name, service_name) self.services[cluster_service_pair] = service return service def list_services(self, cluster_str, scheduling_strategy=None): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] service_arns = [] for key, value in self.services.items(): - if cluster_name + ':' in key: + if cluster_name + ":" in key: service = self.services[key] - if scheduling_strategy is None or service.scheduling_strategy == scheduling_strategy: + if ( + scheduling_strategy is None + or service.scheduling_strategy == scheduling_strategy + ): service_arns.append(service.arn) return sorted(service_arns) def describe_services(self, cluster_str, service_names_or_arns): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] result = [] - for existing_service_name, existing_service_obj in sorted(self.services.items()): + for existing_service_name, existing_service_obj in sorted( + self.services.items() + ): for requested_name_or_arn in service_names_or_arns: - cluster_service_pair = '{0}:{1}'.format( - cluster_name, requested_name_or_arn) - if cluster_service_pair == existing_service_name or existing_service_obj.arn == requested_name_or_arn: + cluster_service_pair = "{0}:{1}".format( + cluster_name, requested_name_or_arn + ) + if ( + cluster_service_pair == existing_service_name + or existing_service_obj.arn == requested_name_or_arn + ): result.append(existing_service_obj) return result - def update_service(self, cluster_str, service_name, task_definition_str, desired_count): - cluster_name = cluster_str.split('/')[-1] - cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) + def update_service( + self, cluster_str, service_name, task_definition_str, desired_count + ): + cluster_name = cluster_str.split("/")[-1] + cluster_service_pair = "{0}:{1}".format(cluster_name, service_name) if cluster_service_pair in self.services: if task_definition_str is not None: self.describe_task_definition(task_definition_str) self.services[ - cluster_service_pair].task_definition = task_definition_str + cluster_service_pair + ].task_definition = task_definition_str if desired_count is not None: - self.services[ - cluster_service_pair].desired_count = desired_count + self.services[cluster_service_pair].desired_count = desired_count return self.services[cluster_service_pair] else: raise ServiceNotFoundException(service_name) def delete_service(self, cluster_name, service_name): - cluster_service_pair = '{0}:{1}'.format(cluster_name, service_name) + cluster_service_pair = "{0}:{1}".format(cluster_name, service_name) if cluster_service_pair in self.services: service = self.services[cluster_service_pair] if service.desired_count > 0: @@ -761,82 +932,110 @@ class EC2ContainerServiceBackend(BaseBackend): else: return self.services.pop(cluster_service_pair) else: - raise Exception("cluster {0} or service {1} does not exist".format( - cluster_name, service_name)) + raise Exception( + "cluster {0} or service {1} does not exist".format( + cluster_name, service_name + ) + ) def register_container_instance(self, cluster_str, ec2_instance_id): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) container_instance = ContainerInstance(ec2_instance_id, self.region_name) if not self.container_instances.get(cluster_name): self.container_instances[cluster_name] = {} - container_instance_id = container_instance.container_instance_arn.split( - '/')[-1] + container_instance_id = container_instance.container_instance_arn.split("/")[-1] self.container_instances[cluster_name][ - container_instance_id] = container_instance + container_instance_id + ] = container_instance self.clusters[cluster_name].registered_container_instances_count += 1 return container_instance def list_container_instances(self, cluster_str): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] container_instances_values = self.container_instances.get( - cluster_name, {}).values() + cluster_name, {} + ).values() container_instances = [ - ci.container_instance_arn for ci in container_instances_values] + ci.container_instance_arn for ci in container_instances_values + ] return sorted(container_instances) def describe_container_instances(self, cluster_str, list_container_instance_ids): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) if not list_container_instance_ids: - raise JsonRESTError('InvalidParameterException', 'Container instance cannot be empty') + raise JsonRESTError( + "InvalidParameterException", "Container instance cannot be empty" + ) failures = [] container_instance_objects = [] for container_instance_id in list_container_instance_ids: - container_instance_id = container_instance_id.split('/')[-1] - container_instance = self.container_instances[ - cluster_name].get(container_instance_id, None) + container_instance_id = container_instance_id.split("/")[-1] + container_instance = self.container_instances[cluster_name].get( + container_instance_id, None + ) if container_instance is not None: container_instance_objects.append(container_instance) else: - failures.append(ContainerInstanceFailure( - 'MISSING', container_instance_id, self.region_name)) + failures.append( + ContainerInstanceFailure( + "MISSING", container_instance_id, self.region_name + ) + ) return container_instance_objects, failures - def update_container_instances_state(self, cluster_str, list_container_instance_ids, status): - cluster_name = cluster_str.split('/')[-1] + def update_container_instances_state( + self, cluster_str, list_container_instance_ids, status + ): + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) status = status.upper() - if status not in ['ACTIVE', 'DRAINING']: - raise Exception("An error occurred (InvalidParameterException) when calling the UpdateContainerInstancesState operation: \ - Container instances status should be one of [ACTIVE,DRAINING]") + if status not in ["ACTIVE", "DRAINING"]: + raise Exception( + "An error occurred (InvalidParameterException) when calling the UpdateContainerInstancesState operation: \ + Container instances status should be one of [ACTIVE,DRAINING]" + ) failures = [] container_instance_objects = [] - list_container_instance_ids = [x.split('/')[-1] - for x in list_container_instance_ids] + list_container_instance_ids = [ + x.split("/")[-1] for x in list_container_instance_ids + ] for container_instance_id in list_container_instance_ids: - container_instance = self.container_instances[cluster_name].get(container_instance_id, None) + container_instance = self.container_instances[cluster_name].get( + container_instance_id, None + ) if container_instance is not None: container_instance.status = status container_instance_objects.append(container_instance) else: - failures.append(ContainerInstanceFailure('MISSING', container_instance_id, self.region_name)) + failures.append( + ContainerInstanceFailure( + "MISSING", container_instance_id, self.region_name + ) + ) return container_instance_objects, failures - def update_container_instance_resources(self, container_instance, task_resources, removing=False): + def update_container_instance_resources( + self, container_instance, task_resources, removing=False + ): resource_multiplier = 1 if removing: resource_multiplier = -1 for resource in container_instance.remaining_resources: if resource.get("name") == "CPU": - resource["integerValue"] -= task_resources.get('CPU') * resource_multiplier + resource["integerValue"] -= ( + task_resources.get("CPU") * resource_multiplier + ) elif resource.get("name") == "MEMORY": - resource["integerValue"] -= task_resources.get('MEMORY') * resource_multiplier + resource["integerValue"] -= ( + task_resources.get("MEMORY") * resource_multiplier + ) elif resource.get("name") == "PORTS": for port in task_resources.get("PORTS"): if removing: @@ -847,11 +1046,13 @@ class EC2ContainerServiceBackend(BaseBackend): def deregister_container_instance(self, cluster_str, container_instance_str, force): failures = [] - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) - container_instance_id = container_instance_str.split('/')[-1] - container_instance = self.container_instances[cluster_name].get(container_instance_id) + container_instance_id = container_instance_str.split("/")[-1] + container_instance = self.container_instances[cluster_name].get( + container_instance_id + ) if container_instance is None: raise Exception("{0} is not a container id in the cluster") if not force and container_instance.running_tasks_count > 0: @@ -859,53 +1060,86 @@ class EC2ContainerServiceBackend(BaseBackend): # Currently assume that people might want to do something based around deregistered instances # with tasks left running on them - but nothing if no tasks were running already elif force and container_instance.running_tasks_count > 0: - if not self.container_instances.get('orphaned'): - self.container_instances['orphaned'] = {} - self.container_instances['orphaned'][container_instance_id] = container_instance - del(self.container_instances[cluster_name][container_instance_id]) + if not self.container_instances.get("orphaned"): + self.container_instances["orphaned"] = {} + self.container_instances["orphaned"][ + container_instance_id + ] = container_instance + del self.container_instances[cluster_name][container_instance_id] self._respond_to_cluster_state_update(cluster_str) return container_instance, failures def _respond_to_cluster_state_update(self, cluster_str): - cluster_name = cluster_str.split('/')[-1] + cluster_name = cluster_str.split("/")[-1] if cluster_name not in self.clusters: raise Exception("{0} is not a cluster".format(cluster_name)) pass def put_attributes(self, cluster_name, attributes=None): if cluster_name is None or cluster_name not in self.clusters: - raise JsonRESTError('ClusterNotFoundException', 'Cluster not found', status=400) + raise JsonRESTError( + "ClusterNotFoundException", "Cluster not found", status=400 + ) if attributes is None: - raise JsonRESTError('InvalidParameterException', 'attributes value is required') + raise JsonRESTError( + "InvalidParameterException", "attributes value is required" + ) for attr in attributes: - self._put_attribute(cluster_name, attr['name'], attr.get('value'), attr.get('targetId'), attr.get('targetType')) + self._put_attribute( + cluster_name, + attr["name"], + attr.get("value"), + attr.get("targetId"), + attr.get("targetType"), + ) - def _put_attribute(self, cluster_name, name, value=None, target_id=None, target_type=None): + def _put_attribute( + self, cluster_name, name, value=None, target_id=None, target_type=None + ): if target_id is None and target_type is None: for instance in self.container_instances[cluster_name].values(): instance.attributes[name] = value elif target_type is None: # targetId is full container instance arn try: - arn = target_id.rsplit('/', 1)[-1] + arn = target_id.rsplit("/", 1)[-1] self.container_instances[cluster_name][arn].attributes[name] = value except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) else: # targetId is container uuid, targetType must be container-instance try: - if target_type != 'container-instance': - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + if target_type != "container-instance": + raise JsonRESTError( + "TargetNotFoundException", + "Could not find {0}".format(target_id), + ) - self.container_instances[cluster_name][target_id].attributes[name] = value + self.container_instances[cluster_name][target_id].attributes[ + name + ] = value except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) - def list_attributes(self, target_type, cluster_name=None, attr_name=None, attr_value=None, max_results=None, next_token=None): - if target_type != 'container-instance': - raise JsonRESTError('InvalidParameterException', 'targetType must be container-instance') + def list_attributes( + self, + target_type, + cluster_name=None, + attr_name=None, + attr_value=None, + max_results=None, + next_token=None, + ): + if target_type != "container-instance": + raise JsonRESTError( + "InvalidParameterException", "targetType must be container-instance" + ) filters = [lambda x: True] @@ -921,21 +1155,40 @@ class EC2ContainerServiceBackend(BaseBackend): for cluster_name, cobj in self.container_instances.items(): for container_instance in cobj.values(): for key, value in container_instance.attributes.items(): - all_attrs.append((cluster_name, container_instance.container_instance_arn, key, value)) + all_attrs.append( + ( + cluster_name, + container_instance.container_instance_arn, + key, + value, + ) + ) return filter(lambda x: all(f(x) for f in filters), all_attrs) def delete_attributes(self, cluster_name, attributes=None): if cluster_name is None or cluster_name not in self.clusters: - raise JsonRESTError('ClusterNotFoundException', 'Cluster not found', status=400) + raise JsonRESTError( + "ClusterNotFoundException", "Cluster not found", status=400 + ) if attributes is None: - raise JsonRESTError('InvalidParameterException', 'attributes value is required') + raise JsonRESTError( + "InvalidParameterException", "attributes value is required" + ) for attr in attributes: - self._delete_attribute(cluster_name, attr['name'], attr.get('value'), attr.get('targetId'), attr.get('targetType')) + self._delete_attribute( + cluster_name, + attr["name"], + attr.get("value"), + attr.get("targetId"), + attr.get("targetType"), + ) - def _delete_attribute(self, cluster_name, name, value=None, target_id=None, target_type=None): + def _delete_attribute( + self, cluster_name, name, value=None, target_id=None, target_type=None + ): if target_id is None and target_type is None: for instance in self.container_instances[cluster_name].values(): if name in instance.attributes and instance.attributes[name] == value: @@ -943,25 +1196,34 @@ class EC2ContainerServiceBackend(BaseBackend): elif target_type is None: # targetId is full container instance arn try: - arn = target_id.rsplit('/', 1)[-1] + arn = target_id.rsplit("/", 1)[-1] instance = self.container_instances[cluster_name][arn] if name in instance.attributes and instance.attributes[name] == value: del instance.attributes[name] except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) else: # targetId is container uuid, targetType must be container-instance try: - if target_type != 'container-instance': - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + if target_type != "container-instance": + raise JsonRESTError( + "TargetNotFoundException", + "Could not find {0}".format(target_id), + ) instance = self.container_instances[cluster_name][target_id] if name in instance.attributes and instance.attributes[name] == value: del instance.attributes[name] except KeyError: - raise JsonRESTError('TargetNotFoundException', 'Could not find {0}'.format(target_id)) + raise JsonRESTError( + "TargetNotFoundException", "Could not find {0}".format(target_id) + ) - def list_task_definition_families(self, family_prefix=None, status=None, max_results=None, next_token=None): + def list_task_definition_families( + self, family_prefix=None, status=None, max_results=None, next_token=None + ): for task_fam in self.task_definitions: if family_prefix is not None and not task_fam.startswith(family_prefix): continue @@ -972,9 +1234,12 @@ class EC2ContainerServiceBackend(BaseBackend): def _parse_resource_arn(resource_arn): match = re.match( "^arn:aws:ecs:(?P[^:]+):(?P[^:]+):(?P[^:]+)/(?P.*)$", - resource_arn) + resource_arn, + ) if not match: - raise JsonRESTError('InvalidParameterException', 'The ARN provided is invalid.') + raise JsonRESTError( + "InvalidParameterException", "The ARN provided is invalid." + ) return match.groupdict() def list_tags_for_resource(self, resource_arn): @@ -1022,7 +1287,7 @@ class EC2ContainerServiceBackend(BaseBackend): @staticmethod def _get_keys(tags): - return [tag['key'] for tag in tags] + return [tag["key"] for tag in tags] def untag_resource(self, resource_arn, tag_keys): """Currently implemented only for services""" @@ -1030,7 +1295,9 @@ class EC2ContainerServiceBackend(BaseBackend): if parsed_arn["service"] == "service": for service in self.services.values(): if service.arn == resource_arn: - service.tags = [tag for tag in service.tags if tag["key"] not in tag_keys] + service.tags = [ + tag for tag in service.tags if tag["key"] not in tag_keys + ] return {} else: raise ServiceNotFoundException(service_name=parsed_arn["id"]) @@ -1038,4 +1305,6 @@ class EC2ContainerServiceBackend(BaseBackend): available_regions = boto3.session.Session().get_available_regions("ecs") -ecs_backends = {region: EC2ContainerServiceBackend(region) for region in available_regions} +ecs_backends = { + region: EC2ContainerServiceBackend(region) for region in available_regions +} diff --git a/moto/ecs/responses.py b/moto/ecs/responses.py index 053b079b9..5116b0472 100644 --- a/moto/ecs/responses.py +++ b/moto/ecs/responses.py @@ -6,7 +6,6 @@ from .models import ecs_backends class EC2ContainerServiceResponse(BaseResponse): - @property def ecs_backend(self): """ @@ -28,307 +27,315 @@ class EC2ContainerServiceResponse(BaseResponse): return self.request_params.get(param, if_none) def create_cluster(self): - cluster_name = self._get_param('clusterName') + cluster_name = self._get_param("clusterName") if cluster_name is None: - cluster_name = 'default' + cluster_name = "default" cluster = self.ecs_backend.create_cluster(cluster_name) - return json.dumps({ - 'cluster': cluster.response_object - }) + return json.dumps({"cluster": cluster.response_object}) def list_clusters(self): cluster_arns = self.ecs_backend.list_clusters() - return json.dumps({ - 'clusterArns': cluster_arns - # 'nextToken': str(uuid.uuid4()) - }) + return json.dumps( + { + "clusterArns": cluster_arns + # 'nextToken': str(uuid.uuid4()) + } + ) def describe_clusters(self): - list_clusters_name = self._get_param('clusters') + list_clusters_name = self._get_param("clusters") clusters, failures = self.ecs_backend.describe_clusters(list_clusters_name) - return json.dumps({ - 'clusters': clusters, - 'failures': [cluster.response_object for cluster in failures] - }) + return json.dumps( + { + "clusters": clusters, + "failures": [cluster.response_object for cluster in failures], + } + ) def delete_cluster(self): - cluster_str = self._get_param('cluster') + cluster_str = self._get_param("cluster") cluster = self.ecs_backend.delete_cluster(cluster_str) - return json.dumps({ - 'cluster': cluster.response_object - }) + return json.dumps({"cluster": cluster.response_object}) def register_task_definition(self): - family = self._get_param('family') - container_definitions = self._get_param('containerDefinitions') - volumes = self._get_param('volumes') - tags = self._get_param('tags') + family = self._get_param("family") + container_definitions = self._get_param("containerDefinitions") + volumes = self._get_param("volumes") + tags = self._get_param("tags") task_definition = self.ecs_backend.register_task_definition( - family, container_definitions, volumes, tags) - return json.dumps({ - 'taskDefinition': task_definition.response_object - }) + family, container_definitions, volumes, tags + ) + return json.dumps({"taskDefinition": task_definition.response_object}) def list_task_definitions(self): task_definition_arns = self.ecs_backend.list_task_definitions() - return json.dumps({ - 'taskDefinitionArns': task_definition_arns - # 'nextToken': str(uuid.uuid4()) - }) + return json.dumps( + { + "taskDefinitionArns": task_definition_arns + # 'nextToken': str(uuid.uuid4()) + } + ) def describe_task_definition(self): - task_definition_str = self._get_param('taskDefinition') + task_definition_str = self._get_param("taskDefinition") data = self.ecs_backend.describe_task_definition(task_definition_str) - return json.dumps({ - 'taskDefinition': data.response_object, - 'failures': [] - }) + return json.dumps({"taskDefinition": data.response_object, "failures": []}) def deregister_task_definition(self): - task_definition_str = self._get_param('taskDefinition') + task_definition_str = self._get_param("taskDefinition") task_definition = self.ecs_backend.deregister_task_definition( - task_definition_str) - return json.dumps({ - 'taskDefinition': task_definition.response_object - }) + task_definition_str + ) + return json.dumps({"taskDefinition": task_definition.response_object}) def run_task(self): - cluster_str = self._get_param('cluster') - overrides = self._get_param('overrides') - task_definition_str = self._get_param('taskDefinition') - count = self._get_int_param('count') - started_by = self._get_param('startedBy') + cluster_str = self._get_param("cluster") + overrides = self._get_param("overrides") + task_definition_str = self._get_param("taskDefinition") + count = self._get_int_param("count") + started_by = self._get_param("startedBy") tasks = self.ecs_backend.run_task( - cluster_str, task_definition_str, count, overrides, started_by) - return json.dumps({ - 'tasks': [task.response_object for task in tasks], - 'failures': [] - }) + cluster_str, task_definition_str, count, overrides, started_by + ) + return json.dumps( + {"tasks": [task.response_object for task in tasks], "failures": []} + ) def describe_tasks(self): - cluster = self._get_param('cluster') - tasks = self._get_param('tasks') + cluster = self._get_param("cluster") + tasks = self._get_param("tasks") data = self.ecs_backend.describe_tasks(cluster, tasks) - return json.dumps({ - 'tasks': [task.response_object for task in data], - 'failures': [] - }) + return json.dumps( + {"tasks": [task.response_object for task in data], "failures": []} + ) def start_task(self): - cluster_str = self._get_param('cluster') - overrides = self._get_param('overrides') - task_definition_str = self._get_param('taskDefinition') - container_instances = self._get_param('containerInstances') - started_by = self._get_param('startedBy') + cluster_str = self._get_param("cluster") + overrides = self._get_param("overrides") + task_definition_str = self._get_param("taskDefinition") + container_instances = self._get_param("containerInstances") + started_by = self._get_param("startedBy") tasks = self.ecs_backend.start_task( - cluster_str, task_definition_str, container_instances, overrides, started_by) - return json.dumps({ - 'tasks': [task.response_object for task in tasks], - 'failures': [] - }) + cluster_str, task_definition_str, container_instances, overrides, started_by + ) + return json.dumps( + {"tasks": [task.response_object for task in tasks], "failures": []} + ) def list_tasks(self): - cluster_str = self._get_param('cluster') - container_instance = self._get_param('containerInstance') - family = self._get_param('family') - started_by = self._get_param('startedBy') - service_name = self._get_param('serviceName') - desiredStatus = self._get_param('desiredStatus') + cluster_str = self._get_param("cluster") + container_instance = self._get_param("containerInstance") + family = self._get_param("family") + started_by = self._get_param("startedBy") + service_name = self._get_param("serviceName") + desiredStatus = self._get_param("desiredStatus") task_arns = self.ecs_backend.list_tasks( - cluster_str, container_instance, family, started_by, service_name, desiredStatus) - return json.dumps({ - 'taskArns': task_arns - }) + cluster_str, + container_instance, + family, + started_by, + service_name, + desiredStatus, + ) + return json.dumps({"taskArns": task_arns}) def stop_task(self): - cluster_str = self._get_param('cluster') - task = self._get_param('task') - reason = self._get_param('reason') + cluster_str = self._get_param("cluster") + task = self._get_param("task") + reason = self._get_param("reason") task = self.ecs_backend.stop_task(cluster_str, task, reason) - return json.dumps({ - 'task': task.response_object - }) + return json.dumps({"task": task.response_object}) def create_service(self): - cluster_str = self._get_param('cluster') - service_name = self._get_param('serviceName') - task_definition_str = self._get_param('taskDefinition') - desired_count = self._get_int_param('desiredCount') - load_balancers = self._get_param('loadBalancers') - scheduling_strategy = self._get_param('schedulingStrategy') - tags = self._get_param('tags') + cluster_str = self._get_param("cluster") + service_name = self._get_param("serviceName") + task_definition_str = self._get_param("taskDefinition") + desired_count = self._get_int_param("desiredCount") + load_balancers = self._get_param("loadBalancers") + scheduling_strategy = self._get_param("schedulingStrategy") + tags = self._get_param("tags") service = self.ecs_backend.create_service( - cluster_str, service_name, task_definition_str, desired_count, load_balancers, scheduling_strategy, tags) - return json.dumps({ - 'service': service.response_object - }) + cluster_str, + service_name, + task_definition_str, + desired_count, + load_balancers, + scheduling_strategy, + tags, + ) + return json.dumps({"service": service.response_object}) def list_services(self): - cluster_str = self._get_param('cluster') - scheduling_strategy = self._get_param('schedulingStrategy') + cluster_str = self._get_param("cluster") + scheduling_strategy = self._get_param("schedulingStrategy") service_arns = self.ecs_backend.list_services(cluster_str, scheduling_strategy) - return json.dumps({ - 'serviceArns': service_arns - # , - # 'nextToken': str(uuid.uuid4()) - }) + return json.dumps( + { + "serviceArns": service_arns + # , + # 'nextToken': str(uuid.uuid4()) + } + ) def describe_services(self): - cluster_str = self._get_param('cluster') - service_names = self._get_param('services') - services = self.ecs_backend.describe_services( - cluster_str, service_names) - return json.dumps({ - 'services': [service.response_object for service in services], - 'failures': [] - }) + cluster_str = self._get_param("cluster") + service_names = self._get_param("services") + services = self.ecs_backend.describe_services(cluster_str, service_names) + return json.dumps( + { + "services": [service.response_object for service in services], + "failures": [], + } + ) def update_service(self): - cluster_str = self._get_param('cluster') - service_name = self._get_param('service') - task_definition = self._get_param('taskDefinition') - desired_count = self._get_int_param('desiredCount') + cluster_str = self._get_param("cluster") + service_name = self._get_param("service") + task_definition = self._get_param("taskDefinition") + desired_count = self._get_int_param("desiredCount") service = self.ecs_backend.update_service( - cluster_str, service_name, task_definition, desired_count) - return json.dumps({ - 'service': service.response_object - }) + cluster_str, service_name, task_definition, desired_count + ) + return json.dumps({"service": service.response_object}) def delete_service(self): - service_name = self._get_param('service') - cluster_name = self._get_param('cluster') + service_name = self._get_param("service") + cluster_name = self._get_param("cluster") service = self.ecs_backend.delete_service(cluster_name, service_name) - return json.dumps({ - 'service': service.response_object - }) + return json.dumps({"service": service.response_object}) def register_container_instance(self): - cluster_str = self._get_param('cluster') - instance_identity_document_str = self._get_param( - 'instanceIdentityDocument') + cluster_str = self._get_param("cluster") + instance_identity_document_str = self._get_param("instanceIdentityDocument") instance_identity_document = json.loads(instance_identity_document_str) ec2_instance_id = instance_identity_document["instanceId"] container_instance = self.ecs_backend.register_container_instance( - cluster_str, ec2_instance_id) - return json.dumps({ - 'containerInstance': container_instance.response_object - }) + cluster_str, ec2_instance_id + ) + return json.dumps({"containerInstance": container_instance.response_object}) def deregister_container_instance(self): - cluster_str = self._get_param('cluster') + cluster_str = self._get_param("cluster") if not cluster_str: - cluster_str = 'default' - container_instance_str = self._get_param('containerInstance') - force = self._get_param('force') + cluster_str = "default" + container_instance_str = self._get_param("containerInstance") + force = self._get_param("force") container_instance, failures = self.ecs_backend.deregister_container_instance( cluster_str, container_instance_str, force ) - return json.dumps({ - 'containerInstance': container_instance.response_object - }) + return json.dumps({"containerInstance": container_instance.response_object}) def list_container_instances(self): - cluster_str = self._get_param('cluster') - container_instance_arns = self.ecs_backend.list_container_instances( - cluster_str) - return json.dumps({ - 'containerInstanceArns': container_instance_arns - }) + cluster_str = self._get_param("cluster") + container_instance_arns = self.ecs_backend.list_container_instances(cluster_str) + return json.dumps({"containerInstanceArns": container_instance_arns}) def describe_container_instances(self): - cluster_str = self._get_param('cluster') - list_container_instance_arns = self._get_param('containerInstances') + cluster_str = self._get_param("cluster") + list_container_instance_arns = self._get_param("containerInstances") container_instances, failures = self.ecs_backend.describe_container_instances( - cluster_str, list_container_instance_arns) - return json.dumps({ - 'failures': [ci.response_object for ci in failures], - 'containerInstances': [ci.response_object for ci in container_instances] - }) + cluster_str, list_container_instance_arns + ) + return json.dumps( + { + "failures": [ci.response_object for ci in failures], + "containerInstances": [ + ci.response_object for ci in container_instances + ], + } + ) def update_container_instances_state(self): - cluster_str = self._get_param('cluster') - list_container_instance_arns = self._get_param('containerInstances') - status_str = self._get_param('status') - container_instances, failures = self.ecs_backend.update_container_instances_state(cluster_str, - list_container_instance_arns, - status_str) - return json.dumps({ - 'failures': [ci.response_object for ci in failures], - 'containerInstances': [ci.response_object for ci in container_instances] - }) + cluster_str = self._get_param("cluster") + list_container_instance_arns = self._get_param("containerInstances") + status_str = self._get_param("status") + ( + container_instances, + failures, + ) = self.ecs_backend.update_container_instances_state( + cluster_str, list_container_instance_arns, status_str + ) + return json.dumps( + { + "failures": [ci.response_object for ci in failures], + "containerInstances": [ + ci.response_object for ci in container_instances + ], + } + ) def put_attributes(self): - cluster_name = self._get_param('cluster') - attributes = self._get_param('attributes') + cluster_name = self._get_param("cluster") + attributes = self._get_param("attributes") self.ecs_backend.put_attributes(cluster_name, attributes) - return json.dumps({'attributes': attributes}) + return json.dumps({"attributes": attributes}) def list_attributes(self): - cluster_name = self._get_param('cluster') - attr_name = self._get_param('attributeName') - attr_value = self._get_param('attributeValue') - target_type = self._get_param('targetType') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') + cluster_name = self._get_param("cluster") + attr_name = self._get_param("attributeName") + attr_value = self._get_param("attributeValue") + target_type = self._get_param("targetType") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") - results = self.ecs_backend.list_attributes(target_type, cluster_name, attr_name, attr_value, max_results, next_token) + results = self.ecs_backend.list_attributes( + target_type, cluster_name, attr_name, attr_value, max_results, next_token + ) # Result will be [item will be {0 cluster_name, 1 arn, 2 name, 3 value}] formatted_results = [] for _, arn, name, value in results: - tmp_result = { - 'name': name, - 'targetId': arn - } + tmp_result = {"name": name, "targetId": arn} if value is not None: - tmp_result['value'] = value + tmp_result["value"] = value formatted_results.append(tmp_result) - return json.dumps({'attributes': formatted_results}) + return json.dumps({"attributes": formatted_results}) def delete_attributes(self): - cluster_name = self._get_param('cluster') - attributes = self._get_param('attributes') + cluster_name = self._get_param("cluster") + attributes = self._get_param("attributes") self.ecs_backend.delete_attributes(cluster_name, attributes) - return json.dumps({'attributes': attributes}) + return json.dumps({"attributes": attributes}) def discover_poll_endpoint(self): # Here are the arguments, this api is used by the ecs client so obviously no decent # documentation. Hence I've responded with valid but useless data # cluster_name = self._get_param('cluster') # instance = self._get_param('containerInstance') - return json.dumps({ - 'endpoint': 'http://localhost', - 'telemetryEndpoint': 'http://localhost' - }) + return json.dumps( + {"endpoint": "http://localhost", "telemetryEndpoint": "http://localhost"} + ) def list_task_definition_families(self): - family_prefix = self._get_param('familyPrefix') - status = self._get_param('status') - max_results = self._get_param('maxResults') - next_token = self._get_param('nextToken') + family_prefix = self._get_param("familyPrefix") + status = self._get_param("status") + max_results = self._get_param("maxResults") + next_token = self._get_param("nextToken") - results = self.ecs_backend.list_task_definition_families(family_prefix, status, max_results, next_token) + results = self.ecs_backend.list_task_definition_families( + family_prefix, status, max_results, next_token + ) - return json.dumps({'families': list(results)}) + return json.dumps({"families": list(results)}) def list_tags_for_resource(self): - resource_arn = self._get_param('resourceArn') + resource_arn = self._get_param("resourceArn") tags = self.ecs_backend.list_tags_for_resource(resource_arn) - return json.dumps({'tags': tags}) + return json.dumps({"tags": tags}) def tag_resource(self): - resource_arn = self._get_param('resourceArn') - tags = self._get_param('tags') + resource_arn = self._get_param("resourceArn") + tags = self._get_param("tags") results = self.ecs_backend.tag_resource(resource_arn, tags) return json.dumps(results) def untag_resource(self): - resource_arn = self._get_param('resourceArn') - tag_keys = self._get_param('tagKeys') + resource_arn = self._get_param("resourceArn") + tag_keys = self._get_param("tagKeys") results = self.ecs_backend.untag_resource(resource_arn, tag_keys) return json.dumps(results) diff --git a/moto/ecs/urls.py b/moto/ecs/urls.py index 1e0d5fbf9..a5adc5923 100644 --- a/moto/ecs/urls.py +++ b/moto/ecs/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import EC2ContainerServiceResponse -url_bases = [ - "https?://ecs.(.+).amazonaws.com", -] +url_bases = ["https?://ecs.(.+).amazonaws.com"] -url_paths = { - '{0}/$': EC2ContainerServiceResponse.dispatch, -} +url_paths = {"{0}/$": EC2ContainerServiceResponse.dispatch} diff --git a/moto/elb/__init__.py b/moto/elb/__init__.py index e25f2d486..d3627ed6d 100644 --- a/moto/elb/__init__.py +++ b/moto/elb/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import elb_backends from ..core.models import base_decorator, deprecated_base_decorator -elb_backend = elb_backends['us-east-1'] +elb_backend = elb_backends["us-east-1"] mock_elb = base_decorator(elb_backends) mock_elb_deprecated = deprecated_base_decorator(elb_backends) diff --git a/moto/elb/exceptions.py b/moto/elb/exceptions.py index 3ea6a1642..d41a66e3f 100644 --- a/moto/elb/exceptions.py +++ b/moto/elb/exceptions.py @@ -7,68 +7,66 @@ class ELBClientError(RESTError): class DuplicateTagKeysError(ELBClientError): - def __init__(self, cidr): super(DuplicateTagKeysError, self).__init__( - "DuplicateTagKeys", - "Tag key was specified more than once: {0}" - .format(cidr)) + "DuplicateTagKeys", "Tag key was specified more than once: {0}".format(cidr) + ) class LoadBalancerNotFoundError(ELBClientError): - def __init__(self, cidr): super(LoadBalancerNotFoundError, self).__init__( "LoadBalancerNotFound", - "The specified load balancer does not exist: {0}" - .format(cidr)) + "The specified load balancer does not exist: {0}".format(cidr), + ) class TooManyTagsError(ELBClientError): - def __init__(self): super(TooManyTagsError, self).__init__( "LoadBalancerNotFound", - "The quota for the number of tags that can be assigned to a load balancer has been reached") + "The quota for the number of tags that can be assigned to a load balancer has been reached", + ) class BadHealthCheckDefinition(ELBClientError): - def __init__(self): super(BadHealthCheckDefinition, self).__init__( "ValidationError", - "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL") + "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL", + ) class DuplicateListenerError(ELBClientError): - def __init__(self, name, port): super(DuplicateListenerError, self).__init__( "DuplicateListener", - "A listener already exists for {0} with LoadBalancerPort {1}, but with a different InstancePort, Protocol, or SSLCertificateId" - .format(name, port)) + "A listener already exists for {0} with LoadBalancerPort {1}, but with a different InstancePort, Protocol, or SSLCertificateId".format( + name, port + ), + ) class DuplicateLoadBalancerName(ELBClientError): - def __init__(self, name): super(DuplicateLoadBalancerName, self).__init__( "DuplicateLoadBalancerName", - "The specified load balancer name already exists for this account: {0}" - .format(name)) + "The specified load balancer name already exists for this account: {0}".format( + name + ), + ) class EmptyListenersError(ELBClientError): - def __init__(self): super(EmptyListenersError, self).__init__( - "ValidationError", - "Listeners cannot be empty") + "ValidationError", "Listeners cannot be empty" + ) class InvalidSecurityGroupError(ELBClientError): - def __init__(self): super(InvalidSecurityGroupError, self).__init__( "ValidationError", - "One or more of the specified security groups do not exist.") + "One or more of the specified security groups do not exist.", + ) diff --git a/moto/elb/models.py b/moto/elb/models.py index 8781620f1..f77811623 100644 --- a/moto/elb/models.py +++ b/moto/elb/models.py @@ -8,10 +8,7 @@ from boto.ec2.elb.attributes import ( AccessLogAttribute, CrossZoneLoadBalancingAttribute, ) -from boto.ec2.elb.policies import ( - Policies, - OtherPolicy, -) +from boto.ec2.elb.policies import Policies, OtherPolicy from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.ec2.models import ec2_backends @@ -27,20 +24,19 @@ from .exceptions import ( class FakeHealthCheck(BaseModel): - - def __init__(self, timeout, healthy_threshold, unhealthy_threshold, - interval, target): + def __init__( + self, timeout, healthy_threshold, unhealthy_threshold, interval, target + ): self.timeout = timeout self.healthy_threshold = healthy_threshold self.unhealthy_threshold = unhealthy_threshold self.interval = interval self.target = target - if not target.startswith(('HTTP', 'TCP', 'HTTPS', 'SSL')): + if not target.startswith(("HTTP", "TCP", "HTTPS", "SSL")): raise BadHealthCheckDefinition class FakeListener(BaseModel): - def __init__(self, load_balancer_port, instance_port, protocol, ssl_certificate_id): self.load_balancer_port = load_balancer_port self.instance_port = instance_port @@ -49,22 +45,38 @@ class FakeListener(BaseModel): self.policy_names = [] def __repr__(self): - return "FakeListener(lbp: %s, inp: %s, pro: %s, cid: %s, policies: %s)" % (self.load_balancer_port, self.instance_port, self.protocol, self.ssl_certificate_id, self.policy_names) + return "FakeListener(lbp: %s, inp: %s, pro: %s, cid: %s, policies: %s)" % ( + self.load_balancer_port, + self.instance_port, + self.protocol, + self.ssl_certificate_id, + self.policy_names, + ) class FakeBackend(BaseModel): - def __init__(self, instance_port): self.instance_port = instance_port self.policy_names = [] def __repr__(self): - return "FakeBackend(inp: %s, policies: %s)" % (self.instance_port, self.policy_names) + return "FakeBackend(inp: %s, policies: %s)" % ( + self.instance_port, + self.policy_names, + ) class FakeLoadBalancer(BaseModel): - - def __init__(self, name, zones, ports, scheme='internet-facing', vpc_id=None, subnets=None, security_groups=None): + def __init__( + self, + name, + zones, + ports, + scheme="internet-facing", + vpc_id=None, + subnets=None, + security_groups=None, + ): self.name = name self.health_check = None self.instance_ids = [] @@ -80,47 +92,49 @@ class FakeLoadBalancer(BaseModel): self.policies.lb_cookie_stickiness_policies = [] self.security_groups = security_groups or [] self.subnets = subnets or [] - self.vpc_id = vpc_id or 'vpc-56e10e3d' + self.vpc_id = vpc_id or "vpc-56e10e3d" self.tags = {} self.dns_name = "%s.us-east-1.elb.amazonaws.com" % (name) for port in ports: listener = FakeListener( - protocol=(port.get('protocol') or port['Protocol']), + protocol=(port.get("protocol") or port["Protocol"]), load_balancer_port=( - port.get('load_balancer_port') or port['LoadBalancerPort']), - instance_port=( - port.get('instance_port') or port['InstancePort']), + port.get("load_balancer_port") or port["LoadBalancerPort"] + ), + instance_port=(port.get("instance_port") or port["InstancePort"]), ssl_certificate_id=port.get( - 'ssl_certificate_id', port.get('SSLCertificateId')), + "ssl_certificate_id", port.get("SSLCertificateId") + ), ) self.listeners.append(listener) # it is unclear per the AWS documentation as to when or how backend # information gets set, so let's guess and set it here *shrug* backend = FakeBackend( - instance_port=( - port.get('instance_port') or port['InstancePort']), + instance_port=(port.get("instance_port") or port["InstancePort"]) ) self.backends.append(backend) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elb_backend = elb_backends[region_name] new_elb = elb_backend.create_load_balancer( - name=properties.get('LoadBalancerName', resource_name), - zones=properties.get('AvailabilityZones', []), - ports=properties['Listeners'], - scheme=properties.get('Scheme', 'internet-facing'), + name=properties.get("LoadBalancerName", resource_name), + zones=properties.get("AvailabilityZones", []), + ports=properties["Listeners"], + scheme=properties.get("Scheme", "internet-facing"), ) - instance_ids = properties.get('Instances', []) + instance_ids = properties.get("Instances", []) for instance_id in instance_ids: elb_backend.register_instances(new_elb.name, [instance_id]) - policies = properties.get('Policies', []) + policies = properties.get("Policies", []) port_policies = {} for policy in policies: policy_name = policy["PolicyName"] @@ -134,29 +148,37 @@ class FakeLoadBalancer(BaseModel): for port, policies in port_policies.items(): elb_backend.set_load_balancer_policies_of_backend_server( - new_elb.name, port, list(policies)) + new_elb.name, port, list(policies) + ) - health_check = properties.get('HealthCheck') + health_check = properties.get("HealthCheck") if health_check: elb_backend.configure_health_check( load_balancer_name=new_elb.name, - timeout=health_check['Timeout'], - healthy_threshold=health_check['HealthyThreshold'], - unhealthy_threshold=health_check['UnhealthyThreshold'], - interval=health_check['Interval'], - target=health_check['Target'], + timeout=health_check["Timeout"], + healthy_threshold=health_check["HealthyThreshold"], + unhealthy_threshold=health_check["UnhealthyThreshold"], + interval=health_check["Interval"], + target=health_check["Target"], ) return new_elb @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): elb_backend = elb_backends[region_name] try: elb_backend.delete_load_balancer(resource_name) @@ -169,20 +191,25 @@ class FakeLoadBalancer(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'CanonicalHostedZoneName': + + if attribute_name == "CanonicalHostedZoneName": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneName" ]"') - elif attribute_name == 'CanonicalHostedZoneNameID': + '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneName" ]"' + ) + elif attribute_name == "CanonicalHostedZoneNameID": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneNameID" ]"') - elif attribute_name == 'DNSName': + '"Fn::GetAtt" : [ "{0}" , "CanonicalHostedZoneNameID" ]"' + ) + elif attribute_name == "DNSName": return self.dns_name - elif attribute_name == 'SourceSecurityGroup.GroupName': + elif attribute_name == "SourceSecurityGroup.GroupName": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.GroupName" ]"') - elif attribute_name == 'SourceSecurityGroup.OwnerAlias': + '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.GroupName" ]"' + ) + elif attribute_name == "SourceSecurityGroup.OwnerAlias": raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.OwnerAlias" ]"') + '"Fn::GetAtt" : [ "{0}" , "SourceSecurityGroup.OwnerAlias" ]"' + ) raise UnformattedGetAttTemplateException() @classmethod @@ -220,12 +247,11 @@ class FakeLoadBalancer(BaseModel): del self.tags[key] def delete(self, region): - ''' Not exposed as part of the ELB API - used for CloudFormation. ''' + """ Not exposed as part of the ELB API - used for CloudFormation. """ elb_backends[region].delete_load_balancer(self.name) class ELBBackend(BaseBackend): - def __init__(self, region_name=None): self.region_name = region_name self.load_balancers = OrderedDict() @@ -235,7 +261,15 @@ class ELBBackend(BaseBackend): self.__dict__ = {} self.__init__(region_name) - def create_load_balancer(self, name, zones, ports, scheme='internet-facing', subnets=None, security_groups=None): + def create_load_balancer( + self, + name, + zones, + ports, + scheme="internet-facing", + subnets=None, + security_groups=None, + ): vpc_id = None ec2_backend = ec2_backends[self.region_name] if subnets: @@ -257,7 +291,8 @@ class ELBBackend(BaseBackend): scheme=scheme, subnets=subnets, security_groups=security_groups, - vpc_id=vpc_id) + vpc_id=vpc_id, + ) self.load_balancers[name] = new_load_balancer return new_load_balancer @@ -265,10 +300,10 @@ class ELBBackend(BaseBackend): balancer = self.load_balancers.get(name, None) if balancer: for port in ports: - protocol = port['protocol'] - instance_port = port['instance_port'] - lb_port = port['load_balancer_port'] - ssl_certificate_id = port.get('ssl_certificate_id') + protocol = port["protocol"] + instance_port = port["instance_port"] + lb_port = port["load_balancer_port"] + ssl_certificate_id = port.get("ssl_certificate_id") for listener in balancer.listeners: if lb_port == listener.load_balancer_port: if protocol != listener.protocol: @@ -279,8 +314,11 @@ class ELBBackend(BaseBackend): raise DuplicateListenerError(name, lb_port) break else: - balancer.listeners.append(FakeListener( - lb_port, instance_port, protocol, ssl_certificate_id)) + balancer.listeners.append( + FakeListener( + lb_port, instance_port, protocol, ssl_certificate_id + ) + ) return balancer @@ -288,7 +326,8 @@ class ELBBackend(BaseBackend): balancers = self.load_balancers.values() if names: matched_balancers = [ - balancer for balancer in balancers if balancer.name in names] + balancer for balancer in balancers if balancer.name in names + ] if len(names) != len(matched_balancers): missing_elb = list(set(names) - set(matched_balancers))[0] raise LoadBalancerNotFoundError(missing_elb) @@ -315,7 +354,9 @@ class ELBBackend(BaseBackend): def get_load_balancer(self, load_balancer_name): return self.load_balancers.get(load_balancer_name) - def apply_security_groups_to_load_balancer(self, load_balancer_name, security_group_ids): + def apply_security_groups_to_load_balancer( + self, load_balancer_name, security_group_ids + ): load_balancer = self.load_balancers.get(load_balancer_name) ec2_backend = ec2_backends[self.region_name] for security_group_id in security_group_ids: @@ -323,22 +364,30 @@ class ELBBackend(BaseBackend): raise InvalidSecurityGroupError() load_balancer.security_groups = security_group_ids - def configure_health_check(self, load_balancer_name, timeout, - healthy_threshold, unhealthy_threshold, interval, - target): - check = FakeHealthCheck(timeout, healthy_threshold, unhealthy_threshold, - interval, target) + def configure_health_check( + self, + load_balancer_name, + timeout, + healthy_threshold, + unhealthy_threshold, + interval, + target, + ): + check = FakeHealthCheck( + timeout, healthy_threshold, unhealthy_threshold, interval, target + ) load_balancer = self.get_load_balancer(load_balancer_name) load_balancer.health_check = check return check - def set_load_balancer_listener_sslcertificate(self, name, lb_port, ssl_certificate_id): + def set_load_balancer_listener_sslcertificate( + self, name, lb_port, ssl_certificate_id + ): balancer = self.load_balancers.get(name, None) if balancer: for idx, listener in enumerate(balancer.listeners): if lb_port == listener.load_balancer_port: - balancer.listeners[ - idx].ssl_certificate_id = ssl_certificate_id + balancer.listeners[idx].ssl_certificate_id = ssl_certificate_id return balancer @@ -350,7 +399,10 @@ class ELBBackend(BaseBackend): def deregister_instances(self, load_balancer_name, instance_ids): load_balancer = self.get_load_balancer(load_balancer_name) new_instance_ids = [ - instance_id for instance_id in load_balancer.instance_ids if instance_id not in instance_ids] + instance_id + for instance_id in load_balancer.instance_ids + if instance_id not in instance_ids + ] load_balancer.instance_ids = new_instance_ids return load_balancer @@ -376,7 +428,9 @@ class ELBBackend(BaseBackend): def create_lb_other_policy(self, load_balancer_name, other_policy): load_balancer = self.get_load_balancer(load_balancer_name) - if other_policy.policy_name not in [p.policy_name for p in load_balancer.policies.other_policies]: + if other_policy.policy_name not in [ + p.policy_name for p in load_balancer.policies.other_policies + ]: load_balancer.policies.other_policies.append(other_policy) return load_balancer @@ -391,19 +445,27 @@ class ELBBackend(BaseBackend): load_balancer.policies.lb_cookie_stickiness_policies.append(policy) return load_balancer - def set_load_balancer_policies_of_backend_server(self, load_balancer_name, instance_port, policies): + def set_load_balancer_policies_of_backend_server( + self, load_balancer_name, instance_port, policies + ): load_balancer = self.get_load_balancer(load_balancer_name) - backend = [b for b in load_balancer.backends if int( - b.instance_port) == instance_port][0] + backend = [ + b for b in load_balancer.backends if int(b.instance_port) == instance_port + ][0] backend_idx = load_balancer.backends.index(backend) backend.policy_names = policies load_balancer.backends[backend_idx] = backend return load_balancer - def set_load_balancer_policies_of_listener(self, load_balancer_name, load_balancer_port, policies): + def set_load_balancer_policies_of_listener( + self, load_balancer_name, load_balancer_port, policies + ): load_balancer = self.get_load_balancer(load_balancer_name) - listener = [l for l in load_balancer.listeners if int( - l.load_balancer_port) == load_balancer_port][0] + listener = [ + l + for l in load_balancer.listeners + if int(l.load_balancer_port) == load_balancer_port + ][0] listener_idx = load_balancer.listeners.index(listener) listener.policy_names = policies load_balancer.listeners[listener_idx] = listener diff --git a/moto/elb/responses.py b/moto/elb/responses.py index b512f56e9..de21f23e7 100644 --- a/moto/elb/responses.py +++ b/moto/elb/responses.py @@ -5,10 +5,7 @@ from boto.ec2.elb.attributes import ( AccessLogAttribute, CrossZoneLoadBalancingAttribute, ) -from boto.ec2.elb.policies import ( - AppCookieStickinessPolicy, - OtherPolicy, -) +from boto.ec2.elb.policies import AppCookieStickinessPolicy, OtherPolicy from moto.core.responses import BaseResponse from .models import elb_backends @@ -16,16 +13,15 @@ from .exceptions import DuplicateTagKeysError, LoadBalancerNotFoundError class ELBResponse(BaseResponse): - @property def elb_backend(self): return elb_backends[self.region] def create_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") availability_zones = self._get_multi_param("AvailabilityZones.member") ports = self._get_list_prefix("Listeners.member") - scheme = self._get_param('Scheme') + scheme = self._get_param("Scheme") subnets = self._get_multi_param("Subnets.member") security_groups = self._get_multi_param("SecurityGroups.member") @@ -42,27 +38,29 @@ class ELBResponse(BaseResponse): return template.render(load_balancer=load_balancer) def create_load_balancer_listeners(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") ports = self._get_list_prefix("Listeners.member") self.elb_backend.create_load_balancer_listeners( - name=load_balancer_name, ports=ports) + name=load_balancer_name, ports=ports + ) - template = self.response_template( - CREATE_LOAD_BALANCER_LISTENERS_TEMPLATE) + template = self.response_template(CREATE_LOAD_BALANCER_LISTENERS_TEMPLATE) return template.render() def describe_load_balancers(self): names = self._get_multi_param("LoadBalancerNames.member") all_load_balancers = list(self.elb_backend.describe_load_balancers(names)) - marker = self._get_param('Marker') + marker = self._get_param("Marker") all_names = [balancer.name for balancer in all_load_balancers] if marker: start = all_names.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('PageSize', 50) # the default is 400, but using 50 to make testing easier - load_balancers_resp = all_load_balancers[start:start + page_size] + page_size = self._get_int_param( + "PageSize", 50 + ) # the default is 400, but using 50 to make testing easier + load_balancers_resp = all_load_balancers[start : start + page_size] next_marker = None if len(all_load_balancers) > start + page_size: next_marker = load_balancers_resp[-1].name @@ -71,143 +69,158 @@ class ELBResponse(BaseResponse): return template.render(load_balancers=load_balancers_resp, marker=next_marker) def delete_load_balancer_listeners(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") ports = self._get_multi_param("LoadBalancerPorts.member") ports = [int(port) for port in ports] - self.elb_backend.delete_load_balancer_listeners( - load_balancer_name, ports) + self.elb_backend.delete_load_balancer_listeners(load_balancer_name, ports) template = self.response_template(DELETE_LOAD_BALANCER_LISTENERS) return template.render() def delete_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") self.elb_backend.delete_load_balancer(load_balancer_name) template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE) return template.render() def apply_security_groups_to_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") security_group_ids = self._get_multi_param("SecurityGroups.member") - self.elb_backend.apply_security_groups_to_load_balancer(load_balancer_name, security_group_ids) + self.elb_backend.apply_security_groups_to_load_balancer( + load_balancer_name, security_group_ids + ) template = self.response_template(APPLY_SECURITY_GROUPS_TEMPLATE) return template.render(security_group_ids=security_group_ids) def configure_health_check(self): check = self.elb_backend.configure_health_check( - load_balancer_name=self._get_param('LoadBalancerName'), - timeout=self._get_param('HealthCheck.Timeout'), - healthy_threshold=self._get_param('HealthCheck.HealthyThreshold'), - unhealthy_threshold=self._get_param( - 'HealthCheck.UnhealthyThreshold'), - interval=self._get_param('HealthCheck.Interval'), - target=self._get_param('HealthCheck.Target'), + load_balancer_name=self._get_param("LoadBalancerName"), + timeout=self._get_param("HealthCheck.Timeout"), + healthy_threshold=self._get_param("HealthCheck.HealthyThreshold"), + unhealthy_threshold=self._get_param("HealthCheck.UnhealthyThreshold"), + interval=self._get_param("HealthCheck.Interval"), + target=self._get_param("HealthCheck.Target"), ) template = self.response_template(CONFIGURE_HEALTH_CHECK_TEMPLATE) return template.render(check=check) def register_instances_with_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') - instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')] + load_balancer_name = self._get_param("LoadBalancerName") + instance_ids = [ + list(param.values())[0] + for param in self._get_list_prefix("Instances.member") + ] template = self.response_template(REGISTER_INSTANCES_TEMPLATE) load_balancer = self.elb_backend.register_instances( - load_balancer_name, instance_ids) + load_balancer_name, instance_ids + ) return template.render(load_balancer=load_balancer) def set_load_balancer_listener_ssl_certificate(self): - load_balancer_name = self._get_param('LoadBalancerName') - ssl_certificate_id = self.querystring['SSLCertificateId'][0] - lb_port = self.querystring['LoadBalancerPort'][0] + load_balancer_name = self._get_param("LoadBalancerName") + ssl_certificate_id = self.querystring["SSLCertificateId"][0] + lb_port = self.querystring["LoadBalancerPort"][0] self.elb_backend.set_load_balancer_listener_sslcertificate( - load_balancer_name, lb_port, ssl_certificate_id) + load_balancer_name, lb_port, ssl_certificate_id + ) template = self.response_template(SET_LOAD_BALANCER_SSL_CERTIFICATE) return template.render() def deregister_instances_from_load_balancer(self): - load_balancer_name = self._get_param('LoadBalancerName') - instance_ids = [list(param.values())[0] for param in self._get_list_prefix('Instances.member')] + load_balancer_name = self._get_param("LoadBalancerName") + instance_ids = [ + list(param.values())[0] + for param in self._get_list_prefix("Instances.member") + ] template = self.response_template(DEREGISTER_INSTANCES_TEMPLATE) load_balancer = self.elb_backend.deregister_instances( - load_balancer_name, instance_ids) + load_balancer_name, instance_ids + ) return template.render(load_balancer=load_balancer) def describe_load_balancer_attributes(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) template = self.response_template(DESCRIBE_ATTRIBUTES_TEMPLATE) return template.render(attributes=load_balancer.attributes) def modify_load_balancer_attributes(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) cross_zone = self._get_dict_param( - "LoadBalancerAttributes.CrossZoneLoadBalancing.") + "LoadBalancerAttributes.CrossZoneLoadBalancing." + ) if cross_zone: attribute = CrossZoneLoadBalancingAttribute() attribute.enabled = cross_zone["enabled"] == "true" self.elb_backend.set_cross_zone_load_balancing_attribute( - load_balancer_name, attribute) + load_balancer_name, attribute + ) access_log = self._get_dict_param("LoadBalancerAttributes.AccessLog.") if access_log: attribute = AccessLogAttribute() attribute.enabled = access_log["enabled"] == "true" - attribute.s3_bucket_name = access_log['s3_bucket_name'] - attribute.s3_bucket_prefix = access_log['s3_bucket_prefix'] + attribute.s3_bucket_name = access_log["s3_bucket_name"] + attribute.s3_bucket_prefix = access_log["s3_bucket_prefix"] attribute.emit_interval = access_log["emit_interval"] - self.elb_backend.set_access_log_attribute( - load_balancer_name, attribute) + self.elb_backend.set_access_log_attribute(load_balancer_name, attribute) connection_draining = self._get_dict_param( - "LoadBalancerAttributes.ConnectionDraining.") + "LoadBalancerAttributes.ConnectionDraining." + ) if connection_draining: attribute = ConnectionDrainingAttribute() attribute.enabled = connection_draining["enabled"] == "true" attribute.timeout = connection_draining.get("timeout", 300) - self.elb_backend.set_connection_draining_attribute(load_balancer_name, attribute) + self.elb_backend.set_connection_draining_attribute( + load_balancer_name, attribute + ) connection_settings = self._get_dict_param( - "LoadBalancerAttributes.ConnectionSettings.") + "LoadBalancerAttributes.ConnectionSettings." + ) if connection_settings: attribute = ConnectionSettingAttribute() attribute.idle_timeout = connection_settings["idle_timeout"] self.elb_backend.set_connection_settings_attribute( - load_balancer_name, attribute) + load_balancer_name, attribute + ) template = self.response_template(MODIFY_ATTRIBUTES_TEMPLATE) - return template.render(load_balancer=load_balancer, attributes=load_balancer.attributes) + return template.render( + load_balancer=load_balancer, attributes=load_balancer.attributes + ) def create_load_balancer_policy(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") other_policy = OtherPolicy() policy_name = self._get_param("PolicyName") other_policy.policy_name = policy_name - self.elb_backend.create_lb_other_policy( - load_balancer_name, other_policy) + self.elb_backend.create_lb_other_policy(load_balancer_name, other_policy) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) return template.render() def create_app_cookie_stickiness_policy(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") policy = AppCookieStickinessPolicy() policy.policy_name = self._get_param("PolicyName") policy.cookie_name = self._get_param("CookieName") - self.elb_backend.create_app_cookie_stickiness_policy( - load_balancer_name, policy) + self.elb_backend.create_app_cookie_stickiness_policy(load_balancer_name, policy) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) return template.render() def create_lb_cookie_stickiness_policy(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") policy = AppCookieStickinessPolicy() policy.policy_name = self._get_param("PolicyName") @@ -217,62 +230,68 @@ class ELBResponse(BaseResponse): else: policy.cookie_expiration_period = None - self.elb_backend.create_lb_cookie_stickiness_policy( - load_balancer_name, policy) + self.elb_backend.create_lb_cookie_stickiness_policy(load_balancer_name, policy) template = self.response_template(CREATE_LOAD_BALANCER_POLICY_TEMPLATE) return template.render() def set_load_balancer_policies_of_listener(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) - load_balancer_port = int(self._get_param('LoadBalancerPort')) + load_balancer_port = int(self._get_param("LoadBalancerPort")) - mb_listener = [l for l in load_balancer.listeners if int( - l.load_balancer_port) == load_balancer_port] + mb_listener = [ + l + for l in load_balancer.listeners + if int(l.load_balancer_port) == load_balancer_port + ] if mb_listener: policies = self._get_multi_param("PolicyNames.member") self.elb_backend.set_load_balancer_policies_of_listener( - load_balancer_name, load_balancer_port, policies) + load_balancer_name, load_balancer_port, policies + ) # else: explode? template = self.response_template( - SET_LOAD_BALANCER_POLICIES_OF_LISTENER_TEMPLATE) + SET_LOAD_BALANCER_POLICIES_OF_LISTENER_TEMPLATE + ) return template.render() def set_load_balancer_policies_for_backend_server(self): - load_balancer_name = self.querystring.get('LoadBalancerName')[0] + load_balancer_name = self.querystring.get("LoadBalancerName")[0] load_balancer = self.elb_backend.get_load_balancer(load_balancer_name) - instance_port = int(self.querystring.get('InstancePort')[0]) + instance_port = int(self.querystring.get("InstancePort")[0]) - mb_backend = [b for b in load_balancer.backends if int( - b.instance_port) == instance_port] + mb_backend = [ + b for b in load_balancer.backends if int(b.instance_port) == instance_port + ] if mb_backend: - policies = self._get_multi_param('PolicyNames.member') + policies = self._get_multi_param("PolicyNames.member") self.elb_backend.set_load_balancer_policies_of_backend_server( - load_balancer_name, instance_port, policies) + load_balancer_name, instance_port, policies + ) # else: explode? template = self.response_template( - SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE) + SET_LOAD_BALANCER_POLICIES_FOR_BACKEND_SERVER_TEMPLATE + ) return template.render() def describe_instance_health(self): - load_balancer_name = self._get_param('LoadBalancerName') + load_balancer_name = self._get_param("LoadBalancerName") provided_instance_ids = [ list(param.values())[0] - for param in self._get_list_prefix('Instances.member') + for param in self._get_list_prefix("Instances.member") ] registered_instances_id = self.elb_backend.get_load_balancer( - load_balancer_name).instance_ids + load_balancer_name + ).instance_ids if len(provided_instance_ids) == 0: provided_instance_ids = registered_instances_id template = self.response_template(DESCRIBE_INSTANCE_HEALTH_TEMPLATE) instances = [] for instance_id in provided_instance_ids: - state = "InService" \ - if instance_id in registered_instances_id\ - else "Unknown" + state = "InService" if instance_id in registered_instances_id else "Unknown" instances.append({"InstanceId": instance_id, "State": state}) return template.render(instances=instances) @@ -293,17 +312,18 @@ class ELBResponse(BaseResponse): def remove_tags(self): for key, value in self.querystring.items(): if "LoadBalancerNames.member" in key: - number = key.split('.')[2] + number = key.split(".")[2] load_balancer_name = self._get_param( - 'LoadBalancerNames.member.{0}'.format(number)) + "LoadBalancerNames.member.{0}".format(number) + ) elb = self.elb_backend.get_load_balancer(load_balancer_name) if not elb: raise LoadBalancerNotFoundError(load_balancer_name) - key = 'Tag.member.{0}.Key'.format(number) + key = "Tag.member.{0}.Key".format(number) for t_key, t_val in self.querystring.items(): - if t_key.startswith('Tags.member.'): - if t_key.split('.')[3] == 'Key': + if t_key.startswith("Tags.member."): + if t_key.split(".")[3] == "Key": elb.remove_tag(t_val[0]) template = self.response_template(REMOVE_TAGS_TEMPLATE) @@ -313,9 +333,10 @@ class ELBResponse(BaseResponse): elbs = [] for key, value in self.querystring.items(): if "LoadBalancerNames.member" in key: - number = key.split('.')[2] + number = key.split(".")[2] load_balancer_name = self._get_param( - 'LoadBalancerNames.member.{0}'.format(number)) + "LoadBalancerNames.member.{0}".format(number) + ) elb = self.elb_backend.get_load_balancer(load_balancer_name) if not elb: raise LoadBalancerNotFoundError(load_balancer_name) @@ -329,10 +350,10 @@ class ELBResponse(BaseResponse): tag_keys = [] for t_key, t_val in sorted(self.querystring.items()): - if t_key.startswith('Tags.member.'): - if t_key.split('.')[3] == 'Key': + if t_key.startswith("Tags.member."): + if t_key.split(".")[3] == "Key": tag_keys.extend(t_val) - elif t_key.split('.')[3] == 'Value': + elif t_key.split(".")[3] == "Value": tag_values.extend(t_val) counts = {} diff --git a/moto/elb/urls.py b/moto/elb/urls.py index 3d96e1892..bb7f1c7bf 100644 --- a/moto/elb/urls.py +++ b/moto/elb/urls.py @@ -16,29 +16,25 @@ def api_version_elb_backend(*args, **kwargs): """ request = args[0] - if hasattr(request, 'values'): + if hasattr(request, "values"): # boto3 - version = request.values.get('Version') + version = request.values.get("Version") elif isinstance(request, AWSPreparedRequest): # boto in-memory - version = parse_qs(request.body).get('Version')[0] + version = parse_qs(request.body).get("Version")[0] else: # boto in server mode request.parse_request() - version = request.querystring.get('Version')[0] + version = request.querystring.get("Version")[0] - if '2012-06-01' == version: + if "2012-06-01" == version: return ELBResponse.dispatch(*args, **kwargs) - elif '2015-12-01' == version: + elif "2015-12-01" == version: return ELBV2Response.dispatch(*args, **kwargs) else: raise Exception("Unknown ELB API version: {}".format(version)) -url_bases = [ - "https?://elasticloadbalancing.(.+).amazonaws.com", -] +url_bases = ["https?://elasticloadbalancing.(.+).amazonaws.com"] -url_paths = { - '{0}/$': api_version_elb_backend, -} +url_paths = {"{0}/$": api_version_elb_backend} diff --git a/moto/elbv2/__init__.py b/moto/elbv2/__init__.py index 21a6d06c6..61c4a37ff 100644 --- a/moto/elbv2/__init__.py +++ b/moto/elbv2/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import elbv2_backends from ..core.models import base_decorator -elb_backend = elbv2_backends['us-east-1'] +elb_backend = elbv2_backends["us-east-1"] mock_elbv2 = base_decorator(elbv2_backends) diff --git a/moto/elbv2/exceptions.py b/moto/elbv2/exceptions.py index ccbfd06dd..8ea509d0d 100644 --- a/moto/elbv2/exceptions.py +++ b/moto/elbv2/exceptions.py @@ -7,200 +7,174 @@ class ELBClientError(RESTError): class DuplicateTagKeysError(ELBClientError): - def __init__(self, cidr): super(DuplicateTagKeysError, self).__init__( - "DuplicateTagKeys", - "Tag key was specified more than once: {0}" - .format(cidr)) + "DuplicateTagKeys", "Tag key was specified more than once: {0}".format(cidr) + ) class LoadBalancerNotFoundError(ELBClientError): - def __init__(self): super(LoadBalancerNotFoundError, self).__init__( - "LoadBalancerNotFound", - "The specified load balancer does not exist.") + "LoadBalancerNotFound", "The specified load balancer does not exist." + ) class ListenerNotFoundError(ELBClientError): - def __init__(self): super(ListenerNotFoundError, self).__init__( - "ListenerNotFound", - "The specified listener does not exist.") + "ListenerNotFound", "The specified listener does not exist." + ) class SubnetNotFoundError(ELBClientError): - def __init__(self): super(SubnetNotFoundError, self).__init__( - "SubnetNotFound", - "The specified subnet does not exist.") + "SubnetNotFound", "The specified subnet does not exist." + ) class TargetGroupNotFoundError(ELBClientError): - def __init__(self): super(TargetGroupNotFoundError, self).__init__( - "TargetGroupNotFound", - "The specified target group does not exist.") + "TargetGroupNotFound", "The specified target group does not exist." + ) class TooManyTagsError(ELBClientError): - def __init__(self): super(TooManyTagsError, self).__init__( "TooManyTagsError", - "The quota for the number of tags that can be assigned to a load balancer has been reached") + "The quota for the number of tags that can be assigned to a load balancer has been reached", + ) class BadHealthCheckDefinition(ELBClientError): - def __init__(self): super(BadHealthCheckDefinition, self).__init__( "ValidationError", - "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL") + "HealthCheck Target must begin with one of HTTP, TCP, HTTPS, SSL", + ) class DuplicateListenerError(ELBClientError): - def __init__(self): super(DuplicateListenerError, self).__init__( - "DuplicateListener", - "A listener with the specified port already exists.") + "DuplicateListener", "A listener with the specified port already exists." + ) class DuplicateLoadBalancerName(ELBClientError): - def __init__(self): super(DuplicateLoadBalancerName, self).__init__( "DuplicateLoadBalancerName", - "A load balancer with the specified name already exists.") + "A load balancer with the specified name already exists.", + ) class DuplicateTargetGroupName(ELBClientError): - def __init__(self): super(DuplicateTargetGroupName, self).__init__( "DuplicateTargetGroupName", - "A target group with the specified name already exists.") + "A target group with the specified name already exists.", + ) class InvalidTargetError(ELBClientError): - def __init__(self): super(InvalidTargetError, self).__init__( "InvalidTarget", - "The specified target does not exist or is not in the same VPC as the target group.") + "The specified target does not exist or is not in the same VPC as the target group.", + ) class EmptyListenersError(ELBClientError): - def __init__(self): super(EmptyListenersError, self).__init__( - "ValidationError", - "Listeners cannot be empty") + "ValidationError", "Listeners cannot be empty" + ) class PriorityInUseError(ELBClientError): - def __init__(self): super(PriorityInUseError, self).__init__( - "PriorityInUse", - "The specified priority is in use.") + "PriorityInUse", "The specified priority is in use." + ) class InvalidConditionFieldError(ELBClientError): - def __init__(self, invalid_name): super(InvalidConditionFieldError, self).__init__( "ValidationError", - "Condition field '%s' must be one of '[path-pattern, host-header]" % (invalid_name)) + "Condition field '%s' must be one of '[path-pattern, host-header]" + % (invalid_name), + ) class InvalidConditionValueError(ELBClientError): - def __init__(self, msg): - super(InvalidConditionValueError, self).__init__( - "ValidationError", msg) + super(InvalidConditionValueError, self).__init__("ValidationError", msg) class InvalidActionTypeError(ELBClientError): - def __init__(self, invalid_name, index): super(InvalidActionTypeError, self).__init__( "ValidationError", - "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect, fixed-response]" % (invalid_name, index) + "1 validation error detected: Value '%s' at 'actions.%s.member.type' failed to satisfy constraint: Member must satisfy enum value set: [forward, redirect, fixed-response]" + % (invalid_name, index), ) class ActionTargetGroupNotFoundError(ELBClientError): - def __init__(self, arn): super(ActionTargetGroupNotFoundError, self).__init__( - "TargetGroupNotFound", - "Target group '%s' not found" % arn + "TargetGroupNotFound", "Target group '%s' not found" % arn ) class InvalidDescribeRulesRequest(ELBClientError): - def __init__(self, msg): - super(InvalidDescribeRulesRequest, self).__init__( - "ValidationError", msg - ) + super(InvalidDescribeRulesRequest, self).__init__("ValidationError", msg) class ResourceInUseError(ELBClientError): - def __init__(self, msg="A specified resource is in use"): - super(ResourceInUseError, self).__init__( - "ResourceInUse", msg) + super(ResourceInUseError, self).__init__("ResourceInUse", msg) class RuleNotFoundError(ELBClientError): - def __init__(self): super(RuleNotFoundError, self).__init__( - "RuleNotFound", - "The specified rule does not exist.") - - -class DuplicatePriorityError(ELBClientError): - - def __init__(self, invalid_value): - super(DuplicatePriorityError, self).__init__( - "ValidationError", - "Priority '%s' was provided multiple times" % invalid_value) - - -class InvalidTargetGroupNameError(ELBClientError): - - def __init__(self, msg): - super(InvalidTargetGroupNameError, self).__init__( - "ValidationError", msg + "RuleNotFound", "The specified rule does not exist." ) -class InvalidModifyRuleArgumentsError(ELBClientError): +class DuplicatePriorityError(ELBClientError): + def __init__(self, invalid_value): + super(DuplicatePriorityError, self).__init__( + "ValidationError", + "Priority '%s' was provided multiple times" % invalid_value, + ) + +class InvalidTargetGroupNameError(ELBClientError): + def __init__(self, msg): + super(InvalidTargetGroupNameError, self).__init__("ValidationError", msg) + + +class InvalidModifyRuleArgumentsError(ELBClientError): def __init__(self): super(InvalidModifyRuleArgumentsError, self).__init__( - "ValidationError", - "Either conditions or actions must be specified" + "ValidationError", "Either conditions or actions must be specified" ) class InvalidStatusCodeActionTypeError(ELBClientError): def __init__(self, msg): - super(InvalidStatusCodeActionTypeError, self).__init__( - "ValidationError", msg - ) + super(InvalidStatusCodeActionTypeError, self).__init__("ValidationError", msg) class InvalidLoadBalancerActionException(ELBClientError): - def __init__(self, msg): super(InvalidLoadBalancerActionException, self).__init__( "InvalidLoadBalancerAction", msg diff --git a/moto/elbv2/models.py b/moto/elbv2/models.py index 636cc56a1..fdce9a8c2 100644 --- a/moto/elbv2/models.py +++ b/moto/elbv2/models.py @@ -33,12 +33,15 @@ from .exceptions import ( DuplicatePriorityError, InvalidTargetGroupNameError, InvalidModifyRuleArgumentsError, - InvalidStatusCodeActionTypeError, InvalidLoadBalancerActionException) + InvalidStatusCodeActionTypeError, + InvalidLoadBalancerActionException, +) class FakeHealthStatus(BaseModel): - - def __init__(self, instance_id, port, health_port, status, reason=None, description=None): + def __init__( + self, instance_id, port, health_port, status, reason=None, description=None + ): self.instance_id = instance_id self.port = port self.health_port = health_port @@ -48,23 +51,25 @@ class FakeHealthStatus(BaseModel): class FakeTargetGroup(BaseModel): - HTTP_CODE_REGEX = re.compile(r'(?:(?:\d+-\d+|\d+),?)+') + HTTP_CODE_REGEX = re.compile(r"(?:(?:\d+-\d+|\d+),?)+") - def __init__(self, - name, - arn, - vpc_id, - protocol, - port, - healthcheck_protocol=None, - healthcheck_port=None, - healthcheck_path=None, - healthcheck_interval_seconds=None, - healthcheck_timeout_seconds=None, - healthy_threshold_count=None, - unhealthy_threshold_count=None, - matcher=None, - target_type=None): + def __init__( + self, + name, + arn, + vpc_id, + protocol, + port, + healthcheck_protocol=None, + healthcheck_port=None, + healthcheck_path=None, + healthcheck_interval_seconds=None, + healthcheck_timeout_seconds=None, + healthy_threshold_count=None, + unhealthy_threshold_count=None, + matcher=None, + target_type=None, + ): # TODO: default values differs when you add Network Load balancer self.name = name @@ -72,9 +77,9 @@ class FakeTargetGroup(BaseModel): self.vpc_id = vpc_id self.protocol = protocol self.port = port - self.healthcheck_protocol = healthcheck_protocol or 'HTTP' + self.healthcheck_protocol = healthcheck_protocol or "HTTP" self.healthcheck_port = healthcheck_port or str(self.port) - self.healthcheck_path = healthcheck_path or '/' + self.healthcheck_path = healthcheck_path or "/" self.healthcheck_interval_seconds = healthcheck_interval_seconds or 30 self.healthcheck_timeout_seconds = healthcheck_timeout_seconds or 5 self.healthy_threshold_count = healthy_threshold_count or 5 @@ -82,14 +87,14 @@ class FakeTargetGroup(BaseModel): self.load_balancer_arns = [] self.tags = {} if matcher is None: - self.matcher = {'HttpCode': '200'} + self.matcher = {"HttpCode": "200"} else: self.matcher = matcher self.target_type = target_type self.attributes = { - 'deregistration_delay.timeout_seconds': 300, - 'stickiness.enabled': 'false', + "deregistration_delay.timeout_seconds": 300, + "stickiness.enabled": "false", } self.targets = OrderedDict() @@ -100,14 +105,14 @@ class FakeTargetGroup(BaseModel): def register(self, targets): for target in targets: - self.targets[target['id']] = { - 'id': target['id'], - 'port': target.get('port', self.port), + self.targets[target["id"]] = { + "id": target["id"], + "port": target.get("port", self.port), } def deregister(self, targets): for target in targets: - t = self.targets.pop(target['id'], None) + t = self.targets.pop(target["id"], None) if not t: raise InvalidTargetError() @@ -122,24 +127,33 @@ class FakeTargetGroup(BaseModel): self.tags[key] = value def health_for(self, target, ec2_backend): - t = self.targets.get(target['id']) + t = self.targets.get(target["id"]) if t is None: raise InvalidTargetError() - if t['id'].startswith("i-"): # EC2 instance ID - instance = ec2_backend.get_instance_by_id(t['id']) + if t["id"].startswith("i-"): # EC2 instance ID + instance = ec2_backend.get_instance_by_id(t["id"]) if instance.state == "stopped": - return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'unused', 'Target.InvalidState', 'Target is in the stopped state') - return FakeHealthStatus(t['id'], t['port'], self.healthcheck_port, 'healthy') + return FakeHealthStatus( + t["id"], + t["port"], + self.healthcheck_port, + "unused", + "Target.InvalidState", + "Target is in the stopped state", + ) + return FakeHealthStatus(t["id"], t["port"], self.healthcheck_port, "healthy") @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elbv2_backend = elbv2_backends[region_name] - name = properties.get('Name') + name = properties.get("Name") vpc_id = properties.get("VpcId") - protocol = properties.get('Protocol') + protocol = properties.get("Protocol") port = properties.get("Port") healthcheck_protocol = properties.get("HealthCheckProtocol") healthcheck_port = properties.get("HealthCheckPort") @@ -170,8 +184,16 @@ class FakeTargetGroup(BaseModel): class FakeListener(BaseModel): - - def __init__(self, load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions): + def __init__( + self, + load_balancer_arn, + arn, + protocol, + port, + ssl_policy, + certificate, + default_actions, + ): self.load_balancer_arn = load_balancer_arn self.arn = arn self.protocol = protocol.upper() @@ -184,9 +206,9 @@ class FakeListener(BaseModel): self._default_rule = FakeRule( listener_arn=self.arn, conditions=[], - priority='default', + priority="default", actions=default_actions, - is_default=True + is_default=True, ) @property @@ -202,11 +224,15 @@ class FakeListener(BaseModel): def register(self, rule): self._non_default_rules.append(rule) - self._non_default_rules = sorted(self._non_default_rules, key=lambda x: x.priority) + self._non_default_rules = sorted( + self._non_default_rules, key=lambda x: x.priority + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elbv2_backend = elbv2_backends[region_name] load_balancer_arn = properties.get("LoadBalancerArn") @@ -217,16 +243,36 @@ class FakeListener(BaseModel): # transform default actions to confirm with the rest of the code and XML templates if "DefaultActions" in properties: default_actions = [] - for i, action in enumerate(properties['DefaultActions']): - action_type = action['Type'] - if action_type == 'forward': - default_actions.append({'type': action_type, 'target_group_arn': action['TargetGroupArn']}) - elif action_type in ['redirect', 'authenticate-cognito', 'fixed-response']: - redirect_action = {'type': action_type} - key = underscores_to_camelcase(action_type.capitalize().replace('-', '_')) + 'Config' - for redirect_config_key, redirect_config_value in action[key].items(): + for i, action in enumerate(properties["DefaultActions"]): + action_type = action["Type"] + if action_type == "forward": + default_actions.append( + { + "type": action_type, + "target_group_arn": action["TargetGroupArn"], + } + ) + elif action_type in [ + "redirect", + "authenticate-cognito", + "fixed-response", + ]: + redirect_action = {"type": action_type} + key = ( + underscores_to_camelcase( + action_type.capitalize().replace("-", "_") + ) + + "Config" + ) + for redirect_config_key, redirect_config_value in action[ + key + ].items(): # need to match the output of _get_list_prefix - redirect_action[camelcase_to_underscores(key) + '._' + camelcase_to_underscores(redirect_config_key)] = redirect_config_value + redirect_action[ + camelcase_to_underscores(key) + + "._" + + camelcase_to_underscores(redirect_config_key) + ] = redirect_config_value default_actions.append(redirect_action) else: raise InvalidActionTypeError(action_type, i + 1) @@ -234,7 +280,8 @@ class FakeListener(BaseModel): default_actions = None listener = elbv2_backend.create_listener( - load_balancer_arn, protocol, port, ssl_policy, certificates, default_actions) + load_balancer_arn, protocol, port, ssl_policy, certificates, default_actions + ) return listener @@ -244,7 +291,8 @@ class FakeAction(BaseModel): self.type = data.get("type") def to_xml(self): - template = Template("""{{ action.type }} + template = Template( + """{{ action.type }} {% if action.type == "forward" %} {{ action.data["target_group_arn"] }} {% elif action.type == "redirect" %} @@ -266,15 +314,17 @@ class FakeAction(BaseModel): {{ action.data["fixed_response_config._status_code"] }} {% endif %} - """) + """ + ) return template.render(action=self) class FakeRule(BaseModel): - def __init__(self, listener_arn, conditions, priority, actions, is_default): self.listener_arn = listener_arn - self.arn = listener_arn.replace(':listener/', ':listener-rule/') + "/%s" % (id(self)) + self.arn = listener_arn.replace(":listener/", ":listener-rule/") + "/%s" % ( + id(self) + ) self.conditions = conditions self.priority = priority # int or 'default' self.actions = actions @@ -282,20 +332,36 @@ class FakeRule(BaseModel): class FakeBackend(BaseModel): - def __init__(self, instance_port): self.instance_port = instance_port self.policy_names = [] def __repr__(self): - return "FakeBackend(inp: %s, policies: %s)" % (self.instance_port, self.policy_names) + return "FakeBackend(inp: %s, policies: %s)" % ( + self.instance_port, + self.policy_names, + ) class FakeLoadBalancer(BaseModel): - VALID_ATTRS = {'access_logs.s3.enabled', 'access_logs.s3.bucket', 'access_logs.s3.prefix', - 'deletion_protection.enabled', 'idle_timeout.timeout_seconds'} + VALID_ATTRS = { + "access_logs.s3.enabled", + "access_logs.s3.bucket", + "access_logs.s3.prefix", + "deletion_protection.enabled", + "idle_timeout.timeout_seconds", + } - def __init__(self, name, security_groups, subnets, vpc_id, arn, dns_name, scheme='internet-facing'): + def __init__( + self, + name, + security_groups, + subnets, + vpc_id, + arn, + dns_name, + scheme="internet-facing", + ): self.name = name self.created_time = datetime.datetime.now() self.scheme = scheme @@ -307,13 +373,13 @@ class FakeLoadBalancer(BaseModel): self.arn = arn self.dns_name = dns_name - self.stack = 'ipv4' + self.stack = "ipv4" self.attrs = { - 'access_logs.s3.enabled': 'false', - 'access_logs.s3.bucket': None, - 'access_logs.s3.prefix': None, - 'deletion_protection.enabled': 'false', - 'idle_timeout.timeout_seconds': '60' + "access_logs.s3.enabled": "false", + "access_logs.s3.bucket": None, + "access_logs.s3.prefix": None, + "deletion_protection.enabled": "false", + "idle_timeout.timeout_seconds": "60", } @property @@ -333,25 +399,29 @@ class FakeLoadBalancer(BaseModel): del self.tags[key] def delete(self, region): - ''' Not exposed as part of the ELB API - used for CloudFormation. ''' + """ Not exposed as part of the ELB API - used for CloudFormation. """ elbv2_backends[region].delete_load_balancer(self.arn) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] elbv2_backend = elbv2_backends[region_name] - name = properties.get('Name', resource_name) + name = properties.get("Name", resource_name) security_groups = properties.get("SecurityGroups") - subnet_ids = properties.get('Subnets') - scheme = properties.get('Scheme', 'internet-facing') + subnet_ids = properties.get("Subnets") + scheme = properties.get("Scheme", "internet-facing") - load_balancer = elbv2_backend.create_load_balancer(name, security_groups, subnet_ids, scheme=scheme) + load_balancer = elbv2_backend.create_load_balancer( + name, security_groups, subnet_ids, scheme=scheme + ) return load_balancer def get_cfn_attribute(self, attribute_name): - ''' + """ Implemented attributes: * DNSName * LoadBalancerName @@ -362,25 +432,27 @@ class FakeLoadBalancer(BaseModel): * SecurityGroups This method is similar to models.py:FakeLoadBalancer.get_cfn_attribute() - ''' + """ from moto.cloudformation.exceptions import UnformattedGetAttTemplateException + not_implemented_yet = [ - 'CanonicalHostedZoneID', - 'LoadBalancerFullName', - 'SecurityGroups', + "CanonicalHostedZoneID", + "LoadBalancerFullName", + "SecurityGroups", ] - if attribute_name == 'DNSName': + if attribute_name == "DNSName": return self.dns_name - elif attribute_name == 'LoadBalancerName': + elif attribute_name == "LoadBalancerName": return self.name elif attribute_name in not_implemented_yet: - raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "%s" ]"' % attribute_name) + raise NotImplementedError( + '"Fn::GetAtt" : [ "{0}" , "%s" ]"' % attribute_name + ) else: raise UnformattedGetAttTemplateException() class ELBv2Backend(BaseBackend): - def __init__(self, region_name=None): self.region_name = region_name self.target_groups = OrderedDict() @@ -411,7 +483,9 @@ class ELBv2Backend(BaseBackend): self.__dict__ = {} self.__init__(region_name) - def create_load_balancer(self, name, security_groups, subnet_ids, scheme='internet-facing'): + def create_load_balancer( + self, name, security_groups, subnet_ids, scheme="internet-facing" + ): vpc_id = None subnets = [] if not subnet_ids: @@ -423,7 +497,9 @@ class ELBv2Backend(BaseBackend): subnets.append(subnet) vpc_id = subnets[0].vpc_id - arn = make_arn_for_load_balancer(account_id=1, name=name, region_name=self.region_name) + arn = make_arn_for_load_balancer( + account_id=1, name=name, region_name=self.region_name + ) dns_name = "%s-1.%s.elb.amazonaws.com" % (name, self.region_name) if arn in self.load_balancers: @@ -436,7 +512,8 @@ class ELBv2Backend(BaseBackend): scheme=scheme, subnets=subnets, vpc_id=vpc_id, - dns_name=dns_name) + dns_name=dns_name, + ) self.load_balancers[arn] = new_load_balancer return new_load_balancer @@ -449,13 +526,13 @@ class ELBv2Backend(BaseBackend): # validate conditions for condition in conditions: - field = condition['field'] - if field not in ['path-pattern', 'host-header']: + field = condition["field"] + if field not in ["path-pattern", "host-header"]: raise InvalidConditionFieldError(field) - values = condition['values'] + values = condition["values"] if len(values) == 0: - raise InvalidConditionValueError('A condition value must be specified') + raise InvalidConditionValueError("A condition value must be specified") if len(values) > 1: raise InvalidConditionValueError( "The '%s' field contains too many values; the limit is '1'" % field @@ -481,34 +558,44 @@ class ELBv2Backend(BaseBackend): def _validate_actions(self, actions): # validate Actions - target_group_arns = [target_group.arn for target_group in self.target_groups.values()] + target_group_arns = [ + target_group.arn for target_group in self.target_groups.values() + ] for i, action in enumerate(actions): index = i + 1 action_type = action.type - if action_type == 'forward': - action_target_group_arn = action.data['target_group_arn'] + if action_type == "forward": + action_target_group_arn = action.data["target_group_arn"] if action_target_group_arn not in target_group_arns: raise ActionTargetGroupNotFoundError(action_target_group_arn) - elif action_type == 'fixed-response': + elif action_type == "fixed-response": self._validate_fixed_response_action(action, i, index) - elif action_type in ['redirect', 'authenticate-cognito']: + elif action_type in ["redirect", "authenticate-cognito"]: pass else: raise InvalidActionTypeError(action_type, index) def _validate_fixed_response_action(self, action, i, index): - status_code = action.data.get('fixed_response_config._status_code') + status_code = action.data.get("fixed_response_config._status_code") if status_code is None: raise ParamValidationError( - report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' % i) - if not re.match(r'^(2|4|5)\d\d$', status_code): + report='Missing required parameter in Actions[%s].FixedResponseConfig: "StatusCode"' + % i + ) + if not re.match(r"^(2|4|5)\d\d$", status_code): raise InvalidStatusCodeActionTypeError( "1 validation error detected: Value '%s' at 'actions.%s.member.fixedResponseConfig.statusCode' failed to satisfy constraint: \ -Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, index) +Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" + % (status_code, index) ) - content_type = action.data['fixed_response_config._content_type'] - if content_type and content_type not in ['text/plain', 'text/css', 'text/html', 'application/javascript', - 'application/json']: + content_type = action.data["fixed_response_config._content_type"] + if content_type and content_type not in [ + "text/plain", + "text/css", + "text/html", + "application/javascript", + "application/json", + ]: raise InvalidLoadBalancerActionException( "The ContentType must be one of:'text/html', 'application/json', 'application/javascript', 'text/css', 'text/plain'" ) @@ -518,18 +605,20 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i raise InvalidTargetGroupNameError( "Target group name '%s' cannot be longer than '32' characters" % name ) - if not re.match('^[a-zA-Z0-9\-]+$', name): + if not re.match("^[a-zA-Z0-9\-]+$", name): raise InvalidTargetGroupNameError( - "Target group name '%s' can only contain characters that are alphanumeric characters or hyphens(-)" % name + "Target group name '%s' can only contain characters that are alphanumeric characters or hyphens(-)" + % name ) # undocumented validation - if not re.match('(?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$', name): + if not re.match("(?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$", name): raise InvalidTargetGroupNameError( - "1 validation error detected: Value '%s' at 'targetGroup.targetGroupArn.targetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern: (?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$" % name + "1 validation error detected: Value '%s' at 'targetGroup.targetGroupArn.targetGroupName' failed to satisfy constraint: Member must satisfy regular expression pattern: (?!.*--)(?!^-)(?!.*-$)^[A-Za-z0-9-]+$" + % name ) - if name.startswith('-') or name.endswith('-'): + if name.startswith("-") or name.endswith("-"): raise InvalidTargetGroupNameError( "Target group name '%s' cannot begin or end with '-'" % name ) @@ -537,25 +626,51 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i if target_group.name == name: raise DuplicateTargetGroupName() - valid_protocols = ['HTTPS', 'HTTP', 'TCP'] - if kwargs.get('healthcheck_protocol') and kwargs['healthcheck_protocol'] not in valid_protocols: + valid_protocols = ["HTTPS", "HTTP", "TCP"] + if ( + kwargs.get("healthcheck_protocol") + and kwargs["healthcheck_protocol"] not in valid_protocols + ): raise InvalidConditionValueError( "Value {} at 'healthCheckProtocol' failed to satisfy constraint: " - "Member must satisfy enum value set: {}".format(kwargs['healthcheck_protocol'], valid_protocols)) - if kwargs.get('protocol') and kwargs['protocol'] not in valid_protocols: + "Member must satisfy enum value set: {}".format( + kwargs["healthcheck_protocol"], valid_protocols + ) + ) + if kwargs.get("protocol") and kwargs["protocol"] not in valid_protocols: raise InvalidConditionValueError( "Value {} at 'protocol' failed to satisfy constraint: " - "Member must satisfy enum value set: {}".format(kwargs['protocol'], valid_protocols)) + "Member must satisfy enum value set: {}".format( + kwargs["protocol"], valid_protocols + ) + ) - if kwargs.get('matcher') and FakeTargetGroup.HTTP_CODE_REGEX.match(kwargs['matcher']['HttpCode']) is None: - raise RESTError('InvalidParameterValue', 'HttpCode must be like 200 | 200-399 | 200,201 ...') + if ( + kwargs.get("matcher") + and FakeTargetGroup.HTTP_CODE_REGEX.match(kwargs["matcher"]["HttpCode"]) + is None + ): + raise RESTError( + "InvalidParameterValue", + "HttpCode must be like 200 | 200-399 | 200,201 ...", + ) - arn = make_arn_for_target_group(account_id=1, name=name, region_name=self.region_name) + arn = make_arn_for_target_group( + account_id=1, name=name, region_name=self.region_name + ) target_group = FakeTargetGroup(name, arn, **kwargs) self.target_groups[target_group.arn] = target_group return target_group - def create_listener(self, load_balancer_arn, protocol, port, ssl_policy, certificate, default_actions): + def create_listener( + self, + load_balancer_arn, + protocol, + port, + ssl_policy, + certificate, + default_actions, + ): default_actions = [FakeAction(action) for action in default_actions] balancer = self.load_balancers.get(load_balancer_arn) if balancer is None: @@ -565,12 +680,23 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i self._validate_actions(default_actions) - arn = load_balancer_arn.replace(':loadbalancer/', ':listener/') + "/%s%s" % (port, id(self)) - listener = FakeListener(load_balancer_arn, arn, protocol, port, ssl_policy, certificate, default_actions) + arn = load_balancer_arn.replace(":loadbalancer/", ":listener/") + "/%s%s" % ( + port, + id(self), + ) + listener = FakeListener( + load_balancer_arn, + arn, + protocol, + port, + ssl_policy, + certificate, + default_actions, + ) balancer.listeners[listener.arn] = listener for action in default_actions: - if action.type == 'forward': - target_group = self.target_groups[action.data['target_group_arn']] + if action.type == "forward": + target_group = self.target_groups[action.data["target_group_arn"]] target_group.load_balancer_arns.append(load_balancer_arn) return listener @@ -612,7 +738,7 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i ) if listener_arn is not None and rule_arns is not None: raise InvalidDescribeRulesRequest( - 'Listener rule ARNs and a listener ARN cannot be specified at the same time' + "Listener rule ARNs and a listener ARN cannot be specified at the same time" ) if listener_arn: listener = self.describe_listeners(None, [listener_arn])[0] @@ -632,8 +758,11 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i if load_balancer_arn: if load_balancer_arn not in self.load_balancers: raise LoadBalancerNotFoundError() - return [tg for tg in self.target_groups.values() - if load_balancer_arn in tg.load_balancer_arns] + return [ + tg + for tg in self.target_groups.values() + if load_balancer_arn in tg.load_balancer_arns + ] if target_group_arns: try: @@ -693,7 +822,9 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i if self._any_listener_using(target_group_arn): raise ResourceInUseError( "The target group '{}' is currently in use by a listener or a rule".format( - target_group_arn)) + target_group_arn + ) + ) del self.target_groups[target_group_arn] return target_group @@ -716,16 +847,19 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i if conditions: for condition in conditions: - field = condition['field'] - if field not in ['path-pattern', 'host-header']: + field = condition["field"] + if field not in ["path-pattern", "host-header"]: raise InvalidConditionFieldError(field) - values = condition['values'] + values = condition["values"] if len(values) == 0: - raise InvalidConditionValueError('A condition value must be specified') + raise InvalidConditionValueError( + "A condition value must be specified" + ) if len(values) > 1: raise InvalidConditionValueError( - "The '%s' field contains too many values; the limit is '1'" % field + "The '%s' field contains too many values; the limit is '1'" + % field ) # TODO: check pattern of value for 'host-header' # TODO: check pattern of value for 'path-pattern' @@ -766,16 +900,18 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i def set_rule_priorities(self, rule_priorities): # validate - priorities = [rule_priority['priority'] for rule_priority in rule_priorities] + priorities = [rule_priority["priority"] for rule_priority in rule_priorities] for priority in set(priorities): if priorities.count(priority) > 1: raise DuplicatePriorityError(priority) # validate for rule_priority in rule_priorities: - given_rule_arn = rule_priority['rule_arn'] - priority = rule_priority['priority'] - _given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn]) + given_rule_arn = rule_priority["rule_arn"] + priority = rule_priority["priority"] + _given_rules = self.describe_rules( + listener_arn=None, rule_arns=[given_rule_arn] + ) if not _given_rules: raise RuleNotFoundError() given_rule = _given_rules[0] @@ -787,9 +923,11 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i # modify modified_rules = [] for rule_priority in rule_priorities: - given_rule_arn = rule_priority['rule_arn'] - priority = rule_priority['priority'] - _given_rules = self.describe_rules(listener_arn=None, rule_arns=[given_rule_arn]) + given_rule_arn = rule_priority["rule_arn"] + priority = rule_priority["priority"] + _given_rules = self.describe_rules( + listener_arn=None, rule_arns=[given_rule_arn] + ) if not _given_rules: raise RuleNotFoundError() given_rule = _given_rules[0] @@ -798,15 +936,21 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i return modified_rules def set_ip_address_type(self, arn, ip_type): - if ip_type not in ('internal', 'dualstack'): - raise RESTError('InvalidParameterValue', 'IpAddressType must be either internal | dualstack') + if ip_type not in ("internal", "dualstack"): + raise RESTError( + "InvalidParameterValue", + "IpAddressType must be either internal | dualstack", + ) balancer = self.load_balancers.get(arn) if balancer is None: raise LoadBalancerNotFoundError() - if ip_type == 'dualstack' and balancer.scheme == 'internal': - raise RESTError('InvalidConfigurationRequest', 'Internal load balancers cannot be dualstack') + if ip_type == "dualstack" and balancer.scheme == "internal": + raise RESTError( + "InvalidConfigurationRequest", + "Internal load balancers cannot be dualstack", + ) balancer.stack = ip_type @@ -818,7 +962,10 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i # Check all security groups exist for sec_group_id in sec_groups: if self.ec2_backend.get_security_group_from_id(sec_group_id) is None: - raise RESTError('InvalidSecurityGroup', 'Security group {0} does not exist'.format(sec_group_id)) + raise RESTError( + "InvalidSecurityGroup", + "Security group {0} does not exist".format(sec_group_id), + ) balancer.security_groups = sec_groups @@ -834,7 +981,10 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i subnet = self.ec2_backend.get_subnet(subnet) if subnet.availability_zone in sub_zone_list: - raise RESTError('InvalidConfigurationRequest', 'More than 1 subnet cannot be specified for 1 availability zone') + raise RESTError( + "InvalidConfigurationRequest", + "More than 1 subnet cannot be specified for 1 availability zone", + ) sub_zone_list[subnet.availability_zone] = subnet.id subnet_objects.append(subnet) @@ -842,7 +992,10 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i raise SubnetNotFoundError() if len(sub_zone_list) < 2: - raise RESTError('InvalidConfigurationRequest', 'More than 1 availability zone must be specified') + raise RESTError( + "InvalidConfigurationRequest", + "More than 1 availability zone must be specified", + ) balancer.subnets = subnet_objects @@ -855,7 +1008,9 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i for key in attrs: if key not in FakeLoadBalancer.VALID_ATTRS: - raise RESTError('InvalidConfigurationRequest', 'Key {0} not valid'.format(key)) + raise RESTError( + "InvalidConfigurationRequest", "Key {0} not valid".format(key) + ) balancer.attrs.update(attrs) return balancer.attrs @@ -867,17 +1022,33 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i return balancer.attrs - def modify_target_group(self, arn, health_check_proto=None, health_check_port=None, health_check_path=None, health_check_interval=None, - health_check_timeout=None, healthy_threshold_count=None, unhealthy_threshold_count=None, http_codes=None): + def modify_target_group( + self, + arn, + health_check_proto=None, + health_check_port=None, + health_check_path=None, + health_check_interval=None, + health_check_timeout=None, + healthy_threshold_count=None, + unhealthy_threshold_count=None, + http_codes=None, + ): target_group = self.target_groups.get(arn) if target_group is None: raise TargetGroupNotFoundError() - if http_codes is not None and FakeTargetGroup.HTTP_CODE_REGEX.match(http_codes) is None: - raise RESTError('InvalidParameterValue', 'HttpCode must be like 200 | 200-399 | 200,201 ...') + if ( + http_codes is not None + and FakeTargetGroup.HTTP_CODE_REGEX.match(http_codes) is None + ): + raise RESTError( + "InvalidParameterValue", + "HttpCode must be like 200 | 200-399 | 200,201 ...", + ) if http_codes is not None: - target_group.matcher['HttpCode'] = http_codes + target_group.matcher["HttpCode"] = http_codes if health_check_interval is not None: target_group.healthcheck_interval_seconds = health_check_interval if health_check_path is not None: @@ -895,7 +1066,15 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i return target_group - def modify_listener(self, arn, port=None, protocol=None, ssl_policy=None, certificates=None, default_actions=None): + def modify_listener( + self, + arn, + port=None, + protocol=None, + ssl_policy=None, + certificates=None, + default_actions=None, + ): default_actions = [FakeAction(action) for action in default_actions] for load_balancer in self.load_balancers.values(): if arn in load_balancer.listeners: @@ -915,33 +1094,46 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i listener.port = port if protocol is not None: - if protocol not in ('HTTP', 'HTTPS', 'TCP'): - raise RESTError('UnsupportedProtocol', 'Protocol {0} is not supported'.format(protocol)) + if protocol not in ("HTTP", "HTTPS", "TCP"): + raise RESTError( + "UnsupportedProtocol", + "Protocol {0} is not supported".format(protocol), + ) # HTTPS checks - if protocol == 'HTTPS': + if protocol == "HTTPS": # HTTPS # Might already be HTTPS so may not provide certs - if certificates is None and listener.protocol != 'HTTPS': - raise RESTError('InvalidConfigurationRequest', 'Certificates must be provided for HTTPS') + if certificates is None and listener.protocol != "HTTPS": + raise RESTError( + "InvalidConfigurationRequest", + "Certificates must be provided for HTTPS", + ) # Check certificates exist if certificates is not None: default_cert = None all_certs = set() # for SNI for cert in certificates: - if cert['is_default'] == 'true': - default_cert = cert['certificate_arn'] + if cert["is_default"] == "true": + default_cert = cert["certificate_arn"] try: - self.acm_backend.get_certificate(cert['certificate_arn']) + self.acm_backend.get_certificate(cert["certificate_arn"]) except Exception: - raise RESTError('CertificateNotFound', 'Certificate {0} not found'.format(cert['certificate_arn'])) + raise RESTError( + "CertificateNotFound", + "Certificate {0} not found".format( + cert["certificate_arn"] + ), + ) - all_certs.add(cert['certificate_arn']) + all_certs.add(cert["certificate_arn"]) if default_cert is None: - raise RESTError('InvalidConfigurationRequest', 'No default certificate') + raise RESTError( + "InvalidConfigurationRequest", "No default certificate" + ) listener.certificate = default_cert listener.certificates = list(all_certs) @@ -963,7 +1155,7 @@ Member must satisfy regular expression pattern: ^(2|4|5)\d\d$" % (status_code, i for listener in load_balancer.listeners.values(): for rule in listener.rules: for action in rule.actions: - if action.data.get('target_group_arn') == target_group_arn: + if action.data.get("target_group_arn") == target_group_arn: return True return False diff --git a/moto/elbv2/responses.py b/moto/elbv2/responses.py index 25c23bb17..922de96d4 100644 --- a/moto/elbv2/responses.py +++ b/moto/elbv2/responses.py @@ -10,120 +10,120 @@ from .exceptions import TargetGroupNotFoundError SSL_POLICIES = [ { - 'name': 'ELBSecurityPolicy-2016-08', - 'ssl_protocols': ['TLSv1', 'TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18} + "name": "ELBSecurityPolicy-2016-08", + "ssl_protocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, ], }, { - 'name': 'ELBSecurityPolicy-TLS-1-2-2017-01', - 'ssl_protocols': ['TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 5}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 8}, - {'name': 'AES128-GCM-SHA256', 'priority': 9}, - {'name': 'AES128-SHA256', 'priority': 10}, - {'name': 'AES256-GCM-SHA384', 'priority': 11}, - {'name': 'AES256-SHA256', 'priority': 12} - ] + "name": "ELBSecurityPolicy-TLS-1-2-2017-01", + "ssl_protocols": ["TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 5}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 8}, + {"name": "AES128-GCM-SHA256", "priority": 9}, + {"name": "AES128-SHA256", "priority": 10}, + {"name": "AES256-GCM-SHA384", "priority": 11}, + {"name": "AES256-SHA256", "priority": 12}, + ], }, { - 'name': 'ELBSecurityPolicy-TLS-1-1-2017-01', - 'ssl_protocols': ['TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18} - ] + "name": "ELBSecurityPolicy-TLS-1-1-2017-01", + "ssl_protocols": ["TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, + ], }, { - 'name': 'ELBSecurityPolicy-2015-05', - 'ssl_protocols': ['TLSv1', 'TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18} - ] + "name": "ELBSecurityPolicy-2015-05", + "ssl_protocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, + ], }, { - 'name': 'ELBSecurityPolicy-TLS-1-0-2015-04', - 'ssl_protocols': ['TLSv1', 'TLSv1.1', 'TLSv1.2'], - 'ciphers': [ - {'name': 'ECDHE-ECDSA-AES128-GCM-SHA256', 'priority': 1}, - {'name': 'ECDHE-RSA-AES128-GCM-SHA256', 'priority': 2}, - {'name': 'ECDHE-ECDSA-AES128-SHA256', 'priority': 3}, - {'name': 'ECDHE-RSA-AES128-SHA256', 'priority': 4}, - {'name': 'ECDHE-ECDSA-AES128-SHA', 'priority': 5}, - {'name': 'ECDHE-RSA-AES128-SHA', 'priority': 6}, - {'name': 'ECDHE-ECDSA-AES256-GCM-SHA384', 'priority': 7}, - {'name': 'ECDHE-RSA-AES256-GCM-SHA384', 'priority': 8}, - {'name': 'ECDHE-ECDSA-AES256-SHA384', 'priority': 9}, - {'name': 'ECDHE-RSA-AES256-SHA384', 'priority': 10}, - {'name': 'ECDHE-RSA-AES256-SHA', 'priority': 11}, - {'name': 'ECDHE-ECDSA-AES256-SHA', 'priority': 12}, - {'name': 'AES128-GCM-SHA256', 'priority': 13}, - {'name': 'AES128-SHA256', 'priority': 14}, - {'name': 'AES128-SHA', 'priority': 15}, - {'name': 'AES256-GCM-SHA384', 'priority': 16}, - {'name': 'AES256-SHA256', 'priority': 17}, - {'name': 'AES256-SHA', 'priority': 18}, - {'name': 'DES-CBC3-SHA', 'priority': 19} - ] - } + "name": "ELBSecurityPolicy-TLS-1-0-2015-04", + "ssl_protocols": ["TLSv1", "TLSv1.1", "TLSv1.2"], + "ciphers": [ + {"name": "ECDHE-ECDSA-AES128-GCM-SHA256", "priority": 1}, + {"name": "ECDHE-RSA-AES128-GCM-SHA256", "priority": 2}, + {"name": "ECDHE-ECDSA-AES128-SHA256", "priority": 3}, + {"name": "ECDHE-RSA-AES128-SHA256", "priority": 4}, + {"name": "ECDHE-ECDSA-AES128-SHA", "priority": 5}, + {"name": "ECDHE-RSA-AES128-SHA", "priority": 6}, + {"name": "ECDHE-ECDSA-AES256-GCM-SHA384", "priority": 7}, + {"name": "ECDHE-RSA-AES256-GCM-SHA384", "priority": 8}, + {"name": "ECDHE-ECDSA-AES256-SHA384", "priority": 9}, + {"name": "ECDHE-RSA-AES256-SHA384", "priority": 10}, + {"name": "ECDHE-RSA-AES256-SHA", "priority": 11}, + {"name": "ECDHE-ECDSA-AES256-SHA", "priority": 12}, + {"name": "AES128-GCM-SHA256", "priority": 13}, + {"name": "AES128-SHA256", "priority": 14}, + {"name": "AES128-SHA", "priority": 15}, + {"name": "AES256-GCM-SHA384", "priority": 16}, + {"name": "AES256-SHA256", "priority": 17}, + {"name": "AES256-SHA", "priority": 18}, + {"name": "DES-CBC3-SHA", "priority": 19}, + ], + }, ] @@ -134,10 +134,10 @@ class ELBV2Response(BaseResponse): @amzn_request_id def create_load_balancer(self): - load_balancer_name = self._get_param('Name') + load_balancer_name = self._get_param("Name") subnet_ids = self._get_multi_param("Subnets.member") security_groups = self._get_multi_param("SecurityGroups.member") - scheme = self._get_param('Scheme') + scheme = self._get_param("Scheme") load_balancer = self.elbv2_backend.create_load_balancer( name=load_balancer_name, @@ -151,43 +151,43 @@ class ELBV2Response(BaseResponse): @amzn_request_id def create_rule(self): - lister_arn = self._get_param('ListenerArn') - _conditions = self._get_list_prefix('Conditions.member') + lister_arn = self._get_param("ListenerArn") + _conditions = self._get_list_prefix("Conditions.member") conditions = [] for _condition in _conditions: condition = {} - condition['field'] = _condition['field'] + condition["field"] = _condition["field"] values = sorted( - [e for e in _condition.items() if e[0].startswith('values.member')], - key=lambda x: x[0] + [e for e in _condition.items() if e[0].startswith("values.member")], + key=lambda x: x[0], ) - condition['values'] = [e[1] for e in values] + condition["values"] = [e[1] for e in values] conditions.append(condition) - priority = self._get_int_param('Priority') - actions = self._get_list_prefix('Actions.member') + priority = self._get_int_param("Priority") + actions = self._get_list_prefix("Actions.member") rules = self.elbv2_backend.create_rule( listener_arn=lister_arn, conditions=conditions, priority=priority, - actions=actions + actions=actions, ) template = self.response_template(CREATE_RULE_TEMPLATE) return template.render(rules=rules) @amzn_request_id def create_target_group(self): - name = self._get_param('Name') - vpc_id = self._get_param('VpcId') - protocol = self._get_param('Protocol') - port = self._get_param('Port') - healthcheck_protocol = self._get_param('HealthCheckProtocol') - healthcheck_port = self._get_param('HealthCheckPort') - healthcheck_path = self._get_param('HealthCheckPath') - healthcheck_interval_seconds = self._get_param('HealthCheckIntervalSeconds') - healthcheck_timeout_seconds = self._get_param('HealthCheckTimeoutSeconds') - healthy_threshold_count = self._get_param('HealthyThresholdCount') - unhealthy_threshold_count = self._get_param('UnhealthyThresholdCount') - matcher = self._get_param('Matcher') + name = self._get_param("Name") + vpc_id = self._get_param("VpcId") + protocol = self._get_param("Protocol") + port = self._get_param("Port") + healthcheck_protocol = self._get_param("HealthCheckProtocol") + healthcheck_port = self._get_param("HealthCheckPort") + healthcheck_path = self._get_param("HealthCheckPath") + healthcheck_interval_seconds = self._get_param("HealthCheckIntervalSeconds") + healthcheck_timeout_seconds = self._get_param("HealthCheckTimeoutSeconds") + healthy_threshold_count = self._get_param("HealthyThresholdCount") + unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount") + matcher = self._get_param("Matcher") target_group = self.elbv2_backend.create_target_group( name, @@ -209,16 +209,16 @@ class ELBV2Response(BaseResponse): @amzn_request_id def create_listener(self): - load_balancer_arn = self._get_param('LoadBalancerArn') - protocol = self._get_param('Protocol') - port = self._get_param('Port') - ssl_policy = self._get_param('SslPolicy', 'ELBSecurityPolicy-2016-08') - certificates = self._get_list_prefix('Certificates.member') + load_balancer_arn = self._get_param("LoadBalancerArn") + protocol = self._get_param("Protocol") + port = self._get_param("Port") + ssl_policy = self._get_param("SslPolicy", "ELBSecurityPolicy-2016-08") + certificates = self._get_list_prefix("Certificates.member") if certificates: - certificate = certificates[0].get('certificate_arn') + certificate = certificates[0].get("certificate_arn") else: certificate = None - default_actions = self._get_list_prefix('DefaultActions.member') + default_actions = self._get_list_prefix("DefaultActions.member") listener = self.elbv2_backend.create_listener( load_balancer_arn=load_balancer_arn, @@ -226,7 +226,8 @@ class ELBV2Response(BaseResponse): port=port, ssl_policy=ssl_policy, certificate=certificate, - default_actions=default_actions) + default_actions=default_actions, + ) template = self.response_template(CREATE_LISTENER_TEMPLATE) return template.render(listener=listener) @@ -235,15 +236,19 @@ class ELBV2Response(BaseResponse): def describe_load_balancers(self): arns = self._get_multi_param("LoadBalancerArns.member") names = self._get_multi_param("Names.member") - all_load_balancers = list(self.elbv2_backend.describe_load_balancers(arns, names)) - marker = self._get_param('Marker') + all_load_balancers = list( + self.elbv2_backend.describe_load_balancers(arns, names) + ) + marker = self._get_param("Marker") all_names = [balancer.name for balancer in all_load_balancers] if marker: start = all_names.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('PageSize', 50) # the default is 400, but using 50 to make testing easier - load_balancers_resp = all_load_balancers[start:start + page_size] + page_size = self._get_int_param( + "PageSize", 50 + ) # the default is 400, but using 50 to make testing easier + load_balancers_resp = all_load_balancers[start : start + page_size] next_marker = None if len(all_load_balancers) > start + page_size: next_marker = load_balancers_resp[-1].name @@ -253,18 +258,26 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_rules(self): - listener_arn = self._get_param('ListenerArn') - rule_arns = self._get_multi_param('RuleArns.member') if any(k for k in list(self.querystring.keys()) if k.startswith('RuleArns.member')) else None + listener_arn = self._get_param("ListenerArn") + rule_arns = ( + self._get_multi_param("RuleArns.member") + if any( + k + for k in list(self.querystring.keys()) + if k.startswith("RuleArns.member") + ) + else None + ) all_rules = self.elbv2_backend.describe_rules(listener_arn, rule_arns) all_arns = [rule.arn for rule in all_rules] - page_size = self._get_int_param('PageSize', 50) # set 50 for temporary + page_size = self._get_int_param("PageSize", 50) # set 50 for temporary - marker = self._get_param('Marker') + marker = self._get_param("Marker") if marker: start = all_arns.index(marker) + 1 else: start = 0 - rules_resp = all_rules[start:start + page_size] + rules_resp = all_rules[start : start + page_size] next_marker = None if len(all_rules) > start + page_size: @@ -274,17 +287,19 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_target_groups(self): - load_balancer_arn = self._get_param('LoadBalancerArn') - target_group_arns = self._get_multi_param('TargetGroupArns.member') - names = self._get_multi_param('Names.member') + load_balancer_arn = self._get_param("LoadBalancerArn") + target_group_arns = self._get_multi_param("TargetGroupArns.member") + names = self._get_multi_param("Names.member") - target_groups = self.elbv2_backend.describe_target_groups(load_balancer_arn, target_group_arns, names) + target_groups = self.elbv2_backend.describe_target_groups( + load_balancer_arn, target_group_arns, names + ) template = self.response_template(DESCRIBE_TARGET_GROUPS_TEMPLATE) return template.render(target_groups=target_groups) @amzn_request_id def describe_target_group_attributes(self): - target_group_arn = self._get_param('TargetGroupArn') + target_group_arn = self._get_param("TargetGroupArn") target_group = self.elbv2_backend.target_groups.get(target_group_arn) if not target_group: raise TargetGroupNotFoundError() @@ -293,73 +308,73 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_listeners(self): - load_balancer_arn = self._get_param('LoadBalancerArn') - listener_arns = self._get_multi_param('ListenerArns.member') + load_balancer_arn = self._get_param("LoadBalancerArn") + listener_arns = self._get_multi_param("ListenerArns.member") if not load_balancer_arn and not listener_arns: raise LoadBalancerNotFoundError() - listeners = self.elbv2_backend.describe_listeners(load_balancer_arn, listener_arns) + listeners = self.elbv2_backend.describe_listeners( + load_balancer_arn, listener_arns + ) template = self.response_template(DESCRIBE_LISTENERS_TEMPLATE) return template.render(listeners=listeners) @amzn_request_id def delete_load_balancer(self): - arn = self._get_param('LoadBalancerArn') + arn = self._get_param("LoadBalancerArn") self.elbv2_backend.delete_load_balancer(arn) template = self.response_template(DELETE_LOAD_BALANCER_TEMPLATE) return template.render() @amzn_request_id def delete_rule(self): - arn = self._get_param('RuleArn') + arn = self._get_param("RuleArn") self.elbv2_backend.delete_rule(arn) template = self.response_template(DELETE_RULE_TEMPLATE) return template.render() @amzn_request_id def delete_target_group(self): - arn = self._get_param('TargetGroupArn') + arn = self._get_param("TargetGroupArn") self.elbv2_backend.delete_target_group(arn) template = self.response_template(DELETE_TARGET_GROUP_TEMPLATE) return template.render() @amzn_request_id def delete_listener(self): - arn = self._get_param('ListenerArn') + arn = self._get_param("ListenerArn") self.elbv2_backend.delete_listener(arn) template = self.response_template(DELETE_LISTENER_TEMPLATE) return template.render() @amzn_request_id def modify_rule(self): - rule_arn = self._get_param('RuleArn') - _conditions = self._get_list_prefix('Conditions.member') + rule_arn = self._get_param("RuleArn") + _conditions = self._get_list_prefix("Conditions.member") conditions = [] for _condition in _conditions: condition = {} - condition['field'] = _condition['field'] + condition["field"] = _condition["field"] values = sorted( - [e for e in _condition.items() if e[0].startswith('values.member')], - key=lambda x: x[0] + [e for e in _condition.items() if e[0].startswith("values.member")], + key=lambda x: x[0], ) - condition['values'] = [e[1] for e in values] + condition["values"] = [e[1] for e in values] conditions.append(condition) - actions = self._get_list_prefix('Actions.member') + actions = self._get_list_prefix("Actions.member") rules = self.elbv2_backend.modify_rule( - rule_arn=rule_arn, - conditions=conditions, - actions=actions + rule_arn=rule_arn, conditions=conditions, actions=actions ) template = self.response_template(MODIFY_RULE_TEMPLATE) return template.render(rules=rules) @amzn_request_id def modify_target_group_attributes(self): - target_group_arn = self._get_param('TargetGroupArn') + target_group_arn = self._get_param("TargetGroupArn") target_group = self.elbv2_backend.target_groups.get(target_group_arn) attributes = { - attr['key']: attr['value'] - for attr in self._get_list_prefix('Attributes.member') + attr["key"]: attr["value"] + for attr in self._get_list_prefix("Attributes.member") } target_group.attributes.update(attributes) if not target_group: @@ -369,8 +384,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def register_targets(self): - target_group_arn = self._get_param('TargetGroupArn') - targets = self._get_list_prefix('Targets.member') + target_group_arn = self._get_param("TargetGroupArn") + targets = self._get_list_prefix("Targets.member") self.elbv2_backend.register_targets(target_group_arn, targets) template = self.response_template(REGISTER_TARGETS_TEMPLATE) @@ -378,8 +393,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def deregister_targets(self): - target_group_arn = self._get_param('TargetGroupArn') - targets = self._get_list_prefix('Targets.member') + target_group_arn = self._get_param("TargetGroupArn") + targets = self._get_list_prefix("Targets.member") self.elbv2_backend.deregister_targets(target_group_arn, targets) template = self.response_template(DEREGISTER_TARGETS_TEMPLATE) @@ -387,32 +402,34 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_target_health(self): - target_group_arn = self._get_param('TargetGroupArn') - targets = self._get_list_prefix('Targets.member') - target_health_descriptions = self.elbv2_backend.describe_target_health(target_group_arn, targets) + target_group_arn = self._get_param("TargetGroupArn") + targets = self._get_list_prefix("Targets.member") + target_health_descriptions = self.elbv2_backend.describe_target_health( + target_group_arn, targets + ) template = self.response_template(DESCRIBE_TARGET_HEALTH_TEMPLATE) return template.render(target_health_descriptions=target_health_descriptions) @amzn_request_id def set_rule_priorities(self): - rule_priorities = self._get_list_prefix('RulePriorities.member') + rule_priorities = self._get_list_prefix("RulePriorities.member") for rule_priority in rule_priorities: - rule_priority['priority'] = int(rule_priority['priority']) + rule_priority["priority"] = int(rule_priority["priority"]) rules = self.elbv2_backend.set_rule_priorities(rule_priorities) template = self.response_template(SET_RULE_PRIORITIES_TEMPLATE) return template.render(rules=rules) @amzn_request_id def add_tags(self): - resource_arns = self._get_multi_param('ResourceArns.member') + resource_arns = self._get_multi_param("ResourceArns.member") for arn in resource_arns: - if ':targetgroup' in arn: + if ":targetgroup" in arn: resource = self.elbv2_backend.target_groups.get(arn) if not resource: raise TargetGroupNotFoundError() - elif ':loadbalancer' in arn: + elif ":loadbalancer" in arn: resource = self.elbv2_backend.load_balancers.get(arn) if not resource: raise LoadBalancerNotFoundError() @@ -425,15 +442,15 @@ class ELBV2Response(BaseResponse): @amzn_request_id def remove_tags(self): - resource_arns = self._get_multi_param('ResourceArns.member') - tag_keys = self._get_multi_param('TagKeys.member') + resource_arns = self._get_multi_param("ResourceArns.member") + tag_keys = self._get_multi_param("TagKeys.member") for arn in resource_arns: - if ':targetgroup' in arn: + if ":targetgroup" in arn: resource = self.elbv2_backend.target_groups.get(arn) if not resource: raise TargetGroupNotFoundError() - elif ':loadbalancer' in arn: + elif ":loadbalancer" in arn: resource = self.elbv2_backend.load_balancers.get(arn) if not resource: raise LoadBalancerNotFoundError() @@ -446,14 +463,14 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_tags(self): - resource_arns = self._get_multi_param('ResourceArns.member') + resource_arns = self._get_multi_param("ResourceArns.member") resources = [] for arn in resource_arns: - if ':targetgroup' in arn: + if ":targetgroup" in arn: resource = self.elbv2_backend.target_groups.get(arn) if not resource: raise TargetGroupNotFoundError() - elif ':loadbalancer' in arn: + elif ":loadbalancer" in arn: resource = self.elbv2_backend.load_balancers.get(arn) if not resource: raise LoadBalancerNotFoundError() @@ -471,14 +488,14 @@ class ELBV2Response(BaseResponse): # page_size = self._get_int_param('PageSize') limits = { - 'application-load-balancers': 20, - 'target-groups': 3000, - 'targets-per-application-load-balancer': 30, - 'listeners-per-application-load-balancer': 50, - 'rules-per-application-load-balancer': 100, - 'network-load-balancers': 20, - 'targets-per-network-load-balancer': 200, - 'listeners-per-network-load-balancer': 50 + "application-load-balancers": 20, + "target-groups": 3000, + "targets-per-application-load-balancer": 30, + "listeners-per-application-load-balancer": 50, + "rules-per-application-load-balancer": 100, + "network-load-balancers": 20, + "targets-per-network-load-balancer": 200, + "listeners-per-network-load-balancer": 50, } template = self.response_template(DESCRIBE_LIMITS_TEMPLATE) @@ -486,22 +503,22 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_ssl_policies(self): - names = self._get_multi_param('Names.member.') + names = self._get_multi_param("Names.member.") # Supports paging but not worth implementing yet # marker = self._get_param('Marker') # page_size = self._get_int_param('PageSize') policies = SSL_POLICIES if names: - policies = filter(lambda policy: policy['name'] in names, policies) + policies = filter(lambda policy: policy["name"] in names, policies) template = self.response_template(DESCRIBE_SSL_POLICIES_TEMPLATE) return template.render(policies=policies) @amzn_request_id def set_ip_address_type(self): - arn = self._get_param('LoadBalancerArn') - ip_type = self._get_param('IpAddressType') + arn = self._get_param("LoadBalancerArn") + ip_type = self._get_param("IpAddressType") self.elbv2_backend.set_ip_address_type(arn, ip_type) @@ -510,8 +527,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def set_security_groups(self): - arn = self._get_param('LoadBalancerArn') - sec_groups = self._get_multi_param('SecurityGroups.member.') + arn = self._get_param("LoadBalancerArn") + sec_groups = self._get_multi_param("SecurityGroups.member.") self.elbv2_backend.set_security_groups(arn, sec_groups) @@ -520,8 +537,8 @@ class ELBV2Response(BaseResponse): @amzn_request_id def set_subnets(self): - arn = self._get_param('LoadBalancerArn') - subnets = self._get_multi_param('Subnets.member.') + arn = self._get_param("LoadBalancerArn") + subnets = self._get_multi_param("Subnets.member.") subnet_zone_list = self.elbv2_backend.set_subnets(arn, subnets) @@ -530,8 +547,10 @@ class ELBV2Response(BaseResponse): @amzn_request_id def modify_load_balancer_attributes(self): - arn = self._get_param('LoadBalancerArn') - attrs = self._get_map_prefix('Attributes.member', key_end='Key', value_end='Value') + arn = self._get_param("LoadBalancerArn") + attrs = self._get_map_prefix( + "Attributes.member", key_end="Key", value_end="Value" + ) all_attrs = self.elbv2_backend.modify_load_balancer_attributes(arn, attrs) @@ -540,7 +559,7 @@ class ELBV2Response(BaseResponse): @amzn_request_id def describe_load_balancer_attributes(self): - arn = self._get_param('LoadBalancerArn') + arn = self._get_param("LoadBalancerArn") attrs = self.elbv2_backend.describe_load_balancer_attributes(arn) template = self.response_template(DESCRIBE_LOADBALANCER_ATTRS_TEMPLATE) @@ -548,37 +567,54 @@ class ELBV2Response(BaseResponse): @amzn_request_id def modify_target_group(self): - arn = self._get_param('TargetGroupArn') + arn = self._get_param("TargetGroupArn") - health_check_proto = self._get_param('HealthCheckProtocol') # 'HTTP' | 'HTTPS' | 'TCP', - health_check_port = self._get_param('HealthCheckPort') - health_check_path = self._get_param('HealthCheckPath') - health_check_interval = self._get_param('HealthCheckIntervalSeconds') - health_check_timeout = self._get_param('HealthCheckTimeoutSeconds') - healthy_threshold_count = self._get_param('HealthyThresholdCount') - unhealthy_threshold_count = self._get_param('UnhealthyThresholdCount') - http_codes = self._get_param('Matcher.HttpCode') + health_check_proto = self._get_param( + "HealthCheckProtocol" + ) # 'HTTP' | 'HTTPS' | 'TCP', + health_check_port = self._get_param("HealthCheckPort") + health_check_path = self._get_param("HealthCheckPath") + health_check_interval = self._get_param("HealthCheckIntervalSeconds") + health_check_timeout = self._get_param("HealthCheckTimeoutSeconds") + healthy_threshold_count = self._get_param("HealthyThresholdCount") + unhealthy_threshold_count = self._get_param("UnhealthyThresholdCount") + http_codes = self._get_param("Matcher.HttpCode") - target_group = self.elbv2_backend.modify_target_group(arn, health_check_proto, health_check_port, health_check_path, health_check_interval, - health_check_timeout, healthy_threshold_count, unhealthy_threshold_count, http_codes) + target_group = self.elbv2_backend.modify_target_group( + arn, + health_check_proto, + health_check_port, + health_check_path, + health_check_interval, + health_check_timeout, + healthy_threshold_count, + unhealthy_threshold_count, + http_codes, + ) template = self.response_template(MODIFY_TARGET_GROUP_TEMPLATE) return template.render(target_group=target_group) @amzn_request_id def modify_listener(self): - arn = self._get_param('ListenerArn') - port = self._get_param('Port') - protocol = self._get_param('Protocol') - ssl_policy = self._get_param('SslPolicy') - certificates = self._get_list_prefix('Certificates.member') - default_actions = self._get_list_prefix('DefaultActions.member') + arn = self._get_param("ListenerArn") + port = self._get_param("Port") + protocol = self._get_param("Protocol") + ssl_policy = self._get_param("SslPolicy") + certificates = self._get_list_prefix("Certificates.member") + default_actions = self._get_list_prefix("DefaultActions.member") # Should really move SSL Policies to models - if ssl_policy is not None and ssl_policy not in [item['name'] for item in SSL_POLICIES]: - raise RESTError('SSLPolicyNotFound', 'Policy {0} not found'.format(ssl_policy)) + if ssl_policy is not None and ssl_policy not in [ + item["name"] for item in SSL_POLICIES + ]: + raise RESTError( + "SSLPolicyNotFound", "Policy {0} not found".format(ssl_policy) + ) - listener = self.elbv2_backend.modify_listener(arn, port, protocol, ssl_policy, certificates, default_actions) + listener = self.elbv2_backend.modify_listener( + arn, port, protocol, ssl_policy, certificates, default_actions + ) template = self.response_template(MODIFY_LISTENER_TEMPLATE) return template.render(listener=listener) @@ -588,10 +624,10 @@ class ELBV2Response(BaseResponse): tag_keys = [] for t_key, t_val in sorted(self.querystring.items()): - if t_key.startswith('Tags.member.'): - if t_key.split('.')[3] == 'Key': + if t_key.startswith("Tags.member."): + if t_key.split(".")[3] == "Key": tag_keys.extend(t_val) - elif t_key.split('.')[3] == 'Value': + elif t_key.split(".")[3] == "Value": tag_values.extend(t_val) counts = {} diff --git a/moto/elbv2/urls.py b/moto/elbv2/urls.py index af51f7d3a..06b8f107e 100644 --- a/moto/elbv2/urls.py +++ b/moto/elbv2/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from ..elb.urls import api_version_elb_backend -url_bases = [ - "https?://elasticloadbalancing.(.+).amazonaws.com", -] +url_bases = ["https?://elasticloadbalancing.(.+).amazonaws.com"] -url_paths = { - '{0}/$': api_version_elb_backend, -} +url_paths = {"{0}/$": api_version_elb_backend} diff --git a/moto/elbv2/utils.py b/moto/elbv2/utils.py index 47a3e66d5..017878e2f 100644 --- a/moto/elbv2/utils.py +++ b/moto/elbv2/utils.py @@ -1,8 +1,10 @@ def make_arn_for_load_balancer(account_id, name, region_name): return "arn:aws:elasticloadbalancing:{}:{}:loadbalancer/{}/50dc6c495c0c9188".format( - region_name, account_id, name) + region_name, account_id, name + ) def make_arn_for_target_group(account_id, name, region_name): return "arn:aws:elasticloadbalancing:{}:{}:targetgroup/{}/50dc6c495c0c9188".format( - region_name, account_id, name) + region_name, account_id, name + ) diff --git a/moto/emr/__init__.py b/moto/emr/__init__.py index b4223f2cb..d35506271 100644 --- a/moto/emr/__init__.py +++ b/moto/emr/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import emr_backends from ..core.models import base_decorator, deprecated_base_decorator -emr_backend = emr_backends['us-east-1'] +emr_backend = emr_backends["us-east-1"] mock_emr = base_decorator(emr_backends) mock_emr_deprecated = deprecated_base_decorator(emr_backends) diff --git a/moto/emr/models.py b/moto/emr/models.py index 4b591acb1..b62ce7932 100644 --- a/moto/emr/models.py +++ b/moto/emr/models.py @@ -11,7 +11,6 @@ from .utils import random_instance_group_id, random_cluster_id, random_step_id class FakeApplication(BaseModel): - def __init__(self, name, version, args=None, additional_info=None): self.additional_info = additional_info or {} self.args = args or [] @@ -20,7 +19,6 @@ class FakeApplication(BaseModel): class FakeBootstrapAction(BaseModel): - def __init__(self, args, name, script_path): self.args = args or [] self.name = name @@ -28,20 +26,27 @@ class FakeBootstrapAction(BaseModel): class FakeInstanceGroup(BaseModel): - - def __init__(self, instance_count, instance_role, instance_type, - market='ON_DEMAND', name=None, id=None, bid_price=None): + def __init__( + self, + instance_count, + instance_role, + instance_type, + market="ON_DEMAND", + name=None, + id=None, + bid_price=None, + ): self.id = id or random_instance_group_id() self.bid_price = bid_price self.market = market if name is None: - if instance_role == 'MASTER': - name = 'master' - elif instance_role == 'CORE': - name = 'slave' + if instance_role == "MASTER": + name = "master" + elif instance_role == "CORE": + name = "slave" else: - name = 'Task instance group' + name = "Task instance group" self.name = name self.num_instances = instance_count self.role = instance_role @@ -51,21 +56,22 @@ class FakeInstanceGroup(BaseModel): self.start_datetime = datetime.now(pytz.utc) self.ready_datetime = datetime.now(pytz.utc) self.end_datetime = None - self.state = 'RUNNING' + self.state = "RUNNING" def set_instance_count(self, instance_count): self.num_instances = instance_count class FakeStep(BaseModel): - - def __init__(self, - state, - name='', - jar='', - args=None, - properties=None, - action_on_failure='TERMINATE_CLUSTER'): + def __init__( + self, + state, + name="", + jar="", + args=None, + properties=None, + action_on_failure="TERMINATE_CLUSTER", + ): self.id = random_step_id() self.action_on_failure = action_on_failure @@ -82,23 +88,24 @@ class FakeStep(BaseModel): class FakeCluster(BaseModel): - - def __init__(self, - emr_backend, - name, - log_uri, - job_flow_role, - service_role, - steps, - instance_attrs, - bootstrap_actions=None, - configurations=None, - cluster_id=None, - visible_to_all_users='false', - release_label=None, - requested_ami_version=None, - running_ami_version=None, - custom_ami_id=None): + def __init__( + self, + emr_backend, + name, + log_uri, + job_flow_role, + service_role, + steps, + instance_attrs, + bootstrap_actions=None, + configurations=None, + cluster_id=None, + visible_to_all_users="false", + release_label=None, + requested_ami_version=None, + running_ami_version=None, + custom_ami_id=None, + ): self.id = cluster_id or random_cluster_id() emr_backend.clusters[self.id] = self self.emr_backend = emr_backend @@ -106,7 +113,7 @@ class FakeCluster(BaseModel): self.applications = [] self.bootstrap_actions = [] - for bootstrap_action in (bootstrap_actions or []): + for bootstrap_action in bootstrap_actions or []: self.add_bootstrap_action(bootstrap_action) self.configurations = configurations or [] @@ -125,47 +132,68 @@ class FakeCluster(BaseModel): self.instance_group_ids = [] self.master_instance_group_id = None self.core_instance_group_id = None - if 'master_instance_type' in instance_attrs and instance_attrs['master_instance_type']: + if ( + "master_instance_type" in instance_attrs + and instance_attrs["master_instance_type"] + ): self.emr_backend.add_instance_groups( self.id, - [{'instance_count': 1, - 'instance_role': 'MASTER', - 'instance_type': instance_attrs['master_instance_type'], - 'market': 'ON_DEMAND', - 'name': 'master'}]) - if 'slave_instance_type' in instance_attrs and instance_attrs['slave_instance_type']: + [ + { + "instance_count": 1, + "instance_role": "MASTER", + "instance_type": instance_attrs["master_instance_type"], + "market": "ON_DEMAND", + "name": "master", + } + ], + ) + if ( + "slave_instance_type" in instance_attrs + and instance_attrs["slave_instance_type"] + ): self.emr_backend.add_instance_groups( self.id, - [{'instance_count': instance_attrs['instance_count'] - 1, - 'instance_role': 'CORE', - 'instance_type': instance_attrs['slave_instance_type'], - 'market': 'ON_DEMAND', - 'name': 'slave'}]) + [ + { + "instance_count": instance_attrs["instance_count"] - 1, + "instance_role": "CORE", + "instance_type": instance_attrs["slave_instance_type"], + "market": "ON_DEMAND", + "name": "slave", + } + ], + ) self.additional_master_security_groups = instance_attrs.get( - 'additional_master_security_groups') + "additional_master_security_groups" + ) self.additional_slave_security_groups = instance_attrs.get( - 'additional_slave_security_groups') - self.availability_zone = instance_attrs.get('availability_zone') - self.ec2_key_name = instance_attrs.get('ec2_key_name') - self.ec2_subnet_id = instance_attrs.get('ec2_subnet_id') - self.hadoop_version = instance_attrs.get('hadoop_version') + "additional_slave_security_groups" + ) + self.availability_zone = instance_attrs.get("availability_zone") + self.ec2_key_name = instance_attrs.get("ec2_key_name") + self.ec2_subnet_id = instance_attrs.get("ec2_subnet_id") + self.hadoop_version = instance_attrs.get("hadoop_version") self.keep_job_flow_alive_when_no_steps = instance_attrs.get( - 'keep_job_flow_alive_when_no_steps') + "keep_job_flow_alive_when_no_steps" + ) self.master_security_group = instance_attrs.get( - 'emr_managed_master_security_group') + "emr_managed_master_security_group" + ) self.service_access_security_group = instance_attrs.get( - 'service_access_security_group') + "service_access_security_group" + ) self.slave_security_group = instance_attrs.get( - 'emr_managed_slave_security_group') - self.termination_protected = instance_attrs.get( - 'termination_protected') + "emr_managed_slave_security_group" + ) + self.termination_protected = instance_attrs.get("termination_protected") self.release_label = release_label self.requested_ami_version = requested_ami_version self.running_ami_version = running_ami_version self.custom_ami_id = custom_ami_id - self.role = job_flow_role or 'EMRJobflowDefault' + self.role = job_flow_role or "EMRJobflowDefault" self.service_role = service_role self.creation_datetime = datetime.now(pytz.utc) @@ -194,42 +222,46 @@ class FakeCluster(BaseModel): return sum(group.num_instances for group in self.instance_groups) def start_cluster(self): - self.state = 'STARTING' + self.state = "STARTING" self.start_datetime = datetime.now(pytz.utc) def run_bootstrap_actions(self): - self.state = 'BOOTSTRAPPING' + self.state = "BOOTSTRAPPING" self.ready_datetime = datetime.now(pytz.utc) - self.state = 'WAITING' + self.state = "WAITING" if not self.steps: if not self.keep_job_flow_alive_when_no_steps: self.terminate() def terminate(self): - self.state = 'TERMINATING' + self.state = "TERMINATING" self.end_datetime = datetime.now(pytz.utc) - self.state = 'TERMINATED' + self.state = "TERMINATED" def add_applications(self, applications): - self.applications.extend([ - FakeApplication( - name=app.get('name', ''), - version=app.get('version', ''), - args=app.get('args', []), - additional_info=app.get('additiona_info', {})) - for app in applications]) + self.applications.extend( + [ + FakeApplication( + name=app.get("name", ""), + version=app.get("version", ""), + args=app.get("args", []), + additional_info=app.get("additiona_info", {}), + ) + for app in applications + ] + ) def add_bootstrap_action(self, bootstrap_action): self.bootstrap_actions.append(FakeBootstrapAction(**bootstrap_action)) def add_instance_group(self, instance_group): - if instance_group.role == 'MASTER': + if instance_group.role == "MASTER": if self.master_instance_group_id: - raise Exception('Cannot add another master instance group') + raise Exception("Cannot add another master instance group") self.master_instance_group_id = instance_group.id - if instance_group.role == 'CORE': + if instance_group.role == "CORE": if self.core_instance_group_id: - raise Exception('Cannot add another core instance group') + raise Exception("Cannot add another core instance group") self.core_instance_group_id = instance_group.id self.instance_group_ids.append(instance_group.id) @@ -238,12 +270,12 @@ class FakeCluster(BaseModel): for step in steps: if self.steps: # If we already have other steps, this one is pending - fake = FakeStep(state='PENDING', **step) + fake = FakeStep(state="PENDING", **step) else: - fake = FakeStep(state='STARTING', **step) + fake = FakeStep(state="STARTING", **step) self.steps.append(fake) added_steps.append(fake) - self.state = 'RUNNING' + self.state = "RUNNING" return added_steps def add_tags(self, tags): @@ -261,7 +293,6 @@ class FakeCluster(BaseModel): class ElasticMapReduceBackend(BaseBackend): - def __init__(self, region_name): super(ElasticMapReduceBackend, self).__init__() self.region_name = region_name @@ -296,12 +327,17 @@ class ElasticMapReduceBackend(BaseBackend): cluster = self.get_cluster(cluster_id) cluster.add_tags(tags) - def describe_job_flows(self, job_flow_ids=None, job_flow_states=None, created_after=None, created_before=None): + def describe_job_flows( + self, + job_flow_ids=None, + job_flow_states=None, + created_after=None, + created_before=None, + ): clusters = self.clusters.values() within_two_month = datetime.now(pytz.utc) - timedelta(days=60) - clusters = [ - c for c in clusters if c.creation_datetime >= within_two_month] + clusters = [c for c in clusters if c.creation_datetime >= within_two_month] if job_flow_ids: clusters = [c for c in clusters if c.id in job_flow_ids] @@ -309,12 +345,10 @@ class ElasticMapReduceBackend(BaseBackend): clusters = [c for c in clusters if c.state in job_flow_states] if created_after: created_after = dtparse(created_after) - clusters = [ - c for c in clusters if c.creation_datetime > created_after] + clusters = [c for c in clusters if c.creation_datetime > created_after] if created_before: created_before = dtparse(created_before) - clusters = [ - c for c in clusters if c.creation_datetime < created_before] + clusters = [c for c in clusters if c.creation_datetime < created_before] # Amazon EMR can return a maximum of 512 job flow descriptions return sorted(clusters, key=lambda x: x.id)[:512] @@ -328,12 +362,12 @@ class ElasticMapReduceBackend(BaseBackend): def get_cluster(self, cluster_id): if cluster_id in self.clusters: return self.clusters[cluster_id] - raise EmrError('ResourceNotFoundException', '', 'error_json') + raise EmrError("ResourceNotFoundException", "", "error_json") def get_instance_groups(self, instance_group_ids): return [ - group for group_id, group - in self.instance_groups.items() + group + for group_id, group in self.instance_groups.items() if group_id in instance_group_ids ] @@ -341,38 +375,43 @@ class ElasticMapReduceBackend(BaseBackend): max_items = 50 actions = self.clusters[cluster_id].bootstrap_actions start_idx = 0 if marker is None else int(marker) - marker = None if len(actions) <= start_idx + \ - max_items else str(start_idx + max_items) - return actions[start_idx:start_idx + max_items], marker + marker = ( + None + if len(actions) <= start_idx + max_items + else str(start_idx + max_items) + ) + return actions[start_idx : start_idx + max_items], marker - def list_clusters(self, cluster_states=None, created_after=None, - created_before=None, marker=None): + def list_clusters( + self, cluster_states=None, created_after=None, created_before=None, marker=None + ): max_items = 50 clusters = self.clusters.values() if cluster_states: clusters = [c for c in clusters if c.state in cluster_states] if created_after: created_after = dtparse(created_after) - clusters = [ - c for c in clusters if c.creation_datetime > created_after] + clusters = [c for c in clusters if c.creation_datetime > created_after] if created_before: created_before = dtparse(created_before) - clusters = [ - c for c in clusters if c.creation_datetime < created_before] + clusters = [c for c in clusters if c.creation_datetime < created_before] clusters = sorted(clusters, key=lambda x: x.id) start_idx = 0 if marker is None else int(marker) - marker = None if len(clusters) <= start_idx + \ - max_items else str(start_idx + max_items) - return clusters[start_idx:start_idx + max_items], marker + marker = ( + None + if len(clusters) <= start_idx + max_items + else str(start_idx + max_items) + ) + return clusters[start_idx : start_idx + max_items], marker def list_instance_groups(self, cluster_id, marker=None): max_items = 50 - groups = sorted(self.clusters[cluster_id].instance_groups, - key=lambda x: x.id) + groups = sorted(self.clusters[cluster_id].instance_groups, key=lambda x: x.id) start_idx = 0 if marker is None else int(marker) - marker = None if len(groups) <= start_idx + \ - max_items else str(start_idx + max_items) - return groups[start_idx:start_idx + max_items], marker + marker = ( + None if len(groups) <= start_idx + max_items else str(start_idx + max_items) + ) + return groups[start_idx : start_idx + max_items], marker def list_steps(self, cluster_id, marker=None, step_ids=None, step_states=None): max_items = 50 @@ -382,15 +421,16 @@ class ElasticMapReduceBackend(BaseBackend): if step_states: steps = [s for s in steps if s.state in step_states] start_idx = 0 if marker is None else int(marker) - marker = None if len(steps) <= start_idx + \ - max_items else str(start_idx + max_items) - return steps[start_idx:start_idx + max_items], marker + marker = ( + None if len(steps) <= start_idx + max_items else str(start_idx + max_items) + ) + return steps[start_idx : start_idx + max_items], marker def modify_instance_groups(self, instance_groups): result_groups = [] for instance_group in instance_groups: - group = self.instance_groups[instance_group['instance_group_id']] - group.set_instance_count(int(instance_group['instance_count'])) + group = self.instance_groups[instance_group["instance_group_id"]] + group.set_instance_count(int(instance_group["instance_count"])) return result_groups def remove_tags(self, cluster_id, tag_keys): diff --git a/moto/emr/responses.py b/moto/emr/responses.py index c807b5f54..94847ec8b 100644 --- a/moto/emr/responses.py +++ b/moto/emr/responses.py @@ -20,20 +20,27 @@ def generate_boto3_response(operation): determined to be from boto3. Pass the API action as a parameter. """ + def _boto3_request(method): @wraps(method) def f(self, *args, **kwargs): rendered = method(self, *args, **kwargs) - if 'json' in self.headers.get('Content-Type', []): + if "json" in self.headers.get("Content-Type", []): self.response_headers.update( - {'x-amzn-requestid': '2690d7eb-ed86-11dd-9877-6fad448a8419', - 'date': datetime.now(pytz.utc).strftime('%a, %d %b %Y %H:%M:%S %Z'), - 'content-type': 'application/x-amz-json-1.1'}) - resp = xml_to_json_response( - self.aws_service_spec, operation, rendered) - return '' if resp is None else json.dumps(resp) + { + "x-amzn-requestid": "2690d7eb-ed86-11dd-9877-6fad448a8419", + "date": datetime.now(pytz.utc).strftime( + "%a, %d %b %Y %H:%M:%S %Z" + ), + "content-type": "application/x-amz-json-1.1", + } + ) + resp = xml_to_json_response(self.aws_service_spec, operation, rendered) + return "" if resp is None else json.dumps(resp) return rendered + return f + return _boto3_request @@ -41,10 +48,12 @@ class ElasticMapReduceResponse(BaseResponse): # EMR end points are inconsistent in the placement of region name # in the URL, so parsing it out needs to be handled differently - region_regex = [re.compile(r'elasticmapreduce\.(.+?)\.amazonaws\.com'), - re.compile(r'(.+?)\.elasticmapreduce\.amazonaws\.com')] + region_regex = [ + re.compile(r"elasticmapreduce\.(.+?)\.amazonaws\.com"), + re.compile(r"(.+?)\.elasticmapreduce\.amazonaws\.com"), + ] - aws_service_spec = AWSServiceSpec('data/emr/2009-03-31/service-2.json') + aws_service_spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json") def get_region_from_url(self, request, full_url): parsed = urlparse(full_url) @@ -58,28 +67,28 @@ class ElasticMapReduceResponse(BaseResponse): def backend(self): return emr_backends[self.region] - @generate_boto3_response('AddInstanceGroups') + @generate_boto3_response("AddInstanceGroups") def add_instance_groups(self): - jobflow_id = self._get_param('JobFlowId') - instance_groups = self._get_list_prefix('InstanceGroups.member') + jobflow_id = self._get_param("JobFlowId") + instance_groups = self._get_list_prefix("InstanceGroups.member") for item in instance_groups: - item['instance_count'] = int(item['instance_count']) - instance_groups = self.backend.add_instance_groups( - jobflow_id, instance_groups) + item["instance_count"] = int(item["instance_count"]) + instance_groups = self.backend.add_instance_groups(jobflow_id, instance_groups) template = self.response_template(ADD_INSTANCE_GROUPS_TEMPLATE) return template.render(instance_groups=instance_groups) - @generate_boto3_response('AddJobFlowSteps') + @generate_boto3_response("AddJobFlowSteps") def add_job_flow_steps(self): - job_flow_id = self._get_param('JobFlowId') + job_flow_id = self._get_param("JobFlowId") steps = self.backend.add_job_flow_steps( - job_flow_id, steps_from_query_string(self._get_list_prefix('Steps.member'))) + job_flow_id, steps_from_query_string(self._get_list_prefix("Steps.member")) + ) template = self.response_template(ADD_JOB_FLOW_STEPS_TEMPLATE) return template.render(steps=steps) - @generate_boto3_response('AddTags') + @generate_boto3_response("AddTags") def add_tags(self): - cluster_id = self._get_param('ResourceId') + cluster_id = self._get_param("ResourceId") tags = tags_from_query_string(self.querystring) self.backend.add_tags(cluster_id, tags) template = self.response_template(ADD_TAGS_TEMPLATE) @@ -94,235 +103,257 @@ class ElasticMapReduceResponse(BaseResponse): def delete_security_configuration(self): raise NotImplementedError - @generate_boto3_response('DescribeCluster') + @generate_boto3_response("DescribeCluster") def describe_cluster(self): - cluster_id = self._get_param('ClusterId') + cluster_id = self._get_param("ClusterId") cluster = self.backend.get_cluster(cluster_id) template = self.response_template(DESCRIBE_CLUSTER_TEMPLATE) return template.render(cluster=cluster) - @generate_boto3_response('DescribeJobFlows') + @generate_boto3_response("DescribeJobFlows") def describe_job_flows(self): - created_after = self._get_param('CreatedAfter') - created_before = self._get_param('CreatedBefore') + created_after = self._get_param("CreatedAfter") + created_before = self._get_param("CreatedBefore") job_flow_ids = self._get_multi_param("JobFlowIds.member") - job_flow_states = self._get_multi_param('JobFlowStates.member') + job_flow_states = self._get_multi_param("JobFlowStates.member") clusters = self.backend.describe_job_flows( - job_flow_ids, job_flow_states, created_after, created_before) + job_flow_ids, job_flow_states, created_after, created_before + ) template = self.response_template(DESCRIBE_JOB_FLOWS_TEMPLATE) return template.render(clusters=clusters) def describe_security_configuration(self): raise NotImplementedError - @generate_boto3_response('DescribeStep') + @generate_boto3_response("DescribeStep") def describe_step(self): - cluster_id = self._get_param('ClusterId') - step_id = self._get_param('StepId') + cluster_id = self._get_param("ClusterId") + step_id = self._get_param("StepId") step = self.backend.describe_step(cluster_id, step_id) template = self.response_template(DESCRIBE_STEP_TEMPLATE) return template.render(step=step) - @generate_boto3_response('ListBootstrapActions') + @generate_boto3_response("ListBootstrapActions") def list_bootstrap_actions(self): - cluster_id = self._get_param('ClusterId') - marker = self._get_param('Marker') + cluster_id = self._get_param("ClusterId") + marker = self._get_param("Marker") bootstrap_actions, marker = self.backend.list_bootstrap_actions( - cluster_id, marker) + cluster_id, marker + ) template = self.response_template(LIST_BOOTSTRAP_ACTIONS_TEMPLATE) return template.render(bootstrap_actions=bootstrap_actions, marker=marker) - @generate_boto3_response('ListClusters') + @generate_boto3_response("ListClusters") def list_clusters(self): - cluster_states = self._get_multi_param('ClusterStates.member') - created_after = self._get_param('CreatedAfter') - created_before = self._get_param('CreatedBefore') - marker = self._get_param('Marker') + cluster_states = self._get_multi_param("ClusterStates.member") + created_after = self._get_param("CreatedAfter") + created_before = self._get_param("CreatedBefore") + marker = self._get_param("Marker") clusters, marker = self.backend.list_clusters( - cluster_states, created_after, created_before, marker) + cluster_states, created_after, created_before, marker + ) template = self.response_template(LIST_CLUSTERS_TEMPLATE) return template.render(clusters=clusters, marker=marker) - @generate_boto3_response('ListInstanceGroups') + @generate_boto3_response("ListInstanceGroups") def list_instance_groups(self): - cluster_id = self._get_param('ClusterId') - marker = self._get_param('Marker') + cluster_id = self._get_param("ClusterId") + marker = self._get_param("Marker") instance_groups, marker = self.backend.list_instance_groups( - cluster_id, marker=marker) + cluster_id, marker=marker + ) template = self.response_template(LIST_INSTANCE_GROUPS_TEMPLATE) return template.render(instance_groups=instance_groups, marker=marker) def list_instances(self): raise NotImplementedError - @generate_boto3_response('ListSteps') + @generate_boto3_response("ListSteps") def list_steps(self): - cluster_id = self._get_param('ClusterId') - marker = self._get_param('Marker') - step_ids = self._get_multi_param('StepIds.member') - step_states = self._get_multi_param('StepStates.member') + cluster_id = self._get_param("ClusterId") + marker = self._get_param("Marker") + step_ids = self._get_multi_param("StepIds.member") + step_states = self._get_multi_param("StepStates.member") steps, marker = self.backend.list_steps( - cluster_id, marker=marker, step_ids=step_ids, step_states=step_states) + cluster_id, marker=marker, step_ids=step_ids, step_states=step_states + ) template = self.response_template(LIST_STEPS_TEMPLATE) return template.render(steps=steps, marker=marker) - @generate_boto3_response('ModifyInstanceGroups') + @generate_boto3_response("ModifyInstanceGroups") def modify_instance_groups(self): - instance_groups = self._get_list_prefix('InstanceGroups.member') + instance_groups = self._get_list_prefix("InstanceGroups.member") for item in instance_groups: - item['instance_count'] = int(item['instance_count']) + item["instance_count"] = int(item["instance_count"]) instance_groups = self.backend.modify_instance_groups(instance_groups) template = self.response_template(MODIFY_INSTANCE_GROUPS_TEMPLATE) return template.render(instance_groups=instance_groups) - @generate_boto3_response('RemoveTags') + @generate_boto3_response("RemoveTags") def remove_tags(self): - cluster_id = self._get_param('ResourceId') - tag_keys = self._get_multi_param('TagKeys.member') + cluster_id = self._get_param("ResourceId") + tag_keys = self._get_multi_param("TagKeys.member") self.backend.remove_tags(cluster_id, tag_keys) template = self.response_template(REMOVE_TAGS_TEMPLATE) return template.render() - @generate_boto3_response('RunJobFlow') + @generate_boto3_response("RunJobFlow") def run_job_flow(self): instance_attrs = dict( - master_instance_type=self._get_param( - 'Instances.MasterInstanceType'), - slave_instance_type=self._get_param('Instances.SlaveInstanceType'), - instance_count=self._get_int_param('Instances.InstanceCount', 1), - ec2_key_name=self._get_param('Instances.Ec2KeyName'), - ec2_subnet_id=self._get_param('Instances.Ec2SubnetId'), - hadoop_version=self._get_param('Instances.HadoopVersion'), + master_instance_type=self._get_param("Instances.MasterInstanceType"), + slave_instance_type=self._get_param("Instances.SlaveInstanceType"), + instance_count=self._get_int_param("Instances.InstanceCount", 1), + ec2_key_name=self._get_param("Instances.Ec2KeyName"), + ec2_subnet_id=self._get_param("Instances.Ec2SubnetId"), + hadoop_version=self._get_param("Instances.HadoopVersion"), availability_zone=self._get_param( - 'Instances.Placement.AvailabilityZone', self.backend.region_name + 'a'), + "Instances.Placement.AvailabilityZone", self.backend.region_name + "a" + ), keep_job_flow_alive_when_no_steps=self._get_bool_param( - 'Instances.KeepJobFlowAliveWhenNoSteps', False), + "Instances.KeepJobFlowAliveWhenNoSteps", False + ), termination_protected=self._get_bool_param( - 'Instances.TerminationProtected', False), + "Instances.TerminationProtected", False + ), emr_managed_master_security_group=self._get_param( - 'Instances.EmrManagedMasterSecurityGroup'), + "Instances.EmrManagedMasterSecurityGroup" + ), emr_managed_slave_security_group=self._get_param( - 'Instances.EmrManagedSlaveSecurityGroup'), + "Instances.EmrManagedSlaveSecurityGroup" + ), service_access_security_group=self._get_param( - 'Instances.ServiceAccessSecurityGroup'), + "Instances.ServiceAccessSecurityGroup" + ), additional_master_security_groups=self._get_multi_param( - 'Instances.AdditionalMasterSecurityGroups.member.'), - additional_slave_security_groups=self._get_multi_param('Instances.AdditionalSlaveSecurityGroups.member.')) + "Instances.AdditionalMasterSecurityGroups.member." + ), + additional_slave_security_groups=self._get_multi_param( + "Instances.AdditionalSlaveSecurityGroups.member." + ), + ) kwargs = dict( - name=self._get_param('Name'), - log_uri=self._get_param('LogUri'), - job_flow_role=self._get_param('JobFlowRole'), - service_role=self._get_param('ServiceRole'), - steps=steps_from_query_string( - self._get_list_prefix('Steps.member')), - visible_to_all_users=self._get_bool_param( - 'VisibleToAllUsers', False), + name=self._get_param("Name"), + log_uri=self._get_param("LogUri"), + job_flow_role=self._get_param("JobFlowRole"), + service_role=self._get_param("ServiceRole"), + steps=steps_from_query_string(self._get_list_prefix("Steps.member")), + visible_to_all_users=self._get_bool_param("VisibleToAllUsers", False), instance_attrs=instance_attrs, ) - bootstrap_actions = self._get_list_prefix('BootstrapActions.member') + bootstrap_actions = self._get_list_prefix("BootstrapActions.member") if bootstrap_actions: for ba in bootstrap_actions: args = [] idx = 1 - keyfmt = 'script_bootstrap_action._args.member.{0}' + keyfmt = "script_bootstrap_action._args.member.{0}" key = keyfmt.format(idx) while key in ba: args.append(ba.pop(key)) idx += 1 key = keyfmt.format(idx) - ba['args'] = args - ba['script_path'] = ba.pop('script_bootstrap_action._path') - kwargs['bootstrap_actions'] = bootstrap_actions + ba["args"] = args + ba["script_path"] = ba.pop("script_bootstrap_action._path") + kwargs["bootstrap_actions"] = bootstrap_actions - configurations = self._get_list_prefix('Configurations.member') + configurations = self._get_list_prefix("Configurations.member") if configurations: for idx, config in enumerate(configurations, 1): for key in list(config.keys()): - if key.startswith('properties.'): + if key.startswith("properties."): config.pop(key) - config['properties'] = {} + config["properties"] = {} map_items = self._get_map_prefix( - 'Configurations.member.{0}.Properties.entry'.format(idx)) - config['properties'] = map_items + "Configurations.member.{0}.Properties.entry".format(idx) + ) + config["properties"] = map_items - kwargs['configurations'] = configurations + kwargs["configurations"] = configurations - release_label = self._get_param('ReleaseLabel') - ami_version = self._get_param('AmiVersion') + release_label = self._get_param("ReleaseLabel") + ami_version = self._get_param("AmiVersion") if release_label: - kwargs['release_label'] = release_label + kwargs["release_label"] = release_label if ami_version: message = ( - 'Only one AMI version and release label may be specified. ' - 'Provided AMI: {0}, release label: {1}.').format( - ami_version, release_label) - raise EmrError(error_type="ValidationException", - message=message, template='error_json') + "Only one AMI version and release label may be specified. " + "Provided AMI: {0}, release label: {1}." + ).format(ami_version, release_label) + raise EmrError( + error_type="ValidationException", + message=message, + template="error_json", + ) else: if ami_version: - kwargs['requested_ami_version'] = ami_version - kwargs['running_ami_version'] = ami_version + kwargs["requested_ami_version"] = ami_version + kwargs["running_ami_version"] = ami_version else: - kwargs['running_ami_version'] = '1.0.0' + kwargs["running_ami_version"] = "1.0.0" - custom_ami_id = self._get_param('CustomAmiId') + custom_ami_id = self._get_param("CustomAmiId") if custom_ami_id: - kwargs['custom_ami_id'] = custom_ami_id - if release_label and release_label < 'emr-5.7.0': - message = 'Custom AMI is not allowed' - raise EmrError(error_type='ValidationException', - message=message, template='error_json') + kwargs["custom_ami_id"] = custom_ami_id + if release_label and release_label < "emr-5.7.0": + message = "Custom AMI is not allowed" + raise EmrError( + error_type="ValidationException", + message=message, + template="error_json", + ) elif ami_version: - message = 'Custom AMI is not supported in this version of EMR' - raise EmrError(error_type='ValidationException', - message=message, template='error_json') + message = "Custom AMI is not supported in this version of EMR" + raise EmrError( + error_type="ValidationException", + message=message, + template="error_json", + ) cluster = self.backend.run_job_flow(**kwargs) - applications = self._get_list_prefix('Applications.member') + applications = self._get_list_prefix("Applications.member") if applications: self.backend.add_applications(cluster.id, applications) else: self.backend.add_applications( - cluster.id, [{'Name': 'Hadoop', 'Version': '0.18'}]) + cluster.id, [{"Name": "Hadoop", "Version": "0.18"}] + ) - instance_groups = self._get_list_prefix( - 'Instances.InstanceGroups.member') + instance_groups = self._get_list_prefix("Instances.InstanceGroups.member") if instance_groups: for ig in instance_groups: - ig['instance_count'] = int(ig['instance_count']) + ig["instance_count"] = int(ig["instance_count"]) self.backend.add_instance_groups(cluster.id, instance_groups) - tags = self._get_list_prefix('Tags.member') + tags = self._get_list_prefix("Tags.member") if tags: self.backend.add_tags( - cluster.id, dict((d['key'], d['value']) for d in tags)) + cluster.id, dict((d["key"], d["value"]) for d in tags) + ) template = self.response_template(RUN_JOB_FLOW_TEMPLATE) return template.render(cluster=cluster) - @generate_boto3_response('SetTerminationProtection') + @generate_boto3_response("SetTerminationProtection") def set_termination_protection(self): - termination_protection = self._get_param('TerminationProtected') - job_ids = self._get_multi_param('JobFlowIds.member') - self.backend.set_termination_protection( - job_ids, termination_protection) + termination_protection = self._get_param("TerminationProtected") + job_ids = self._get_multi_param("JobFlowIds.member") + self.backend.set_termination_protection(job_ids, termination_protection) template = self.response_template(SET_TERMINATION_PROTECTION_TEMPLATE) return template.render() - @generate_boto3_response('SetVisibleToAllUsers') + @generate_boto3_response("SetVisibleToAllUsers") def set_visible_to_all_users(self): - visible_to_all_users = self._get_param('VisibleToAllUsers') - job_ids = self._get_multi_param('JobFlowIds.member') + visible_to_all_users = self._get_param("VisibleToAllUsers") + job_ids = self._get_multi_param("JobFlowIds.member") self.backend.set_visible_to_all_users(job_ids, visible_to_all_users) template = self.response_template(SET_VISIBLE_TO_ALL_USERS_TEMPLATE) return template.render() - @generate_boto3_response('TerminateJobFlows') + @generate_boto3_response("TerminateJobFlows") def terminate_job_flows(self): - job_ids = self._get_multi_param('JobFlowIds.member.') + job_ids = self._get_multi_param("JobFlowIds.member.") self.backend.terminate_job_flows(job_ids) template = self.response_template(TERMINATE_JOB_FLOWS_TEMPLATE) return template.render() diff --git a/moto/emr/urls.py b/moto/emr/urls.py index 870eaf9d7..81275135d 100644 --- a/moto/emr/urls.py +++ b/moto/emr/urls.py @@ -6,6 +6,4 @@ url_bases = [ "https?://elasticmapreduce.(.+).amazonaws.com", ] -url_paths = { - '{0}/$': ElasticMapReduceResponse.dispatch, -} +url_paths = {"{0}/$": ElasticMapReduceResponse.dispatch} diff --git a/moto/emr/utils.py b/moto/emr/utils.py index 4f12522cf..0f75995b8 100644 --- a/moto/emr/utils.py +++ b/moto/emr/utils.py @@ -7,24 +7,24 @@ import six def random_id(size=13): chars = list(range(10)) + list(string.ascii_uppercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) def random_cluster_id(size=13): - return 'j-{0}'.format(random_id()) + return "j-{0}".format(random_id()) def random_step_id(size=13): - return 's-{0}'.format(random_id()) + return "s-{0}".format(random_id()) def random_instance_group_id(size=13): - return 'i-{0}'.format(random_id()) + return "i-{0}".format(random_id()) def tags_from_query_string(querystring_dict): - prefix = 'Tags' - suffix = 'Key' + prefix = "Tags" + suffix = "Key" response_values = {} for key, value in querystring_dict.items(): if key.startswith(prefix) and key.endswith(suffix): @@ -32,8 +32,7 @@ def tags_from_query_string(querystring_dict): tag_key = querystring_dict.get("Tags.{0}.Key".format(tag_index))[0] tag_value_key = "Tags.{0}.Value".format(tag_index) if tag_value_key in querystring_dict: - response_values[tag_key] = querystring_dict.get(tag_value_key)[ - 0] + response_values[tag_key] = querystring_dict.get(tag_value_key)[0] else: response_values[tag_key] = None return response_values @@ -42,14 +41,15 @@ def tags_from_query_string(querystring_dict): def steps_from_query_string(querystring_dict): steps = [] for step in querystring_dict: - step['jar'] = step.pop('hadoop_jar_step._jar') - step['properties'] = dict((o['Key'], o['Value']) - for o in step.get('properties', [])) - step['args'] = [] + step["jar"] = step.pop("hadoop_jar_step._jar") + step["properties"] = dict( + (o["Key"], o["Value"]) for o in step.get("properties", []) + ) + step["args"] = [] idx = 1 - keyfmt = 'hadoop_jar_step._args.member.{0}' + keyfmt = "hadoop_jar_step._args.member.{0}" while keyfmt.format(idx) in step: - step['args'].append(step.pop(keyfmt.format(idx))) + step["args"].append(step.pop(keyfmt.format(idx))) idx += 1 steps.append(step) return steps diff --git a/moto/events/__init__.py b/moto/events/__init__.py index 8f2730c84..8fd414325 100644 --- a/moto/events/__init__.py +++ b/moto/events/__init__.py @@ -3,5 +3,5 @@ from __future__ import unicode_literals from .models import events_backends from ..core.models import base_decorator -events_backend = events_backends['us-east-1'] +events_backend = events_backends["us-east-1"] mock_events = base_decorator(events_backends) diff --git a/moto/events/models.py b/moto/events/models.py index 7871bae7b..e69062b2c 100644 --- a/moto/events/models.py +++ b/moto/events/models.py @@ -8,42 +8,40 @@ from moto.core import BaseBackend, BaseModel class Rule(BaseModel): - def _generate_arn(self, name): - return 'arn:aws:events:{region_name}:111111111111:rule/{name}'.format( - region_name=self.region_name, - name=name + return "arn:aws:events:{region_name}:111111111111:rule/{name}".format( + region_name=self.region_name, name=name ) def __init__(self, name, region_name, **kwargs): self.name = name self.region_name = region_name - self.arn = kwargs.get('Arn') or self._generate_arn(name) - self.event_pattern = kwargs.get('EventPattern') - self.schedule_exp = kwargs.get('ScheduleExpression') - self.state = kwargs.get('State') or 'ENABLED' - self.description = kwargs.get('Description') - self.role_arn = kwargs.get('RoleArn') + self.arn = kwargs.get("Arn") or self._generate_arn(name) + self.event_pattern = kwargs.get("EventPattern") + self.schedule_exp = kwargs.get("ScheduleExpression") + self.state = kwargs.get("State") or "ENABLED" + self.description = kwargs.get("Description") + self.role_arn = kwargs.get("RoleArn") self.targets = [] def enable(self): - self.state = 'ENABLED' + self.state = "ENABLED" def disable(self): - self.state = 'DISABLED' + self.state = "DISABLED" # This song and dance for targets is because we need order for Limits and NextTokens, but can't use OrderedDicts # with Python 2.6, so tracking it with an array it is. def _check_target_exists(self, target_id): for i in range(0, len(self.targets)): - if target_id == self.targets[i]['Id']: + if target_id == self.targets[i]["Id"]: return i return None def put_targets(self, targets): # Not testing for valid ARNs. for target in targets: - index = self._check_target_exists(target['Id']) + index = self._check_target_exists(target["Id"]) if index is not None: self.targets[index] = target else: @@ -57,8 +55,8 @@ class Rule(BaseModel): class EventsBackend(BaseBackend): - ACCOUNT_ID = re.compile(r'^(\d{1,12}|\*)$') - STATEMENT_ID = re.compile(r'^[a-zA-Z0-9-_]{1,64}$') + ACCOUNT_ID = re.compile(r"^(\d{1,12}|\*)$") + STATEMENT_ID = re.compile(r"^[a-zA-Z0-9-_]{1,64}$") def __init__(self, region_name): self.rules = {} @@ -78,7 +76,7 @@ class EventsBackend(BaseBackend): return self.rules.get(self.rules_order[i]) def _gen_next_token(self, index): - token = os.urandom(128).encode('base64') + token = os.urandom(128).encode("base64") self.next_tokens[token] = index return token @@ -124,24 +122,25 @@ class EventsBackend(BaseBackend): return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( - len(self.rules), next_token, limit) + len(self.rules), next_token, limit + ) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) for target in rule.targets: - if target['Arn'] == target_arn: + if target["Arn"] == target_arn: matching_rules.append(rule.name) - return_obj['RuleNames'] = matching_rules + return_obj["RuleNames"] = matching_rules if new_next_token is not None: - return_obj['NextToken'] = new_next_token + return_obj["NextToken"] = new_next_token return return_obj def list_rules(self, prefix=None, next_token=None, limit=None): - match_string = '.*' + match_string = ".*" if prefix is not None: - match_string = '^' + prefix + match_string + match_string = "^" + prefix + match_string match_regex = re.compile(match_string) @@ -149,16 +148,17 @@ class EventsBackend(BaseBackend): return_obj = {} start_index, end_index, new_next_token = self._process_token_and_limits( - len(self.rules), next_token, limit) + len(self.rules), next_token, limit + ) for i in range(start_index, end_index): rule = self._get_rule_by_index(i) if match_regex.match(rule.name): matching_rules.append(rule) - return_obj['Rules'] = matching_rules + return_obj["Rules"] = matching_rules if new_next_token is not None: - return_obj['NextToken'] = new_next_token + return_obj["NextToken"] = new_next_token return return_obj @@ -168,7 +168,8 @@ class EventsBackend(BaseBackend): rule = self.rules[rule] start_index, end_index, new_next_token = self._process_token_and_limits( - len(rule.targets), next_token, limit) + len(rule.targets), next_token, limit + ) returned_targets = [] return_obj = {} @@ -176,9 +177,9 @@ class EventsBackend(BaseBackend): for i in range(start_index, end_index): returned_targets.append(rule.targets[i]) - return_obj['Targets'] = returned_targets + return_obj["Targets"] = returned_targets if new_next_token is not None: - return_obj['NextToken'] = new_next_token + return_obj["NextToken"] = new_next_token return return_obj @@ -201,9 +202,9 @@ class EventsBackend(BaseBackend): num_events = len(events) if num_events < 1: - raise JsonRESTError('ValidationError', 'Need at least 1 event') + raise JsonRESTError("ValidationError", "Need at least 1 event") elif num_events > 10: - raise JsonRESTError('ValidationError', 'Can only submit 10 events at once') + raise JsonRESTError("ValidationError", "Can only submit 10 events at once") # We dont really need to store the events yet return [] @@ -221,41 +222,47 @@ class EventsBackend(BaseBackend): raise NotImplementedError() def put_permission(self, action, principal, statement_id): - if action is None or action != 'events:PutEvents': - raise JsonRESTError('InvalidParameterValue', 'Action must be PutEvents') + if action is None or action != "events:PutEvents": + raise JsonRESTError("InvalidParameterValue", "Action must be PutEvents") if principal is None or self.ACCOUNT_ID.match(principal) is None: - raise JsonRESTError('InvalidParameterValue', 'Principal must match ^(\d{1,12}|\*)$') + raise JsonRESTError( + "InvalidParameterValue", "Principal must match ^(\d{1,12}|\*)$" + ) if statement_id is None or self.STATEMENT_ID.match(statement_id) is None: - raise JsonRESTError('InvalidParameterValue', 'StatementId must match ^[a-zA-Z0-9-_]{1,64}$') + raise JsonRESTError( + "InvalidParameterValue", "StatementId must match ^[a-zA-Z0-9-_]{1,64}$" + ) - self.permissions[statement_id] = {'action': action, 'principal': principal} + self.permissions[statement_id] = {"action": action, "principal": principal} def remove_permission(self, statement_id): try: del self.permissions[statement_id] except KeyError: - raise JsonRESTError('ResourceNotFoundException', 'StatementId not found') + raise JsonRESTError("ResourceNotFoundException", "StatementId not found") def describe_event_bus(self): - arn = "arn:aws:events:{0}:000000000000:event-bus/default".format(self.region_name) + arn = "arn:aws:events:{0}:000000000000:event-bus/default".format( + self.region_name + ) statements = [] for statement_id, data in self.permissions.items(): - statements.append({ - 'Sid': statement_id, - 'Effect': 'Allow', - 'Principal': {'AWS': 'arn:aws:iam::{0}:root'.format(data['principal'])}, - 'Action': data['action'], - 'Resource': arn - }) - policy = {'Version': '2012-10-17', 'Statement': statements} + statements.append( + { + "Sid": statement_id, + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::{0}:root".format(data["principal"]) + }, + "Action": data["action"], + "Resource": arn, + } + ) + policy = {"Version": "2012-10-17", "Statement": statements} policy_json = json.dumps(policy) - return { - 'Policy': policy_json, - 'Name': 'default', - 'Arn': arn - } + return {"Policy": policy_json, "Name": "default", "Arn": arn} available_regions = boto3.session.Session().get_available_regions("events") diff --git a/moto/events/responses.py b/moto/events/responses.py index 2eb72d342..39c5c75dc 100644 --- a/moto/events/responses.py +++ b/moto/events/responses.py @@ -6,7 +6,6 @@ from moto.events import events_backends class EventsHandler(BaseResponse): - @property def events_backend(self): """ @@ -19,18 +18,18 @@ class EventsHandler(BaseResponse): def _generate_rule_dict(self, rule): return { - 'Name': rule.name, - 'Arn': rule.arn, - 'EventPattern': rule.event_pattern, - 'State': rule.state, - 'Description': rule.description, - 'ScheduleExpression': rule.schedule_exp, - 'RoleArn': rule.role_arn + "Name": rule.name, + "Arn": rule.arn, + "EventPattern": rule.event_pattern, + "State": rule.state, + "Description": rule.description, + "ScheduleExpression": rule.schedule_exp, + "RoleArn": rule.role_arn, } @property def request_params(self): - if not hasattr(self, '_json_body'): + if not hasattr(self, "_json_body"): try: self._json_body = json.loads(self.body) except ValueError: @@ -40,127 +39,134 @@ class EventsHandler(BaseResponse): def _get_param(self, param, if_none=None): return self.request_params.get(param, if_none) - def error(self, type_, message='', status=400): + def error(self, type_, message="", status=400): headers = self.response_headers - headers['status'] = status - return json.dumps({'__type': type_, 'message': message}), headers, + headers["status"] = status + return json.dumps({"__type": type_, "message": message}), headers def delete_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") self.events_backend.delete_rule(name) - return '', self.response_headers + return "", self.response_headers def describe_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") rule = self.events_backend.describe_rule(name) if not rule: - return self.error('ResourceNotFoundException', 'Rule test does not exist.') + return self.error("ResourceNotFoundException", "Rule test does not exist.") rule_dict = self._generate_rule_dict(rule) return json.dumps(rule_dict), self.response_headers def disable_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") if not self.events_backend.disable_rule(name): - return self.error('ResourceNotFoundException', 'Rule ' + name + ' does not exist.') + return self.error( + "ResourceNotFoundException", "Rule " + name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def enable_rule(self): - name = self._get_param('Name') + name = self._get_param("Name") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") if not self.events_backend.enable_rule(name): - return self.error('ResourceNotFoundException', 'Rule ' + name + ' does not exist.') + return self.error( + "ResourceNotFoundException", "Rule " + name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def generate_presigned_url(self): pass def list_rule_names_by_target(self): - target_arn = self._get_param('TargetArn') - next_token = self._get_param('NextToken') - limit = self._get_param('Limit') + target_arn = self._get_param("TargetArn") + next_token = self._get_param("NextToken") + limit = self._get_param("Limit") if not target_arn: - return self.error('ValidationException', 'Parameter TargetArn is required.') + return self.error("ValidationException", "Parameter TargetArn is required.") rule_names = self.events_backend.list_rule_names_by_target( - target_arn, next_token, limit) + target_arn, next_token, limit + ) return json.dumps(rule_names), self.response_headers def list_rules(self): - prefix = self._get_param('NamePrefix') - next_token = self._get_param('NextToken') - limit = self._get_param('Limit') + prefix = self._get_param("NamePrefix") + next_token = self._get_param("NextToken") + limit = self._get_param("Limit") rules = self.events_backend.list_rules(prefix, next_token, limit) - rules_obj = {'Rules': []} + rules_obj = {"Rules": []} - for rule in rules['Rules']: - rules_obj['Rules'].append(self._generate_rule_dict(rule)) + for rule in rules["Rules"]: + rules_obj["Rules"].append(self._generate_rule_dict(rule)) - if rules.get('NextToken'): - rules_obj['NextToken'] = rules['NextToken'] + if rules.get("NextToken"): + rules_obj["NextToken"] = rules["NextToken"] return json.dumps(rules_obj), self.response_headers def list_targets_by_rule(self): - rule_name = self._get_param('Rule') - next_token = self._get_param('NextToken') - limit = self._get_param('Limit') + rule_name = self._get_param("Rule") + next_token = self._get_param("NextToken") + limit = self._get_param("Limit") if not rule_name: - return self.error('ValidationException', 'Parameter Rule is required.') + return self.error("ValidationException", "Parameter Rule is required.") try: targets = self.events_backend.list_targets_by_rule( - rule_name, next_token, limit) + rule_name, next_token, limit + ) except KeyError: - return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.') + return self.error( + "ResourceNotFoundException", "Rule " + rule_name + " does not exist." + ) return json.dumps(targets), self.response_headers def put_events(self): - events = self._get_param('Entries') + events = self._get_param("Entries") failed_entries = self.events_backend.put_events(events) if failed_entries: - return json.dumps({ - 'FailedEntryCount': len(failed_entries), - 'Entries': failed_entries - }) + return json.dumps( + {"FailedEntryCount": len(failed_entries), "Entries": failed_entries} + ) - return '', self.response_headers + return "", self.response_headers def put_rule(self): - name = self._get_param('Name') - event_pattern = self._get_param('EventPattern') - sched_exp = self._get_param('ScheduleExpression') - state = self._get_param('State') - desc = self._get_param('Description') - role_arn = self._get_param('RoleArn') + name = self._get_param("Name") + event_pattern = self._get_param("EventPattern") + sched_exp = self._get_param("ScheduleExpression") + state = self._get_param("State") + desc = self._get_param("Description") + role_arn = self._get_param("RoleArn") if not name: - return self.error('ValidationException', 'Parameter Name is required.') + return self.error("ValidationException", "Parameter Name is required.") if event_pattern: try: @@ -168,12 +174,20 @@ class EventsHandler(BaseResponse): except ValueError: # Not quite as informative as the real error, but it'll work # for now. - return self.error('InvalidEventPatternException', 'Event pattern is not valid.') + return self.error( + "InvalidEventPatternException", "Event pattern is not valid." + ) if sched_exp: - if not (re.match('^cron\(.*\)', sched_exp) or - re.match('^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)', sched_exp)): - return self.error('ValidationException', 'Parameter ScheduleExpression is not valid.') + if not ( + re.match("^cron\(.*\)", sched_exp) + or re.match( + "^rate\(\d*\s(minute|minutes|hour|hours|day|days)\)", sched_exp + ) + ): + return self.error( + "ValidationException", "Parameter ScheduleExpression is not valid." + ) rule_arn = self.events_backend.put_rule( name, @@ -181,59 +195,63 @@ class EventsHandler(BaseResponse): EventPattern=event_pattern, State=state, Description=desc, - RoleArn=role_arn + RoleArn=role_arn, ) - return json.dumps({'RuleArn': rule_arn}), self.response_headers + return json.dumps({"RuleArn": rule_arn}), self.response_headers def put_targets(self): - rule_name = self._get_param('Rule') - targets = self._get_param('Targets') + rule_name = self._get_param("Rule") + targets = self._get_param("Targets") if not rule_name: - return self.error('ValidationException', 'Parameter Rule is required.') + return self.error("ValidationException", "Parameter Rule is required.") if not targets: - return self.error('ValidationException', 'Parameter Targets is required.') + return self.error("ValidationException", "Parameter Targets is required.") if not self.events_backend.put_targets(rule_name, targets): - return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.') + return self.error( + "ResourceNotFoundException", "Rule " + rule_name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def remove_targets(self): - rule_name = self._get_param('Rule') - ids = self._get_param('Ids') + rule_name = self._get_param("Rule") + ids = self._get_param("Ids") if not rule_name: - return self.error('ValidationException', 'Parameter Rule is required.') + return self.error("ValidationException", "Parameter Rule is required.") if not ids: - return self.error('ValidationException', 'Parameter Ids is required.') + return self.error("ValidationException", "Parameter Ids is required.") if not self.events_backend.remove_targets(rule_name, ids): - return self.error('ResourceNotFoundException', 'Rule ' + rule_name + ' does not exist.') + return self.error( + "ResourceNotFoundException", "Rule " + rule_name + " does not exist." + ) - return '', self.response_headers + return "", self.response_headers def test_event_pattern(self): pass def put_permission(self): - action = self._get_param('Action') - principal = self._get_param('Principal') - statement_id = self._get_param('StatementId') + action = self._get_param("Action") + principal = self._get_param("Principal") + statement_id = self._get_param("StatementId") self.events_backend.put_permission(action, principal, statement_id) - return '' + return "" def remove_permission(self): - statement_id = self._get_param('StatementId') + statement_id = self._get_param("StatementId") self.events_backend.remove_permission(statement_id) - return '' + return "" def describe_event_bus(self): return json.dumps(self.events_backend.describe_event_bus()) diff --git a/moto/events/urls.py b/moto/events/urls.py index a6e533b08..39e6a3462 100644 --- a/moto/events/urls.py +++ b/moto/events/urls.py @@ -2,10 +2,6 @@ from __future__ import unicode_literals from .responses import EventsHandler -url_bases = [ - "https?://events.(.+).amazonaws.com" -] +url_bases = ["https?://events.(.+).amazonaws.com"] -url_paths = { - "{0}/": EventsHandler.dispatch, -} +url_paths = {"{0}/": EventsHandler.dispatch} diff --git a/moto/glacier/__init__.py b/moto/glacier/__init__.py index 1570fa7d4..270d580f5 100644 --- a/moto/glacier/__init__.py +++ b/moto/glacier/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import glacier_backends from ..core.models import base_decorator, deprecated_base_decorator -glacier_backend = glacier_backends['us-east-1'] +glacier_backend = glacier_backends["us-east-1"] mock_glacier = base_decorator(glacier_backends) mock_glacier_deprecated = deprecated_base_decorator(glacier_backends) diff --git a/moto/glacier/models.py b/moto/glacier/models.py index 2c16bc97d..6a3fc074d 100644 --- a/moto/glacier/models.py +++ b/moto/glacier/models.py @@ -25,7 +25,6 @@ class Job(BaseModel): class ArchiveJob(Job): - def __init__(self, job_id, tier, arn, archive_id): self.job_id = job_id self.tier = tier @@ -50,7 +49,7 @@ class ArchiveJob(Job): "StatusCode": "InProgress", "StatusMessage": None, "VaultARN": self.arn, - "Tier": self.tier + "Tier": self.tier, } if datetime.datetime.now() > self.et: d["Completed"] = True @@ -61,7 +60,6 @@ class ArchiveJob(Job): class InventoryJob(Job): - def __init__(self, job_id, tier, arn): self.job_id = job_id self.tier = tier @@ -83,7 +81,7 @@ class InventoryJob(Job): "StatusCode": "InProgress", "StatusMessage": None, "VaultARN": self.arn, - "Tier": self.tier + "Tier": self.tier, } if datetime.datetime.now() > self.et: d["Completed"] = True @@ -94,7 +92,6 @@ class InventoryJob(Job): class Vault(BaseModel): - def __init__(self, vault_name, region): self.st = datetime.datetime.now() self.vault_name = vault_name @@ -104,7 +101,9 @@ class Vault(BaseModel): @property def arn(self): - return "arn:aws:glacier:{0}:012345678901:vaults/{1}".format(self.region, self.vault_name) + return "arn:aws:glacier:{0}:012345678901:vaults/{1}".format( + self.region, self.vault_name + ) def to_dict(self): archives_size = 0 @@ -126,7 +125,9 @@ class Vault(BaseModel): self.archives[archive_id]["body"] = body self.archives[archive_id]["size"] = len(body) self.archives[archive_id]["sha256"] = hashlib.sha256(body).hexdigest() - self.archives[archive_id]["creation_date"] = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.000Z") + self.archives[archive_id]["creation_date"] = datetime.datetime.now().strftime( + "%Y-%m-%dT%H:%M:%S.000Z" + ) self.archives[archive_id]["description"] = description return archive_id @@ -142,7 +143,7 @@ class Vault(BaseModel): "ArchiveDescription": archive["description"], "CreationDate": archive["creation_date"], "Size": archive["size"], - "SHA256TreeHash": archive["sha256"] + "SHA256TreeHash": archive["sha256"], } archive_list.append(aobj) return archive_list @@ -180,7 +181,7 @@ class Vault(BaseModel): return { "VaultARN": self.arn, "InventoryDate": jobj["CompletionDate"], - "ArchiveList": archives + "ArchiveList": archives, } else: archive_body = self.get_archive_body(job.archive_id) @@ -188,7 +189,6 @@ class Vault(BaseModel): class GlacierBackend(BaseBackend): - def __init__(self, region_name): self.vaults = {} self.region_name = region_name diff --git a/moto/glacier/responses.py b/moto/glacier/responses.py index abdf83e4f..5a82be479 100644 --- a/moto/glacier/responses.py +++ b/moto/glacier/responses.py @@ -9,7 +9,6 @@ from .utils import region_from_glacier_url, vault_from_glacier_url class GlacierResponse(_TemplateEnvironmentMixin): - def __init__(self, backend): super(GlacierResponse, self).__init__() self.backend = backend @@ -22,14 +21,11 @@ class GlacierResponse(_TemplateEnvironmentMixin): def _all_vault_response(self, request, full_url, headers): vaults = self.backend.list_vaules() - response = json.dumps({ - "Marker": None, - "VaultList": [ - vault.to_dict() for vault in vaults - ] - }) + response = json.dumps( + {"Marker": None, "VaultList": [vault.to_dict() for vault in vaults]} + ) - headers['content-type'] = 'application/json' + headers["content-type"] = "application/json" return 200, headers, response @classmethod @@ -44,16 +40,16 @@ class GlacierResponse(_TemplateEnvironmentMixin): querystring = parse_qs(parsed_url.query, keep_blank_values=True) vault_name = vault_from_glacier_url(full_url) - if method == 'GET': + if method == "GET": return self._vault_response_get(vault_name, querystring, headers) - elif method == 'PUT': + elif method == "PUT": return self._vault_response_put(vault_name, querystring, headers) - elif method == 'DELETE': + elif method == "DELETE": return self._vault_response_delete(vault_name, querystring, headers) def _vault_response_get(self, vault_name, querystring, headers): vault = self.backend.get_vault(vault_name) - headers['content-type'] = 'application/json' + headers["content-type"] = "application/json" return 200, headers, json.dumps(vault.to_dict()) def _vault_response_put(self, vault_name, querystring, headers): @@ -72,40 +68,46 @@ class GlacierResponse(_TemplateEnvironmentMixin): def _vault_archive_response(self, request, full_url, headers): method = request.method - if hasattr(request, 'body'): + if hasattr(request, "body"): body = request.body else: body = request.data description = "" - if 'x-amz-archive-description' in request.headers: - description = request.headers['x-amz-archive-description'] + if "x-amz-archive-description" in request.headers: + description = request.headers["x-amz-archive-description"] parsed_url = urlparse(full_url) querystring = parse_qs(parsed_url.query, keep_blank_values=True) vault_name = full_url.split("/")[-2] - if method == 'POST': - return self._vault_archive_response_post(vault_name, body, description, querystring, headers) + if method == "POST": + return self._vault_archive_response_post( + vault_name, body, description, querystring, headers + ) else: return 400, headers, "400 Bad Request" - def _vault_archive_response_post(self, vault_name, body, description, querystring, headers): + def _vault_archive_response_post( + self, vault_name, body, description, querystring, headers + ): vault = self.backend.get_vault(vault_name) vault_id = vault.create_archive(body, description) - headers['x-amz-archive-id'] = vault_id + headers["x-amz-archive-id"] = vault_id return 201, headers, "" @classmethod def vault_archive_individual_response(clazz, request, full_url, headers): region_name = region_from_glacier_url(full_url) response_instance = GlacierResponse(glacier_backends[region_name]) - return response_instance._vault_archive_individual_response(request, full_url, headers) + return response_instance._vault_archive_individual_response( + request, full_url, headers + ) def _vault_archive_individual_response(self, request, full_url, headers): method = request.method vault_name = full_url.split("/")[-3] archive_id = full_url.split("/")[-1] - if method == 'DELETE': + if method == "DELETE": vault = self.backend.get_vault(vault_name) vault.delete_archive(archive_id) return 204, headers, "" @@ -118,42 +120,47 @@ class GlacierResponse(_TemplateEnvironmentMixin): def _vault_jobs_response(self, request, full_url, headers): method = request.method - if hasattr(request, 'body'): + if hasattr(request, "body"): body = request.body else: body = request.data account_id = full_url.split("/")[1] vault_name = full_url.split("/")[-2] - if method == 'GET': + if method == "GET": jobs = self.backend.list_jobs(vault_name) - headers['content-type'] = 'application/json' - return 200, headers, json.dumps({ - "JobList": [ - job.to_dict() for job in jobs - ], - "Marker": None, - }) - elif method == 'POST': + headers["content-type"] = "application/json" + return ( + 200, + headers, + json.dumps( + {"JobList": [job.to_dict() for job in jobs], "Marker": None} + ), + ) + elif method == "POST": json_body = json.loads(body.decode("utf-8")) - job_type = json_body['Type'] + job_type = json_body["Type"] archive_id = None - if 'ArchiveId' in json_body: - archive_id = json_body['ArchiveId'] - if 'Tier' in json_body: + if "ArchiveId" in json_body: + archive_id = json_body["ArchiveId"] + if "Tier" in json_body: tier = json_body["Tier"] else: tier = "Standard" job_id = self.backend.initiate_job(vault_name, job_type, tier, archive_id) - headers['x-amz-job-id'] = job_id - headers['Location'] = "/{0}/vaults/{1}/jobs/{2}".format(account_id, vault_name, job_id) + headers["x-amz-job-id"] = job_id + headers["Location"] = "/{0}/vaults/{1}/jobs/{2}".format( + account_id, vault_name, job_id + ) return 202, headers, "" @classmethod def vault_jobs_individual_response(clazz, request, full_url, headers): region_name = region_from_glacier_url(full_url) response_instance = GlacierResponse(glacier_backends[region_name]) - return response_instance._vault_jobs_individual_response(request, full_url, headers) + return response_instance._vault_jobs_individual_response( + request, full_url, headers + ) def _vault_jobs_individual_response(self, request, full_url, headers): vault_name = full_url.split("/")[-3] @@ -176,10 +183,10 @@ class GlacierResponse(_TemplateEnvironmentMixin): if vault.job_ready(job_id): output = vault.get_job_output(job_id) if isinstance(output, dict): - headers['content-type'] = 'application/json' + headers["content-type"] = "application/json" return 200, headers, json.dumps(output) else: - headers['content-type'] = 'application/octet-stream' + headers["content-type"] = "application/octet-stream" return 200, headers, output else: return 404, headers, "404 Not Found" diff --git a/moto/glacier/urls.py b/moto/glacier/urls.py index 6038c2bb4..480b125af 100644 --- a/moto/glacier/urls.py +++ b/moto/glacier/urls.py @@ -1,16 +1,14 @@ from __future__ import unicode_literals from .responses import GlacierResponse -url_bases = [ - "https?://glacier.(.+).amazonaws.com", -] +url_bases = ["https?://glacier.(.+).amazonaws.com"] url_paths = { - '{0}/(?P.+)/vaults$': GlacierResponse.all_vault_response, - '{0}/(?P.+)/vaults/(?P[^/.]+)$': GlacierResponse.vault_response, - '{0}/(?P.+)/vaults/(?P.+)/archives$': GlacierResponse.vault_archive_response, - '{0}/(?P.+)/vaults/(?P.+)/archives/(?P.+)$': GlacierResponse.vault_archive_individual_response, - '{0}/(?P.+)/vaults/(?P.+)/jobs$': GlacierResponse.vault_jobs_response, - '{0}/(?P.+)/vaults/(?P.+)/jobs/(?P[^/.]+)$': GlacierResponse.vault_jobs_individual_response, - '{0}/(?P.+)/vaults/(?P.+)/jobs/(?P.+)/output$': GlacierResponse.vault_jobs_output_response, + "{0}/(?P.+)/vaults$": GlacierResponse.all_vault_response, + "{0}/(?P.+)/vaults/(?P[^/.]+)$": GlacierResponse.vault_response, + "{0}/(?P.+)/vaults/(?P.+)/archives$": GlacierResponse.vault_archive_response, + "{0}/(?P.+)/vaults/(?P.+)/archives/(?P.+)$": GlacierResponse.vault_archive_individual_response, + "{0}/(?P.+)/vaults/(?P.+)/jobs$": GlacierResponse.vault_jobs_response, + "{0}/(?P.+)/vaults/(?P.+)/jobs/(?P[^/.]+)$": GlacierResponse.vault_jobs_individual_response, + "{0}/(?P.+)/vaults/(?P.+)/jobs/(?P.+)/output$": GlacierResponse.vault_jobs_output_response, } diff --git a/moto/glacier/utils.py b/moto/glacier/utils.py index f4a869bf3..d6dd7c656 100644 --- a/moto/glacier/utils.py +++ b/moto/glacier/utils.py @@ -7,10 +7,10 @@ from six.moves.urllib.parse import urlparse def region_from_glacier_url(url): domain = urlparse(url).netloc - if '.' in domain: + if "." in domain: return domain.split(".")[1] else: - return 'us-east-1' + return "us-east-1" def vault_from_glacier_url(full_url): @@ -18,4 +18,6 @@ def vault_from_glacier_url(full_url): def get_job_id(): - return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(92)) + return "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(92) + ) diff --git a/moto/glue/exceptions.py b/moto/glue/exceptions.py index 8972adb35..c4b7048db 100644 --- a/moto/glue/exceptions.py +++ b/moto/glue/exceptions.py @@ -9,46 +9,38 @@ class GlueClientError(JsonRESTError): class AlreadyExistsException(GlueClientError): def __init__(self, typ): super(GlueClientError, self).__init__( - 'AlreadyExistsException', - '%s already exists.' % (typ), + "AlreadyExistsException", "%s already exists." % (typ) ) class DatabaseAlreadyExistsException(AlreadyExistsException): def __init__(self): - super(DatabaseAlreadyExistsException, self).__init__('Database') + super(DatabaseAlreadyExistsException, self).__init__("Database") class TableAlreadyExistsException(AlreadyExistsException): def __init__(self): - super(TableAlreadyExistsException, self).__init__('Table') + super(TableAlreadyExistsException, self).__init__("Table") class PartitionAlreadyExistsException(AlreadyExistsException): def __init__(self): - super(PartitionAlreadyExistsException, self).__init__('Partition') + super(PartitionAlreadyExistsException, self).__init__("Partition") class EntityNotFoundException(GlueClientError): def __init__(self, msg): - super(GlueClientError, self).__init__( - 'EntityNotFoundException', - msg, - ) + super(GlueClientError, self).__init__("EntityNotFoundException", msg) class DatabaseNotFoundException(EntityNotFoundException): def __init__(self, db): - super(DatabaseNotFoundException, self).__init__( - 'Database %s not found.' % db, - ) + super(DatabaseNotFoundException, self).__init__("Database %s not found." % db) class TableNotFoundException(EntityNotFoundException): def __init__(self, tbl): - super(TableNotFoundException, self).__init__( - 'Table %s not found.' % tbl, - ) + super(TableNotFoundException, self).__init__("Table %s not found." % tbl) class PartitionNotFoundException(EntityNotFoundException): diff --git a/moto/glue/models.py b/moto/glue/models.py index 0989e0e9b..8f3396d9a 100644 --- a/moto/glue/models.py +++ b/moto/glue/models.py @@ -4,7 +4,7 @@ import time from moto.core import BaseBackend, BaseModel from moto.compat import OrderedDict -from.exceptions import ( +from .exceptions import ( JsonRESTError, DatabaseAlreadyExistsException, DatabaseNotFoundException, @@ -17,7 +17,6 @@ from.exceptions import ( class GlueBackend(BaseBackend): - def __init__(self): self.databases = OrderedDict() @@ -66,14 +65,12 @@ class GlueBackend(BaseBackend): class FakeDatabase(BaseModel): - def __init__(self, database_name): self.name = database_name self.tables = OrderedDict() class FakeTable(BaseModel): - def __init__(self, database_name, table_name, table_input): self.database_name = database_name self.name = table_name @@ -98,10 +95,7 @@ class FakeTable(BaseModel): raise VersionNotFoundException() def as_dict(self, version=-1): - obj = { - 'DatabaseName': self.database_name, - 'Name': self.name, - } + obj = {"DatabaseName": self.database_name, "Name": self.name} obj.update(self.get_version(version)) return obj @@ -124,7 +118,7 @@ class FakeTable(BaseModel): def update_partition(self, old_values, partiton_input): partition = FakePartition(self.database_name, self.name, partiton_input) key = str(partition.values) - if old_values == partiton_input['Values']: + if old_values == partiton_input["Values"]: # Altering a partition in place. Don't remove it so the order of # returned partitions doesn't change if key not in self.partitions: @@ -151,13 +145,13 @@ class FakePartition(BaseModel): self.database_name = database_name self.table_name = table_name self.partition_input = partiton_input - self.values = self.partition_input.get('Values', []) + self.values = self.partition_input.get("Values", []) def as_dict(self): obj = { - 'DatabaseName': self.database_name, - 'TableName': self.table_name, - 'CreationTime': self.creation_time, + "DatabaseName": self.database_name, + "TableName": self.table_name, + "CreationTime": self.creation_time, } obj.update(self.partition_input) return obj diff --git a/moto/glue/responses.py b/moto/glue/responses.py index 875513e7f..bf7b5776b 100644 --- a/moto/glue/responses.py +++ b/moto/glue/responses.py @@ -7,12 +7,11 @@ from .models import glue_backend from .exceptions import ( PartitionAlreadyExistsException, PartitionNotFoundException, - TableNotFoundException + TableNotFoundException, ) class GlueResponse(BaseResponse): - @property def glue_backend(self): return glue_backend @@ -22,94 +21,94 @@ class GlueResponse(BaseResponse): return json.loads(self.body) def create_database(self): - database_name = self.parameters['DatabaseInput']['Name'] + database_name = self.parameters["DatabaseInput"]["Name"] self.glue_backend.create_database(database_name) return "" def get_database(self): - database_name = self.parameters.get('Name') + database_name = self.parameters.get("Name") database = self.glue_backend.get_database(database_name) - return json.dumps({'Database': {'Name': database.name}}) + return json.dumps({"Database": {"Name": database.name}}) def create_table(self): - database_name = self.parameters.get('DatabaseName') - table_input = self.parameters.get('TableInput') - table_name = table_input.get('Name') + database_name = self.parameters.get("DatabaseName") + table_input = self.parameters.get("TableInput") + table_name = table_input.get("Name") self.glue_backend.create_table(database_name, table_name, table_input) return "" def get_table(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('Name') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("Name") table = self.glue_backend.get_table(database_name, table_name) - return json.dumps({'Table': table.as_dict()}) + return json.dumps({"Table": table.as_dict()}) def update_table(self): - database_name = self.parameters.get('DatabaseName') - table_input = self.parameters.get('TableInput') - table_name = table_input.get('Name') + database_name = self.parameters.get("DatabaseName") + table_input = self.parameters.get("TableInput") + table_name = table_input.get("Name") table = self.glue_backend.get_table(database_name, table_name) table.update(table_input) return "" def get_table_versions(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) - return json.dumps({ - "TableVersions": [ - { - "Table": table.as_dict(version=n), - "VersionId": str(n + 1), - } for n in range(len(table.versions)) - ], - }) + return json.dumps( + { + "TableVersions": [ + {"Table": table.as_dict(version=n), "VersionId": str(n + 1)} + for n in range(len(table.versions)) + ] + } + ) def get_table_version(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) - ver_id = self.parameters.get('VersionId') + ver_id = self.parameters.get("VersionId") - return json.dumps({ - "TableVersion": { - "Table": table.as_dict(version=ver_id), - "VersionId": ver_id, - }, - }) + return json.dumps( + { + "TableVersion": { + "Table": table.as_dict(version=ver_id), + "VersionId": ver_id, + } + } + ) def get_tables(self): - database_name = self.parameters.get('DatabaseName') + database_name = self.parameters.get("DatabaseName") tables = self.glue_backend.get_tables(database_name) - return json.dumps({ - 'TableList': [ - table.as_dict() for table in tables - ] - }) + return json.dumps({"TableList": [table.as_dict() for table in tables]}) def delete_table(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('Name') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("Name") resp = self.glue_backend.delete_table(database_name, table_name) return json.dumps(resp) def batch_delete_table(self): - database_name = self.parameters.get('DatabaseName') + database_name = self.parameters.get("DatabaseName") errors = [] - for table_name in self.parameters.get('TablesToDelete'): + for table_name in self.parameters.get("TablesToDelete"): try: self.glue_backend.delete_table(database_name, table_name) except TableNotFoundException: - errors.append({ - "TableName": table_name, - "ErrorDetail": { - "ErrorCode": "EntityNotFoundException", - "ErrorMessage": "Table not found" + errors.append( + { + "TableName": table_name, + "ErrorDetail": { + "ErrorCode": "EntityNotFoundException", + "ErrorMessage": "Table not found", + }, } - }) + ) out = {} if errors: @@ -118,33 +117,31 @@ class GlueResponse(BaseResponse): return json.dumps(out) def get_partitions(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - if 'Expression' in self.parameters: - raise NotImplementedError("Expression filtering in get_partitions is not implemented in moto") + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + if "Expression" in self.parameters: + raise NotImplementedError( + "Expression filtering in get_partitions is not implemented in moto" + ) table = self.glue_backend.get_table(database_name, table_name) - return json.dumps({ - 'Partitions': [ - p.as_dict() for p in table.get_partitions() - ] - }) + return json.dumps({"Partitions": [p.as_dict() for p in table.get_partitions()]}) def get_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - values = self.parameters.get('PartitionValues') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + values = self.parameters.get("PartitionValues") table = self.glue_backend.get_table(database_name, table_name) p = table.get_partition(values) - return json.dumps({'Partition': p.as_dict()}) + return json.dumps({"Partition": p.as_dict()}) def batch_get_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - partitions_to_get = self.parameters.get('PartitionsToGet') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + partitions_to_get = self.parameters.get("PartitionsToGet") table = self.glue_backend.get_table(database_name, table_name) @@ -156,12 +153,12 @@ class GlueResponse(BaseResponse): except PartitionNotFoundException: continue - return json.dumps({'Partitions': partitions}) + return json.dumps({"Partitions": partitions}) def create_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - part_input = self.parameters.get('PartitionInput') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + part_input = self.parameters.get("PartitionInput") table = self.glue_backend.get_table(database_name, table_name) table.create_partition(part_input) @@ -169,22 +166,24 @@ class GlueResponse(BaseResponse): return "" def batch_create_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) errors_output = [] - for part_input in self.parameters.get('PartitionInputList'): + for part_input in self.parameters.get("PartitionInputList"): try: table.create_partition(part_input) except PartitionAlreadyExistsException: - errors_output.append({ - 'PartitionValues': part_input['Values'], - 'ErrorDetail': { - 'ErrorCode': 'AlreadyExistsException', - 'ErrorMessage': 'Partition already exists.' + errors_output.append( + { + "PartitionValues": part_input["Values"], + "ErrorDetail": { + "ErrorCode": "AlreadyExistsException", + "ErrorMessage": "Partition already exists.", + }, } - }) + ) out = {} if errors_output: @@ -193,10 +192,10 @@ class GlueResponse(BaseResponse): return json.dumps(out) def update_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - part_input = self.parameters.get('PartitionInput') - part_to_update = self.parameters.get('PartitionValueList') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + part_input = self.parameters.get("PartitionInput") + part_to_update = self.parameters.get("PartitionValueList") table = self.glue_backend.get_table(database_name, table_name) table.update_partition(part_to_update, part_input) @@ -204,9 +203,9 @@ class GlueResponse(BaseResponse): return "" def delete_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') - part_to_delete = self.parameters.get('PartitionValues') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") + part_to_delete = self.parameters.get("PartitionValues") table = self.glue_backend.get_table(database_name, table_name) table.delete_partition(part_to_delete) @@ -214,26 +213,28 @@ class GlueResponse(BaseResponse): return "" def batch_delete_partition(self): - database_name = self.parameters.get('DatabaseName') - table_name = self.parameters.get('TableName') + database_name = self.parameters.get("DatabaseName") + table_name = self.parameters.get("TableName") table = self.glue_backend.get_table(database_name, table_name) errors_output = [] - for part_input in self.parameters.get('PartitionsToDelete'): - values = part_input.get('Values') + for part_input in self.parameters.get("PartitionsToDelete"): + values = part_input.get("Values") try: table.delete_partition(values) except PartitionNotFoundException: - errors_output.append({ - 'PartitionValues': values, - 'ErrorDetail': { - 'ErrorCode': 'EntityNotFoundException', - 'ErrorMessage': 'Partition not found', + errors_output.append( + { + "PartitionValues": values, + "ErrorDetail": { + "ErrorCode": "EntityNotFoundException", + "ErrorMessage": "Partition not found", + }, } - }) + ) out = {} if errors_output: - out['Errors'] = errors_output + out["Errors"] = errors_output return json.dumps(out) diff --git a/moto/glue/urls.py b/moto/glue/urls.py index f3eaa9cad..2c7854732 100644 --- a/moto/glue/urls.py +++ b/moto/glue/urls.py @@ -2,10 +2,6 @@ from __future__ import unicode_literals from .responses import GlueResponse -url_bases = [ - "https?://glue(.*).amazonaws.com" -] +url_bases = ["https?://glue(.*).amazonaws.com"] -url_paths = { - '{0}/$': GlueResponse.dispatch -} +url_paths = {"{0}/$": GlueResponse.dispatch} diff --git a/moto/iam/exceptions.py b/moto/iam/exceptions.py index afd1373a3..b9b0176e0 100644 --- a/moto/iam/exceptions.py +++ b/moto/iam/exceptions.py @@ -6,32 +6,28 @@ class IAMNotFoundException(RESTError): code = 404 def __init__(self, message): - super(IAMNotFoundException, self).__init__( - "NoSuchEntity", message) + super(IAMNotFoundException, self).__init__("NoSuchEntity", message) class IAMConflictException(RESTError): code = 409 - def __init__(self, code='Conflict', message=""): - super(IAMConflictException, self).__init__( - code, message) + def __init__(self, code="Conflict", message=""): + super(IAMConflictException, self).__init__(code, message) class IAMReportNotPresentException(RESTError): code = 410 def __init__(self, message): - super(IAMReportNotPresentException, self).__init__( - "ReportNotPresent", message) + super(IAMReportNotPresentException, self).__init__("ReportNotPresent", message) class IAMLimitExceededException(RESTError): code = 400 def __init__(self, message): - super(IAMLimitExceededException, self).__init__( - "LimitExceeded", message) + super(IAMLimitExceededException, self).__init__("LimitExceeded", message) class MalformedCertificate(RESTError): @@ -39,7 +35,8 @@ class MalformedCertificate(RESTError): def __init__(self, cert): super(MalformedCertificate, self).__init__( - 'MalformedCertificate', 'Certificate {cert} is malformed'.format(cert=cert)) + "MalformedCertificate", "Certificate {cert} is malformed".format(cert=cert) + ) class MalformedPolicyDocument(RESTError): @@ -47,7 +44,8 @@ class MalformedPolicyDocument(RESTError): def __init__(self, message=""): super(MalformedPolicyDocument, self).__init__( - 'MalformedPolicyDocument', message) + "MalformedPolicyDocument", message + ) class DuplicateTags(RESTError): @@ -55,16 +53,22 @@ class DuplicateTags(RESTError): def __init__(self): super(DuplicateTags, self).__init__( - 'InvalidInput', 'Duplicate tag keys found. Please note that Tag keys are case insensitive.') + "InvalidInput", + "Duplicate tag keys found. Please note that Tag keys are case insensitive.", + ) class TagKeyTooBig(RESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): + def __init__(self, tag, param="tags.X.member.key"): super(TagKeyTooBig, self).__init__( - 'ValidationError', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 128.".format(tag, param)) + "ValidationError", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 128.".format( + tag, param + ), + ) class TagValueTooBig(RESTError): @@ -72,48 +76,55 @@ class TagValueTooBig(RESTError): def __init__(self, tag): super(TagValueTooBig, self).__init__( - 'ValidationError', "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " - "constraint: Member must have length less than or equal to 256.".format(tag)) + "ValidationError", + "1 validation error detected: Value '{}' at 'tags.X.member.value' failed to satisfy " + "constraint: Member must have length less than or equal to 256.".format( + tag + ), + ) class InvalidTagCharacters(RESTError): code = 400 - def __init__(self, tag, param='tags.X.member.key'): - message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format(tag, param) + def __init__(self, tag, param="tags.X.member.key"): + message = "1 validation error detected: Value '{}' at '{}' failed to satisfy ".format( + tag, param + ) message += "constraint: Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" - super(InvalidTagCharacters, self).__init__('ValidationError', message) + super(InvalidTagCharacters, self).__init__("ValidationError", message) class TooManyTags(RESTError): code = 400 - def __init__(self, tags, param='tags'): + def __init__(self, tags, param="tags"): super(TooManyTags, self).__init__( - 'ValidationError', "1 validation error detected: Value '{}' at '{}' failed to satisfy " - "constraint: Member must have length less than or equal to 50.".format(tags, param)) + "ValidationError", + "1 validation error detected: Value '{}' at '{}' failed to satisfy " + "constraint: Member must have length less than or equal to 50.".format( + tags, param + ), + ) class EntityAlreadyExists(RESTError): code = 409 def __init__(self, message): - super(EntityAlreadyExists, self).__init__( - 'EntityAlreadyExists', message) + super(EntityAlreadyExists, self).__init__("EntityAlreadyExists", message) class ValidationError(RESTError): code = 400 def __init__(self, message): - super(ValidationError, self).__init__( - 'ValidationError', message) + super(ValidationError, self).__init__("ValidationError", message) class InvalidInput(RESTError): code = 400 def __init__(self, message): - super(InvalidInput, self).__init__( - 'InvalidInput', message) + super(InvalidInput, self).__init__("InvalidInput", message) diff --git a/moto/iam/models.py b/moto/iam/models.py index 741b71142..4a115999c 100644 --- a/moto/iam/models.py +++ b/moto/iam/models.py @@ -14,14 +14,34 @@ from six.moves.urllib.parse import urlparse from moto.core.exceptions import RESTError from moto.core import BaseBackend, BaseModel -from moto.core.utils import iso_8601_datetime_without_milliseconds, iso_8601_datetime_with_milliseconds +from moto.core.utils import ( + iso_8601_datetime_without_milliseconds, + iso_8601_datetime_with_milliseconds, +) from moto.iam.policy_validation import IAMPolicyDocumentValidator from .aws_managed_policies import aws_managed_policies_data -from .exceptions import (IAMNotFoundException, IAMConflictException, IAMReportNotPresentException, IAMLimitExceededException, - MalformedCertificate, DuplicateTags, TagKeyTooBig, InvalidTagCharacters, TooManyTags, TagValueTooBig, - EntityAlreadyExists, ValidationError, InvalidInput) -from .utils import random_access_key, random_alphanumeric, random_resource_id, random_policy_id +from .exceptions import ( + IAMNotFoundException, + IAMConflictException, + IAMReportNotPresentException, + IAMLimitExceededException, + MalformedCertificate, + DuplicateTags, + TagKeyTooBig, + InvalidTagCharacters, + TooManyTags, + TagValueTooBig, + EntityAlreadyExists, + ValidationError, + InvalidInput, +) +from .utils import ( + random_access_key, + random_alphanumeric, + random_resource_id, + random_policy_id, +) ACCOUNT_ID = 123456789012 @@ -29,10 +49,7 @@ ACCOUNT_ID = 123456789012 class MFADevice(object): """MFA Device class.""" - def __init__(self, - serial_number, - authentication_code_1, - authentication_code_2): + def __init__(self, serial_number, authentication_code_1, authentication_code_2): self.enable_date = datetime.utcnow() self.serial_number = serial_number self.authentication_code_1 = authentication_code_1 @@ -45,11 +62,17 @@ class MFADevice(object): class VirtualMfaDevice(object): def __init__(self, device_name): - self.serial_number = 'arn:aws:iam::{0}:mfa{1}'.format(ACCOUNT_ID, device_name) + self.serial_number = "arn:aws:iam::{0}:mfa{1}".format(ACCOUNT_ID, device_name) - random_base32_string = ''.join(random.choice(string.ascii_uppercase + '234567') for _ in range(64)) - self.base32_string_seed = base64.b64encode(random_base32_string.encode('ascii')).decode('ascii') - self.qr_code_png = base64.b64encode(os.urandom(64)) # this would be a generated PNG + random_base32_string = "".join( + random.choice(string.ascii_uppercase + "234567") for _ in range(64) + ) + self.base32_string_seed = base64.b64encode( + random_base32_string.encode("ascii") + ).decode("ascii") + self.qr_code_png = base64.b64encode( + os.urandom(64) + ) # this would be a generated PNG self.enable_date = None self.user_attribute = None @@ -63,28 +86,34 @@ class VirtualMfaDevice(object): class Policy(BaseModel): is_attachable = False - def __init__(self, - name, - default_version_id=None, - description=None, - document=None, - path=None, - create_date=None, - update_date=None): + def __init__( + self, + name, + default_version_id=None, + description=None, + document=None, + path=None, + create_date=None, + update_date=None, + ): self.name = name self.attachment_count = 0 - self.description = description or '' + self.description = description or "" self.id = random_policy_id() - self.path = path or '/' + self.path = path or "/" if default_version_id: self.default_version_id = default_version_id - self.next_version_num = int(default_version_id.lstrip('v')) + 1 + self.next_version_num = int(default_version_id.lstrip("v")) + 1 else: - self.default_version_id = 'v1' + self.default_version_id = "v1" self.next_version_num = 2 - self.versions = [PolicyVersion(self.arn, document, True, self.default_version_id, update_date)] + self.versions = [ + PolicyVersion( + self.arn, document, True, self.default_version_id, update_date + ) + ] self.create_date = create_date if create_date is not None else datetime.utcnow() self.update_date = update_date if update_date is not None else datetime.utcnow() @@ -128,7 +157,7 @@ class OpenIDConnectProvider(BaseModel): @property def arn(self): - return 'arn:aws:iam::{0}:oidc-provider/{1}'.format(ACCOUNT_ID, self.url) + return "arn:aws:iam::{0}:oidc-provider/{1}".format(ACCOUNT_ID, self.url) @property def created_iso_8601(self): @@ -136,47 +165,53 @@ class OpenIDConnectProvider(BaseModel): def _validate(self, url, thumbprint_list, client_id_list): if any(len(client_id) > 255 for client_id in client_id_list): - self._errors.append(self._format_error( - key='clientIDList', - value=client_id_list, - constraint='Member must satisfy constraint: ' - '[Member must have length less than or equal to 255, ' - 'Member must have length greater than or equal to 1]', - )) + self._errors.append( + self._format_error( + key="clientIDList", + value=client_id_list, + constraint="Member must satisfy constraint: " + "[Member must have length less than or equal to 255, " + "Member must have length greater than or equal to 1]", + ) + ) if any(len(thumbprint) > 40 for thumbprint in thumbprint_list): - self._errors.append(self._format_error( - key='thumbprintList', - value=thumbprint_list, - constraint='Member must satisfy constraint: ' - '[Member must have length less than or equal to 40, ' - 'Member must have length greater than or equal to 40]', - )) + self._errors.append( + self._format_error( + key="thumbprintList", + value=thumbprint_list, + constraint="Member must satisfy constraint: " + "[Member must have length less than or equal to 40, " + "Member must have length greater than or equal to 40]", + ) + ) if len(url) > 255: - self._errors.append(self._format_error( - key='url', - value=url, - constraint='Member must have length less than or equal to 255', - )) + self._errors.append( + self._format_error( + key="url", + value=url, + constraint="Member must have length less than or equal to 255", + ) + ) self._raise_errors() parsed_url = urlparse(url) if not parsed_url.scheme or not parsed_url.netloc: - raise ValidationError('Invalid Open ID Connect Provider URL') + raise ValidationError("Invalid Open ID Connect Provider URL") if len(thumbprint_list) > 5: - raise InvalidInput('Thumbprint list must contain fewer than 5 entries.') + raise InvalidInput("Thumbprint list must contain fewer than 5 entries.") if len(client_id_list) > 100: - raise IAMLimitExceededException('Cannot exceed quota for ClientIdsPerOpenIdConnectProvider: 100') + raise IAMLimitExceededException( + "Cannot exceed quota for ClientIdsPerOpenIdConnectProvider: 100" + ) def _format_error(self, key, value, constraint): return 'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}'.format( - constraint=constraint, - key=key, - value=value, + constraint=constraint, key=key, value=value ) def _raise_errors(self): @@ -186,19 +221,17 @@ class OpenIDConnectProvider(BaseModel): errors = "; ".join(self._errors) self._errors = [] # reset collected errors - raise ValidationError('{count} validation error{plural} detected: {errors}'.format( - count=count, plural=plural, errors=errors, - )) + raise ValidationError( + "{count} validation error{plural} detected: {errors}".format( + count=count, plural=plural, errors=errors + ) + ) class PolicyVersion(object): - - def __init__(self, - policy_arn, - document, - is_default=False, - version_id='v1', - create_date=None): + def __init__( + self, policy_arn, document, is_default=False, version_id="v1", create_date=None + ): self.policy_arn = policy_arn self.document = document or {} self.is_default = is_default @@ -234,23 +267,30 @@ class AWSManagedPolicy(ManagedPolicy): @classmethod def from_data(cls, name, data): - return cls(name, - default_version_id=data.get('DefaultVersionId'), - path=data.get('Path'), - document=json.dumps(data.get('Document')), - create_date=datetime.strptime(data.get('CreateDate'), "%Y-%m-%dT%H:%M:%S+00:00"), - update_date=datetime.strptime(data.get('UpdateDate'), "%Y-%m-%dT%H:%M:%S+00:00")) + return cls( + name, + default_version_id=data.get("DefaultVersionId"), + path=data.get("Path"), + document=json.dumps(data.get("Document")), + create_date=datetime.strptime( + data.get("CreateDate"), "%Y-%m-%dT%H:%M:%S+00:00" + ), + update_date=datetime.strptime( + data.get("UpdateDate"), "%Y-%m-%dT%H:%M:%S+00:00" + ), + ) @property def arn(self): - return 'arn:aws:iam::aws:policy{0}{1}'.format(self.path, self.name) + return "arn:aws:iam::aws:policy{0}{1}".format(self.path, self.name) # AWS defines some of its own managed policies and we periodically # import them via `make aws_managed_policies` aws_managed_policies = [ - AWSManagedPolicy.from_data(name, d) for name, d - in json.loads(aws_managed_policies_data).items()] + AWSManagedPolicy.from_data(name, d) + for name, d in json.loads(aws_managed_policies_data).items() +] class InlinePolicy(Policy): @@ -258,12 +298,20 @@ class InlinePolicy(Policy): class Role(BaseModel): - - def __init__(self, role_id, name, assume_role_policy_document, path, permissions_boundary, description, tags): + def __init__( + self, + role_id, + name, + assume_role_policy_document, + path, + permissions_boundary, + description, + tags, + ): self.id = role_id self.name = name self.assume_role_policy_document = assume_role_policy_document - self.path = path or '/' + self.path = path or "/" self.policies = {} self.managed_policies = {} self.create_date = datetime.utcnow() @@ -276,22 +324,24 @@ class Role(BaseModel): return iso_8601_datetime_with_milliseconds(self.create_date) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] role = iam_backend.create_role( role_name=resource_name, - assume_role_policy_document=properties['AssumeRolePolicyDocument'], - path=properties.get('Path', '/'), - permissions_boundary=properties.get('PermissionsBoundary', ''), - description=properties.get('Description', ''), - tags=properties.get('Tags', {}) + assume_role_policy_document=properties["AssumeRolePolicyDocument"], + path=properties.get("Path", "/"), + permissions_boundary=properties.get("PermissionsBoundary", ""), + description=properties.get("Description", ""), + tags=properties.get("Tags", {}), ) - policies = properties.get('Policies', []) + policies = properties.get("Policies", []) for policy in policies: - policy_name = policy['PolicyName'] - policy_json = policy['PolicyDocument'] + policy_name = policy["PolicyName"] + policy_json = policy["PolicyDocument"] role.put_policy(policy_name, policy_json) return role @@ -308,7 +358,8 @@ class Role(BaseModel): del self.policies[policy_name] except KeyError: raise IAMNotFoundException( - "The role policy with name {0} cannot be found.".format(policy_name)) + "The role policy with name {0} cannot be found.".format(policy_name) + ) @property def physical_resource_id(self): @@ -316,7 +367,8 @@ class Role(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "Arn" ]"') raise UnformattedGetAttTemplateException() @@ -325,11 +377,10 @@ class Role(BaseModel): class InstanceProfile(BaseModel): - def __init__(self, instance_profile_id, name, path, roles): self.id = instance_profile_id self.name = name - self.path = path or '/' + self.path = path or "/" self.roles = roles if roles else [] self.create_date = datetime.utcnow() @@ -338,19 +389,21 @@ class InstanceProfile(BaseModel): return iso_8601_datetime_with_milliseconds(self.create_date) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - role_ids = properties['Roles'] + role_ids = properties["Roles"] return iam_backend.create_instance_profile( - name=resource_name, - path=properties.get('Path', '/'), - role_ids=role_ids, + name=resource_name, path=properties.get("Path", "/"), role_ids=role_ids ) @property def arn(self): - return "arn:aws:iam::{0}:instance-profile{1}{2}".format(ACCOUNT_ID, self.path, self.name) + return "arn:aws:iam::{0}:instance-profile{1}{2}".format( + ACCOUNT_ID, self.path, self.name + ) @property def physical_resource_id(self): @@ -358,13 +411,13 @@ class InstanceProfile(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.arn raise UnformattedGetAttTemplateException() class Certificate(BaseModel): - def __init__(self, cert_name, cert_body, private_key, cert_chain=None, path=None): self.cert_name = cert_name self.cert_body = cert_body @@ -378,17 +431,18 @@ class Certificate(BaseModel): @property def arn(self): - return "arn:aws:iam::{0}:server-certificate{1}{2}".format(ACCOUNT_ID, self.path, self.cert_name) + return "arn:aws:iam::{0}:server-certificate{1}{2}".format( + ACCOUNT_ID, self.path, self.cert_name + ) class SigningCertificate(BaseModel): - def __init__(self, id, user_name, body): self.id = id self.user_name = user_name self.body = body self.upload_date = datetime.utcnow() - self.status = 'Active' + self.status = "Active" @property def uploaded_iso_8601(self): @@ -396,12 +450,11 @@ class SigningCertificate(BaseModel): class AccessKey(BaseModel): - def __init__(self, user_name): self.user_name = user_name self.access_key_id = "AKIA" + random_access_key() self.secret_access_key = random_alphanumeric(40) - self.status = 'Active' + self.status = "Active" self.create_date = datetime.utcnow() self.last_used = datetime.utcnow() @@ -415,14 +468,14 @@ class AccessKey(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'SecretAccessKey': + + if attribute_name == "SecretAccessKey": return self.secret_access_key raise UnformattedGetAttTemplateException() class Group(BaseModel): - - def __init__(self, name, path='/'): + def __init__(self, name, path="/"): self.name = name self.id = random_resource_id() self.path = path @@ -438,17 +491,20 @@ class Group(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "Arn" ]"') raise UnformattedGetAttTemplateException() @property def arn(self): - if self.path == '/': + if self.path == "/": return "arn:aws:iam::{0}:group/{1}".format(ACCOUNT_ID, self.name) else: - return "arn:aws:iam::{0}:group/{1}/{2}".format(ACCOUNT_ID, self.path, self.name) + return "arn:aws:iam::{0}:group/{1}/{2}".format( + ACCOUNT_ID, self.path, self.name + ) def get_policy(self, policy_name): try: @@ -457,9 +513,9 @@ class Group(BaseModel): raise IAMNotFoundException("Policy {0} not found".format(policy_name)) return { - 'policy_name': policy_name, - 'policy_document': policy_json, - 'group_name': self.name, + "policy_name": policy_name, + "policy_document": policy_json, + "group_name": self.name, } def put_policy(self, policy_name, policy_json): @@ -470,7 +526,6 @@ class Group(BaseModel): class User(BaseModel): - def __init__(self, name, path=None): self.name = name self.id = random_resource_id() @@ -497,13 +552,12 @@ class User(BaseModel): try: policy_json = self.policies[policy_name] except KeyError: - raise IAMNotFoundException( - "Policy {0} not found".format(policy_name)) + raise IAMNotFoundException("Policy {0} not found".format(policy_name)) return { - 'policy_name': policy_name, - 'policy_document': policy_json, - 'user_name': self.name, + "policy_name": policy_name, + "policy_document": policy_json, + "user_name": self.name, } def put_policy(self, policy_name, policy_json): @@ -514,8 +568,7 @@ class User(BaseModel): def delete_policy(self, policy_name): if policy_name not in self.policies: - raise IAMNotFoundException( - "Policy {0} not found".format(policy_name)) + raise IAMNotFoundException("Policy {0} not found".format(policy_name)) del self.policies[policy_name] @@ -524,14 +577,11 @@ class User(BaseModel): self.access_keys.append(access_key) return access_key - def enable_mfa_device(self, - serial_number, - authentication_code_1, - authentication_code_2): + def enable_mfa_device( + self, serial_number, authentication_code_1, authentication_code_2 + ): self.mfa_devices[serial_number] = MFADevice( - serial_number, - authentication_code_1, - authentication_code_2 + serial_number, authentication_code_1, authentication_code_2 ) def get_all_access_keys(self): @@ -551,58 +601,58 @@ class User(BaseModel): return key else: raise IAMNotFoundException( - "The Access Key with id {0} cannot be found".format(access_key_id)) + "The Access Key with id {0} cannot be found".format(access_key_id) + ) def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "Arn" ]"') raise UnformattedGetAttTemplateException() def to_csv(self): - date_format = '%Y-%m-%dT%H:%M:%S+00:00' + date_format = "%Y-%m-%dT%H:%M:%S+00:00" date_created = self.create_date # aagrawal,arn:aws:iam::509284790694:user/aagrawal,2014-09-01T22:28:48+00:00,true,2014-11-12T23:36:49+00:00,2014-09-03T18:59:00+00:00,N/A,false,true,2014-09-01T22:28:48+00:00,false,N/A,false,N/A,false,N/A if not self.password: - password_enabled = 'false' - password_last_used = 'not_supported' + password_enabled = "false" + password_last_used = "not_supported" else: - password_enabled = 'true' - password_last_used = 'no_information' + password_enabled = "true" + password_last_used = "no_information" if len(self.access_keys) == 0: - access_key_1_active = 'false' - access_key_1_last_rotated = 'N/A' - access_key_2_active = 'false' - access_key_2_last_rotated = 'N/A' + access_key_1_active = "false" + access_key_1_last_rotated = "N/A" + access_key_2_active = "false" + access_key_2_last_rotated = "N/A" elif len(self.access_keys) == 1: - access_key_1_active = 'true' + access_key_1_active = "true" access_key_1_last_rotated = date_created.strftime(date_format) - access_key_2_active = 'false' - access_key_2_last_rotated = 'N/A' + access_key_2_active = "false" + access_key_2_last_rotated = "N/A" else: - access_key_1_active = 'true' + access_key_1_active = "true" access_key_1_last_rotated = date_created.strftime(date_format) - access_key_2_active = 'true' + access_key_2_active = "true" access_key_2_last_rotated = date_created.strftime(date_format) - return '{0},{1},{2},{3},{4},{5},not_supported,false,{6},{7},{8},{9},false,N/A,false,N/A'.format(self.name, - self.arn, - date_created.strftime( - date_format), - password_enabled, - password_last_used, - date_created.strftime( - date_format), - access_key_1_active, - access_key_1_last_rotated, - access_key_2_active, - access_key_2_last_rotated - ) + return "{0},{1},{2},{3},{4},{5},not_supported,false,{6},{7},{8},{9},false,N/A,false,N/A".format( + self.name, + self.arn, + date_created.strftime(date_format), + password_enabled, + password_last_used, + date_created.strftime(date_format), + access_key_1_active, + access_key_1_last_rotated, + access_key_2_active, + access_key_2_last_rotated, + ) class IAMBackend(BaseBackend): - def __init__(self): self.instance_profiles = {} self.roles = {} @@ -614,8 +664,7 @@ class IAMBackend(BaseBackend): self.account_aliases = [] self.saml_providers = {} self.open_id_providers = {} - self.policy_arn_regex = re.compile( - r'^arn:aws:iam::[0-9]*:policy/.*$') + self.policy_arn_regex = re.compile(r"^arn:aws:iam::[0-9]*:policy/.*$") self.virtual_mfa_devices = {} super(IAMBackend, self).__init__() @@ -682,10 +731,7 @@ class IAMBackend(BaseBackend): iam_policy_document_validator.validate() policy = ManagedPolicy( - policy_name, - description=description, - document=policy_document, - path=path, + policy_name, description=description, document=policy_document, path=path ) self.managed_policies[policy.arn] = policy return policy @@ -695,15 +741,21 @@ class IAMBackend(BaseBackend): raise IAMNotFoundException("Policy {0} not found".format(policy_arn)) return self.managed_policies.get(policy_arn) - def list_attached_role_policies(self, role_name, marker=None, max_items=100, path_prefix='/'): + def list_attached_role_policies( + self, role_name, marker=None, max_items=100, path_prefix="/" + ): policies = self.get_role(role_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) - def list_attached_group_policies(self, group_name, marker=None, max_items=100, path_prefix='/'): + def list_attached_group_policies( + self, group_name, marker=None, max_items=100, path_prefix="/" + ): policies = self.get_group(group_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) - def list_attached_user_policies(self, user_name, marker=None, max_items=100, path_prefix='/'): + def list_attached_user_policies( + self, user_name, marker=None, max_items=100, path_prefix="/" + ): policies = self.get_user(user_name).managed_policies.values() return self._filter_attached_policies(policies, marker, max_items, path_prefix) @@ -713,11 +765,10 @@ class IAMBackend(BaseBackend): if only_attached: policies = [p for p in policies if p.attachment_count > 0] - if scope == 'AWS': + if scope == "AWS": policies = [p for p in policies if isinstance(p, AWSManagedPolicy)] - elif scope == 'Local': - policies = [p for p in policies if not isinstance( - p, AWSManagedPolicy)] + elif scope == "Local": + policies = [p for p in policies if not isinstance(p, AWSManagedPolicy)] return self._filter_attached_policies(policies, marker, max_items, path_prefix) @@ -728,7 +779,7 @@ class IAMBackend(BaseBackend): policies = sorted(policies, key=lambda policy: policy.name) start_idx = int(marker) if marker else 0 - policies = policies[start_idx:start_idx + max_items] + policies = policies[start_idx : start_idx + max_items] if len(policies) < max_items: marker = None @@ -737,13 +788,36 @@ class IAMBackend(BaseBackend): return policies, marker - def create_role(self, role_name, assume_role_policy_document, path, permissions_boundary, description, tags): + def create_role( + self, + role_name, + assume_role_policy_document, + path, + permissions_boundary, + description, + tags, + ): role_id = random_resource_id() - if permissions_boundary and not self.policy_arn_regex.match(permissions_boundary): - raise RESTError('InvalidParameterValue', 'Value ({}) for parameter PermissionsBoundary is invalid.'.format(permissions_boundary)) + if permissions_boundary and not self.policy_arn_regex.match( + permissions_boundary + ): + raise RESTError( + "InvalidParameterValue", + "Value ({}) for parameter PermissionsBoundary is invalid.".format( + permissions_boundary + ), + ) clean_tags = self._tag_verification(tags) - role = Role(role_id, role_name, assume_role_policy_document, path, permissions_boundary, description, clean_tags) + role = Role( + role_id, + role_name, + assume_role_policy_document, + path, + permissions_boundary, + description, + clean_tags, + ) self.roles[role_id] = role return role @@ -769,17 +843,17 @@ class IAMBackend(BaseBackend): if role.name == role_name: raise IAMConflictException( code="DeleteConflict", - message="Cannot delete entity, must remove roles from instance profile first." + message="Cannot delete entity, must remove roles from instance profile first.", ) if role.managed_policies: raise IAMConflictException( code="DeleteConflict", - message="Cannot delete entity, must detach all policies first." + message="Cannot delete entity, must detach all policies first.", ) if role.policies: raise IAMConflictException( code="DeleteConflict", - message="Cannot delete entity, must delete policies first." + message="Cannot delete entity, must delete policies first.", ) del self.roles[role.id] @@ -802,7 +876,11 @@ class IAMBackend(BaseBackend): for p, d in role.policies.items(): if p == policy_name: return p, d - raise IAMNotFoundException("Policy Document {0} not attached to role {1}".format(policy_name, role_name)) + raise IAMNotFoundException( + "Policy Document {0} not attached to role {1}".format( + policy_name, role_name + ) + ) def list_role_policies(self, role_name): role = self.get_role(role_name) @@ -815,17 +893,17 @@ class IAMBackend(BaseBackend): tag_keys = {} for tag in tags: # Need to index by the lowercase tag key since the keys are case insensitive, but their case is retained. - ref_key = tag['Key'].lower() + ref_key = tag["Key"].lower() self._check_tag_duplicate(tag_keys, ref_key) - self._validate_tag_key(tag['Key']) - if len(tag['Value']) > 256: - raise TagValueTooBig(tag['Value']) + self._validate_tag_key(tag["Key"]) + if len(tag["Value"]) > 256: + raise TagValueTooBig(tag["Value"]) tag_keys[ref_key] = tag return tag_keys - def _validate_tag_key(self, tag_key, exception_param='tags.X.member.key'): + def _validate_tag_key(self, tag_key, exception_param="tags.X.member.key"): """Validates the tag key. :param tag_key: The tag key to check against. @@ -839,7 +917,7 @@ class IAMBackend(BaseBackend): # Validate that the tag key fits the proper Regex: # [\w\s_.:/=+\-@]+ SHOULD be the same as the Java regex on the AWS documentation: [\p{L}\p{Z}\p{N}_.:/=+\-@]+ - match = re.findall(r'[\w\s_.:/=+\-@]+', tag_key) + match = re.findall(r"[\w\s_.:/=+\-@]+", tag_key) # Kudos if you can come up with a better way of doing a global search :) if not len(match) or len(match[0]) < len(tag_key): raise InvalidTagCharacters(tag_key, param=exception_param) @@ -861,7 +939,7 @@ class IAMBackend(BaseBackend): tag_index = sorted(role.tags) start_idx = int(marker) if marker else 0 - tag_index = tag_index[start_idx:start_idx + max_items] + tag_index = tag_index[start_idx : start_idx + max_items] if len(role.tags) <= (start_idx + max_items): marker = None @@ -880,13 +958,13 @@ class IAMBackend(BaseBackend): def untag_role(self, role_name, tag_keys): if len(tag_keys) > 50: - raise TooManyTags(tag_keys, param='tagKeys') + raise TooManyTags(tag_keys, param="tagKeys") role = self.get_role(role_name) for key in tag_keys: ref_key = key.lower() - self._validate_tag_key(key, exception_param='tagKeys') + self._validate_tag_key(key, exception_param="tagKeys") role.tags.pop(ref_key, None) @@ -898,11 +976,13 @@ class IAMBackend(BaseBackend): if not policy: raise IAMNotFoundException("Policy not found") if len(policy.versions) >= 5: - raise IAMLimitExceededException("A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version.") - set_as_default = (set_as_default == "true") # convert it to python bool + raise IAMLimitExceededException( + "A managed policy can have up to 5 versions. Before you create a new version, you must delete an existing version." + ) + set_as_default = set_as_default == "true" # convert it to python bool version = PolicyVersion(policy_arn, policy_document, set_as_default) policy.versions.append(version) - version.version_id = 'v{0}'.format(policy.next_version_num) + version.version_id = "v{0}".format(policy.next_version_num) policy.next_version_num += 1 if set_as_default: policy.update_default_version(version.version_id) @@ -928,8 +1008,10 @@ class IAMBackend(BaseBackend): if not policy: raise IAMNotFoundException("Policy not found") if version_id == policy.default_version_id: - raise IAMConflictException(code="DeleteConflict", - message="Cannot delete the default version of a policy.") + raise IAMConflictException( + code="DeleteConflict", + message="Cannot delete the default version of a policy.", + ) for i, v in enumerate(policy.versions): if v.version_id == version_id: del policy.versions[i] @@ -940,8 +1022,7 @@ class IAMBackend(BaseBackend): instance_profile_id = random_resource_id() roles = [iam_backend.get_role_by_id(role_id) for role_id in role_ids] - instance_profile = InstanceProfile( - instance_profile_id, name, path, roles) + instance_profile = InstanceProfile(instance_profile_id, name, path, roles) self.instance_profiles[instance_profile_id] = instance_profile return instance_profile @@ -951,7 +1032,8 @@ class IAMBackend(BaseBackend): return profile raise IAMNotFoundException( - "Instance profile {0} not found".format(profile_name)) + "Instance profile {0} not found".format(profile_name) + ) def get_instance_profiles(self): return self.instance_profiles.values() @@ -979,7 +1061,9 @@ class IAMBackend(BaseBackend): def get_all_server_certs(self, marker=None): return self.certificates.values() - def upload_server_cert(self, cert_name, cert_body, private_key, cert_chain=None, path=None): + def upload_server_cert( + self, cert_name, cert_body, private_key, cert_chain=None, path=None + ): certificate_id = random_resource_id() cert = Certificate(cert_name, cert_body, private_key, cert_chain, path) self.certificates[certificate_id] = cert @@ -991,8 +1075,8 @@ class IAMBackend(BaseBackend): return cert raise IAMNotFoundException( - "The Server Certificate with name {0} cannot be " - "found.".format(name)) + "The Server Certificate with name {0} cannot be " "found.".format(name) + ) def delete_server_certificate(self, name): cert_id = None @@ -1003,15 +1087,14 @@ class IAMBackend(BaseBackend): if cert_id is None: raise IAMNotFoundException( - "The Server Certificate with name {0} cannot be " - "found.".format(name)) + "The Server Certificate with name {0} cannot be " "found.".format(name) + ) self.certificates.pop(cert_id, None) - def create_group(self, group_name, path='/'): + def create_group(self, group_name, path="/"): if group_name in self.groups: - raise IAMConflictException( - "Group {0} already exists".format(group_name)) + raise IAMConflictException("Group {0} already exists".format(group_name)) group = Group(group_name, path) self.groups[group_name] = group @@ -1022,8 +1105,7 @@ class IAMBackend(BaseBackend): try: group = self.groups[group_name] except KeyError: - raise IAMNotFoundException( - "Group {0} not found".format(group_name)) + raise IAMNotFoundException("Group {0} not found".format(group_name)) return group @@ -1054,10 +1136,11 @@ class IAMBackend(BaseBackend): group = self.get_group(group_name) return group.get_policy(policy_name) - def create_user(self, user_name, path='/'): + def create_user(self, user_name, path="/"): if user_name in self.users: raise IAMConflictException( - "EntityAlreadyExists", "User {0} already exists".format(user_name)) + "EntityAlreadyExists", "User {0} already exists".format(user_name) + ) user = User(user_name, path) self.users[user_name] = user @@ -1078,7 +1161,8 @@ class IAMBackend(BaseBackend): users = self.users.values() except KeyError: raise IAMNotFoundException( - "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items)) + "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items) + ) return users @@ -1100,7 +1184,8 @@ class IAMBackend(BaseBackend): roles = self.roles.values() except KeyError: raise IAMNotFoundException( - "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items)) + "Users {0}, {1}, {2} not found".format(path_prefix, marker, max_items) + ) return roles @@ -1113,14 +1198,16 @@ class IAMBackend(BaseBackend): if sys.version_info < (3, 0): data = bytes(body) else: - data = bytes(body, 'utf8') + data = bytes(body, "utf8") x509.load_pem_x509_certificate(data, default_backend()) except Exception: raise MalformedCertificate(body) - user.signing_certificates[cert_id] = SigningCertificate(cert_id, user_name, body) + user.signing_certificates[cert_id] = SigningCertificate( + cert_id, user_name, body + ) return user.signing_certificates[cert_id] @@ -1130,7 +1217,9 @@ class IAMBackend(BaseBackend): try: del user.signing_certificates[cert_id] except KeyError: - raise IAMNotFoundException("The Certificate with id {id} cannot be found.".format(id=cert_id)) + raise IAMNotFoundException( + "The Certificate with id {id} cannot be found.".format(id=cert_id) + ) def list_signing_certificates(self, user_name): user = self.get_user(user_name) @@ -1144,14 +1233,17 @@ class IAMBackend(BaseBackend): user.signing_certificates[cert_id].status = status except KeyError: - raise IAMNotFoundException("The Certificate with id {id} cannot be found.".format(id=cert_id)) + raise IAMNotFoundException( + "The Certificate with id {id} cannot be found.".format(id=cert_id) + ) def create_login_profile(self, user_name, password): # This does not currently deal with PasswordPolicyViolation. user = self.get_user(user_name) if user.password: raise IAMConflictException( - "User {0} already has password".format(user_name)) + "User {0} already has password".format(user_name) + ) user.password = password return user @@ -1159,7 +1251,8 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) if not user.password: raise IAMNotFoundException( - "Login profile for {0} not found".format(user_name)) + "Login profile for {0} not found".format(user_name) + ) return user def update_login_profile(self, user_name, password, password_reset_required): @@ -1167,7 +1260,8 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) if not user.password: raise IAMNotFoundException( - "Login profile for {0} not found".format(user_name)) + "Login profile for {0} not found".format(user_name) + ) user.password = password user.password_reset_required = password_reset_required return user @@ -1176,7 +1270,8 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) if not user.password: raise IAMNotFoundException( - "Login profile for {0} not found".format(user_name)) + "Login profile for {0} not found".format(user_name) + ) user.password = None def add_user_to_group(self, group_name, user_name): @@ -1191,7 +1286,8 @@ class IAMBackend(BaseBackend): group.users.remove(user) except ValueError: raise IAMNotFoundException( - "User {0} not in group {1}".format(user_name, group_name)) + "User {0} not in group {1}".format(user_name, group_name) + ) def get_user_policy(self, user_name, policy_name): user = self.get_user(user_name) @@ -1229,13 +1325,11 @@ class IAMBackend(BaseBackend): access_keys_list = self.get_all_access_keys_for_all_users() for key in access_keys_list: if key.access_key_id == access_key_id: - return { - 'user_name': key.user_name, - 'last_used': key.last_used_iso_8601, - } + return {"user_name": key.user_name, "last_used": key.last_used_iso_8601} else: raise IAMNotFoundException( - "The Access Key with id {0} cannot be found".format(access_key_id)) + "The Access Key with id {0} cannot be found".format(access_key_id) + ) def get_all_access_keys_for_all_users(self): access_keys_list = [] @@ -1252,17 +1346,14 @@ class IAMBackend(BaseBackend): user = self.get_user(user_name) user.delete_access_key(access_key_id) - def enable_mfa_device(self, - user_name, - serial_number, - authentication_code_1, - authentication_code_2): + def enable_mfa_device( + self, user_name, serial_number, authentication_code_1, authentication_code_2 + ): """Enable MFA Device for user.""" user = self.get_user(user_name) if serial_number in user.mfa_devices: raise IAMConflictException( - "EntityAlreadyExists", - "Device {0} already exists".format(serial_number) + "EntityAlreadyExists", "Device {0} already exists".format(serial_number) ) device = self.virtual_mfa_devices.get(serial_number, None) @@ -1270,29 +1361,25 @@ class IAMBackend(BaseBackend): device.enable_date = datetime.utcnow() device.user = user device.user_attribute = { - 'Path': user.path, - 'UserName': user.name, - 'UserId': user.id, - 'Arn': user.arn, - 'CreateDate': user.created_iso_8601, - 'PasswordLastUsed': None, # not supported - 'PermissionsBoundary': {}, # ToDo: add put_user_permissions_boundary() functionality - 'Tags': {} # ToDo: add tag_user() functionality + "Path": user.path, + "UserName": user.name, + "UserId": user.id, + "Arn": user.arn, + "CreateDate": user.created_iso_8601, + "PasswordLastUsed": None, # not supported + "PermissionsBoundary": {}, # ToDo: add put_user_permissions_boundary() functionality + "Tags": {}, # ToDo: add tag_user() functionality } user.enable_mfa_device( - serial_number, - authentication_code_1, - authentication_code_2 + serial_number, authentication_code_1, authentication_code_2 ) def deactivate_mfa_device(self, user_name, serial_number): """Deactivate and detach MFA Device from user if device exists.""" user = self.get_user(user_name) if serial_number not in user.mfa_devices: - raise IAMNotFoundException( - "Device {0} not found".format(serial_number) - ) + raise IAMNotFoundException("Device {0} not found".format(serial_number)) device = self.virtual_mfa_devices.get(serial_number, None) if device: @@ -1308,25 +1395,33 @@ class IAMBackend(BaseBackend): def create_virtual_mfa_device(self, device_name, path): if not path: - path = '/' + path = "/" - if not path.startswith('/') and not path.endswith('/'): - raise ValidationError('The specified value for path is invalid. ' - 'It must begin and end with / and contain only alphanumeric characters and/or / characters.') + if not path.startswith("/") and not path.endswith("/"): + raise ValidationError( + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters." + ) - if any(not len(part) for part in path.split('/')[1:-1]): - raise ValidationError('The specified value for path is invalid. ' - 'It must begin and end with / and contain only alphanumeric characters and/or / characters.') + if any(not len(part) for part in path.split("/")[1:-1]): + raise ValidationError( + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters." + ) if len(path) > 512: - raise ValidationError('1 validation error detected: ' - 'Value "{}" at "path" failed to satisfy constraint: ' - 'Member must have length less than or equal to 512') + raise ValidationError( + "1 validation error detected: " + 'Value "{}" at "path" failed to satisfy constraint: ' + "Member must have length less than or equal to 512" + ) device = VirtualMfaDevice(path + device_name) if device.serial_number in self.virtual_mfa_devices: - raise EntityAlreadyExists('MFADevice entity at the same path and name already exists.') + raise EntityAlreadyExists( + "MFADevice entity at the same path and name already exists." + ) self.virtual_mfa_devices[device.serial_number] = device return device @@ -1335,15 +1430,19 @@ class IAMBackend(BaseBackend): device = self.virtual_mfa_devices.pop(serial_number, None) if not device: - raise IAMNotFoundException('VirtualMFADevice with serial number {0} doesn\'t exist.'.format(serial_number)) + raise IAMNotFoundException( + "VirtualMFADevice with serial number {0} doesn't exist.".format( + serial_number + ) + ) def list_virtual_mfa_devices(self, assignment_status, marker, max_items): devices = list(self.virtual_mfa_devices.values()) - if assignment_status == 'Assigned': + if assignment_status == "Assigned": devices = [device for device in devices if device.enable_date] - if assignment_status == 'Unassigned': + if assignment_status == "Unassigned": devices = [device for device in devices if not device.enable_date] sorted(devices, key=lambda device: device.serial_number) @@ -1351,9 +1450,9 @@ class IAMBackend(BaseBackend): start_idx = int(marker) if marker else 0 if start_idx > len(devices): - raise ValidationError('Invalid Marker.') + raise ValidationError("Invalid Marker.") - devices = devices[start_idx:start_idx + max_items] + devices = devices[start_idx : start_idx + max_items] if len(devices) < max_items: marker = None @@ -1367,12 +1466,12 @@ class IAMBackend(BaseBackend): if user.managed_policies: raise IAMConflictException( code="DeleteConflict", - message="Cannot delete entity, must detach all policies first." + message="Cannot delete entity, must detach all policies first.", ) if user.policies: raise IAMConflictException( code="DeleteConflict", - message="Cannot delete entity, must delete policies first." + message="Cannot delete entity, must delete policies first.", ) del self.users[user_name] @@ -1385,10 +1484,10 @@ class IAMBackend(BaseBackend): def get_credential_report(self): if not self.credential_report: raise IAMReportNotPresentException("Credential report not present") - report = 'user,arn,user_creation_time,password_enabled,password_last_used,password_last_changed,password_next_rotation,mfa_active,access_key_1_active,access_key_1_last_rotated,access_key_2_active,access_key_2_last_rotated,cert_1_active,cert_1_last_rotated,cert_2_active,cert_2_last_rotated\n' + report = "user,arn,user_creation_time,password_enabled,password_last_used,password_last_changed,password_next_rotation,mfa_active,access_key_1_active,access_key_1_last_rotated,access_key_2_active,access_key_2_last_rotated,cert_1_active,cert_1_last_rotated,cert_2_active,cert_2_last_rotated\n" for user in self.users: report += self.users[user].to_csv() - return base64.b64encode(report.encode('ascii')).decode('ascii') + return base64.b64encode(report.encode("ascii")).decode("ascii") def list_account_aliases(self): return self.account_aliases @@ -1407,24 +1506,24 @@ class IAMBackend(BaseBackend): if len(filter) == 0: return { - 'instance_profiles': self.instance_profiles.values(), - 'roles': self.roles.values(), - 'groups': self.groups.values(), - 'users': self.users.values(), - 'managed_policies': self.managed_policies.values() + "instance_profiles": self.instance_profiles.values(), + "roles": self.roles.values(), + "groups": self.groups.values(), + "users": self.users.values(), + "managed_policies": self.managed_policies.values(), } - if 'AWSManagedPolicy' in filter: + if "AWSManagedPolicy" in filter: returned_policies = aws_managed_policies - if 'LocalManagedPolicy' in filter: + if "LocalManagedPolicy" in filter: returned_policies = returned_policies + list(local_policies) return { - 'instance_profiles': self.instance_profiles.values(), - 'roles': self.roles.values() if 'Role' in filter else [], - 'groups': self.groups.values() if 'Group' in filter else [], - 'users': self.users.values() if 'User' in filter else [], - 'managed_policies': returned_policies + "instance_profiles": self.instance_profiles.values(), + "roles": self.roles.values() if "Role" in filter else [], + "groups": self.groups.values() if "Group" in filter else [], + "users": self.users.values() if "User" in filter else [], + "managed_policies": returned_policies, } def create_saml_provider(self, name, saml_metadata_document): @@ -1444,7 +1543,8 @@ class IAMBackend(BaseBackend): del self.saml_providers[saml_provider.name] except KeyError: raise IAMNotFoundException( - "SAMLProvider {0} not found".format(saml_provider_arn)) + "SAMLProvider {0} not found".format(saml_provider_arn) + ) def list_saml_providers(self): return self.saml_providers.values() @@ -1453,7 +1553,9 @@ class IAMBackend(BaseBackend): for saml_provider in self.list_saml_providers(): if saml_provider.arn == saml_provider_arn: return saml_provider - raise IAMNotFoundException("SamlProvider {0} not found".format(saml_provider_arn)) + raise IAMNotFoundException( + "SamlProvider {0} not found".format(saml_provider_arn) + ) def get_user_from_access_key_id(self, access_key_id): for user_name, user in self.users.items(): @@ -1467,7 +1569,7 @@ class IAMBackend(BaseBackend): open_id_provider = OpenIDConnectProvider(url, thumbprint_list, client_id_list) if open_id_provider.arn in self.open_id_providers: - raise EntityAlreadyExists('Unknown') + raise EntityAlreadyExists("Unknown") self.open_id_providers[open_id_provider.arn] = open_id_provider return open_id_provider @@ -1479,7 +1581,9 @@ class IAMBackend(BaseBackend): open_id_provider = self.open_id_providers.get(arn) if not open_id_provider: - raise IAMNotFoundException('OpenIDConnect Provider not found for arn {}'.format(arn)) + raise IAMNotFoundException( + "OpenIDConnect Provider not found for arn {}".format(arn) + ) return open_id_provider diff --git a/moto/iam/policy_validation.py b/moto/iam/policy_validation.py index d9a4b0282..95610ac4d 100644 --- a/moto/iam/policy_validation.py +++ b/moto/iam/policy_validation.py @@ -6,17 +6,9 @@ from six import string_types from moto.iam.exceptions import MalformedPolicyDocument -VALID_TOP_ELEMENTS = [ - "Version", - "Id", - "Statement", - "Conditions" -] +VALID_TOP_ELEMENTS = ["Version", "Id", "Statement", "Conditions"] -VALID_VERSIONS = [ - "2008-10-17", - "2012-10-17" -] +VALID_VERSIONS = ["2008-10-17", "2012-10-17"] VALID_STATEMENT_ELEMENTS = [ "Sid", @@ -25,13 +17,10 @@ VALID_STATEMENT_ELEMENTS = [ "Resource", "NotResource", "Effect", - "Condition" + "Condition", ] -VALID_EFFECTS = [ - "Allow", - "Deny" -] +VALID_EFFECTS = ["Allow", "Deny"] VALID_CONDITIONS = [ "StringEquals", @@ -60,34 +49,41 @@ VALID_CONDITIONS = [ "ArnLike", "ArnNotEquals", "ArnNotLike", - "Null" + "Null", ] -VALID_CONDITION_PREFIXES = [ - "ForAnyValue:", - "ForAllValues:" -] +VALID_CONDITION_PREFIXES = ["ForAnyValue:", "ForAllValues:"] -VALID_CONDITION_POSTFIXES = [ - "IfExists" -] +VALID_CONDITION_POSTFIXES = ["IfExists"] SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS = { - "iam": 'IAM resource {resource} cannot contain region information.', - "s3": 'Resource {resource} can not contain region information.' + "iam": "IAM resource {resource} cannot contain region information.", + "s3": "Resource {resource} can not contain region information.", } VALID_RESOURCE_PATH_STARTING_VALUES = { "iam": { - "values": ["user/", "federated-user/", "role/", "group/", "instance-profile/", "mfa/", "server-certificate/", - "policy/", "sms-mfa/", "saml-provider/", "oidc-provider/", "report/", "access-report/"], - "error_message": 'IAM resource path must either be "*" or start with {values}.' + "values": [ + "user/", + "federated-user/", + "role/", + "group/", + "instance-profile/", + "mfa/", + "server-certificate/", + "policy/", + "sms-mfa/", + "saml-provider/", + "oidc-provider/", + "report/", + "access-report/", + ], + "error_message": 'IAM resource path must either be "*" or start with {values}.', } } class IAMPolicyDocumentValidator: - def __init__(self, policy_document): self._policy_document = policy_document self._policy_json = {} @@ -102,7 +98,9 @@ class IAMPolicyDocumentValidator: try: self._validate_version() except Exception: - raise MalformedPolicyDocument("Policy document must be version 2012-10-17 or greater.") + raise MalformedPolicyDocument( + "Policy document must be version 2012-10-17 or greater." + ) try: self._perform_first_legacy_parsing() self._validate_resources_for_formats() @@ -112,7 +110,9 @@ class IAMPolicyDocumentValidator: try: self._validate_sid_uniqueness() except Exception: - raise MalformedPolicyDocument("Statement IDs (SID) in a single policy must be unique.") + raise MalformedPolicyDocument( + "Statement IDs (SID) in a single policy must be unique." + ) try: self._validate_action_like_exist() except Exception: @@ -176,8 +176,8 @@ class IAMPolicyDocumentValidator: for statement_element in statement.keys(): assert statement_element in VALID_STATEMENT_ELEMENTS - assert ("Resource" not in statement or "NotResource" not in statement) - assert ("Action" not in statement or "NotAction" not in statement) + assert "Resource" not in statement or "NotResource" not in statement + assert "Action" not in statement or "NotAction" not in statement IAMPolicyDocumentValidator._validate_effect_syntax(statement) IAMPolicyDocumentValidator._validate_action_syntax(statement) @@ -191,23 +191,33 @@ class IAMPolicyDocumentValidator: def _validate_effect_syntax(statement): assert "Effect" in statement assert isinstance(statement["Effect"], string_types) - assert statement["Effect"].lower() in [allowed_effect.lower() for allowed_effect in VALID_EFFECTS] + assert statement["Effect"].lower() in [ + allowed_effect.lower() for allowed_effect in VALID_EFFECTS + ] @staticmethod def _validate_action_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "Action") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "Action" + ) @staticmethod def _validate_not_action_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "NotAction") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "NotAction" + ) @staticmethod def _validate_resource_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "Resource") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "Resource" + ) @staticmethod def _validate_not_resource_syntax(statement): - IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax(statement, "NotResource") + IAMPolicyDocumentValidator._validate_string_or_list_of_strings_syntax( + statement, "NotResource" + ) @staticmethod def _validate_string_or_list_of_strings_syntax(statement, key): @@ -223,22 +233,28 @@ class IAMPolicyDocumentValidator: assert isinstance(statement["Condition"], dict) for condition_key, condition_value in statement["Condition"].items(): assert isinstance(condition_value, dict) - for condition_element_key, condition_element_value in condition_value.items(): + for ( + condition_element_key, + condition_element_value, + ) in condition_value.items(): assert isinstance(condition_element_value, (list, string_types)) - if IAMPolicyDocumentValidator._strip_condition_key(condition_key) not in VALID_CONDITIONS: + if ( + IAMPolicyDocumentValidator._strip_condition_key(condition_key) + not in VALID_CONDITIONS + ): assert not condition_value # empty dict @staticmethod def _strip_condition_key(condition_key): for valid_prefix in VALID_CONDITION_PREFIXES: if condition_key.startswith(valid_prefix): - condition_key = condition_key[len(valid_prefix):] + condition_key = condition_key[len(valid_prefix) :] break # strip only the first match for valid_postfix in VALID_CONDITION_POSTFIXES: if condition_key.endswith(valid_postfix): - condition_key = condition_key[:-len(valid_postfix)] + condition_key = condition_key[: -len(valid_postfix)] break # strip only the first match return condition_key @@ -254,15 +270,17 @@ class IAMPolicyDocumentValidator: def _validate_resource_exist(self): for statement in self._statements: - assert ("Resource" in statement or "NotResource" in statement) + assert "Resource" in statement or "NotResource" in statement if "Resource" in statement and isinstance(statement["Resource"], list): assert statement["Resource"] - elif "NotResource" in statement and isinstance(statement["NotResource"], list): + elif "NotResource" in statement and isinstance( + statement["NotResource"], list + ): assert statement["NotResource"] def _validate_action_like_exist(self): for statement in self._statements: - assert ("Action" in statement or "NotAction" in statement) + assert "Action" in statement or "NotAction" in statement if "Action" in statement and isinstance(statement["Action"], list): assert statement["Action"] elif "NotAction" in statement and isinstance(statement["NotAction"], list): @@ -287,13 +305,19 @@ class IAMPolicyDocumentValidator: def _validate_action_prefix(action): action_parts = action.split(":") if len(action_parts) == 1 and action_parts[0] != "*": - raise MalformedPolicyDocument("Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.") + raise MalformedPolicyDocument( + "Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc." + ) elif len(action_parts) > 2: - raise MalformedPolicyDocument("Actions/Condition can contain only one colon.") + raise MalformedPolicyDocument( + "Actions/Condition can contain only one colon." + ) - vendor_pattern = re.compile(r'[^a-zA-Z0-9\-.]') + vendor_pattern = re.compile(r"[^a-zA-Z0-9\-.]") if action_parts[0] != "*" and vendor_pattern.search(action_parts[0]): - raise MalformedPolicyDocument("Vendor {vendor} is not valid".format(vendor=action_parts[0])) + raise MalformedPolicyDocument( + "Vendor {vendor} is not valid".format(vendor=action_parts[0]) + ) def _validate_resources_for_formats(self): self._validate_resource_like_for_formats("Resource") @@ -310,30 +334,51 @@ class IAMPolicyDocumentValidator: for resource in sorted(statement[key], reverse=True): self._validate_resource_format(resource) if self._resource_error == "": - IAMPolicyDocumentValidator._legacy_parse_resource_like(statement, key) + IAMPolicyDocumentValidator._legacy_parse_resource_like( + statement, key + ) def _validate_resource_format(self, resource): if resource != "*": resource_partitions = resource.partition(":") if resource_partitions[1] == "": - self._resource_error = 'Resource {resource} must be in ARN format or "*".'.format(resource=resource) + self._resource_error = 'Resource {resource} must be in ARN format or "*".'.format( + resource=resource + ) return resource_partitions = resource_partitions[2].partition(":") if resource_partitions[0] != "aws": remaining_resource_parts = resource_partitions[2].split(":") - arn1 = remaining_resource_parts[0] if remaining_resource_parts[0] != "" or len(remaining_resource_parts) > 1 else "*" - arn2 = remaining_resource_parts[1] if len(remaining_resource_parts) > 1 else "*" - arn3 = remaining_resource_parts[2] if len(remaining_resource_parts) > 2 else "*" - arn4 = ":".join(remaining_resource_parts[3:]) if len(remaining_resource_parts) > 3 else "*" + arn1 = ( + remaining_resource_parts[0] + if remaining_resource_parts[0] != "" + or len(remaining_resource_parts) > 1 + else "*" + ) + arn2 = ( + remaining_resource_parts[1] + if len(remaining_resource_parts) > 1 + else "*" + ) + arn3 = ( + remaining_resource_parts[2] + if len(remaining_resource_parts) > 2 + else "*" + ) + arn4 = ( + ":".join(remaining_resource_parts[3:]) + if len(remaining_resource_parts) > 3 + else "*" + ) self._resource_error = 'Partition "{partition}" is not valid for resource "arn:{partition}:{arn1}:{arn2}:{arn3}:{arn4}".'.format( partition=resource_partitions[0], arn1=arn1, arn2=arn2, arn3=arn3, - arn4=arn4 + arn4=arn4, ) return @@ -345,8 +390,16 @@ class IAMPolicyDocumentValidator: service = resource_partitions[0] - if service in SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS.keys() and not resource_partitions[2].startswith(":"): - self._resource_error = SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS[service].format(resource=resource) + if service in SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS.keys() and not resource_partitions[ + 2 + ].startswith( + ":" + ): + self._resource_error = SERVICE_TYPE_REGION_INFORMATION_ERROR_ASSOCIATIONS[ + service + ].format( + resource=resource + ) return resource_partitions = resource_partitions[2].partition(":") @@ -354,13 +407,19 @@ class IAMPolicyDocumentValidator: if service in VALID_RESOURCE_PATH_STARTING_VALUES.keys(): valid_start = False - for valid_starting_value in VALID_RESOURCE_PATH_STARTING_VALUES[service]["values"]: + for valid_starting_value in VALID_RESOURCE_PATH_STARTING_VALUES[ + service + ]["values"]: if resource_partitions[2].startswith(valid_starting_value): valid_start = True break if not valid_start: - self._resource_error = VALID_RESOURCE_PATH_STARTING_VALUES[service]["error_message"].format( - values=", ".join(VALID_RESOURCE_PATH_STARTING_VALUES[service]["values"]) + self._resource_error = VALID_RESOURCE_PATH_STARTING_VALUES[service][ + "error_message" + ].format( + values=", ".join( + VALID_RESOURCE_PATH_STARTING_VALUES[service]["values"] + ) ) def _perform_first_legacy_parsing(self): @@ -373,7 +432,9 @@ class IAMPolicyDocumentValidator: assert statement["Effect"] in VALID_EFFECTS # case-sensitive matching if "Condition" in statement: for condition_key, condition_value in statement["Condition"].items(): - IAMPolicyDocumentValidator._legacy_parse_condition(condition_key, condition_value) + IAMPolicyDocumentValidator._legacy_parse_condition( + condition_key, condition_value + ) @staticmethod def _legacy_parse_resource_like(statement, key): @@ -389,20 +450,31 @@ class IAMPolicyDocumentValidator: @staticmethod def _legacy_parse_condition(condition_key, condition_value): - stripped_condition_key = IAMPolicyDocumentValidator._strip_condition_key(condition_key) + stripped_condition_key = IAMPolicyDocumentValidator._strip_condition_key( + condition_key + ) if stripped_condition_key.startswith("Date"): - for condition_element_key, condition_element_value in condition_value.items(): + for ( + condition_element_key, + condition_element_value, + ) in condition_value.items(): if isinstance(condition_element_value, string_types): - IAMPolicyDocumentValidator._legacy_parse_date_condition_value(condition_element_value) + IAMPolicyDocumentValidator._legacy_parse_date_condition_value( + condition_element_value + ) else: # it has to be a list for date_condition_value in condition_element_value: - IAMPolicyDocumentValidator._legacy_parse_date_condition_value(date_condition_value) + IAMPolicyDocumentValidator._legacy_parse_date_condition_value( + date_condition_value + ) @staticmethod def _legacy_parse_date_condition_value(date_condition_value): if "t" in date_condition_value.lower() or "-" in date_condition_value: - IAMPolicyDocumentValidator._validate_iso_8601_datetime(date_condition_value.lower()) + IAMPolicyDocumentValidator._validate_iso_8601_datetime( + date_condition_value.lower() + ) else: # timestamp assert 0 <= int(date_condition_value) <= 9223372036854775807 @@ -410,7 +482,11 @@ class IAMPolicyDocumentValidator: def _validate_iso_8601_datetime(datetime): datetime_parts = datetime.partition("t") negative_year = datetime_parts[0].startswith("-") - date_parts = datetime_parts[0][1:].split("-") if negative_year else datetime_parts[0].split("-") + date_parts = ( + datetime_parts[0][1:].split("-") + if negative_year + else datetime_parts[0].split("-") + ) year = "-" + date_parts[0] if negative_year else date_parts[0] assert -292275054 <= int(year) <= 292278993 if len(date_parts) > 1: @@ -444,7 +520,9 @@ class IAMPolicyDocumentValidator: assert 0 <= int(time_zone_minutes) <= 59 else: seconds_with_decimal_fraction = time_parts[2] - seconds_with_decimal_fraction_partition = seconds_with_decimal_fraction.partition(".") + seconds_with_decimal_fraction_partition = seconds_with_decimal_fraction.partition( + "." + ) seconds = seconds_with_decimal_fraction_partition[0] assert 0 <= int(seconds) <= 59 if seconds_with_decimal_fraction_partition[1] == ".": diff --git a/moto/iam/responses.py b/moto/iam/responses.py index 01cbeb712..d18fac88d 100644 --- a/moto/iam/responses.py +++ b/moto/iam/responses.py @@ -6,123 +6,125 @@ from .models import iam_backend, User class IamResponse(BaseResponse): - def attach_role_policy(self): - policy_arn = self._get_param('PolicyArn') - role_name = self._get_param('RoleName') + policy_arn = self._get_param("PolicyArn") + role_name = self._get_param("RoleName") iam_backend.attach_role_policy(policy_arn, role_name) template = self.response_template(ATTACH_ROLE_POLICY_TEMPLATE) return template.render() def detach_role_policy(self): - role_name = self._get_param('RoleName') - policy_arn = self._get_param('PolicyArn') + role_name = self._get_param("RoleName") + policy_arn = self._get_param("PolicyArn") iam_backend.detach_role_policy(policy_arn, role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DetachRolePolicyResponse") def attach_group_policy(self): - policy_arn = self._get_param('PolicyArn') - group_name = self._get_param('GroupName') + policy_arn = self._get_param("PolicyArn") + group_name = self._get_param("GroupName") iam_backend.attach_group_policy(policy_arn, group_name) template = self.response_template(ATTACH_GROUP_POLICY_TEMPLATE) return template.render() def detach_group_policy(self): - policy_arn = self._get_param('PolicyArn') - group_name = self._get_param('GroupName') + policy_arn = self._get_param("PolicyArn") + group_name = self._get_param("GroupName") iam_backend.detach_group_policy(policy_arn, group_name) template = self.response_template(DETACH_GROUP_POLICY_TEMPLATE) return template.render() def attach_user_policy(self): - policy_arn = self._get_param('PolicyArn') - user_name = self._get_param('UserName') + policy_arn = self._get_param("PolicyArn") + user_name = self._get_param("UserName") iam_backend.attach_user_policy(policy_arn, user_name) template = self.response_template(ATTACH_USER_POLICY_TEMPLATE) return template.render() def detach_user_policy(self): - policy_arn = self._get_param('PolicyArn') - user_name = self._get_param('UserName') + policy_arn = self._get_param("PolicyArn") + user_name = self._get_param("UserName") iam_backend.detach_user_policy(policy_arn, user_name) template = self.response_template(DETACH_USER_POLICY_TEMPLATE) return template.render() def create_policy(self): - description = self._get_param('Description') - path = self._get_param('Path') - policy_document = self._get_param('PolicyDocument') - policy_name = self._get_param('PolicyName') + description = self._get_param("Description") + path = self._get_param("Path") + policy_document = self._get_param("PolicyDocument") + policy_name = self._get_param("PolicyName") policy = iam_backend.create_policy( - description, path, policy_document, policy_name) + description, path, policy_document, policy_name + ) template = self.response_template(CREATE_POLICY_TEMPLATE) return template.render(policy=policy) def get_policy(self): - policy_arn = self._get_param('PolicyArn') + policy_arn = self._get_param("PolicyArn") policy = iam_backend.get_policy(policy_arn) template = self.response_template(GET_POLICY_TEMPLATE) return template.render(policy=policy) def list_attached_role_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - path_prefix = self._get_param('PathPrefix', '/') - role_name = self._get_param('RoleName') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + path_prefix = self._get_param("PathPrefix", "/") + role_name = self._get_param("RoleName") policies, marker = iam_backend.list_attached_role_policies( - role_name, marker=marker, max_items=max_items, path_prefix=path_prefix) + role_name, marker=marker, max_items=max_items, path_prefix=path_prefix + ) template = self.response_template(LIST_ATTACHED_ROLE_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_attached_group_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - path_prefix = self._get_param('PathPrefix', '/') - group_name = self._get_param('GroupName') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + path_prefix = self._get_param("PathPrefix", "/") + group_name = self._get_param("GroupName") policies, marker = iam_backend.list_attached_group_policies( - group_name, marker=marker, max_items=max_items, - path_prefix=path_prefix) + group_name, marker=marker, max_items=max_items, path_prefix=path_prefix + ) template = self.response_template(LIST_ATTACHED_GROUP_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_attached_user_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - path_prefix = self._get_param('PathPrefix', '/') - user_name = self._get_param('UserName') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + path_prefix = self._get_param("PathPrefix", "/") + user_name = self._get_param("UserName") policies, marker = iam_backend.list_attached_user_policies( - user_name, marker=marker, max_items=max_items, - path_prefix=path_prefix) + user_name, marker=marker, max_items=max_items, path_prefix=path_prefix + ) template = self.response_template(LIST_ATTACHED_USER_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_policies(self): - marker = self._get_param('Marker') - max_items = self._get_int_param('MaxItems', 100) - only_attached = self._get_bool_param('OnlyAttached', False) - path_prefix = self._get_param('PathPrefix', '/') - scope = self._get_param('Scope', 'All') + marker = self._get_param("Marker") + max_items = self._get_int_param("MaxItems", 100) + only_attached = self._get_bool_param("OnlyAttached", False) + path_prefix = self._get_param("PathPrefix", "/") + scope = self._get_param("Scope", "All") policies, marker = iam_backend.list_policies( - marker, max_items, only_attached, path_prefix, scope) + marker, max_items, only_attached, path_prefix, scope + ) template = self.response_template(LIST_POLICIES_TEMPLATE) return template.render(policies=policies, marker=marker) def list_entities_for_policy(self): - policy_arn = self._get_param('PolicyArn') + policy_arn = self._get_param("PolicyArn") # Options 'User'|'Role'|'Group'|'LocalManagedPolicy'|'AWSManagedPolicy - entity = self._get_param('EntityFilter') - path_prefix = self._get_param('PathPrefix') + entity = self._get_param("EntityFilter") + path_prefix = self._get_param("PathPrefix") # policy_usage_filter = self._get_param('PolicyUsageFilter') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems') + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems") entity_roles = [] entity_groups = [] entity_users = [] - if entity == 'User': + if entity == "User": users = iam_backend.list_users(path_prefix, marker, max_items) if users: for user in users: @@ -130,7 +132,7 @@ class IamResponse(BaseResponse): if p == policy_arn: entity_users.append(user.name) - elif entity == 'Role': + elif entity == "Role": roles = iam_backend.list_roles(path_prefix, marker, max_items) if roles: for role in roles: @@ -138,7 +140,7 @@ class IamResponse(BaseResponse): if p == policy_arn: entity_roles.append(role.name) - elif entity == 'Group': + elif entity == "Group": groups = iam_backend.list_groups() if groups: for group in groups: @@ -146,7 +148,7 @@ class IamResponse(BaseResponse): if p == policy_arn: entity_groups.append(group.name) - elif entity == 'LocalManagedPolicy' or entity == 'AWSManagedPolicy': + elif entity == "LocalManagedPolicy" or entity == "AWSManagedPolicy": users = iam_backend.list_users(path_prefix, marker, max_items) if users: for user in users: @@ -169,150 +171,158 @@ class IamResponse(BaseResponse): entity_groups.append(group.name) template = self.response_template(LIST_ENTITIES_FOR_POLICY_TEMPLATE) - return template.render(roles=entity_roles, users=entity_users, groups=entity_groups) + return template.render( + roles=entity_roles, users=entity_users, groups=entity_groups + ) def create_role(self): - role_name = self._get_param('RoleName') - path = self._get_param('Path') - assume_role_policy_document = self._get_param( - 'AssumeRolePolicyDocument') - permissions_boundary = self._get_param( - 'PermissionsBoundary') - description = self._get_param('Description') - tags = self._get_multi_param('Tags.member') + role_name = self._get_param("RoleName") + path = self._get_param("Path") + assume_role_policy_document = self._get_param("AssumeRolePolicyDocument") + permissions_boundary = self._get_param("PermissionsBoundary") + description = self._get_param("Description") + tags = self._get_multi_param("Tags.member") role = iam_backend.create_role( - role_name, assume_role_policy_document, path, permissions_boundary, description, tags) + role_name, + assume_role_policy_document, + path, + permissions_boundary, + description, + tags, + ) template = self.response_template(CREATE_ROLE_TEMPLATE) return template.render(role=role) def get_role(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") role = iam_backend.get_role(role_name) template = self.response_template(GET_ROLE_TEMPLATE) return template.render(role=role) def delete_role(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") iam_backend.delete_role(role_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRoleResponse") def list_role_policies(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") role_policies_names = iam_backend.list_role_policies(role_name) template = self.response_template(LIST_ROLE_POLICIES) return template.render(role_policies=role_policies_names) def put_role_policy(self): - role_name = self._get_param('RoleName') - policy_name = self._get_param('PolicyName') - policy_document = self._get_param('PolicyDocument') + role_name = self._get_param("RoleName") + policy_name = self._get_param("PolicyName") + policy_document = self._get_param("PolicyDocument") iam_backend.put_role_policy(role_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutRolePolicyResponse") def delete_role_policy(self): - role_name = self._get_param('RoleName') - policy_name = self._get_param('PolicyName') + role_name = self._get_param("RoleName") + policy_name = self._get_param("PolicyName") iam_backend.delete_role_policy(role_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteRolePolicyResponse") def get_role_policy(self): - role_name = self._get_param('RoleName') - policy_name = self._get_param('PolicyName') + role_name = self._get_param("RoleName") + policy_name = self._get_param("PolicyName") policy_name, policy_document = iam_backend.get_role_policy( - role_name, policy_name) + role_name, policy_name + ) template = self.response_template(GET_ROLE_POLICY_TEMPLATE) - return template.render(role_name=role_name, - policy_name=policy_name, - policy_document=policy_document) + return template.render( + role_name=role_name, + policy_name=policy_name, + policy_document=policy_document, + ) def update_assume_role_policy(self): - role_name = self._get_param('RoleName') + role_name = self._get_param("RoleName") role = iam_backend.get_role(role_name) - role.assume_role_policy_document = self._get_param('PolicyDocument') + role.assume_role_policy_document = self._get_param("PolicyDocument") template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="UpdateAssumeRolePolicyResponse") def update_role_description(self): - role_name = self._get_param('RoleName') - description = self._get_param('Description') + role_name = self._get_param("RoleName") + description = self._get_param("Description") role = iam_backend.update_role_description(role_name, description) template = self.response_template(UPDATE_ROLE_DESCRIPTION_TEMPLATE) return template.render(role=role) def update_role(self): - role_name = self._get_param('RoleName') - description = self._get_param('Description') + role_name = self._get_param("RoleName") + description = self._get_param("Description") role = iam_backend.update_role(role_name, description) template = self.response_template(UPDATE_ROLE_TEMPLATE) return template.render(role=role) def create_policy_version(self): - policy_arn = self._get_param('PolicyArn') - policy_document = self._get_param('PolicyDocument') - set_as_default = self._get_param('SetAsDefault') - policy_version = iam_backend.create_policy_version(policy_arn, policy_document, set_as_default) + policy_arn = self._get_param("PolicyArn") + policy_document = self._get_param("PolicyDocument") + set_as_default = self._get_param("SetAsDefault") + policy_version = iam_backend.create_policy_version( + policy_arn, policy_document, set_as_default + ) template = self.response_template(CREATE_POLICY_VERSION_TEMPLATE) return template.render(policy_version=policy_version) def get_policy_version(self): - policy_arn = self._get_param('PolicyArn') - version_id = self._get_param('VersionId') + policy_arn = self._get_param("PolicyArn") + version_id = self._get_param("VersionId") policy_version = iam_backend.get_policy_version(policy_arn, version_id) template = self.response_template(GET_POLICY_VERSION_TEMPLATE) return template.render(policy_version=policy_version) def list_policy_versions(self): - policy_arn = self._get_param('PolicyArn') + policy_arn = self._get_param("PolicyArn") policy_versions = iam_backend.list_policy_versions(policy_arn) template = self.response_template(LIST_POLICY_VERSIONS_TEMPLATE) return template.render(policy_versions=policy_versions) def delete_policy_version(self): - policy_arn = self._get_param('PolicyArn') - version_id = self._get_param('VersionId') + policy_arn = self._get_param("PolicyArn") + version_id = self._get_param("VersionId") iam_backend.delete_policy_version(policy_arn, version_id) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeletePolicyVersion') + return template.render(name="DeletePolicyVersion") def create_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') - path = self._get_param('Path', '/') + profile_name = self._get_param("InstanceProfileName") + path = self._get_param("Path", "/") - profile = iam_backend.create_instance_profile( - profile_name, path, role_ids=[]) + profile = iam_backend.create_instance_profile(profile_name, path, role_ids=[]) template = self.response_template(CREATE_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) def get_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') + profile_name = self._get_param("InstanceProfileName") profile = iam_backend.get_instance_profile(profile_name) template = self.response_template(GET_INSTANCE_PROFILE_TEMPLATE) return template.render(profile=profile) def add_role_to_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') - role_name = self._get_param('RoleName') + profile_name = self._get_param("InstanceProfileName") + role_name = self._get_param("RoleName") iam_backend.add_role_to_instance_profile(profile_name, role_name) - template = self.response_template( - ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE) + template = self.response_template(ADD_ROLE_TO_INSTANCE_PROFILE_TEMPLATE) return template.render() def remove_role_from_instance_profile(self): - profile_name = self._get_param('InstanceProfileName') - role_name = self._get_param('RoleName') + profile_name = self._get_param("InstanceProfileName") + role_name = self._get_param("RoleName") iam_backend.remove_role_from_instance_profile(profile_name, role_name) - template = self.response_template( - REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE) + template = self.response_template(REMOVE_ROLE_FROM_INSTANCE_PROFILE_TEMPLATE) return template.render() def list_roles(self): @@ -328,23 +338,22 @@ class IamResponse(BaseResponse): return template.render(instance_profiles=profiles) def list_instance_profiles_for_role(self): - role_name = self._get_param('RoleName') - profiles = iam_backend.get_instance_profiles_for_role( - role_name=role_name) + role_name = self._get_param("RoleName") + profiles = iam_backend.get_instance_profiles_for_role(role_name=role_name) - template = self.response_template( - LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE) + template = self.response_template(LIST_INSTANCE_PROFILES_FOR_ROLE_TEMPLATE) return template.render(instance_profiles=profiles) def upload_server_certificate(self): - cert_name = self._get_param('ServerCertificateName') - cert_body = self._get_param('CertificateBody') - path = self._get_param('Path') - private_key = self._get_param('PrivateKey') - cert_chain = self._get_param('CertificateName') + cert_name = self._get_param("ServerCertificateName") + cert_body = self._get_param("CertificateBody") + path = self._get_param("Path") + private_key = self._get_param("PrivateKey") + cert_chain = self._get_param("CertificateName") cert = iam_backend.upload_server_cert( - cert_name, cert_body, private_key, cert_chain=cert_chain, path=path) + cert_name, cert_body, private_key, cert_chain=cert_chain, path=path + ) template = self.response_template(UPLOAD_CERT_TEMPLATE) return template.render(certificate=cert) @@ -354,27 +363,27 @@ class IamResponse(BaseResponse): return template.render(server_certificates=certs) def get_server_certificate(self): - cert_name = self._get_param('ServerCertificateName') + cert_name = self._get_param("ServerCertificateName") cert = iam_backend.get_server_certificate(cert_name) template = self.response_template(GET_SERVER_CERTIFICATE_TEMPLATE) return template.render(certificate=cert) def delete_server_certificate(self): - cert_name = self._get_param('ServerCertificateName') + cert_name = self._get_param("ServerCertificateName") iam_backend.delete_server_certificate(cert_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="DeleteServerCertificate") def create_group(self): - group_name = self._get_param('GroupName') - path = self._get_param('Path', '/') + group_name = self._get_param("GroupName") + path = self._get_param("Path", "/") group = iam_backend.create_group(group_name, path) template = self.response_template(CREATE_GROUP_TEMPLATE) return template.render(group=group) def get_group(self): - group_name = self._get_param('GroupName') + group_name = self._get_param("GroupName") group = iam_backend.get_group(group_name) template = self.response_template(GET_GROUP_TEMPLATE) @@ -386,48 +395,49 @@ class IamResponse(BaseResponse): return template.render(groups=groups) def list_groups_for_user(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") groups = iam_backend.get_groups_for_user(user_name) template = self.response_template(LIST_GROUPS_FOR_USER_TEMPLATE) return template.render(groups=groups) def put_group_policy(self): - group_name = self._get_param('GroupName') - policy_name = self._get_param('PolicyName') - policy_document = self._get_param('PolicyDocument') + group_name = self._get_param("GroupName") + policy_name = self._get_param("PolicyName") + policy_document = self._get_param("PolicyDocument") iam_backend.put_group_policy(group_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) return template.render(name="PutGroupPolicyResponse") def list_group_policies(self): - group_name = self._get_param('GroupName') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems') - policies = iam_backend.list_group_policies(group_name, - marker=marker, max_items=max_items) + group_name = self._get_param("GroupName") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems") + policies = iam_backend.list_group_policies( + group_name, marker=marker, max_items=max_items + ) template = self.response_template(LIST_GROUP_POLICIES_TEMPLATE) - return template.render(name="ListGroupPoliciesResponse", - policies=policies, - marker=marker) + return template.render( + name="ListGroupPoliciesResponse", policies=policies, marker=marker + ) def get_group_policy(self): - group_name = self._get_param('GroupName') - policy_name = self._get_param('PolicyName') + group_name = self._get_param("GroupName") + policy_name = self._get_param("PolicyName") policy_result = iam_backend.get_group_policy(group_name, policy_name) template = self.response_template(GET_GROUP_POLICY_TEMPLATE) return template.render(name="GetGroupPolicyResponse", **policy_result) def create_user(self): - user_name = self._get_param('UserName') - path = self._get_param('Path') + user_name = self._get_param("UserName") + path = self._get_param("Path") user = iam_backend.create_user(user_name, path) template = self.response_template(USER_TEMPLATE) - return template.render(action='Create', user=user) + return template.render(action="Create", user=user) def get_user(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") if not user_name: access_key_id = self.get_current_user() user = iam_backend.get_user_from_access_key_id(access_key_id) @@ -437,178 +447,182 @@ class IamResponse(BaseResponse): user = iam_backend.get_user(user_name) template = self.response_template(USER_TEMPLATE) - return template.render(action='Get', user=user) + return template.render(action="Get", user=user) def list_users(self): - path_prefix = self._get_param('PathPrefix') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems') + path_prefix = self._get_param("PathPrefix") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems") users = iam_backend.list_users(path_prefix, marker, max_items) template = self.response_template(LIST_USERS_TEMPLATE) - return template.render(action='List', users=users) + return template.render(action="List", users=users) def update_user(self): - user_name = self._get_param('UserName') - new_path = self._get_param('NewPath') - new_user_name = self._get_param('NewUserName') + user_name = self._get_param("UserName") + new_path = self._get_param("NewPath") + new_user_name = self._get_param("NewUserName") iam_backend.update_user(user_name, new_path, new_user_name) if new_user_name: user = iam_backend.get_user(new_user_name) else: user = iam_backend.get_user(user_name) template = self.response_template(USER_TEMPLATE) - return template.render(action='Update', user=user) + return template.render(action="Update", user=user) def create_login_profile(self): - user_name = self._get_param('UserName') - password = self._get_param('Password') + user_name = self._get_param("UserName") + password = self._get_param("Password") user = iam_backend.create_login_profile(user_name, password) template = self.response_template(CREATE_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) def get_login_profile(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") user = iam_backend.get_login_profile(user_name) template = self.response_template(GET_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) def update_login_profile(self): - user_name = self._get_param('UserName') - password = self._get_param('Password') - password_reset_required = self._get_param('PasswordResetRequired') - user = iam_backend.update_login_profile(user_name, password, password_reset_required) + user_name = self._get_param("UserName") + password = self._get_param("Password") + password_reset_required = self._get_param("PasswordResetRequired") + user = iam_backend.update_login_profile( + user_name, password, password_reset_required + ) template = self.response_template(UPDATE_LOGIN_PROFILE_TEMPLATE) return template.render(user=user) def add_user_to_group(self): - group_name = self._get_param('GroupName') - user_name = self._get_param('UserName') + group_name = self._get_param("GroupName") + user_name = self._get_param("UserName") iam_backend.add_user_to_group(group_name, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='AddUserToGroup') + return template.render(name="AddUserToGroup") def remove_user_from_group(self): - group_name = self._get_param('GroupName') - user_name = self._get_param('UserName') + group_name = self._get_param("GroupName") + user_name = self._get_param("UserName") iam_backend.remove_user_from_group(group_name, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='RemoveUserFromGroup') + return template.render(name="RemoveUserFromGroup") def get_user_policy(self): - user_name = self._get_param('UserName') - policy_name = self._get_param('PolicyName') + user_name = self._get_param("UserName") + policy_name = self._get_param("PolicyName") policy_document = iam_backend.get_user_policy(user_name, policy_name) template = self.response_template(GET_USER_POLICY_TEMPLATE) return template.render( user_name=user_name, policy_name=policy_name, - policy_document=policy_document.get('policy_document') + policy_document=policy_document.get("policy_document"), ) def list_user_policies(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") policies = iam_backend.list_user_policies(user_name) template = self.response_template(LIST_USER_POLICIES_TEMPLATE) return template.render(policies=policies) def put_user_policy(self): - user_name = self._get_param('UserName') - policy_name = self._get_param('PolicyName') - policy_document = self._get_param('PolicyDocument') + user_name = self._get_param("UserName") + policy_name = self._get_param("PolicyName") + policy_document = self._get_param("PolicyDocument") iam_backend.put_user_policy(user_name, policy_name, policy_document) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='PutUserPolicy') + return template.render(name="PutUserPolicy") def delete_user_policy(self): - user_name = self._get_param('UserName') - policy_name = self._get_param('PolicyName') + user_name = self._get_param("UserName") + policy_name = self._get_param("PolicyName") iam_backend.delete_user_policy(user_name, policy_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteUserPolicy') + return template.render(name="DeleteUserPolicy") def create_access_key(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") key = iam_backend.create_access_key(user_name) template = self.response_template(CREATE_ACCESS_KEY_TEMPLATE) return template.render(key=key) def update_access_key(self): - user_name = self._get_param('UserName') - access_key_id = self._get_param('AccessKeyId') - status = self._get_param('Status') + user_name = self._get_param("UserName") + access_key_id = self._get_param("AccessKeyId") + status = self._get_param("Status") iam_backend.update_access_key(user_name, access_key_id, status) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='UpdateAccessKey') + return template.render(name="UpdateAccessKey") def get_access_key_last_used(self): - access_key_id = self._get_param('AccessKeyId') + access_key_id = self._get_param("AccessKeyId") last_used_response = iam_backend.get_access_key_last_used(access_key_id) template = self.response_template(GET_ACCESS_KEY_LAST_USED_TEMPLATE) - return template.render(user_name=last_used_response["user_name"], last_used=last_used_response["last_used"]) + return template.render( + user_name=last_used_response["user_name"], + last_used=last_used_response["last_used"], + ) def list_access_keys(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") keys = iam_backend.get_all_access_keys(user_name) template = self.response_template(LIST_ACCESS_KEYS_TEMPLATE) return template.render(user_name=user_name, keys=keys) def delete_access_key(self): - user_name = self._get_param('UserName') - access_key_id = self._get_param('AccessKeyId') + user_name = self._get_param("UserName") + access_key_id = self._get_param("AccessKeyId") iam_backend.delete_access_key(access_key_id, user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteAccessKey') + return template.render(name="DeleteAccessKey") def deactivate_mfa_device(self): - user_name = self._get_param('UserName') - serial_number = self._get_param('SerialNumber') + user_name = self._get_param("UserName") + serial_number = self._get_param("SerialNumber") iam_backend.deactivate_mfa_device(user_name, serial_number) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeactivateMFADevice') + return template.render(name="DeactivateMFADevice") def enable_mfa_device(self): - user_name = self._get_param('UserName') - serial_number = self._get_param('SerialNumber') - authentication_code_1 = self._get_param('AuthenticationCode1') - authentication_code_2 = self._get_param('AuthenticationCode2') + user_name = self._get_param("UserName") + serial_number = self._get_param("SerialNumber") + authentication_code_1 = self._get_param("AuthenticationCode1") + authentication_code_2 = self._get_param("AuthenticationCode2") iam_backend.enable_mfa_device( - user_name, - serial_number, - authentication_code_1, - authentication_code_2 + user_name, serial_number, authentication_code_1, authentication_code_2 ) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='EnableMFADevice') + return template.render(name="EnableMFADevice") def list_mfa_devices(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") devices = iam_backend.list_mfa_devices(user_name) template = self.response_template(LIST_MFA_DEVICES_TEMPLATE) return template.render(user_name=user_name, devices=devices) def create_virtual_mfa_device(self): - path = self._get_param('Path') - virtual_mfa_device_name = self._get_param('VirtualMFADeviceName') + path = self._get_param("Path") + virtual_mfa_device_name = self._get_param("VirtualMFADeviceName") - virtual_mfa_device = iam_backend.create_virtual_mfa_device(virtual_mfa_device_name, path) + virtual_mfa_device = iam_backend.create_virtual_mfa_device( + virtual_mfa_device_name, path + ) template = self.response_template(CREATE_VIRTUAL_MFA_DEVICE_TEMPLATE) return template.render(device=virtual_mfa_device) def delete_virtual_mfa_device(self): - serial_number = self._get_param('SerialNumber') + serial_number = self._get_param("SerialNumber") iam_backend.delete_virtual_mfa_device(serial_number) @@ -616,32 +630,34 @@ class IamResponse(BaseResponse): return template.render() def list_virtual_mfa_devices(self): - assignment_status = self._get_param('AssignmentStatus', 'Any') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems', 100) + assignment_status = self._get_param("AssignmentStatus", "Any") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems", 100) - devices, marker = iam_backend.list_virtual_mfa_devices(assignment_status, marker, max_items) + devices, marker = iam_backend.list_virtual_mfa_devices( + assignment_status, marker, max_items + ) template = self.response_template(LIST_VIRTUAL_MFA_DEVICES_TEMPLATE) return template.render(devices=devices, marker=marker) def delete_user(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") iam_backend.delete_user(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteUser') + return template.render(name="DeleteUser") def delete_policy(self): - policy_arn = self._get_param('PolicyArn') + policy_arn = self._get_param("PolicyArn") iam_backend.delete_policy(policy_arn) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeletePolicy') + return template.render(name="DeletePolicy") def delete_login_profile(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") iam_backend.delete_login_profile(user_name) template = self.response_template(GENERIC_EMPTY_TEMPLATE) - return template.render(name='DeleteLoginProfile') + return template.render(name="DeleteLoginProfile") def generate_credential_report(self): if iam_backend.report_generated(): @@ -662,48 +678,52 @@ class IamResponse(BaseResponse): return template.render(aliases=aliases) def create_account_alias(self): - alias = self._get_param('AccountAlias') + alias = self._get_param("AccountAlias") iam_backend.create_account_alias(alias) template = self.response_template(CREATE_ACCOUNT_ALIAS_TEMPLATE) return template.render() def delete_account_alias(self): - alias = self._get_param('AccountAlias') + alias = self._get_param("AccountAlias") iam_backend.delete_account_alias(alias) template = self.response_template(DELETE_ACCOUNT_ALIAS_TEMPLATE) return template.render() def get_account_authorization_details(self): - filter_param = self._get_multi_param('Filter.member') + filter_param = self._get_multi_param("Filter.member") account_details = iam_backend.get_account_authorization_details(filter_param) template = self.response_template(GET_ACCOUNT_AUTHORIZATION_DETAILS_TEMPLATE) return template.render( - instance_profiles=account_details['instance_profiles'], - policies=account_details['managed_policies'], - users=account_details['users'], - groups=account_details['groups'], - roles=account_details['roles'], - get_groups_for_user=iam_backend.get_groups_for_user + instance_profiles=account_details["instance_profiles"], + policies=account_details["managed_policies"], + users=account_details["users"], + groups=account_details["groups"], + roles=account_details["roles"], + get_groups_for_user=iam_backend.get_groups_for_user, ) def create_saml_provider(self): - saml_provider_name = self._get_param('Name') - saml_metadata_document = self._get_param('SAMLMetadataDocument') - saml_provider = iam_backend.create_saml_provider(saml_provider_name, saml_metadata_document) + saml_provider_name = self._get_param("Name") + saml_metadata_document = self._get_param("SAMLMetadataDocument") + saml_provider = iam_backend.create_saml_provider( + saml_provider_name, saml_metadata_document + ) template = self.response_template(CREATE_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) def update_saml_provider(self): - saml_provider_arn = self._get_param('SAMLProviderArn') - saml_metadata_document = self._get_param('SAMLMetadataDocument') - saml_provider = iam_backend.update_saml_provider(saml_provider_arn, saml_metadata_document) + saml_provider_arn = self._get_param("SAMLProviderArn") + saml_metadata_document = self._get_param("SAMLMetadataDocument") + saml_provider = iam_backend.update_saml_provider( + saml_provider_arn, saml_metadata_document + ) template = self.response_template(UPDATE_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) def delete_saml_provider(self): - saml_provider_arn = self._get_param('SAMLProviderArn') + saml_provider_arn = self._get_param("SAMLProviderArn") iam_backend.delete_saml_provider(saml_provider_arn) template = self.response_template(DELETE_SAML_PROVIDER_TEMPLATE) @@ -716,48 +736,48 @@ class IamResponse(BaseResponse): return template.render(saml_providers=saml_providers) def get_saml_provider(self): - saml_provider_arn = self._get_param('SAMLProviderArn') + saml_provider_arn = self._get_param("SAMLProviderArn") saml_provider = iam_backend.get_saml_provider(saml_provider_arn) template = self.response_template(GET_SAML_PROVIDER_TEMPLATE) return template.render(saml_provider=saml_provider) def upload_signing_certificate(self): - user_name = self._get_param('UserName') - cert_body = self._get_param('CertificateBody') + user_name = self._get_param("UserName") + cert_body = self._get_param("CertificateBody") cert = iam_backend.upload_signing_certificate(user_name, cert_body) template = self.response_template(UPLOAD_SIGNING_CERTIFICATE_TEMPLATE) return template.render(cert=cert) def update_signing_certificate(self): - user_name = self._get_param('UserName') - cert_id = self._get_param('CertificateId') - status = self._get_param('Status') + user_name = self._get_param("UserName") + cert_id = self._get_param("CertificateId") + status = self._get_param("Status") iam_backend.update_signing_certificate(user_name, cert_id, status) template = self.response_template(UPDATE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() def delete_signing_certificate(self): - user_name = self._get_param('UserName') - cert_id = self._get_param('CertificateId') + user_name = self._get_param("UserName") + cert_id = self._get_param("CertificateId") iam_backend.delete_signing_certificate(user_name, cert_id) template = self.response_template(DELETE_SIGNING_CERTIFICATE_TEMPLATE) return template.render() def list_signing_certificates(self): - user_name = self._get_param('UserName') + user_name = self._get_param("UserName") certs = iam_backend.list_signing_certificates(user_name) template = self.response_template(LIST_SIGNING_CERTIFICATES_TEMPLATE) return template.render(user_name=user_name, certificates=certs) def list_role_tags(self): - role_name = self._get_param('RoleName') - marker = self._get_param('Marker') - max_items = self._get_param('MaxItems', 100) + role_name = self._get_param("RoleName") + marker = self._get_param("Marker") + max_items = self._get_param("MaxItems", 100) tags, marker = iam_backend.list_role_tags(role_name, marker, max_items) @@ -765,8 +785,8 @@ class IamResponse(BaseResponse): return template.render(tags=tags, marker=marker) def tag_role(self): - role_name = self._get_param('RoleName') - tags = self._get_multi_param('Tags.member') + role_name = self._get_param("RoleName") + tags = self._get_multi_param("Tags.member") iam_backend.tag_role(role_name, tags) @@ -774,8 +794,8 @@ class IamResponse(BaseResponse): return template.render() def untag_role(self): - role_name = self._get_param('RoleName') - tag_keys = self._get_multi_param('TagKeys.member') + role_name = self._get_param("RoleName") + tag_keys = self._get_multi_param("TagKeys.member") iam_backend.untag_role(role_name, tag_keys) @@ -783,17 +803,19 @@ class IamResponse(BaseResponse): return template.render() def create_open_id_connect_provider(self): - open_id_provider_url = self._get_param('Url') - thumbprint_list = self._get_multi_param('ThumbprintList.member') - client_id_list = self._get_multi_param('ClientIDList.member') + open_id_provider_url = self._get_param("Url") + thumbprint_list = self._get_multi_param("ThumbprintList.member") + client_id_list = self._get_multi_param("ClientIDList.member") - open_id_provider = iam_backend.create_open_id_connect_provider(open_id_provider_url, thumbprint_list, client_id_list) + open_id_provider = iam_backend.create_open_id_connect_provider( + open_id_provider_url, thumbprint_list, client_id_list + ) template = self.response_template(CREATE_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) return template.render(open_id_provider=open_id_provider) def delete_open_id_connect_provider(self): - open_id_provider_arn = self._get_param('OpenIDConnectProviderArn') + open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") iam_backend.delete_open_id_connect_provider(open_id_provider_arn) @@ -801,9 +823,11 @@ class IamResponse(BaseResponse): return template.render() def get_open_id_connect_provider(self): - open_id_provider_arn = self._get_param('OpenIDConnectProviderArn') + open_id_provider_arn = self._get_param("OpenIDConnectProviderArn") - open_id_provider = iam_backend.get_open_id_connect_provider(open_id_provider_arn) + open_id_provider = iam_backend.get_open_id_connect_provider( + open_id_provider_arn + ) template = self.response_template(GET_OPEN_ID_CONNECT_PROVIDER_TEMPLATE) return template.render(open_id_provider=open_id_provider) diff --git a/moto/iam/urls.py b/moto/iam/urls.py index 46db41e46..c4ce1d81f 100644 --- a/moto/iam/urls.py +++ b/moto/iam/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import IamResponse -url_bases = [ - "https?://iam(.*).amazonaws.com", -] +url_bases = ["https?://iam(.*).amazonaws.com"] -url_paths = { - '{0}/$': IamResponse.dispatch, -} +url_paths = {"{0}/$": IamResponse.dispatch} diff --git a/moto/iam/utils.py b/moto/iam/utils.py index 2bd6448f9..391f54dbd 100644 --- a/moto/iam/utils.py +++ b/moto/iam/utils.py @@ -5,29 +5,26 @@ import six def random_alphanumeric(length): - return ''.join(six.text_type( - random.choice( - string.ascii_letters + string.digits + "+" + "/" - )) for _ in range(length) + return "".join( + six.text_type(random.choice(string.ascii_letters + string.digits + "+" + "/")) + for _ in range(length) ) def random_resource_id(size=20): chars = list(range(10)) + list(string.ascii_lowercase) - return ''.join(six.text_type(random.choice(chars)) for x in range(size)) + return "".join(six.text_type(random.choice(chars)) for x in range(size)) def random_access_key(): - return ''.join(six.text_type( - random.choice( - string.ascii_uppercase + string.digits - )) for _ in range(16) + return "".join( + six.text_type(random.choice(string.ascii_uppercase + string.digits)) + for _ in range(16) ) def random_policy_id(): - return 'A' + ''.join( - random.choice(string.ascii_uppercase + string.digits) - for _ in range(20) + return "A" + "".join( + random.choice(string.ascii_uppercase + string.digits) for _ in range(20) ) diff --git a/moto/instance_metadata/responses.py b/moto/instance_metadata/responses.py index 460e65aca..81dfd8b59 100644 --- a/moto/instance_metadata/responses.py +++ b/moto/instance_metadata/responses.py @@ -7,7 +7,6 @@ from moto.core.responses import BaseResponse class InstanceMetadataResponse(BaseResponse): - def metadata_response(self, request, full_url, headers): """ Mock response for localhost metadata @@ -21,7 +20,7 @@ class InstanceMetadataResponse(BaseResponse): AccessKeyId="test-key", SecretAccessKey="test-secret-key", Token="test-session-token", - Expiration=tomorrow.strftime("%Y-%m-%dT%H:%M:%SZ") + Expiration=tomorrow.strftime("%Y-%m-%dT%H:%M:%SZ"), ) path = parsed_url.path @@ -29,21 +28,18 @@ class InstanceMetadataResponse(BaseResponse): meta_data_prefix = "/latest/meta-data/" # Strip prefix if it is there if path.startswith(meta_data_prefix): - path = path[len(meta_data_prefix):] + path = path[len(meta_data_prefix) :] - if path == '': - result = 'iam' - elif path == 'iam': - result = json.dumps({ - 'security-credentials': { - 'default-role': credentials - } - }) - elif path == 'iam/security-credentials/': - result = 'default-role' - elif path == 'iam/security-credentials/default-role': + if path == "": + result = "iam" + elif path == "iam": + result = json.dumps({"security-credentials": {"default-role": credentials}}) + elif path == "iam/security-credentials/": + result = "default-role" + elif path == "iam/security-credentials/default-role": result = json.dumps(credentials) else: raise NotImplementedError( - "The {0} metadata path has not been implemented".format(path)) + "The {0} metadata path has not been implemented".format(path) + ) return 200, headers, result diff --git a/moto/instance_metadata/urls.py b/moto/instance_metadata/urls.py index 7776b364a..b77935473 100644 --- a/moto/instance_metadata/urls.py +++ b/moto/instance_metadata/urls.py @@ -1,12 +1,8 @@ from __future__ import unicode_literals from .responses import InstanceMetadataResponse -url_bases = [ - "http://169.254.169.254" -] +url_bases = ["http://169.254.169.254"] instance_metadata = InstanceMetadataResponse() -url_paths = { - '{0}/(?P.+)': instance_metadata.metadata_response, -} +url_paths = {"{0}/(?P.+)": instance_metadata.metadata_response} diff --git a/moto/iot/__init__.py b/moto/iot/__init__.py index 199b8aeae..97d36fbcc 100644 --- a/moto/iot/__init__.py +++ b/moto/iot/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import iot_backends from ..core.models import base_decorator -iot_backend = iot_backends['us-east-1'] +iot_backend = iot_backends["us-east-1"] mock_iot = base_decorator(iot_backends) diff --git a/moto/iot/exceptions.py b/moto/iot/exceptions.py index 3af3751d9..14d577389 100644 --- a/moto/iot/exceptions.py +++ b/moto/iot/exceptions.py @@ -10,8 +10,7 @@ class ResourceNotFoundException(IoTClientError): def __init__(self): self.code = 404 super(ResourceNotFoundException, self).__init__( - "ResourceNotFoundException", - "The specified resource does not exist" + "ResourceNotFoundException", "The specified resource does not exist" ) @@ -19,8 +18,7 @@ class InvalidRequestException(IoTClientError): def __init__(self, msg=None): self.code = 400 super(InvalidRequestException, self).__init__( - "InvalidRequestException", - msg or "The request is not valid." + "InvalidRequestException", msg or "The request is not valid." ) @@ -28,8 +26,8 @@ class VersionConflictException(IoTClientError): def __init__(self, name): self.code = 409 super(VersionConflictException, self).__init__( - 'VersionConflictException', - 'The version for thing %s does not match the expected version.' % name + "VersionConflictException", + "The version for thing %s does not match the expected version." % name, ) @@ -37,14 +35,11 @@ class CertificateStateException(IoTClientError): def __init__(self, msg, cert_id): self.code = 406 super(CertificateStateException, self).__init__( - 'CertificateStateException', - '%s Id: %s' % (msg, cert_id) + "CertificateStateException", "%s Id: %s" % (msg, cert_id) ) class DeleteConflictException(IoTClientError): def __init__(self, msg): self.code = 409 - super(DeleteConflictException, self).__init__( - 'DeleteConflictException', msg - ) + super(DeleteConflictException, self).__init__("DeleteConflictException", msg) diff --git a/moto/iot/models.py b/moto/iot/models.py index a828a05f8..9e520b0fd 100644 --- a/moto/iot/models.py +++ b/moto/iot/models.py @@ -17,7 +17,7 @@ from .exceptions import ( DeleteConflictException, ResourceNotFoundException, InvalidRequestException, - VersionConflictException + VersionConflictException, ) @@ -27,7 +27,7 @@ class FakeThing(BaseModel): self.thing_name = thing_name self.thing_type = thing_type self.attributes = attributes - self.arn = 'arn:aws:iot:%s:1:thing/%s' % (self.region_name, thing_name) + self.arn = "arn:aws:iot:%s:1:thing/%s" % (self.region_name, thing_name) self.version = 1 # TODO: we need to handle 'version'? @@ -36,15 +36,15 @@ class FakeThing(BaseModel): def to_dict(self, include_default_client_id=False): obj = { - 'thingName': self.thing_name, - 'thingArn': self.arn, - 'attributes': self.attributes, - 'version': self.version + "thingName": self.thing_name, + "thingArn": self.arn, + "attributes": self.attributes, + "version": self.version, } if self.thing_type: - obj['thingTypeName'] = self.thing_type.thing_type_name + obj["thingTypeName"] = self.thing_type.thing_type_name if include_default_client_id: - obj['defaultClientId'] = self.thing_name + obj["defaultClientId"] = self.thing_name return obj @@ -55,23 +55,22 @@ class FakeThingType(BaseModel): self.thing_type_properties = thing_type_properties self.thing_type_id = str(uuid.uuid4()) # I don't know the rule of id t = time.time() - self.metadata = { - 'deprecated': False, - 'creationData': int(t * 1000) / 1000.0 - } - self.arn = 'arn:aws:iot:%s:1:thingtype/%s' % (self.region_name, thing_type_name) + self.metadata = {"deprecated": False, "creationData": int(t * 1000) / 1000.0} + self.arn = "arn:aws:iot:%s:1:thingtype/%s" % (self.region_name, thing_type_name) def to_dict(self): return { - 'thingTypeName': self.thing_type_name, - 'thingTypeId': self.thing_type_id, - 'thingTypeProperties': self.thing_type_properties, - 'thingTypeMetadata': self.metadata + "thingTypeName": self.thing_type_name, + "thingTypeId": self.thing_type_id, + "thingTypeProperties": self.thing_type_properties, + "thingTypeMetadata": self.metadata, } class FakeThingGroup(BaseModel): - def __init__(self, thing_group_name, parent_group_name, thing_group_properties, region_name): + def __init__( + self, thing_group_name, parent_group_name, thing_group_properties, region_name + ): self.region_name = region_name self.thing_group_name = thing_group_name self.thing_group_id = str(uuid.uuid4()) # I don't know the rule of id @@ -79,33 +78,34 @@ class FakeThingGroup(BaseModel): self.parent_group_name = parent_group_name self.thing_group_properties = thing_group_properties or {} t = time.time() - self.metadata = { - 'creationData': int(t * 1000) / 1000.0 - } - self.arn = 'arn:aws:iot:%s:1:thinggroup/%s' % (self.region_name, thing_group_name) + self.metadata = {"creationData": int(t * 1000) / 1000.0} + self.arn = "arn:aws:iot:%s:1:thinggroup/%s" % ( + self.region_name, + thing_group_name, + ) self.things = OrderedDict() def to_dict(self): return { - 'thingGroupName': self.thing_group_name, - 'thingGroupId': self.thing_group_id, - 'version': self.version, - 'thingGroupProperties': self.thing_group_properties, - 'thingGroupMetadata': self.metadata + "thingGroupName": self.thing_group_name, + "thingGroupId": self.thing_group_id, + "version": self.version, + "thingGroupProperties": self.thing_group_properties, + "thingGroupMetadata": self.metadata, } class FakeCertificate(BaseModel): def __init__(self, certificate_pem, status, region_name, ca_certificate_pem=None): m = hashlib.sha256() - m.update(str(uuid.uuid4()).encode('utf-8')) + m.update(str(uuid.uuid4()).encode("utf-8")) self.certificate_id = m.hexdigest() - self.arn = 'arn:aws:iot:%s:1:cert/%s' % (region_name, self.certificate_id) + self.arn = "arn:aws:iot:%s:1:cert/%s" % (region_name, self.certificate_id) self.certificate_pem = certificate_pem self.status = status # TODO: must adjust - self.owner = '1' + self.owner = "1" self.transfer_data = {} self.creation_date = time.time() self.last_modified_date = self.creation_date @@ -113,16 +113,16 @@ class FakeCertificate(BaseModel): self.ca_certificate_id = None self.ca_certificate_pem = ca_certificate_pem if ca_certificate_pem: - m.update(str(uuid.uuid4()).encode('utf-8')) + m.update(str(uuid.uuid4()).encode("utf-8")) self.ca_certificate_id = m.hexdigest() def to_dict(self): return { - 'certificateArn': self.arn, - 'certificateId': self.certificate_id, - 'caCertificateId': self.ca_certificate_id, - 'status': self.status, - 'creationDate': self.creation_date + "certificateArn": self.arn, + "certificateId": self.certificate_id, + "caCertificateId": self.ca_certificate_id, + "status": self.status, + "creationDate": self.creation_date, } def to_description_dict(self): @@ -132,14 +132,14 @@ class FakeCertificate(BaseModel): - previousOwnedBy """ return { - 'certificateArn': self.arn, - 'certificateId': self.certificate_id, - 'status': self.status, - 'certificatePem': self.certificate_pem, - 'ownedBy': self.owner, - 'creationDate': self.creation_date, - 'lastModifiedDate': self.last_modified_date, - 'transferData': self.transfer_data + "certificateArn": self.arn, + "certificateId": self.certificate_id, + "status": self.status, + "certificatePem": self.certificate_pem, + "ownedBy": self.owner, + "creationDate": self.creation_date, + "lastModifiedDate": self.last_modified_date, + "transferData": self.transfer_data, } @@ -147,44 +147,52 @@ class FakePolicy(BaseModel): def __init__(self, name, document, region_name): self.name = name self.document = document - self.arn = 'arn:aws:iot:%s:1:policy/%s' % (region_name, name) - self.version = '1' # TODO: handle version + self.arn = "arn:aws:iot:%s:1:policy/%s" % (region_name, name) + self.version = "1" # TODO: handle version def to_get_dict(self): return { - 'policyName': self.name, - 'policyArn': self.arn, - 'policyDocument': self.document, - 'defaultVersionId': self.version + "policyName": self.name, + "policyArn": self.arn, + "policyDocument": self.document, + "defaultVersionId": self.version, } def to_dict_at_creation(self): return { - 'policyName': self.name, - 'policyArn': self.arn, - 'policyDocument': self.document, - 'policyVersionId': self.version + "policyName": self.name, + "policyArn": self.arn, + "policyDocument": self.document, + "policyVersionId": self.version, } def to_dict(self): - return { - 'policyName': self.name, - 'policyArn': self.arn, - } + return {"policyName": self.name, "policyArn": self.arn} class FakeJob(BaseModel): JOB_ID_REGEX_PATTERN = "[a-zA-Z0-9_-]" JOB_ID_REGEX = re.compile(JOB_ID_REGEX_PATTERN) - def __init__(self, job_id, targets, document_source, document, description, presigned_url_config, target_selection, - job_executions_rollout_config, document_parameters, region_name): + def __init__( + self, + job_id, + targets, + document_source, + document, + description, + presigned_url_config, + target_selection, + job_executions_rollout_config, + document_parameters, + region_name, + ): if not self._job_id_matcher(self.JOB_ID_REGEX, job_id): raise InvalidRequestException() self.region_name = region_name self.job_id = job_id - self.job_arn = 'arn:aws:iot:%s:1:job/%s' % (self.region_name, job_id) + self.job_arn = "arn:aws:iot:%s:1:job/%s" % (self.region_name, job_id) self.targets = targets self.document_source = document_source self.document = document @@ -198,35 +206,35 @@ class FakeJob(BaseModel): self.last_updated_at = time.mktime(datetime(2015, 1, 1).timetuple()) self.completed_at = None self.job_process_details = { - 'processingTargets': targets, - 'numberOfQueuedThings': 1, - 'numberOfCanceledThings': 0, - 'numberOfSucceededThings': 0, - 'numberOfFailedThings': 0, - 'numberOfRejectedThings': 0, - 'numberOfInProgressThings': 0, - 'numberOfRemovedThings': 0 + "processingTargets": targets, + "numberOfQueuedThings": 1, + "numberOfCanceledThings": 0, + "numberOfSucceededThings": 0, + "numberOfFailedThings": 0, + "numberOfRejectedThings": 0, + "numberOfInProgressThings": 0, + "numberOfRemovedThings": 0, } self.document_parameters = document_parameters def to_dict(self): obj = { - 'jobArn': self.job_arn, - 'jobId': self.job_id, - 'targets': self.targets, - 'description': self.description, - 'presignedUrlConfig': self.presigned_url_config, - 'targetSelection': self.target_selection, - 'jobExecutionsRolloutConfig': self.job_executions_rollout_config, - 'status': self.status, - 'comment': self.comment, - 'createdAt': self.created_at, - 'lastUpdatedAt': self.last_updated_at, - 'completedAt': self.completedAt, - 'jobProcessDetails': self.job_process_details, - 'documentParameters': self.document_parameters, - 'document': self.document, - 'documentSource': self.document_source + "jobArn": self.job_arn, + "jobId": self.job_id, + "targets": self.targets, + "description": self.description, + "presignedUrlConfig": self.presigned_url_config, + "targetSelection": self.target_selection, + "jobExecutionsRolloutConfig": self.job_executions_rollout_config, + "status": self.status, + "comment": self.comment, + "createdAt": self.created_at, + "lastUpdatedAt": self.last_updated_at, + "completedAt": self.completedAt, + "jobProcessDetails": self.job_process_details, + "documentParameters": self.document_parameters, + "document": self.document, + "documentSource": self.document_source, } return obj @@ -259,16 +267,18 @@ class IoTBackend(BaseBackend): thing_types = self.list_thing_types() thing_type = None if thing_type_name: - filtered_thing_types = [_ for _ in thing_types if _.thing_type_name == thing_type_name] + filtered_thing_types = [ + _ for _ in thing_types if _.thing_type_name == thing_type_name + ] if len(filtered_thing_types) == 0: raise ResourceNotFoundException() thing_type = filtered_thing_types[0] if attribute_payload is None: attributes = {} - elif 'attributes' not in attribute_payload: + elif "attributes" not in attribute_payload: attributes = {} else: - attributes = attribute_payload['attributes'] + attributes = attribute_payload["attributes"] thing = FakeThing(thing_name, thing_type, attributes, self.region_name) self.things[thing.arn] = thing return thing.thing_name, thing.arn @@ -276,41 +286,68 @@ class IoTBackend(BaseBackend): def create_thing_type(self, thing_type_name, thing_type_properties): if thing_type_properties is None: thing_type_properties = {} - thing_type = FakeThingType(thing_type_name, thing_type_properties, self.region_name) + thing_type = FakeThingType( + thing_type_name, thing_type_properties, self.region_name + ) self.thing_types[thing_type.arn] = thing_type return thing_type.thing_type_name, thing_type.arn def list_thing_types(self, thing_type_name=None): if thing_type_name: # It's weird but thing_type_name is filtered by forward match, not complete match - return [_ for _ in self.thing_types.values() if _.thing_type_name.startswith(thing_type_name)] + return [ + _ + for _ in self.thing_types.values() + if _.thing_type_name.startswith(thing_type_name) + ] return self.thing_types.values() - def list_things(self, attribute_name, attribute_value, thing_type_name, max_results, token): + def list_things( + self, attribute_name, attribute_value, thing_type_name, max_results, token + ): all_things = [_.to_dict() for _ in self.things.values()] if attribute_name is not None and thing_type_name is not None: - filtered_things = list(filter(lambda elem: - attribute_name in elem["attributes"] and - elem["attributes"][attribute_name] == attribute_value and - "thingTypeName" in elem and - elem["thingTypeName"] == thing_type_name, all_things)) + filtered_things = list( + filter( + lambda elem: attribute_name in elem["attributes"] + and elem["attributes"][attribute_name] == attribute_value + and "thingTypeName" in elem + and elem["thingTypeName"] == thing_type_name, + all_things, + ) + ) elif attribute_name is not None and thing_type_name is None: - filtered_things = list(filter(lambda elem: - attribute_name in elem["attributes"] and - elem["attributes"][attribute_name] == attribute_value, all_things)) + filtered_things = list( + filter( + lambda elem: attribute_name in elem["attributes"] + and elem["attributes"][attribute_name] == attribute_value, + all_things, + ) + ) elif attribute_name is None and thing_type_name is not None: filtered_things = list( - filter(lambda elem: "thingTypeName" in elem and elem["thingTypeName"] == thing_type_name, all_things)) + filter( + lambda elem: "thingTypeName" in elem + and elem["thingTypeName"] == thing_type_name, + all_things, + ) + ) else: filtered_things = all_things if token is None: things = filtered_things[0:max_results] - next_token = str(max_results) if len(filtered_things) > max_results else None + next_token = ( + str(max_results) if len(filtered_things) > max_results else None + ) else: token = int(token) - things = filtered_things[token:token + max_results] - next_token = str(token + max_results) if len(filtered_things) > token + max_results else None + things = filtered_things[token : token + max_results] + next_token = ( + str(token + max_results) + if len(filtered_things) > token + max_results + else None + ) return things, next_token @@ -321,7 +358,9 @@ class IoTBackend(BaseBackend): return things[0] def describe_thing_type(self, thing_type_name): - thing_types = [_ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name] + thing_types = [ + _ for _ in self.thing_types.values() if _.thing_type_name == thing_type_name + ] if len(thing_types) == 0: raise ResourceNotFoundException() return thing_types[0] @@ -344,7 +383,14 @@ class IoTBackend(BaseBackend): thing_type = self.describe_thing_type(thing_type_name) del self.thing_types[thing_type.arn] - def update_thing(self, thing_name, thing_type_name, attribute_payload, expected_version, remove_thing_type): + def update_thing( + self, + thing_name, + thing_type_name, + attribute_payload, + expected_version, + remove_thing_type, + ): # if attributes payload = {}, nothing thing = self.describe_thing(thing_name) thing_type = None @@ -355,7 +401,9 @@ class IoTBackend(BaseBackend): # thing_type if thing_type_name: thing_types = self.list_thing_types() - filtered_thing_types = [_ for _ in thing_types if _.thing_type_name == thing_type_name] + filtered_thing_types = [ + _ for _ in thing_types if _.thing_type_name == thing_type_name + ] if len(filtered_thing_types) == 0: raise ResourceNotFoundException() thing_type = filtered_thing_types[0] @@ -365,9 +413,9 @@ class IoTBackend(BaseBackend): thing.thing_type = None # attribute - if attribute_payload is not None and 'attributes' in attribute_payload: - do_merge = attribute_payload.get('merge', False) - attributes = attribute_payload['attributes'] + if attribute_payload is not None and "attributes" in attribute_payload: + do_merge = attribute_payload.get("merge", False) + attributes = attribute_payload["attributes"] if not do_merge: thing.attributes = attributes else: @@ -375,46 +423,59 @@ class IoTBackend(BaseBackend): def _random_string(self): n = 20 - random_str = ''.join([random.choice(string.ascii_letters + string.digits) for i in range(n)]) + random_str = "".join( + [random.choice(string.ascii_letters + string.digits) for i in range(n)] + ) return random_str def create_keys_and_certificate(self, set_as_active): # implement here # caCertificate can be blank key_pair = { - 'PublicKey': self._random_string(), - 'PrivateKey': self._random_string() + "PublicKey": self._random_string(), + "PrivateKey": self._random_string(), } certificate_pem = self._random_string() - status = 'ACTIVE' if set_as_active else 'INACTIVE' + status = "ACTIVE" if set_as_active else "INACTIVE" certificate = FakeCertificate(certificate_pem, status, self.region_name) self.certificates[certificate.certificate_id] = certificate return certificate, key_pair def delete_certificate(self, certificate_id): cert = self.describe_certificate(certificate_id) - if cert.status == 'ACTIVE': + if cert.status == "ACTIVE": raise CertificateStateException( - 'Certificate must be deactivated (not ACTIVE) before deletion.', certificate_id) - - certs = [k[0] for k, v in self.principal_things.items() - if self._get_principal(k[0]).certificate_id == certificate_id] - if len(certs) > 0: - raise DeleteConflictException( - 'Things must be detached before deletion (arn: %s)' % certs[0] + "Certificate must be deactivated (not ACTIVE) before deletion.", + certificate_id, ) - certs = [k[0] for k, v in self.principal_policies.items() - if self._get_principal(k[0]).certificate_id == certificate_id] + certs = [ + k[0] + for k, v in self.principal_things.items() + if self._get_principal(k[0]).certificate_id == certificate_id + ] if len(certs) > 0: raise DeleteConflictException( - 'Certificate policies must be detached before deletion (arn: %s)' % certs[0] + "Things must be detached before deletion (arn: %s)" % certs[0] + ) + + certs = [ + k[0] + for k, v in self.principal_policies.items() + if self._get_principal(k[0]).certificate_id == certificate_id + ] + if len(certs) > 0: + raise DeleteConflictException( + "Certificate policies must be detached before deletion (arn: %s)" + % certs[0] ) del self.certificates[certificate_id] def describe_certificate(self, certificate_id): - certs = [_ for _ in self.certificates.values() if _.certificate_id == certificate_id] + certs = [ + _ for _ in self.certificates.values() if _.certificate_id == certificate_id + ] if len(certs) == 0: raise ResourceNotFoundException() return certs[0] @@ -422,9 +483,15 @@ class IoTBackend(BaseBackend): def list_certificates(self): return self.certificates.values() - def register_certificate(self, certificate_pem, ca_certificate_pem, set_as_active, status): - certificate = FakeCertificate(certificate_pem, 'ACTIVE' if set_as_active else status, - self.region_name, ca_certificate_pem) + def register_certificate( + self, certificate_pem, ca_certificate_pem, set_as_active, status + ): + certificate = FakeCertificate( + certificate_pem, + "ACTIVE" if set_as_active else status, + self.region_name, + ca_certificate_pem, + ) self.certificates[certificate.certificate_id] = certificate return certificate @@ -450,10 +517,12 @@ class IoTBackend(BaseBackend): def delete_policy(self, policy_name): - policies = [k[1] for k, v in self.principal_policies.items() if k[1] == policy_name] + policies = [ + k[1] for k, v in self.principal_policies.items() if k[1] == policy_name + ] if len(policies) > 0: raise DeleteConflictException( - 'The policy cannot be deleted as the policy is attached to one or more principals (name=%s)' + "The policy cannot be deleted as the policy is attached to one or more principals (name=%s)" % policy_name ) @@ -464,7 +533,7 @@ class IoTBackend(BaseBackend): """ raise ResourceNotFoundException """ - if ':cert/' in principal_arn: + if ":cert/" in principal_arn: certs = [_ for _ in self.certificates.values() if _.arn == principal_arn] if len(certs) == 0: raise ResourceNotFoundException() @@ -511,11 +580,15 @@ class IoTBackend(BaseBackend): del self.principal_policies[k] def list_principal_policies(self, principal_arn): - policies = [v[1] for k, v in self.principal_policies.items() if k[0] == principal_arn] + policies = [ + v[1] for k, v in self.principal_policies.items() if k[0] == principal_arn + ] return policies def list_policy_principals(self, policy_name): - principals = [k[0] for k, v in self.principal_policies.items() if k[1] == policy_name] + principals = [ + k[0] for k, v in self.principal_policies.items() if k[1] == policy_name + ] return principals def attach_thing_principal(self, thing_name, principal_arn): @@ -537,21 +610,36 @@ class IoTBackend(BaseBackend): del self.principal_things[k] def list_principal_things(self, principal_arn): - thing_names = [k[0] for k, v in self.principal_things.items() if k[0] == principal_arn] + thing_names = [ + k[0] for k, v in self.principal_things.items() if k[0] == principal_arn + ] return thing_names def list_thing_principals(self, thing_name): - principals = [k[0] for k, v in self.principal_things.items() if k[1] == thing_name] + principals = [ + k[0] for k, v in self.principal_things.items() if k[1] == thing_name + ] return principals def describe_thing_group(self, thing_group_name): - thing_groups = [_ for _ in self.thing_groups.values() if _.thing_group_name == thing_group_name] + thing_groups = [ + _ + for _ in self.thing_groups.values() + if _.thing_group_name == thing_group_name + ] if len(thing_groups) == 0: raise ResourceNotFoundException() return thing_groups[0] - def create_thing_group(self, thing_group_name, parent_group_name, thing_group_properties): - thing_group = FakeThingGroup(thing_group_name, parent_group_name, thing_group_properties, self.region_name) + def create_thing_group( + self, thing_group_name, parent_group_name, thing_group_properties + ): + thing_group = FakeThingGroup( + thing_group_name, + parent_group_name, + thing_group_properties, + self.region_name, + ) self.thing_groups[thing_group.arn] = thing_group return thing_group.thing_group_name, thing_group.arn, thing_group.thing_group_id @@ -563,19 +651,25 @@ class IoTBackend(BaseBackend): thing_groups = self.thing_groups.values() return thing_groups - def update_thing_group(self, thing_group_name, thing_group_properties, expected_version): + def update_thing_group( + self, thing_group_name, thing_group_properties, expected_version + ): thing_group = self.describe_thing_group(thing_group_name) if expected_version and expected_version != thing_group.version: raise VersionConflictException(thing_group_name) - attribute_payload = thing_group_properties.get('attributePayload', None) - if attribute_payload is not None and 'attributes' in attribute_payload: - do_merge = attribute_payload.get('merge', False) - attributes = attribute_payload['attributes'] + attribute_payload = thing_group_properties.get("attributePayload", None) + if attribute_payload is not None and "attributes" in attribute_payload: + do_merge = attribute_payload.get("merge", False) + attributes = attribute_payload["attributes"] if not do_merge: - thing_group.thing_group_properties['attributePayload']['attributes'] = attributes + thing_group.thing_group_properties["attributePayload"][ + "attributes" + ] = attributes else: - thing_group.thing_group_properties['attributePayload']['attributes'].update(attributes) - elif attribute_payload is not None and 'attributes' not in attribute_payload: + thing_group.thing_group_properties["attributePayload"][ + "attributes" + ].update(attributes) + elif attribute_payload is not None and "attributes" not in attribute_payload: thing_group.attributes = {} thing_group.version = thing_group.version + 1 return thing_group.version @@ -584,13 +678,13 @@ class IoTBackend(BaseBackend): # identify thing group if thing_group_name is None and thing_group_arn is None: raise InvalidRequestException( - ' Both thingGroupArn and thingGroupName are empty. Need to specify at least one of them' + " Both thingGroupArn and thingGroupName are empty. Need to specify at least one of them" ) if thing_group_name is not None: thing_group = self.describe_thing_group(thing_group_name) if thing_group_arn and thing_group.arn != thing_group_arn: raise InvalidRequestException( - 'ThingGroupName thingGroupArn does not match specified thingGroupName in request' + "ThingGroupName thingGroupArn does not match specified thingGroupName in request" ) elif thing_group_arn is not None: if thing_group_arn not in self.thing_groups: @@ -602,13 +696,13 @@ class IoTBackend(BaseBackend): # identify thing if thing_name is None and thing_arn is None: raise InvalidRequestException( - 'Both thingArn and thingName are empty. Need to specify at least one of them' + "Both thingArn and thingName are empty. Need to specify at least one of them" ) if thing_name is not None: thing = self.describe_thing(thing_name) if thing_arn and thing.arn != thing_arn: raise InvalidRequestException( - 'ThingName thingArn does not match specified thingName in request' + "ThingName thingArn does not match specified thingName in request" ) elif thing_arn is not None: if thing_arn not in self.things: @@ -616,7 +710,9 @@ class IoTBackend(BaseBackend): thing = self.things[thing_arn] return thing - def add_thing_to_thing_group(self, thing_group_name, thing_group_arn, thing_name, thing_arn): + def add_thing_to_thing_group( + self, thing_group_name, thing_group_arn, thing_name, thing_arn + ): thing_group = self._identify_thing_group(thing_group_name, thing_group_arn) thing = self._identify_thing(thing_name, thing_arn) if thing.arn in thing_group.things: @@ -624,7 +720,9 @@ class IoTBackend(BaseBackend): return thing_group.things[thing.arn] = thing - def remove_thing_from_thing_group(self, thing_group_name, thing_group_arn, thing_name, thing_arn): + def remove_thing_from_thing_group( + self, thing_group_name, thing_group_arn, thing_name, thing_arn + ): thing_group = self._identify_thing_group(thing_group_name, thing_group_arn) thing = self._identify_thing(thing_name, thing_arn) if thing.arn not in thing_group.things: @@ -642,31 +740,53 @@ class IoTBackend(BaseBackend): ret = [] for thing_group in all_thing_groups: if thing.arn in thing_group.things: - ret.append({ - 'groupName': thing_group.thing_group_name, - 'groupArn': thing_group.arn - }) + ret.append( + { + "groupName": thing_group.thing_group_name, + "groupArn": thing_group.arn, + } + ) return ret - def update_thing_groups_for_thing(self, thing_name, thing_groups_to_add, thing_groups_to_remove): + def update_thing_groups_for_thing( + self, thing_name, thing_groups_to_add, thing_groups_to_remove + ): thing = self.describe_thing(thing_name) for thing_group_name in thing_groups_to_add: thing_group = self.describe_thing_group(thing_group_name) self.add_thing_to_thing_group( - thing_group.thing_group_name, None, - thing.thing_name, None + thing_group.thing_group_name, None, thing.thing_name, None ) for thing_group_name in thing_groups_to_remove: thing_group = self.describe_thing_group(thing_group_name) self.remove_thing_from_thing_group( - thing_group.thing_group_name, None, - thing.thing_name, None + thing_group.thing_group_name, None, thing.thing_name, None ) - def create_job(self, job_id, targets, document_source, document, description, presigned_url_config, - target_selection, job_executions_rollout_config, document_parameters): - job = FakeJob(job_id, targets, document_source, document, description, presigned_url_config, target_selection, - job_executions_rollout_config, document_parameters, self.region_name) + def create_job( + self, + job_id, + targets, + document_source, + document, + description, + presigned_url_config, + target_selection, + job_executions_rollout_config, + document_parameters, + ): + job = FakeJob( + job_id, + targets, + document_source, + document, + description, + presigned_url_config, + target_selection, + job_executions_rollout_config, + document_parameters, + self.region_name, + ) self.jobs[job_id] = job return job.job_arn, job_id, description diff --git a/moto/iot/responses.py b/moto/iot/responses.py index 3821c1c79..5981eaa37 100644 --- a/moto/iot/responses.py +++ b/moto/iot/responses.py @@ -7,7 +7,7 @@ from .models import iot_backends class IoTResponse(BaseResponse): - SERVICE_NAME = 'iot' + SERVICE_NAME = "iot" @property def iot_backend(self): @@ -28,18 +28,19 @@ class IoTResponse(BaseResponse): thing_type_name = self._get_param("thingTypeName") thing_type_properties = self._get_param("thingTypeProperties") thing_type_name, thing_type_arn = self.iot_backend.create_thing_type( - thing_type_name=thing_type_name, - thing_type_properties=thing_type_properties, + thing_type_name=thing_type_name, thing_type_properties=thing_type_properties + ) + return json.dumps( + dict(thingTypeName=thing_type_name, thingTypeArn=thing_type_arn) ) - return json.dumps(dict(thingTypeName=thing_type_name, thingTypeArn=thing_type_arn)) def list_thing_types(self): previous_next_token = self._get_param("nextToken") - max_results = self._get_int_param("maxResults", 50) # not the default, but makes testing easier + max_results = self._get_int_param( + "maxResults", 50 + ) # not the default, but makes testing easier thing_type_name = self._get_param("thingTypeName") - thing_types = self.iot_backend.list_thing_types( - thing_type_name=thing_type_name - ) + thing_types = self.iot_backend.list_thing_types(thing_type_name=thing_type_name) thing_types = [_.to_dict() for _ in thing_types] if previous_next_token is None: @@ -47,14 +48,20 @@ class IoTResponse(BaseResponse): next_token = str(max_results) if len(thing_types) > max_results else None else: token = int(previous_next_token) - result = thing_types[token:token + max_results] - next_token = str(token + max_results) if len(thing_types) > token + max_results else None + result = thing_types[token : token + max_results] + next_token = ( + str(token + max_results) + if len(thing_types) > token + max_results + else None + ) return json.dumps(dict(thingTypes=result, nextToken=next_token)) def list_things(self): previous_next_token = self._get_param("nextToken") - max_results = self._get_int_param("maxResults", 50) # not the default, but makes testing easier + max_results = self._get_int_param( + "maxResults", 50 + ) # not the default, but makes testing easier attribute_name = self._get_param("attributeName") attribute_value = self._get_param("attributeValue") thing_type_name = self._get_param("thingTypeName") @@ -63,22 +70,20 @@ class IoTResponse(BaseResponse): attribute_value=attribute_value, thing_type_name=thing_type_name, max_results=max_results, - token=previous_next_token + token=previous_next_token, ) return json.dumps(dict(things=things, nextToken=next_token)) def describe_thing(self): thing_name = self._get_param("thingName") - thing = self.iot_backend.describe_thing( - thing_name=thing_name, - ) + thing = self.iot_backend.describe_thing(thing_name=thing_name) return json.dumps(thing.to_dict(include_default_client_id=True)) def describe_thing_type(self): thing_type_name = self._get_param("thingTypeName") thing_type = self.iot_backend.describe_thing_type( - thing_type_name=thing_type_name, + thing_type_name=thing_type_name ) return json.dumps(thing_type.to_dict()) @@ -86,16 +91,13 @@ class IoTResponse(BaseResponse): thing_name = self._get_param("thingName") expected_version = self._get_param("expectedVersion") self.iot_backend.delete_thing( - thing_name=thing_name, - expected_version=expected_version, + thing_name=thing_name, expected_version=expected_version ) return json.dumps(dict()) def delete_thing_type(self): thing_type_name = self._get_param("thingTypeName") - self.iot_backend.delete_thing_type( - thing_type_name=thing_type_name, - ) + self.iot_backend.delete_thing_type(thing_type_name=thing_type_name) return json.dumps(dict()) def update_thing(self): @@ -123,57 +125,62 @@ class IoTResponse(BaseResponse): presigned_url_config=self._get_param("presignedUrlConfig"), target_selection=self._get_param("targetSelection"), job_executions_rollout_config=self._get_param("jobExecutionsRolloutConfig"), - document_parameters=self._get_param("documentParameters") + document_parameters=self._get_param("documentParameters"), ) return json.dumps(dict(jobArn=job_arn, jobId=job_id, description=description)) def describe_job(self): job = self.iot_backend.describe_job(job_id=self._get_param("jobId")) - return json.dumps(dict( - documentSource=job.document_source, - job=dict( - comment=job.comment, - completedAt=job.completed_at, - createdAt=job.created_at, - description=job.description, - documentParameters=job.document_parameters, - jobArn=job.job_arn, - jobExecutionsRolloutConfig=job.job_executions_rollout_config, - jobId=job.job_id, - jobProcessDetails=job.job_process_details, - lastUpdatedAt=job.last_updated_at, - presignedUrlConfig=job.presigned_url_config, - status=job.status, - targets=job.targets, - targetSelection=job.target_selection - ))) + return json.dumps( + dict( + documentSource=job.document_source, + job=dict( + comment=job.comment, + completedAt=job.completed_at, + createdAt=job.created_at, + description=job.description, + documentParameters=job.document_parameters, + jobArn=job.job_arn, + jobExecutionsRolloutConfig=job.job_executions_rollout_config, + jobId=job.job_id, + jobProcessDetails=job.job_process_details, + lastUpdatedAt=job.last_updated_at, + presignedUrlConfig=job.presigned_url_config, + status=job.status, + targets=job.targets, + targetSelection=job.target_selection, + ), + ) + ) def create_keys_and_certificate(self): set_as_active = self._get_bool_param("setAsActive") cert, key_pair = self.iot_backend.create_keys_and_certificate( - set_as_active=set_as_active, + set_as_active=set_as_active + ) + return json.dumps( + dict( + certificateArn=cert.arn, + certificateId=cert.certificate_id, + certificatePem=cert.certificate_pem, + keyPair=key_pair, + ) ) - return json.dumps(dict( - certificateArn=cert.arn, - certificateId=cert.certificate_id, - certificatePem=cert.certificate_pem, - keyPair=key_pair - )) def delete_certificate(self): certificate_id = self._get_param("certificateId") - self.iot_backend.delete_certificate( - certificate_id=certificate_id, - ) + self.iot_backend.delete_certificate(certificate_id=certificate_id) return json.dumps(dict()) def describe_certificate(self): certificate_id = self._get_param("certificateId") certificate = self.iot_backend.describe_certificate( - certificate_id=certificate_id, + certificate_id=certificate_id + ) + return json.dumps( + dict(certificateDescription=certificate.to_description_dict()) ) - return json.dumps(dict(certificateDescription=certificate.to_description_dict())) def list_certificates(self): # page_size = self._get_int_param("pageSize") @@ -193,16 +200,17 @@ class IoTResponse(BaseResponse): certificate_pem=certificate_pem, ca_certificate_pem=ca_certificate_pem, set_as_active=set_as_active, - status=status + status=status, + ) + return json.dumps( + dict(certificateId=cert.certificate_id, certificateArn=cert.arn) ) - return json.dumps(dict(certificateId=cert.certificate_id, certificateArn=cert.arn)) def update_certificate(self): certificate_id = self._get_param("certificateId") new_status = self._get_param("newStatus") self.iot_backend.update_certificate( - certificate_id=certificate_id, - new_status=new_status, + certificate_id=certificate_id, new_status=new_status ) return json.dumps(dict()) @@ -210,8 +218,7 @@ class IoTResponse(BaseResponse): policy_name = self._get_param("policyName") policy_document = self._get_param("policyDocument") policy = self.iot_backend.create_policy( - policy_name=policy_name, - policy_document=policy_document, + policy_name=policy_name, policy_document=policy_document ) return json.dumps(policy.to_dict_at_creation()) @@ -226,118 +233,98 @@ class IoTResponse(BaseResponse): def get_policy(self): policy_name = self._get_param("policyName") - policy = self.iot_backend.get_policy( - policy_name=policy_name, - ) + policy = self.iot_backend.get_policy(policy_name=policy_name) return json.dumps(policy.to_get_dict()) def delete_policy(self): policy_name = self._get_param("policyName") - self.iot_backend.delete_policy( - policy_name=policy_name, - ) + self.iot_backend.delete_policy(policy_name=policy_name) return json.dumps(dict()) def attach_policy(self): policy_name = self._get_param("policyName") - target = self._get_param('target') - self.iot_backend.attach_policy( - policy_name=policy_name, - target=target, - ) + target = self._get_param("target") + self.iot_backend.attach_policy(policy_name=policy_name, target=target) return json.dumps(dict()) def attach_principal_policy(self): policy_name = self._get_param("policyName") - principal = self.headers.get('x-amzn-iot-principal') + principal = self.headers.get("x-amzn-iot-principal") self.iot_backend.attach_principal_policy( - policy_name=policy_name, - principal_arn=principal, + policy_name=policy_name, principal_arn=principal ) return json.dumps(dict()) def detach_policy(self): policy_name = self._get_param("policyName") - target = self._get_param('target') - self.iot_backend.detach_policy( - policy_name=policy_name, - target=target, - ) + target = self._get_param("target") + self.iot_backend.detach_policy(policy_name=policy_name, target=target) return json.dumps(dict()) def detach_principal_policy(self): policy_name = self._get_param("policyName") - principal = self.headers.get('x-amzn-iot-principal') + principal = self.headers.get("x-amzn-iot-principal") self.iot_backend.detach_principal_policy( - policy_name=policy_name, - principal_arn=principal, + policy_name=policy_name, principal_arn=principal ) return json.dumps(dict()) def list_principal_policies(self): - principal = self.headers.get('x-amzn-iot-principal') + principal = self.headers.get("x-amzn-iot-principal") # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") # ascending_order = self._get_param("ascendingOrder") - policies = self.iot_backend.list_principal_policies( - principal_arn=principal - ) + policies = self.iot_backend.list_principal_policies(principal_arn=principal) # TODO: implement pagination in the future next_marker = None - return json.dumps(dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker)) + return json.dumps( + dict(policies=[_.to_dict() for _ in policies], nextMarker=next_marker) + ) def list_policy_principals(self): - policy_name = self.headers.get('x-amzn-iot-policy') + policy_name = self.headers.get("x-amzn-iot-policy") # marker = self._get_param("marker") # page_size = self._get_int_param("pageSize") # ascending_order = self._get_param("ascendingOrder") - principals = self.iot_backend.list_policy_principals( - policy_name=policy_name, - ) + principals = self.iot_backend.list_policy_principals(policy_name=policy_name) # TODO: implement pagination in the future next_marker = None return json.dumps(dict(principals=principals, nextMarker=next_marker)) def attach_thing_principal(self): thing_name = self._get_param("thingName") - principal = self.headers.get('x-amzn-principal') + principal = self.headers.get("x-amzn-principal") self.iot_backend.attach_thing_principal( - thing_name=thing_name, - principal_arn=principal, + thing_name=thing_name, principal_arn=principal ) return json.dumps(dict()) def detach_thing_principal(self): thing_name = self._get_param("thingName") - principal = self.headers.get('x-amzn-principal') + principal = self.headers.get("x-amzn-principal") self.iot_backend.detach_thing_principal( - thing_name=thing_name, - principal_arn=principal, + thing_name=thing_name, principal_arn=principal ) return json.dumps(dict()) def list_principal_things(self): next_token = self._get_param("nextToken") # max_results = self._get_int_param("maxResults") - principal = self.headers.get('x-amzn-principal') - things = self.iot_backend.list_principal_things( - principal_arn=principal, - ) + principal = self.headers.get("x-amzn-principal") + things = self.iot_backend.list_principal_things(principal_arn=principal) # TODO: implement pagination in the future next_token = None return json.dumps(dict(things=things, nextToken=next_token)) def list_thing_principals(self): thing_name = self._get_param("thingName") - principals = self.iot_backend.list_thing_principals( - thing_name=thing_name, - ) + principals = self.iot_backend.list_thing_principals(thing_name=thing_name) return json.dumps(dict(principals=principals)) def describe_thing_group(self): thing_group_name = self._get_param("thingGroupName") thing_group = self.iot_backend.describe_thing_group( - thing_group_name=thing_group_name, + thing_group_name=thing_group_name ) return json.dumps(thing_group.to_dict()) @@ -345,23 +332,28 @@ class IoTResponse(BaseResponse): thing_group_name = self._get_param("thingGroupName") parent_group_name = self._get_param("parentGroupName") thing_group_properties = self._get_param("thingGroupProperties") - thing_group_name, thing_group_arn, thing_group_id = self.iot_backend.create_thing_group( + ( + thing_group_name, + thing_group_arn, + thing_group_id, + ) = self.iot_backend.create_thing_group( thing_group_name=thing_group_name, parent_group_name=parent_group_name, thing_group_properties=thing_group_properties, ) - return json.dumps(dict( - thingGroupName=thing_group_name, - thingGroupArn=thing_group_arn, - thingGroupId=thing_group_id) + return json.dumps( + dict( + thingGroupName=thing_group_name, + thingGroupArn=thing_group_arn, + thingGroupId=thing_group_id, + ) ) def delete_thing_group(self): thing_group_name = self._get_param("thingGroupName") expected_version = self._get_param("expectedVersion") self.iot_backend.delete_thing_group( - thing_group_name=thing_group_name, - expected_version=expected_version, + thing_group_name=thing_group_name, expected_version=expected_version ) return json.dumps(dict()) @@ -377,7 +369,9 @@ class IoTResponse(BaseResponse): recursive=recursive, ) next_token = None - rets = [{'groupName': _.thing_group_name, 'groupArn': _.arn} for _ in thing_groups] + rets = [ + {"groupName": _.thing_group_name, "groupArn": _.arn} for _ in thing_groups + ] # TODO: implement pagination in the future return json.dumps(dict(thingGroups=rets, nextToken=next_token)) @@ -424,8 +418,7 @@ class IoTResponse(BaseResponse): # next_token = self._get_param("nextToken") # max_results = self._get_int_param("maxResults") things = self.iot_backend.list_things_in_thing_group( - thing_group_name=thing_group_name, - recursive=recursive, + thing_group_name=thing_group_name, recursive=recursive ) next_token = None thing_names = [_.thing_name for _ in things] diff --git a/moto/iot/urls.py b/moto/iot/urls.py index 6d11c15a5..2ad908714 100644 --- a/moto/iot/urls.py +++ b/moto/iot/urls.py @@ -1,14 +1,10 @@ from __future__ import unicode_literals from .responses import IoTResponse -url_bases = [ - "https?://iot.(.+).amazonaws.com", -] +url_bases = ["https?://iot.(.+).amazonaws.com"] response = IoTResponse() -url_paths = { - '{0}/.*$': response.dispatch, -} +url_paths = {"{0}/.*$": response.dispatch} diff --git a/moto/iotdata/__init__.py b/moto/iotdata/__init__.py index 214f2e575..016fef5fb 100644 --- a/moto/iotdata/__init__.py +++ b/moto/iotdata/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import iotdata_backends from ..core.models import base_decorator -iotdata_backend = iotdata_backends['us-east-1'] +iotdata_backend = iotdata_backends["us-east-1"] mock_iotdata = base_decorator(iotdata_backends) diff --git a/moto/iotdata/exceptions.py b/moto/iotdata/exceptions.py index f2c209eed..30998ffc3 100644 --- a/moto/iotdata/exceptions.py +++ b/moto/iotdata/exceptions.py @@ -10,8 +10,7 @@ class ResourceNotFoundException(IoTDataPlaneClientError): def __init__(self): self.code = 404 super(ResourceNotFoundException, self).__init__( - "ResourceNotFoundException", - "The specified resource does not exist" + "ResourceNotFoundException", "The specified resource does not exist" ) @@ -26,6 +25,4 @@ class InvalidRequestException(IoTDataPlaneClientError): class ConflictException(IoTDataPlaneClientError): def __init__(self, message): self.code = 409 - super(ConflictException, self).__init__( - "ConflictException", message - ) + super(ConflictException, self).__init__("ConflictException", message) diff --git a/moto/iotdata/models.py b/moto/iotdata/models.py index fec066f07..e534e1d1f 100644 --- a/moto/iotdata/models.py +++ b/moto/iotdata/models.py @@ -8,7 +8,7 @@ from moto.iot import iot_backends from .exceptions import ( ConflictException, ResourceNotFoundException, - InvalidRequestException + InvalidRequestException, ) @@ -16,6 +16,7 @@ class FakeShadow(BaseModel): """See the specification: http://docs.aws.amazon.com/iot/latest/developerguide/thing-shadow-document-syntax.html """ + def __init__(self, desired, reported, requested_payload, version, deleted=False): self.desired = desired self.reported = reported @@ -24,15 +25,23 @@ class FakeShadow(BaseModel): self.timestamp = int(time.time()) self.deleted = deleted - self.metadata_desired = self._create_metadata_from_state(self.desired, self.timestamp) - self.metadata_reported = self._create_metadata_from_state(self.reported, self.timestamp) + self.metadata_desired = self._create_metadata_from_state( + self.desired, self.timestamp + ) + self.metadata_reported = self._create_metadata_from_state( + self.reported, self.timestamp + ) @classmethod def create_from_previous_version(cls, previous_shadow, payload): """ set None to payload when you want to delete shadow """ - version, previous_payload = (previous_shadow.version + 1, previous_shadow.to_dict(include_delta=False)) if previous_shadow else (1, {}) + version, previous_payload = ( + (previous_shadow.version + 1, previous_shadow.to_dict(include_delta=False)) + if previous_shadow + else (1, {}) + ) if payload is None: # if given payload is None, delete existing payload @@ -41,13 +50,11 @@ class FakeShadow(BaseModel): return shadow # we can make sure that payload has 'state' key - desired = payload['state'].get( - 'desired', - previous_payload.get('state', {}).get('desired', None) + desired = payload["state"].get( + "desired", previous_payload.get("state", {}).get("desired", None) ) - reported = payload['state'].get( - 'reported', - previous_payload.get('state', {}).get('reported', None) + reported = payload["state"].get( + "reported", previous_payload.get("state", {}).get("reported", None) ) shadow = FakeShadow(desired, reported, payload, version) return shadow @@ -76,58 +83,60 @@ class FakeShadow(BaseModel): if isinstance(elem, list): return [_f(_, ts) for _ in elem] return {"timestamp": ts} + return _f(state, ts) def to_response_dict(self): - desired = self.requested_payload['state'].get('desired', None) - reported = self.requested_payload['state'].get('reported', None) + desired = self.requested_payload["state"].get("desired", None) + reported = self.requested_payload["state"].get("reported", None) payload = {} if desired is not None: - payload['desired'] = desired + payload["desired"] = desired if reported is not None: - payload['reported'] = reported + payload["reported"] = reported metadata = {} if desired is not None: - metadata['desired'] = self._create_metadata_from_state(desired, self.timestamp) + metadata["desired"] = self._create_metadata_from_state( + desired, self.timestamp + ) if reported is not None: - metadata['reported'] = self._create_metadata_from_state(reported, self.timestamp) + metadata["reported"] = self._create_metadata_from_state( + reported, self.timestamp + ) return { - 'state': payload, - 'metadata': metadata, - 'timestamp': self.timestamp, - 'version': self.version + "state": payload, + "metadata": metadata, + "timestamp": self.timestamp, + "version": self.version, } def to_dict(self, include_delta=True): """returning nothing except for just top-level keys for now. """ if self.deleted: - return { - 'timestamp': self.timestamp, - 'version': self.version - } + return {"timestamp": self.timestamp, "version": self.version} delta = self.parse_payload(self.desired, self.reported) payload = {} if self.desired is not None: - payload['desired'] = self.desired + payload["desired"] = self.desired if self.reported is not None: - payload['reported'] = self.reported + payload["reported"] = self.reported if include_delta and (delta is not None and len(delta.keys()) != 0): - payload['delta'] = delta + payload["delta"] = delta metadata = {} if self.metadata_desired is not None: - metadata['desired'] = self.metadata_desired + metadata["desired"] = self.metadata_desired if self.metadata_reported is not None: - metadata['reported'] = self.metadata_reported + metadata["reported"] = self.metadata_reported return { - 'state': payload, - 'metadata': metadata, - 'timestamp': self.timestamp, - 'version': self.version + "state": payload, + "metadata": metadata, + "timestamp": self.timestamp, + "version": self.version, } @@ -154,17 +163,19 @@ class IoTDataPlaneBackend(BaseBackend): try: payload = json.loads(payload) except ValueError: - raise InvalidRequestException('invalid json') - if 'state' not in payload: - raise InvalidRequestException('need node `state`') - if not isinstance(payload['state'], dict): - raise InvalidRequestException('state node must be an Object') - if any(_ for _ in payload['state'].keys() if _ not in ['desired', 'reported']): - raise InvalidRequestException('State contains an invalid node') + raise InvalidRequestException("invalid json") + if "state" not in payload: + raise InvalidRequestException("need node `state`") + if not isinstance(payload["state"], dict): + raise InvalidRequestException("state node must be an Object") + if any(_ for _ in payload["state"].keys() if _ not in ["desired", "reported"]): + raise InvalidRequestException("State contains an invalid node") - if 'version' in payload and thing.thing_shadow.version != payload['version']: - raise ConflictException('Version conflict') - new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload) + if "version" in payload and thing.thing_shadow.version != payload["version"]: + raise ConflictException("Version conflict") + new_shadow = FakeShadow.create_from_previous_version( + thing.thing_shadow, payload + ) thing.thing_shadow = new_shadow return thing.thing_shadow @@ -183,7 +194,9 @@ class IoTDataPlaneBackend(BaseBackend): if thing.thing_shadow is None: raise ResourceNotFoundException() payload = None - new_shadow = FakeShadow.create_from_previous_version(thing.thing_shadow, payload) + new_shadow = FakeShadow.create_from_previous_version( + thing.thing_shadow, payload + ) thing.thing_shadow = new_shadow return thing.thing_shadow diff --git a/moto/iotdata/responses.py b/moto/iotdata/responses.py index 8ab724ed1..045ed5e59 100644 --- a/moto/iotdata/responses.py +++ b/moto/iotdata/responses.py @@ -5,7 +5,7 @@ import json class IoTDataPlaneResponse(BaseResponse): - SERVICE_NAME = 'iot-data' + SERVICE_NAME = "iot-data" @property def iotdata_backend(self): @@ -15,32 +15,23 @@ class IoTDataPlaneResponse(BaseResponse): thing_name = self._get_param("thingName") payload = self.body payload = self.iotdata_backend.update_thing_shadow( - thing_name=thing_name, - payload=payload, + thing_name=thing_name, payload=payload ) return json.dumps(payload.to_response_dict()) def get_thing_shadow(self): thing_name = self._get_param("thingName") - payload = self.iotdata_backend.get_thing_shadow( - thing_name=thing_name, - ) + payload = self.iotdata_backend.get_thing_shadow(thing_name=thing_name) return json.dumps(payload.to_dict()) def delete_thing_shadow(self): thing_name = self._get_param("thingName") - payload = self.iotdata_backend.delete_thing_shadow( - thing_name=thing_name, - ) + payload = self.iotdata_backend.delete_thing_shadow(thing_name=thing_name) return json.dumps(payload.to_dict()) def publish(self): topic = self._get_param("topic") qos = self._get_int_param("qos") payload = self._get_param("payload") - self.iotdata_backend.publish( - topic=topic, - qos=qos, - payload=payload, - ) + self.iotdata_backend.publish(topic=topic, qos=qos, payload=payload) return json.dumps(dict()) diff --git a/moto/iotdata/urls.py b/moto/iotdata/urls.py index a3bcb0a52..b3baa66cc 100644 --- a/moto/iotdata/urls.py +++ b/moto/iotdata/urls.py @@ -1,14 +1,10 @@ from __future__ import unicode_literals from .responses import IoTDataPlaneResponse -url_bases = [ - "https?://data.iot.(.+).amazonaws.com", -] +url_bases = ["https?://data.iot.(.+).amazonaws.com"] response = IoTDataPlaneResponse() -url_paths = { - '{0}/.*$': response.dispatch, -} +url_paths = {"{0}/.*$": response.dispatch} diff --git a/moto/kinesis/__init__.py b/moto/kinesis/__init__.py index 7d9767a9f..823379cd5 100644 --- a/moto/kinesis/__init__.py +++ b/moto/kinesis/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import kinesis_backends from ..core.models import base_decorator, deprecated_base_decorator -kinesis_backend = kinesis_backends['us-east-1'] +kinesis_backend = kinesis_backends["us-east-1"] mock_kinesis = base_decorator(kinesis_backends) mock_kinesis_deprecated = deprecated_base_decorator(kinesis_backends) diff --git a/moto/kinesis/exceptions.py b/moto/kinesis/exceptions.py index 82f796ecc..8c950c355 100644 --- a/moto/kinesis/exceptions.py +++ b/moto/kinesis/exceptions.py @@ -5,44 +5,38 @@ from werkzeug.exceptions import BadRequest class ResourceNotFoundError(BadRequest): - def __init__(self, message): super(ResourceNotFoundError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) class ResourceInUseError(BadRequest): - def __init__(self, message): super(ResourceInUseError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceInUseException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceInUseException"} + ) class StreamNotFoundError(ResourceNotFoundError): - def __init__(self, stream_name): super(StreamNotFoundError, self).__init__( - 'Stream {0} under account 123456789012 not found.'.format(stream_name)) + "Stream {0} under account 123456789012 not found.".format(stream_name) + ) class ShardNotFoundError(ResourceNotFoundError): - def __init__(self, shard_id): super(ShardNotFoundError, self).__init__( - 'Shard {0} under account 123456789012 not found.'.format(shard_id)) + "Shard {0} under account 123456789012 not found.".format(shard_id) + ) class InvalidArgumentError(BadRequest): - def __init__(self, message): super(InvalidArgumentError, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'InvalidArgumentException', - }) + self.description = json.dumps( + {"message": message, "__type": "InvalidArgumentException"} + ) diff --git a/moto/kinesis/models.py b/moto/kinesis/models.py index 965f3367a..38a622841 100644 --- a/moto/kinesis/models.py +++ b/moto/kinesis/models.py @@ -13,9 +13,18 @@ from hashlib import md5 from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel from moto.core.utils import unix_time -from .exceptions import StreamNotFoundError, ShardNotFoundError, ResourceInUseError, \ - ResourceNotFoundError, InvalidArgumentError -from .utils import compose_shard_iterator, compose_new_shard_iterator, decompose_shard_iterator +from .exceptions import ( + StreamNotFoundError, + ShardNotFoundError, + ResourceInUseError, + ResourceNotFoundError, + InvalidArgumentError, +) +from .utils import ( + compose_shard_iterator, + compose_new_shard_iterator, + decompose_shard_iterator, +) class Record(BaseModel): @@ -32,12 +41,11 @@ class Record(BaseModel): "Data": self.data, "PartitionKey": self.partition_key, "SequenceNumber": str(self.sequence_number), - "ApproximateArrivalTimestamp": self.created_at_datetime.isoformat() + "ApproximateArrivalTimestamp": self.created_at_datetime.isoformat(), } class Shard(BaseModel): - def __init__(self, shard_id, starting_hash, ending_hash): self._shard_id = shard_id self.starting_hash = starting_hash @@ -75,7 +83,8 @@ class Shard(BaseModel): last_sequence_number = 0 sequence_number = last_sequence_number + 1 self.records[sequence_number] = Record( - partition_key, data, sequence_number, explicit_hash_key) + partition_key, data, sequence_number, explicit_hash_key + ) return sequence_number def get_min_sequence_number(self): @@ -94,25 +103,31 @@ class Shard(BaseModel): else: # find the last item in the list that was created before # at_timestamp - r = next((r for r in reversed(self.records.values()) if r.created_at < at_timestamp), None) + r = next( + ( + r + for r in reversed(self.records.values()) + if r.created_at < at_timestamp + ), + None, + ) return r.sequence_number def to_json(self): return { "HashKeyRange": { "EndingHashKey": str(self.ending_hash), - "StartingHashKey": str(self.starting_hash) + "StartingHashKey": str(self.starting_hash), }, "SequenceNumberRange": { "EndingSequenceNumber": self.get_max_sequence_number(), "StartingSequenceNumber": self.get_min_sequence_number(), }, - "ShardId": self.shard_id + "ShardId": self.shard_id, } class Stream(BaseModel): - def __init__(self, stream_name, shard_count, region): self.stream_name = stream_name self.shard_count = shard_count @@ -123,10 +138,11 @@ class Stream(BaseModel): self.tags = {} self.status = "ACTIVE" - step = 2**128 // shard_count - hash_ranges = itertools.chain(map(lambda i: (i, i * step, (i + 1) * step), - range(shard_count - 1)), - [(shard_count - 1, (shard_count - 1) * step, 2**128)]) + step = 2 ** 128 // shard_count + hash_ranges = itertools.chain( + map(lambda i: (i, i * step, (i + 1) * step), range(shard_count - 1)), + [(shard_count - 1, (shard_count - 1) * step, 2 ** 128)], + ) for index, start, end in hash_ranges: shard = Shard(index, start, end) @@ -137,7 +153,7 @@ class Stream(BaseModel): return "arn:aws:kinesis:{region}:{account_number}:{stream_name}".format( region=self.region, account_number=self.account_number, - stream_name=self.stream_name + stream_name=self.stream_name, ) def get_shard(self, shard_id): @@ -158,21 +174,22 @@ class Stream(BaseModel): key = int(explicit_hash_key) - if key >= 2**128: + if key >= 2 ** 128: raise InvalidArgumentError("explicit_hash_key") else: - key = int(md5(partition_key.encode('utf-8')).hexdigest(), 16) + key = int(md5(partition_key.encode("utf-8")).hexdigest(), 16) for shard in self.shards.values(): if shard.starting_hash <= key < shard.ending_hash: return shard - def put_record(self, partition_key, explicit_hash_key, sequence_number_for_ordering, data): + def put_record( + self, partition_key, explicit_hash_key, sequence_number_for_ordering, data + ): shard = self.get_shard_for_key(partition_key, explicit_hash_key) - sequence_number = shard.put_record( - partition_key, data, explicit_hash_key) + sequence_number = shard.put_record(partition_key, data, explicit_hash_key) return sequence_number, shard.shard_id def to_json(self): @@ -198,64 +215,69 @@ class Stream(BaseModel): } @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - region = properties.get('Region', 'us-east-1') - shard_count = properties.get('ShardCount', 1) - return Stream(properties['Name'], shard_count, region) + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + region = properties.get("Region", "us-east-1") + shard_count = properties.get("ShardCount", 1) + return Stream(properties["Name"], shard_count, region) class FirehoseRecord(BaseModel): - def __init__(self, record_data): self.record_id = 12345678 self.record_data = record_data class DeliveryStream(BaseModel): - def __init__(self, stream_name, **stream_kwargs): self.name = stream_name - self.redshift_username = stream_kwargs.get('redshift_username') - self.redshift_password = stream_kwargs.get('redshift_password') - self.redshift_jdbc_url = stream_kwargs.get('redshift_jdbc_url') - self.redshift_role_arn = stream_kwargs.get('redshift_role_arn') - self.redshift_copy_command = stream_kwargs.get('redshift_copy_command') + self.redshift_username = stream_kwargs.get("redshift_username") + self.redshift_password = stream_kwargs.get("redshift_password") + self.redshift_jdbc_url = stream_kwargs.get("redshift_jdbc_url") + self.redshift_role_arn = stream_kwargs.get("redshift_role_arn") + self.redshift_copy_command = stream_kwargs.get("redshift_copy_command") - self.s3_config = stream_kwargs.get('s3_config') - self.extended_s3_config = stream_kwargs.get('extended_s3_config') + self.s3_config = stream_kwargs.get("s3_config") + self.extended_s3_config = stream_kwargs.get("extended_s3_config") - self.redshift_s3_role_arn = stream_kwargs.get('redshift_s3_role_arn') - self.redshift_s3_bucket_arn = stream_kwargs.get( - 'redshift_s3_bucket_arn') - self.redshift_s3_prefix = stream_kwargs.get('redshift_s3_prefix') + self.redshift_s3_role_arn = stream_kwargs.get("redshift_s3_role_arn") + self.redshift_s3_bucket_arn = stream_kwargs.get("redshift_s3_bucket_arn") + self.redshift_s3_prefix = stream_kwargs.get("redshift_s3_prefix") self.redshift_s3_compression_format = stream_kwargs.get( - 'redshift_s3_compression_format', 'UNCOMPRESSED') + "redshift_s3_compression_format", "UNCOMPRESSED" + ) self.redshift_s3_buffering_hints = stream_kwargs.get( - 'redshift_s3_buffering_hints') + "redshift_s3_buffering_hints" + ) self.records = [] - self.status = 'ACTIVE' + self.status = "ACTIVE" self.created_at = datetime.datetime.utcnow() self.last_updated = datetime.datetime.utcnow() @property def arn(self): - return 'arn:aws:firehose:us-east-1:123456789012:deliverystream/{0}'.format(self.name) + return "arn:aws:firehose:us-east-1:123456789012:deliverystream/{0}".format( + self.name + ) def destinations_to_dict(self): if self.s3_config: - return [{ - 'DestinationId': 'string', - 'S3DestinationDescription': self.s3_config, - }] + return [ + {"DestinationId": "string", "S3DestinationDescription": self.s3_config} + ] elif self.extended_s3_config: - return [{ - 'DestinationId': 'string', - 'ExtendedS3DestinationDescription': self.extended_s3_config, - }] + return [ + { + "DestinationId": "string", + "ExtendedS3DestinationDescription": self.extended_s3_config, + } + ] else: - return [{ + return [ + { "DestinationId": "string", "RedshiftDestinationDescription": { "ClusterJDBCURL": self.redshift_jdbc_url, @@ -266,12 +288,12 @@ class DeliveryStream(BaseModel): "BufferingHints": self.redshift_s3_buffering_hints, "CompressionFormat": self.redshift_s3_compression_format, "Prefix": self.redshift_s3_prefix, - "RoleARN": self.redshift_s3_role_arn + "RoleARN": self.redshift_s3_role_arn, }, "Username": self.redshift_username, }, - } - ] + } + ] def to_dict(self): return { @@ -294,7 +316,6 @@ class DeliveryStream(BaseModel): class KinesisBackend(BaseBackend): - def __init__(self): self.streams = OrderedDict() self.delivery_streams = {} @@ -323,14 +344,24 @@ class KinesisBackend(BaseBackend): return self.streams.pop(stream_name) raise StreamNotFoundError(stream_name) - def get_shard_iterator(self, stream_name, shard_id, shard_iterator_type, starting_sequence_number, - at_timestamp): + def get_shard_iterator( + self, + stream_name, + shard_id, + shard_iterator_type, + starting_sequence_number, + at_timestamp, + ): # Validate params stream = self.describe_stream(stream_name) shard = stream.get_shard(shard_id) shard_iterator = compose_new_shard_iterator( - stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp + stream_name, + shard, + shard_iterator_type, + starting_sequence_number, + at_timestamp, ) return shard_iterator @@ -341,14 +372,24 @@ class KinesisBackend(BaseBackend): stream = self.describe_stream(stream_name) shard = stream.get_shard(shard_id) - records, last_sequence_id, millis_behind_latest = shard.get_records(last_sequence_id, limit) + records, last_sequence_id, millis_behind_latest = shard.get_records( + last_sequence_id, limit + ) next_shard_iterator = compose_shard_iterator( - stream_name, shard, last_sequence_id) + stream_name, shard, last_sequence_id + ) return next_shard_iterator, records, millis_behind_latest - def put_record(self, stream_name, partition_key, explicit_hash_key, sequence_number_for_ordering, data): + def put_record( + self, + stream_name, + partition_key, + explicit_hash_key, + sequence_number_for_ordering, + data, + ): stream = self.describe_stream(stream_name) sequence_number, shard_id = stream.put_record( @@ -360,10 +401,7 @@ class KinesisBackend(BaseBackend): def put_records(self, stream_name, records): stream = self.describe_stream(stream_name) - response = { - "FailedRecordCount": 0, - "Records": [] - } + response = {"FailedRecordCount": 0, "Records": []} for record in records: partition_key = record.get("PartitionKey") @@ -373,10 +411,9 @@ class KinesisBackend(BaseBackend): sequence_number, shard_id = stream.put_record( partition_key, explicit_hash_key, None, data ) - response['Records'].append({ - "SequenceNumber": sequence_number, - "ShardId": shard_id - }) + response["Records"].append( + {"SequenceNumber": sequence_number, "ShardId": shard_id} + ) return response @@ -386,18 +423,18 @@ class KinesisBackend(BaseBackend): if shard_to_split not in stream.shards: raise ResourceNotFoundError(shard_to_split) - if not re.match(r'0|([1-9]\d{0,38})', new_starting_hash_key): + if not re.match(r"0|([1-9]\d{0,38})", new_starting_hash_key): raise InvalidArgumentError(new_starting_hash_key) new_starting_hash_key = int(new_starting_hash_key) shard = stream.shards[shard_to_split] - last_id = sorted(stream.shards.values(), - key=attrgetter('_shard_id'))[-1]._shard_id + last_id = sorted(stream.shards.values(), key=attrgetter("_shard_id"))[ + -1 + ]._shard_id if shard.starting_hash < new_starting_hash_key < shard.ending_hash: - new_shard = Shard( - last_id + 1, new_starting_hash_key, shard.ending_hash) + new_shard = Shard(last_id + 1, new_starting_hash_key, shard.ending_hash) shard.ending_hash = new_starting_hash_key stream.shards[new_shard.shard_id] = new_shard else: @@ -434,10 +471,11 @@ class KinesisBackend(BaseBackend): del stream.shards[shard2.shard_id] for index in shard2.records: record = shard2.records[index] - shard1.put_record(record.partition_key, - record.data, record.explicit_hash_key) + shard1.put_record( + record.partition_key, record.data, record.explicit_hash_key + ) - ''' Firehose ''' + """ Firehose """ def create_delivery_stream(self, stream_name, **stream_kwargs): stream = DeliveryStream(stream_name, **stream_kwargs) @@ -461,25 +499,21 @@ class KinesisBackend(BaseBackend): record = stream.put_record(record_data) return record - def list_tags_for_stream(self, stream_name, exclusive_start_tag_key=None, limit=None): + def list_tags_for_stream( + self, stream_name, exclusive_start_tag_key=None, limit=None + ): stream = self.describe_stream(stream_name) tags = [] - result = { - 'HasMoreTags': False, - 'Tags': tags - } + result = {"HasMoreTags": False, "Tags": tags} for key, val in sorted(stream.tags.items(), key=lambda x: x[0]): if limit and len(tags) >= limit: - result['HasMoreTags'] = True + result["HasMoreTags"] = True break if exclusive_start_tag_key and key < exclusive_start_tag_key: continue - tags.append({ - 'Key': key, - 'Value': val - }) + tags.append({"Key": key, "Value": val}) return result diff --git a/moto/kinesis/responses.py b/moto/kinesis/responses.py index aa2b8c225..500f7855d 100644 --- a/moto/kinesis/responses.py +++ b/moto/kinesis/responses.py @@ -7,7 +7,6 @@ from .models import kinesis_backends class KinesisResponse(BaseResponse): - @property def parameters(self): return json.loads(self.body) @@ -18,47 +17,47 @@ class KinesisResponse(BaseResponse): @property def is_firehose(self): - host = self.headers.get('host') or self.headers['Host'] - return host.startswith('firehose') or 'firehose' in self.headers.get('Authorization', '') + host = self.headers.get("host") or self.headers["Host"] + return host.startswith("firehose") or "firehose" in self.headers.get( + "Authorization", "" + ) def create_stream(self): - stream_name = self.parameters.get('StreamName') - shard_count = self.parameters.get('ShardCount') - self.kinesis_backend.create_stream( - stream_name, shard_count, self.region) + stream_name = self.parameters.get("StreamName") + shard_count = self.parameters.get("ShardCount") + self.kinesis_backend.create_stream(stream_name, shard_count, self.region) return "" def describe_stream(self): - stream_name = self.parameters.get('StreamName') + stream_name = self.parameters.get("StreamName") stream = self.kinesis_backend.describe_stream(stream_name) return json.dumps(stream.to_json()) def describe_stream_summary(self): - stream_name = self.parameters.get('StreamName') + stream_name = self.parameters.get("StreamName") stream = self.kinesis_backend.describe_stream_summary(stream_name) return json.dumps(stream.to_json_summary()) def list_streams(self): streams = self.kinesis_backend.list_streams() stream_names = [stream.stream_name for stream in streams] - max_streams = self._get_param('Limit', 10) + max_streams = self._get_param("Limit", 10) try: - token = self.parameters.get('ExclusiveStartStreamName') + token = self.parameters.get("ExclusiveStartStreamName") except ValueError: - token = self._get_param('ExclusiveStartStreamName') + token = self._get_param("ExclusiveStartStreamName") if token: start = stream_names.index(token) + 1 else: start = 0 - streams_resp = stream_names[start:start + max_streams] + streams_resp = stream_names[start : start + max_streams] has_more_streams = False if start + max_streams < len(stream_names): has_more_streams = True - return json.dumps({ - "HasMoreStreams": has_more_streams, - "StreamNames": streams_resp - }) + return json.dumps( + {"HasMoreStreams": has_more_streams, "StreamNames": streams_resp} + ) def delete_stream(self): stream_name = self.parameters.get("StreamName") @@ -69,30 +68,36 @@ class KinesisResponse(BaseResponse): stream_name = self.parameters.get("StreamName") shard_id = self.parameters.get("ShardId") shard_iterator_type = self.parameters.get("ShardIteratorType") - starting_sequence_number = self.parameters.get( - "StartingSequenceNumber") + starting_sequence_number = self.parameters.get("StartingSequenceNumber") at_timestamp = self.parameters.get("Timestamp") shard_iterator = self.kinesis_backend.get_shard_iterator( - stream_name, shard_id, shard_iterator_type, starting_sequence_number, at_timestamp + stream_name, + shard_id, + shard_iterator_type, + starting_sequence_number, + at_timestamp, ) - return json.dumps({ - "ShardIterator": shard_iterator - }) + return json.dumps({"ShardIterator": shard_iterator}) def get_records(self): shard_iterator = self.parameters.get("ShardIterator") limit = self.parameters.get("Limit") - next_shard_iterator, records, millis_behind_latest = self.kinesis_backend.get_records( - shard_iterator, limit) + ( + next_shard_iterator, + records, + millis_behind_latest, + ) = self.kinesis_backend.get_records(shard_iterator, limit) - return json.dumps({ - "NextShardIterator": next_shard_iterator, - "Records": [record.to_json() for record in records], - 'MillisBehindLatest': millis_behind_latest - }) + return json.dumps( + { + "NextShardIterator": next_shard_iterator, + "Records": [record.to_json() for record in records], + "MillisBehindLatest": millis_behind_latest, + } + ) def put_record(self): if self.is_firehose: @@ -100,18 +105,18 @@ class KinesisResponse(BaseResponse): stream_name = self.parameters.get("StreamName") partition_key = self.parameters.get("PartitionKey") explicit_hash_key = self.parameters.get("ExplicitHashKey") - sequence_number_for_ordering = self.parameters.get( - "SequenceNumberForOrdering") + sequence_number_for_ordering = self.parameters.get("SequenceNumberForOrdering") data = self.parameters.get("Data") sequence_number, shard_id = self.kinesis_backend.put_record( - stream_name, partition_key, explicit_hash_key, sequence_number_for_ordering, data + stream_name, + partition_key, + explicit_hash_key, + sequence_number_for_ordering, + data, ) - return json.dumps({ - "SequenceNumber": sequence_number, - "ShardId": shard_id, - }) + return json.dumps({"SequenceNumber": sequence_number, "ShardId": shard_id}) def put_records(self): if self.is_firehose: @@ -119,9 +124,7 @@ class KinesisResponse(BaseResponse): stream_name = self.parameters.get("StreamName") records = self.parameters.get("Records") - response = self.kinesis_backend.put_records( - stream_name, records - ) + response = self.kinesis_backend.put_records(stream_name, records) return json.dumps(response) @@ -143,42 +146,39 @@ class KinesisResponse(BaseResponse): ) return "" - ''' Firehose ''' + """ Firehose """ def create_delivery_stream(self): - stream_name = self.parameters['DeliveryStreamName'] - redshift_config = self.parameters.get( - 'RedshiftDestinationConfiguration') - s3_config = self.parameters.get( - 'S3DestinationConfiguration') - extended_s3_config = self.parameters.get( - 'ExtendedS3DestinationConfiguration') + stream_name = self.parameters["DeliveryStreamName"] + redshift_config = self.parameters.get("RedshiftDestinationConfiguration") + s3_config = self.parameters.get("S3DestinationConfiguration") + extended_s3_config = self.parameters.get("ExtendedS3DestinationConfiguration") if redshift_config: - redshift_s3_config = redshift_config['S3Configuration'] + redshift_s3_config = redshift_config["S3Configuration"] stream_kwargs = { - 'redshift_username': redshift_config['Username'], - 'redshift_password': redshift_config['Password'], - 'redshift_jdbc_url': redshift_config['ClusterJDBCURL'], - 'redshift_role_arn': redshift_config['RoleARN'], - 'redshift_copy_command': redshift_config['CopyCommand'], - - 'redshift_s3_role_arn': redshift_s3_config['RoleARN'], - 'redshift_s3_bucket_arn': redshift_s3_config['BucketARN'], - 'redshift_s3_prefix': redshift_s3_config['Prefix'], - 'redshift_s3_compression_format': redshift_s3_config.get('CompressionFormat'), - 'redshift_s3_buffering_hints': redshift_s3_config['BufferingHints'], + "redshift_username": redshift_config["Username"], + "redshift_password": redshift_config["Password"], + "redshift_jdbc_url": redshift_config["ClusterJDBCURL"], + "redshift_role_arn": redshift_config["RoleARN"], + "redshift_copy_command": redshift_config["CopyCommand"], + "redshift_s3_role_arn": redshift_s3_config["RoleARN"], + "redshift_s3_bucket_arn": redshift_s3_config["BucketARN"], + "redshift_s3_prefix": redshift_s3_config["Prefix"], + "redshift_s3_compression_format": redshift_s3_config.get( + "CompressionFormat" + ), + "redshift_s3_buffering_hints": redshift_s3_config["BufferingHints"], } elif s3_config: - stream_kwargs = {'s3_config': s3_config} + stream_kwargs = {"s3_config": s3_config} elif extended_s3_config: - stream_kwargs = {'extended_s3_config': extended_s3_config} + stream_kwargs = {"extended_s3_config": extended_s3_config} stream = self.kinesis_backend.create_delivery_stream( - stream_name, **stream_kwargs) - return json.dumps({ - 'DeliveryStreamARN': stream.arn - }) + stream_name, **stream_kwargs + ) + return json.dumps({"DeliveryStreamARN": stream.arn}) def describe_delivery_stream(self): stream_name = self.parameters["DeliveryStreamName"] @@ -187,60 +187,54 @@ class KinesisResponse(BaseResponse): def list_delivery_streams(self): streams = self.kinesis_backend.list_delivery_streams() - return json.dumps({ - "DeliveryStreamNames": [ - stream.name for stream in streams - ], - "HasMoreDeliveryStreams": False - }) + return json.dumps( + { + "DeliveryStreamNames": [stream.name for stream in streams], + "HasMoreDeliveryStreams": False, + } + ) def delete_delivery_stream(self): - stream_name = self.parameters['DeliveryStreamName'] + stream_name = self.parameters["DeliveryStreamName"] self.kinesis_backend.delete_delivery_stream(stream_name) return json.dumps({}) def firehose_put_record(self): - stream_name = self.parameters['DeliveryStreamName'] - record_data = self.parameters['Record']['Data'] + stream_name = self.parameters["DeliveryStreamName"] + record_data = self.parameters["Record"]["Data"] - record = self.kinesis_backend.put_firehose_record( - stream_name, record_data) - return json.dumps({ - "RecordId": record.record_id, - }) + record = self.kinesis_backend.put_firehose_record(stream_name, record_data) + return json.dumps({"RecordId": record.record_id}) def put_record_batch(self): - stream_name = self.parameters['DeliveryStreamName'] - records = self.parameters['Records'] + stream_name = self.parameters["DeliveryStreamName"] + records = self.parameters["Records"] request_responses = [] for record in records: record_response = self.kinesis_backend.put_firehose_record( - stream_name, record['Data']) - request_responses.append({ - "RecordId": record_response.record_id - }) - return json.dumps({ - "FailedPutCount": 0, - "RequestResponses": request_responses, - }) + stream_name, record["Data"] + ) + request_responses.append({"RecordId": record_response.record_id}) + return json.dumps({"FailedPutCount": 0, "RequestResponses": request_responses}) def add_tags_to_stream(self): - stream_name = self.parameters.get('StreamName') - tags = self.parameters.get('Tags') + stream_name = self.parameters.get("StreamName") + tags = self.parameters.get("Tags") self.kinesis_backend.add_tags_to_stream(stream_name, tags) return json.dumps({}) def list_tags_for_stream(self): - stream_name = self.parameters.get('StreamName') - exclusive_start_tag_key = self.parameters.get('ExclusiveStartTagKey') - limit = self.parameters.get('Limit') + stream_name = self.parameters.get("StreamName") + exclusive_start_tag_key = self.parameters.get("ExclusiveStartTagKey") + limit = self.parameters.get("Limit") response = self.kinesis_backend.list_tags_for_stream( - stream_name, exclusive_start_tag_key, limit) + stream_name, exclusive_start_tag_key, limit + ) return json.dumps(response) def remove_tags_from_stream(self): - stream_name = self.parameters.get('StreamName') - tag_keys = self.parameters.get('TagKeys') + stream_name = self.parameters.get("StreamName") + tag_keys = self.parameters.get("TagKeys") self.kinesis_backend.remove_tags_from_stream(stream_name, tag_keys) return json.dumps({}) diff --git a/moto/kinesis/urls.py b/moto/kinesis/urls.py index a8d15eecd..c95f03190 100644 --- a/moto/kinesis/urls.py +++ b/moto/kinesis/urls.py @@ -6,6 +6,4 @@ url_bases = [ "https?://firehose.(.+).amazonaws.com", ] -url_paths = { - '{0}/$': KinesisResponse.dispatch, -} +url_paths = {"{0}/$": KinesisResponse.dispatch} diff --git a/moto/kinesis/utils.py b/moto/kinesis/utils.py index 0c3edbb5a..b455cb7ba 100644 --- a/moto/kinesis/utils.py +++ b/moto/kinesis/utils.py @@ -14,8 +14,9 @@ else: raise Exception("Python version is not supported") -def compose_new_shard_iterator(stream_name, shard, shard_iterator_type, starting_sequence_number, - at_timestamp): +def compose_new_shard_iterator( + stream_name, shard, shard_iterator_type, starting_sequence_number, at_timestamp +): if shard_iterator_type == "AT_SEQUENCE_NUMBER": last_sequence_id = int(starting_sequence_number) - 1 elif shard_iterator_type == "AFTER_SEQUENCE_NUMBER": @@ -28,17 +29,16 @@ def compose_new_shard_iterator(stream_name, shard, shard_iterator_type, starting last_sequence_id = shard.get_sequence_number_at(at_timestamp) else: raise InvalidArgumentError( - "Invalid ShardIteratorType: {0}".format(shard_iterator_type)) + "Invalid ShardIteratorType: {0}".format(shard_iterator_type) + ) return compose_shard_iterator(stream_name, shard, last_sequence_id) def compose_shard_iterator(stream_name, shard, last_sequence_id): return encode_method( - "{0}:{1}:{2}".format( - stream_name, - shard.shard_id, - last_sequence_id, - ).encode("utf-8") + "{0}:{1}:{2}".format(stream_name, shard.shard_id, last_sequence_id).encode( + "utf-8" + ) ).decode("utf-8") diff --git a/moto/kms/__init__.py b/moto/kms/__init__.py index b4bb0b639..ecedb8bfd 100644 --- a/moto/kms/__init__.py +++ b/moto/kms/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import kms_backends from ..core.models import base_decorator, deprecated_base_decorator -kms_backend = kms_backends['us-east-1'] +kms_backend = kms_backends["us-east-1"] mock_kms = base_decorator(kms_backends) mock_kms_deprecated = deprecated_base_decorator(kms_backends) diff --git a/moto/kms/exceptions.py b/moto/kms/exceptions.py index c9094e8f8..4ddfd279f 100644 --- a/moto/kms/exceptions.py +++ b/moto/kms/exceptions.py @@ -6,32 +6,28 @@ class NotFoundException(JsonRESTError): code = 400 def __init__(self, message): - super(NotFoundException, self).__init__( - "NotFoundException", message) + super(NotFoundException, self).__init__("NotFoundException", message) class ValidationException(JsonRESTError): code = 400 def __init__(self, message): - super(ValidationException, self).__init__( - "ValidationException", message) + super(ValidationException, self).__init__("ValidationException", message) class AlreadyExistsException(JsonRESTError): code = 400 def __init__(self, message): - super(AlreadyExistsException, self).__init__( - "AlreadyExistsException", message) + super(AlreadyExistsException, self).__init__("AlreadyExistsException", message) class NotAuthorizedException(JsonRESTError): code = 400 def __init__(self): - super(NotAuthorizedException, self).__init__( - "NotAuthorizedException", None) + super(NotAuthorizedException, self).__init__("NotAuthorizedException", None) self.description = '{"__type":"NotAuthorizedException"}' @@ -40,8 +36,7 @@ class AccessDeniedException(JsonRESTError): code = 400 def __init__(self, message): - super(AccessDeniedException, self).__init__( - "AccessDeniedException", message) + super(AccessDeniedException, self).__init__("AccessDeniedException", message) self.description = '{"__type":"AccessDeniedException"}' @@ -51,6 +46,7 @@ class InvalidCiphertextException(JsonRESTError): def __init__(self): super(InvalidCiphertextException, self).__init__( - "InvalidCiphertextException", None) + "InvalidCiphertextException", None + ) self.description = '{"__type":"InvalidCiphertextException"}' diff --git a/moto/kms/models.py b/moto/kms/models.py index 9e1b08bf9..9d7739779 100644 --- a/moto/kms/models.py +++ b/moto/kms/models.py @@ -33,7 +33,9 @@ class Key(BaseModel): @property def arn(self): - return "arn:aws:kms:{0}:{1}:key/{2}".format(self.region, self.account_id, self.id) + return "arn:aws:kms:{0}:{1}:key/{2}".format( + self.region, self.account_id, self.id + ) def to_dict(self): key_dict = { @@ -49,14 +51,18 @@ class Key(BaseModel): } } if self.key_state == "PendingDeletion": - key_dict["KeyMetadata"]["DeletionDate"] = iso_8601_datetime_without_milliseconds(self.deletion_date) + key_dict["KeyMetadata"][ + "DeletionDate" + ] = iso_8601_datetime_without_milliseconds(self.deletion_date) return key_dict def delete(self, region_name): kms_backends[region_name].delete_key(self.id) @classmethod - def create_from_cloudformation_json(self, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + self, resource_name, cloudformation_json, region_name + ): kms_backend = kms_backends[region_name] properties = cloudformation_json["Properties"] @@ -206,39 +212,57 @@ class KmsBackend(BaseBackend): if 7 <= pending_window_in_days <= 30: self.keys[key_id].enabled = False self.keys[key_id].key_state = "PendingDeletion" - self.keys[key_id].deletion_date = datetime.now() + timedelta(days=pending_window_in_days) - return iso_8601_datetime_without_milliseconds(self.keys[key_id].deletion_date) + self.keys[key_id].deletion_date = datetime.now() + timedelta( + days=pending_window_in_days + ) + return iso_8601_datetime_without_milliseconds( + self.keys[key_id].deletion_date + ) def encrypt(self, key_id, plaintext, encryption_context): key_id = self.any_id_to_key_id(key_id) ciphertext_blob = encrypt( - master_keys=self.keys, key_id=key_id, plaintext=plaintext, encryption_context=encryption_context + master_keys=self.keys, + key_id=key_id, + plaintext=plaintext, + encryption_context=encryption_context, ) arn = self.keys[key_id].arn return ciphertext_blob, arn def decrypt(self, ciphertext_blob, encryption_context): plaintext, key_id = decrypt( - master_keys=self.keys, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context + master_keys=self.keys, + ciphertext_blob=ciphertext_blob, + encryption_context=encryption_context, ) arn = self.keys[key_id].arn return plaintext, arn def re_encrypt( - self, ciphertext_blob, source_encryption_context, destination_key_id, destination_encryption_context + self, + ciphertext_blob, + source_encryption_context, + destination_key_id, + destination_encryption_context, ): destination_key_id = self.any_id_to_key_id(destination_key_id) plaintext, decrypting_arn = self.decrypt( - ciphertext_blob=ciphertext_blob, encryption_context=source_encryption_context + ciphertext_blob=ciphertext_blob, + encryption_context=source_encryption_context, ) new_ciphertext_blob, encrypting_arn = self.encrypt( - key_id=destination_key_id, plaintext=plaintext, encryption_context=destination_encryption_context + key_id=destination_key_id, + plaintext=plaintext, + encryption_context=destination_encryption_context, ) return new_ciphertext_blob, decrypting_arn, encrypting_arn - def generate_data_key(self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens): + def generate_data_key( + self, key_id, encryption_context, number_of_bytes, key_spec, grant_tokens + ): key_id = self.any_id_to_key_id(key_id) if key_spec: @@ -252,7 +276,9 @@ class KmsBackend(BaseBackend): plaintext = os.urandom(plaintext_len) - ciphertext_blob, arn = self.encrypt(key_id=key_id, plaintext=plaintext, encryption_context=encryption_context) + ciphertext_blob, arn = self.encrypt( + key_id=key_id, plaintext=plaintext, encryption_context=encryption_context + ) return plaintext, ciphertext_blob, arn diff --git a/moto/kms/responses.py b/moto/kms/responses.py index 998d5cc4b..d3a9726e1 100644 --- a/moto/kms/responses.py +++ b/moto/kms/responses.py @@ -9,19 +9,23 @@ import six from moto.core.responses import BaseResponse from .models import kms_backends -from .exceptions import NotFoundException, ValidationException, AlreadyExistsException, NotAuthorizedException +from .exceptions import ( + NotFoundException, + ValidationException, + AlreadyExistsException, + NotAuthorizedException, +) ACCOUNT_ID = "012345678912" reserved_aliases = [ - 'alias/aws/ebs', - 'alias/aws/s3', - 'alias/aws/redshift', - 'alias/aws/rds', + "alias/aws/ebs", + "alias/aws/s3", + "alias/aws/redshift", + "alias/aws/rds", ] class KmsResponse(BaseResponse): - @property def parameters(self): params = json.loads(self.body) @@ -56,7 +60,11 @@ class KmsResponse(BaseResponse): - key ARN """ is_arn = key_id.startswith("arn:") and ":key/" in key_id - is_raw_key_id = re.match(r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", key_id, re.IGNORECASE) + is_raw_key_id = re.match( + r"^[A-F0-9]{8}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{4}-[A-F0-9]{12}$", + key_id, + re.IGNORECASE, + ) if not is_arn and not is_raw_key_id: raise NotFoundException("Invalid keyId {key_id}".format(key_id=key_id)) @@ -64,7 +72,9 @@ class KmsResponse(BaseResponse): cmk_id = self.kms_backend.get_key_id(key_id) if cmk_id not in self.kms_backend.keys: - raise NotFoundException("Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id))) + raise NotFoundException( + "Key '{key_id}' does not exist".format(key_id=self._display_arn(key_id)) + ) def _validate_alias(self, key_id): """Determine whether an alias exists. @@ -72,7 +82,9 @@ class KmsResponse(BaseResponse): - alias name - alias ARN """ - error = NotFoundException("Alias {key_id} is not found.".format(key_id=self._display_arn(key_id))) + error = NotFoundException( + "Alias {key_id} is not found.".format(key_id=self._display_arn(key_id)) + ) is_arn = key_id.startswith("arn:") and ":alias/" in key_id is_name = key_id.startswith("alias/") @@ -104,19 +116,20 @@ class KmsResponse(BaseResponse): def create_key(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateKey.html""" - policy = self.parameters.get('Policy') - key_usage = self.parameters.get('KeyUsage') - description = self.parameters.get('Description') - tags = self.parameters.get('Tags') + policy = self.parameters.get("Policy") + key_usage = self.parameters.get("KeyUsage") + description = self.parameters.get("Description") + tags = self.parameters.get("Tags") key = self.kms_backend.create_key( - policy, key_usage, description, tags, self.region) + policy, key_usage, description, tags, self.region + ) return json.dumps(key.to_dict()) def update_key_description(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_UpdateKeyDescription.html""" - key_id = self.parameters.get('KeyId') - description = self.parameters.get('Description') + key_id = self.parameters.get("KeyId") + description = self.parameters.get("Description") self._validate_cmk_id(key_id) @@ -125,8 +138,8 @@ class KmsResponse(BaseResponse): def tag_resource(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_TagResource.html""" - key_id = self.parameters.get('KeyId') - tags = self.parameters.get('Tags') + key_id = self.parameters.get("KeyId") + tags = self.parameters.get("Tags") self._validate_cmk_id(key_id) @@ -135,26 +148,20 @@ class KmsResponse(BaseResponse): def list_resource_tags(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListResourceTags.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) tags = self.kms_backend.list_resource_tags(key_id) - return json.dumps({ - "Tags": tags, - "NextMarker": None, - "Truncated": False, - }) + return json.dumps({"Tags": tags, "NextMarker": None, "Truncated": False}) def describe_key(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_DescribeKey.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_key_id(key_id) - key = self.kms_backend.describe_key( - self.kms_backend.get_key_id(key_id) - ) + key = self.kms_backend.describe_key(self.kms_backend.get_key_id(key_id)) return json.dumps(key.to_dict()) @@ -162,43 +169,47 @@ class KmsResponse(BaseResponse): """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeys.html""" keys = self.kms_backend.list_keys() - return json.dumps({ - "Keys": [ - { - "KeyArn": key.arn, - "KeyId": key.id, - } for key in keys - ], - "NextMarker": None, - "Truncated": False, - }) + return json.dumps( + { + "Keys": [{"KeyArn": key.arn, "KeyId": key.id} for key in keys], + "NextMarker": None, + "Truncated": False, + } + ) def create_alias(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_CreateAlias.html""" - alias_name = self.parameters['AliasName'] - target_key_id = self.parameters['TargetKeyId'] + alias_name = self.parameters["AliasName"] + target_key_id = self.parameters["TargetKeyId"] - if not alias_name.startswith('alias/'): - raise ValidationException('Invalid identifier') + if not alias_name.startswith("alias/"): + raise ValidationException("Invalid identifier") if alias_name in reserved_aliases: raise NotAuthorizedException() - if ':' in alias_name: - raise ValidationException('{alias_name} contains invalid characters for an alias'.format(alias_name=alias_name)) + if ":" in alias_name: + raise ValidationException( + "{alias_name} contains invalid characters for an alias".format( + alias_name=alias_name + ) + ) - if not re.match(r'^[a-zA-Z0-9:/_-]+$', alias_name): - raise ValidationException("1 validation error detected: Value '{alias_name}' at 'aliasName' " - "failed to satisfy constraint: Member must satisfy regular " - "expression pattern: ^[a-zA-Z0-9:/_-]+$" - .format(alias_name=alias_name)) + if not re.match(r"^[a-zA-Z0-9:/_-]+$", alias_name): + raise ValidationException( + "1 validation error detected: Value '{alias_name}' at 'aliasName' " + "failed to satisfy constraint: Member must satisfy regular " + "expression pattern: ^[a-zA-Z0-9:/_-]+$".format(alias_name=alias_name) + ) if self.kms_backend.alias_exists(target_key_id): - raise ValidationException('Aliases must refer to keys. Not aliases') + raise ValidationException("Aliases must refer to keys. Not aliases") if self.kms_backend.alias_exists(alias_name): - raise AlreadyExistsException('An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} ' - 'already exists'.format(region=self.region, alias_name=alias_name)) + raise AlreadyExistsException( + "An alias with the name arn:aws:kms:{region}:012345678912:{alias_name} " + "already exists".format(region=self.region, alias_name=alias_name) + ) self._validate_cmk_id(target_key_id) @@ -208,10 +219,10 @@ class KmsResponse(BaseResponse): def delete_alias(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_DeleteAlias.html""" - alias_name = self.parameters['AliasName'] + alias_name = self.parameters["AliasName"] - if not alias_name.startswith('alias/'): - raise ValidationException('Invalid identifier') + if not alias_name.startswith("alias/"): + raise ValidationException("Invalid identifier") self._validate_alias(alias_name) @@ -227,30 +238,32 @@ class KmsResponse(BaseResponse): response_aliases = [ { - 'AliasArn': u'arn:aws:kms:{region}:012345678912:{reserved_alias}'.format(region=region, - reserved_alias=reserved_alias), - 'AliasName': reserved_alias - } for reserved_alias in reserved_aliases + "AliasArn": "arn:aws:kms:{region}:012345678912:{reserved_alias}".format( + region=region, reserved_alias=reserved_alias + ), + "AliasName": reserved_alias, + } + for reserved_alias in reserved_aliases ] backend_aliases = self.kms_backend.get_all_aliases() for target_key_id, aliases in backend_aliases.items(): for alias_name in aliases: - response_aliases.append({ - 'AliasArn': u'arn:aws:kms:{region}:012345678912:{alias_name}'.format(region=region, - alias_name=alias_name), - 'AliasName': alias_name, - 'TargetKeyId': target_key_id, - }) + response_aliases.append( + { + "AliasArn": "arn:aws:kms:{region}:012345678912:{alias_name}".format( + region=region, alias_name=alias_name + ), + "AliasName": alias_name, + "TargetKeyId": target_key_id, + } + ) - return json.dumps({ - 'Truncated': False, - 'Aliases': response_aliases, - }) + return json.dumps({"Truncated": False, "Aliases": response_aliases}) def enable_key_rotation(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) @@ -260,7 +273,7 @@ class KmsResponse(BaseResponse): def disable_key_rotation(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKeyRotation.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) @@ -270,19 +283,19 @@ class KmsResponse(BaseResponse): def get_key_rotation_status(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyRotationStatus.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) rotation_enabled = self.kms_backend.get_key_rotation_status(key_id) - return json.dumps({'KeyRotationEnabled': rotation_enabled}) + return json.dumps({"KeyRotationEnabled": rotation_enabled}) def put_key_policy(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_PutKeyPolicy.html""" - key_id = self.parameters.get('KeyId') - policy_name = self.parameters.get('PolicyName') - policy = self.parameters.get('Policy') + key_id = self.parameters.get("KeyId") + policy_name = self.parameters.get("PolicyName") + policy = self.parameters.get("Policy") _assert_default_policy(policy_name) self._validate_cmk_id(key_id) @@ -293,39 +306,37 @@ class KmsResponse(BaseResponse): def get_key_policy(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_GetKeyPolicy.html""" - key_id = self.parameters.get('KeyId') - policy_name = self.parameters.get('PolicyName') + key_id = self.parameters.get("KeyId") + policy_name = self.parameters.get("PolicyName") _assert_default_policy(policy_name) self._validate_cmk_id(key_id) - return json.dumps({'Policy': self.kms_backend.get_key_policy(key_id)}) + return json.dumps({"Policy": self.kms_backend.get_key_policy(key_id)}) def list_key_policies(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_ListKeyPolicies.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) self.kms_backend.describe_key(key_id) - return json.dumps({'Truncated': False, 'PolicyNames': ['default']}) + return json.dumps({"Truncated": False, "PolicyNames": ["default"]}) def encrypt(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_Encrypt.html""" key_id = self.parameters.get("KeyId") - encryption_context = self.parameters.get('EncryptionContext', {}) + encryption_context = self.parameters.get("EncryptionContext", {}) plaintext = self.parameters.get("Plaintext") self._validate_key_id(key_id) if isinstance(plaintext, six.text_type): - plaintext = plaintext.encode('utf-8') + plaintext = plaintext.encode("utf-8") ciphertext_blob, arn = self.kms_backend.encrypt( - key_id=key_id, - plaintext=plaintext, - encryption_context=encryption_context, + key_id=key_id, plaintext=plaintext, encryption_context=encryption_context ) ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") @@ -334,27 +345,32 @@ class KmsResponse(BaseResponse): def decrypt(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_Decrypt.html""" ciphertext_blob = self.parameters.get("CiphertextBlob") - encryption_context = self.parameters.get('EncryptionContext', {}) + encryption_context = self.parameters.get("EncryptionContext", {}) plaintext, arn = self.kms_backend.decrypt( - ciphertext_blob=ciphertext_blob, - encryption_context=encryption_context, + ciphertext_blob=ciphertext_blob, encryption_context=encryption_context ) plaintext_response = base64.b64encode(plaintext).decode("utf-8") - return json.dumps({"Plaintext": plaintext_response, 'KeyId': arn}) + return json.dumps({"Plaintext": plaintext_response, "KeyId": arn}) def re_encrypt(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_ReEncrypt.html""" ciphertext_blob = self.parameters.get("CiphertextBlob") source_encryption_context = self.parameters.get("SourceEncryptionContext", {}) destination_key_id = self.parameters.get("DestinationKeyId") - destination_encryption_context = self.parameters.get("DestinationEncryptionContext", {}) + destination_encryption_context = self.parameters.get( + "DestinationEncryptionContext", {} + ) self._validate_cmk_id(destination_key_id) - new_ciphertext_blob, decrypting_arn, encrypting_arn = self.kms_backend.re_encrypt( + ( + new_ciphertext_blob, + decrypting_arn, + encrypting_arn, + ) = self.kms_backend.re_encrypt( ciphertext_blob=ciphertext_blob, source_encryption_context=source_encryption_context, destination_key_id=destination_key_id, @@ -364,12 +380,16 @@ class KmsResponse(BaseResponse): response_ciphertext_blob = base64.b64encode(new_ciphertext_blob).decode("utf-8") return json.dumps( - {"CiphertextBlob": response_ciphertext_blob, "KeyId": encrypting_arn, "SourceKeyId": decrypting_arn} + { + "CiphertextBlob": response_ciphertext_blob, + "KeyId": encrypting_arn, + "SourceKeyId": decrypting_arn, + } ) def disable_key(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_DisableKey.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) @@ -379,7 +399,7 @@ class KmsResponse(BaseResponse): def enable_key(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_EnableKey.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) @@ -389,80 +409,94 @@ class KmsResponse(BaseResponse): def cancel_key_deletion(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_CancelKeyDeletion.html""" - key_id = self.parameters.get('KeyId') + key_id = self.parameters.get("KeyId") self._validate_cmk_id(key_id) self.kms_backend.cancel_key_deletion(key_id) - return json.dumps({'KeyId': key_id}) + return json.dumps({"KeyId": key_id}) def schedule_key_deletion(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_ScheduleKeyDeletion.html""" - key_id = self.parameters.get('KeyId') - if self.parameters.get('PendingWindowInDays') is None: + key_id = self.parameters.get("KeyId") + if self.parameters.get("PendingWindowInDays") is None: pending_window_in_days = 30 else: - pending_window_in_days = self.parameters.get('PendingWindowInDays') + pending_window_in_days = self.parameters.get("PendingWindowInDays") self._validate_cmk_id(key_id) - return json.dumps({ - 'KeyId': key_id, - 'DeletionDate': self.kms_backend.schedule_key_deletion(key_id, pending_window_in_days) - }) + return json.dumps( + { + "KeyId": key_id, + "DeletionDate": self.kms_backend.schedule_key_deletion( + key_id, pending_window_in_days + ), + } + ) def generate_data_key(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKey.html""" - key_id = self.parameters.get('KeyId') - encryption_context = self.parameters.get('EncryptionContext', {}) - number_of_bytes = self.parameters.get('NumberOfBytes') - key_spec = self.parameters.get('KeySpec') - grant_tokens = self.parameters.get('GrantTokens') + key_id = self.parameters.get("KeyId") + encryption_context = self.parameters.get("EncryptionContext", {}) + number_of_bytes = self.parameters.get("NumberOfBytes") + key_spec = self.parameters.get("KeySpec") + grant_tokens = self.parameters.get("GrantTokens") # Param validation self._validate_key_id(key_id) if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): - raise ValidationException(( - "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " - "to satisfy constraint: Member must have value less than or " - "equal to 1024" - ).format(number_of_bytes=number_of_bytes)) + raise ValidationException( + ( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes) + ) - if key_spec and key_spec not in ('AES_256', 'AES_128'): - raise ValidationException(( - "1 validation error detected: Value '{key_spec}' at 'keySpec' failed " - "to satisfy constraint: Member must satisfy enum value set: " - "[AES_256, AES_128]" - ).format(key_spec=key_spec)) + if key_spec and key_spec not in ("AES_256", "AES_128"): + raise ValidationException( + ( + "1 validation error detected: Value '{key_spec}' at 'keySpec' failed " + "to satisfy constraint: Member must satisfy enum value set: " + "[AES_256, AES_128]" + ).format(key_spec=key_spec) + ) if not key_spec and not number_of_bytes: - raise ValidationException("Please specify either number of bytes or key spec.") + raise ValidationException( + "Please specify either number of bytes or key spec." + ) if key_spec and number_of_bytes: - raise ValidationException("Please specify either number of bytes or key spec.") + raise ValidationException( + "Please specify either number of bytes or key spec." + ) plaintext, ciphertext_blob, key_arn = self.kms_backend.generate_data_key( key_id=key_id, encryption_context=encryption_context, number_of_bytes=number_of_bytes, key_spec=key_spec, - grant_tokens=grant_tokens + grant_tokens=grant_tokens, ) plaintext_response = base64.b64encode(plaintext).decode("utf-8") ciphertext_blob_response = base64.b64encode(ciphertext_blob).decode("utf-8") - return json.dumps({ - 'CiphertextBlob': ciphertext_blob_response, - 'Plaintext': plaintext_response, - 'KeyId': key_arn # not alias - }) + return json.dumps( + { + "CiphertextBlob": ciphertext_blob_response, + "Plaintext": plaintext_response, + "KeyId": key_arn, # not alias + } + ) def generate_data_key_without_plaintext(self): """https://docs.aws.amazon.com/kms/latest/APIReference/API_GenerateDataKeyWithoutPlaintext.html""" result = json.loads(self.generate_data_key()) - del result['Plaintext'] + del result["Plaintext"] return json.dumps(result) @@ -471,11 +505,13 @@ class KmsResponse(BaseResponse): number_of_bytes = self.parameters.get("NumberOfBytes") if number_of_bytes and (number_of_bytes > 1024 or number_of_bytes < 1): - raise ValidationException(( - "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " - "to satisfy constraint: Member must have value less than or " - "equal to 1024" - ).format(number_of_bytes=number_of_bytes)) + raise ValidationException( + ( + "1 validation error detected: Value '{number_of_bytes:d}' at 'numberOfBytes' failed " + "to satisfy constraint: Member must have value less than or " + "equal to 1024" + ).format(number_of_bytes=number_of_bytes) + ) entropy = os.urandom(number_of_bytes) @@ -485,5 +521,5 @@ class KmsResponse(BaseResponse): def _assert_default_policy(policy_name): - if policy_name != 'default': + if policy_name != "default": raise NotFoundException("No such policy exists") diff --git a/moto/kms/urls.py b/moto/kms/urls.py index 5b0b48969..97e1a3720 100644 --- a/moto/kms/urls.py +++ b/moto/kms/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import KmsResponse -url_bases = [ - "https?://kms.(.+).amazonaws.com", -] +url_bases = ["https?://kms.(.+).amazonaws.com"] -url_paths = { - '{0}/$': KmsResponse.dispatch, -} +url_paths = {"{0}/$": KmsResponse.dispatch} diff --git a/moto/kms/utils.py b/moto/kms/utils.py index 96d3f25cc..4eacba1a6 100644 --- a/moto/kms/utils.py +++ b/moto/kms/utils.py @@ -9,7 +9,11 @@ import uuid from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import algorithms, Cipher, modes -from .exceptions import InvalidCiphertextException, AccessDeniedException, NotFoundException +from .exceptions import ( + InvalidCiphertextException, + AccessDeniedException, + NotFoundException, +) MASTER_KEY_LEN = 32 @@ -43,7 +47,12 @@ def _serialize_ciphertext_blob(ciphertext): NOTE: This is just a simple binary format. It is not what KMS actually does. """ - header = struct.pack(CIPHERTEXT_HEADER_FORMAT, ciphertext.key_id.encode("utf-8"), ciphertext.iv, ciphertext.tag) + header = struct.pack( + CIPHERTEXT_HEADER_FORMAT, + ciphertext.key_id.encode("utf-8"), + ciphertext.iv, + ciphertext.tag, + ) return header + ciphertext.ciphertext @@ -55,7 +64,9 @@ def _deserialize_ciphertext_blob(ciphertext_blob): header = ciphertext_blob[:HEADER_LEN] ciphertext = ciphertext_blob[HEADER_LEN:] key_id, iv, tag = struct.unpack(CIPHERTEXT_HEADER_FORMAT, header) - return Ciphertext(key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag) + return Ciphertext( + key_id=key_id.decode("utf-8"), iv=iv, ciphertext=ciphertext, tag=tag + ) def _serialize_encryption_context(encryption_context): @@ -88,17 +99,23 @@ def encrypt(master_keys, key_id, plaintext, encryption_context): except KeyError: is_alias = key_id.startswith("alias/") or ":alias/" in key_id raise NotFoundException( - "{id_type} {key_id} is not found.".format(id_type="Alias" if is_alias else "keyId", key_id=key_id) + "{id_type} {key_id} is not found.".format( + id_type="Alias" if is_alias else "keyId", key_id=key_id + ) ) iv = os.urandom(IV_LEN) aad = _serialize_encryption_context(encryption_context=encryption_context) - encryptor = Cipher(algorithms.AES(key.key_material), modes.GCM(iv), backend=default_backend()).encryptor() + encryptor = Cipher( + algorithms.AES(key.key_material), modes.GCM(iv), backend=default_backend() + ).encryptor() encryptor.authenticate_additional_data(aad) ciphertext = encryptor.update(plaintext) + encryptor.finalize() return _serialize_ciphertext_blob( - ciphertext=Ciphertext(key_id=key_id, iv=iv, ciphertext=ciphertext, tag=encryptor.tag) + ciphertext=Ciphertext( + key_id=key_id, iv=iv, ciphertext=ciphertext, tag=encryptor.tag + ) ) @@ -132,7 +149,9 @@ def decrypt(master_keys, ciphertext_blob, encryption_context): try: decryptor = Cipher( - algorithms.AES(key.key_material), modes.GCM(ciphertext.iv, ciphertext.tag), backend=default_backend() + algorithms.AES(key.key_material), + modes.GCM(ciphertext.iv, ciphertext.tag), + backend=default_backend(), ).decryptor() decryptor.authenticate_additional_data(aad) plaintext = decryptor.update(ciphertext.ciphertext) + decryptor.finalize() diff --git a/moto/logs/exceptions.py b/moto/logs/exceptions.py index bb02eced3..9f6628b0f 100644 --- a/moto/logs/exceptions.py +++ b/moto/logs/exceptions.py @@ -10,8 +10,7 @@ class ResourceNotFoundException(LogsClientError): def __init__(self): self.code = 400 super(ResourceNotFoundException, self).__init__( - "ResourceNotFoundException", - "The specified resource does not exist" + "ResourceNotFoundException", "The specified resource does not exist" ) @@ -19,8 +18,7 @@ class InvalidParameterException(LogsClientError): def __init__(self, msg=None): self.code = 400 super(InvalidParameterException, self).__init__( - "InvalidParameterException", - msg or "A parameter is specified incorrectly." + "InvalidParameterException", msg or "A parameter is specified incorrectly." ) @@ -28,6 +26,5 @@ class ResourceAlreadyExistsException(LogsClientError): def __init__(self): self.code = 400 super(ResourceAlreadyExistsException, self).__init__( - 'ResourceAlreadyExistsException', - 'The specified log group already exists' + "ResourceAlreadyExistsException", "The specified log group already exists" ) diff --git a/moto/logs/models.py b/moto/logs/models.py index 448d3dec1..965b1e19a 100644 --- a/moto/logs/models.py +++ b/moto/logs/models.py @@ -1,10 +1,7 @@ from moto.core import BaseBackend import boto.logs from moto.core.utils import unix_time_millis -from .exceptions import ( - ResourceNotFoundException, - ResourceAlreadyExistsException -) +from .exceptions import ResourceNotFoundException, ResourceAlreadyExistsException class LogEvent: @@ -13,7 +10,7 @@ class LogEvent: def __init__(self, ingestion_time, log_event): self.ingestionTime = ingestion_time self.timestamp = log_event["timestamp"] - self.message = log_event['message'] + self.message = log_event["message"] self.eventId = self.__class__._event_id self.__class__._event_id += 1 @@ -23,14 +20,14 @@ class LogEvent: "ingestionTime": self.ingestionTime, # "logStreamName": "message": self.message, - "timestamp": self.timestamp + "timestamp": self.timestamp, } def to_response_dict(self): return { "ingestionTime": self.ingestionTime, "message": self.message, - "timestamp": self.timestamp + "timestamp": self.timestamp, } @@ -40,22 +37,32 @@ class LogStream: def __init__(self, region, log_group, name): self.region = region self.arn = "arn:aws:logs:{region}:{id}:log-group:{log_group}:log-stream:{log_stream}".format( - region=region, id=self.__class__._log_ids, log_group=log_group, log_stream=name) + region=region, + id=self.__class__._log_ids, + log_group=log_group, + log_stream=name, + ) self.creationTime = int(unix_time_millis()) self.firstEventTimestamp = None self.lastEventTimestamp = None self.lastIngestionTime = None self.logStreamName = name self.storedBytes = 0 - self.uploadSequenceToken = 0 # I'm guessing this is token needed for sequenceToken by put_events + self.uploadSequenceToken = ( + 0 # I'm guessing this is token needed for sequenceToken by put_events + ) self.events = [] self.__class__._log_ids += 1 def _update(self): # events can be empty when stream is described soon after creation - self.firstEventTimestamp = min([x.timestamp for x in self.events]) if self.events else None - self.lastEventTimestamp = max([x.timestamp for x in self.events]) if self.events else None + self.firstEventTimestamp = ( + min([x.timestamp for x in self.events]) if self.events else None + ) + self.lastEventTimestamp = ( + max([x.timestamp for x in self.events]) if self.events else None + ) def to_describe_dict(self): # Compute start and end times @@ -77,18 +84,31 @@ class LogStream: res.update(rest) return res - def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): + def put_log_events( + self, log_group_name, log_stream_name, log_events, sequence_token + ): # TODO: ensure sequence_token # TODO: to be thread safe this would need a lock self.lastIngestionTime = int(unix_time_millis()) # TODO: make this match AWS if possible self.storedBytes += sum([len(log_event["message"]) for log_event in log_events]) - self.events += [LogEvent(self.lastIngestionTime, log_event) for log_event in log_events] + self.events += [ + LogEvent(self.lastIngestionTime, log_event) for log_event in log_events + ] self.uploadSequenceToken += 1 - return '{:056d}'.format(self.uploadSequenceToken) + return "{:056d}".format(self.uploadSequenceToken) - def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head): + def get_log_events( + self, + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ): def filter_func(event): if start_time and event.timestamp < start_time: return False @@ -108,11 +128,18 @@ class LogStream: return int(token[2:]) return 0 - events = sorted(filter(filter_func, self.events), key=lambda event: event.timestamp, reverse=start_from_head) + events = sorted( + filter(filter_func, self.events), + key=lambda event: event.timestamp, + reverse=start_from_head, + ) next_index = get_index_from_paging_token(next_token) back_index = next_index - events_page = [event.to_response_dict() for event in events[next_index: next_index + limit]] + events_page = [ + event.to_response_dict() + for event in events[next_index : next_index + limit] + ] if next_index + limit < len(self.events): next_index += limit else: @@ -122,11 +149,25 @@ class LogStream: if back_index <= 0: back_index = 0 - return events_page, get_paging_token_from_index(back_index, True), get_paging_token_from_index(next_index) + return ( + events_page, + get_paging_token_from_index(back_index, True), + get_paging_token_from_index(next_index), + ) - def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): + def filter_log_events( + self, + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ): if filter_pattern: - raise NotImplementedError('filter_pattern is not yet implemented') + raise NotImplementedError("filter_pattern is not yet implemented") def filter_func(event): if start_time and event.timestamp < start_time: @@ -138,9 +179,11 @@ class LogStream: return True events = [] - for event in sorted(filter(filter_func, self.events), key=lambda x: x.timestamp): + for event in sorted( + filter(filter_func, self.events), key=lambda x: x.timestamp + ): event_obj = event.to_filter_dict() - event_obj['logStreamName'] = self.logStreamName + event_obj["logStreamName"] = self.logStreamName events.append(event_obj) return events @@ -150,72 +193,140 @@ class LogGroup: self.name = name self.region = region self.arn = "arn:aws:logs:{region}:1:log-group:{log_group}".format( - region=region, log_group=name) + region=region, log_group=name + ) self.creationTime = int(unix_time_millis()) self.tags = tags self.streams = dict() # {name: LogStream} - self.retentionInDays = None # AWS defaults to Never Expire for log group retention + self.retentionInDays = ( + None # AWS defaults to Never Expire for log group retention + ) def create_log_stream(self, log_stream_name): if log_stream_name in self.streams: raise ResourceAlreadyExistsException() - self.streams[log_stream_name] = LogStream(self.region, self.name, log_stream_name) + self.streams[log_stream_name] = LogStream( + self.region, self.name, log_stream_name + ) def delete_log_stream(self, log_stream_name): if log_stream_name not in self.streams: raise ResourceNotFoundException() del self.streams[log_stream_name] - def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by): + def describe_log_streams( + self, + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ): # responses only logStreamName, creationTime, arn, storedBytes when no events are stored. - log_streams = [(name, stream.to_describe_dict()) for name, stream in self.streams.items() if name.startswith(log_stream_name_prefix)] + log_streams = [ + (name, stream.to_describe_dict()) + for name, stream in self.streams.items() + if name.startswith(log_stream_name_prefix) + ] def sorter(item): - return item[0] if order_by == 'logStreamName' else item[1].get('lastEventTimestamp', 0) + return ( + item[0] + if order_by == "logStreamName" + else item[1].get("lastEventTimestamp", 0) + ) if next_token is None: next_token = 0 log_streams = sorted(log_streams, key=sorter, reverse=descending) new_token = next_token + limit - log_streams_page = [x[1] for x in log_streams[next_token: new_token]] + log_streams_page = [x[1] for x in log_streams[next_token:new_token]] if new_token >= len(log_streams): new_token = None return log_streams_page, new_token - def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): + def put_log_events( + self, log_group_name, log_stream_name, log_events, sequence_token + ): if log_stream_name not in self.streams: raise ResourceNotFoundException() stream = self.streams[log_stream_name] - return stream.put_log_events(log_group_name, log_stream_name, log_events, sequence_token) + return stream.put_log_events( + log_group_name, log_stream_name, log_events, sequence_token + ) - def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head): + def get_log_events( + self, + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ): if log_stream_name not in self.streams: raise ResourceNotFoundException() stream = self.streams[log_stream_name] - return stream.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head) + return stream.get_log_events( + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ) - def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): - streams = [stream for name, stream in self.streams.items() if not log_stream_names or name in log_stream_names] + def filter_log_events( + self, + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ): + streams = [ + stream + for name, stream in self.streams.items() + if not log_stream_names or name in log_stream_names + ] events = [] for stream in streams: - events += stream.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved) + events += stream.filter_log_events( + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ) if interleaved: - events = sorted(events, key=lambda event: event['timestamp']) + events = sorted(events, key=lambda event: event["timestamp"]) if next_token is None: next_token = 0 - events_page = events[next_token: next_token + limit] + events_page = events[next_token : next_token + limit] next_token += limit if next_token >= len(events): next_token = None - searched_streams = [{"logStreamName": stream.logStreamName, "searchedCompletely": True} for stream in streams] + searched_streams = [ + {"logStreamName": stream.logStreamName, "searchedCompletely": True} + for stream in streams + ] return events_page, next_token, searched_streams def to_describe_dict(self): @@ -245,7 +356,9 @@ class LogGroup: def untag(self, tags_to_remove): if self.tags: - self.tags = {k: v for (k, v) in self.tags.items() if k not in tags_to_remove} + self.tags = { + k: v for (k, v) in self.tags.items() if k not in tags_to_remove + } class LogsBackend(BaseBackend): @@ -275,13 +388,17 @@ class LogsBackend(BaseBackend): def describe_log_groups(self, limit, log_group_name_prefix, next_token): if log_group_name_prefix is None: - log_group_name_prefix = '' + log_group_name_prefix = "" if next_token is None: next_token = 0 - groups = [group.to_describe_dict() for name, group in self.groups.items() if name.startswith(log_group_name_prefix)] - groups = sorted(groups, key=lambda x: x['creationTime'], reverse=True) - groups_page = groups[next_token:next_token + limit] + groups = [ + group.to_describe_dict() + for name, group in self.groups.items() + if name.startswith(log_group_name_prefix) + ] + groups = sorted(groups, key=lambda x: x["creationTime"], reverse=True) + groups_page = groups[next_token : next_token + limit] next_token += limit if next_token >= len(groups): @@ -301,30 +418,85 @@ class LogsBackend(BaseBackend): log_group = self.groups[log_group_name] return log_group.delete_log_stream(log_stream_name) - def describe_log_streams(self, descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by): + def describe_log_streams( + self, + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ): if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.describe_log_streams(descending, limit, log_group_name, log_stream_name_prefix, next_token, order_by) + return log_group.describe_log_streams( + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ) - def put_log_events(self, log_group_name, log_stream_name, log_events, sequence_token): + def put_log_events( + self, log_group_name, log_stream_name, log_events, sequence_token + ): # TODO: add support for sequence_tokens if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.put_log_events(log_group_name, log_stream_name, log_events, sequence_token) + return log_group.put_log_events( + log_group_name, log_stream_name, log_events, sequence_token + ) - def get_log_events(self, log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head): + def get_log_events( + self, + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ): if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head) + return log_group.get_log_events( + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ) - def filter_log_events(self, log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved): + def filter_log_events( + self, + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ): if log_group_name not in self.groups: raise ResourceNotFoundException() log_group = self.groups[log_group_name] - return log_group.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved) + return log_group.filter_log_events( + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ) def put_retention_policy(self, log_group_name, retention_in_days): if log_group_name not in self.groups: @@ -357,4 +529,6 @@ class LogsBackend(BaseBackend): log_group.untag(tags) -logs_backends = {region.name: LogsBackend(region.name) for region in boto.logs.regions()} +logs_backends = { + region.name: LogsBackend(region.name) for region in boto.logs.regions() +} diff --git a/moto/logs/responses.py b/moto/logs/responses.py index b91662cf8..072c76b71 100644 --- a/moto/logs/responses.py +++ b/moto/logs/responses.py @@ -5,6 +5,7 @@ import json # See http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/Welcome.html + class LogsResponse(BaseResponse): @property def logs_backend(self): @@ -21,135 +22,159 @@ class LogsResponse(BaseResponse): return self.request_params.get(param, if_none) def create_log_group(self): - log_group_name = self._get_param('logGroupName') - tags = self._get_param('tags') + log_group_name = self._get_param("logGroupName") + tags = self._get_param("tags") assert 1 <= len(log_group_name) <= 512 # TODO: assert pattern self.logs_backend.create_log_group(log_group_name, tags) - return '' + return "" def delete_log_group(self): - log_group_name = self._get_param('logGroupName') + log_group_name = self._get_param("logGroupName") self.logs_backend.delete_log_group(log_group_name) - return '' + return "" def describe_log_groups(self): - log_group_name_prefix = self._get_param('logGroupNamePrefix') - next_token = self._get_param('nextToken') - limit = self._get_param('limit', 50) + log_group_name_prefix = self._get_param("logGroupNamePrefix") + next_token = self._get_param("nextToken") + limit = self._get_param("limit", 50) assert limit <= 50 groups, next_token = self.logs_backend.describe_log_groups( - limit, log_group_name_prefix, next_token) - return json.dumps({ - "logGroups": groups, - "nextToken": next_token - }) + limit, log_group_name_prefix, next_token + ) + return json.dumps({"logGroups": groups, "nextToken": next_token}) def create_log_stream(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") self.logs_backend.create_log_stream(log_group_name, log_stream_name) - return '' + return "" def delete_log_stream(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") self.logs_backend.delete_log_stream(log_group_name, log_stream_name) - return '' + return "" def describe_log_streams(self): - log_group_name = self._get_param('logGroupName') - log_stream_name_prefix = self._get_param('logStreamNamePrefix', '') - descending = self._get_param('descending', False) - limit = self._get_param('limit', 50) + log_group_name = self._get_param("logGroupName") + log_stream_name_prefix = self._get_param("logStreamNamePrefix", "") + descending = self._get_param("descending", False) + limit = self._get_param("limit", 50) assert limit <= 50 - next_token = self._get_param('nextToken') - order_by = self._get_param('orderBy', 'LogStreamName') - assert order_by in {'LogStreamName', 'LastEventTime'} + next_token = self._get_param("nextToken") + order_by = self._get_param("orderBy", "LogStreamName") + assert order_by in {"LogStreamName", "LastEventTime"} - if order_by == 'LastEventTime': + if order_by == "LastEventTime": assert not log_stream_name_prefix streams, next_token = self.logs_backend.describe_log_streams( - descending, limit, log_group_name, log_stream_name_prefix, - next_token, order_by) - return json.dumps({ - "logStreams": streams, - "nextToken": next_token - }) + descending, + limit, + log_group_name, + log_stream_name_prefix, + next_token, + order_by, + ) + return json.dumps({"logStreams": streams, "nextToken": next_token}) def put_log_events(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') - log_events = self._get_param('logEvents') - sequence_token = self._get_param('sequenceToken') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") + log_events = self._get_param("logEvents") + sequence_token = self._get_param("sequenceToken") - next_sequence_token = self.logs_backend.put_log_events(log_group_name, log_stream_name, log_events, sequence_token) - return json.dumps({'nextSequenceToken': next_sequence_token}) + next_sequence_token = self.logs_backend.put_log_events( + log_group_name, log_stream_name, log_events, sequence_token + ) + return json.dumps({"nextSequenceToken": next_sequence_token}) def get_log_events(self): - log_group_name = self._get_param('logGroupName') - log_stream_name = self._get_param('logStreamName') - start_time = self._get_param('startTime') + log_group_name = self._get_param("logGroupName") + log_stream_name = self._get_param("logStreamName") + start_time = self._get_param("startTime") end_time = self._get_param("endTime") - limit = self._get_param('limit', 10000) + limit = self._get_param("limit", 10000) assert limit <= 10000 - next_token = self._get_param('nextToken') - start_from_head = self._get_param('startFromHead', False) + next_token = self._get_param("nextToken") + start_from_head = self._get_param("startFromHead", False) - events, next_backward_token, next_foward_token = \ - self.logs_backend.get_log_events(log_group_name, log_stream_name, start_time, end_time, limit, next_token, start_from_head) - return json.dumps({ - "events": events, - "nextBackwardToken": next_backward_token, - "nextForwardToken": next_foward_token - }) + ( + events, + next_backward_token, + next_foward_token, + ) = self.logs_backend.get_log_events( + log_group_name, + log_stream_name, + start_time, + end_time, + limit, + next_token, + start_from_head, + ) + return json.dumps( + { + "events": events, + "nextBackwardToken": next_backward_token, + "nextForwardToken": next_foward_token, + } + ) def filter_log_events(self): - log_group_name = self._get_param('logGroupName') - log_stream_names = self._get_param('logStreamNames', []) - start_time = self._get_param('startTime') + log_group_name = self._get_param("logGroupName") + log_stream_names = self._get_param("logStreamNames", []) + start_time = self._get_param("startTime") # impl, see: http://docs.aws.amazon.com/AmazonCloudWatch/latest/logs/FilterAndPatternSyntax.html - filter_pattern = self._get_param('filterPattern') - interleaved = self._get_param('interleaved', False) + filter_pattern = self._get_param("filterPattern") + interleaved = self._get_param("interleaved", False) end_time = self._get_param("endTime") - limit = self._get_param('limit', 10000) + limit = self._get_param("limit", 10000) assert limit <= 10000 - next_token = self._get_param('nextToken') + next_token = self._get_param("nextToken") - events, next_token, searched_streams = self.logs_backend.filter_log_events(log_group_name, log_stream_names, start_time, end_time, limit, next_token, filter_pattern, interleaved) - return json.dumps({ - "events": events, - "nextToken": next_token, - "searchedLogStreams": searched_streams - }) + events, next_token, searched_streams = self.logs_backend.filter_log_events( + log_group_name, + log_stream_names, + start_time, + end_time, + limit, + next_token, + filter_pattern, + interleaved, + ) + return json.dumps( + { + "events": events, + "nextToken": next_token, + "searchedLogStreams": searched_streams, + } + ) def put_retention_policy(self): - log_group_name = self._get_param('logGroupName') - retention_in_days = self._get_param('retentionInDays') + log_group_name = self._get_param("logGroupName") + retention_in_days = self._get_param("retentionInDays") self.logs_backend.put_retention_policy(log_group_name, retention_in_days) - return '' + return "" def delete_retention_policy(self): - log_group_name = self._get_param('logGroupName') + log_group_name = self._get_param("logGroupName") self.logs_backend.delete_retention_policy(log_group_name) - return '' + return "" def list_tags_log_group(self): - log_group_name = self._get_param('logGroupName') + log_group_name = self._get_param("logGroupName") tags = self.logs_backend.list_tags_log_group(log_group_name) - return json.dumps({ - 'tags': tags - }) + return json.dumps({"tags": tags}) def tag_log_group(self): - log_group_name = self._get_param('logGroupName') - tags = self._get_param('tags') + log_group_name = self._get_param("logGroupName") + tags = self._get_param("tags") self.logs_backend.tag_log_group(log_group_name, tags) - return '' + return "" def untag_log_group(self): - log_group_name = self._get_param('logGroupName') - tags = self._get_param('tags') + log_group_name = self._get_param("logGroupName") + tags = self._get_param("tags") self.logs_backend.untag_log_group(log_group_name, tags) - return '' + return "" diff --git a/moto/logs/urls.py b/moto/logs/urls.py index b7910e675..e4e1f5a88 100644 --- a/moto/logs/urls.py +++ b/moto/logs/urls.py @@ -1,9 +1,5 @@ from .responses import LogsResponse -url_bases = [ - "https?://logs.(.+).amazonaws.com", -] +url_bases = ["https?://logs.(.+).amazonaws.com"] -url_paths = { - '{0}/$': LogsResponse.dispatch, -} +url_paths = {"{0}/$": LogsResponse.dispatch} diff --git a/moto/opsworks/__init__.py b/moto/opsworks/__init__.py index b492b6a53..e0e6b88d0 100644 --- a/moto/opsworks/__init__.py +++ b/moto/opsworks/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import opsworks_backends from ..core.models import base_decorator, deprecated_base_decorator -opsworks_backend = opsworks_backends['us-east-1'] +opsworks_backend = opsworks_backends["us-east-1"] mock_opsworks = base_decorator(opsworks_backends) mock_opsworks_deprecated = deprecated_base_decorator(opsworks_backends) diff --git a/moto/opsworks/exceptions.py b/moto/opsworks/exceptions.py index 00bdffbc5..3867b3b90 100644 --- a/moto/opsworks/exceptions.py +++ b/moto/opsworks/exceptions.py @@ -5,20 +5,16 @@ from werkzeug.exceptions import BadRequest class ResourceNotFoundException(BadRequest): - def __init__(self, message): super(ResourceNotFoundException, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) class ValidationException(BadRequest): - def __init__(self, message): super(ValidationException, self).__init__() - self.description = json.dumps({ - "message": message, - '__type': 'ResourceNotFoundException', - }) + self.description = json.dumps( + {"message": message, "__type": "ResourceNotFoundException"} + ) diff --git a/moto/opsworks/models.py b/moto/opsworks/models.py index 4fe428c65..336bbde14 100644 --- a/moto/opsworks/models.py +++ b/moto/opsworks/models.py @@ -15,24 +15,30 @@ class OpsworkInstance(BaseModel): used to populate a reservation request when "start" is called """ - def __init__(self, stack_id, layer_ids, instance_type, ec2_backend, - auto_scale_type=None, - hostname=None, - os=None, - ami_id="ami-08111162", - ssh_keyname=None, - availability_zone=None, - virtualization_type="hvm", - subnet_id=None, - architecture="x86_64", - root_device_type="ebs", - block_device_mappings=None, - install_updates_on_boot=True, - ebs_optimized=False, - agent_version="INHERIT", - instance_profile_arn=None, - associate_public_ip=None, - security_group_ids=None): + def __init__( + self, + stack_id, + layer_ids, + instance_type, + ec2_backend, + auto_scale_type=None, + hostname=None, + os=None, + ami_id="ami-08111162", + ssh_keyname=None, + availability_zone=None, + virtualization_type="hvm", + subnet_id=None, + architecture="x86_64", + root_device_type="ebs", + block_device_mappings=None, + install_updates_on_boot=True, + ebs_optimized=False, + agent_version="INHERIT", + instance_profile_arn=None, + associate_public_ip=None, + security_group_ids=None, + ): self.ec2_backend = ec2_backend @@ -55,13 +61,12 @@ class OpsworkInstance(BaseModel): # formatting in to_dict() self.block_device_mappings = block_device_mappings if self.block_device_mappings is None: - self.block_device_mappings = [{ - 'DeviceName': 'ROOT_DEVICE', - 'Ebs': { - 'VolumeSize': 8, - 'VolumeType': 'gp2' + self.block_device_mappings = [ + { + "DeviceName": "ROOT_DEVICE", + "Ebs": {"VolumeSize": 8, "VolumeType": "gp2"}, } - }] + ] self.security_group_ids = security_group_ids if self.security_group_ids is None: self.security_group_ids = [] @@ -102,9 +107,9 @@ class OpsworkInstance(BaseModel): ) self.instance = reservation.instances[0] self.reported_os = { - 'Family': 'rhel (fixed)', - 'Name': 'amazon (fixed)', - 'Version': '2016.03 (fixed)' + "Family": "rhel (fixed)", + "Name": "amazon (fixed)", + "Version": "2016.03 (fixed)", } self.platform = self.instance.platform self.security_group_ids = self.instance.security_groups @@ -156,32 +161,43 @@ class OpsworkInstance(BaseModel): d.update({"RootDeviceVolumeId": "vol-a20e450a (fixed)"}) if self.ssh_keyname is not None: d.update( - {"SshHostDsaKeyFingerprint": "24:36:32:fe:d8:5f:9c:18:b1:ad:37:e9:eb:e8:69:58 (fixed)"}) + { + "SshHostDsaKeyFingerprint": "24:36:32:fe:d8:5f:9c:18:b1:ad:37:e9:eb:e8:69:58 (fixed)" + } + ) d.update( - {"SshHostRsaKeyFingerprint": "3c:bd:37:52:d7:ca:67:e1:6e:4b:ac:31:86:79:f5:6c (fixed)"}) + { + "SshHostRsaKeyFingerprint": "3c:bd:37:52:d7:ca:67:e1:6e:4b:ac:31:86:79:f5:6c (fixed)" + } + ) d.update({"PrivateDns": self.instance.private_dns}) d.update({"PrivateIp": self.instance.private_ip}) - d.update({"PublicDns": getattr(self.instance, 'public_dns', None)}) - d.update({"PublicIp": getattr(self.instance, 'public_ip', None)}) + d.update({"PublicDns": getattr(self.instance, "public_dns", None)}) + d.update({"PublicIp": getattr(self.instance, "public_ip", None)}) return d class Layer(BaseModel): - - def __init__(self, stack_id, type, name, shortname, - attributes=None, - custom_instance_profile_arn=None, - custom_json=None, - custom_security_group_ids=None, - packages=None, - volume_configurations=None, - enable_autohealing=None, - auto_assign_elastic_ips=None, - auto_assign_public_ips=None, - custom_recipes=None, - install_updates_on_boot=None, - use_ebs_optimized_instances=None, - lifecycle_event_configuration=None): + def __init__( + self, + stack_id, + type, + name, + shortname, + attributes=None, + custom_instance_profile_arn=None, + custom_json=None, + custom_security_group_ids=None, + packages=None, + volume_configurations=None, + enable_autohealing=None, + auto_assign_elastic_ips=None, + auto_assign_public_ips=None, + custom_recipes=None, + install_updates_on_boot=None, + use_ebs_optimized_instances=None, + lifecycle_event_configuration=None, + ): self.stack_id = stack_id self.type = type self.name = name @@ -190,31 +206,31 @@ class Layer(BaseModel): self.attributes = attributes if attributes is None: self.attributes = { - 'BundlerVersion': None, - 'EcsClusterArn': None, - 'EnableHaproxyStats': None, - 'GangliaPassword': None, - 'GangliaUrl': None, - 'GangliaUser': None, - 'HaproxyHealthCheckMethod': None, - 'HaproxyHealthCheckUrl': None, - 'HaproxyStatsPassword': None, - 'HaproxyStatsUrl': None, - 'HaproxyStatsUser': None, - 'JavaAppServer': None, - 'JavaAppServerVersion': None, - 'Jvm': None, - 'JvmOptions': None, - 'JvmVersion': None, - 'ManageBundler': None, - 'MemcachedMemory': None, - 'MysqlRootPassword': None, - 'MysqlRootPasswordUbiquitous': None, - 'NodejsVersion': None, - 'PassengerVersion': None, - 'RailsStack': None, - 'RubyVersion': None, - 'RubygemsVersion': None + "BundlerVersion": None, + "EcsClusterArn": None, + "EnableHaproxyStats": None, + "GangliaPassword": None, + "GangliaUrl": None, + "GangliaUser": None, + "HaproxyHealthCheckMethod": None, + "HaproxyHealthCheckUrl": None, + "HaproxyStatsPassword": None, + "HaproxyStatsUrl": None, + "HaproxyStatsUser": None, + "JavaAppServer": None, + "JavaAppServerVersion": None, + "Jvm": None, + "JvmOptions": None, + "JvmVersion": None, + "ManageBundler": None, + "MemcachedMemory": None, + "MysqlRootPassword": None, + "MysqlRootPasswordUbiquitous": None, + "NodejsVersion": None, + "PassengerVersion": None, + "RailsStack": None, + "RubyVersion": None, + "RubygemsVersion": None, } # May not be accurate self.packages = packages @@ -224,11 +240,11 @@ class Layer(BaseModel): self.custom_recipes = custom_recipes if custom_recipes is None: self.custom_recipes = { - 'Configure': [], - 'Deploy': [], - 'Setup': [], - 'Shutdown': [], - 'Undeploy': [], + "Configure": [], + "Deploy": [], + "Setup": [], + "Shutdown": [], + "Undeploy": [], } self.custom_security_group_ids = custom_security_group_ids @@ -271,9 +287,9 @@ class Layer(BaseModel): "Configure": [], "Setup": [], "Shutdown": [], - "Undeploy": [] + "Undeploy": [], }, # May not be accurate - "DefaultSecurityGroupNames": ['AWS-OpsWorks-Custom-Server'], + "DefaultSecurityGroupNames": ["AWS-OpsWorks-Custom-Server"], "EnableAutoHealing": self.enable_autohealing, "LayerId": self.id, "LifecycleEventConfiguration": self.lifecycle_event_configuration, @@ -287,29 +303,33 @@ class Layer(BaseModel): if self.custom_json is not None: d.update({"CustomJson": self.custom_json}) if self.custom_instance_profile_arn is not None: - d.update( - {"CustomInstanceProfileArn": self.custom_instance_profile_arn}) + d.update({"CustomInstanceProfileArn": self.custom_instance_profile_arn}) return d class Stack(BaseModel): - - def __init__(self, name, region, service_role_arn, default_instance_profile_arn, - vpcid="vpc-1f99bf7a", - attributes=None, - default_os='Ubuntu 12.04 LTS', - hostname_theme='Layer_Dependent', - default_availability_zone='us-east-1a', - default_subnet_id='subnet-73981004', - custom_json=None, - configuration_manager=None, - chef_configuration=None, - use_custom_cookbooks=False, - use_opsworks_security_groups=True, - custom_cookbooks_source=None, - default_ssh_keyname=None, - default_root_device_type='instance-store', - agent_version='LATEST'): + def __init__( + self, + name, + region, + service_role_arn, + default_instance_profile_arn, + vpcid="vpc-1f99bf7a", + attributes=None, + default_os="Ubuntu 12.04 LTS", + hostname_theme="Layer_Dependent", + default_availability_zone="us-east-1a", + default_subnet_id="subnet-73981004", + custom_json=None, + configuration_manager=None, + chef_configuration=None, + use_custom_cookbooks=False, + use_opsworks_security_groups=True, + custom_cookbooks_source=None, + default_ssh_keyname=None, + default_root_device_type="instance-store", + agent_version="LATEST", + ): self.name = name self.region = region @@ -319,11 +339,11 @@ class Stack(BaseModel): self.vpcid = vpcid self.attributes = attributes if attributes is None: - self.attributes = {'Color': None} + self.attributes = {"Color": None} self.configuration_manager = configuration_manager if configuration_manager is None: - self.configuration_manager = {'Name': 'Chef', 'Version': '11.4'} + self.configuration_manager = {"Name": "Chef", "Version": "11.4"} self.chef_configuration = chef_configuration if chef_configuration is None: @@ -356,15 +376,13 @@ class Stack(BaseModel): def generate_hostname(self): # this doesn't match amazon's implementation return "{theme}-{rand}-(moto)".format( - theme=self.hostname_theme, - rand=[choice("abcdefghijhk") for _ in range(4)]) + theme=self.hostname_theme, rand=[choice("abcdefghijhk") for _ in range(4)] + ) @property def arn(self): return "arn:aws:opsworks:{region}:{account_number}:stack/{id}".format( - region=self.region, - account_number=self.account_number, - id=self.id + region=self.region, account_number=self.account_number, id=self.id ) def to_dict(self): @@ -389,7 +407,7 @@ class Stack(BaseModel): "StackId": self.id, "UseCustomCookbooks": self.use_custom_cookbooks, "UseOpsworksSecurityGroups": self.use_opsworks_security_groups, - "VpcId": self.vpcid + "VpcId": self.vpcid, } if self.custom_json is not None: response.update({"CustomJson": self.custom_json}) @@ -399,17 +417,21 @@ class Stack(BaseModel): class App(BaseModel): - - def __init__(self, stack_id, name, type, - shortname=None, - description=None, - datasources=None, - app_source=None, - domains=None, - enable_ssl=False, - ssl_configuration=None, - attributes=None, - environment=None): + def __init__( + self, + stack_id, + name, + type, + shortname=None, + description=None, + datasources=None, + app_source=None, + domains=None, + enable_ssl=False, + ssl_configuration=None, + attributes=None, + environment=None, + ): self.stack_id = stack_id self.name = name self.type = type @@ -463,13 +485,12 @@ class App(BaseModel): "Shortname": self.shortname, "SslConfiguration": self.ssl_configuration, "StackId": self.stack_id, - "Type": self.type + "Type": self.type, } return d class OpsWorksBackend(BaseBackend): - def __init__(self, ec2_backend): self.stacks = {} self.layers = {} @@ -488,55 +509,59 @@ class OpsWorksBackend(BaseBackend): return stack def create_layer(self, **kwargs): - name = kwargs['name'] - shortname = kwargs['shortname'] - stackid = kwargs['stack_id'] + name = kwargs["name"] + shortname = kwargs["shortname"] + stackid = kwargs["stack_id"] if stackid not in self.stacks: raise ResourceNotFoundException(stackid) if name in [l.name for l in self.stacks[stackid].layers]: raise ValidationException( - 'There is already a layer named "{0}" ' - 'for this stack'.format(name)) + 'There is already a layer named "{0}" ' "for this stack".format(name) + ) if shortname in [l.shortname for l in self.stacks[stackid].layers]: raise ValidationException( 'There is already a layer with shortname "{0}" ' - 'for this stack'.format(shortname)) + "for this stack".format(shortname) + ) layer = Layer(**kwargs) self.layers[layer.id] = layer self.stacks[stackid].layers.append(layer) return layer def create_app(self, **kwargs): - name = kwargs['name'] - stackid = kwargs['stack_id'] + name = kwargs["name"] + stackid = kwargs["stack_id"] if stackid not in self.stacks: raise ResourceNotFoundException(stackid) if name in [a.name for a in self.stacks[stackid].apps]: raise ValidationException( - 'There is already an app named "{0}" ' - 'for this stack'.format(name)) + 'There is already an app named "{0}" ' "for this stack".format(name) + ) app = App(**kwargs) self.apps[app.id] = app self.stacks[stackid].apps.append(app) return app def create_instance(self, **kwargs): - stack_id = kwargs['stack_id'] - layer_ids = kwargs['layer_ids'] + stack_id = kwargs["stack_id"] + layer_ids = kwargs["layer_ids"] if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) + "Unable to find stack with ID {0}".format(stack_id) + ) unknown_layers = set(layer_ids) - set(self.layers.keys()) if unknown_layers: raise ResourceNotFoundException(", ".join(unknown_layers)) layers = [self.layers[id] for id in layer_ids] - if len(set([layer.stack_id for layer in layers])) != 1 or \ - any([layer.stack_id != stack_id for layer in layers]): + if len(set([layer.stack_id for layer in layers])) != 1 or any( + [layer.stack_id != stack_id for layer in layers] + ): raise ValidationException( - "Please only provide layer IDs from the same stack") + "Please only provide layer IDs from the same stack" + ) stack = self.stacks[stack_id] # pick the first to set default instance_profile_arn and @@ -549,12 +574,9 @@ class OpsWorksBackend(BaseBackend): kwargs.setdefault("subnet_id", stack.default_subnet_id) kwargs.setdefault("root_device_type", stack.default_root_device_type) if layer.custom_instance_profile_arn: - kwargs.setdefault("instance_profile_arn", - layer.custom_instance_profile_arn) - kwargs.setdefault("instance_profile_arn", - stack.default_instance_profile_arn) - kwargs.setdefault("security_group_ids", - layer.custom_security_group_ids) + kwargs.setdefault("instance_profile_arn", layer.custom_instance_profile_arn) + kwargs.setdefault("instance_profile_arn", stack.default_instance_profile_arn) + kwargs.setdefault("security_group_ids", layer.custom_security_group_ids) kwargs.setdefault("associate_public_ip", layer.auto_assign_public_ips) kwargs.setdefault("ebs_optimized", layer.use_ebs_optimized_instances) kwargs.update({"ec2_backend": self.ec2_backend}) @@ -579,7 +601,8 @@ class OpsWorksBackend(BaseBackend): if stack_id is not None: if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) + "Unable to find stack with ID {0}".format(stack_id) + ) return [layer.to_dict() for layer in self.stacks[stack_id].layers] unknown_layers = set(layer_ids) - set(self.layers.keys()) @@ -595,7 +618,8 @@ class OpsWorksBackend(BaseBackend): if stack_id is not None: if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) + "Unable to find stack with ID {0}".format(stack_id) + ) return [app.to_dict() for app in self.stacks[stack_id].apps] unknown_apps = set(app_ids) - set(self.apps.keys()) @@ -605,9 +629,11 @@ class OpsWorksBackend(BaseBackend): def describe_instances(self, instance_ids, layer_id, stack_id): if len(list(filter(None, (instance_ids, layer_id, stack_id)))) != 1: - raise ValidationException("Please provide either one or more " - "instance IDs or one stack ID or one " - "layer ID") + raise ValidationException( + "Please provide either one or more " + "instance IDs or one stack ID or one " + "layer ID" + ) if instance_ids: unknown_instances = set(instance_ids) - set(self.instances.keys()) if unknown_instances: @@ -617,23 +643,28 @@ class OpsWorksBackend(BaseBackend): if layer_id: if layer_id not in self.layers: raise ResourceNotFoundException( - "Unable to find layer with ID {0}".format(layer_id)) - instances = [i.to_dict() for i in self.instances.values() - if layer_id in i.layer_ids] + "Unable to find layer with ID {0}".format(layer_id) + ) + instances = [ + i.to_dict() for i in self.instances.values() if layer_id in i.layer_ids + ] return instances if stack_id: if stack_id not in self.stacks: raise ResourceNotFoundException( - "Unable to find stack with ID {0}".format(stack_id)) - instances = [i.to_dict() for i in self.instances.values() - if stack_id == i.stack_id] + "Unable to find stack with ID {0}".format(stack_id) + ) + instances = [ + i.to_dict() for i in self.instances.values() if stack_id == i.stack_id + ] return instances def start_instance(self, instance_id): if instance_id not in self.instances: raise ResourceNotFoundException( - "Unable to find instance with ID {0}".format(instance_id)) + "Unable to find instance with ID {0}".format(instance_id) + ) self.instances[instance_id].start() diff --git a/moto/opsworks/responses.py b/moto/opsworks/responses.py index c9f8fe125..870b75244 100644 --- a/moto/opsworks/responses.py +++ b/moto/opsworks/responses.py @@ -7,7 +7,6 @@ from .models import opsworks_backends class OpsWorksResponse(BaseResponse): - @property def parameters(self): return json.loads(self.body) @@ -23,23 +22,22 @@ class OpsWorksResponse(BaseResponse): vpcid=self.parameters.get("VpcId"), attributes=self.parameters.get("Attributes"), default_instance_profile_arn=self.parameters.get( - "DefaultInstanceProfileArn"), + "DefaultInstanceProfileArn" + ), default_os=self.parameters.get("DefaultOs"), hostname_theme=self.parameters.get("HostnameTheme"), - default_availability_zone=self.parameters.get( - "DefaultAvailabilityZone"), + default_availability_zone=self.parameters.get("DefaultAvailabilityZone"), default_subnet_id=self.parameters.get("DefaultInstanceProfileArn"), custom_json=self.parameters.get("CustomJson"), configuration_manager=self.parameters.get("ConfigurationManager"), chef_configuration=self.parameters.get("ChefConfiguration"), use_custom_cookbooks=self.parameters.get("UseCustomCookbooks"), use_opsworks_security_groups=self.parameters.get( - "UseOpsworksSecurityGroups"), - custom_cookbooks_source=self.parameters.get( - "CustomCookbooksSource"), + "UseOpsworksSecurityGroups" + ), + custom_cookbooks_source=self.parameters.get("CustomCookbooksSource"), default_ssh_keyname=self.parameters.get("DefaultSshKeyName"), - default_root_device_type=self.parameters.get( - "DefaultRootDeviceType"), + default_root_device_type=self.parameters.get("DefaultRootDeviceType"), service_role_arn=self.parameters.get("ServiceRoleArn"), agent_version=self.parameters.get("AgentVersion"), ) @@ -48,47 +46,43 @@ class OpsWorksResponse(BaseResponse): def create_layer(self): kwargs = dict( - stack_id=self.parameters.get('StackId'), - type=self.parameters.get('Type'), - name=self.parameters.get('Name'), - shortname=self.parameters.get('Shortname'), - attributes=self.parameters.get('Attributes'), - custom_instance_profile_arn=self.parameters.get( - "CustomInstanceProfileArn"), + stack_id=self.parameters.get("StackId"), + type=self.parameters.get("Type"), + name=self.parameters.get("Name"), + shortname=self.parameters.get("Shortname"), + attributes=self.parameters.get("Attributes"), + custom_instance_profile_arn=self.parameters.get("CustomInstanceProfileArn"), custom_json=self.parameters.get("CustomJson"), - custom_security_group_ids=self.parameters.get( - 'CustomSecurityGroupIds'), - packages=self.parameters.get('Packages'), + custom_security_group_ids=self.parameters.get("CustomSecurityGroupIds"), + packages=self.parameters.get("Packages"), volume_configurations=self.parameters.get("VolumeConfigurations"), enable_autohealing=self.parameters.get("EnableAutoHealing"), - auto_assign_elastic_ips=self.parameters.get( - "AutoAssignElasticIps"), + auto_assign_elastic_ips=self.parameters.get("AutoAssignElasticIps"), auto_assign_public_ips=self.parameters.get("AutoAssignPublicIps"), custom_recipes=self.parameters.get("CustomRecipes"), - install_updates_on_boot=self.parameters.get( - "InstallUpdatesOnBoot"), - use_ebs_optimized_instances=self.parameters.get( - "UseEbsOptimizedInstances"), + install_updates_on_boot=self.parameters.get("InstallUpdatesOnBoot"), + use_ebs_optimized_instances=self.parameters.get("UseEbsOptimizedInstances"), lifecycle_event_configuration=self.parameters.get( - "LifecycleEventConfiguration") + "LifecycleEventConfiguration" + ), ) layer = self.opsworks_backend.create_layer(**kwargs) return json.dumps({"LayerId": layer.id}, indent=1) def create_app(self): kwargs = dict( - stack_id=self.parameters.get('StackId'), - name=self.parameters.get('Name'), - type=self.parameters.get('Type'), - shortname=self.parameters.get('Shortname'), - description=self.parameters.get('Description'), - datasources=self.parameters.get('DataSources'), - app_source=self.parameters.get('AppSource'), - domains=self.parameters.get('Domains'), - enable_ssl=self.parameters.get('EnableSsl'), - ssl_configuration=self.parameters.get('SslConfiguration'), - attributes=self.parameters.get('Attributes'), - environment=self.parameters.get('Environment') + stack_id=self.parameters.get("StackId"), + name=self.parameters.get("Name"), + type=self.parameters.get("Type"), + shortname=self.parameters.get("Shortname"), + description=self.parameters.get("Description"), + datasources=self.parameters.get("DataSources"), + app_source=self.parameters.get("AppSource"), + domains=self.parameters.get("Domains"), + enable_ssl=self.parameters.get("EnableSsl"), + ssl_configuration=self.parameters.get("SslConfiguration"), + attributes=self.parameters.get("Attributes"), + environment=self.parameters.get("Environment"), ) app = self.opsworks_backend.create_app(**kwargs) return json.dumps({"AppId": app.id}, indent=1) @@ -109,8 +103,7 @@ class OpsWorksResponse(BaseResponse): architecture=self.parameters.get("Architecture"), root_device_type=self.parameters.get("RootDeviceType"), block_device_mappings=self.parameters.get("BlockDeviceMappings"), - install_updates_on_boot=self.parameters.get( - "InstallUpdatesOnBoot"), + install_updates_on_boot=self.parameters.get("InstallUpdatesOnBoot"), ebs_optimized=self.parameters.get("EbsOptimized"), agent_version=self.parameters.get("AgentVersion"), ) @@ -139,7 +132,8 @@ class OpsWorksResponse(BaseResponse): layer_id = self.parameters.get("LayerId") stack_id = self.parameters.get("StackId") instances = self.opsworks_backend.describe_instances( - instance_ids, layer_id, stack_id) + instance_ids, layer_id, stack_id + ) return json.dumps({"Instances": instances}, indent=1) def start_instance(self): diff --git a/moto/opsworks/urls.py b/moto/opsworks/urls.py index 3d72bb0dd..1e5246e59 100644 --- a/moto/opsworks/urls.py +++ b/moto/opsworks/urls.py @@ -3,10 +3,6 @@ from .responses import OpsWorksResponse # AWS OpsWorks has a single endpoint: opsworks.us-east-1.amazonaws.com # and only supports HTTPS requests. -url_bases = [ - "https?://opsworks.us-east-1.amazonaws.com" -] +url_bases = ["https?://opsworks.us-east-1.amazonaws.com"] -url_paths = { - '{0}/$': OpsWorksResponse.dispatch, -} +url_paths = {"{0}/$": OpsWorksResponse.dispatch} diff --git a/moto/organizations/models.py b/moto/organizations/models.py index 561c6c3a8..37f8bdeb9 100644 --- a/moto/organizations/models.py +++ b/moto/organizations/models.py @@ -11,17 +11,15 @@ from moto.organizations import utils class FakeOrganization(BaseModel): - def __init__(self, feature_set): self.id = utils.make_random_org_id() self.root_id = utils.make_random_root_id() self.feature_set = feature_set self.master_account_id = utils.MASTER_ACCOUNT_ID self.master_account_email = utils.MASTER_ACCOUNT_EMAIL - self.available_policy_types = [{ - 'Type': 'SERVICE_CONTROL_POLICY', - 'Status': 'ENABLED' - }] + self.available_policy_types = [ + {"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"} + ] @property def arn(self): @@ -33,129 +31,114 @@ class FakeOrganization(BaseModel): def describe(self): return { - 'Organization': { - 'Id': self.id, - 'Arn': self.arn, - 'FeatureSet': self.feature_set, - 'MasterAccountArn': self.master_account_arn, - 'MasterAccountId': self.master_account_id, - 'MasterAccountEmail': self.master_account_email, - 'AvailablePolicyTypes': self.available_policy_types, + "Organization": { + "Id": self.id, + "Arn": self.arn, + "FeatureSet": self.feature_set, + "MasterAccountArn": self.master_account_arn, + "MasterAccountId": self.master_account_id, + "MasterAccountEmail": self.master_account_email, + "AvailablePolicyTypes": self.available_policy_types, } } class FakeAccount(BaseModel): - def __init__(self, organization, **kwargs): - self.type = 'ACCOUNT' + self.type = "ACCOUNT" self.organization_id = organization.id self.master_account_id = organization.master_account_id self.create_account_status_id = utils.make_random_create_account_status_id() self.id = utils.make_random_account_id() - self.name = kwargs['AccountName'] - self.email = kwargs['Email'] + self.name = kwargs["AccountName"] + self.email = kwargs["Email"] self.create_time = datetime.datetime.utcnow() - self.status = 'ACTIVE' - self.joined_method = 'CREATED' + self.status = "ACTIVE" + self.joined_method = "CREATED" self.parent_id = organization.root_id self.attached_policies = [] @property def arn(self): return utils.ACCOUNT_ARN_FORMAT.format( - self.master_account_id, - self.organization_id, - self.id + self.master_account_id, self.organization_id, self.id ) @property def create_account_status(self): return { - 'CreateAccountStatus': { - 'Id': self.create_account_status_id, - 'AccountName': self.name, - 'State': 'SUCCEEDED', - 'RequestedTimestamp': unix_time(self.create_time), - 'CompletedTimestamp': unix_time(self.create_time), - 'AccountId': self.id, + "CreateAccountStatus": { + "Id": self.create_account_status_id, + "AccountName": self.name, + "State": "SUCCEEDED", + "RequestedTimestamp": unix_time(self.create_time), + "CompletedTimestamp": unix_time(self.create_time), + "AccountId": self.id, } } def describe(self): return { - 'Account': { - 'Id': self.id, - 'Arn': self.arn, - 'Email': self.email, - 'Name': self.name, - 'Status': self.status, - 'JoinedMethod': self.joined_method, - 'JoinedTimestamp': unix_time(self.create_time), + "Account": { + "Id": self.id, + "Arn": self.arn, + "Email": self.email, + "Name": self.name, + "Status": self.status, + "JoinedMethod": self.joined_method, + "JoinedTimestamp": unix_time(self.create_time), } } class FakeOrganizationalUnit(BaseModel): - def __init__(self, organization, **kwargs): - self.type = 'ORGANIZATIONAL_UNIT' + self.type = "ORGANIZATIONAL_UNIT" self.organization_id = organization.id self.master_account_id = organization.master_account_id self.id = utils.make_random_ou_id(organization.root_id) - self.name = kwargs.get('Name') - self.parent_id = kwargs.get('ParentId') + self.name = kwargs.get("Name") + self.parent_id = kwargs.get("ParentId") self._arn_format = utils.OU_ARN_FORMAT self.attached_policies = [] @property def arn(self): return self._arn_format.format( - self.master_account_id, - self.organization_id, - self.id + self.master_account_id, self.organization_id, self.id ) def describe(self): return { - 'OrganizationalUnit': { - 'Id': self.id, - 'Arn': self.arn, - 'Name': self.name, - } + "OrganizationalUnit": {"Id": self.id, "Arn": self.arn, "Name": self.name} } class FakeRoot(FakeOrganizationalUnit): - def __init__(self, organization, **kwargs): super(FakeRoot, self).__init__(organization, **kwargs) - self.type = 'ROOT' + self.type = "ROOT" self.id = organization.root_id - self.name = 'Root' - self.policy_types = [{ - 'Type': 'SERVICE_CONTROL_POLICY', - 'Status': 'ENABLED' - }] + self.name = "Root" + self.policy_types = [{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}] self._arn_format = utils.ROOT_ARN_FORMAT self.attached_policies = [] def describe(self): return { - 'Id': self.id, - 'Arn': self.arn, - 'Name': self.name, - 'PolicyTypes': self.policy_types + "Id": self.id, + "Arn": self.arn, + "Name": self.name, + "PolicyTypes": self.policy_types, } class FakeServiceControlPolicy(BaseModel): - def __init__(self, organization, **kwargs): - self.content = kwargs.get('Content') - self.description = kwargs.get('Description') - self.name = kwargs.get('Name') - self.type = kwargs.get('Type') + self.content = kwargs.get("Content") + self.description = kwargs.get("Description") + self.name = kwargs.get("Name") + self.type = kwargs.get("Type") self.id = utils.make_random_service_control_policy_id() self.aws_managed = False self.organization_id = organization.id @@ -166,29 +149,26 @@ class FakeServiceControlPolicy(BaseModel): @property def arn(self): return self._arn_format.format( - self.master_account_id, - self.organization_id, - self.id + self.master_account_id, self.organization_id, self.id ) def describe(self): return { - 'Policy': { - 'PolicySummary': { - 'Id': self.id, - 'Arn': self.arn, - 'Name': self.name, - 'Description': self.description, - 'Type': self.type, - 'AwsManaged': self.aws_managed, + "Policy": { + "PolicySummary": { + "Id": self.id, + "Arn": self.arn, + "Name": self.name, + "Description": self.description, + "Type": self.type, + "AwsManaged": self.aws_managed, }, - 'Content': self.content + "Content": self.content, } } class OrganizationsBackend(BaseBackend): - def __init__(self): self.org = None self.accounts = [] @@ -196,33 +176,25 @@ class OrganizationsBackend(BaseBackend): self.policies = [] def create_organization(self, **kwargs): - self.org = FakeOrganization(kwargs['FeatureSet']) + self.org = FakeOrganization(kwargs["FeatureSet"]) root_ou = FakeRoot(self.org) self.ou.append(root_ou) master_account = FakeAccount( - self.org, - AccountName='master', - Email=self.org.master_account_email, + self.org, AccountName="master", Email=self.org.master_account_email ) master_account.id = self.org.master_account_id self.accounts.append(master_account) default_policy = FakeServiceControlPolicy( self.org, - Name='FullAWSAccess', - Description='Allows access to every operation', - Type='SERVICE_CONTROL_POLICY', + Name="FullAWSAccess", + Description="Allows access to every operation", + Type="SERVICE_CONTROL_POLICY", Content=json.dumps( { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}], } - ) + ), ) default_policy.id = utils.DEFAULT_POLICY_ID default_policy.aws_managed = True @@ -234,15 +206,13 @@ class OrganizationsBackend(BaseBackend): def describe_organization(self): if not self.org: raise RESTError( - 'AWSOrganizationsNotInUseException', - "Your account is not a member of an organization." + "AWSOrganizationsNotInUseException", + "Your account is not a member of an organization.", ) return self.org.describe() def list_roots(self): - return dict( - Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)] - ) + return dict(Roots=[ou.describe() for ou in self.ou if isinstance(ou, FakeRoot)]) def create_organizational_unit(self, **kwargs): new_ou = FakeOrganizationalUnit(self.org, **kwargs) @@ -254,8 +224,8 @@ class OrganizationsBackend(BaseBackend): ou = next((ou for ou in self.ou if ou.id == ou_id), None) if ou is None: raise RESTError( - 'OrganizationalUnitNotFoundException', - "You specified an organizational unit that doesn't exist." + "OrganizationalUnitNotFoundException", + "You specified an organizational unit that doesn't exist.", ) return ou @@ -264,24 +234,19 @@ class OrganizationsBackend(BaseBackend): self.get_organizational_unit_by_id(parent_id) except RESTError: raise RESTError( - 'ParentNotFoundException', - "You specified parent that doesn't exist." + "ParentNotFoundException", "You specified parent that doesn't exist." ) return parent_id def describe_organizational_unit(self, **kwargs): - ou = self.get_organizational_unit_by_id(kwargs['OrganizationalUnitId']) + ou = self.get_organizational_unit_by_id(kwargs["OrganizationalUnitId"]) return ou.describe() def list_organizational_units_for_parent(self, **kwargs): - parent_id = self.validate_parent_id(kwargs['ParentId']) + parent_id = self.validate_parent_id(kwargs["ParentId"]) return dict( OrganizationalUnits=[ - { - 'Id': ou.id, - 'Arn': ou.arn, - 'Name': ou.name, - } + {"Id": ou.id, "Arn": ou.arn, "Name": ou.name} for ou in self.ou if ou.parent_id == parent_id ] @@ -294,76 +259,66 @@ class OrganizationsBackend(BaseBackend): return new_account.create_account_status def get_account_by_id(self, account_id): - account = next(( - account for account in self.accounts - if account.id == account_id - ), None) + account = next( + (account for account in self.accounts if account.id == account_id), None + ) if account is None: raise RESTError( - 'AccountNotFoundException', - "You specified an account that doesn't exist." + "AccountNotFoundException", + "You specified an account that doesn't exist.", ) return account def describe_account(self, **kwargs): - account = self.get_account_by_id(kwargs['AccountId']) + account = self.get_account_by_id(kwargs["AccountId"]) return account.describe() def list_accounts(self): return dict( - Accounts=[account.describe()['Account'] for account in self.accounts] + Accounts=[account.describe()["Account"] for account in self.accounts] ) def list_accounts_for_parent(self, **kwargs): - parent_id = self.validate_parent_id(kwargs['ParentId']) + parent_id = self.validate_parent_id(kwargs["ParentId"]) return dict( Accounts=[ - account.describe()['Account'] + account.describe()["Account"] for account in self.accounts if account.parent_id == parent_id ] ) def move_account(self, **kwargs): - new_parent_id = self.validate_parent_id(kwargs['DestinationParentId']) - self.validate_parent_id(kwargs['SourceParentId']) - account = self.get_account_by_id(kwargs['AccountId']) + new_parent_id = self.validate_parent_id(kwargs["DestinationParentId"]) + self.validate_parent_id(kwargs["SourceParentId"]) + account = self.get_account_by_id(kwargs["AccountId"]) index = self.accounts.index(account) self.accounts[index].parent_id = new_parent_id def list_parents(self, **kwargs): - if re.compile(r'[0-9]{12}').match(kwargs['ChildId']): - child_object = self.get_account_by_id(kwargs['ChildId']) + if re.compile(r"[0-9]{12}").match(kwargs["ChildId"]): + child_object = self.get_account_by_id(kwargs["ChildId"]) else: - child_object = self.get_organizational_unit_by_id(kwargs['ChildId']) + child_object = self.get_organizational_unit_by_id(kwargs["ChildId"]) return dict( Parents=[ - { - 'Id': ou.id, - 'Type': ou.type, - } + {"Id": ou.id, "Type": ou.type} for ou in self.ou if ou.id == child_object.parent_id ] ) def list_children(self, **kwargs): - parent_id = self.validate_parent_id(kwargs['ParentId']) - if kwargs['ChildType'] == 'ACCOUNT': + parent_id = self.validate_parent_id(kwargs["ParentId"]) + if kwargs["ChildType"] == "ACCOUNT": obj_list = self.accounts - elif kwargs['ChildType'] == 'ORGANIZATIONAL_UNIT': + elif kwargs["ChildType"] == "ORGANIZATIONAL_UNIT": obj_list = self.ou else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") return dict( Children=[ - { - 'Id': obj.id, - 'Type': kwargs['ChildType'], - } + {"Id": obj.id, "Type": kwargs["ChildType"]} for obj in obj_list if obj.parent_id == parent_id ] @@ -375,99 +330,93 @@ class OrganizationsBackend(BaseBackend): return new_policy.describe() def describe_policy(self, **kwargs): - if re.compile(utils.SCP_ID_REGEX).match(kwargs['PolicyId']): - policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) + if re.compile(utils.SCP_ID_REGEX).match(kwargs["PolicyId"]): + policy = next( + (p for p in self.policies if p.id == kwargs["PolicyId"]), None + ) if policy is None: raise RESTError( - 'PolicyNotFoundException', - "You specified a policy that doesn't exist." + "PolicyNotFoundException", + "You specified a policy that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") return policy.describe() def attach_policy(self, **kwargs): - policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) - if (re.compile(utils.ROOT_ID_REGEX).match(kwargs['TargetId']) or re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId'])): - ou = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None) + policy = next((p for p in self.policies if p.id == kwargs["PolicyId"]), None) + if re.compile(utils.ROOT_ID_REGEX).match(kwargs["TargetId"]) or re.compile( + utils.OU_ID_REGEX + ).match(kwargs["TargetId"]): + ou = next((ou for ou in self.ou if ou.id == kwargs["TargetId"]), None) if ou is not None: if ou not in ou.attached_policies: ou.attached_policies.append(policy) policy.attachments.append(ou) else: raise RESTError( - 'OrganizationalUnitNotFoundException', - "You specified an organizational unit that doesn't exist." + "OrganizationalUnitNotFoundException", + "You specified an organizational unit that doesn't exist.", ) - elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs['TargetId']): - account = next((a for a in self.accounts if a.id == kwargs['TargetId']), None) + elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs["TargetId"]): + account = next( + (a for a in self.accounts if a.id == kwargs["TargetId"]), None + ) if account is not None: if account not in account.attached_policies: account.attached_policies.append(policy) policy.attachments.append(account) else: raise RESTError( - 'AccountNotFoundException', - "You specified an account that doesn't exist." + "AccountNotFoundException", + "You specified an account that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") def list_policies(self, **kwargs): - return dict(Policies=[ - p.describe()['Policy']['PolicySummary'] for p in self.policies - ]) + return dict( + Policies=[p.describe()["Policy"]["PolicySummary"] for p in self.policies] + ) def list_policies_for_target(self, **kwargs): - if re.compile(utils.OU_ID_REGEX).match(kwargs['TargetId']): - obj = next((ou for ou in self.ou if ou.id == kwargs['TargetId']), None) + if re.compile(utils.OU_ID_REGEX).match(kwargs["TargetId"]): + obj = next((ou for ou in self.ou if ou.id == kwargs["TargetId"]), None) if obj is None: raise RESTError( - 'OrganizationalUnitNotFoundException', - "You specified an organizational unit that doesn't exist." + "OrganizationalUnitNotFoundException", + "You specified an organizational unit that doesn't exist.", ) - elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs['TargetId']): - obj = next((a for a in self.accounts if a.id == kwargs['TargetId']), None) + elif re.compile(utils.ACCOUNT_ID_REGEX).match(kwargs["TargetId"]): + obj = next((a for a in self.accounts if a.id == kwargs["TargetId"]), None) if obj is None: raise RESTError( - 'AccountNotFoundException', - "You specified an account that doesn't exist." + "AccountNotFoundException", + "You specified an account that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) - return dict(Policies=[ - p.describe()['Policy']['PolicySummary'] for p in obj.attached_policies - ]) + raise RESTError("InvalidInputException", "You specified an invalid value.") + return dict( + Policies=[ + p.describe()["Policy"]["PolicySummary"] for p in obj.attached_policies + ] + ) def list_targets_for_policy(self, **kwargs): - if re.compile(utils.SCP_ID_REGEX).match(kwargs['PolicyId']): - policy = next((p for p in self.policies if p.id == kwargs['PolicyId']), None) + if re.compile(utils.SCP_ID_REGEX).match(kwargs["PolicyId"]): + policy = next( + (p for p in self.policies if p.id == kwargs["PolicyId"]), None + ) if policy is None: raise RESTError( - 'PolicyNotFoundException', - "You specified a policy that doesn't exist." + "PolicyNotFoundException", + "You specified a policy that doesn't exist.", ) else: - raise RESTError( - 'InvalidInputException', - 'You specified an invalid value.' - ) + raise RESTError("InvalidInputException", "You specified an invalid value.") objects = [ - { - 'TargetId': obj.id, - 'Arn': obj.arn, - 'Name': obj.name, - 'Type': obj.type, - } for obj in policy.attachments + {"TargetId": obj.id, "Arn": obj.arn, "Name": obj.name, "Type": obj.type} + for obj in policy.attachments ] return dict(Targets=objects) diff --git a/moto/organizations/responses.py b/moto/organizations/responses.py index 814f30bad..673bf5adb 100644 --- a/moto/organizations/responses.py +++ b/moto/organizations/responses.py @@ -6,7 +6,6 @@ from .models import organizations_backend class OrganizationsResponse(BaseResponse): - @property def organizations_backend(self): return organizations_backend @@ -27,14 +26,10 @@ class OrganizationsResponse(BaseResponse): ) def describe_organization(self): - return json.dumps( - self.organizations_backend.describe_organization() - ) + return json.dumps(self.organizations_backend.describe_organization()) def list_roots(self): - return json.dumps( - self.organizations_backend.list_roots() - ) + return json.dumps(self.organizations_backend.list_roots()) def create_organizational_unit(self): return json.dumps( @@ -43,12 +38,16 @@ class OrganizationsResponse(BaseResponse): def describe_organizational_unit(self): return json.dumps( - self.organizations_backend.describe_organizational_unit(**self.request_params) + self.organizations_backend.describe_organizational_unit( + **self.request_params + ) ) def list_organizational_units_for_parent(self): return json.dumps( - self.organizations_backend.list_organizational_units_for_parent(**self.request_params) + self.organizations_backend.list_organizational_units_for_parent( + **self.request_params + ) ) def list_parents(self): @@ -67,9 +66,7 @@ class OrganizationsResponse(BaseResponse): ) def list_accounts(self): - return json.dumps( - self.organizations_backend.list_accounts() - ) + return json.dumps(self.organizations_backend.list_accounts()) def list_accounts_for_parent(self): return json.dumps( diff --git a/moto/organizations/urls.py b/moto/organizations/urls.py index 7911f5b53..d0909bbef 100644 --- a/moto/organizations/urls.py +++ b/moto/organizations/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import OrganizationsResponse -url_bases = [ - "https?://organizations.(.+).amazonaws.com", -] +url_bases = ["https?://organizations.(.+).amazonaws.com"] -url_paths = { - '{0}/$': OrganizationsResponse.dispatch, -} +url_paths = {"{0}/$": OrganizationsResponse.dispatch} diff --git a/moto/organizations/utils.py b/moto/organizations/utils.py index 5cbe59ada..dacd58502 100644 --- a/moto/organizations/utils.py +++ b/moto/organizations/utils.py @@ -3,15 +3,15 @@ from __future__ import unicode_literals import random import string -MASTER_ACCOUNT_ID = '123456789012' -MASTER_ACCOUNT_EMAIL = 'master@example.com' -DEFAULT_POLICY_ID = 'p-FullAWSAccess' -ORGANIZATION_ARN_FORMAT = 'arn:aws:organizations::{0}:organization/{1}' -MASTER_ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{0}' -ACCOUNT_ARN_FORMAT = 'arn:aws:organizations::{0}:account/{1}/{2}' -ROOT_ARN_FORMAT = 'arn:aws:organizations::{0}:root/{1}/{2}' -OU_ARN_FORMAT = 'arn:aws:organizations::{0}:ou/{1}/{2}' -SCP_ARN_FORMAT = 'arn:aws:organizations::{0}:policy/{1}/service_control_policy/{2}' +MASTER_ACCOUNT_ID = "123456789012" +MASTER_ACCOUNT_EMAIL = "master@example.com" +DEFAULT_POLICY_ID = "p-FullAWSAccess" +ORGANIZATION_ARN_FORMAT = "arn:aws:organizations::{0}:organization/{1}" +MASTER_ACCOUNT_ARN_FORMAT = "arn:aws:organizations::{0}:account/{1}/{0}" +ACCOUNT_ARN_FORMAT = "arn:aws:organizations::{0}:account/{1}/{2}" +ROOT_ARN_FORMAT = "arn:aws:organizations::{0}:root/{1}/{2}" +OU_ARN_FORMAT = "arn:aws:organizations::{0}:ou/{1}/{2}" +SCP_ARN_FORMAT = "arn:aws:organizations::{0}:policy/{1}/service_control_policy/{2}" CHARSET = string.ascii_lowercase + string.digits ORG_ID_SIZE = 10 @@ -22,26 +22,26 @@ CREATE_ACCOUNT_STATUS_ID_SIZE = 8 SCP_ID_SIZE = 8 EMAIL_REGEX = "^.+@[a-zA-Z0-9-.]+.[a-zA-Z]{2,3}|[0-9]{1,3}$" -ORG_ID_REGEX = r'o-[a-z0-9]{%s}' % ORG_ID_SIZE -ROOT_ID_REGEX = r'r-[a-z0-9]{%s}' % ROOT_ID_SIZE -OU_ID_REGEX = r'ou-[a-z0-9]{%s}-[a-z0-9]{%s}' % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE) -ACCOUNT_ID_REGEX = r'[0-9]{%s}' % ACCOUNT_ID_SIZE -CREATE_ACCOUNT_STATUS_ID_REGEX = r'car-[a-z0-9]{%s}' % CREATE_ACCOUNT_STATUS_ID_SIZE -SCP_ID_REGEX = r'%s|p-[a-z0-9]{%s}' % (DEFAULT_POLICY_ID, SCP_ID_SIZE) +ORG_ID_REGEX = r"o-[a-z0-9]{%s}" % ORG_ID_SIZE +ROOT_ID_REGEX = r"r-[a-z0-9]{%s}" % ROOT_ID_SIZE +OU_ID_REGEX = r"ou-[a-z0-9]{%s}-[a-z0-9]{%s}" % (ROOT_ID_SIZE, OU_ID_SUFFIX_SIZE) +ACCOUNT_ID_REGEX = r"[0-9]{%s}" % ACCOUNT_ID_SIZE +CREATE_ACCOUNT_STATUS_ID_REGEX = r"car-[a-z0-9]{%s}" % CREATE_ACCOUNT_STATUS_ID_SIZE +SCP_ID_REGEX = r"%s|p-[a-z0-9]{%s}" % (DEFAULT_POLICY_ID, SCP_ID_SIZE) def make_random_org_id(): # The regex pattern for an organization ID string requires "o-" # followed by from 10 to 32 lower-case letters or digits. # e.g. 'o-vipjnq5z86' - return 'o-' + ''.join(random.choice(CHARSET) for x in range(ORG_ID_SIZE)) + return "o-" + "".join(random.choice(CHARSET) for x in range(ORG_ID_SIZE)) def make_random_root_id(): # The regex pattern for a root ID string requires "r-" followed by # from 4 to 32 lower-case letters or digits. # e.g. 'r-3zwx' - return 'r-' + ''.join(random.choice(CHARSET) for x in range(ROOT_ID_SIZE)) + return "r-" + "".join(random.choice(CHARSET) for x in range(ROOT_ID_SIZE)) def make_random_ou_id(root_id): @@ -50,28 +50,32 @@ def make_random_ou_id(root_id): # that contains the OU) followed by a second "-" dash and from 8 to 32 # additional lower-case letters or digits. # e.g. ou-g8sd-5oe3bjaw - return '-'.join([ - 'ou', - root_id.partition('-')[2], - ''.join(random.choice(CHARSET) for x in range(OU_ID_SUFFIX_SIZE)), - ]) + return "-".join( + [ + "ou", + root_id.partition("-")[2], + "".join(random.choice(CHARSET) for x in range(OU_ID_SUFFIX_SIZE)), + ] + ) def make_random_account_id(): # The regex pattern for an account ID string requires exactly 12 digits. # e.g. '488633172133' - return ''.join([random.choice(string.digits) for n in range(ACCOUNT_ID_SIZE)]) + return "".join([random.choice(string.digits) for n in range(ACCOUNT_ID_SIZE)]) def make_random_create_account_status_id(): # The regex pattern for an create account request ID string requires # "car-" followed by from 8 to 32 lower-case letters or digits. # e.g. 'car-35gxzwrp' - return 'car-' + ''.join(random.choice(CHARSET) for x in range(CREATE_ACCOUNT_STATUS_ID_SIZE)) + return "car-" + "".join( + random.choice(CHARSET) for x in range(CREATE_ACCOUNT_STATUS_ID_SIZE) + ) def make_random_service_control_policy_id(): # The regex pattern for a policy ID string requires "p-" followed by # from 8 to 128 lower-case letters or digits. # e.g. 'p-k2av4a8a' - return 'p-' + ''.join(random.choice(CHARSET) for x in range(SCP_ID_SIZE)) + return "p-" + "".join(random.choice(CHARSET) for x in range(SCP_ID_SIZE)) diff --git a/moto/packages/httpretty/__init__.py b/moto/packages/httpretty/__init__.py index 679294a4b..c6a78526f 100644 --- a/moto/packages/httpretty/__init__.py +++ b/moto/packages/httpretty/__init__.py @@ -25,7 +25,7 @@ # OTHER DEALINGS IN THE SOFTWARE. from __future__ import unicode_literals -__version__ = version = '0.8.10' +__version__ = version = "0.8.10" from .core import httpretty, httprettified, EmptyRequestHeaders from .errors import HTTPrettyError, UnmockedError diff --git a/moto/packages/httpretty/compat.py b/moto/packages/httpretty/compat.py index b9e215b13..c452dec0e 100644 --- a/moto/packages/httpretty/compat.py +++ b/moto/packages/httpretty/compat.py @@ -34,33 +34,36 @@ if PY3: # pragma: no cover text_type = str byte_type = bytes import io + StringIO = io.BytesIO basestring = (str, bytes) class BaseClass(object): - def __repr__(self): return self.__str__() + + else: # pragma: no cover text_type = unicode byte_type = str import StringIO + StringIO = StringIO.StringIO basestring = basestring class BaseClass(object): - def __repr__(self): ret = self.__str__() if PY3: # pragma: no cover return ret else: - return ret.encode('utf-8') + return ret.encode("utf-8") try: # pragma: no cover from urllib.parse import urlsplit, urlunsplit, parse_qs, quote, quote_plus, unquote + unquote_utf8 = unquote except ImportError: # pragma: no cover from urlparse import urlsplit, urlunsplit, parse_qs, unquote @@ -68,7 +71,7 @@ except ImportError: # pragma: no cover def unquote_utf8(qs): if isinstance(qs, text_type): - qs = qs.encode('utf-8') + qs = qs.encode("utf-8") s = unquote(qs) if isinstance(s, byte_type): return s.decode("utf-8") @@ -88,16 +91,16 @@ if not PY3: # pragma: no cover __all__ = [ - 'PY3', - 'StringIO', - 'text_type', - 'byte_type', - 'BaseClass', - 'BaseHTTPRequestHandler', - 'quote', - 'quote_plus', - 'urlunsplit', - 'urlsplit', - 'parse_qs', - 'ClassTypes', + "PY3", + "StringIO", + "text_type", + "byte_type", + "BaseClass", + "BaseHTTPRequestHandler", + "quote", + "quote_plus", + "urlunsplit", + "urlsplit", + "parse_qs", + "ClassTypes", ] diff --git a/moto/packages/httpretty/core.py b/moto/packages/httpretty/core.py index f94723017..0c9635e79 100644 --- a/moto/packages/httpretty/core.py +++ b/moto/packages/httpretty/core.py @@ -52,19 +52,11 @@ from .compat import ( unquote, unquote_utf8, ClassTypes, - basestring -) -from .http import ( - STATUSES, - HttpBaseClass, - parse_requestline, - last_requestline, + basestring, ) +from .http import STATUSES, HttpBaseClass, parse_requestline, last_requestline -from .utils import ( - utf8, - decode_utf8, -) +from .utils import utf8, decode_utf8 from .errors import HTTPrettyError, UnmockedError @@ -91,12 +83,14 @@ if PY3: # pragma: no cover basestring = (bytes, str) try: # pragma: no cover import socks + old_socksocket = socks.socksocket except ImportError: socks = None try: # pragma: no cover import ssl + old_ssl_wrap_socket = ssl.wrap_socket if not PY3: old_sslwrap_simple = ssl.sslwrap_simple @@ -109,7 +103,11 @@ except ImportError: # pragma: no cover ssl = None try: # pragma: no cover - from requests.packages.urllib3.contrib.pyopenssl import inject_into_urllib3, extract_from_urllib3 + from requests.packages.urllib3.contrib.pyopenssl import ( + inject_into_urllib3, + extract_from_urllib3, + ) + pyopenssl_override = True except: pyopenssl_override = False @@ -154,7 +152,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): 'application/x-www-form-urlencoded' """ - def __init__(self, headers, body=''): + def __init__(self, headers, body=""): # first of all, lets make sure that if headers or body are # unicode strings, it must be converted into a utf-8 encoded # byte string @@ -163,7 +161,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): # Now let's concatenate the headers with the body, and create # `rfile` based on it - self.rfile = StringIO(b'\r\n\r\n'.join([self.raw_headers, self.body])) + self.rfile = StringIO(b"\r\n\r\n".join([self.raw_headers, self.body])) self.wfile = StringIO() # Creating `wfile` as an empty # StringIO, just to avoid any real # I/O calls @@ -186,7 +184,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): # `querystring` holds a dictionary with the parsed query string try: - self.path = self.path.encode('iso-8859-1') + self.path = self.path.encode("iso-8859-1") except UnicodeDecodeError: pass @@ -201,9 +199,7 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): def __str__(self): return ''.format( - self.headers.get('content-type', ''), - len(self.headers), - len(self.body), + self.headers.get("content-type", ""), len(self.headers), len(self.body) ) def parse_querystring(self, qs): @@ -219,13 +215,13 @@ class HTTPrettyRequest(BaseHTTPRequestHandler, BaseClass): """ Attempt to parse the post based on the content-type passed. Return the regular body if not """ PARSING_FUNCTIONS = { - 'application/json': json.loads, - 'text/json': json.loads, - 'application/x-www-form-urlencoded': self.parse_querystring, + "application/json": json.loads, + "text/json": json.loads, + "application/x-www-form-urlencoded": self.parse_querystring, } FALLBACK_FUNCTION = lambda x: x - content_type = self.headers.get('content-type', '') + content_type = self.headers.get("content-type", "") do_parse = PARSING_FUNCTIONS.get(content_type, FALLBACK_FUNCTION) try: @@ -240,19 +236,17 @@ class EmptyRequestHeaders(dict): class HTTPrettyRequestEmpty(object): - body = '' + body = "" headers = EmptyRequestHeaders() class FakeSockFile(StringIO): - def close(self): self.socket.close() StringIO.close(self) class FakeSSLSocket(object): - def __init__(self, sock, *args, **kw): self._httpretty_sock = sock @@ -261,14 +255,19 @@ class FakeSSLSocket(object): class fakesock(object): - class socket(object): _entry = None debuglevel = 0 _sent_data = [] - def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, - proto=0, fileno=None, _sock=None): + def __init__( + self, + family=socket.AF_INET, + type=socket.SOCK_STREAM, + proto=0, + fileno=None, + _sock=None, + ): """ Matches both the Python 2 API: def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None): @@ -300,23 +299,16 @@ class fakesock(object): now = datetime.now() shift = now + timedelta(days=30 * 12) return { - 'notAfter': shift.strftime('%b %d %H:%M:%S GMT'), - 'subjectAltName': ( - ('DNS', '*.%s' % self._host), - ('DNS', self._host), - ('DNS', '*'), + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", "*.%s" % self._host), + ("DNS", self._host), + ("DNS", "*"), ), - 'subject': ( - ( - ('organizationName', '*.%s' % self._host), - ), - ( - ('organizationalUnitName', - 'Domain Control Validated'), - ), - ( - ('commonName', '*.%s' % self._host), - ), + "subject": ( + (("organizationName", "*.%s" % self._host),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", "*.%s" % self._host),), ), } @@ -339,7 +331,9 @@ class fakesock(object): # See issue #206 self.is_http = False else: - self.is_http = self._port in POTENTIAL_HTTP_PORTS | POTENTIAL_HTTPS_PORTS + self.is_http = ( + self._port in POTENTIAL_HTTP_PORTS | POTENTIAL_HTTPS_PORTS + ) if not self.is_http: if self.truesock: @@ -353,7 +347,7 @@ class fakesock(object): self.truesock.close() self._closed = True - def makefile(self, mode='r', bufsize=-1): + def makefile(self, mode="r", bufsize=-1): """Returns this fake socket's own StringIO buffer. If there is an entry associated with the socket, the file @@ -408,9 +402,8 @@ class fakesock(object): self.fd = FakeSockFile() self.fd.socket = self try: - requestline, _ = data.split(b'\r\n', 1) - method, path, version = parse_requestline( - decode_utf8(requestline)) + requestline, _ = data.split(b"\r\n", 1) + method, path, version = parse_requestline(decode_utf8(requestline)) is_parsing_headers = True except ValueError: is_parsing_headers = False @@ -427,8 +420,12 @@ class fakesock(object): headers = utf8(last_requestline(self._sent_data)) meta = self._entry.request.headers body = utf8(self._sent_data[-1]) - if meta.get('transfer-encoding', '') == 'chunked': - if not body.isdigit() and body != b'\r\n' and body != b'0\r\n\r\n': + if meta.get("transfer-encoding", "") == "chunked": + if ( + not body.isdigit() + and body != b"\r\n" + and body != b"0\r\n\r\n" + ): self._entry.request.body += body else: self._entry.request.body += body @@ -439,14 +436,17 @@ class fakesock(object): # path might come with s = urlsplit(path) POTENTIAL_HTTP_PORTS.add(int(s.port or 80)) - headers, body = list(map(utf8, data.split(b'\r\n\r\n', 1))) + headers, body = list(map(utf8, data.split(b"\r\n\r\n", 1))) request = httpretty.historify_request(headers, body) - info = URIInfo(hostname=self._host, port=self._port, - path=s.path, - query=s.query, - last_request=request) + info = URIInfo( + hostname=self._host, + port=self._port, + path=s.path, + query=s.query, + last_request=request, + ) matcher, entries = httpretty.match_uriinfo(info) @@ -464,8 +464,10 @@ class fakesock(object): message = [ "HTTPretty intercepted and unexpected socket method call.", - ("Please open an issue at " - "'https://github.com/gabrielfalcao/HTTPretty/issues'"), + ( + "Please open an issue at " + "'https://github.com/gabrielfalcao/HTTPretty/issues'" + ), "And paste the following traceback:\n", "".join(decode_utf8(lines)), ] @@ -478,22 +480,22 @@ class fakesock(object): self.timeout = new_timeout def send(self, *args, **kwargs): - return self.debug('send', *args, **kwargs) + return self.debug("send", *args, **kwargs) def sendto(self, *args, **kwargs): - return self.debug('sendto', *args, **kwargs) + return self.debug("sendto", *args, **kwargs) def recvfrom_into(self, *args, **kwargs): - return self.debug('recvfrom_into', *args, **kwargs) + return self.debug("recvfrom_into", *args, **kwargs) def recv_into(self, *args, **kwargs): - return self.debug('recv_into', *args, **kwargs) + return self.debug("recv_into", *args, **kwargs) def recvfrom(self, *args, **kwargs): - return self.debug('recvfrom', *args, **kwargs) + return self.debug("recvfrom", *args, **kwargs) def recv(self, *args, **kwargs): - return self.debug('recv', *args, **kwargs) + return self.debug("recv", *args, **kwargs) def __getattr__(self, name): if not self.truesock: @@ -505,7 +507,9 @@ def fake_wrap_socket(s, *args, **kw): return s -def create_fake_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None): +def create_fake_connection( + address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, source_address=None +): s = fakesock.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout is not socket._GLOBAL_DEFAULT_TIMEOUT: s.settimeout(timeout) @@ -516,26 +520,29 @@ def create_fake_connection(address, timeout=socket._GLOBAL_DEFAULT_TIMEOUT, sour def fake_gethostbyname(host): - return '127.0.0.1' + return "127.0.0.1" def fake_gethostname(): - return 'localhost' + return "localhost" -def fake_getaddrinfo( - host, port, family=None, socktype=None, proto=None, flags=None): - return [(2, 1, 6, '', (host, port))] +def fake_getaddrinfo(host, port, family=None, socktype=None, proto=None, flags=None): + return [(2, 1, 6, "", (host, port))] class Entry(BaseClass): - - def __init__(self, method, uri, body, - adding_headers=None, - forcing_headers=None, - status=200, - streaming=False, - **headers): + def __init__( + self, + method, + uri, + body, + adding_headers=None, + forcing_headers=None, + status=200, + streaming=False, + **headers + ): self.method = method self.uri = uri @@ -554,7 +561,7 @@ class Entry(BaseClass): self.streaming = streaming if not streaming and not self.body_is_callable: - self.body_length = len(self.body or '') + self.body_length = len(self.body or "") else: self.body_length = 0 @@ -569,10 +576,9 @@ class Entry(BaseClass): self.validate() def validate(self): - content_length_keys = 'Content-Length', 'content-length' + content_length_keys = "Content-Length", "content-length" for key in content_length_keys: - got = self.adding_headers.get( - key, self.forcing_headers.get(key, None)) + got = self.adding_headers.get(key, self.forcing_headers.get(key, None)) if got is None: continue @@ -581,28 +587,25 @@ class Entry(BaseClass): igot = int(got) except ValueError: warnings.warn( - 'HTTPretty got to register the Content-Length header ' - 'with "%r" which is not a number' % got, + "HTTPretty got to register the Content-Length header " + 'with "%r" which is not a number' % got ) if igot > self.body_length: raise HTTPrettyError( - 'HTTPretty got inconsistent parameters. The header ' + "HTTPretty got inconsistent parameters. The header " 'Content-Length you registered expects size "%d" but ' - 'the body you registered for that has actually length ' - '"%d".' % ( - igot, self.body_length, - ) + "the body you registered for that has actually length " + '"%d".' % (igot, self.body_length) ) def __str__(self): - return r'' % ( - self.method, self.uri, self.status) + return r"" % (self.method, self.uri, self.status) def normalize_headers(self, headers): new = {} for k in headers: - new_k = '-'.join([s.lower() for s in k.split('-')]) + new_k = "-".join([s.lower() for s in k.split("-")]) new[new_k] = headers[k] return new @@ -611,10 +614,10 @@ class Entry(BaseClass): now = datetime.utcnow() headers = { - 'status': self.status, - 'date': now.strftime('%a, %d %b %Y %H:%M:%S GMT'), - 'server': 'Python/HTTPretty', - 'connection': 'close', + "status": self.status, + "date": now.strftime("%a, %d %b %Y %H:%M:%S GMT"), + "server": "Python/HTTPretty", + "connection": "close", } if self.forcing_headers: @@ -624,44 +627,38 @@ class Entry(BaseClass): headers.update(self.normalize_headers(self.adding_headers)) headers = self.normalize_headers(headers) - status = headers.get('status', self.status) + status = headers.get("status", self.status) if self.body_is_callable: status, headers, self.body = self.callable_body( - self.request, self.info.full_url(), headers) + self.request, self.info.full_url(), headers + ) headers = self.normalize_headers(headers) if self.request.method != "HEAD": - headers.update({ - 'content-length': len(self.body) - }) + headers.update({"content-length": len(self.body)}) - string_list = [ - 'HTTP/1.1 %d %s' % (status, STATUSES[status]), - ] + string_list = ["HTTP/1.1 %d %s" % (status, STATUSES[status])] - if 'date' in headers: - string_list.append('date: %s' % headers.pop('date')) + if "date" in headers: + string_list.append("date: %s" % headers.pop("date")) if not self.forcing_headers: - content_type = headers.pop('content-type', - 'text/plain; charset=utf-8') + content_type = headers.pop("content-type", "text/plain; charset=utf-8") - content_length = headers.pop('content-length', self.body_length) + content_length = headers.pop("content-length", self.body_length) - string_list.append('content-type: %s' % content_type) + string_list.append("content-type: %s" % content_type) if not self.streaming: - string_list.append('content-length: %s' % content_length) + string_list.append("content-length: %s" % content_length) - string_list.append('server: %s' % headers.pop('server')) + string_list.append("server: %s" % headers.pop("server")) for k, v in headers.items(): - string_list.append( - '{0}: {1}'.format(k, v), - ) + string_list.append("{0}: {1}".format(k, v)) for item in string_list: - fk.write(utf8(item) + b'\n') + fk.write(utf8(item) + b"\n") - fk.write(b'\r\n') + fk.write(b"\r\n") if self.streaming: self.body, body = itertools.tee(self.body) @@ -673,58 +670,53 @@ class Entry(BaseClass): fk.seek(0) -def url_fix(s, charset='utf-8'): +def url_fix(s, charset="utf-8"): scheme, netloc, path, querystring, fragment = urlsplit(s) - path = quote(path, b'/%') - querystring = quote_plus(querystring, b':&=') + path = quote(path, b"/%") + querystring = quote_plus(querystring, b":&=") return urlunsplit((scheme, netloc, path, querystring, fragment)) class URIInfo(BaseClass): + def __init__( + self, + username="", + password="", + hostname="", + port=80, + path="/", + query="", + fragment="", + scheme="", + last_request=None, + ): - def __init__(self, - username='', - password='', - hostname='', - port=80, - path='/', - query='', - fragment='', - scheme='', - last_request=None): - - self.username = username or '' - self.password = password or '' - self.hostname = hostname or '' + self.username = username or "" + self.password = password or "" + self.hostname = hostname or "" if port: port = int(port) - elif scheme == 'https': + elif scheme == "https": port = 443 self.port = port or 80 - self.path = path or '' - self.query = query or '' + self.path = path or "" + self.query = query or "" if scheme: self.scheme = scheme elif self.port in POTENTIAL_HTTPS_PORTS: - self.scheme = 'https' + self.scheme = "https" else: - self.scheme = 'http' - self.fragment = fragment or '' + self.scheme = "http" + self.fragment = fragment or "" self.last_request = last_request def __str__(self): - attrs = ( - 'username', - 'password', - 'hostname', - 'port', - 'path', - ) - fmt = ", ".join(['%s="%s"' % (k, getattr(self, k, '')) for k in attrs]) - return r'' % fmt + attrs = ("username", "password", "hostname", "port", "path") + fmt = ", ".join(['%s="%s"' % (k, getattr(self, k, "")) for k in attrs]) + return r"" % fmt def __hash__(self): return hash(text_type(self)) @@ -745,8 +737,7 @@ class URIInfo(BaseClass): def full_url(self, use_querystring=True): credentials = "" if self.password: - credentials = "{0}:{1}@".format( - self.username, self.password) + credentials = "{0}:{1}@".format(self.username, self.password) query = "" if use_querystring and self.query: @@ -757,7 +748,7 @@ class URIInfo(BaseClass): credentials=credentials, domain=self.get_full_domain(), path=decode_utf8(self.path), - query=query + query=query, ) return result @@ -772,19 +763,21 @@ class URIInfo(BaseClass): @classmethod def from_uri(cls, uri, entry): result = urlsplit(uri) - if result.scheme == 'https': + if result.scheme == "https": POTENTIAL_HTTPS_PORTS.add(int(result.port or 443)) else: POTENTIAL_HTTP_PORTS.add(int(result.port or 80)) - return cls(result.username, - result.password, - result.hostname, - result.port, - result.path, - result.query, - result.fragment, - result.scheme, - entry) + return cls( + result.username, + result.password, + result.hostname, + result.port, + result.path, + result.query, + result.fragment, + result.scheme, + entry, + ) class URIMatcher(object): @@ -793,10 +786,10 @@ class URIMatcher(object): def __init__(self, uri, entries, match_querystring=False): self._match_querystring = match_querystring - if type(uri).__name__ in ('SRE_Pattern', 'Pattern'): + if type(uri).__name__ in ("SRE_Pattern", "Pattern"): self.regex = uri result = urlsplit(uri.pattern) - if result.scheme == 'https': + if result.scheme == "https": POTENTIAL_HTTPS_PORTS.add(int(result.port or 443)) else: POTENTIAL_HTTP_PORTS.add(int(result.port or 80)) @@ -812,11 +805,12 @@ class URIMatcher(object): if self.info: return self.info == info else: - return self.regex.search(info.full_url( - use_querystring=self._match_querystring)) + return self.regex.search( + info.full_url(use_querystring=self._match_querystring) + ) def __str__(self): - wrap = 'URLMatcher({0})' + wrap = "URLMatcher({0})" if self.info: return wrap.format(text_type(self.info)) else: @@ -836,8 +830,7 @@ class URIMatcher(object): self.current_entries[method] = -1 if not self.entries or not entries_for_method: - raise ValueError('I have no entries for method %s: %s' - % (method, self)) + raise ValueError("I have no entries for method %s: %s" % (method, self)) entry = entries_for_method[self.current_entries[method]] if self.current_entries[method] != -1: @@ -861,6 +854,7 @@ class URIMatcher(object): class httpretty(HttpBaseClass): """The URI registration class""" + _entries = {} latest_requests = [] @@ -878,12 +872,13 @@ class httpretty(HttpBaseClass): @classmethod @contextlib.contextmanager - def record(cls, filename, indentation=4, encoding='utf-8'): + def record(cls, filename, indentation=4, encoding="utf-8"): try: import urllib3 except ImportError: raise RuntimeError( - 'HTTPretty requires urllib3 installed for recording actual requests.') + "HTTPretty requires urllib3 installed for recording actual requests." + ) http = urllib3.PoolManager() @@ -894,30 +889,31 @@ class httpretty(HttpBaseClass): cls.disable() response = http.request(request.method, uri) - calls.append({ - 'request': { - 'uri': uri, - 'method': request.method, - 'headers': dict(request.headers), - 'body': decode_utf8(request.body), - 'querystring': request.querystring - }, - 'response': { - 'status': response.status, - 'body': decode_utf8(response.data), - 'headers': dict(response.headers) + calls.append( + { + "request": { + "uri": uri, + "method": request.method, + "headers": dict(request.headers), + "body": decode_utf8(request.body), + "querystring": request.querystring, + }, + "response": { + "status": response.status, + "body": decode_utf8(response.data), + "headers": dict(response.headers), + }, } - }) + ) cls.enable() return response.status, response.headers, response.data for method in cls.METHODS: - cls.register_uri(method, re.compile( - r'.*', re.M), body=record_request) + cls.register_uri(method, re.compile(r".*", re.M), body=record_request) yield cls.disable() - with codecs.open(filename, 'w', encoding) as f: + with codecs.open(filename, "w", encoding) as f: f.write(json.dumps(calls, indent=indentation)) @classmethod @@ -927,10 +923,14 @@ class httpretty(HttpBaseClass): data = json.loads(open(origin).read()) for item in data: - uri = item['request']['uri'] - method = item['request']['method'] - cls.register_uri(method, uri, body=item['response'][ - 'body'], forcing_headers=item['response']['headers']) + uri = item["request"]["uri"] + method = item["request"]["method"] + cls.register_uri( + method, + uri, + body=item["response"]["body"], + forcing_headers=item["response"]["headers"], + ) yield cls.disable() @@ -944,7 +944,7 @@ class httpretty(HttpBaseClass): cls.last_request = HTTPrettyRequestEmpty() @classmethod - def historify_request(cls, headers, body='', append=True): + def historify_request(cls, headers, body="", append=True): request = HTTPrettyRequest(headers, body) cls.last_request = request if append or not cls.latest_requests: @@ -954,17 +954,23 @@ class httpretty(HttpBaseClass): return request @classmethod - def register_uri(cls, method, uri, body='HTTPretty :)', - adding_headers=None, - forcing_headers=None, - status=200, - responses=None, match_querystring=False, - **headers): + def register_uri( + cls, + method, + uri, + body="HTTPretty :)", + adding_headers=None, + forcing_headers=None, + status=200, + responses=None, + match_querystring=False, + **headers + ): uri_is_string = isinstance(uri, basestring) - if uri_is_string and re.search(r'^\w+://[^/]+[.]\w{2,}$', uri): - uri += '/' + if uri_is_string and re.search(r"^\w+://[^/]+[.]\w{2,}$", uri): + uri += "/" if isinstance(responses, list) and len(responses) > 0: for response in responses: @@ -972,17 +978,14 @@ class httpretty(HttpBaseClass): response.method = method entries_for_this_uri = responses else: - headers[str('body')] = body - headers[str('adding_headers')] = adding_headers - headers[str('forcing_headers')] = forcing_headers - headers[str('status')] = status + headers[str("body")] = body + headers[str("adding_headers")] = adding_headers + headers[str("forcing_headers")] = forcing_headers + headers[str("status")] = status - entries_for_this_uri = [ - cls.Response(method=method, uri=uri, **headers), - ] + entries_for_this_uri = [cls.Response(method=method, uri=uri, **headers)] - matcher = URIMatcher(uri, entries_for_this_uri, - match_querystring) + matcher = URIMatcher(uri, entries_for_this_uri, match_querystring) if matcher in cls._entries: matcher.entries.extend(cls._entries[matcher]) del cls._entries[matcher] @@ -990,17 +993,26 @@ class httpretty(HttpBaseClass): cls._entries[matcher] = entries_for_this_uri def __str__(self): - return '' % len(self._entries) + return "" % len(self._entries) @classmethod - def Response(cls, body, method=None, uri=None, adding_headers=None, forcing_headers=None, - status=200, streaming=False, **headers): + def Response( + cls, + body, + method=None, + uri=None, + adding_headers=None, + forcing_headers=None, + status=200, + streaming=False, + **headers + ): - headers[str('body')] = body - headers[str('adding_headers')] = adding_headers - headers[str('forcing_headers')] = forcing_headers - headers[str('status')] = int(status) - headers[str('streaming')] = streaming + headers[str("body")] = body + headers[str("adding_headers")] = adding_headers + headers[str("forcing_headers")] = forcing_headers + headers[str("status")] = int(status) + headers[str("streaming")] = streaming return Entry(method, uri, **headers) @classmethod @@ -1016,19 +1028,19 @@ class httpretty(HttpBaseClass): socket.gethostbyname = old_gethostbyname socket.getaddrinfo = old_getaddrinfo - socket.__dict__['socket'] = old_socket - socket.__dict__['_socketobject'] = old_socket + socket.__dict__["socket"] = old_socket + socket.__dict__["_socketobject"] = old_socket if not BAD_SOCKET_SHADOW: - socket.__dict__['SocketType'] = old_socket + socket.__dict__["SocketType"] = old_socket - socket.__dict__['create_connection'] = old_create_connection - socket.__dict__['gethostname'] = old_gethostname - socket.__dict__['gethostbyname'] = old_gethostbyname - socket.__dict__['getaddrinfo'] = old_getaddrinfo + socket.__dict__["create_connection"] = old_create_connection + socket.__dict__["gethostname"] = old_gethostname + socket.__dict__["gethostbyname"] = old_gethostbyname + socket.__dict__["getaddrinfo"] = old_getaddrinfo if socks: socks.socksocket = old_socksocket - socks.__dict__['socksocket'] = old_socksocket + socks.__dict__["socksocket"] = old_socksocket if ssl: ssl.wrap_socket = old_ssl_wrap_socket @@ -1037,12 +1049,12 @@ class httpretty(HttpBaseClass): ssl.SSLContext.wrap_socket = old_sslcontext_wrap_socket except AttributeError: pass - ssl.__dict__['wrap_socket'] = old_ssl_wrap_socket - ssl.__dict__['SSLSocket'] = old_sslsocket + ssl.__dict__["wrap_socket"] = old_ssl_wrap_socket + ssl.__dict__["SSLSocket"] = old_sslsocket if not PY3: ssl.sslwrap_simple = old_sslwrap_simple - ssl.__dict__['sslwrap_simple'] = old_sslwrap_simple + ssl.__dict__["sslwrap_simple"] = old_sslwrap_simple if pyopenssl_override: inject_into_urllib3() @@ -1065,25 +1077,26 @@ class httpretty(HttpBaseClass): socket.gethostbyname = fake_gethostbyname socket.getaddrinfo = fake_getaddrinfo - socket.__dict__['socket'] = fakesock.socket - socket.__dict__['_socketobject'] = fakesock.socket + socket.__dict__["socket"] = fakesock.socket + socket.__dict__["_socketobject"] = fakesock.socket if not BAD_SOCKET_SHADOW: - socket.__dict__['SocketType'] = fakesock.socket + socket.__dict__["SocketType"] = fakesock.socket - socket.__dict__['create_connection'] = create_fake_connection - socket.__dict__['gethostname'] = fake_gethostname - socket.__dict__['gethostbyname'] = fake_gethostbyname - socket.__dict__['getaddrinfo'] = fake_getaddrinfo + socket.__dict__["create_connection"] = create_fake_connection + socket.__dict__["gethostname"] = fake_gethostname + socket.__dict__["gethostbyname"] = fake_gethostbyname + socket.__dict__["getaddrinfo"] = fake_getaddrinfo if socks: socks.socksocket = fakesock.socket - socks.__dict__['socksocket'] = fakesock.socket + socks.__dict__["socksocket"] = fakesock.socket if ssl: ssl.wrap_socket = fake_wrap_socket ssl.SSLSocket = FakeSSLSocket try: + def fake_sslcontext_wrap_socket(cls, *args, **kwargs): return fake_wrap_socket(*args, **kwargs) @@ -1091,12 +1104,12 @@ class httpretty(HttpBaseClass): except AttributeError: pass - ssl.__dict__['wrap_socket'] = fake_wrap_socket - ssl.__dict__['SSLSocket'] = FakeSSLSocket + ssl.__dict__["wrap_socket"] = fake_wrap_socket + ssl.__dict__["SSLSocket"] = FakeSSLSocket if not PY3: ssl.sslwrap_simple = fake_wrap_socket - ssl.__dict__['sslwrap_simple'] = fake_wrap_socket + ssl.__dict__["sslwrap_simple"] = fake_wrap_socket if pyopenssl_override: extract_from_urllib3() @@ -1104,9 +1117,10 @@ class httpretty(HttpBaseClass): def httprettified(test): "A decorator tests that use HTTPretty" + def decorate_class(klass): for attr in dir(klass): - if not attr.startswith('test_'): + if not attr.startswith("test_"): continue attr_value = getattr(klass, attr) @@ -1125,8 +1139,9 @@ def httprettified(test): return test(*args, **kw) finally: httpretty.disable() + return wrapper if isinstance(test, ClassTypes): return decorate_class(test) - return decorate_callable(test) \ No newline at end of file + return decorate_callable(test) diff --git a/moto/packages/httpretty/errors.py b/moto/packages/httpretty/errors.py index e2dcad357..8221e5f66 100644 --- a/moto/packages/httpretty/errors.py +++ b/moto/packages/httpretty/errors.py @@ -32,9 +32,8 @@ class HTTPrettyError(Exception): class UnmockedError(HTTPrettyError): - def __init__(self): super(UnmockedError, self).__init__( - 'No mocking was registered, and real connections are ' - 'not allowed (httpretty.allow_net_connect = False).' + "No mocking was registered, and real connections are " + "not allowed (httpretty.allow_net_connect = False)." ) diff --git a/moto/packages/httpretty/http.py b/moto/packages/httpretty/http.py index ee1625905..20c00707e 100644 --- a/moto/packages/httpretty/http.py +++ b/moto/packages/httpretty/http.py @@ -109,14 +109,14 @@ STATUSES = { class HttpBaseClass(BaseClass): - GET = 'GET' - PUT = 'PUT' - POST = 'POST' - DELETE = 'DELETE' - HEAD = 'HEAD' - PATCH = 'PATCH' - OPTIONS = 'OPTIONS' - CONNECT = 'CONNECT' + GET = "GET" + PUT = "PUT" + POST = "POST" + DELETE = "DELETE" + HEAD = "HEAD" + PATCH = "PATCH" + OPTIONS = "OPTIONS" + CONNECT = "CONNECT" METHODS = (GET, PUT, POST, DELETE, HEAD, PATCH, OPTIONS, CONNECT) @@ -133,12 +133,12 @@ def parse_requestline(s): ... ValueError: Not a Request-Line """ - methods = '|'.join(HttpBaseClass.METHODS) - m = re.match(r'(' + methods + ')\s+(.*)\s+HTTP/(1.[0|1])', s, re.I) + methods = "|".join(HttpBaseClass.METHODS) + m = re.match(r"(" + methods + ")\s+(.*)\s+HTTP/(1.[0|1])", s, re.I) if m: return m.group(1).upper(), m.group(2), m.group(3) else: - raise ValueError('Not a Request-Line') + raise ValueError("Not a Request-Line") def last_requestline(sent_data): diff --git a/moto/packages/httpretty/utils.py b/moto/packages/httpretty/utils.py index caa8fa13b..2bf5d0829 100644 --- a/moto/packages/httpretty/utils.py +++ b/moto/packages/httpretty/utils.py @@ -25,14 +25,12 @@ # OTHER DEALINGS IN THE SOFTWARE. from __future__ import unicode_literals -from .compat import ( - byte_type, text_type -) +from .compat import byte_type, text_type def utf8(s): if isinstance(s, text_type): - s = s.encode('utf-8') + s = s.encode("utf-8") elif s is None: return byte_type() diff --git a/moto/polly/__init__.py b/moto/polly/__init__.py index 9c2281126..6db0215de 100644 --- a/moto/polly/__init__.py +++ b/moto/polly/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import polly_backends from ..core.models import base_decorator -polly_backend = polly_backends['us-east-1'] +polly_backend = polly_backends["us-east-1"] mock_polly = base_decorator(polly_backends) diff --git a/moto/polly/models.py b/moto/polly/models.py index e7b7117dc..3be5b7a0b 100644 --- a/moto/polly/models.py +++ b/moto/polly/models.py @@ -32,33 +32,36 @@ class Lexicon(BaseModel): try: root = ET.fromstring(self.content) self.size = len(self.content) - self.last_modified = int((datetime.datetime.now() - - datetime.datetime(1970, 1, 1)).total_seconds()) - self.lexemes_count = len(root.findall('.')) + self.last_modified = int( + ( + datetime.datetime.now() - datetime.datetime(1970, 1, 1) + ).total_seconds() + ) + self.lexemes_count = len(root.findall(".")) for key, value in root.attrib.items(): - if key.endswith('alphabet'): + if key.endswith("alphabet"): self.alphabet = value - elif key.endswith('lang'): + elif key.endswith("lang"): self.language_code = value except Exception as err: - raise ValueError('Failure parsing XML: {0}'.format(err)) + raise ValueError("Failure parsing XML: {0}".format(err)) def to_dict(self): return { - 'Attributes': { - 'Alphabet': self.alphabet, - 'LanguageCode': self.language_code, - 'LastModified': self.last_modified, - 'LexemesCount': self.lexemes_count, - 'LexiconArn': self.arn, - 'Size': self.size + "Attributes": { + "Alphabet": self.alphabet, + "LanguageCode": self.language_code, + "LastModified": self.last_modified, + "LexemesCount": self.lexemes_count, + "LexiconArn": self.arn, + "Size": self.size, } } def __repr__(self): - return ''.format(self.name) + return "".format(self.name) class PollyBackend(BaseBackend): @@ -77,7 +80,7 @@ class PollyBackend(BaseBackend): if language_code is None: return VOICE_DATA - return [item for item in VOICE_DATA if item['LanguageCode'] == language_code] + return [item for item in VOICE_DATA if item["LanguageCode"] == language_code] def delete_lexicon(self, name): # implement here @@ -93,7 +96,7 @@ class PollyBackend(BaseBackend): for name, lexicon in self._lexicons.items(): lexicon_dict = lexicon.to_dict() - lexicon_dict['Name'] = name + lexicon_dict["Name"] = name result.append(lexicon_dict) @@ -111,4 +114,6 @@ class PollyBackend(BaseBackend): available_regions = boto3.session.Session().get_available_regions("polly") -polly_backends = {region: PollyBackend(region_name=region) for region in available_regions} +polly_backends = { + region: PollyBackend(region_name=region) for region in available_regions +} diff --git a/moto/polly/resources.py b/moto/polly/resources.py index f4ad69a98..560e62b7b 100644 --- a/moto/polly/resources.py +++ b/moto/polly/resources.py @@ -1,63 +1,418 @@ # -*- coding: utf-8 -*- VOICE_DATA = [ - {'Id': 'Joanna', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Joanna'}, - {'Id': 'Mizuki', 'LanguageCode': 'ja-JP', 'LanguageName': 'Japanese', 'Gender': 'Female', 'Name': 'Mizuki'}, - {'Id': 'Filiz', 'LanguageCode': 'tr-TR', 'LanguageName': 'Turkish', 'Gender': 'Female', 'Name': 'Filiz'}, - {'Id': 'Astrid', 'LanguageCode': 'sv-SE', 'LanguageName': 'Swedish', 'Gender': 'Female', 'Name': 'Astrid'}, - {'Id': 'Tatyana', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Female', 'Name': 'Tatyana'}, - {'Id': 'Maxim', 'LanguageCode': 'ru-RU', 'LanguageName': 'Russian', 'Gender': 'Male', 'Name': 'Maxim'}, - {'Id': 'Carmen', 'LanguageCode': 'ro-RO', 'LanguageName': 'Romanian', 'Gender': 'Female', 'Name': 'Carmen'}, - {'Id': 'Ines', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Female', 'Name': 'Inês'}, - {'Id': 'Cristiano', 'LanguageCode': 'pt-PT', 'LanguageName': 'Portuguese', 'Gender': 'Male', 'Name': 'Cristiano'}, - {'Id': 'Vitoria', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Female', 'Name': 'Vitória'}, - {'Id': 'Ricardo', 'LanguageCode': 'pt-BR', 'LanguageName': 'Brazilian Portuguese', 'Gender': 'Male', 'Name': 'Ricardo'}, - {'Id': 'Maja', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Maja'}, - {'Id': 'Jan', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jan'}, - {'Id': 'Ewa', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Female', 'Name': 'Ewa'}, - {'Id': 'Ruben', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Male', 'Name': 'Ruben'}, - {'Id': 'Lotte', 'LanguageCode': 'nl-NL', 'LanguageName': 'Dutch', 'Gender': 'Female', 'Name': 'Lotte'}, - {'Id': 'Liv', 'LanguageCode': 'nb-NO', 'LanguageName': 'Norwegian', 'Gender': 'Female', 'Name': 'Liv'}, - {'Id': 'Giorgio', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Male', 'Name': 'Giorgio'}, - {'Id': 'Carla', 'LanguageCode': 'it-IT', 'LanguageName': 'Italian', 'Gender': 'Female', 'Name': 'Carla'}, - {'Id': 'Karl', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Male', 'Name': 'Karl'}, - {'Id': 'Dora', 'LanguageCode': 'is-IS', 'LanguageName': 'Icelandic', 'Gender': 'Female', 'Name': 'Dóra'}, - {'Id': 'Mathieu', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Male', 'Name': 'Mathieu'}, - {'Id': 'Celine', 'LanguageCode': 'fr-FR', 'LanguageName': 'French', 'Gender': 'Female', 'Name': 'Céline'}, - {'Id': 'Chantal', 'LanguageCode': 'fr-CA', 'LanguageName': 'Canadian French', 'Gender': 'Female', 'Name': 'Chantal'}, - {'Id': 'Penelope', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Female', 'Name': 'Penélope'}, - {'Id': 'Miguel', 'LanguageCode': 'es-US', 'LanguageName': 'US Spanish', 'Gender': 'Male', 'Name': 'Miguel'}, - {'Id': 'Enrique', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Male', 'Name': 'Enrique'}, - {'Id': 'Conchita', 'LanguageCode': 'es-ES', 'LanguageName': 'Castilian Spanish', 'Gender': 'Female', 'Name': 'Conchita'}, - {'Id': 'Geraint', 'LanguageCode': 'en-GB-WLS', 'LanguageName': 'Welsh English', 'Gender': 'Male', 'Name': 'Geraint'}, - {'Id': 'Salli', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Salli'}, - {'Id': 'Kimberly', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kimberly'}, - {'Id': 'Kendra', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Kendra'}, - {'Id': 'Justin', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Justin'}, - {'Id': 'Joey', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Male', 'Name': 'Joey'}, - {'Id': 'Ivy', 'LanguageCode': 'en-US', 'LanguageName': 'US English', 'Gender': 'Female', 'Name': 'Ivy'}, - {'Id': 'Raveena', 'LanguageCode': 'en-IN', 'LanguageName': 'Indian English', 'Gender': 'Female', 'Name': 'Raveena'}, - {'Id': 'Emma', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Emma'}, - {'Id': 'Brian', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Male', 'Name': 'Brian'}, - {'Id': 'Amy', 'LanguageCode': 'en-GB', 'LanguageName': 'British English', 'Gender': 'Female', 'Name': 'Amy'}, - {'Id': 'Russell', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Male', 'Name': 'Russell'}, - {'Id': 'Nicole', 'LanguageCode': 'en-AU', 'LanguageName': 'Australian English', 'Gender': 'Female', 'Name': 'Nicole'}, - {'Id': 'Vicki', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Vicki'}, - {'Id': 'Marlene', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Female', 'Name': 'Marlene'}, - {'Id': 'Hans', 'LanguageCode': 'de-DE', 'LanguageName': 'German', 'Gender': 'Male', 'Name': 'Hans'}, - {'Id': 'Naja', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Female', 'Name': 'Naja'}, - {'Id': 'Mads', 'LanguageCode': 'da-DK', 'LanguageName': 'Danish', 'Gender': 'Male', 'Name': 'Mads'}, - {'Id': 'Gwyneth', 'LanguageCode': 'cy-GB', 'LanguageName': 'Welsh', 'Gender': 'Female', 'Name': 'Gwyneth'}, - {'Id': 'Jacek', 'LanguageCode': 'pl-PL', 'LanguageName': 'Polish', 'Gender': 'Male', 'Name': 'Jacek'} + { + "Id": "Joanna", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Joanna", + }, + { + "Id": "Mizuki", + "LanguageCode": "ja-JP", + "LanguageName": "Japanese", + "Gender": "Female", + "Name": "Mizuki", + }, + { + "Id": "Filiz", + "LanguageCode": "tr-TR", + "LanguageName": "Turkish", + "Gender": "Female", + "Name": "Filiz", + }, + { + "Id": "Astrid", + "LanguageCode": "sv-SE", + "LanguageName": "Swedish", + "Gender": "Female", + "Name": "Astrid", + }, + { + "Id": "Tatyana", + "LanguageCode": "ru-RU", + "LanguageName": "Russian", + "Gender": "Female", + "Name": "Tatyana", + }, + { + "Id": "Maxim", + "LanguageCode": "ru-RU", + "LanguageName": "Russian", + "Gender": "Male", + "Name": "Maxim", + }, + { + "Id": "Carmen", + "LanguageCode": "ro-RO", + "LanguageName": "Romanian", + "Gender": "Female", + "Name": "Carmen", + }, + { + "Id": "Ines", + "LanguageCode": "pt-PT", + "LanguageName": "Portuguese", + "Gender": "Female", + "Name": "Inês", + }, + { + "Id": "Cristiano", + "LanguageCode": "pt-PT", + "LanguageName": "Portuguese", + "Gender": "Male", + "Name": "Cristiano", + }, + { + "Id": "Vitoria", + "LanguageCode": "pt-BR", + "LanguageName": "Brazilian Portuguese", + "Gender": "Female", + "Name": "Vitória", + }, + { + "Id": "Ricardo", + "LanguageCode": "pt-BR", + "LanguageName": "Brazilian Portuguese", + "Gender": "Male", + "Name": "Ricardo", + }, + { + "Id": "Maja", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Female", + "Name": "Maja", + }, + { + "Id": "Jan", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Male", + "Name": "Jan", + }, + { + "Id": "Ewa", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Female", + "Name": "Ewa", + }, + { + "Id": "Ruben", + "LanguageCode": "nl-NL", + "LanguageName": "Dutch", + "Gender": "Male", + "Name": "Ruben", + }, + { + "Id": "Lotte", + "LanguageCode": "nl-NL", + "LanguageName": "Dutch", + "Gender": "Female", + "Name": "Lotte", + }, + { + "Id": "Liv", + "LanguageCode": "nb-NO", + "LanguageName": "Norwegian", + "Gender": "Female", + "Name": "Liv", + }, + { + "Id": "Giorgio", + "LanguageCode": "it-IT", + "LanguageName": "Italian", + "Gender": "Male", + "Name": "Giorgio", + }, + { + "Id": "Carla", + "LanguageCode": "it-IT", + "LanguageName": "Italian", + "Gender": "Female", + "Name": "Carla", + }, + { + "Id": "Karl", + "LanguageCode": "is-IS", + "LanguageName": "Icelandic", + "Gender": "Male", + "Name": "Karl", + }, + { + "Id": "Dora", + "LanguageCode": "is-IS", + "LanguageName": "Icelandic", + "Gender": "Female", + "Name": "Dóra", + }, + { + "Id": "Mathieu", + "LanguageCode": "fr-FR", + "LanguageName": "French", + "Gender": "Male", + "Name": "Mathieu", + }, + { + "Id": "Celine", + "LanguageCode": "fr-FR", + "LanguageName": "French", + "Gender": "Female", + "Name": "Céline", + }, + { + "Id": "Chantal", + "LanguageCode": "fr-CA", + "LanguageName": "Canadian French", + "Gender": "Female", + "Name": "Chantal", + }, + { + "Id": "Penelope", + "LanguageCode": "es-US", + "LanguageName": "US Spanish", + "Gender": "Female", + "Name": "Penélope", + }, + { + "Id": "Miguel", + "LanguageCode": "es-US", + "LanguageName": "US Spanish", + "Gender": "Male", + "Name": "Miguel", + }, + { + "Id": "Enrique", + "LanguageCode": "es-ES", + "LanguageName": "Castilian Spanish", + "Gender": "Male", + "Name": "Enrique", + }, + { + "Id": "Conchita", + "LanguageCode": "es-ES", + "LanguageName": "Castilian Spanish", + "Gender": "Female", + "Name": "Conchita", + }, + { + "Id": "Geraint", + "LanguageCode": "en-GB-WLS", + "LanguageName": "Welsh English", + "Gender": "Male", + "Name": "Geraint", + }, + { + "Id": "Salli", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Salli", + }, + { + "Id": "Kimberly", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Kimberly", + }, + { + "Id": "Kendra", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Kendra", + }, + { + "Id": "Justin", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Male", + "Name": "Justin", + }, + { + "Id": "Joey", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Male", + "Name": "Joey", + }, + { + "Id": "Ivy", + "LanguageCode": "en-US", + "LanguageName": "US English", + "Gender": "Female", + "Name": "Ivy", + }, + { + "Id": "Raveena", + "LanguageCode": "en-IN", + "LanguageName": "Indian English", + "Gender": "Female", + "Name": "Raveena", + }, + { + "Id": "Emma", + "LanguageCode": "en-GB", + "LanguageName": "British English", + "Gender": "Female", + "Name": "Emma", + }, + { + "Id": "Brian", + "LanguageCode": "en-GB", + "LanguageName": "British English", + "Gender": "Male", + "Name": "Brian", + }, + { + "Id": "Amy", + "LanguageCode": "en-GB", + "LanguageName": "British English", + "Gender": "Female", + "Name": "Amy", + }, + { + "Id": "Russell", + "LanguageCode": "en-AU", + "LanguageName": "Australian English", + "Gender": "Male", + "Name": "Russell", + }, + { + "Id": "Nicole", + "LanguageCode": "en-AU", + "LanguageName": "Australian English", + "Gender": "Female", + "Name": "Nicole", + }, + { + "Id": "Vicki", + "LanguageCode": "de-DE", + "LanguageName": "German", + "Gender": "Female", + "Name": "Vicki", + }, + { + "Id": "Marlene", + "LanguageCode": "de-DE", + "LanguageName": "German", + "Gender": "Female", + "Name": "Marlene", + }, + { + "Id": "Hans", + "LanguageCode": "de-DE", + "LanguageName": "German", + "Gender": "Male", + "Name": "Hans", + }, + { + "Id": "Naja", + "LanguageCode": "da-DK", + "LanguageName": "Danish", + "Gender": "Female", + "Name": "Naja", + }, + { + "Id": "Mads", + "LanguageCode": "da-DK", + "LanguageName": "Danish", + "Gender": "Male", + "Name": "Mads", + }, + { + "Id": "Gwyneth", + "LanguageCode": "cy-GB", + "LanguageName": "Welsh", + "Gender": "Female", + "Name": "Gwyneth", + }, + { + "Id": "Jacek", + "LanguageCode": "pl-PL", + "LanguageName": "Polish", + "Gender": "Male", + "Name": "Jacek", + }, ] # {...} is also shorthand set syntax -LANGUAGE_CODES = {'cy-GB', 'da-DK', 'de-DE', 'en-AU', 'en-GB', 'en-GB-WLS', 'en-IN', 'en-US', 'es-ES', 'es-US', - 'fr-CA', 'fr-FR', 'is-IS', 'it-IT', 'ja-JP', 'nb-NO', 'nl-NL', 'pl-PL', 'pt-BR', 'pt-PT', 'ro-RO', - 'ru-RU', 'sv-SE', 'tr-TR'} +LANGUAGE_CODES = { + "cy-GB", + "da-DK", + "de-DE", + "en-AU", + "en-GB", + "en-GB-WLS", + "en-IN", + "en-US", + "es-ES", + "es-US", + "fr-CA", + "fr-FR", + "is-IS", + "it-IT", + "ja-JP", + "nb-NO", + "nl-NL", + "pl-PL", + "pt-BR", + "pt-PT", + "ro-RO", + "ru-RU", + "sv-SE", + "tr-TR", +} -VOICE_IDS = {'Geraint', 'Gwyneth', 'Mads', 'Naja', 'Hans', 'Marlene', 'Nicole', 'Russell', 'Amy', 'Brian', 'Emma', - 'Raveena', 'Ivy', 'Joanna', 'Joey', 'Justin', 'Kendra', 'Kimberly', 'Salli', 'Conchita', 'Enrique', - 'Miguel', 'Penelope', 'Chantal', 'Celine', 'Mathieu', 'Dora', 'Karl', 'Carla', 'Giorgio', 'Mizuki', - 'Liv', 'Lotte', 'Ruben', 'Ewa', 'Jacek', 'Jan', 'Maja', 'Ricardo', 'Vitoria', 'Cristiano', 'Ines', - 'Carmen', 'Maxim', 'Tatyana', 'Astrid', 'Filiz'} +VOICE_IDS = { + "Geraint", + "Gwyneth", + "Mads", + "Naja", + "Hans", + "Marlene", + "Nicole", + "Russell", + "Amy", + "Brian", + "Emma", + "Raveena", + "Ivy", + "Joanna", + "Joey", + "Justin", + "Kendra", + "Kimberly", + "Salli", + "Conchita", + "Enrique", + "Miguel", + "Penelope", + "Chantal", + "Celine", + "Mathieu", + "Dora", + "Karl", + "Carla", + "Giorgio", + "Mizuki", + "Liv", + "Lotte", + "Ruben", + "Ewa", + "Jacek", + "Jan", + "Maja", + "Ricardo", + "Vitoria", + "Cristiano", + "Ines", + "Carmen", + "Maxim", + "Tatyana", + "Astrid", + "Filiz", +} diff --git a/moto/polly/responses.py b/moto/polly/responses.py index 810264424..e7de01b2b 100644 --- a/moto/polly/responses.py +++ b/moto/polly/responses.py @@ -9,7 +9,7 @@ from moto.core.responses import BaseResponse from .models import polly_backends from .resources import LANGUAGE_CODES, VOICE_IDS -LEXICON_NAME_REGEX = re.compile(r'^[0-9A-Za-z]{1,20}$') +LEXICON_NAME_REGEX = re.compile(r"^[0-9A-Za-z]{1,20}$") class PollyResponse(BaseResponse): @@ -19,71 +19,75 @@ class PollyResponse(BaseResponse): @property def json(self): - if not hasattr(self, '_json'): + if not hasattr(self, "_json"): self._json = json.loads(self.body) return self._json def _error(self, code, message): - return json.dumps({'__type': code, 'message': message}), dict(status=400) + return json.dumps({"__type": code, "message": message}), dict(status=400) def _get_action(self): # Amazon is now naming things /v1/api_name - url_parts = urlsplit(self.uri).path.lstrip('/').split('/') + url_parts = urlsplit(self.uri).path.lstrip("/").split("/") # [0] = 'v1' return url_parts[1] # DescribeVoices def voices(self): - language_code = self._get_param('LanguageCode') - next_token = self._get_param('NextToken') + language_code = self._get_param("LanguageCode") + next_token = self._get_param("NextToken") if language_code is not None and language_code not in LANGUAGE_CODES: - msg = "1 validation error detected: Value '{0}' at 'languageCode' failed to satisfy constraint: " \ - "Member must satisfy enum value set: [{1}]".format(language_code, ', '.join(LANGUAGE_CODES)) + msg = ( + "1 validation error detected: Value '{0}' at 'languageCode' failed to satisfy constraint: " + "Member must satisfy enum value set: [{1}]".format( + language_code, ", ".join(LANGUAGE_CODES) + ) + ) return msg, dict(status=400) voices = self.polly_backend.describe_voices(language_code, next_token) - return json.dumps({'Voices': voices}) + return json.dumps({"Voices": voices}) def lexicons(self): # Dish out requests based on methods # anything after the /v1/lexicons/ - args = urlsplit(self.uri).path.lstrip('/').split('/')[2:] + args = urlsplit(self.uri).path.lstrip("/").split("/")[2:] - if self.method == 'GET': + if self.method == "GET": if len(args) == 0: return self._get_lexicons_list() else: return self._get_lexicon(*args) - elif self.method == 'PUT': + elif self.method == "PUT": return self._put_lexicons(*args) - elif self.method == 'DELETE': + elif self.method == "DELETE": return self._delete_lexicon(*args) - return self._error('InvalidAction', 'Bad route') + return self._error("InvalidAction", "Bad route") # PutLexicon def _put_lexicons(self, lexicon_name): if LEXICON_NAME_REGEX.match(lexicon_name) is None: - return self._error('InvalidParameterValue', 'Lexicon name must match [0-9A-Za-z]{1,20}') + return self._error( + "InvalidParameterValue", "Lexicon name must match [0-9A-Za-z]{1,20}" + ) - if 'Content' not in self.json: - return self._error('MissingParameter', 'Content is missing from the body') + if "Content" not in self.json: + return self._error("MissingParameter", "Content is missing from the body") - self.polly_backend.put_lexicon(lexicon_name, self.json['Content']) + self.polly_backend.put_lexicon(lexicon_name, self.json["Content"]) - return '' + return "" # ListLexicons def _get_lexicons_list(self): - next_token = self._get_param('NextToken') + next_token = self._get_param("NextToken") - result = { - 'Lexicons': self.polly_backend.list_lexicons(next_token) - } + result = {"Lexicons": self.polly_backend.list_lexicons(next_token)} return json.dumps(result) @@ -92,14 +96,11 @@ class PollyResponse(BaseResponse): try: lexicon = self.polly_backend.get_lexicon(lexicon_name) except KeyError: - return self._error('LexiconNotFoundException', 'Lexicon not found') + return self._error("LexiconNotFoundException", "Lexicon not found") result = { - 'Lexicon': { - 'Name': lexicon_name, - 'Content': lexicon.content - }, - 'LexiconAttributes': lexicon.to_dict()['Attributes'] + "Lexicon": {"Name": lexicon_name, "Content": lexicon.content}, + "LexiconAttributes": lexicon.to_dict()["Attributes"], } return json.dumps(result) @@ -109,80 +110,94 @@ class PollyResponse(BaseResponse): try: self.polly_backend.delete_lexicon(lexicon_name) except KeyError: - return self._error('LexiconNotFoundException', 'Lexicon not found') + return self._error("LexiconNotFoundException", "Lexicon not found") - return '' + return "" # SynthesizeSpeech def speech(self): # Sanity check params args = { - 'lexicon_names': None, - 'sample_rate': 22050, - 'speech_marks': None, - 'text': None, - 'text_type': 'text' + "lexicon_names": None, + "sample_rate": 22050, + "speech_marks": None, + "text": None, + "text_type": "text", } - if 'LexiconNames' in self.json: - for lex in self.json['LexiconNames']: + if "LexiconNames" in self.json: + for lex in self.json["LexiconNames"]: try: self.polly_backend.get_lexicon(lex) except KeyError: - return self._error('LexiconNotFoundException', 'Lexicon not found') + return self._error("LexiconNotFoundException", "Lexicon not found") - args['lexicon_names'] = self.json['LexiconNames'] + args["lexicon_names"] = self.json["LexiconNames"] - if 'OutputFormat' not in self.json: - return self._error('MissingParameter', 'Missing parameter OutputFormat') - if self.json['OutputFormat'] not in ('json', 'mp3', 'ogg_vorbis', 'pcm'): - return self._error('InvalidParameterValue', 'Not one of json, mp3, ogg_vorbis, pcm') - args['output_format'] = self.json['OutputFormat'] + if "OutputFormat" not in self.json: + return self._error("MissingParameter", "Missing parameter OutputFormat") + if self.json["OutputFormat"] not in ("json", "mp3", "ogg_vorbis", "pcm"): + return self._error( + "InvalidParameterValue", "Not one of json, mp3, ogg_vorbis, pcm" + ) + args["output_format"] = self.json["OutputFormat"] - if 'SampleRate' in self.json: - sample_rate = int(self.json['SampleRate']) + if "SampleRate" in self.json: + sample_rate = int(self.json["SampleRate"]) if sample_rate not in (8000, 16000, 22050): - return self._error('InvalidSampleRateException', 'The specified sample rate is not valid.') - args['sample_rate'] = sample_rate + return self._error( + "InvalidSampleRateException", + "The specified sample rate is not valid.", + ) + args["sample_rate"] = sample_rate - if 'SpeechMarkTypes' in self.json: - for value in self.json['SpeechMarkTypes']: - if value not in ('sentance', 'ssml', 'viseme', 'word'): - return self._error('InvalidParameterValue', 'Not one of sentance, ssml, viseme, word') - args['speech_marks'] = self.json['SpeechMarkTypes'] + if "SpeechMarkTypes" in self.json: + for value in self.json["SpeechMarkTypes"]: + if value not in ("sentance", "ssml", "viseme", "word"): + return self._error( + "InvalidParameterValue", + "Not one of sentance, ssml, viseme, word", + ) + args["speech_marks"] = self.json["SpeechMarkTypes"] - if 'Text' not in self.json: - return self._error('MissingParameter', 'Missing parameter Text') - args['text'] = self.json['Text'] + if "Text" not in self.json: + return self._error("MissingParameter", "Missing parameter Text") + args["text"] = self.json["Text"] - if 'TextType' in self.json: - if self.json['TextType'] not in ('ssml', 'text'): - return self._error('InvalidParameterValue', 'Not one of ssml, text') - args['text_type'] = self.json['TextType'] + if "TextType" in self.json: + if self.json["TextType"] not in ("ssml", "text"): + return self._error("InvalidParameterValue", "Not one of ssml, text") + args["text_type"] = self.json["TextType"] - if 'VoiceId' not in self.json: - return self._error('MissingParameter', 'Missing parameter VoiceId') - if self.json['VoiceId'] not in VOICE_IDS: - return self._error('InvalidParameterValue', 'Not one of {0}'.format(', '.join(VOICE_IDS))) - args['voice_id'] = self.json['VoiceId'] + if "VoiceId" not in self.json: + return self._error("MissingParameter", "Missing parameter VoiceId") + if self.json["VoiceId"] not in VOICE_IDS: + return self._error( + "InvalidParameterValue", "Not one of {0}".format(", ".join(VOICE_IDS)) + ) + args["voice_id"] = self.json["VoiceId"] # More validation - if len(args['text']) > 3000: - return self._error('TextLengthExceededException', 'Text too long') + if len(args["text"]) > 3000: + return self._error("TextLengthExceededException", "Text too long") - if args['speech_marks'] is not None and args['output_format'] != 'json': - return self._error('MarksNotSupportedForFormatException', 'OutputFormat must be json') - if args['speech_marks'] is not None and args['text_type'] == 'text': - return self._error('SsmlMarksNotSupportedForTextTypeException', 'TextType must be ssml') + if args["speech_marks"] is not None and args["output_format"] != "json": + return self._error( + "MarksNotSupportedForFormatException", "OutputFormat must be json" + ) + if args["speech_marks"] is not None and args["text_type"] == "text": + return self._error( + "SsmlMarksNotSupportedForTextTypeException", "TextType must be ssml" + ) - content_type = 'audio/json' - if args['output_format'] == 'mp3': - content_type = 'audio/mpeg' - elif args['output_format'] == 'ogg_vorbis': - content_type = 'audio/ogg' - elif args['output_format'] == 'pcm': - content_type = 'audio/pcm' + content_type = "audio/json" + if args["output_format"] == "mp3": + content_type = "audio/mpeg" + elif args["output_format"] == "ogg_vorbis": + content_type = "audio/ogg" + elif args["output_format"] == "pcm": + content_type = "audio/pcm" - headers = {'Content-Type': content_type} + headers = {"Content-Type": content_type} - return '\x00\x00\x00\x00\x00\x00\x00\x00', headers + return "\x00\x00\x00\x00\x00\x00\x00\x00", headers diff --git a/moto/polly/urls.py b/moto/polly/urls.py index bd4057a0b..5408c8cc1 100644 --- a/moto/polly/urls.py +++ b/moto/polly/urls.py @@ -1,13 +1,11 @@ from __future__ import unicode_literals from .responses import PollyResponse -url_bases = [ - "https?://polly.(.+).amazonaws.com", -] +url_bases = ["https?://polly.(.+).amazonaws.com"] url_paths = { - '{0}/v1/voices': PollyResponse.dispatch, - '{0}/v1/lexicons/(?P[^/]+)': PollyResponse.dispatch, - '{0}/v1/lexicons': PollyResponse.dispatch, - '{0}/v1/speech': PollyResponse.dispatch, + "{0}/v1/voices": PollyResponse.dispatch, + "{0}/v1/lexicons/(?P[^/]+)": PollyResponse.dispatch, + "{0}/v1/lexicons": PollyResponse.dispatch, + "{0}/v1/speech": PollyResponse.dispatch, } diff --git a/moto/rds/__init__.py b/moto/rds/__init__.py index a4086d89c..bd260d023 100644 --- a/moto/rds/__init__.py +++ b/moto/rds/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import rds_backends from ..core.models import base_decorator, deprecated_base_decorator -rds_backend = rds_backends['us-east-1'] +rds_backend = rds_backends["us-east-1"] mock_rds = base_decorator(rds_backends) mock_rds_deprecated = deprecated_base_decorator(rds_backends) diff --git a/moto/rds/exceptions.py b/moto/rds/exceptions.py index 5bcc95560..cf9b9aac6 100644 --- a/moto/rds/exceptions.py +++ b/moto/rds/exceptions.py @@ -5,38 +5,34 @@ from werkzeug.exceptions import BadRequest class RDSClientError(BadRequest): - def __init__(self, code, message): super(RDSClientError, self).__init__() - self.description = json.dumps({ - "Error": { - "Code": code, - "Message": message, - 'Type': 'Sender', - }, - 'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1', - }) + self.description = json.dumps( + { + "Error": {"Code": code, "Message": message, "Type": "Sender"}, + "RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1", + } + ) class DBInstanceNotFoundError(RDSClientError): - def __init__(self, database_identifier): super(DBInstanceNotFoundError, self).__init__( - 'DBInstanceNotFound', - "Database {0} not found.".format(database_identifier)) + "DBInstanceNotFound", "Database {0} not found.".format(database_identifier) + ) class DBSecurityGroupNotFoundError(RDSClientError): - def __init__(self, security_group_name): super(DBSecurityGroupNotFoundError, self).__init__( - 'DBSecurityGroupNotFound', - "Security Group {0} not found.".format(security_group_name)) + "DBSecurityGroupNotFound", + "Security Group {0} not found.".format(security_group_name), + ) class DBSubnetGroupNotFoundError(RDSClientError): - def __init__(self, subnet_group_name): super(DBSubnetGroupNotFoundError, self).__init__( - 'DBSubnetGroupNotFound', - "Subnet Group {0} not found.".format(subnet_group_name)) + "DBSubnetGroupNotFound", + "Subnet Group {0} not found.".format(subnet_group_name), + ) diff --git a/moto/rds/models.py b/moto/rds/models.py index 592516b34..421f3784b 100644 --- a/moto/rds/models.py +++ b/moto/rds/models.py @@ -11,33 +11,34 @@ from moto.rds2.models import rds2_backends class Database(BaseModel): - def get_cfn_attribute(self, attribute_name): - if attribute_name == 'Endpoint.Address': + if attribute_name == "Endpoint.Address": return self.address - elif attribute_name == 'Endpoint.Port': + elif attribute_name == "Endpoint.Port": return self.port raise UnformattedGetAttTemplateException() @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - db_instance_identifier = properties.get('DBInstanceIdentifier') + db_instance_identifier = properties.get("DBInstanceIdentifier") if not db_instance_identifier: db_instance_identifier = resource_name.lower() + get_random_hex(12) - db_security_groups = properties.get('DBSecurityGroups') + db_security_groups = properties.get("DBSecurityGroups") if not db_security_groups: db_security_groups = [] security_groups = [group.group_name for group in db_security_groups] db_subnet_group = properties.get("DBSubnetGroupName") db_subnet_group_name = db_subnet_group.subnet_name if db_subnet_group else None db_kwargs = { - "auto_minor_version_upgrade": properties.get('AutoMinorVersionUpgrade'), - "allocated_storage": properties.get('AllocatedStorage'), + "auto_minor_version_upgrade": properties.get("AutoMinorVersionUpgrade"), + "allocated_storage": properties.get("AllocatedStorage"), "availability_zone": properties.get("AvailabilityZone"), "backup_retention_period": properties.get("BackupRetentionPeriod"), - "db_instance_class": properties.get('DBInstanceClass'), + "db_instance_class": properties.get("DBInstanceClass"), "db_instance_identifier": db_instance_identifier, "db_name": properties.get("DBName"), "db_subnet_group_name": db_subnet_group_name, @@ -45,10 +46,10 @@ class Database(BaseModel): "engine_version": properties.get("EngineVersion"), "iops": properties.get("Iops"), "kms_key_id": properties.get("KmsKeyId"), - "master_password": properties.get('MasterUserPassword'), - "master_username": properties.get('MasterUsername'), + "master_password": properties.get("MasterUserPassword"), + "master_username": properties.get("MasterUsername"), "multi_az": properties.get("MultiAZ"), - "port": properties.get('Port', 3306), + "port": properties.get("Port", 3306), "publicly_accessible": properties.get("PubliclyAccessible"), "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"), "region": region_name, @@ -69,7 +70,8 @@ class Database(BaseModel): return database def to_xml(self): - template = Template(""" + template = Template( + """ {{ database.backup_retention_period }} {{ database.status }} {{ database.multi_az }} @@ -152,7 +154,8 @@ class Database(BaseModel): {{ database.port }} {{ database.db_instance_arn }} - """) + """ + ) return template.render(database=self) def delete(self, region_name): @@ -161,7 +164,6 @@ class Database(BaseModel): class SecurityGroup(BaseModel): - def __init__(self, group_name, description): self.group_name = group_name self.description = description @@ -170,7 +172,8 @@ class SecurityGroup(BaseModel): self.ec2_security_groups = [] def to_xml(self): - template = Template(""" + template = Template( + """ {% for security_group in security_group.ec2_security_groups %} @@ -193,7 +196,8 @@ class SecurityGroup(BaseModel): {{ security_group.ownder_id }} {{ security_group.group_name }} - """) + """ + ) return template.render(security_group=self) def authorize_cidr(self, cidr_ip): @@ -203,20 +207,19 @@ class SecurityGroup(BaseModel): self.ec2_security_groups.append(security_group) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] group_name = resource_name.lower() + get_random_hex(12) - description = properties['GroupDescription'] - security_group_ingress_rules = properties.get( - 'DBSecurityGroupIngress', []) - tags = properties.get('Tags') + description = properties["GroupDescription"] + security_group_ingress_rules = properties.get("DBSecurityGroupIngress", []) + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] rds_backend = rds_backends[region_name] security_group = rds_backend.create_security_group( - group_name, - description, - tags, + group_name, description, tags ) for security_group_ingress in security_group_ingress_rules: @@ -224,12 +227,10 @@ class SecurityGroup(BaseModel): if ingress_type == "CIDRIP": security_group.authorize_cidr(ingress_value) elif ingress_type == "EC2SecurityGroupName": - subnet = ec2_backend.get_security_group_from_name( - ingress_value) + subnet = ec2_backend.get_security_group_from_name(ingress_value) security_group.authorize_security_group(subnet) elif ingress_type == "EC2SecurityGroupId": - subnet = ec2_backend.get_security_group_from_id( - ingress_value) + subnet = ec2_backend.get_security_group_from_id(ingress_value) security_group.authorize_security_group(subnet) return security_group @@ -239,7 +240,6 @@ class SecurityGroup(BaseModel): class SubnetGroup(BaseModel): - def __init__(self, subnet_name, description, subnets): self.subnet_name = subnet_name self.description = description @@ -249,7 +249,8 @@ class SubnetGroup(BaseModel): self.vpc_id = self.subnets[0].vpc_id def to_xml(self): - template = Template(""" + template = Template( + """ {{ subnet_group.vpc_id }} {{ subnet_group.status }} {{ subnet_group.description }} @@ -266,27 +267,26 @@ class SubnetGroup(BaseModel): {% endfor %} - """) + """ + ) return template.render(subnet_group=self) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] subnet_name = resource_name.lower() + get_random_hex(12) - description = properties['DBSubnetGroupDescription'] - subnet_ids = properties['SubnetIds'] - tags = properties.get('Tags') + description = properties["DBSubnetGroupDescription"] + subnet_ids = properties["SubnetIds"] + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] - subnets = [ec2_backend.get_subnet(subnet_id) - for subnet_id in subnet_ids] + subnets = [ec2_backend.get_subnet(subnet_id) for subnet_id in subnet_ids] rds_backend = rds_backends[region_name] subnet_group = rds_backend.create_subnet_group( - subnet_name, - description, - subnets, - tags, + subnet_name, description, subnets, tags ) return subnet_group @@ -296,7 +296,6 @@ class SubnetGroup(BaseModel): class RDSBackend(BaseBackend): - def __init__(self, region): self.region = region @@ -314,5 +313,6 @@ class RDSBackend(BaseBackend): return rds2_backends[self.region] -rds_backends = dict((region.name, RDSBackend(region.name)) - for region in boto.rds.regions()) +rds_backends = dict( + (region.name, RDSBackend(region.name)) for region in boto.rds.regions() +) diff --git a/moto/rds/responses.py b/moto/rds/responses.py index 0afb03979..e3d37effc 100644 --- a/moto/rds/responses.py +++ b/moto/rds/responses.py @@ -6,19 +6,18 @@ from .models import rds_backends class RDSResponse(BaseResponse): - @property def backend(self): return rds_backends[self.region] def _get_db_kwargs(self): args = { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), - "allocated_storage": self._get_int_param('AllocatedStorage'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), + "allocated_storage": self._get_int_param("AllocatedStorage"), "availability_zone": self._get_param("AvailabilityZone"), "backup_retention_period": self._get_param("BackupRetentionPeriod"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_name": self._get_param("DBName"), # DBParameterGroupName "db_subnet_group_name": self._get_param("DBSubnetGroupName"), @@ -26,48 +25,48 @@ class RDSResponse(BaseResponse): "engine_version": self._get_param("EngineVersion"), "iops": self._get_int_param("Iops"), "kms_key_id": self._get_param("KmsKeyId"), - "master_password": self._get_param('MasterUserPassword'), - "master_username": self._get_param('MasterUsername'), + "master_password": self._get_param("MasterUserPassword"), + "master_username": self._get_param("MasterUsername"), "multi_az": self._get_bool_param("MultiAZ"), # OptionGroupName - "port": self._get_param('Port'), + "port": self._get_param("Port"), # PreferredBackupWindow # PreferredMaintenanceWindow "publicly_accessible": self._get_param("PubliclyAccessible"), "region": self.region, - "security_groups": self._get_multi_param('DBSecurityGroups.member'), + "security_groups": self._get_multi_param("DBSecurityGroups.member"), "storage_encrypted": self._get_param("StorageEncrypted"), "storage_type": self._get_param("StorageType"), # VpcSecurityGroupIds.member.N "tags": list(), } - args['tags'] = self.unpack_complex_list_params( - 'Tags.Tag', ('Key', 'Value')) + args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) return args def _get_db_replica_kwargs(self): return { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "availability_zone": self._get_param("AvailabilityZone"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"), "iops": self._get_int_param("Iops"), # OptionGroupName - "port": self._get_param('Port'), + "port": self._get_param("Port"), "publicly_accessible": self._get_param("PubliclyAccessible"), - "source_db_identifier": self._get_param('SourceDBInstanceIdentifier'), + "source_db_identifier": self._get_param("SourceDBInstanceIdentifier"), "storage_type": self._get_param("StorageType"), } def unpack_complex_list_params(self, label, names): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}.{2}'.format(label, count, names[0])): + while self._get_param("{0}.{1}.{2}".format(label, count, names[0])): param = dict() for i in range(len(names)): param[names[i]] = self._get_param( - '{0}.{1}.{2}'.format(label, count, names[i])) + "{0}.{1}.{2}".format(label, count, names[i]) + ) unpacked_list.append(param) count += 1 return unpacked_list @@ -87,16 +86,18 @@ class RDSResponse(BaseResponse): return template.render(database=database) def describe_db_instances(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") all_instances = list(self.backend.describe_databases(db_instance_identifier)) - marker = self._get_param('Marker') + marker = self._get_param("Marker") all_ids = [instance.db_instance_identifier for instance in all_instances] if marker: start = all_ids.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier - instances_resp = all_instances[start:start + page_size] + page_size = self._get_int_param( + "MaxRecords", 50 + ) # the default is 100, but using 50 to make testing easier + instances_resp = all_instances[start : start + page_size] next_marker = None if len(all_instances) > start + page_size: next_marker = instances_resp[-1].db_instance_identifier @@ -105,73 +106,74 @@ class RDSResponse(BaseResponse): return template.render(databases=instances_resp, marker=next_marker) def modify_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") db_kwargs = self._get_db_kwargs() - new_db_instance_identifier = self._get_param('NewDBInstanceIdentifier') + new_db_instance_identifier = self._get_param("NewDBInstanceIdentifier") if new_db_instance_identifier: - db_kwargs['new_db_instance_identifier'] = new_db_instance_identifier - database = self.backend.modify_database( - db_instance_identifier, db_kwargs) + db_kwargs["new_db_instance_identifier"] = new_db_instance_identifier + database = self.backend.modify_database(db_instance_identifier, db_kwargs) template = self.response_template(MODIFY_DATABASE_TEMPLATE) return template.render(database=database) def delete_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.delete_database(db_instance_identifier) template = self.response_template(DELETE_DATABASE_TEMPLATE) return template.render(database=database) def create_db_security_group(self): - group_name = self._get_param('DBSecurityGroupName') - description = self._get_param('DBSecurityGroupDescription') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + group_name = self._get_param("DBSecurityGroupName") + description = self._get_param("DBSecurityGroupDescription") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) security_group = self.backend.create_security_group( - group_name, description, tags) + group_name, description, tags + ) template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def describe_db_security_groups(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_groups = self.backend.describe_security_groups( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_groups = self.backend.describe_security_groups(security_group_name) template = self.response_template(DESCRIBE_SECURITY_GROUPS_TEMPLATE) return template.render(security_groups=security_groups) def delete_db_security_group(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_group = self.backend.delete_security_group( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_group = self.backend.delete_security_group(security_group_name) template = self.response_template(DELETE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def authorize_db_security_group_ingress(self): - security_group_name = self._get_param('DBSecurityGroupName') - cidr_ip = self._get_param('CIDRIP') + security_group_name = self._get_param("DBSecurityGroupName") + cidr_ip = self._get_param("CIDRIP") security_group = self.backend.authorize_security_group( - security_group_name, cidr_ip) + security_group_name, cidr_ip + ) template = self.response_template(AUTHORIZE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def create_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') - description = self._get_param('DBSubnetGroupDescription') - subnet_ids = self._get_multi_param('SubnetIds.member') - subnets = [ec2_backends[self.region].get_subnet( - subnet_id) for subnet_id in subnet_ids] - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + subnet_name = self._get_param("DBSubnetGroupName") + description = self._get_param("DBSubnetGroupDescription") + subnet_ids = self._get_multi_param("SubnetIds.member") + subnets = [ + ec2_backends[self.region].get_subnet(subnet_id) for subnet_id in subnet_ids + ] + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) subnet_group = self.backend.create_subnet_group( - subnet_name, description, subnets, tags) + subnet_name, description, subnets, tags + ) template = self.response_template(CREATE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) def describe_db_subnet_groups(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_groups = self.backend.describe_subnet_groups(subnet_name) template = self.response_template(DESCRIBE_SUBNET_GROUPS_TEMPLATE) return template.render(subnet_groups=subnet_groups) def delete_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_group = self.backend.delete_subnet_group(subnet_name) template = self.response_template(DELETE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) diff --git a/moto/rds/urls.py b/moto/rds/urls.py index 646f17304..9c7570167 100644 --- a/moto/rds/urls.py +++ b/moto/rds/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import RDSResponse -url_bases = [ - "https?://rds(\..+)?.amazonaws.com", -] +url_bases = ["https?://rds(\..+)?.amazonaws.com"] -url_paths = { - '{0}/$': RDSResponse.dispatch, -} +url_paths = {"{0}/$": RDSResponse.dispatch} diff --git a/moto/rds2/__init__.py b/moto/rds2/__init__.py index 723fa0968..acc8564e2 100644 --- a/moto/rds2/__init__.py +++ b/moto/rds2/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import rds2_backends from ..core.models import base_decorator, deprecated_base_decorator -rds2_backend = rds2_backends['us-west-1'] +rds2_backend = rds2_backends["us-west-1"] mock_rds2 = base_decorator(rds2_backends) mock_rds2_deprecated = deprecated_base_decorator(rds2_backends) diff --git a/moto/rds2/exceptions.py b/moto/rds2/exceptions.py index e82ae7077..b6dc5bb99 100644 --- a/moto/rds2/exceptions.py +++ b/moto/rds2/exceptions.py @@ -5,10 +5,10 @@ from werkzeug.exceptions import BadRequest class RDSClientError(BadRequest): - def __init__(self, code, message): super(RDSClientError, self).__init__() - template = Template(""" + template = Template( + """ {{ code }} @@ -16,87 +16,94 @@ class RDSClientError(BadRequest): Sender 6876f774-7273-11e4-85dc-39e55ca848d1 - """) + """ + ) self.description = template.render(code=code, message=message) class DBInstanceNotFoundError(RDSClientError): - def __init__(self, database_identifier): super(DBInstanceNotFoundError, self).__init__( - 'DBInstanceNotFound', - "Database {0} not found.".format(database_identifier)) + "DBInstanceNotFound", "Database {0} not found.".format(database_identifier) + ) class DBSnapshotNotFoundError(RDSClientError): - def __init__(self): super(DBSnapshotNotFoundError, self).__init__( - 'DBSnapshotNotFound', - "DBSnapshotIdentifier does not refer to an existing DB snapshot.") + "DBSnapshotNotFound", + "DBSnapshotIdentifier does not refer to an existing DB snapshot.", + ) class DBSecurityGroupNotFoundError(RDSClientError): - def __init__(self, security_group_name): super(DBSecurityGroupNotFoundError, self).__init__( - 'DBSecurityGroupNotFound', - "Security Group {0} not found.".format(security_group_name)) + "DBSecurityGroupNotFound", + "Security Group {0} not found.".format(security_group_name), + ) class DBSubnetGroupNotFoundError(RDSClientError): - def __init__(self, subnet_group_name): super(DBSubnetGroupNotFoundError, self).__init__( - 'DBSubnetGroupNotFound', - "Subnet Group {0} not found.".format(subnet_group_name)) + "DBSubnetGroupNotFound", + "Subnet Group {0} not found.".format(subnet_group_name), + ) class DBParameterGroupNotFoundError(RDSClientError): - def __init__(self, db_parameter_group_name): super(DBParameterGroupNotFoundError, self).__init__( - 'DBParameterGroupNotFound', - 'DB Parameter Group {0} not found.'.format(db_parameter_group_name)) + "DBParameterGroupNotFound", + "DB Parameter Group {0} not found.".format(db_parameter_group_name), + ) class OptionGroupNotFoundFaultError(RDSClientError): - def __init__(self, option_group_name): super(OptionGroupNotFoundFaultError, self).__init__( - 'OptionGroupNotFoundFault', - 'Specified OptionGroupName: {0} not found.'.format(option_group_name) + "OptionGroupNotFoundFault", + "Specified OptionGroupName: {0} not found.".format(option_group_name), ) class InvalidDBClusterStateFaultError(RDSClientError): - def __init__(self, database_identifier): super(InvalidDBClusterStateFaultError, self).__init__( - 'InvalidDBClusterStateFault', - 'Invalid DB type, when trying to perform StopDBInstance on {0}e. See AWS RDS documentation on rds.stop_db_instance'.format(database_identifier)) + "InvalidDBClusterStateFault", + "Invalid DB type, when trying to perform StopDBInstance on {0}e. See AWS RDS documentation on rds.stop_db_instance".format( + database_identifier + ), + ) class InvalidDBInstanceStateError(RDSClientError): - def __init__(self, database_identifier, istate): - estate = "in available state" if istate == 'stop' else "stopped, it cannot be started" + estate = ( + "in available state" + if istate == "stop" + else "stopped, it cannot be started" + ) super(InvalidDBInstanceStateError, self).__init__( - 'InvalidDBInstanceState', - 'Instance {} is not {}.'.format(database_identifier, estate)) + "InvalidDBInstanceState", + "Instance {} is not {}.".format(database_identifier, estate), + ) class SnapshotQuotaExceededError(RDSClientError): - def __init__(self): super(SnapshotQuotaExceededError, self).__init__( - 'SnapshotQuotaExceeded', - 'The request cannot be processed because it would exceed the maximum number of snapshots.') + "SnapshotQuotaExceeded", + "The request cannot be processed because it would exceed the maximum number of snapshots.", + ) class DBSnapshotAlreadyExistsError(RDSClientError): - def __init__(self, database_snapshot_identifier): super(DBSnapshotAlreadyExistsError, self).__init__( - 'DBSnapshotAlreadyExists', - 'Cannot create the snapshot because a snapshot with the identifier {} already exists.'.format(database_snapshot_identifier)) + "DBSnapshotAlreadyExists", + "Cannot create the snapshot because a snapshot with the identifier {} already exists.".format( + database_snapshot_identifier + ), + ) diff --git a/moto/rds2/models.py b/moto/rds2/models.py index cd56599e6..686d22ccf 100644 --- a/moto/rds2/models.py +++ b/moto/rds2/models.py @@ -14,39 +14,41 @@ from moto.core import BaseBackend, BaseModel from moto.core.utils import get_random_hex from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.ec2.models import ec2_backends -from .exceptions import (RDSClientError, - DBInstanceNotFoundError, - DBSnapshotNotFoundError, - DBSecurityGroupNotFoundError, - DBSubnetGroupNotFoundError, - DBParameterGroupNotFoundError, - OptionGroupNotFoundFaultError, - InvalidDBClusterStateFaultError, - InvalidDBInstanceStateError, - SnapshotQuotaExceededError, - DBSnapshotAlreadyExistsError) +from .exceptions import ( + RDSClientError, + DBInstanceNotFoundError, + DBSnapshotNotFoundError, + DBSecurityGroupNotFoundError, + DBSubnetGroupNotFoundError, + DBParameterGroupNotFoundError, + OptionGroupNotFoundFaultError, + InvalidDBClusterStateFaultError, + InvalidDBInstanceStateError, + SnapshotQuotaExceededError, + DBSnapshotAlreadyExistsError, +) class Database(BaseModel): - def __init__(self, **kwargs): self.status = "available" self.is_replica = False self.replicas = [] - self.region = kwargs.get('region') + self.region = kwargs.get("region") self.engine = kwargs.get("engine") self.engine_version = kwargs.get("engine_version", None) - self.default_engine_versions = {"MySQL": "5.6.21", - "mysql": "5.6.21", - "oracle-se1": "11.2.0.4.v3", - "oracle-se": "11.2.0.4.v3", - "oracle-ee": "11.2.0.4.v3", - "sqlserver-ee": "11.00.2100.60.v1", - "sqlserver-se": "11.00.2100.60.v1", - "sqlserver-ex": "11.00.2100.60.v1", - "sqlserver-web": "11.00.2100.60.v1", - "postgres": "9.3.3" - } + self.default_engine_versions = { + "MySQL": "5.6.21", + "mysql": "5.6.21", + "oracle-se1": "11.2.0.4.v3", + "oracle-se": "11.2.0.4.v3", + "oracle-ee": "11.2.0.4.v3", + "sqlserver-ee": "11.00.2100.60.v1", + "sqlserver-se": "11.00.2100.60.v1", + "sqlserver-ex": "11.00.2100.60.v1", + "sqlserver-web": "11.00.2100.60.v1", + "postgres": "9.3.3", + } if not self.engine_version and self.engine in self.default_engine_versions: self.engine_version = self.default_engine_versions[self.engine] self.iops = kwargs.get("iops") @@ -58,24 +60,27 @@ class Database(BaseModel): self.storage_type = kwargs.get("storage_type") if self.storage_type is None: self.storage_type = Database.default_storage_type(iops=self.iops) - self.master_username = kwargs.get('master_username') - self.master_user_password = kwargs.get('master_user_password') - self.auto_minor_version_upgrade = kwargs.get( - 'auto_minor_version_upgrade') + self.master_username = kwargs.get("master_username") + self.master_user_password = kwargs.get("master_user_password") + self.auto_minor_version_upgrade = kwargs.get("auto_minor_version_upgrade") if self.auto_minor_version_upgrade is None: self.auto_minor_version_upgrade = True - self.allocated_storage = kwargs.get('allocated_storage') + self.allocated_storage = kwargs.get("allocated_storage") if self.allocated_storage is None: - self.allocated_storage = Database.default_allocated_storage(engine=self.engine, storage_type=self.storage_type) - self.db_instance_identifier = kwargs.get('db_instance_identifier') + self.allocated_storage = Database.default_allocated_storage( + engine=self.engine, storage_type=self.storage_type + ) + self.db_instance_identifier = kwargs.get("db_instance_identifier") self.source_db_identifier = kwargs.get("source_db_identifier") - self.db_instance_class = kwargs.get('db_instance_class') - self.port = kwargs.get('port') + self.db_instance_class = kwargs.get("db_instance_class") + self.port = kwargs.get("port") if self.port is None: self.port = Database.default_port(self.engine) - self.db_instance_identifier = kwargs.get('db_instance_identifier') + self.db_instance_identifier = kwargs.get("db_instance_identifier") self.db_name = kwargs.get("db_name") - self.instance_create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) + self.instance_create_time = iso_8601_datetime_with_milliseconds( + datetime.datetime.now() + ) self.publicly_accessible = kwargs.get("publicly_accessible") if self.publicly_accessible is None: self.publicly_accessible = True @@ -89,39 +94,51 @@ class Database(BaseModel): self.multi_az = kwargs.get("multi_az") self.db_subnet_group_name = kwargs.get("db_subnet_group_name") if self.db_subnet_group_name: - self.db_subnet_group = rds2_backends[ - self.region].describe_subnet_groups(self.db_subnet_group_name)[0] + self.db_subnet_group = rds2_backends[self.region].describe_subnet_groups( + self.db_subnet_group_name + )[0] else: self.db_subnet_group = None - self.security_groups = kwargs.get('security_groups', []) - self.vpc_security_group_ids = kwargs.get('vpc_security_group_ids', []) + self.security_groups = kwargs.get("security_groups", []) + self.vpc_security_group_ids = kwargs.get("vpc_security_group_ids", []) self.preferred_maintenance_window = kwargs.get( - 'preferred_maintenance_window', 'wed:06:38-wed:07:08') - self.db_parameter_group_name = kwargs.get('db_parameter_group_name') - if self.db_parameter_group_name and self.db_parameter_group_name not in rds2_backends[self.region].db_parameter_groups: + "preferred_maintenance_window", "wed:06:38-wed:07:08" + ) + self.db_parameter_group_name = kwargs.get("db_parameter_group_name") + if ( + self.db_parameter_group_name + and self.db_parameter_group_name + not in rds2_backends[self.region].db_parameter_groups + ): raise DBParameterGroupNotFoundError(self.db_parameter_group_name) self.preferred_backup_window = kwargs.get( - 'preferred_backup_window', '13:14-13:44') - self.license_model = kwargs.get('license_model', 'general-public-license') - self.option_group_name = kwargs.get('option_group_name', None) - if self.option_group_name and self.option_group_name not in rds2_backends[self.region].option_groups: + "preferred_backup_window", "13:14-13:44" + ) + self.license_model = kwargs.get("license_model", "general-public-license") + self.option_group_name = kwargs.get("option_group_name", None) + if ( + self.option_group_name + and self.option_group_name not in rds2_backends[self.region].option_groups + ): raise OptionGroupNotFoundFaultError(self.option_group_name) - self.default_option_groups = {"MySQL": "default.mysql5.6", - "mysql": "default.mysql5.6", - "postgres": "default.postgres9.3" - } + self.default_option_groups = { + "MySQL": "default.mysql5.6", + "mysql": "default.mysql5.6", + "postgres": "default.postgres9.3", + } if not self.option_group_name and self.engine in self.default_option_groups: self.option_group_name = self.default_option_groups[self.engine] - self.character_set_name = kwargs.get('character_set_name', None) + self.character_set_name = kwargs.get("character_set_name", None) self.iam_database_authentication_enabled = False self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U" - self.tags = kwargs.get('tags', []) + self.tags = kwargs.get("tags", []) @property def db_instance_arn(self): return "arn:aws:rds:{0}:1234567890:db:{1}".format( - self.region, self.db_instance_identifier) + self.region, self.db_instance_identifier + ) @property def physical_resource_id(self): @@ -129,26 +146,38 @@ class Database(BaseModel): def db_parameter_groups(self): if not self.db_parameter_group_name: - db_family, db_parameter_group_name = self.default_db_parameter_group_details() - description = 'Default parameter group for {0}'.format(db_family) - return [DBParameterGroup(name=db_parameter_group_name, - family=db_family, - description=description, - tags={})] + ( + db_family, + db_parameter_group_name, + ) = self.default_db_parameter_group_details() + description = "Default parameter group for {0}".format(db_family) + return [ + DBParameterGroup( + name=db_parameter_group_name, + family=db_family, + description=description, + tags={}, + ) + ] else: - return [rds2_backends[self.region].db_parameter_groups[self.db_parameter_group_name]] + return [ + rds2_backends[self.region].db_parameter_groups[ + self.db_parameter_group_name + ] + ] def default_db_parameter_group_details(self): if not self.engine_version: return (None, None) - minor_engine_version = '.'.join(self.engine_version.rsplit('.')[:-1]) - db_family = '{0}{1}'.format(self.engine.lower(), minor_engine_version) + minor_engine_version = ".".join(self.engine_version.rsplit(".")[:-1]) + db_family = "{0}{1}".format(self.engine.lower(), minor_engine_version) - return db_family, 'default.{0}'.format(db_family) + return db_family, "default.{0}".format(db_family) def to_xml(self): - template = Template(""" + template = Template( + """ {{ database.backup_retention_period }} {{ database.status }} {% if database.db_name %}{{ database.db_name }}{% endif %} @@ -251,12 +280,15 @@ class Database(BaseModel): {{ database.port }} {{ database.db_instance_arn }} - """) + """ + ) return template.render(database=self) @property def address(self): - return "{0}.aaaaaaaaaa.{1}.rds.amazonaws.com".format(self.db_instance_identifier, self.region) + return "{0}.aaaaaaaaaa.{1}.rds.amazonaws.com".format( + self.db_instance_identifier, self.region + ) def add_replica(self, replica): self.replicas.append(replica.db_instance_identifier) @@ -274,119 +306,73 @@ class Database(BaseModel): setattr(self, key, value) def get_cfn_attribute(self, attribute_name): - if attribute_name == 'Endpoint.Address': + if attribute_name == "Endpoint.Address": return self.address - elif attribute_name == 'Endpoint.Port': + elif attribute_name == "Endpoint.Port": return self.port raise UnformattedGetAttTemplateException() @staticmethod def default_port(engine): return { - 'mysql': 3306, - 'mariadb': 3306, - 'postgres': 5432, - 'oracle-ee': 1521, - 'oracle-se2': 1521, - 'oracle-se1': 1521, - 'oracle-se': 1521, - 'sqlserver-ee': 1433, - 'sqlserver-ex': 1433, - 'sqlserver-se': 1433, - 'sqlserver-web': 1433, + "mysql": 3306, + "mariadb": 3306, + "postgres": 5432, + "oracle-ee": 1521, + "oracle-se2": 1521, + "oracle-se1": 1521, + "oracle-se": 1521, + "sqlserver-ee": 1433, + "sqlserver-ex": 1433, + "sqlserver-se": 1433, + "sqlserver-web": 1433, }[engine] @staticmethod def default_storage_type(iops): if iops is None: - return 'gp2' + return "gp2" else: - return 'io1' + return "io1" @staticmethod def default_allocated_storage(engine, storage_type): return { - 'aurora': { - 'gp2': 0, - 'io1': 0, - 'standard': 0, - }, - 'mysql': { - 'gp2': 20, - 'io1': 100, - 'standard': 5, - }, - 'mariadb': { - 'gp2': 20, - 'io1': 100, - 'standard': 5, - }, - 'postgres': { - 'gp2': 20, - 'io1': 100, - 'standard': 5, - }, - 'oracle-ee': { - 'gp2': 20, - 'io1': 100, - 'standard': 10, - }, - 'oracle-se2': { - 'gp2': 20, - 'io1': 100, - 'standard': 10, - }, - 'oracle-se1': { - 'gp2': 20, - 'io1': 100, - 'standard': 10, - }, - 'oracle-se': { - 'gp2': 20, - 'io1': 100, - 'standard': 10, - }, - 'sqlserver-ee': { - 'gp2': 200, - 'io1': 200, - 'standard': 200, - }, - 'sqlserver-ex': { - 'gp2': 20, - 'io1': 100, - 'standard': 20, - }, - 'sqlserver-se': { - 'gp2': 200, - 'io1': 200, - 'standard': 200, - }, - 'sqlserver-web': { - 'gp2': 20, - 'io1': 100, - 'standard': 20, - }, + "aurora": {"gp2": 0, "io1": 0, "standard": 0}, + "mysql": {"gp2": 20, "io1": 100, "standard": 5}, + "mariadb": {"gp2": 20, "io1": 100, "standard": 5}, + "postgres": {"gp2": 20, "io1": 100, "standard": 5}, + "oracle-ee": {"gp2": 20, "io1": 100, "standard": 10}, + "oracle-se2": {"gp2": 20, "io1": 100, "standard": 10}, + "oracle-se1": {"gp2": 20, "io1": 100, "standard": 10}, + "oracle-se": {"gp2": 20, "io1": 100, "standard": 10}, + "sqlserver-ee": {"gp2": 200, "io1": 200, "standard": 200}, + "sqlserver-ex": {"gp2": 20, "io1": 100, "standard": 20}, + "sqlserver-se": {"gp2": 200, "io1": 200, "standard": 200}, + "sqlserver-web": {"gp2": 20, "io1": 100, "standard": 20}, }[engine][storage_type] @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] - db_instance_identifier = properties.get('DBInstanceIdentifier') + db_instance_identifier = properties.get("DBInstanceIdentifier") if not db_instance_identifier: db_instance_identifier = resource_name.lower() + get_random_hex(12) - db_security_groups = properties.get('DBSecurityGroups') + db_security_groups = properties.get("DBSecurityGroups") if not db_security_groups: db_security_groups = [] security_groups = [group.group_name for group in db_security_groups] db_subnet_group = properties.get("DBSubnetGroupName") db_subnet_group_name = db_subnet_group.subnet_name if db_subnet_group else None db_kwargs = { - "auto_minor_version_upgrade": properties.get('AutoMinorVersionUpgrade'), - "allocated_storage": properties.get('AllocatedStorage'), + "auto_minor_version_upgrade": properties.get("AutoMinorVersionUpgrade"), + "allocated_storage": properties.get("AllocatedStorage"), "availability_zone": properties.get("AvailabilityZone"), "backup_retention_period": properties.get("BackupRetentionPeriod"), - "db_instance_class": properties.get('DBInstanceClass'), + "db_instance_class": properties.get("DBInstanceClass"), "db_instance_identifier": db_instance_identifier, "db_name": properties.get("DBName"), "db_subnet_group_name": db_subnet_group_name, @@ -394,11 +380,11 @@ class Database(BaseModel): "engine_version": properties.get("EngineVersion"), "iops": properties.get("Iops"), "kms_key_id": properties.get("KmsKeyId"), - "master_user_password": properties.get('MasterUserPassword'), - "master_username": properties.get('MasterUsername'), + "master_user_password": properties.get("MasterUserPassword"), + "master_username": properties.get("MasterUsername"), "multi_az": properties.get("MultiAZ"), - "db_parameter_group_name": properties.get('DBParameterGroupName'), - "port": properties.get('Port', 3306), + "db_parameter_group_name": properties.get("DBParameterGroupName"), + "port": properties.get("Port", 3306), "publicly_accessible": properties.get("PubliclyAccessible"), "copy_tags_to_snapshot": properties.get("CopyTagsToSnapshot"), "region": region_name, @@ -406,7 +392,7 @@ class Database(BaseModel): "storage_encrypted": properties.get("StorageEncrypted"), "storage_type": properties.get("StorageType"), "tags": properties.get("Tags"), - "vpc_security_group_ids": properties.get('VpcSecurityGroupIds', []), + "vpc_security_group_ids": properties.get("VpcSecurityGroupIds", []), } rds2_backend = rds2_backends[region_name] @@ -420,7 +406,8 @@ class Database(BaseModel): return database def to_json(self): - template = Template("""{ + template = Template( + """{ "AllocatedStorage": 10, "AutoMinorVersionUpgrade": "{{ database.auto_minor_version_upgrade }}", "AvailabilityZone": "{{ database.availability_zone }}", @@ -489,22 +476,21 @@ class Database(BaseModel): {% endfor %} ], "DBInstanceArn": "{{ database.db_instance_arn }}" - }""") + }""" + ) return template.render(database=self) def get_tags(self): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def delete(self, region_name): backend = rds2_backends[region_name] @@ -520,10 +506,13 @@ class Snapshot(BaseModel): @property def snapshot_arn(self): - return "arn:aws:rds:{0}:1234567890:snapshot:{1}".format(self.database.region, self.snapshot_id) + return "arn:aws:rds:{0}:1234567890:snapshot:{1}".format( + self.database.region, self.snapshot_id + ) def to_xml(self): - template = Template(""" + template = Template( + """ {{ snapshot.snapshot_id }} {{ database.db_instance_identifier }} {{ snapshot.created_at }} @@ -554,26 +543,24 @@ class Snapshot(BaseModel): {{ snapshot.snapshot_arn }} false - """) + """ + ) return template.render(snapshot=self, database=self.database) def get_tags(self): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] class SecurityGroup(BaseModel): - def __init__(self, group_name, description, tags): self.group_name = group_name self.description = description @@ -581,11 +568,12 @@ class SecurityGroup(BaseModel): self.ip_ranges = [] self.ec2_security_groups = [] self.tags = tags - self.owner_id = '1234567890' + self.owner_id = "1234567890" self.vpc_id = None def to_xml(self): - template = Template(""" + template = Template( + """ {% for security_group in security_group.ec2_security_groups %} @@ -608,11 +596,13 @@ class SecurityGroup(BaseModel): {{ security_group.ownder_id }} {{ security_group.group_name }} - """) + """ + ) return template.render(security_group=self) def to_json(self): - template = Template("""{ + template = Template( + """{ "DBSecurityGroupDescription": "{{ security_group.description }}", "DBSecurityGroupName": "{{ security_group.group_name }}", "EC2SecurityGroups": {{ security_group.ec2_security_groups }}, @@ -623,7 +613,8 @@ class SecurityGroup(BaseModel): ], "OwnerId": "{{ security_group.owner_id }}", "VpcId": "{{ security_group.vpc_id }}" - }""") + }""" + ) return template.render(security_group=self) def authorize_cidr(self, cidr_ip): @@ -633,32 +624,29 @@ class SecurityGroup(BaseModel): self.ec2_security_groups.append(security_group) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] group_name = resource_name.lower() + get_random_hex(12) - description = properties['GroupDescription'] - security_group_ingress_rules = properties.get( - 'DBSecurityGroupIngress', []) - tags = properties.get('Tags') + description = properties["GroupDescription"] + security_group_ingress_rules = properties.get("DBSecurityGroupIngress", []) + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] rds2_backend = rds2_backends[region_name] security_group = rds2_backend.create_security_group( - group_name, - description, - tags, + group_name, description, tags ) for security_group_ingress in security_group_ingress_rules: for ingress_type, ingress_value in security_group_ingress.items(): if ingress_type == "CIDRIP": security_group.authorize_cidr(ingress_value) elif ingress_type == "EC2SecurityGroupName": - subnet = ec2_backend.get_security_group_from_name( - ingress_value) + subnet = ec2_backend.get_security_group_from_name(ingress_value) security_group.authorize_security_group(subnet) elif ingress_type == "EC2SecurityGroupId": - subnet = ec2_backend.get_security_group_from_id( - ingress_value) + subnet = ec2_backend.get_security_group_from_id(ingress_value) security_group.authorize_security_group(subnet) return security_group @@ -666,15 +654,13 @@ class SecurityGroup(BaseModel): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def delete(self, region_name): backend = rds2_backends[region_name] @@ -682,7 +668,6 @@ class SecurityGroup(BaseModel): class SubnetGroup(BaseModel): - def __init__(self, subnet_name, description, subnets, tags): self.subnet_name = subnet_name self.description = description @@ -692,7 +677,8 @@ class SubnetGroup(BaseModel): self.vpc_id = self.subnets[0].vpc_id def to_xml(self): - template = Template(""" + template = Template( + """ {{ subnet_group.vpc_id }} {{ subnet_group.status }} {{ subnet_group.description }} @@ -709,11 +695,13 @@ class SubnetGroup(BaseModel): {% endfor %} - """) + """ + ) return template.render(subnet_group=self) def to_json(self): - template = Template(""""DBSubnetGroup": { + template = Template( + """"DBSubnetGroup": { "VpcId": "{{ subnet_group.vpc_id }}", "SubnetGroupStatus": "{{ subnet_group.status }}", "DBSubnetGroupDescription": "{{ subnet_group.description }}", @@ -730,27 +718,26 @@ class SubnetGroup(BaseModel): }{%- if not loop.last -%},{%- endif -%}{% endfor %} ] } - }""") + }""" + ) return template.render(subnet_group=self) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] subnet_name = resource_name.lower() + get_random_hex(12) - description = properties['DBSubnetGroupDescription'] - subnet_ids = properties['SubnetIds'] - tags = properties.get('Tags') + description = properties["DBSubnetGroupDescription"] + subnet_ids = properties["SubnetIds"] + tags = properties.get("Tags") ec2_backend = ec2_backends[region_name] - subnets = [ec2_backend.get_subnet(subnet_id) - for subnet_id in subnet_ids] + subnets = [ec2_backend.get_subnet(subnet_id) for subnet_id in subnet_ids] rds2_backend = rds2_backends[region_name] subnet_group = rds2_backend.create_subnet_group( - subnet_name, - description, - subnets, - tags, + subnet_name, description, subnets, tags ) return subnet_group @@ -758,15 +745,13 @@ class SubnetGroup(BaseModel): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def delete(self, region_name): backend = rds2_backends[region_name] @@ -774,11 +759,11 @@ class SubnetGroup(BaseModel): class RDS2Backend(BaseBackend): - def __init__(self, region): self.region = region self.arn_regex = re_compile( - r'^arn:aws:rds:.*:[0-9]*:(db|es|og|pg|ri|secgrp|snapshot|subgrp):.*$') + r"^arn:aws:rds:.*:[0-9]*:(db|es|og|pg|ri|secgrp|snapshot|subgrp):.*$" + ) self.databases = OrderedDict() self.snapshots = OrderedDict() self.db_parameter_groups = {} @@ -793,18 +778,20 @@ class RDS2Backend(BaseBackend): self.__init__(region) def create_database(self, db_kwargs): - database_id = db_kwargs['db_instance_identifier'] + database_id = db_kwargs["db_instance_identifier"] database = Database(**db_kwargs) self.databases[database_id] = database return database - def create_snapshot(self, db_instance_identifier, db_snapshot_identifier, tags=None): + def create_snapshot( + self, db_instance_identifier, db_snapshot_identifier, tags=None + ): database = self.databases.get(db_instance_identifier) if not database: raise DBInstanceNotFoundError(db_instance_identifier) if db_snapshot_identifier in self.snapshots: raise DBSnapshotAlreadyExistsError(db_snapshot_identifier) - if len(self.snapshots) >= int(os.environ.get('MOTO_RDS_SNAPSHOT_LIMIT', '100')): + if len(self.snapshots) >= int(os.environ.get("MOTO_RDS_SNAPSHOT_LIMIT", "100")): raise SnapshotQuotaExceededError() if tags is None: tags = list() @@ -821,11 +808,11 @@ class RDS2Backend(BaseBackend): return self.snapshots.pop(db_snapshot_identifier) def create_database_replica(self, db_kwargs): - database_id = db_kwargs['db_instance_identifier'] - source_database_id = db_kwargs['source_db_identifier'] + database_id = db_kwargs["db_instance_identifier"] + source_database_id = db_kwargs["source_db_identifier"] primary = self.find_db_from_id(source_database_id) if self.arn_regex.match(source_database_id): - db_kwargs['region'] = self.region + db_kwargs["region"] = self.region # Shouldn't really copy here as the instance is duplicated. RDS replicas have different instances. replica = copy.copy(primary) @@ -860,9 +847,11 @@ class RDS2Backend(BaseBackend): def modify_database(self, db_instance_identifier, db_kwargs): database = self.describe_databases(db_instance_identifier)[0] - if 'new_db_instance_identifier' in db_kwargs: + if "new_db_instance_identifier" in db_kwargs: del self.databases[db_instance_identifier] - db_instance_identifier = db_kwargs['db_instance_identifier'] = db_kwargs.pop('new_db_instance_identifier') + db_instance_identifier = db_kwargs[ + "db_instance_identifier" + ] = db_kwargs.pop("new_db_instance_identifier") self.databases[db_instance_identifier] = database database.update(db_kwargs) return database @@ -875,26 +864,26 @@ class RDS2Backend(BaseBackend): database = self.describe_databases(db_instance_identifier)[0] # todo: certain rds types not allowed to be stopped at this time. if database.is_replica or database.multi_az: - # todo: more db types not supported by stop/start instance api - raise InvalidDBClusterStateFaultError(db_instance_identifier) - if database.status != 'available': - raise InvalidDBInstanceStateError(db_instance_identifier, 'stop') + # todo: more db types not supported by stop/start instance api + raise InvalidDBClusterStateFaultError(db_instance_identifier) + if database.status != "available": + raise InvalidDBInstanceStateError(db_instance_identifier, "stop") if db_snapshot_identifier: self.create_snapshot(db_instance_identifier, db_snapshot_identifier) - database.status = 'stopped' + database.status = "stopped" return database def start_database(self, db_instance_identifier): database = self.describe_databases(db_instance_identifier)[0] # todo: bunch of different error messages to be generated from this api call - if database.status != 'stopped': - raise InvalidDBInstanceStateError(db_instance_identifier, 'start') - database.status = 'available' + if database.status != "stopped": + raise InvalidDBInstanceStateError(db_instance_identifier, "start") + database.status = "available" return database def find_db_from_id(self, db_id): if self.arn_regex.match(db_id): - arn_breakdown = db_id.split(':') + arn_breakdown = db_id.split(":") region = arn_breakdown[3] backend = rds2_backends[region] db_name = arn_breakdown[-1] @@ -912,7 +901,7 @@ class RDS2Backend(BaseBackend): if database.is_replica: primary = self.find_db_from_id(database.source_db_identifier) primary.remove_replica(database) - database.status = 'deleting' + database.status = "deleting" return database else: raise DBInstanceNotFoundError(db_instance_identifier) @@ -967,34 +956,49 @@ class RDS2Backend(BaseBackend): raise DBSubnetGroupNotFoundError(subnet_name) def create_option_group(self, option_group_kwargs): - option_group_id = option_group_kwargs['name'] - valid_option_group_engines = {'mariadb': ['10.0', '10.1', '10.2', '10.3'], - 'mysql': ['5.5', '5.6', '5.7', '8.0'], - 'oracle-se2': ['11.2', '12.1', '12.2'], - 'oracle-se1': ['11.2', '12.1', '12.2'], - 'oracle-se': ['11.2', '12.1', '12.2'], - 'oracle-ee': ['11.2', '12.1', '12.2'], - 'sqlserver-se': ['10.50', '11.00'], - 'sqlserver-ee': ['10.50', '11.00'], - 'sqlserver-ex': ['10.50', '11.00'], - 'sqlserver-web': ['10.50', '11.00']} - if option_group_kwargs['name'] in self.option_groups: - raise RDSClientError('OptionGroupAlreadyExistsFault', - 'An option group named {0} already exists.'.format(option_group_kwargs['name'])) - if 'description' not in option_group_kwargs or not option_group_kwargs['description']: - raise RDSClientError('InvalidParameterValue', - 'The parameter OptionGroupDescription must be provided and must not be blank.') - if option_group_kwargs['engine_name'] not in valid_option_group_engines.keys(): - raise RDSClientError('InvalidParameterValue', - 'Invalid DB engine: non-existant') - if option_group_kwargs['major_engine_version'] not in\ - valid_option_group_engines[option_group_kwargs['engine_name']]: - raise RDSClientError('InvalidParameterCombination', - 'Cannot find major version {0} for {1}'.format( - option_group_kwargs[ - 'major_engine_version'], - option_group_kwargs['engine_name'] - )) + option_group_id = option_group_kwargs["name"] + valid_option_group_engines = { + "mariadb": ["10.0", "10.1", "10.2", "10.3"], + "mysql": ["5.5", "5.6", "5.7", "8.0"], + "oracle-se2": ["11.2", "12.1", "12.2"], + "oracle-se1": ["11.2", "12.1", "12.2"], + "oracle-se": ["11.2", "12.1", "12.2"], + "oracle-ee": ["11.2", "12.1", "12.2"], + "sqlserver-se": ["10.50", "11.00"], + "sqlserver-ee": ["10.50", "11.00"], + "sqlserver-ex": ["10.50", "11.00"], + "sqlserver-web": ["10.50", "11.00"], + } + if option_group_kwargs["name"] in self.option_groups: + raise RDSClientError( + "OptionGroupAlreadyExistsFault", + "An option group named {0} already exists.".format( + option_group_kwargs["name"] + ), + ) + if ( + "description" not in option_group_kwargs + or not option_group_kwargs["description"] + ): + raise RDSClientError( + "InvalidParameterValue", + "The parameter OptionGroupDescription must be provided and must not be blank.", + ) + if option_group_kwargs["engine_name"] not in valid_option_group_engines.keys(): + raise RDSClientError( + "InvalidParameterValue", "Invalid DB engine: non-existant" + ) + if ( + option_group_kwargs["major_engine_version"] + not in valid_option_group_engines[option_group_kwargs["engine_name"]] + ): + raise RDSClientError( + "InvalidParameterCombination", + "Cannot find major version {0} for {1}".format( + option_group_kwargs["major_engine_version"], + option_group_kwargs["engine_name"], + ), + ) option_group = OptionGroup(**option_group_kwargs) self.option_groups[option_group_id] = option_group return option_group @@ -1008,82 +1012,129 @@ class RDS2Backend(BaseBackend): def describe_option_groups(self, option_group_kwargs): option_group_list = [] - if option_group_kwargs['marker']: - marker = option_group_kwargs['marker'] + if option_group_kwargs["marker"]: + marker = option_group_kwargs["marker"] else: marker = 0 - if option_group_kwargs['max_records']: - if option_group_kwargs['max_records'] < 20 or option_group_kwargs['max_records'] > 100: - raise RDSClientError('InvalidParameterValue', - 'Invalid value for max records. Must be between 20 and 100') - max_records = option_group_kwargs['max_records'] + if option_group_kwargs["max_records"]: + if ( + option_group_kwargs["max_records"] < 20 + or option_group_kwargs["max_records"] > 100 + ): + raise RDSClientError( + "InvalidParameterValue", + "Invalid value for max records. Must be between 20 and 100", + ) + max_records = option_group_kwargs["max_records"] else: max_records = 100 for option_group_name, option_group in self.option_groups.items(): - if option_group_kwargs['name'] and option_group.name != option_group_kwargs['name']: + if ( + option_group_kwargs["name"] + and option_group.name != option_group_kwargs["name"] + ): continue - elif option_group_kwargs['engine_name'] and \ - option_group.engine_name != option_group_kwargs['engine_name']: + elif ( + option_group_kwargs["engine_name"] + and option_group.engine_name != option_group_kwargs["engine_name"] + ): continue - elif option_group_kwargs['major_engine_version'] and \ - option_group.major_engine_version != option_group_kwargs['major_engine_version']: + elif ( + option_group_kwargs["major_engine_version"] + and option_group.major_engine_version + != option_group_kwargs["major_engine_version"] + ): continue else: option_group_list.append(option_group) if not len(option_group_list): - raise OptionGroupNotFoundFaultError(option_group_kwargs['name']) - return option_group_list[marker:max_records + marker] + raise OptionGroupNotFoundFaultError(option_group_kwargs["name"]) + return option_group_list[marker : max_records + marker] @staticmethod def describe_option_group_options(engine_name, major_engine_version=None): - default_option_group_options = {'mysql': {'5.6': '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'oracle-ee': {'11.2': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'oracle-sa': {'11.2': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'oracle-sa1': {'11.2': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}, - 'sqlserver-ee': {'10.50': '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - '11.00': '\n \n \n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', - 'all': '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n'}} + default_option_group_options = { + "mysql": { + "5.6": '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 5.611211TrueInnodb Memcached for MySQLMEMCACHED1-4294967295STATIC1TrueSpecifies how many memcached read operations (get) to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_R_BATCH_SIZE1-4294967295STATIC1TrueSpecifies how many memcached write operations, such as add, set, or incr, to perform before doing a COMMIT to start a new transactionDAEMON_MEMCACHED_W_BATCH_SIZE1-1073741824DYNAMIC5TrueSpecifies how often to auto-commit idle connections that use the InnoDB memcached interface.INNODB_API_BK_COMMIT_INTERVAL0,1STATIC0TrueDisables the use of row locks when using the InnoDB memcached interface.INNODB_API_DISABLE_ROWLOCK0,1STATIC0TrueLocks the table used by the InnoDB memcached plugin, so that it cannot be dropped or altered by DDL through the SQL interface.INNODB_API_ENABLE_MDL0-3STATIC0TrueLets you control the transaction isolation level on queries processed by the memcached interface.INNODB_API_TRX_LEVELauto,ascii,binarySTATICautoTrueThe binding protocol to use which can be either auto, ascii, or binary. The default is auto which means the server automatically negotiates the protocol with the client.BINDING_PROTOCOL1-2048STATIC1024TrueThe backlog queue configures how many network connections can be waiting to be processed by memcachedBACKLOG_QUEUE_LIMIT0,1STATIC0TrueDisable the use of compare and swap (CAS) which reduces the per-item size by 8 bytes.CAS_DISABLED1-48STATIC48TrueMinimum chunk size in bytes to allocate for the smallest item\'s key, value, and flags. The default is 48 and you can get a significant memory efficiency gain with a lower value.CHUNK_SIZE1-2STATIC1.25TrueChunk size growth factor that controls the size of each successive chunk with each chunk growing times this amount larger than the previous chunk.CHUNK_SIZE_GROWTH_FACTOR0,1STATIC0TrueIf enabled when there is no more memory to store items, memcached will return an error rather than evicting items.ERROR_ON_MEMORY_EXHAUSTED10-1024STATIC1024TrueMaximum number of concurrent connections. Setting this value to anything less than 10 prevents MySQL from starting.MAX_SIMULTANEOUS_CONNECTIONSv,vv,vvvSTATICvTrueVerbose level for memcached.VERBOSITYmysql\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "oracle-ee": { + "11.2": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "oracle-sa": { + "11.2": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "oracle-sa1": { + "11.2": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 11.2XMLDBOracle Application Express Runtime EnvironmentAPEXoracle-ee\n \n 11.2APEXOracle Application Express Development EnvironmentAPEX-DEVoracle-ee\n \n 11.2Oracle Advanced Security - Native Network EncryptionNATIVE_NETWORK_ENCRYPTIONACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired encryption behaviorSQLNET.ENCRYPTION_SERVERACCEPTED,REJECTED,REQUESTED,REQUIREDSTATICREQUESTEDTrueSpecifies the desired data integrity behaviorSQLNET.CRYPTO_CHECKSUM_SERVERRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40STATICRC4_256,AES256,AES192,3DES168,RC4_128,AES128,3DES112,RC4_56,DES,RC4_40,DES40TrueSpecifies list of encryption algorithms in order of intended useSQLNET.ENCRYPTION_TYPES_SERVERSHA1,MD5STATICSHA1,MD5TrueSpecifies list of checksumming algorithms in order of intended useSQLNET.CRYPTO_CHECKSUM_TYPES_SERVERoracle-ee\n \n 11.21158TrueOracle Enterprise Manager (Database Control only)OEMoracle-ee\n \n 11.2Oracle StatspackSTATSPACKoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - Transparent Data EncryptionTDEoracle-ee\n \n 11.2TrueTrueOracle Advanced Security - TDE with HSMTDE_HSMoracle-ee\n \n 11.2TrueTrueChange time zoneTimezoneAfrica/Cairo,Africa/Casablanca,Africa/Harare,Africa/Monrovia,Africa/Nairobi,Africa/Tripoli,Africa/Windhoek,America/Araguaina,America/Asuncion,America/Bogota,America/Caracas,America/Chihuahua,America/Cuiaba,America/Denver,America/Fortaleza,America/Guatemala,America/Halifax,America/Manaus,America/Matamoros,America/Monterrey,America/Montevideo,America/Phoenix,America/Santiago,America/Tijuana,Asia/Amman,Asia/Ashgabat,Asia/Baghdad,Asia/Baku,Asia/Bangkok,Asia/Beirut,Asia/Calcutta,Asia/Damascus,Asia/Dhaka,Asia/Irkutsk,Asia/Jerusalem,Asia/Kabul,Asia/Karachi,Asia/Kathmandu,Asia/Krasnoyarsk,Asia/Magadan,Asia/Muscat,Asia/Novosibirsk,Asia/Riyadh,Asia/Seoul,Asia/Shanghai,Asia/Singapore,Asia/Taipei,Asia/Tehran,Asia/Tokyo,Asia/Ulaanbaatar,Asia/Vladivostok,Asia/Yakutsk,Asia/Yerevan,Atlantic/Azores,Australia/Adelaide,Australia/Brisbane,Australia/Darwin,Australia/Hobart,Australia/Perth,Australia/Sydney,Brazil/East,Canada/Newfoundland,Canada/Saskatchewan,Europe/Amsterdam,Europe/Athens,Europe/Dublin,Europe/Helsinki,Europe/Istanbul,Europe/Kaliningrad,Europe/Moscow,Europe/Paris,Europe/Prague,Europe/Sarajevo,Pacific/Auckland,Pacific/Fiji,Pacific/Guam,Pacific/Honolulu,Pacific/Samoa,US/Alaska,US/Central,US/Eastern,US/East-Indiana,US/Pacific,UTCDYNAMICUTCTrueSpecifies the timezone the user wants to change the system time toTIME_ZONEoracle-ee\n \n 11.2Oracle XMLDB RepositoryXMLDBoracle-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + "sqlserver-ee": { + "10.50": '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "11.00": '\n \n \n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + "all": '\n \n \n \n 10.50SQLServer Database MirroringMirroringsqlserver-ee\n \n 10.50TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n 11.00SQLServer Database MirroringMirroringsqlserver-ee\n \n 11.00TrueSQL Server - Transparent Data EncryptionTDEsqlserver-ee\n \n \n \n \n 457f7bb8-9fbf-11e4-9084-5754f80d5144\n \n', + }, + } if engine_name not in default_option_group_options: - raise RDSClientError('InvalidParameterValue', - 'Invalid DB engine: {0}'.format(engine_name)) - if major_engine_version and major_engine_version not in default_option_group_options[engine_name]: - raise RDSClientError('InvalidParameterCombination', - 'Cannot find major version {0} for {1}'.format(major_engine_version, engine_name)) + raise RDSClientError( + "InvalidParameterValue", "Invalid DB engine: {0}".format(engine_name) + ) + if ( + major_engine_version + and major_engine_version not in default_option_group_options[engine_name] + ): + raise RDSClientError( + "InvalidParameterCombination", + "Cannot find major version {0} for {1}".format( + major_engine_version, engine_name + ), + ) if major_engine_version: return default_option_group_options[engine_name][major_engine_version] - return default_option_group_options[engine_name]['all'] + return default_option_group_options[engine_name]["all"] - def modify_option_group(self, option_group_name, options_to_include=None, options_to_remove=None, apply_immediately=None): + def modify_option_group( + self, + option_group_name, + options_to_include=None, + options_to_remove=None, + apply_immediately=None, + ): if option_group_name not in self.option_groups: raise OptionGroupNotFoundFaultError(option_group_name) if not options_to_include and not options_to_remove: - raise RDSClientError('InvalidParameterValue', - 'At least one option must be added, modified, or removed.') + raise RDSClientError( + "InvalidParameterValue", + "At least one option must be added, modified, or removed.", + ) if options_to_remove: - self.option_groups[option_group_name].remove_options( - options_to_remove) + self.option_groups[option_group_name].remove_options(options_to_remove) if options_to_include: - self.option_groups[option_group_name].add_options( - options_to_include) + self.option_groups[option_group_name].add_options(options_to_include) return self.option_groups[option_group_name] def create_db_parameter_group(self, db_parameter_group_kwargs): - db_parameter_group_id = db_parameter_group_kwargs['name'] - if db_parameter_group_kwargs['name'] in self.db_parameter_groups: - raise RDSClientError('DBParameterGroupAlreadyExistsFault', - 'A DB parameter group named {0} already exists.'.format(db_parameter_group_kwargs['name'])) - if not db_parameter_group_kwargs.get('description'): - raise RDSClientError('InvalidParameterValue', - 'The parameter Description must be provided and must not be blank.') - if not db_parameter_group_kwargs.get('family'): - raise RDSClientError('InvalidParameterValue', - 'The parameter DBParameterGroupName must be provided and must not be blank.') + db_parameter_group_id = db_parameter_group_kwargs["name"] + if db_parameter_group_kwargs["name"] in self.db_parameter_groups: + raise RDSClientError( + "DBParameterGroupAlreadyExistsFault", + "A DB parameter group named {0} already exists.".format( + db_parameter_group_kwargs["name"] + ), + ) + if not db_parameter_group_kwargs.get("description"): + raise RDSClientError( + "InvalidParameterValue", + "The parameter Description must be provided and must not be blank.", + ) + if not db_parameter_group_kwargs.get("family"): + raise RDSClientError( + "InvalidParameterValue", + "The parameter DBParameterGroupName must be provided and must not be blank.", + ) db_parameter_group = DBParameterGroup(**db_parameter_group_kwargs) self.db_parameter_groups[db_parameter_group_id] = db_parameter_group @@ -1092,27 +1143,39 @@ class RDS2Backend(BaseBackend): def describe_db_parameter_groups(self, db_parameter_group_kwargs): db_parameter_group_list = [] - if db_parameter_group_kwargs.get('marker'): - marker = db_parameter_group_kwargs['marker'] + if db_parameter_group_kwargs.get("marker"): + marker = db_parameter_group_kwargs["marker"] else: marker = 0 - if db_parameter_group_kwargs.get('max_records'): - if db_parameter_group_kwargs['max_records'] < 20 or db_parameter_group_kwargs['max_records'] > 100: - raise RDSClientError('InvalidParameterValue', - 'Invalid value for max records. Must be between 20 and 100') - max_records = db_parameter_group_kwargs['max_records'] + if db_parameter_group_kwargs.get("max_records"): + if ( + db_parameter_group_kwargs["max_records"] < 20 + or db_parameter_group_kwargs["max_records"] > 100 + ): + raise RDSClientError( + "InvalidParameterValue", + "Invalid value for max records. Must be between 20 and 100", + ) + max_records = db_parameter_group_kwargs["max_records"] else: max_records = 100 - for db_parameter_group_name, db_parameter_group in self.db_parameter_groups.items(): - if not db_parameter_group_kwargs.get('name') or db_parameter_group.name == db_parameter_group_kwargs.get('name'): + for ( + db_parameter_group_name, + db_parameter_group, + ) in self.db_parameter_groups.items(): + if not db_parameter_group_kwargs.get( + "name" + ) or db_parameter_group.name == db_parameter_group_kwargs.get("name"): db_parameter_group_list.append(db_parameter_group) else: continue - return db_parameter_group_list[marker:max_records + marker] + return db_parameter_group_list[marker : max_records + marker] - def modify_db_parameter_group(self, db_parameter_group_name, db_parameter_group_parameters): + def modify_db_parameter_group( + self, db_parameter_group_name, db_parameter_group_parameters + ): if db_parameter_group_name not in self.db_parameter_groups: raise DBParameterGroupNotFoundError(db_parameter_group_name) @@ -1123,103 +1186,105 @@ class RDS2Backend(BaseBackend): def list_tags_for_resource(self, arn): if self.arn_regex.match(arn): - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == 'db': # Database + if resource_type == "db": # Database if resource_name in self.databases: return self.databases[resource_name].get_tags() - elif resource_type == 'es': # Event Subscription + elif resource_type == "es": # Event Subscription # TODO: Complete call to tags on resource type Event # Subscription return [] - elif resource_type == 'og': # Option Group + elif resource_type == "og": # Option Group if resource_name in self.option_groups: return self.option_groups[resource_name].get_tags() - elif resource_type == 'pg': # Parameter Group + elif resource_type == "pg": # Parameter Group if resource_name in self.db_parameter_groups: return self.db_parameter_groups[resource_name].get_tags() - elif resource_type == 'ri': # Reserved DB instance + elif resource_type == "ri": # Reserved DB instance # TODO: Complete call to tags on resource type Reserved DB # instance return [] - elif resource_type == 'secgrp': # DB security group + elif resource_type == "secgrp": # DB security group if resource_name in self.security_groups: return self.security_groups[resource_name].get_tags() - elif resource_type == 'snapshot': # DB Snapshot + elif resource_type == "snapshot": # DB Snapshot if resource_name in self.snapshots: return self.snapshots[resource_name].get_tags() - elif resource_type == 'subgrp': # DB subnet group + elif resource_type == "subgrp": # DB subnet group if resource_name in self.subnet_groups: return self.subnet_groups[resource_name].get_tags() else: - raise RDSClientError('InvalidParameterValue', - 'Invalid resource name: {0}'.format(arn)) + raise RDSClientError( + "InvalidParameterValue", "Invalid resource name: {0}".format(arn) + ) return [] def remove_tags_from_resource(self, arn, tag_keys): if self.arn_regex.match(arn): - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == 'db': # Database + if resource_type == "db": # Database if resource_name in self.databases: self.databases[resource_name].remove_tags(tag_keys) - elif resource_type == 'es': # Event Subscription + elif resource_type == "es": # Event Subscription return None - elif resource_type == 'og': # Option Group + elif resource_type == "og": # Option Group if resource_name in self.option_groups: return self.option_groups[resource_name].remove_tags(tag_keys) - elif resource_type == 'pg': # Parameter Group + elif resource_type == "pg": # Parameter Group return None - elif resource_type == 'ri': # Reserved DB instance + elif resource_type == "ri": # Reserved DB instance return None - elif resource_type == 'secgrp': # DB security group + elif resource_type == "secgrp": # DB security group if resource_name in self.security_groups: return self.security_groups[resource_name].remove_tags(tag_keys) - elif resource_type == 'snapshot': # DB Snapshot + elif resource_type == "snapshot": # DB Snapshot if resource_name in self.snapshots: return self.snapshots[resource_name].remove_tags(tag_keys) - elif resource_type == 'subgrp': # DB subnet group + elif resource_type == "subgrp": # DB subnet group if resource_name in self.subnet_groups: return self.subnet_groups[resource_name].remove_tags(tag_keys) else: - raise RDSClientError('InvalidParameterValue', - 'Invalid resource name: {0}'.format(arn)) + raise RDSClientError( + "InvalidParameterValue", "Invalid resource name: {0}".format(arn) + ) def add_tags_to_resource(self, arn, tags): if self.arn_regex.match(arn): - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[len(arn_breakdown) - 2] resource_name = arn_breakdown[len(arn_breakdown) - 1] - if resource_type == 'db': # Database + if resource_type == "db": # Database if resource_name in self.databases: return self.databases[resource_name].add_tags(tags) - elif resource_type == 'es': # Event Subscription + elif resource_type == "es": # Event Subscription return [] - elif resource_type == 'og': # Option Group + elif resource_type == "og": # Option Group if resource_name in self.option_groups: return self.option_groups[resource_name].add_tags(tags) - elif resource_type == 'pg': # Parameter Group + elif resource_type == "pg": # Parameter Group return [] - elif resource_type == 'ri': # Reserved DB instance + elif resource_type == "ri": # Reserved DB instance return [] - elif resource_type == 'secgrp': # DB security group + elif resource_type == "secgrp": # DB security group if resource_name in self.security_groups: return self.security_groups[resource_name].add_tags(tags) - elif resource_type == 'snapshot': # DB Snapshot + elif resource_type == "snapshot": # DB Snapshot if resource_name in self.snapshots: return self.snapshots[resource_name].add_tags(tags) - elif resource_type == 'subgrp': # DB subnet group + elif resource_type == "subgrp": # DB subnet group if resource_name in self.subnet_groups: return self.subnet_groups[resource_name].add_tags(tags) else: - raise RDSClientError('InvalidParameterValue', - 'Invalid resource name: {0}'.format(arn)) + raise RDSClientError( + "InvalidParameterValue", "Invalid resource name: {0}".format(arn) + ) class OptionGroup(object): - def __init__(self, name, engine_name, major_engine_version, description=None): self.engine_name = engine_name self.major_engine_version = major_engine_version @@ -1227,11 +1292,12 @@ class OptionGroup(object): self.name = name self.vpc_and_non_vpc_instance_memberships = False self.options = {} - self.vpcId = 'null' + self.vpcId = "null" self.tags = [] def to_json(self): - template = Template("""{ + template = Template( + """{ "VpcId": null, "MajorEngineVersion": "{{ option_group.major_engine_version }}", "OptionGroupDescription": "{{ option_group.description }}", @@ -1239,18 +1305,21 @@ class OptionGroup(object): "EngineName": "{{ option_group.engine_name }}", "Options": [], "OptionGroupName": "{{ option_group.name }}" -}""") +}""" + ) return template.render(option_group=self) def to_xml(self): - template = Template(""" + template = Template( + """ {{ option_group.name }} {{ option_group.vpc_and_non_vpc_instance_memberships }} {{ option_group.major_engine_version }} {{ option_group.engine_name }} {{ option_group.description }} - """) + """ + ) return template.render(option_group=self) def remove_options(self, options_to_remove): @@ -1267,37 +1336,39 @@ class OptionGroup(object): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] class OptionGroupOption(object): - def __init__(self, **kwargs): - self.default_port = kwargs.get('default_port') - self.description = kwargs.get('description') - self.engine_name = kwargs.get('engine_name') - self.major_engine_version = kwargs.get('major_engine_version') - self.name = kwargs.get('name') + self.default_port = kwargs.get("default_port") + self.description = kwargs.get("description") + self.engine_name = kwargs.get("engine_name") + self.major_engine_version = kwargs.get("major_engine_version") + self.name = kwargs.get("name") self.option_group_option_settings = self._make_option_group_option_settings( - kwargs.get('option_group_option_settings', [])) - self.options_depended_on = kwargs.get('options_depended_on', []) - self.permanent = kwargs.get('permanent') - self.persistent = kwargs.get('persistent') - self.port_required = kwargs.get('port_required') + kwargs.get("option_group_option_settings", []) + ) + self.options_depended_on = kwargs.get("options_depended_on", []) + self.permanent = kwargs.get("permanent") + self.persistent = kwargs.get("persistent") + self.port_required = kwargs.get("port_required") def _make_option_group_option_settings(self, option_group_option_settings_kwargs): - return [OptionGroupOptionSetting(**setting_kwargs) for setting_kwargs in option_group_option_settings_kwargs] + return [ + OptionGroupOptionSetting(**setting_kwargs) + for setting_kwargs in option_group_option_settings_kwargs + ] def to_json(self): - template = Template("""{ "MinimumRequiredMinorEngineVersion": + template = Template( + """{ "MinimumRequiredMinorEngineVersion": "2789.0.v1", "OptionsDependedOn": [], "MajorEngineVersion": "10.50", @@ -1309,11 +1380,13 @@ class OptionGroupOption(object): "Name": "Mirroring", "PortRequired": false, "Description": "SQLServer Database Mirroring" - }""") + }""" + ) return template.render(option_group=self) def to_xml(self): - template = Template(""" + template = Template( + """ {{ option_group.major_engine_version }} {{ option_group.default_port }} {{ option_group.port_required }} @@ -1333,34 +1406,35 @@ class OptionGroupOption(object): {{ option_group.engine_name }} {{ option_group.minimum_required_minor_engine_version }} -""") +""" + ) return template.render(option_group=self) class OptionGroupOptionSetting(object): - def __init__(self, *kwargs): - self.allowed_values = kwargs.get('allowed_values') - self.apply_type = kwargs.get('apply_type') - self.default_value = kwargs.get('default_value') - self.is_modifiable = kwargs.get('is_modifiable') - self.setting_description = kwargs.get('setting_description') - self.setting_name = kwargs.get('setting_name') + self.allowed_values = kwargs.get("allowed_values") + self.apply_type = kwargs.get("apply_type") + self.default_value = kwargs.get("default_value") + self.is_modifiable = kwargs.get("is_modifiable") + self.setting_description = kwargs.get("setting_description") + self.setting_name = kwargs.get("setting_name") def to_xml(self): - template = Template(""" + template = Template( + """ {{ option_group_option_setting.allowed_values }} {{ option_group_option_setting.apply_type }} {{ option_group_option_setting.default_value }} {{ option_group_option_setting.is_modifiable }} {{ option_group_option_setting.setting_description }} {{ option_group_option_setting.setting_name }} -""") +""" + ) return template.render(option_group_option_setting=self) class DBParameterGroup(object): - def __init__(self, name, description, family, tags): self.name = name self.description = description @@ -1369,30 +1443,30 @@ class DBParameterGroup(object): self.parameters = defaultdict(dict) def to_xml(self): - template = Template(""" + template = Template( + """ {{ param_group.name }} {{ param_group.family }} {{ param_group.description }} - """) + """ + ) return template.render(param_group=self) def get_tags(self): return self.tags def add_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def remove_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags if tag_set[ - 'Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] def update_parameters(self, new_parameters): for new_parameter in new_parameters: - parameter = self.parameters[new_parameter['ParameterName']] + parameter = self.parameters[new_parameter["ParameterName"]] parameter.update(new_parameter) def delete(self, region_name): @@ -1400,28 +1474,33 @@ class DBParameterGroup(object): backend.delete_db_parameter_group(self.name) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] db_parameter_group_kwargs = { - 'description': properties['Description'], - 'family': properties['Family'], - 'name': resource_name.lower(), - 'tags': properties.get("Tags"), + "description": properties["Description"], + "family": properties["Family"], + "name": resource_name.lower(), + "tags": properties.get("Tags"), } db_parameter_group_parameters = [] - for db_parameter, db_parameter_value in properties.get('Parameters', {}).items(): - db_parameter_group_parameters.append({ - 'ParameterName': db_parameter, - 'ParameterValue': db_parameter_value, - }) + for db_parameter, db_parameter_value in properties.get( + "Parameters", {} + ).items(): + db_parameter_group_parameters.append( + {"ParameterName": db_parameter, "ParameterValue": db_parameter_value} + ) rds2_backend = rds2_backends[region_name] db_parameter_group = rds2_backend.create_db_parameter_group( - db_parameter_group_kwargs) + db_parameter_group_kwargs + ) db_parameter_group.update_parameters(db_parameter_group_parameters) return db_parameter_group -rds2_backends = dict((region.name, RDS2Backend(region.name)) - for region in boto.rds2.regions()) +rds2_backends = dict( + (region.name, RDS2Backend(region.name)) for region in boto.rds2.regions() +) diff --git a/moto/rds2/responses.py b/moto/rds2/responses.py index fdba73248..625838d4d 100644 --- a/moto/rds2/responses.py +++ b/moto/rds2/responses.py @@ -8,87 +8,90 @@ from .exceptions import DBParameterGroupNotFoundError class RDS2Response(BaseResponse): - @property def backend(self): return rds2_backends[self.region] def _get_db_kwargs(self): args = { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), - "allocated_storage": self._get_int_param('AllocatedStorage'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), + "allocated_storage": self._get_int_param("AllocatedStorage"), "availability_zone": self._get_param("AvailabilityZone"), "backup_retention_period": self._get_param("BackupRetentionPeriod"), "copy_tags_to_snapshot": self._get_param("CopyTagsToSnapshot"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_name": self._get_param("DBName"), "db_parameter_group_name": self._get_param("DBParameterGroupName"), - "db_snapshot_identifier": self._get_param('DBSnapshotIdentifier'), + "db_snapshot_identifier": self._get_param("DBSnapshotIdentifier"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"), "engine": self._get_param("Engine"), "engine_version": self._get_param("EngineVersion"), "license_model": self._get_param("LicenseModel"), "iops": self._get_int_param("Iops"), "kms_key_id": self._get_param("KmsKeyId"), - "master_user_password": self._get_param('MasterUserPassword'), - "master_username": self._get_param('MasterUsername'), + "master_user_password": self._get_param("MasterUserPassword"), + "master_username": self._get_param("MasterUsername"), "multi_az": self._get_bool_param("MultiAZ"), "option_group_name": self._get_param("OptionGroupName"), - "port": self._get_param('Port'), + "port": self._get_param("Port"), # PreferredBackupWindow # PreferredMaintenanceWindow "publicly_accessible": self._get_param("PubliclyAccessible"), "region": self.region, - "security_groups": self._get_multi_param('DBSecurityGroups.DBSecurityGroupName'), + "security_groups": self._get_multi_param( + "DBSecurityGroups.DBSecurityGroupName" + ), "storage_encrypted": self._get_param("StorageEncrypted"), "storage_type": self._get_param("StorageType", None), - "vpc_security_group_ids": self._get_multi_param("VpcSecurityGroupIds.VpcSecurityGroupId"), + "vpc_security_group_ids": self._get_multi_param( + "VpcSecurityGroupIds.VpcSecurityGroupId" + ), "tags": list(), } - args['tags'] = self.unpack_complex_list_params( - 'Tags.Tag', ('Key', 'Value')) + args["tags"] = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) return args def _get_db_replica_kwargs(self): return { - "auto_minor_version_upgrade": self._get_param('AutoMinorVersionUpgrade'), + "auto_minor_version_upgrade": self._get_param("AutoMinorVersionUpgrade"), "availability_zone": self._get_param("AvailabilityZone"), - "db_instance_class": self._get_param('DBInstanceClass'), - "db_instance_identifier": self._get_param('DBInstanceIdentifier'), + "db_instance_class": self._get_param("DBInstanceClass"), + "db_instance_identifier": self._get_param("DBInstanceIdentifier"), "db_subnet_group_name": self._get_param("DBSubnetGroupName"), "iops": self._get_int_param("Iops"), # OptionGroupName - "port": self._get_param('Port'), + "port": self._get_param("Port"), "publicly_accessible": self._get_param("PubliclyAccessible"), - "source_db_identifier": self._get_param('SourceDBInstanceIdentifier'), + "source_db_identifier": self._get_param("SourceDBInstanceIdentifier"), "storage_type": self._get_param("StorageType"), } def _get_option_group_kwargs(self): return { - 'major_engine_version': self._get_param('MajorEngineVersion'), - 'description': self._get_param('OptionGroupDescription'), - 'engine_name': self._get_param('EngineName'), - 'name': self._get_param('OptionGroupName') + "major_engine_version": self._get_param("MajorEngineVersion"), + "description": self._get_param("OptionGroupDescription"), + "engine_name": self._get_param("EngineName"), + "name": self._get_param("OptionGroupName"), } def _get_db_parameter_group_kwargs(self): return { - 'description': self._get_param('Description'), - 'family': self._get_param('DBParameterGroupFamily'), - 'name': self._get_param('DBParameterGroupName'), - 'tags': self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')), + "description": self._get_param("Description"), + "family": self._get_param("DBParameterGroupFamily"), + "name": self._get_param("DBParameterGroupName"), + "tags": self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")), } def unpack_complex_list_params(self, label, names): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}.{2}'.format(label, count, names[0])): + while self._get_param("{0}.{1}.{2}".format(label, count, names[0])): param = dict() for i in range(len(names)): param[names[i]] = self._get_param( - '{0}.{1}.{2}'.format(label, count, names[i])) + "{0}.{1}.{2}".format(label, count, names[i]) + ) unpacked_list.append(param) count += 1 return unpacked_list @@ -96,9 +99,8 @@ class RDS2Response(BaseResponse): def unpack_list_params(self, label): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}'.format(label, count)): - unpacked_list.append(self._get_param( - '{0}.{1}'.format(label, count))) + while self._get_param("{0}.{1}".format(label, count)): + unpacked_list.append(self._get_param("{0}.{1}".format(label, count))) count += 1 return unpacked_list @@ -116,16 +118,18 @@ class RDS2Response(BaseResponse): return template.render(database=database) def describe_db_instances(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") all_instances = list(self.backend.describe_databases(db_instance_identifier)) - marker = self._get_param('Marker') + marker = self._get_param("Marker") all_ids = [instance.db_instance_identifier for instance in all_instances] if marker: start = all_ids.index(marker) + 1 else: start = 0 - page_size = self._get_int_param('MaxRecords', 50) # the default is 100, but using 50 to make testing easier - instances_resp = all_instances[start:start + page_size] + page_size = self._get_int_param( + "MaxRecords", 50 + ) # the default is 100, but using 50 to make testing easier + instances_resp = all_instances[start : start + page_size] next_marker = None if len(all_instances) > start + page_size: next_marker = instances_resp[-1].db_instance_identifier @@ -134,134 +138,143 @@ class RDS2Response(BaseResponse): return template.render(databases=instances_resp, marker=next_marker) def modify_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") db_kwargs = self._get_db_kwargs() - new_db_instance_identifier = self._get_param('NewDBInstanceIdentifier') + new_db_instance_identifier = self._get_param("NewDBInstanceIdentifier") if new_db_instance_identifier: - db_kwargs['new_db_instance_identifier'] = new_db_instance_identifier - database = self.backend.modify_database( - db_instance_identifier, db_kwargs) + db_kwargs["new_db_instance_identifier"] = new_db_instance_identifier + database = self.backend.modify_database(db_instance_identifier, db_kwargs) template = self.response_template(MODIFY_DATABASE_TEMPLATE) return template.render(database=database) def delete_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_name = self._get_param('FinalDBSnapshotIdentifier') - database = self.backend.delete_database(db_instance_identifier, db_snapshot_name) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_name = self._get_param("FinalDBSnapshotIdentifier") + database = self.backend.delete_database( + db_instance_identifier, db_snapshot_name + ) template = self.response_template(DELETE_DATABASE_TEMPLATE) return template.render(database=database) def reboot_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.reboot_db_instance(db_instance_identifier) template = self.response_template(REBOOT_DATABASE_TEMPLATE) return template.render(database=database) def create_db_snapshot(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) - snapshot = self.backend.create_snapshot(db_instance_identifier, db_snapshot_identifier, tags) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) + snapshot = self.backend.create_snapshot( + db_instance_identifier, db_snapshot_identifier, tags + ) template = self.response_template(CREATE_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) def describe_db_snapshots(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') - snapshots = self.backend.describe_snapshots(db_instance_identifier, db_snapshot_identifier) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") + snapshots = self.backend.describe_snapshots( + db_instance_identifier, db_snapshot_identifier + ) template = self.response_template(DESCRIBE_SNAPSHOTS_TEMPLATE) return template.render(snapshots=snapshots) def delete_db_snapshot(self): - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") snapshot = self.backend.delete_snapshot(db_snapshot_identifier) template = self.response_template(DELETE_SNAPSHOT_TEMPLATE) return template.render(snapshot=snapshot) def list_tags_for_resource(self): - arn = self._get_param('ResourceName') + arn = self._get_param("ResourceName") template = self.response_template(LIST_TAGS_FOR_RESOURCE_TEMPLATE) tags = self.backend.list_tags_for_resource(arn) return template.render(tags=tags) def add_tags_to_resource(self): - arn = self._get_param('ResourceName') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + arn = self._get_param("ResourceName") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) tags = self.backend.add_tags_to_resource(arn, tags) template = self.response_template(ADD_TAGS_TO_RESOURCE_TEMPLATE) return template.render(tags=tags) def remove_tags_from_resource(self): - arn = self._get_param('ResourceName') - tag_keys = self.unpack_list_params('TagKeys.member') + arn = self._get_param("ResourceName") + tag_keys = self.unpack_list_params("TagKeys.member") self.backend.remove_tags_from_resource(arn, tag_keys) template = self.response_template(REMOVE_TAGS_FROM_RESOURCE_TEMPLATE) return template.render() def stop_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') - db_snapshot_identifier = self._get_param('DBSnapshotIdentifier') - database = self.backend.stop_database(db_instance_identifier, db_snapshot_identifier) + db_instance_identifier = self._get_param("DBInstanceIdentifier") + db_snapshot_identifier = self._get_param("DBSnapshotIdentifier") + database = self.backend.stop_database( + db_instance_identifier, db_snapshot_identifier + ) template = self.response_template(STOP_DATABASE_TEMPLATE) return template.render(database=database) def start_db_instance(self): - db_instance_identifier = self._get_param('DBInstanceIdentifier') + db_instance_identifier = self._get_param("DBInstanceIdentifier") database = self.backend.start_database(db_instance_identifier) template = self.response_template(START_DATABASE_TEMPLATE) return template.render(database=database) def create_db_security_group(self): - group_name = self._get_param('DBSecurityGroupName') - description = self._get_param('DBSecurityGroupDescription') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + group_name = self._get_param("DBSecurityGroupName") + description = self._get_param("DBSecurityGroupDescription") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) security_group = self.backend.create_security_group( - group_name, description, tags) + group_name, description, tags + ) template = self.response_template(CREATE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def describe_db_security_groups(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_groups = self.backend.describe_security_groups( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_groups = self.backend.describe_security_groups(security_group_name) template = self.response_template(DESCRIBE_SECURITY_GROUPS_TEMPLATE) return template.render(security_groups=security_groups) def delete_db_security_group(self): - security_group_name = self._get_param('DBSecurityGroupName') - security_group = self.backend.delete_security_group( - security_group_name) + security_group_name = self._get_param("DBSecurityGroupName") + security_group = self.backend.delete_security_group(security_group_name) template = self.response_template(DELETE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def authorize_db_security_group_ingress(self): - security_group_name = self._get_param('DBSecurityGroupName') - cidr_ip = self._get_param('CIDRIP') + security_group_name = self._get_param("DBSecurityGroupName") + cidr_ip = self._get_param("CIDRIP") security_group = self.backend.authorize_security_group( - security_group_name, cidr_ip) + security_group_name, cidr_ip + ) template = self.response_template(AUTHORIZE_SECURITY_GROUP_TEMPLATE) return template.render(security_group=security_group) def create_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') - description = self._get_param('DBSubnetGroupDescription') - subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) - subnets = [ec2_backends[self.region].get_subnet( - subnet_id) for subnet_id in subnet_ids] + subnet_name = self._get_param("DBSubnetGroupName") + description = self._get_param("DBSubnetGroupDescription") + subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) + subnets = [ + ec2_backends[self.region].get_subnet(subnet_id) for subnet_id in subnet_ids + ] subnet_group = self.backend.create_subnet_group( - subnet_name, description, subnets, tags) + subnet_name, description, subnets, tags + ) template = self.response_template(CREATE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) def describe_db_subnet_groups(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_groups = self.backend.describe_subnet_groups(subnet_name) template = self.response_template(DESCRIBE_SUBNET_GROUPS_TEMPLATE) return template.render(subnet_groups=subnet_groups) def delete_db_subnet_group(self): - subnet_name = self._get_param('DBSubnetGroupName') + subnet_name = self._get_param("DBSubnetGroupName") subnet_group = self.backend.delete_subnet_group(subnet_name) template = self.response_template(DELETE_SUBNET_GROUP_TEMPLATE) return template.render(subnet_group=subnet_group) @@ -274,50 +287,67 @@ class RDS2Response(BaseResponse): def delete_option_group(self): kwargs = self._get_option_group_kwargs() - option_group = self.backend.delete_option_group(kwargs['name']) + option_group = self.backend.delete_option_group(kwargs["name"]) template = self.response_template(DELETE_OPTION_GROUP_TEMPLATE) return template.render(option_group=option_group) def describe_option_groups(self): kwargs = self._get_option_group_kwargs() - kwargs['max_records'] = self._get_int_param('MaxRecords') - kwargs['marker'] = self._get_param('Marker') + kwargs["max_records"] = self._get_int_param("MaxRecords") + kwargs["marker"] = self._get_param("Marker") option_groups = self.backend.describe_option_groups(kwargs) template = self.response_template(DESCRIBE_OPTION_GROUP_TEMPLATE) return template.render(option_groups=option_groups) def describe_option_group_options(self): - engine_name = self._get_param('EngineName') - major_engine_version = self._get_param('MajorEngineVersion') + engine_name = self._get_param("EngineName") + major_engine_version = self._get_param("MajorEngineVersion") option_group_options = self.backend.describe_option_group_options( - engine_name, major_engine_version) + engine_name, major_engine_version + ) return option_group_options def modify_option_group(self): - option_group_name = self._get_param('OptionGroupName') + option_group_name = self._get_param("OptionGroupName") count = 1 options_to_include = [] - while self._get_param('OptionsToInclude.member.{0}.OptionName'.format(count)): - options_to_include.append({ - 'Port': self._get_param('OptionsToInclude.member.{0}.Port'.format(count)), - 'OptionName': self._get_param('OptionsToInclude.member.{0}.OptionName'.format(count)), - 'DBSecurityGroupMemberships': self._get_param('OptionsToInclude.member.{0}.DBSecurityGroupMemberships'.format(count)), - 'OptionSettings': self._get_param('OptionsToInclude.member.{0}.OptionSettings'.format(count)), - 'VpcSecurityGroupMemberships': self._get_param('OptionsToInclude.member.{0}.VpcSecurityGroupMemberships'.format(count)) - }) + while self._get_param("OptionsToInclude.member.{0}.OptionName".format(count)): + options_to_include.append( + { + "Port": self._get_param( + "OptionsToInclude.member.{0}.Port".format(count) + ), + "OptionName": self._get_param( + "OptionsToInclude.member.{0}.OptionName".format(count) + ), + "DBSecurityGroupMemberships": self._get_param( + "OptionsToInclude.member.{0}.DBSecurityGroupMemberships".format( + count + ) + ), + "OptionSettings": self._get_param( + "OptionsToInclude.member.{0}.OptionSettings".format(count) + ), + "VpcSecurityGroupMemberships": self._get_param( + "OptionsToInclude.member.{0}.VpcSecurityGroupMemberships".format( + count + ) + ), + } + ) count += 1 count = 1 options_to_remove = [] - while self._get_param('OptionsToRemove.member.{0}'.format(count)): - options_to_remove.append(self._get_param( - 'OptionsToRemove.member.{0}'.format(count))) + while self._get_param("OptionsToRemove.member.{0}".format(count)): + options_to_remove.append( + self._get_param("OptionsToRemove.member.{0}".format(count)) + ) count += 1 - apply_immediately = self._get_param('ApplyImmediately') - option_group = self.backend.modify_option_group(option_group_name, - options_to_include, - options_to_remove, - apply_immediately) + apply_immediately = self._get_param("ApplyImmediately") + option_group = self.backend.modify_option_group( + option_group_name, options_to_include, options_to_remove, apply_immediately + ) template = self.response_template(MODIFY_OPTION_GROUP_TEMPLATE) return template.render(option_group=option_group) @@ -329,28 +359,28 @@ class RDS2Response(BaseResponse): def describe_db_parameter_groups(self): kwargs = self._get_db_parameter_group_kwargs() - kwargs['max_records'] = self._get_int_param('MaxRecords') - kwargs['marker'] = self._get_param('Marker') + kwargs["max_records"] = self._get_int_param("MaxRecords") + kwargs["marker"] = self._get_param("Marker") db_parameter_groups = self.backend.describe_db_parameter_groups(kwargs) - template = self.response_template( - DESCRIBE_DB_PARAMETER_GROUPS_TEMPLATE) + template = self.response_template(DESCRIBE_DB_PARAMETER_GROUPS_TEMPLATE) return template.render(db_parameter_groups=db_parameter_groups) def modify_db_parameter_group(self): - db_parameter_group_name = self._get_param('DBParameterGroupName') + db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_group_parameters = self._get_db_parameter_group_paramters() - db_parameter_group = self.backend.modify_db_parameter_group(db_parameter_group_name, - db_parameter_group_parameters) + db_parameter_group = self.backend.modify_db_parameter_group( + db_parameter_group_name, db_parameter_group_parameters + ) template = self.response_template(MODIFY_DB_PARAMETER_GROUP_TEMPLATE) return template.render(db_parameter_group=db_parameter_group) def _get_db_parameter_group_paramters(self): parameter_group_parameters = defaultdict(dict) for param_name, value in self.querystring.items(): - if not param_name.startswith('Parameters.Parameter'): + if not param_name.startswith("Parameters.Parameter"): continue - split_param_name = param_name.split('.') + split_param_name = param_name.split(".") param_id = split_param_name[2] param_setting = split_param_name[3] @@ -359,9 +389,10 @@ class RDS2Response(BaseResponse): return parameter_group_parameters.values() def describe_db_parameters(self): - db_parameter_group_name = self._get_param('DBParameterGroupName') + db_parameter_group_name = self._get_param("DBParameterGroupName") db_parameter_groups = self.backend.describe_db_parameter_groups( - {'name': db_parameter_group_name}) + {"name": db_parameter_group_name} + ) if not db_parameter_groups: raise DBParameterGroupNotFoundError(db_parameter_group_name) @@ -370,8 +401,7 @@ class RDS2Response(BaseResponse): def delete_db_parameter_group(self): kwargs = self._get_db_parameter_group_kwargs() - db_parameter_group = self.backend.delete_db_parameter_group(kwargs[ - 'name']) + db_parameter_group = self.backend.delete_db_parameter_group(kwargs["name"]) template = self.response_template(DELETE_DB_PARAMETER_GROUP_TEMPLATE) return template.render(db_parameter_group=db_parameter_group) diff --git a/moto/rds2/urls.py b/moto/rds2/urls.py index d19dc2785..d937554e0 100644 --- a/moto/rds2/urls.py +++ b/moto/rds2/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import RDS2Response -url_bases = [ - "https?://rds.(.+).amazonaws.com", - "https?://rds.amazonaws.com", -] +url_bases = ["https?://rds.(.+).amazonaws.com", "https?://rds.amazonaws.com"] -url_paths = { - '{0}/$': RDS2Response.dispatch, -} +url_paths = {"{0}/$": RDS2Response.dispatch} diff --git a/moto/redshift/__init__.py b/moto/redshift/__init__.py index 06f778e8d..47cbf3b58 100644 --- a/moto/redshift/__init__.py +++ b/moto/redshift/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import redshift_backends from ..core.models import base_decorator, deprecated_base_decorator -redshift_backend = redshift_backends['us-east-1'] +redshift_backend = redshift_backends["us-east-1"] mock_redshift = base_decorator(redshift_backends) mock_redshift_deprecated = deprecated_base_decorator(redshift_backends) diff --git a/moto/redshift/exceptions.py b/moto/redshift/exceptions.py index b0cef57ad..0a17e8aab 100644 --- a/moto/redshift/exceptions.py +++ b/moto/redshift/exceptions.py @@ -5,94 +5,93 @@ from werkzeug.exceptions import BadRequest class RedshiftClientError(BadRequest): - def __init__(self, code, message): super(RedshiftClientError, self).__init__() - self.description = json.dumps({ - "Error": { - "Code": code, - "Message": message, - 'Type': 'Sender', - }, - 'RequestId': '6876f774-7273-11e4-85dc-39e55ca848d1', - }) + self.description = json.dumps( + { + "Error": {"Code": code, "Message": message, "Type": "Sender"}, + "RequestId": "6876f774-7273-11e4-85dc-39e55ca848d1", + } + ) class ClusterNotFoundError(RedshiftClientError): - def __init__(self, cluster_identifier): super(ClusterNotFoundError, self).__init__( - 'ClusterNotFound', - "Cluster {0} not found.".format(cluster_identifier)) + "ClusterNotFound", "Cluster {0} not found.".format(cluster_identifier) + ) class ClusterSubnetGroupNotFoundError(RedshiftClientError): - def __init__(self, subnet_identifier): super(ClusterSubnetGroupNotFoundError, self).__init__( - 'ClusterSubnetGroupNotFound', - "Subnet group {0} not found.".format(subnet_identifier)) + "ClusterSubnetGroupNotFound", + "Subnet group {0} not found.".format(subnet_identifier), + ) class ClusterSecurityGroupNotFoundError(RedshiftClientError): - def __init__(self, group_identifier): super(ClusterSecurityGroupNotFoundError, self).__init__( - 'ClusterSecurityGroupNotFound', - "Security group {0} not found.".format(group_identifier)) + "ClusterSecurityGroupNotFound", + "Security group {0} not found.".format(group_identifier), + ) class ClusterParameterGroupNotFoundError(RedshiftClientError): - def __init__(self, group_identifier): super(ClusterParameterGroupNotFoundError, self).__init__( - 'ClusterParameterGroupNotFound', - "Parameter group {0} not found.".format(group_identifier)) + "ClusterParameterGroupNotFound", + "Parameter group {0} not found.".format(group_identifier), + ) class InvalidSubnetError(RedshiftClientError): - def __init__(self, subnet_identifier): super(InvalidSubnetError, self).__init__( - 'InvalidSubnet', - "Subnet {0} not found.".format(subnet_identifier)) + "InvalidSubnet", "Subnet {0} not found.".format(subnet_identifier) + ) class SnapshotCopyGrantAlreadyExistsFaultError(RedshiftClientError): def __init__(self, snapshot_copy_grant_name): super(SnapshotCopyGrantAlreadyExistsFaultError, self).__init__( - 'SnapshotCopyGrantAlreadyExistsFault', + "SnapshotCopyGrantAlreadyExistsFault", "Cannot create the snapshot copy grant because a grant " - "with the identifier '{0}' already exists".format(snapshot_copy_grant_name)) + "with the identifier '{0}' already exists".format(snapshot_copy_grant_name), + ) class SnapshotCopyGrantNotFoundFaultError(RedshiftClientError): def __init__(self, snapshot_copy_grant_name): super(SnapshotCopyGrantNotFoundFaultError, self).__init__( - 'SnapshotCopyGrantNotFoundFault', - "Snapshot copy grant not found: {0}".format(snapshot_copy_grant_name)) + "SnapshotCopyGrantNotFoundFault", + "Snapshot copy grant not found: {0}".format(snapshot_copy_grant_name), + ) class ClusterSnapshotNotFoundError(RedshiftClientError): def __init__(self, snapshot_identifier): super(ClusterSnapshotNotFoundError, self).__init__( - 'ClusterSnapshotNotFound', - "Snapshot {0} not found.".format(snapshot_identifier)) + "ClusterSnapshotNotFound", + "Snapshot {0} not found.".format(snapshot_identifier), + ) class ClusterSnapshotAlreadyExistsError(RedshiftClientError): def __init__(self, snapshot_identifier): super(ClusterSnapshotAlreadyExistsError, self).__init__( - 'ClusterSnapshotAlreadyExists', + "ClusterSnapshotAlreadyExists", "Cannot create the snapshot because a snapshot with the " - "identifier {0} already exists".format(snapshot_identifier)) + "identifier {0} already exists".format(snapshot_identifier), + ) class InvalidParameterValueError(RedshiftClientError): def __init__(self, message): super(InvalidParameterValueError, self).__init__( - 'InvalidParameterValue', - message) + "InvalidParameterValue", message + ) class ResourceNotFoundFaultError(RedshiftClientError): @@ -106,26 +105,34 @@ class ResourceNotFoundFaultError(RedshiftClientError): msg = "{0} ({1}) not found.".format(resource_type, resource_name) if message: msg = message - super(ResourceNotFoundFaultError, self).__init__( - 'ResourceNotFoundFault', msg) + super(ResourceNotFoundFaultError, self).__init__("ResourceNotFoundFault", msg) class SnapshotCopyDisabledFaultError(RedshiftClientError): def __init__(self, cluster_identifier): super(SnapshotCopyDisabledFaultError, self).__init__( - 'SnapshotCopyDisabledFault', - "Cannot modify retention period because snapshot copy is disabled on Cluster {0}.".format(cluster_identifier)) + "SnapshotCopyDisabledFault", + "Cannot modify retention period because snapshot copy is disabled on Cluster {0}.".format( + cluster_identifier + ), + ) class SnapshotCopyAlreadyDisabledFaultError(RedshiftClientError): def __init__(self, cluster_identifier): super(SnapshotCopyAlreadyDisabledFaultError, self).__init__( - 'SnapshotCopyAlreadyDisabledFault', - "Snapshot Copy is already disabled on Cluster {0}.".format(cluster_identifier)) + "SnapshotCopyAlreadyDisabledFault", + "Snapshot Copy is already disabled on Cluster {0}.".format( + cluster_identifier + ), + ) class SnapshotCopyAlreadyEnabledFaultError(RedshiftClientError): def __init__(self, cluster_identifier): super(SnapshotCopyAlreadyEnabledFaultError, self).__init__( - 'SnapshotCopyAlreadyEnabledFault', - "Snapshot Copy is already enabled on Cluster {0}.".format(cluster_identifier)) + "SnapshotCopyAlreadyEnabledFault", + "Snapshot Copy is already enabled on Cluster {0}.".format( + cluster_identifier + ), + ) diff --git a/moto/redshift/models.py b/moto/redshift/models.py index 8a2b7e6b6..3eac565f8 100644 --- a/moto/redshift/models.py +++ b/moto/redshift/models.py @@ -48,59 +48,91 @@ class TaggableResourceMixin(object): region=self.region, account_id=ACCOUNT_ID, resource_type=self.resource_type, - resource_id=self.resource_id) + resource_id=self.resource_id, + ) def create_tags(self, tags): - new_keys = [tag_set['Key'] for tag_set in tags] - self.tags = [tag_set for tag_set in self.tags - if tag_set['Key'] not in new_keys] + new_keys = [tag_set["Key"] for tag_set in tags] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in new_keys] self.tags.extend(tags) return self.tags def delete_tags(self, tag_keys): - self.tags = [tag_set for tag_set in self.tags - if tag_set['Key'] not in tag_keys] + self.tags = [tag_set for tag_set in self.tags if tag_set["Key"] not in tag_keys] return self.tags class Cluster(TaggableResourceMixin, BaseModel): - resource_type = 'cluster' + resource_type = "cluster" - def __init__(self, redshift_backend, cluster_identifier, node_type, master_username, - master_user_password, db_name, cluster_type, cluster_security_groups, - vpc_security_group_ids, cluster_subnet_group_name, availability_zone, - preferred_maintenance_window, cluster_parameter_group_name, - automated_snapshot_retention_period, port, cluster_version, - allow_version_upgrade, number_of_nodes, publicly_accessible, - encrypted, region_name, tags=None, iam_roles_arn=None, - enhanced_vpc_routing=None, restored_from_snapshot=False): + def __init__( + self, + redshift_backend, + cluster_identifier, + node_type, + master_username, + master_user_password, + db_name, + cluster_type, + cluster_security_groups, + vpc_security_group_ids, + cluster_subnet_group_name, + availability_zone, + preferred_maintenance_window, + cluster_parameter_group_name, + automated_snapshot_retention_period, + port, + cluster_version, + allow_version_upgrade, + number_of_nodes, + publicly_accessible, + encrypted, + region_name, + tags=None, + iam_roles_arn=None, + enhanced_vpc_routing=None, + restored_from_snapshot=False, + ): super(Cluster, self).__init__(region_name, tags) self.redshift_backend = redshift_backend self.cluster_identifier = cluster_identifier - self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()) - self.status = 'available' + self.create_time = iso_8601_datetime_with_milliseconds( + datetime.datetime.utcnow() + ) + self.status = "available" self.node_type = node_type self.master_username = master_username self.master_user_password = master_user_password self.db_name = db_name if db_name else "dev" self.vpc_security_group_ids = vpc_security_group_ids - self.enhanced_vpc_routing = enhanced_vpc_routing if enhanced_vpc_routing is not None else False + self.enhanced_vpc_routing = ( + enhanced_vpc_routing if enhanced_vpc_routing is not None else False + ) self.cluster_subnet_group_name = cluster_subnet_group_name self.publicly_accessible = publicly_accessible self.encrypted = encrypted - self.allow_version_upgrade = allow_version_upgrade if allow_version_upgrade is not None else True + self.allow_version_upgrade = ( + allow_version_upgrade if allow_version_upgrade is not None else True + ) self.cluster_version = cluster_version if cluster_version else "1.0" self.port = int(port) if port else 5439 - self.automated_snapshot_retention_period = int( - automated_snapshot_retention_period) if automated_snapshot_retention_period else 1 - self.preferred_maintenance_window = preferred_maintenance_window if preferred_maintenance_window else "Mon:03:00-Mon:03:30" + self.automated_snapshot_retention_period = ( + int(automated_snapshot_retention_period) + if automated_snapshot_retention_period + else 1 + ) + self.preferred_maintenance_window = ( + preferred_maintenance_window + if preferred_maintenance_window + else "Mon:03:00-Mon:03:30" + ) if cluster_parameter_group_name: self.cluster_parameter_group_name = [cluster_parameter_group_name] else: - self.cluster_parameter_group_name = ['default.redshift-1.0'] + self.cluster_parameter_group_name = ["default.redshift-1.0"] if cluster_security_groups: self.cluster_security_groups = cluster_security_groups @@ -114,7 +146,7 @@ class Cluster(TaggableResourceMixin, BaseModel): # way to pull AZs for a region in boto self.availability_zone = region_name + "a" - if cluster_type == 'single-node': + if cluster_type == "single-node": self.number_of_nodes = 1 elif number_of_nodes: self.number_of_nodes = int(number_of_nodes) @@ -125,38 +157,39 @@ class Cluster(TaggableResourceMixin, BaseModel): self.restored_from_snapshot = restored_from_snapshot @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): redshift_backend = redshift_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] - if 'ClusterSubnetGroupName' in properties: + if "ClusterSubnetGroupName" in properties: subnet_group_name = properties[ - 'ClusterSubnetGroupName'].cluster_subnet_group_name + "ClusterSubnetGroupName" + ].cluster_subnet_group_name else: subnet_group_name = None cluster = redshift_backend.create_cluster( cluster_identifier=resource_name, - node_type=properties.get('NodeType'), - master_username=properties.get('MasterUsername'), - master_user_password=properties.get('MasterUserPassword'), - db_name=properties.get('DBName'), - cluster_type=properties.get('ClusterType'), - cluster_security_groups=properties.get( - 'ClusterSecurityGroups', []), - vpc_security_group_ids=properties.get('VpcSecurityGroupIds', []), + node_type=properties.get("NodeType"), + master_username=properties.get("MasterUsername"), + master_user_password=properties.get("MasterUserPassword"), + db_name=properties.get("DBName"), + cluster_type=properties.get("ClusterType"), + cluster_security_groups=properties.get("ClusterSecurityGroups", []), + vpc_security_group_ids=properties.get("VpcSecurityGroupIds", []), cluster_subnet_group_name=subnet_group_name, - availability_zone=properties.get('AvailabilityZone'), - preferred_maintenance_window=properties.get( - 'PreferredMaintenanceWindow'), - cluster_parameter_group_name=properties.get( - 'ClusterParameterGroupName'), + availability_zone=properties.get("AvailabilityZone"), + preferred_maintenance_window=properties.get("PreferredMaintenanceWindow"), + cluster_parameter_group_name=properties.get("ClusterParameterGroupName"), automated_snapshot_retention_period=properties.get( - 'AutomatedSnapshotRetentionPeriod'), - port=properties.get('Port'), - cluster_version=properties.get('ClusterVersion'), - allow_version_upgrade=properties.get('AllowVersionUpgrade'), - enhanced_vpc_routing=properties.get('EnhancedVpcRouting'), - number_of_nodes=properties.get('NumberOfNodes'), + "AutomatedSnapshotRetentionPeriod" + ), + port=properties.get("Port"), + cluster_version=properties.get("ClusterVersion"), + allow_version_upgrade=properties.get("AllowVersionUpgrade"), + enhanced_vpc_routing=properties.get("EnhancedVpcRouting"), + number_of_nodes=properties.get("NumberOfNodes"), publicly_accessible=properties.get("PubliclyAccessible"), encrypted=properties.get("Encrypted"), region_name=region_name, @@ -165,41 +198,43 @@ class Cluster(TaggableResourceMixin, BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Endpoint.Address': + + if attribute_name == "Endpoint.Address": return self.endpoint - elif attribute_name == 'Endpoint.Port': + elif attribute_name == "Endpoint.Port": return self.port raise UnformattedGetAttTemplateException() @property def endpoint(self): return "{0}.cg034hpkmmjt.{1}.redshift.amazonaws.com".format( - self.cluster_identifier, - self.region, + self.cluster_identifier, self.region ) @property def security_groups(self): return [ - security_group for security_group - in self.redshift_backend.describe_cluster_security_groups() - if security_group.cluster_security_group_name in self.cluster_security_groups + security_group + for security_group in self.redshift_backend.describe_cluster_security_groups() + if security_group.cluster_security_group_name + in self.cluster_security_groups ] @property def vpc_security_groups(self): return [ - security_group for security_group - in self.redshift_backend.ec2_backend.describe_security_groups() + security_group + for security_group in self.redshift_backend.ec2_backend.describe_security_groups() if security_group.id in self.vpc_security_group_ids ] @property def parameter_groups(self): return [ - parameter_group for parameter_group - in self.redshift_backend.describe_cluster_parameter_groups() - if parameter_group.cluster_parameter_group_name in self.cluster_parameter_group_name + parameter_group + for parameter_group in self.redshift_backend.describe_cluster_parameter_groups() + if parameter_group.cluster_parameter_group_name + in self.cluster_parameter_group_name ] @property @@ -211,10 +246,10 @@ class Cluster(TaggableResourceMixin, BaseModel): "MasterUsername": self.master_username, "MasterUserPassword": "****", "ClusterVersion": self.cluster_version, - "VpcSecurityGroups": [{ - "Status": "active", - "VpcSecurityGroupId": group.id - } for group in self.vpc_security_groups], + "VpcSecurityGroups": [ + {"Status": "active", "VpcSecurityGroupId": group.id} + for group in self.vpc_security_groups + ], "ClusterSubnetGroupName": self.cluster_subnet_group_name, "AvailabilityZone": self.availability_zone, "ClusterStatus": self.status, @@ -224,42 +259,47 @@ class Cluster(TaggableResourceMixin, BaseModel): "Encrypted": self.encrypted, "DBName": self.db_name, "PreferredMaintenanceWindow": self.preferred_maintenance_window, - "ClusterParameterGroups": [{ - "ParameterApplyStatus": "in-sync", - "ParameterGroupName": group.cluster_parameter_group_name, - } for group in self.parameter_groups], - "ClusterSecurityGroups": [{ - "Status": "active", - "ClusterSecurityGroupName": group.cluster_security_group_name, - } for group in self.security_groups], + "ClusterParameterGroups": [ + { + "ParameterApplyStatus": "in-sync", + "ParameterGroupName": group.cluster_parameter_group_name, + } + for group in self.parameter_groups + ], + "ClusterSecurityGroups": [ + { + "Status": "active", + "ClusterSecurityGroupName": group.cluster_security_group_name, + } + for group in self.security_groups + ], "Port": self.port, "NodeType": self.node_type, "ClusterIdentifier": self.cluster_identifier, "AllowVersionUpgrade": self.allow_version_upgrade, - "Endpoint": { - "Address": self.endpoint, - "Port": self.port - }, - 'ClusterCreateTime': self.create_time, + "Endpoint": {"Address": self.endpoint, "Port": self.port}, + "ClusterCreateTime": self.create_time, "PendingModifiedValues": [], "Tags": self.tags, "EnhancedVpcRouting": self.enhanced_vpc_routing, - "IamRoles": [{ - "ApplyStatus": "in-sync", - "IamRoleArn": iam_role_arn - } for iam_role_arn in self.iam_roles_arn] + "IamRoles": [ + {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} + for iam_role_arn in self.iam_roles_arn + ], } if self.restored_from_snapshot: - json_response['RestoreStatus'] = { - 'Status': 'completed', - 'CurrentRestoreRateInMegaBytesPerSecond': 123.0, - 'SnapshotSizeInMegaBytes': 123, - 'ProgressInMegaBytes': 123, - 'ElapsedTimeInSeconds': 123, - 'EstimatedTimeToCompletionInSeconds': 123 + json_response["RestoreStatus"] = { + "Status": "completed", + "CurrentRestoreRateInMegaBytesPerSecond": 123.0, + "SnapshotSizeInMegaBytes": 123, + "ProgressInMegaBytes": 123, + "ElapsedTimeInSeconds": 123, + "EstimatedTimeToCompletionInSeconds": 123, } try: - json_response['ClusterSnapshotCopyStatus'] = self.cluster_snapshot_copy_status + json_response[ + "ClusterSnapshotCopyStatus" + ] = self.cluster_snapshot_copy_status except AttributeError: pass return json_response @@ -267,7 +307,7 @@ class Cluster(TaggableResourceMixin, BaseModel): class SnapshotCopyGrant(TaggableResourceMixin, BaseModel): - resource_type = 'snapshotcopygrant' + resource_type = "snapshotcopygrant" def __init__(self, snapshot_copy_grant_name, kms_key_id): self.snapshot_copy_grant_name = snapshot_copy_grant_name @@ -276,16 +316,23 @@ class SnapshotCopyGrant(TaggableResourceMixin, BaseModel): def to_json(self): return { "SnapshotCopyGrantName": self.snapshot_copy_grant_name, - "KmsKeyId": self.kms_key_id + "KmsKeyId": self.kms_key_id, } class SubnetGroup(TaggableResourceMixin, BaseModel): - resource_type = 'subnetgroup' + resource_type = "subnetgroup" - def __init__(self, ec2_backend, cluster_subnet_group_name, description, subnet_ids, - region_name, tags=None): + def __init__( + self, + ec2_backend, + cluster_subnet_group_name, + description, + subnet_ids, + region_name, + tags=None, + ): super(SubnetGroup, self).__init__(region_name, tags) self.ec2_backend = ec2_backend self.cluster_subnet_group_name = cluster_subnet_group_name @@ -295,21 +342,23 @@ class SubnetGroup(TaggableResourceMixin, BaseModel): raise InvalidSubnetError(subnet_ids) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): redshift_backend = redshift_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] subnet_group = redshift_backend.create_cluster_subnet_group( cluster_subnet_group_name=resource_name, description=properties.get("Description"), subnet_ids=properties.get("SubnetIds", []), - region_name=region_name + region_name=region_name, ) return subnet_group @property def subnets(self): - return self.ec2_backend.get_all_subnets(filters={'subnet-id': self.subnet_ids}) + return self.ec2_backend.get_all_subnets(filters={"subnet-id": self.subnet_ids}) @property def vpc_id(self): @@ -325,22 +374,25 @@ class SubnetGroup(TaggableResourceMixin, BaseModel): "Description": self.description, "ClusterSubnetGroupName": self.cluster_subnet_group_name, "SubnetGroupStatus": "Complete", - "Subnets": [{ - "SubnetStatus": "Active", - "SubnetIdentifier": subnet.id, - "SubnetAvailabilityZone": { - "Name": subnet.availability_zone - }, - } for subnet in self.subnets], - "Tags": self.tags + "Subnets": [ + { + "SubnetStatus": "Active", + "SubnetIdentifier": subnet.id, + "SubnetAvailabilityZone": {"Name": subnet.availability_zone}, + } + for subnet in self.subnets + ], + "Tags": self.tags, } class SecurityGroup(TaggableResourceMixin, BaseModel): - resource_type = 'securitygroup' + resource_type = "securitygroup" - def __init__(self, cluster_security_group_name, description, region_name, tags=None): + def __init__( + self, cluster_security_group_name, description, region_name, tags=None + ): super(SecurityGroup, self).__init__(region_name, tags) self.cluster_security_group_name = cluster_security_group_name self.description = description @@ -355,30 +407,39 @@ class SecurityGroup(TaggableResourceMixin, BaseModel): "IPRanges": [], "Description": self.description, "ClusterSecurityGroupName": self.cluster_security_group_name, - "Tags": self.tags + "Tags": self.tags, } class ParameterGroup(TaggableResourceMixin, BaseModel): - resource_type = 'parametergroup' + resource_type = "parametergroup" - def __init__(self, cluster_parameter_group_name, group_family, description, region_name, tags=None): + def __init__( + self, + cluster_parameter_group_name, + group_family, + description, + region_name, + tags=None, + ): super(ParameterGroup, self).__init__(region_name, tags) self.cluster_parameter_group_name = cluster_parameter_group_name self.group_family = group_family self.description = description @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): redshift_backend = redshift_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] parameter_group = redshift_backend.create_cluster_parameter_group( cluster_parameter_group_name=resource_name, description=properties.get("Description"), group_family=properties.get("ParameterGroupFamily"), - region_name=region_name + region_name=region_name, ) return parameter_group @@ -391,78 +452,81 @@ class ParameterGroup(TaggableResourceMixin, BaseModel): "ParameterGroupFamily": self.group_family, "Description": self.description, "ParameterGroupName": self.cluster_parameter_group_name, - "Tags": self.tags + "Tags": self.tags, } class Snapshot(TaggableResourceMixin, BaseModel): - resource_type = 'snapshot' + resource_type = "snapshot" - def __init__(self, cluster, snapshot_identifier, region_name, tags=None, iam_roles_arn=None): + def __init__( + self, cluster, snapshot_identifier, region_name, tags=None, iam_roles_arn=None + ): super(Snapshot, self).__init__(region_name, tags) self.cluster = copy.copy(cluster) self.snapshot_identifier = snapshot_identifier - self.snapshot_type = 'manual' - self.status = 'available' - self.create_time = iso_8601_datetime_with_milliseconds( - datetime.datetime.now()) + self.snapshot_type = "manual" + self.status = "available" + self.create_time = iso_8601_datetime_with_milliseconds(datetime.datetime.now()) self.iam_roles_arn = iam_roles_arn or [] @property def resource_id(self): return "{cluster_id}/{snapshot_id}".format( cluster_id=self.cluster.cluster_identifier, - snapshot_id=self.snapshot_identifier) + snapshot_id=self.snapshot_identifier, + ) def to_json(self): return { - 'SnapshotIdentifier': self.snapshot_identifier, - 'ClusterIdentifier': self.cluster.cluster_identifier, - 'SnapshotCreateTime': self.create_time, - 'Status': self.status, - 'Port': self.cluster.port, - 'AvailabilityZone': self.cluster.availability_zone, - 'MasterUsername': self.cluster.master_username, - 'ClusterVersion': self.cluster.cluster_version, - 'SnapshotType': self.snapshot_type, - 'NodeType': self.cluster.node_type, - 'NumberOfNodes': self.cluster.number_of_nodes, - 'DBName': self.cluster.db_name, - 'Tags': self.tags, - 'EnhancedVpcRouting': self.cluster.enhanced_vpc_routing, - "IamRoles": [{ - "ApplyStatus": "in-sync", - "IamRoleArn": iam_role_arn - } for iam_role_arn in self.iam_roles_arn] + "SnapshotIdentifier": self.snapshot_identifier, + "ClusterIdentifier": self.cluster.cluster_identifier, + "SnapshotCreateTime": self.create_time, + "Status": self.status, + "Port": self.cluster.port, + "AvailabilityZone": self.cluster.availability_zone, + "MasterUsername": self.cluster.master_username, + "ClusterVersion": self.cluster.cluster_version, + "SnapshotType": self.snapshot_type, + "NodeType": self.cluster.node_type, + "NumberOfNodes": self.cluster.number_of_nodes, + "DBName": self.cluster.db_name, + "Tags": self.tags, + "EnhancedVpcRouting": self.cluster.enhanced_vpc_routing, + "IamRoles": [ + {"ApplyStatus": "in-sync", "IamRoleArn": iam_role_arn} + for iam_role_arn in self.iam_roles_arn + ], } class RedshiftBackend(BaseBackend): - def __init__(self, ec2_backend, region_name): self.region = region_name self.clusters = {} self.subnet_groups = {} self.security_groups = { - "Default": SecurityGroup("Default", "Default Redshift Security Group", self.region) + "Default": SecurityGroup( + "Default", "Default Redshift Security Group", self.region + ) } self.parameter_groups = { "default.redshift-1.0": ParameterGroup( "default.redshift-1.0", "redshift-1.0", "Default Redshift parameter group", - self.region + self.region, ) } self.ec2_backend = ec2_backend self.snapshots = OrderedDict() self.RESOURCE_TYPE_MAP = { - 'cluster': self.clusters, - 'parametergroup': self.parameter_groups, - 'securitygroup': self.security_groups, - 'snapshot': self.snapshots, - 'subnetgroup': self.subnet_groups + "cluster": self.clusters, + "parametergroup": self.parameter_groups, + "securitygroup": self.security_groups, + "snapshot": self.snapshots, + "subnetgroup": self.subnet_groups, } self.snapshot_copy_grants = {} @@ -473,19 +537,22 @@ class RedshiftBackend(BaseBackend): self.__init__(ec2_backend, region_name) def enable_snapshot_copy(self, **kwargs): - cluster_identifier = kwargs['cluster_identifier'] + cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] - if not hasattr(cluster, 'cluster_snapshot_copy_status'): - if cluster.encrypted == 'true' and kwargs['snapshot_copy_grant_name'] is None: + if not hasattr(cluster, "cluster_snapshot_copy_status"): + if ( + cluster.encrypted == "true" + and kwargs["snapshot_copy_grant_name"] is None + ): raise ClientError( - 'InvalidParameterValue', - 'SnapshotCopyGrantName is required for Snapshot Copy ' - 'on KMS encrypted clusters.' + "InvalidParameterValue", + "SnapshotCopyGrantName is required for Snapshot Copy " + "on KMS encrypted clusters.", ) status = { - 'DestinationRegion': kwargs['destination_region'], - 'RetentionPeriod': kwargs['retention_period'], - 'SnapshotCopyGrantName': kwargs['snapshot_copy_grant_name'], + "DestinationRegion": kwargs["destination_region"], + "RetentionPeriod": kwargs["retention_period"], + "SnapshotCopyGrantName": kwargs["snapshot_copy_grant_name"], } cluster.cluster_snapshot_copy_status = status return cluster @@ -493,24 +560,26 @@ class RedshiftBackend(BaseBackend): raise SnapshotCopyAlreadyEnabledFaultError(cluster_identifier) def disable_snapshot_copy(self, **kwargs): - cluster_identifier = kwargs['cluster_identifier'] + cluster_identifier = kwargs["cluster_identifier"] cluster = self.clusters[cluster_identifier] - if hasattr(cluster, 'cluster_snapshot_copy_status'): + if hasattr(cluster, "cluster_snapshot_copy_status"): del cluster.cluster_snapshot_copy_status return cluster else: raise SnapshotCopyAlreadyDisabledFaultError(cluster_identifier) - def modify_snapshot_copy_retention_period(self, cluster_identifier, retention_period): + def modify_snapshot_copy_retention_period( + self, cluster_identifier, retention_period + ): cluster = self.clusters[cluster_identifier] - if hasattr(cluster, 'cluster_snapshot_copy_status'): - cluster.cluster_snapshot_copy_status['RetentionPeriod'] = retention_period + if hasattr(cluster, "cluster_snapshot_copy_status"): + cluster.cluster_snapshot_copy_status["RetentionPeriod"] = retention_period return cluster else: raise SnapshotCopyDisabledFaultError(cluster_identifier) def create_cluster(self, **cluster_kwargs): - cluster_identifier = cluster_kwargs['cluster_identifier'] + cluster_identifier = cluster_kwargs["cluster_identifier"] cluster = Cluster(self, **cluster_kwargs) self.clusters[cluster_identifier] = cluster return cluster @@ -525,9 +594,8 @@ class RedshiftBackend(BaseBackend): return clusters def modify_cluster(self, **cluster_kwargs): - cluster_identifier = cluster_kwargs.pop('cluster_identifier') - new_cluster_identifier = cluster_kwargs.pop( - 'new_cluster_identifier', None) + cluster_identifier = cluster_kwargs.pop("cluster_identifier") + new_cluster_identifier = cluster_kwargs.pop("new_cluster_identifier", None) cluster = self.describe_clusters(cluster_identifier)[0] @@ -538,7 +606,7 @@ class RedshiftBackend(BaseBackend): dic = { "cluster_identifier": cluster_identifier, "skip_final_snapshot": True, - "final_cluster_snapshot_identifier": None + "final_cluster_snapshot_identifier": None, } self.delete_cluster(**dic) cluster.cluster_identifier = new_cluster_identifier @@ -549,30 +617,46 @@ class RedshiftBackend(BaseBackend): def delete_cluster(self, **cluster_kwargs): cluster_identifier = cluster_kwargs.pop("cluster_identifier") cluster_skip_final_snapshot = cluster_kwargs.pop("skip_final_snapshot") - cluster_snapshot_identifer = cluster_kwargs.pop("final_cluster_snapshot_identifier") + cluster_snapshot_identifer = cluster_kwargs.pop( + "final_cluster_snapshot_identifier" + ) if cluster_identifier in self.clusters: - if cluster_skip_final_snapshot is False and cluster_snapshot_identifer is None: + if ( + cluster_skip_final_snapshot is False + and cluster_snapshot_identifer is None + ): raise ClientError( "InvalidParameterValue", - 'FinalSnapshotIdentifier is required for Snapshot copy ' - 'when SkipFinalSnapshot is False' + "FinalSnapshotIdentifier is required for Snapshot copy " + "when SkipFinalSnapshot is False", ) - elif cluster_skip_final_snapshot is False and cluster_snapshot_identifer is not None: # create snapshot + elif ( + cluster_skip_final_snapshot is False + and cluster_snapshot_identifer is not None + ): # create snapshot cluster = self.describe_clusters(cluster_identifier)[0] self.create_cluster_snapshot( cluster_identifier, cluster_snapshot_identifer, cluster.region, - cluster.tags) + cluster.tags, + ) return self.clusters.pop(cluster_identifier) raise ClusterNotFoundError(cluster_identifier) - def create_cluster_subnet_group(self, cluster_subnet_group_name, description, subnet_ids, - region_name, tags=None): + def create_cluster_subnet_group( + self, cluster_subnet_group_name, description, subnet_ids, region_name, tags=None + ): subnet_group = SubnetGroup( - self.ec2_backend, cluster_subnet_group_name, description, subnet_ids, region_name, tags) + self.ec2_backend, + cluster_subnet_group_name, + description, + subnet_ids, + region_name, + tags, + ) self.subnet_groups[cluster_subnet_group_name] = subnet_group return subnet_group @@ -590,9 +674,12 @@ class RedshiftBackend(BaseBackend): return self.subnet_groups.pop(subnet_identifier) raise ClusterSubnetGroupNotFoundError(subnet_identifier) - def create_cluster_security_group(self, cluster_security_group_name, description, region_name, tags=None): + def create_cluster_security_group( + self, cluster_security_group_name, description, region_name, tags=None + ): security_group = SecurityGroup( - cluster_security_group_name, description, region_name, tags) + cluster_security_group_name, description, region_name, tags + ) self.security_groups[cluster_security_group_name] = security_group return security_group @@ -610,10 +697,17 @@ class RedshiftBackend(BaseBackend): return self.security_groups.pop(security_group_identifier) raise ClusterSecurityGroupNotFoundError(security_group_identifier) - def create_cluster_parameter_group(self, cluster_parameter_group_name, - group_family, description, region_name, tags=None): + def create_cluster_parameter_group( + self, + cluster_parameter_group_name, + group_family, + description, + region_name, + tags=None, + ): parameter_group = ParameterGroup( - cluster_parameter_group_name, group_family, description, region_name, tags) + cluster_parameter_group_name, group_family, description, region_name, tags + ) self.parameter_groups[cluster_parameter_group_name] = parameter_group return parameter_group @@ -632,7 +726,9 @@ class RedshiftBackend(BaseBackend): return self.parameter_groups.pop(parameter_group_name) raise ClusterParameterGroupNotFoundError(parameter_group_name) - def create_cluster_snapshot(self, cluster_identifier, snapshot_identifier, region_name, tags): + def create_cluster_snapshot( + self, cluster_identifier, snapshot_identifier, region_name, tags + ): cluster = self.clusters.get(cluster_identifier) if not cluster: raise ClusterNotFoundError(cluster_identifier) @@ -642,7 +738,9 @@ class RedshiftBackend(BaseBackend): self.snapshots[snapshot_identifier] = snapshot return snapshot - def describe_cluster_snapshots(self, cluster_identifier=None, snapshot_identifier=None): + def describe_cluster_snapshots( + self, cluster_identifier=None, snapshot_identifier=None + ): if cluster_identifier: cluster_snapshots = [] for snapshot in self.snapshots.values(): @@ -664,18 +762,22 @@ class RedshiftBackend(BaseBackend): raise ClusterSnapshotNotFoundError(snapshot_identifier) deleted_snapshot = self.snapshots.pop(snapshot_identifier) - deleted_snapshot.status = 'deleted' + deleted_snapshot.status = "deleted" return deleted_snapshot def restore_from_cluster_snapshot(self, **kwargs): - snapshot_identifier = kwargs.pop('snapshot_identifier') - snapshot = self.describe_cluster_snapshots(snapshot_identifier=snapshot_identifier)[0] + snapshot_identifier = kwargs.pop("snapshot_identifier") + snapshot = self.describe_cluster_snapshots( + snapshot_identifier=snapshot_identifier + )[0] create_kwargs = { "node_type": snapshot.cluster.node_type, "master_username": snapshot.cluster.master_username, "master_user_password": snapshot.cluster.master_user_password, "db_name": snapshot.cluster.db_name, - "cluster_type": 'multi-node' if snapshot.cluster.number_of_nodes > 1 else 'single-node', + "cluster_type": "multi-node" + if snapshot.cluster.number_of_nodes > 1 + else "single-node", "availability_zone": snapshot.cluster.availability_zone, "port": snapshot.cluster.port, "cluster_version": snapshot.cluster.cluster_version, @@ -683,29 +785,31 @@ class RedshiftBackend(BaseBackend): "encrypted": snapshot.cluster.encrypted, "tags": snapshot.cluster.tags, "restored_from_snapshot": True, - "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing + "enhanced_vpc_routing": snapshot.cluster.enhanced_vpc_routing, } create_kwargs.update(kwargs) return self.create_cluster(**create_kwargs) def create_snapshot_copy_grant(self, **kwargs): - snapshot_copy_grant_name = kwargs['snapshot_copy_grant_name'] - kms_key_id = kwargs['kms_key_id'] + snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] + kms_key_id = kwargs["kms_key_id"] if snapshot_copy_grant_name not in self.snapshot_copy_grants: - snapshot_copy_grant = SnapshotCopyGrant(snapshot_copy_grant_name, kms_key_id) + snapshot_copy_grant = SnapshotCopyGrant( + snapshot_copy_grant_name, kms_key_id + ) self.snapshot_copy_grants[snapshot_copy_grant_name] = snapshot_copy_grant return snapshot_copy_grant raise SnapshotCopyGrantAlreadyExistsFaultError(snapshot_copy_grant_name) def delete_snapshot_copy_grant(self, **kwargs): - snapshot_copy_grant_name = kwargs['snapshot_copy_grant_name'] + snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name in self.snapshot_copy_grants: return self.snapshot_copy_grants.pop(snapshot_copy_grant_name) raise SnapshotCopyGrantNotFoundFaultError(snapshot_copy_grant_name) def describe_snapshot_copy_grants(self, **kwargs): copy_grants = self.snapshot_copy_grants.values() - snapshot_copy_grant_name = kwargs['snapshot_copy_grant_name'] + snapshot_copy_grant_name = kwargs["snapshot_copy_grant_name"] if snapshot_copy_grant_name: if snapshot_copy_grant_name in self.snapshot_copy_grants: return [self.snapshot_copy_grants[snapshot_copy_grant_name]] @@ -715,10 +819,10 @@ class RedshiftBackend(BaseBackend): def _get_resource_from_arn(self, arn): try: - arn_breakdown = arn.split(':') + arn_breakdown = arn.split(":") resource_type = arn_breakdown[5] - if resource_type == 'snapshot': - resource_id = arn_breakdown[6].split('/')[1] + if resource_type == "snapshot": + resource_id = arn_breakdown[6].split("/")[1] else: resource_id = arn_breakdown[6] except IndexError: @@ -728,7 +832,8 @@ class RedshiftBackend(BaseBackend): message = ( "Tagging is not supported for this type of resource: '{0}' " "(the ARN is potentially malformed, please check the ARN " - "documentation for more information)".format(resource_type)) + "documentation for more information)".format(resource_type) + ) raise ResourceNotFoundFaultError(message=message) try: resource = resources[resource_id] @@ -743,12 +848,9 @@ class RedshiftBackend(BaseBackend): for resource in resources: for tag in resource.tags: data = { - 'ResourceName': resource.arn, - 'ResourceType': resource.resource_type, - 'Tag': { - 'Key': tag['Key'], - 'Value': tag['Value'] - } + "ResourceName": resource.arn, + "ResourceType": resource.resource_type, + "Tag": {"Key": tag["Key"], "Value": tag["Value"]}, } tagged_resources.append(data) return tagged_resources @@ -773,7 +875,8 @@ class RedshiftBackend(BaseBackend): "You cannot filter a list of resources using an Amazon " "Resource Name (ARN) and a resource type together in the " "same request. Retry the request using either an ARN or " - "a resource type, but not both.") + "a resource type, but not both." + ) if resource_type: return self._describe_tags_for_resource_type(resource_type.lower()) if resource_name: @@ -795,4 +898,6 @@ class RedshiftBackend(BaseBackend): redshift_backends = {} for region in boto.redshift.regions(): - redshift_backends[region.name] = RedshiftBackend(ec2_backends[region.name], region.name) + redshift_backends[region.name] = RedshiftBackend( + ec2_backends[region.name], region.name + ) diff --git a/moto/redshift/responses.py b/moto/redshift/responses.py index 7ac73d470..a4094949f 100644 --- a/moto/redshift/responses.py +++ b/moto/redshift/responses.py @@ -13,9 +13,10 @@ from .models import redshift_backends def convert_json_error_to_xml(json_error): error = json.loads(json_error) - code = error['Error']['Code'] - message = error['Error']['Message'] - template = Template(""" + code = error["Error"]["Code"] + message = error["Error"]["Message"] + template = Template( + """ {{ code }} @@ -23,7 +24,8 @@ def convert_json_error_to_xml(json_error): Sender 6876f774-7273-11e4-85dc-39e55ca848d1 - """) + """ + ) return template.render(code=code, message=message) @@ -40,13 +42,12 @@ def itemize(data): ret[key] = itemize(data[key]) return ret elif isinstance(data, list): - return {'item': [itemize(value) for value in data]} + return {"item": [itemize(value) for value in data]} else: return data class RedshiftResponse(BaseResponse): - @property def redshift_backend(self): return redshift_backends[self.region] @@ -56,8 +57,8 @@ class RedshiftResponse(BaseResponse): return json.dumps(response) else: xml = xmltodict.unparse(itemize(response), full_document=False) - if hasattr(xml, 'decode'): - xml = xml.decode('utf-8') + if hasattr(xml, "decode"): + xml = xml.decode("utf-8") return xml def call_action(self): @@ -69,11 +70,12 @@ class RedshiftResponse(BaseResponse): def unpack_complex_list_params(self, label, names): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}.{2}'.format(label, count, names[0])): + while self._get_param("{0}.{1}.{2}".format(label, count, names[0])): param = dict() for i in range(len(names)): param[names[i]] = self._get_param( - '{0}.{1}.{2}'.format(label, count, names[i])) + "{0}.{1}.{2}".format(label, count, names[i]) + ) unpacked_list.append(param) count += 1 return unpacked_list @@ -81,148 +83,168 @@ class RedshiftResponse(BaseResponse): def unpack_list_params(self, label): unpacked_list = list() count = 1 - while self._get_param('{0}.{1}'.format(label, count)): - unpacked_list.append(self._get_param( - '{0}.{1}'.format(label, count))) + while self._get_param("{0}.{1}".format(label, count)): + unpacked_list.append(self._get_param("{0}.{1}".format(label, count))) count += 1 return unpacked_list def _get_cluster_security_groups(self): - cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.member') + cluster_security_groups = self._get_multi_param("ClusterSecurityGroups.member") if not cluster_security_groups: - cluster_security_groups = self._get_multi_param('ClusterSecurityGroups.ClusterSecurityGroupName') + cluster_security_groups = self._get_multi_param( + "ClusterSecurityGroups.ClusterSecurityGroupName" + ) return cluster_security_groups def _get_vpc_security_group_ids(self): - vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.member') + vpc_security_group_ids = self._get_multi_param("VpcSecurityGroupIds.member") if not vpc_security_group_ids: - vpc_security_group_ids = self._get_multi_param('VpcSecurityGroupIds.VpcSecurityGroupId') + vpc_security_group_ids = self._get_multi_param( + "VpcSecurityGroupIds.VpcSecurityGroupId" + ) return vpc_security_group_ids def _get_iam_roles(self): - iam_roles = self._get_multi_param('IamRoles.member') + iam_roles = self._get_multi_param("IamRoles.member") if not iam_roles: - iam_roles = self._get_multi_param('IamRoles.IamRoleArn') + iam_roles = self._get_multi_param("IamRoles.IamRoleArn") return iam_roles def _get_subnet_ids(self): - subnet_ids = self._get_multi_param('SubnetIds.member') + subnet_ids = self._get_multi_param("SubnetIds.member") if not subnet_ids: - subnet_ids = self._get_multi_param('SubnetIds.SubnetIdentifier') + subnet_ids = self._get_multi_param("SubnetIds.SubnetIdentifier") return subnet_ids def create_cluster(self): cluster_kwargs = { - "cluster_identifier": self._get_param('ClusterIdentifier'), - "node_type": self._get_param('NodeType'), - "master_username": self._get_param('MasterUsername'), - "master_user_password": self._get_param('MasterUserPassword'), - "db_name": self._get_param('DBName'), - "cluster_type": self._get_param('ClusterType'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "node_type": self._get_param("NodeType"), + "master_username": self._get_param("MasterUsername"), + "master_user_password": self._get_param("MasterUserPassword"), + "db_name": self._get_param("DBName"), + "cluster_type": self._get_param("ClusterType"), "cluster_security_groups": self._get_cluster_security_groups(), "vpc_security_group_ids": self._get_vpc_security_group_ids(), - "cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'), - "availability_zone": self._get_param('AvailabilityZone'), - "preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'), - "cluster_parameter_group_name": self._get_param('ClusterParameterGroupName'), - "automated_snapshot_retention_period": self._get_int_param('AutomatedSnapshotRetentionPeriod'), - "port": self._get_int_param('Port'), - "cluster_version": self._get_param('ClusterVersion'), - "allow_version_upgrade": self._get_bool_param('AllowVersionUpgrade'), - "number_of_nodes": self._get_int_param('NumberOfNodes'), + "cluster_subnet_group_name": self._get_param("ClusterSubnetGroupName"), + "availability_zone": self._get_param("AvailabilityZone"), + "preferred_maintenance_window": self._get_param( + "PreferredMaintenanceWindow" + ), + "cluster_parameter_group_name": self._get_param( + "ClusterParameterGroupName" + ), + "automated_snapshot_retention_period": self._get_int_param( + "AutomatedSnapshotRetentionPeriod" + ), + "port": self._get_int_param("Port"), + "cluster_version": self._get_param("ClusterVersion"), + "allow_version_upgrade": self._get_bool_param("AllowVersionUpgrade"), + "number_of_nodes": self._get_int_param("NumberOfNodes"), "publicly_accessible": self._get_param("PubliclyAccessible"), "encrypted": self._get_param("Encrypted"), "region_name": self.region, - "tags": self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')), + "tags": self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")), "iam_roles_arn": self._get_iam_roles(), - "enhanced_vpc_routing": self._get_param('EnhancedVpcRouting'), + "enhanced_vpc_routing": self._get_param("EnhancedVpcRouting"), } cluster = self.redshift_backend.create_cluster(**cluster_kwargs).to_json() - cluster['ClusterStatus'] = 'creating' - return self.get_response({ - "CreateClusterResponse": { - "CreateClusterResult": { - "Cluster": cluster, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + cluster["ClusterStatus"] = "creating" + return self.get_response( + { + "CreateClusterResponse": { + "CreateClusterResult": {"Cluster": cluster}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def restore_from_cluster_snapshot(self): - enhanced_vpc_routing = self._get_bool_param('EnhancedVpcRouting') + enhanced_vpc_routing = self._get_bool_param("EnhancedVpcRouting") restore_kwargs = { - "snapshot_identifier": self._get_param('SnapshotIdentifier'), - "cluster_identifier": self._get_param('ClusterIdentifier'), - "port": self._get_int_param('Port'), - "availability_zone": self._get_param('AvailabilityZone'), - "allow_version_upgrade": self._get_bool_param( - 'AllowVersionUpgrade'), - "cluster_subnet_group_name": self._get_param( - 'ClusterSubnetGroupName'), + "snapshot_identifier": self._get_param("SnapshotIdentifier"), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "port": self._get_int_param("Port"), + "availability_zone": self._get_param("AvailabilityZone"), + "allow_version_upgrade": self._get_bool_param("AllowVersionUpgrade"), + "cluster_subnet_group_name": self._get_param("ClusterSubnetGroupName"), "publicly_accessible": self._get_param("PubliclyAccessible"), "cluster_parameter_group_name": self._get_param( - 'ClusterParameterGroupName'), + "ClusterParameterGroupName" + ), "cluster_security_groups": self._get_cluster_security_groups(), "vpc_security_group_ids": self._get_vpc_security_group_ids(), "preferred_maintenance_window": self._get_param( - 'PreferredMaintenanceWindow'), + "PreferredMaintenanceWindow" + ), "automated_snapshot_retention_period": self._get_int_param( - 'AutomatedSnapshotRetentionPeriod'), + "AutomatedSnapshotRetentionPeriod" + ), "region_name": self.region, "iam_roles_arn": self._get_iam_roles(), } if enhanced_vpc_routing is not None: - restore_kwargs['enhanced_vpc_routing'] = enhanced_vpc_routing - cluster = self.redshift_backend.restore_from_cluster_snapshot(**restore_kwargs).to_json() - cluster['ClusterStatus'] = 'creating' - return self.get_response({ - "RestoreFromClusterSnapshotResponse": { - "RestoreFromClusterSnapshotResult": { - "Cluster": cluster, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + restore_kwargs["enhanced_vpc_routing"] = enhanced_vpc_routing + cluster = self.redshift_backend.restore_from_cluster_snapshot( + **restore_kwargs + ).to_json() + cluster["ClusterStatus"] = "creating" + return self.get_response( + { + "RestoreFromClusterSnapshotResponse": { + "RestoreFromClusterSnapshotResult": {"Cluster": cluster}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_clusters(self): cluster_identifier = self._get_param("ClusterIdentifier") clusters = self.redshift_backend.describe_clusters(cluster_identifier) - return self.get_response({ - "DescribeClustersResponse": { - "DescribeClustersResult": { - "Clusters": [cluster.to_json() for cluster in clusters] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClustersResponse": { + "DescribeClustersResult": { + "Clusters": [cluster.to_json() for cluster in clusters] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def modify_cluster(self): request_kwargs = { - "cluster_identifier": self._get_param('ClusterIdentifier'), - "new_cluster_identifier": self._get_param('NewClusterIdentifier'), - "node_type": self._get_param('NodeType'), - "master_user_password": self._get_param('MasterUserPassword'), - "cluster_type": self._get_param('ClusterType'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "new_cluster_identifier": self._get_param("NewClusterIdentifier"), + "node_type": self._get_param("NodeType"), + "master_user_password": self._get_param("MasterUserPassword"), + "cluster_type": self._get_param("ClusterType"), "cluster_security_groups": self._get_cluster_security_groups(), "vpc_security_group_ids": self._get_vpc_security_group_ids(), - "cluster_subnet_group_name": self._get_param('ClusterSubnetGroupName'), - "preferred_maintenance_window": self._get_param('PreferredMaintenanceWindow'), - "cluster_parameter_group_name": self._get_param('ClusterParameterGroupName'), - "automated_snapshot_retention_period": self._get_int_param('AutomatedSnapshotRetentionPeriod'), - "cluster_version": self._get_param('ClusterVersion'), - "allow_version_upgrade": self._get_bool_param('AllowVersionUpgrade'), - "number_of_nodes": self._get_int_param('NumberOfNodes'), + "cluster_subnet_group_name": self._get_param("ClusterSubnetGroupName"), + "preferred_maintenance_window": self._get_param( + "PreferredMaintenanceWindow" + ), + "cluster_parameter_group_name": self._get_param( + "ClusterParameterGroupName" + ), + "automated_snapshot_retention_period": self._get_int_param( + "AutomatedSnapshotRetentionPeriod" + ), + "cluster_version": self._get_param("ClusterVersion"), + "allow_version_upgrade": self._get_bool_param("AllowVersionUpgrade"), + "number_of_nodes": self._get_int_param("NumberOfNodes"), "publicly_accessible": self._get_param("PubliclyAccessible"), "encrypted": self._get_param("Encrypted"), "iam_roles_arn": self._get_iam_roles(), - "enhanced_vpc_routing": self._get_param("EnhancedVpcRouting") + "enhanced_vpc_routing": self._get_param("EnhancedVpcRouting"), } cluster_kwargs = {} # We only want parameters that were actually passed in, otherwise @@ -233,394 +255,442 @@ class RedshiftResponse(BaseResponse): cluster = self.redshift_backend.modify_cluster(**cluster_kwargs) - return self.get_response({ - "ModifyClusterResponse": { - "ModifyClusterResult": { - "Cluster": cluster.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "ModifyClusterResponse": { + "ModifyClusterResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster(self): request_kwargs = { "cluster_identifier": self._get_param("ClusterIdentifier"), - "final_cluster_snapshot_identifier": self._get_param("FinalClusterSnapshotIdentifier"), - "skip_final_snapshot": self._get_bool_param("SkipFinalClusterSnapshot") + "final_cluster_snapshot_identifier": self._get_param( + "FinalClusterSnapshotIdentifier" + ), + "skip_final_snapshot": self._get_bool_param("SkipFinalClusterSnapshot"), } cluster = self.redshift_backend.delete_cluster(**request_kwargs) - return self.get_response({ - "DeleteClusterResponse": { - "DeleteClusterResult": { - "Cluster": cluster.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterResponse": { + "DeleteClusterResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def create_cluster_subnet_group(self): - cluster_subnet_group_name = self._get_param('ClusterSubnetGroupName') - description = self._get_param('Description') + cluster_subnet_group_name = self._get_param("ClusterSubnetGroupName") + description = self._get_param("Description") subnet_ids = self._get_subnet_ids() - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) subnet_group = self.redshift_backend.create_cluster_subnet_group( cluster_subnet_group_name=cluster_subnet_group_name, description=description, subnet_ids=subnet_ids, region_name=self.region, - tags=tags + tags=tags, ) - return self.get_response({ - "CreateClusterSubnetGroupResponse": { - "CreateClusterSubnetGroupResult": { - "ClusterSubnetGroup": subnet_group.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "CreateClusterSubnetGroupResponse": { + "CreateClusterSubnetGroupResult": { + "ClusterSubnetGroup": subnet_group.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_subnet_groups(self): subnet_identifier = self._get_param("ClusterSubnetGroupName") subnet_groups = self.redshift_backend.describe_cluster_subnet_groups( - subnet_identifier) + subnet_identifier + ) - return self.get_response({ - "DescribeClusterSubnetGroupsResponse": { - "DescribeClusterSubnetGroupsResult": { - "ClusterSubnetGroups": [subnet_group.to_json() for subnet_group in subnet_groups] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClusterSubnetGroupsResponse": { + "DescribeClusterSubnetGroupsResult": { + "ClusterSubnetGroups": [ + subnet_group.to_json() for subnet_group in subnet_groups + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_subnet_group(self): subnet_identifier = self._get_param("ClusterSubnetGroupName") self.redshift_backend.delete_cluster_subnet_group(subnet_identifier) - return self.get_response({ - "DeleteClusterSubnetGroupResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterSubnetGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def create_cluster_security_group(self): - cluster_security_group_name = self._get_param( - 'ClusterSecurityGroupName') - description = self._get_param('Description') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + cluster_security_group_name = self._get_param("ClusterSecurityGroupName") + description = self._get_param("Description") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) security_group = self.redshift_backend.create_cluster_security_group( cluster_security_group_name=cluster_security_group_name, description=description, region_name=self.region, - tags=tags + tags=tags, ) - return self.get_response({ - "CreateClusterSecurityGroupResponse": { - "CreateClusterSecurityGroupResult": { - "ClusterSecurityGroup": security_group.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "CreateClusterSecurityGroupResponse": { + "CreateClusterSecurityGroupResult": { + "ClusterSecurityGroup": security_group.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_security_groups(self): - cluster_security_group_name = self._get_param( - "ClusterSecurityGroupName") + cluster_security_group_name = self._get_param("ClusterSecurityGroupName") security_groups = self.redshift_backend.describe_cluster_security_groups( - cluster_security_group_name) + cluster_security_group_name + ) - return self.get_response({ - "DescribeClusterSecurityGroupsResponse": { - "DescribeClusterSecurityGroupsResult": { - "ClusterSecurityGroups": [security_group.to_json() for security_group in security_groups] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClusterSecurityGroupsResponse": { + "DescribeClusterSecurityGroupsResult": { + "ClusterSecurityGroups": [ + security_group.to_json() + for security_group in security_groups + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_security_group(self): security_group_identifier = self._get_param("ClusterSecurityGroupName") - self.redshift_backend.delete_cluster_security_group( - security_group_identifier) + self.redshift_backend.delete_cluster_security_group(security_group_identifier) - return self.get_response({ - "DeleteClusterSecurityGroupResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterSecurityGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) - - def create_cluster_parameter_group(self): - cluster_parameter_group_name = self._get_param('ParameterGroupName') - group_family = self._get_param('ParameterGroupFamily') - description = self._get_param('Description') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) - - parameter_group = self.redshift_backend.create_cluster_parameter_group( - cluster_parameter_group_name, - group_family, - description, - self.region, - tags ) - return self.get_response({ - "CreateClusterParameterGroupResponse": { - "CreateClusterParameterGroupResult": { - "ClusterParameterGroup": parameter_group.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + def create_cluster_parameter_group(self): + cluster_parameter_group_name = self._get_param("ParameterGroupName") + group_family = self._get_param("ParameterGroupFamily") + description = self._get_param("Description") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) + + parameter_group = self.redshift_backend.create_cluster_parameter_group( + cluster_parameter_group_name, group_family, description, self.region, tags + ) + + return self.get_response( + { + "CreateClusterParameterGroupResponse": { + "CreateClusterParameterGroupResult": { + "ClusterParameterGroup": parameter_group.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_parameter_groups(self): cluster_parameter_group_name = self._get_param("ParameterGroupName") parameter_groups = self.redshift_backend.describe_cluster_parameter_groups( - cluster_parameter_group_name) + cluster_parameter_group_name + ) - return self.get_response({ - "DescribeClusterParameterGroupsResponse": { - "DescribeClusterParameterGroupsResult": { - "ParameterGroups": [parameter_group.to_json() for parameter_group in parameter_groups] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DescribeClusterParameterGroupsResponse": { + "DescribeClusterParameterGroupsResult": { + "ParameterGroups": [ + parameter_group.to_json() + for parameter_group in parameter_groups + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_parameter_group(self): cluster_parameter_group_name = self._get_param("ParameterGroupName") self.redshift_backend.delete_cluster_parameter_group( - cluster_parameter_group_name) + cluster_parameter_group_name + ) - return self.get_response({ - "DeleteClusterParameterGroupResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterParameterGroupResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def create_cluster_snapshot(self): - cluster_identifier = self._get_param('ClusterIdentifier') - snapshot_identifier = self._get_param('SnapshotIdentifier') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + cluster_identifier = self._get_param("ClusterIdentifier") + snapshot_identifier = self._get_param("SnapshotIdentifier") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) - snapshot = self.redshift_backend.create_cluster_snapshot(cluster_identifier, - snapshot_identifier, - self.region, - tags) - return self.get_response({ - 'CreateClusterSnapshotResponse': { - "CreateClusterSnapshotResult": { - "Snapshot": snapshot.to_json(), - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + snapshot = self.redshift_backend.create_cluster_snapshot( + cluster_identifier, snapshot_identifier, self.region, tags + ) + return self.get_response( + { + "CreateClusterSnapshotResponse": { + "CreateClusterSnapshotResult": {"Snapshot": snapshot.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def describe_cluster_snapshots(self): - cluster_identifier = self._get_param('ClusterIdentifier') - snapshot_identifier = self._get_param('SnapshotIdentifier') - snapshots = self.redshift_backend.describe_cluster_snapshots(cluster_identifier, - snapshot_identifier) - return self.get_response({ - "DescribeClusterSnapshotsResponse": { - "DescribeClusterSnapshotsResult": { - "Snapshots": [snapshot.to_json() for snapshot in snapshots] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + cluster_identifier = self._get_param("ClusterIdentifier") + snapshot_identifier = self._get_param("SnapshotIdentifier") + snapshots = self.redshift_backend.describe_cluster_snapshots( + cluster_identifier, snapshot_identifier + ) + return self.get_response( + { + "DescribeClusterSnapshotsResponse": { + "DescribeClusterSnapshotsResult": { + "Snapshots": [snapshot.to_json() for snapshot in snapshots] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_cluster_snapshot(self): - snapshot_identifier = self._get_param('SnapshotIdentifier') + snapshot_identifier = self._get_param("SnapshotIdentifier") snapshot = self.redshift_backend.delete_cluster_snapshot(snapshot_identifier) - return self.get_response({ - "DeleteClusterSnapshotResponse": { - "DeleteClusterSnapshotResult": { - "Snapshot": snapshot.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteClusterSnapshotResponse": { + "DeleteClusterSnapshotResult": {"Snapshot": snapshot.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def create_snapshot_copy_grant(self): copy_grant_kwargs = { - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), - 'kms_key_id': self._get_param('KmsKeyId'), - 'region_name': self._get_param('Region'), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName"), + "kms_key_id": self._get_param("KmsKeyId"), + "region_name": self._get_param("Region"), } - copy_grant = self.redshift_backend.create_snapshot_copy_grant(**copy_grant_kwargs) - return self.get_response({ - "CreateSnapshotCopyGrantResponse": { - "CreateSnapshotCopyGrantResult": { - "SnapshotCopyGrant": copy_grant.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + copy_grant = self.redshift_backend.create_snapshot_copy_grant( + **copy_grant_kwargs + ) + return self.get_response( + { + "CreateSnapshotCopyGrantResponse": { + "CreateSnapshotCopyGrantResult": { + "SnapshotCopyGrant": copy_grant.to_json() + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_snapshot_copy_grant(self): copy_grant_kwargs = { - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName") } self.redshift_backend.delete_snapshot_copy_grant(**copy_grant_kwargs) - return self.get_response({ - "DeleteSnapshotCopyGrantResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteSnapshotCopyGrantResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def describe_snapshot_copy_grants(self): copy_grant_kwargs = { - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName") } - copy_grants = self.redshift_backend.describe_snapshot_copy_grants(**copy_grant_kwargs) - return self.get_response({ - "DescribeSnapshotCopyGrantsResponse": { - "DescribeSnapshotCopyGrantsResult": { - "SnapshotCopyGrants": [copy_grant.to_json() for copy_grant in copy_grants] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + copy_grants = self.redshift_backend.describe_snapshot_copy_grants( + **copy_grant_kwargs + ) + return self.get_response( + { + "DescribeSnapshotCopyGrantsResponse": { + "DescribeSnapshotCopyGrantsResult": { + "SnapshotCopyGrants": [ + copy_grant.to_json() for copy_grant in copy_grants + ] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def create_tags(self): - resource_name = self._get_param('ResourceName') - tags = self.unpack_complex_list_params('Tags.Tag', ('Key', 'Value')) + resource_name = self._get_param("ResourceName") + tags = self.unpack_complex_list_params("Tags.Tag", ("Key", "Value")) self.redshift_backend.create_tags(resource_name, tags) - return self.get_response({ - "CreateTagsResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "CreateTagsResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def describe_tags(self): - resource_name = self._get_param('ResourceName') - resource_type = self._get_param('ResourceType') + resource_name = self._get_param("ResourceName") + resource_type = self._get_param("ResourceType") - tagged_resources = self.redshift_backend.describe_tags(resource_name, - resource_type) - return self.get_response({ - "DescribeTagsResponse": { - "DescribeTagsResult": { - "TaggedResources": tagged_resources - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + tagged_resources = self.redshift_backend.describe_tags( + resource_name, resource_type + ) + return self.get_response( + { + "DescribeTagsResponse": { + "DescribeTagsResult": {"TaggedResources": tagged_resources}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def delete_tags(self): - resource_name = self._get_param('ResourceName') - tag_keys = self.unpack_list_params('TagKeys.TagKey') + resource_name = self._get_param("ResourceName") + tag_keys = self.unpack_list_params("TagKeys.TagKey") self.redshift_backend.delete_tags(resource_name, tag_keys) - return self.get_response({ - "DeleteTagsResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DeleteTagsResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + } } } - }) + ) def enable_snapshot_copy(self): snapshot_copy_kwargs = { - 'cluster_identifier': self._get_param('ClusterIdentifier'), - 'destination_region': self._get_param('DestinationRegion'), - 'retention_period': self._get_param('RetentionPeriod', 7), - 'snapshot_copy_grant_name': self._get_param('SnapshotCopyGrantName'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "destination_region": self._get_param("DestinationRegion"), + "retention_period": self._get_param("RetentionPeriod", 7), + "snapshot_copy_grant_name": self._get_param("SnapshotCopyGrantName"), } cluster = self.redshift_backend.enable_snapshot_copy(**snapshot_copy_kwargs) - return self.get_response({ - "EnableSnapshotCopyResponse": { - "EnableSnapshotCopyResult": { - "Cluster": cluster.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "EnableSnapshotCopyResponse": { + "EnableSnapshotCopyResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def disable_snapshot_copy(self): snapshot_copy_kwargs = { - 'cluster_identifier': self._get_param('ClusterIdentifier'), + "cluster_identifier": self._get_param("ClusterIdentifier") } cluster = self.redshift_backend.disable_snapshot_copy(**snapshot_copy_kwargs) - return self.get_response({ - "DisableSnapshotCopyResponse": { - "DisableSnapshotCopyResult": { - "Cluster": cluster.to_json() - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "DisableSnapshotCopyResponse": { + "DisableSnapshotCopyResult": {"Cluster": cluster.to_json()}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) def modify_snapshot_copy_retention_period(self): snapshot_copy_kwargs = { - 'cluster_identifier': self._get_param('ClusterIdentifier'), - 'retention_period': self._get_param('RetentionPeriod'), + "cluster_identifier": self._get_param("ClusterIdentifier"), + "retention_period": self._get_param("RetentionPeriod"), } - cluster = self.redshift_backend.modify_snapshot_copy_retention_period(**snapshot_copy_kwargs) + cluster = self.redshift_backend.modify_snapshot_copy_retention_period( + **snapshot_copy_kwargs + ) - return self.get_response({ - "ModifySnapshotCopyRetentionPeriodResponse": { - "ModifySnapshotCopyRetentionPeriodResult": { - "Clusters": [cluster.to_json()] - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return self.get_response( + { + "ModifySnapshotCopyRetentionPeriodResponse": { + "ModifySnapshotCopyRetentionPeriodResult": { + "Clusters": [cluster.to_json()] + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) diff --git a/moto/redshift/urls.py b/moto/redshift/urls.py index ebef59e86..8494669ee 100644 --- a/moto/redshift/urls.py +++ b/moto/redshift/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import RedshiftResponse -url_bases = [ - "https?://redshift.(.+).amazonaws.com", -] +url_bases = ["https?://redshift.(.+).amazonaws.com"] -url_paths = { - '{0}/$': RedshiftResponse.dispatch, -} +url_paths = {"{0}/$": RedshiftResponse.dispatch} diff --git a/moto/resourcegroups/__init__.py b/moto/resourcegroups/__init__.py index 74b0eb598..13ff17307 100644 --- a/moto/resourcegroups/__init__.py +++ b/moto/resourcegroups/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import resourcegroups_backends from ..core.models import base_decorator -resourcegroups_backend = resourcegroups_backends['us-east-1'] +resourcegroups_backend = resourcegroups_backends["us-east-1"] mock_resourcegroups = base_decorator(resourcegroups_backends) diff --git a/moto/resourcegroups/exceptions.py b/moto/resourcegroups/exceptions.py index a8e542979..6c0f470be 100644 --- a/moto/resourcegroups/exceptions.py +++ b/moto/resourcegroups/exceptions.py @@ -9,5 +9,6 @@ class BadRequestException(HTTPException): def __init__(self, message, **kwargs): super(BadRequestException, self).__init__( - description=json.dumps({"Message": message, "Code": "BadRequestException"}), **kwargs + description=json.dumps({"Message": message, "Code": "BadRequestException"}), + **kwargs ) diff --git a/moto/resourcegroups/models.py b/moto/resourcegroups/models.py index 6734bd48a..5dd54d197 100644 --- a/moto/resourcegroups/models.py +++ b/moto/resourcegroups/models.py @@ -23,14 +23,14 @@ class FakeResourceGroup(BaseModel): if self._validate_tags(value=tags): self._tags = tags self._raise_errors() - self.arn = "arn:aws:resource-groups:us-west-1:123456789012:{name}".format(name=name) + self.arn = "arn:aws:resource-groups:us-west-1:123456789012:{name}".format( + name=name + ) @staticmethod def _format_error(key, value, constraint): return "Value '{value}' at '{key}' failed to satisfy constraint: {constraint}".format( - constraint=constraint, - key=key, - value=value, + constraint=constraint, key=key, value=value ) def _raise_errors(self): @@ -38,24 +38,30 @@ class FakeResourceGroup(BaseModel): errors_len = len(self.errors) plural = "s" if len(self.errors) > 1 else "" errors = "; ".join(self.errors) - raise BadRequestException("{errors_len} validation error{plural} detected: {errors}".format( - errors_len=errors_len, plural=plural, errors=errors, - )) + raise BadRequestException( + "{errors_len} validation error{plural} detected: {errors}".format( + errors_len=errors_len, plural=plural, errors=errors + ) + ) def _validate_description(self, value): errors = [] if len(value) > 511: - errors.append(self._format_error( - key="description", - value=value, - constraint="Member must have length less than or equal to 512", - )) + errors.append( + self._format_error( + key="description", + value=value, + constraint="Member must have length less than or equal to 512", + ) + ) if not re.match(r"^[\sa-zA-Z0-9_.-]*$", value): - errors.append(self._format_error( - key="name", - value=value, - constraint=r"Member must satisfy regular expression pattern: [\sa-zA-Z0-9_\.-]*", - )) + errors.append( + self._format_error( + key="name", + value=value, + constraint=r"Member must satisfy regular expression pattern: [\sa-zA-Z0-9_\.-]*", + ) + ) if errors: self.errors += errors return False @@ -64,18 +70,22 @@ class FakeResourceGroup(BaseModel): def _validate_name(self, value): errors = [] if len(value) > 128: - errors.append(self._format_error( - key="name", - value=value, - constraint="Member must have length less than or equal to 128", - )) + errors.append( + self._format_error( + key="name", + value=value, + constraint="Member must have length less than or equal to 128", + ) + ) # Note \ is a character to match not an escape. if not re.match(r"^[a-zA-Z0-9_\\.-]+$", value): - errors.append(self._format_error( - key="name", - value=value, - constraint=r"Member must satisfy regular expression pattern: [a-zA-Z0-9_\.-]+", - )) + errors.append( + self._format_error( + key="name", + value=value, + constraint=r"Member must satisfy regular expression pattern: [a-zA-Z0-9_\.-]+", + ) + ) if errors: self.errors += errors return False @@ -84,17 +94,21 @@ class FakeResourceGroup(BaseModel): def _validate_resource_query(self, value): errors = [] if value["Type"] not in {"CLOUDFORMATION_STACK_1_0", "TAG_FILTERS_1_0"}: - errors.append(self._format_error( - key="resourceQuery.type", - value=value, - constraint="Member must satisfy enum value set: [CLOUDFORMATION_STACK_1_0, TAG_FILTERS_1_0]", - )) + errors.append( + self._format_error( + key="resourceQuery.type", + value=value, + constraint="Member must satisfy enum value set: [CLOUDFORMATION_STACK_1_0, TAG_FILTERS_1_0]", + ) + ) if len(value["Query"]) > 2048: - errors.append(self._format_error( - key="resourceQuery.query", - value=value, - constraint="Member must have length less than or equal to 2048", - )) + errors.append( + self._format_error( + key="resourceQuery.query", + value=value, + constraint="Member must have length less than or equal to 2048", + ) + ) if errors: self.errors += errors return False @@ -183,7 +197,7 @@ class FakeResourceGroup(BaseModel): self._tags = value -class ResourceGroups(): +class ResourceGroups: def __init__(self): self.by_name = {} self.by_arn = {} @@ -213,7 +227,9 @@ class ResourceGroupsBackend(BaseBackend): type = resource_query["Type"] query = json.loads(resource_query["Query"]) query_keys = set(query.keys()) - invalid_json_exception = BadRequestException("Invalid query: Invalid query format: check JSON syntax") + invalid_json_exception = BadRequestException( + "Invalid query: Invalid query format: check JSON syntax" + ) if not isinstance(query["ResourceTypeFilters"], list): raise invalid_json_exception if type == "CLOUDFORMATION_STACK_1_0": @@ -255,7 +271,9 @@ class ResourceGroupsBackend(BaseBackend): "Invalid query: The TagFilter element cannot have empty or null Key field" ) if len(key) > 128: - raise BadRequestException("Invalid query: The maximum length for a tag Key is 128") + raise BadRequestException( + "Invalid query: The maximum length for a tag Key is 128" + ) values = tag_filter["Values"] if not isinstance(values, list): raise invalid_json_exception @@ -274,16 +292,13 @@ class ResourceGroupsBackend(BaseBackend): @staticmethod def _validate_tags(tags): for tag in tags: - if tag.lower().startswith('aws:'): + if tag.lower().startswith("aws:"): raise BadRequestException("Tag keys must not start with 'aws:'") def create_group(self, name, resource_query, description=None, tags=None): tags = tags or {} group = FakeResourceGroup( - name=name, - resource_query=resource_query, - description=description, - tags=tags, + name=name, resource_query=resource_query, description=description, tags=tags ) if name in self.groups: raise BadRequestException("Cannot create group: group already exists") @@ -335,4 +350,6 @@ class ResourceGroupsBackend(BaseBackend): available_regions = boto3.session.Session().get_available_regions("resource-groups") -resourcegroups_backends = {region: ResourceGroupsBackend(region_name=region) for region in available_regions} +resourcegroups_backends = { + region: ResourceGroupsBackend(region_name=region) for region in available_regions +} diff --git a/moto/resourcegroups/responses.py b/moto/resourcegroups/responses.py index 02ea14c1a..77edff19d 100644 --- a/moto/resourcegroups/responses.py +++ b/moto/resourcegroups/responses.py @@ -11,7 +11,7 @@ from .models import resourcegroups_backends class ResourceGroupsResponse(BaseResponse): - SERVICE_NAME = 'resource-groups' + SERVICE_NAME = "resource-groups" @property def resourcegroups_backend(self): @@ -23,140 +23,145 @@ class ResourceGroupsResponse(BaseResponse): resource_query = self._get_param("ResourceQuery") tags = self._get_param("Tags") group = self.resourcegroups_backend.create_group( - name=name, - description=description, - resource_query=resource_query, - tags=tags, + name=name, description=description, resource_query=resource_query, tags=tags + ) + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + }, + "ResourceQuery": group.resource_query, + "Tags": group.tags, + } ) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description - }, - "ResourceQuery": group.resource_query, - "Tags": group.tags - }) def delete_group(self): group_name = self._get_param("GroupName") group = self.resourcegroups_backend.delete_group(group_name=group_name) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description - }, - }) + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } + } + ) def get_group(self): group_name = self._get_param("GroupName") group = self.resourcegroups_backend.get_group(group_name=group_name) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description, + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } } - }) + ) def get_group_query(self): group_name = self._get_param("GroupName") group = self.resourcegroups_backend.get_group(group_name=group_name) - return json.dumps({ - "GroupQuery": { - "GroupName": group.name, - "ResourceQuery": group.resource_query, + return json.dumps( + { + "GroupQuery": { + "GroupName": group.name, + "ResourceQuery": group.resource_query, + } } - }) + ) def get_tags(self): arn = unquote(self._get_param("Arn")) - return json.dumps({ - "Arn": arn, - "Tags": self.resourcegroups_backend.get_tags(arn=arn) - }) + return json.dumps( + {"Arn": arn, "Tags": self.resourcegroups_backend.get_tags(arn=arn)} + ) def list_group_resources(self): - raise NotImplementedError('ResourceGroups.list_group_resources is not yet implemented') + raise NotImplementedError( + "ResourceGroups.list_group_resources is not yet implemented" + ) def list_groups(self): filters = self._get_param("Filters") if filters: raise NotImplementedError( - 'ResourceGroups.list_groups with filter parameter is not yet implemented' + "ResourceGroups.list_groups with filter parameter is not yet implemented" ) max_results = self._get_int_param("MaxResults", 50) next_token = self._get_param("NextToken") groups = self.resourcegroups_backend.list_groups( - filters=filters, - max_results=max_results, - next_token=next_token + filters=filters, max_results=max_results, next_token=next_token + ) + return json.dumps( + { + "GroupIdentifiers": [ + {"GroupName": group.name, "GroupArn": group.arn} + for group in groups.values() + ], + "Groups": [ + { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } + for group in groups.values() + ], + "NextToken": next_token, + } ) - return json.dumps({ - "GroupIdentifiers": [{ - "GroupName": group.name, - "GroupArn": group.arn, - } for group in groups.values()], - "Groups": [{ - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description, - } for group in groups.values()], - "NextToken": next_token, - }) def search_resources(self): - raise NotImplementedError('ResourceGroups.search_resources is not yet implemented') + raise NotImplementedError( + "ResourceGroups.search_resources is not yet implemented" + ) def tag(self): arn = unquote(self._get_param("Arn")) tags = self._get_param("Tags") if arn not in self.resourcegroups_backend.groups.by_arn: raise NotImplementedError( - 'ResourceGroups.tag with non-resource-group Arn parameter is not yet implemented' + "ResourceGroups.tag with non-resource-group Arn parameter is not yet implemented" ) self.resourcegroups_backend.tag(arn=arn, tags=tags) - return json.dumps({ - "Arn": arn, - "Tags": tags - }) + return json.dumps({"Arn": arn, "Tags": tags}) def untag(self): arn = unquote(self._get_param("Arn")) keys = self._get_param("Keys") if arn not in self.resourcegroups_backend.groups.by_arn: raise NotImplementedError( - 'ResourceGroups.untag with non-resource-group Arn parameter is not yet implemented' + "ResourceGroups.untag with non-resource-group Arn parameter is not yet implemented" ) self.resourcegroups_backend.untag(arn=arn, keys=keys) - return json.dumps({ - "Arn": arn, - "Keys": keys - }) + return json.dumps({"Arn": arn, "Keys": keys}) def update_group(self): group_name = self._get_param("GroupName") description = self._get_param("Description", "") - group = self.resourcegroups_backend.update_group(group_name=group_name, description=description) - return json.dumps({ - "Group": { - "GroupArn": group.arn, - "Name": group.name, - "Description": group.description - }, - }) + group = self.resourcegroups_backend.update_group( + group_name=group_name, description=description + ) + return json.dumps( + { + "Group": { + "GroupArn": group.arn, + "Name": group.name, + "Description": group.description, + } + } + ) def update_group_query(self): group_name = self._get_param("GroupName") resource_query = self._get_param("ResourceQuery") group = self.resourcegroups_backend.update_group_query( - group_name=group_name, - resource_query=resource_query + group_name=group_name, resource_query=resource_query + ) + return json.dumps( + {"GroupQuery": {"GroupName": group.name, "ResourceQuery": resource_query}} ) - return json.dumps({ - "GroupQuery": { - "GroupName": group.name, - "ResourceQuery": resource_query - } - }) diff --git a/moto/resourcegroups/urls.py b/moto/resourcegroups/urls.py index 518dde766..b40179145 100644 --- a/moto/resourcegroups/urls.py +++ b/moto/resourcegroups/urls.py @@ -1,14 +1,12 @@ from __future__ import unicode_literals from .responses import ResourceGroupsResponse -url_bases = [ - "https?://resource-groups(-fips)?.(.+).amazonaws.com", -] +url_bases = ["https?://resource-groups(-fips)?.(.+).amazonaws.com"] url_paths = { - '{0}/groups$': ResourceGroupsResponse.dispatch, - '{0}/groups/(?P[^/]+)$': ResourceGroupsResponse.dispatch, - '{0}/groups/(?P[^/]+)/query$': ResourceGroupsResponse.dispatch, - '{0}/groups-list$': ResourceGroupsResponse.dispatch, - '{0}/resources/(?P[^/]+)/tags$': ResourceGroupsResponse.dispatch, + "{0}/groups$": ResourceGroupsResponse.dispatch, + "{0}/groups/(?P[^/]+)$": ResourceGroupsResponse.dispatch, + "{0}/groups/(?P[^/]+)/query$": ResourceGroupsResponse.dispatch, + "{0}/groups-list$": ResourceGroupsResponse.dispatch, + "{0}/resources/(?P[^/]+)/tags$": ResourceGroupsResponse.dispatch, } diff --git a/moto/resourcegroupstaggingapi/__init__.py b/moto/resourcegroupstaggingapi/__init__.py index bd0c4a7df..2dff989b6 100644 --- a/moto/resourcegroupstaggingapi/__init__.py +++ b/moto/resourcegroupstaggingapi/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import resourcegroupstaggingapi_backends from ..core.models import base_decorator -resourcegroupstaggingapi_backend = resourcegroupstaggingapi_backends['us-east-1'] +resourcegroupstaggingapi_backend = resourcegroupstaggingapi_backends["us-east-1"] mock_resourcegroupstaggingapi = base_decorator(resourcegroupstaggingapi_backends) diff --git a/moto/resourcegroupstaggingapi/models.py b/moto/resourcegroupstaggingapi/models.py index 3f15017cc..7b0c03a88 100644 --- a/moto/resourcegroupstaggingapi/models.py +++ b/moto/resourcegroupstaggingapi/models.py @@ -42,7 +42,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): """ :rtype: moto.s3.models.S3Backend """ - return s3_backends['global'] + return s3_backends["global"] @property def ec2_backend(self): @@ -114,16 +114,18 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # TODO move these to their respective backends filters = [lambda t, v: True] for tag_filter_dict in tag_filters: - values = tag_filter_dict.get('Values', []) + values = tag_filter_dict.get("Values", []) if len(values) == 0: # Check key matches - filters.append(lambda t, v: t == tag_filter_dict['Key']) + filters.append(lambda t, v: t == tag_filter_dict["Key"]) elif len(values) == 1: # Check its exactly the same as key, value - filters.append(lambda t, v: t == tag_filter_dict['Key'] and v == values[0]) + filters.append( + lambda t, v: t == tag_filter_dict["Key"] and v == values[0] + ) else: # Check key matches and value is one of the provided values - filters.append(lambda t, v: t == tag_filter_dict['Key'] and v in values) + filters.append(lambda t, v: t == tag_filter_dict["Key"] and v in values) def tag_filter(tag_list): result = [] @@ -131,7 +133,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): for tag in tag_list: temp_result = [] for f in filters: - f_result = f(tag['Key'], tag['Value']) + f_result = f(tag["Key"], tag["Value"]) temp_result.append(f_result) result.append(all(temp_result)) @@ -140,82 +142,150 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): return True # Do S3, resource type s3 - if not resource_type_filters or 's3' in resource_type_filters: + if not resource_type_filters or "s3" in resource_type_filters: for bucket in self.s3_backend.buckets.values(): tags = [] for tag in bucket.tags.tag_set.tags: - tags.append({'Key': tag.key, 'Value': tag.value}) + tags.append({"Key": tag.key, "Value": tag.value}) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:s3:::' + bucket.name, 'Tags': tags} + yield {"ResourceARN": "arn:aws:s3:::" + bucket.name, "Tags": tags} # EC2 tags def get_ec2_tags(res_id): result = [] for key, value in self.ec2_backend.tags.get(res_id, {}).items(): - result.append({'Key': key, 'Value': value}) + result.append({"Key": key, "Value": value}) return result # EC2 AMI, resource type ec2:image - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:image' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:image" in resource_type_filters + ): for ami in self.ec2_backend.amis.values(): tags = get_ec2_tags(ami.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::image/{1}'.format(self.region_name, ami.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::image/{1}".format( + self.region_name, ami.id + ), + "Tags": tags, + } # EC2 Instance, resource type ec2:instance - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:instance' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:instance" in resource_type_filters + ): for reservation in self.ec2_backend.reservations.values(): for instance in reservation.instances: tags = get_ec2_tags(instance.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::instance/{1}'.format(self.region_name, instance.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::instance/{1}".format( + self.region_name, instance.id + ), + "Tags": tags, + } # EC2 NetworkInterface, resource type ec2:network-interface - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:network-interface' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:network-interface" in resource_type_filters + ): for eni in self.ec2_backend.enis.values(): tags = get_ec2_tags(eni.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::network-interface/{1}'.format(self.region_name, eni.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::network-interface/{1}".format( + self.region_name, eni.id + ), + "Tags": tags, + } # TODO EC2 ReservedInstance # EC2 SecurityGroup, resource type ec2:security-group - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:security-group' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:security-group" in resource_type_filters + ): for vpc in self.ec2_backend.groups.values(): for sg in vpc.values(): tags = get_ec2_tags(sg.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::security-group/{1}'.format(self.region_name, sg.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::security-group/{1}".format( + self.region_name, sg.id + ), + "Tags": tags, + } # EC2 Snapshot, resource type ec2:snapshot - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:snapshot' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:snapshot" in resource_type_filters + ): for snapshot in self.ec2_backend.snapshots.values(): tags = get_ec2_tags(snapshot.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::snapshot/{1}'.format(self.region_name, snapshot.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::snapshot/{1}".format( + self.region_name, snapshot.id + ), + "Tags": tags, + } # TODO EC2 SpotInstanceRequest # EC2 Volume, resource type ec2:volume - if not resource_type_filters or 'ec2' in resource_type_filters or 'ec2:volume' in resource_type_filters: + if ( + not resource_type_filters + or "ec2" in resource_type_filters + or "ec2:volume" in resource_type_filters + ): for volume in self.ec2_backend.volumes.values(): tags = get_ec2_tags(volume.id) - if not tags or not tag_filter(tags): # Skip if no tags, or invalid filter + if not tags or not tag_filter( + tags + ): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': 'arn:aws:ec2:{0}::volume/{1}'.format(self.region_name, volume.id), 'Tags': tags} + yield { + "ResourceARN": "arn:aws:ec2:{0}::volume/{1}".format( + self.region_name, volume.id + ), + "Tags": tags, + } # TODO add these to the keys and values functions / combine functions # ELB @@ -223,16 +293,20 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): def get_elbv2_tags(arn): result = [] for key, value in self.elbv2_backend.load_balancers[elb.arn].tags.items(): - result.append({'Key': key, 'Value': value}) + result.append({"Key": key, "Value": value}) return result - if not resource_type_filters or 'elasticloadbalancer' in resource_type_filters or 'elasticloadbalancer:loadbalancer' in resource_type_filters: + if ( + not resource_type_filters + or "elasticloadbalancer" in resource_type_filters + or "elasticloadbalancer:loadbalancer" in resource_type_filters + ): for elb in self.elbv2_backend.load_balancers.values(): tags = get_elbv2_tags(elb.arn) if not tag_filter(tags): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': '{0}'.format(elb.arn), 'Tags': tags} + yield {"ResourceARN": "{0}".format(elb.arn), "Tags": tags} # EMR Cluster @@ -244,16 +318,16 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): def get_kms_tags(kms_key_id): result = [] for tag in self.kms_backend.list_resource_tags(kms_key_id): - result.append({'Key': tag['TagKey'], 'Value': tag['TagValue']}) + result.append({"Key": tag["TagKey"], "Value": tag["TagValue"]}) return result - if not resource_type_filters or 'kms' in resource_type_filters: + if not resource_type_filters or "kms" in resource_type_filters: for kms_key in self.kms_backend.list_keys(): tags = get_kms_tags(kms_key.id) if not tag_filter(tags): # Skip if no tags, or invalid filter continue - yield {'ResourceARN': '{0}'.format(kms_key.arn), 'Tags': tags} + yield {"ResourceARN": "{0}".format(kms_key.arn), "Tags": tags} # RDS Instance # RDS Reserved Database Instance @@ -387,25 +461,37 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): for value in get_ec2_values(volume.id): yield value - def get_resources(self, pagination_token=None, - resources_per_page=50, tags_per_page=100, - tag_filters=None, resource_type_filters=None): + def get_resources( + self, + pagination_token=None, + resources_per_page=50, + tags_per_page=100, + tag_filters=None, + resource_type_filters=None, + ): # Simple range checking if 100 >= tags_per_page >= 500: - raise RESTError('InvalidParameterException', 'TagsPerPage must be between 100 and 500') + raise RESTError( + "InvalidParameterException", "TagsPerPage must be between 100 and 500" + ) if 1 >= resources_per_page >= 50: - raise RESTError('InvalidParameterException', 'ResourcesPerPage must be between 1 and 50') + raise RESTError( + "InvalidParameterException", "ResourcesPerPage must be between 1 and 50" + ) # If we have a token, go and find the respective generator, or error if pagination_token: if pagination_token not in self._pages: - raise RESTError('PaginationTokenExpiredException', 'Token does not exist') + raise RESTError( + "PaginationTokenExpiredException", "Token does not exist" + ) - generator = self._pages[pagination_token]['gen'] - left_over = self._pages[pagination_token]['misc'] + generator = self._pages[pagination_token]["gen"] + left_over = self._pages[pagination_token]["misc"] else: - generator = self._get_resources_generator(tag_filters=tag_filters, - resource_type_filters=resource_type_filters) + generator = self._get_resources_generator( + tag_filters=tag_filters, resource_type_filters=resource_type_filters + ) left_over = None result = [] @@ -414,13 +500,13 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): if left_over: result.append(left_over) current_resources += 1 - current_tags += len(left_over['Tags']) + current_tags += len(left_over["Tags"]) try: while True: # Generator format: [{'ResourceARN': str, 'Tags': [{'Key': str, 'Value': str]}, ...] next_item = six.next(generator) - resource_tags = len(next_item['Tags']) + resource_tags = len(next_item["Tags"]) if current_resources >= resources_per_page: break @@ -438,7 +524,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # Didn't hit StopIteration so there's stuff left in generator new_token = str(uuid.uuid4()) - self._pages[new_token] = {'gen': generator, 'misc': next_item} + self._pages[new_token] = {"gen": generator, "misc": next_item} # Token used up, might as well bin now, if you call it again your an idiot if pagination_token: @@ -450,10 +536,12 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): if pagination_token: if pagination_token not in self._pages: - raise RESTError('PaginationTokenExpiredException', 'Token does not exist') + raise RESTError( + "PaginationTokenExpiredException", "Token does not exist" + ) - generator = self._pages[pagination_token]['gen'] - left_over = self._pages[pagination_token]['misc'] + generator = self._pages[pagination_token]["gen"] + left_over = self._pages[pagination_token]["misc"] else: generator = self._get_tag_keys_generator() left_over = None @@ -482,7 +570,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # Didn't hit StopIteration so there's stuff left in generator new_token = str(uuid.uuid4()) - self._pages[new_token] = {'gen': generator, 'misc': next_item} + self._pages[new_token] = {"gen": generator, "misc": next_item} # Token used up, might as well bin now, if you call it again your an idiot if pagination_token: @@ -494,10 +582,12 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): if pagination_token: if pagination_token not in self._pages: - raise RESTError('PaginationTokenExpiredException', 'Token does not exist') + raise RESTError( + "PaginationTokenExpiredException", "Token does not exist" + ) - generator = self._pages[pagination_token]['gen'] - left_over = self._pages[pagination_token]['misc'] + generator = self._pages[pagination_token]["gen"] + left_over = self._pages[pagination_token]["misc"] else: generator = self._get_tag_values_generator(key) left_over = None @@ -526,7 +616,7 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # Didn't hit StopIteration so there's stuff left in generator new_token = str(uuid.uuid4()) - self._pages[new_token] = {'gen': generator, 'misc': next_item} + self._pages[new_token] = {"gen": generator, "misc": next_item} # Token used up, might as well bin now, if you call it again your an idiot if pagination_token: @@ -546,5 +636,9 @@ class ResourceGroupsTaggingAPIBackend(BaseBackend): # return failed_resources_map -available_regions = boto3.session.Session().get_available_regions("resourcegroupstaggingapi") -resourcegroupstaggingapi_backends = {region: ResourceGroupsTaggingAPIBackend(region) for region in available_regions} +available_regions = boto3.session.Session().get_available_regions( + "resourcegroupstaggingapi" +) +resourcegroupstaggingapi_backends = { + region: ResourceGroupsTaggingAPIBackend(region) for region in available_regions +} diff --git a/moto/resourcegroupstaggingapi/responses.py b/moto/resourcegroupstaggingapi/responses.py index 966778f29..02f5b5484 100644 --- a/moto/resourcegroupstaggingapi/responses.py +++ b/moto/resourcegroupstaggingapi/responses.py @@ -5,7 +5,7 @@ import json class ResourceGroupsTaggingAPIResponse(BaseResponse): - SERVICE_NAME = 'resourcegroupstaggingapi' + SERVICE_NAME = "resourcegroupstaggingapi" @property def backend(self): @@ -32,25 +32,21 @@ class ResourceGroupsTaggingAPIResponse(BaseResponse): ) # Format tag response - response = { - 'ResourceTagMappingList': resource_tag_mapping_list - } + response = {"ResourceTagMappingList": resource_tag_mapping_list} if pagination_token: - response['PaginationToken'] = pagination_token + response["PaginationToken"] = pagination_token return json.dumps(response) def get_tag_keys(self): pagination_token = self._get_param("PaginationToken") pagination_token, tag_keys = self.backend.get_tag_keys( - pagination_token=pagination_token, + pagination_token=pagination_token ) - response = { - 'TagKeys': tag_keys - } + response = {"TagKeys": tag_keys} if pagination_token: - response['PaginationToken'] = pagination_token + response["PaginationToken"] = pagination_token return json.dumps(response) @@ -58,15 +54,12 @@ class ResourceGroupsTaggingAPIResponse(BaseResponse): pagination_token = self._get_param("PaginationToken") key = self._get_param("Key") pagination_token, tag_values = self.backend.get_tag_values( - pagination_token=pagination_token, - key=key, + pagination_token=pagination_token, key=key ) - response = { - 'TagValues': tag_values - } + response = {"TagValues": tag_values} if pagination_token: - response['PaginationToken'] = pagination_token + response["PaginationToken"] = pagination_token return json.dumps(response) diff --git a/moto/resourcegroupstaggingapi/urls.py b/moto/resourcegroupstaggingapi/urls.py index a972df276..3b0182ee9 100644 --- a/moto/resourcegroupstaggingapi/urls.py +++ b/moto/resourcegroupstaggingapi/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import ResourceGroupsTaggingAPIResponse -url_bases = [ - "https?://tagging.(.+).amazonaws.com", -] +url_bases = ["https?://tagging.(.+).amazonaws.com"] -url_paths = { - '{0}/$': ResourceGroupsTaggingAPIResponse.dispatch, -} +url_paths = {"{0}/$": ResourceGroupsTaggingAPIResponse.dispatch} diff --git a/moto/route53/models.py b/moto/route53/models.py index 77a0e59e6..2ae03e54d 100644 --- a/moto/route53/models.py +++ b/moto/route53/models.py @@ -15,11 +15,10 @@ ROUTE53_ID_CHOICE = string.ascii_uppercase + string.digits def create_route53_zone_id(): # New ID's look like this Z1RWWTK7Y8UDDQ - return ''.join([random.choice(ROUTE53_ID_CHOICE) for _ in range(0, 15)]) + return "".join([random.choice(ROUTE53_ID_CHOICE) for _ in range(0, 15)]) class HealthCheck(BaseModel): - def __init__(self, health_check_id, health_check_args): self.id = health_check_id self.ip_address = health_check_args.get("ip_address") @@ -36,23 +35,26 @@ class HealthCheck(BaseModel): return self.id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties']['HealthCheckConfig'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"]["HealthCheckConfig"] health_check_args = { - "ip_address": properties.get('IPAddress'), - "port": properties.get('Port'), - "type": properties['Type'], - "resource_path": properties.get('ResourcePath'), - "fqdn": properties.get('FullyQualifiedDomainName'), - "search_string": properties.get('SearchString'), - "request_interval": properties.get('RequestInterval'), - "failure_threshold": properties.get('FailureThreshold'), + "ip_address": properties.get("IPAddress"), + "port": properties.get("Port"), + "type": properties["Type"], + "resource_path": properties.get("ResourcePath"), + "fqdn": properties.get("FullyQualifiedDomainName"), + "search_string": properties.get("SearchString"), + "request_interval": properties.get("RequestInterval"), + "failure_threshold": properties.get("FailureThreshold"), } health_check = route53_backend.create_health_check(health_check_args) return health_check def to_xml(self): - template = Template(""" + template = Template( + """ {{ health_check.id }} example.com 192.0.2.17 @@ -68,59 +70,66 @@ class HealthCheck(BaseModel): {% endif %} 1 - """) + """ + ) return template.render(health_check=self) class RecordSet(BaseModel): - def __init__(self, kwargs): - self.name = kwargs.get('Name') - self.type_ = kwargs.get('Type') - self.ttl = kwargs.get('TTL') - self.records = kwargs.get('ResourceRecords', []) - self.set_identifier = kwargs.get('SetIdentifier') - self.weight = kwargs.get('Weight') - self.region = kwargs.get('Region') - self.health_check = kwargs.get('HealthCheckId') - self.hosted_zone_name = kwargs.get('HostedZoneName') - self.hosted_zone_id = kwargs.get('HostedZoneId') - self.alias_target = kwargs.get('AliasTarget') + self.name = kwargs.get("Name") + self.type_ = kwargs.get("Type") + self.ttl = kwargs.get("TTL") + self.records = kwargs.get("ResourceRecords", []) + self.set_identifier = kwargs.get("SetIdentifier") + self.weight = kwargs.get("Weight") + self.region = kwargs.get("Region") + self.health_check = kwargs.get("HealthCheckId") + self.hosted_zone_name = kwargs.get("HostedZoneName") + self.hosted_zone_id = kwargs.get("HostedZoneId") + self.alias_target = kwargs.get("AliasTarget") @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") if zone_name: hosted_zone = route53_backend.get_hosted_zone_by_name(zone_name) else: - hosted_zone = route53_backend.get_hosted_zone( - properties["HostedZoneId"]) + hosted_zone = route53_backend.get_hosted_zone(properties["HostedZoneId"]) record_set = hosted_zone.add_rrset(properties) return record_set @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): cls.delete_from_cloudformation_json( - original_resource.name, cloudformation_json, region_name) - return cls.create_from_cloudformation_json(new_resource_name, cloudformation_json, region_name) + original_resource.name, cloudformation_json, region_name + ) + return cls.create_from_cloudformation_json( + new_resource_name, cloudformation_json, region_name + ) @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): # this will break if you changed the zone the record is in, # unfortunately - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") if zone_name: hosted_zone = route53_backend.get_hosted_zone_by_name(zone_name) else: - hosted_zone = route53_backend.get_hosted_zone( - properties["HostedZoneId"]) + hosted_zone = route53_backend.get_hosted_zone(properties["HostedZoneId"]) try: - hosted_zone.delete_rrset({'Name': resource_name}) + hosted_zone.delete_rrset({"Name": resource_name}) except KeyError: pass @@ -129,7 +138,8 @@ class RecordSet(BaseModel): return self.name def to_xml(self): - template = Template(""" + template = Template( + """ {{ record_set.name }} {{ record_set.type_ }} {% if record_set.set_identifier %} @@ -162,26 +172,25 @@ class RecordSet(BaseModel): {% if record_set.health_check %} {{ record_set.health_check }} {% endif %} - """) + """ + ) return template.render(record_set=self) def delete(self, *args, **kwargs): - ''' Not exposed as part of the Route 53 API - used for CloudFormation. args are ignored ''' - hosted_zone = route53_backend.get_hosted_zone_by_name( - self.hosted_zone_name) + """ Not exposed as part of the Route 53 API - used for CloudFormation. args are ignored """ + hosted_zone = route53_backend.get_hosted_zone_by_name(self.hosted_zone_name) if not hosted_zone: hosted_zone = route53_backend.get_hosted_zone(self.hosted_zone_id) - hosted_zone.delete_rrset({'Name': self.name, 'Type': self.type_}) + hosted_zone.delete_rrset({"Name": self.name, "Type": self.type_}) def reverse_domain_name(domain_name): - if domain_name.endswith('.'): # normalize without trailing dot + if domain_name.endswith("."): # normalize without trailing dot domain_name = domain_name[:-1] - return '.'.join(reversed(domain_name.split('.'))) + return ".".join(reversed(domain_name.split("."))) class FakeZone(BaseModel): - def __init__(self, name, id_, private_zone, comment=None): self.name = name self.id = id_ @@ -198,7 +207,11 @@ class FakeZone(BaseModel): def upsert_rrset(self, record_set): new_rrset = RecordSet(record_set) for i, rrset in enumerate(self.rrsets): - if rrset.name == new_rrset.name and rrset.type_ == new_rrset.type_ and rrset.set_identifier == new_rrset.set_identifier: + if ( + rrset.name == new_rrset.name + and rrset.type_ == new_rrset.type_ + and rrset.set_identifier == new_rrset.set_identifier + ): self.rrsets[i] = new_rrset break else: @@ -209,13 +222,16 @@ class FakeZone(BaseModel): self.rrsets = [ record_set for record_set in self.rrsets - if record_set.name != rrset['Name'] or - (rrset.get('Type') is not None and record_set.type_ != rrset['Type']) + if record_set.name != rrset["Name"] + or (rrset.get("Type") is not None and record_set.type_ != rrset["Type"]) ] def delete_rrset_by_id(self, set_identifier): self.rrsets = [ - record_set for record_set in self.rrsets if record_set.set_identifier != set_identifier] + record_set + for record_set in self.rrsets + if record_set.set_identifier != set_identifier + ] def get_record_sets(self, start_type, start_name): record_sets = list(self.rrsets) # Copy the list @@ -223,11 +239,15 @@ class FakeZone(BaseModel): record_sets = [ record_set for record_set in record_sets - if reverse_domain_name(record_set.name) >= reverse_domain_name(start_name) + if reverse_domain_name(record_set.name) + >= reverse_domain_name(start_name) ] if start_type: record_sets = [ - record_set for record_set in record_sets if record_set.type_ >= start_type] + record_set + for record_set in record_sets + if record_set.type_ >= start_type + ] return record_sets @@ -236,17 +256,17 @@ class FakeZone(BaseModel): return self.id @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] name = properties["Name"] - hosted_zone = route53_backend.create_hosted_zone( - name, private_zone=False) + hosted_zone = route53_backend.create_hosted_zone(name, private_zone=False) return hosted_zone class RecordSetGroup(BaseModel): - def __init__(self, hosted_zone_id, record_sets): self.hosted_zone_id = hosted_zone_id self.record_sets = record_sets @@ -256,8 +276,10 @@ class RecordSetGroup(BaseModel): return "arn:aws:route53:::hostedzone/{0}".format(self.hosted_zone_id) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] zone_name = properties.get("HostedZoneName") if zone_name: @@ -273,7 +295,6 @@ class RecordSetGroup(BaseModel): class Route53Backend(BaseBackend): - def __init__(self): self.zones = {} self.health_checks = {} @@ -281,26 +302,25 @@ class Route53Backend(BaseBackend): def create_hosted_zone(self, name, private_zone, comment=None): new_id = create_route53_zone_id() - new_zone = FakeZone( - name, new_id, private_zone=private_zone, comment=comment) + new_zone = FakeZone(name, new_id, private_zone=private_zone, comment=comment) self.zones[new_id] = new_zone return new_zone def change_tags_for_resource(self, resource_id, tags): - if 'Tag' in tags: - if isinstance(tags['Tag'], list): - for tag in tags['Tag']: - self.resource_tags[resource_id][tag['Key']] = tag['Value'] + if "Tag" in tags: + if isinstance(tags["Tag"], list): + for tag in tags["Tag"]: + self.resource_tags[resource_id][tag["Key"]] = tag["Value"] else: - key, value = (tags['Tag']['Key'], tags['Tag']['Value']) + key, value = (tags["Tag"]["Key"], tags["Tag"]["Value"]) self.resource_tags[resource_id][key] = value else: - if 'Key' in tags: - if isinstance(tags['Key'], list): - for key in tags['Key']: - del(self.resource_tags[resource_id][key]) + if "Key" in tags: + if isinstance(tags["Key"], list): + for key in tags["Key"]: + del self.resource_tags[resource_id][key] else: - del(self.resource_tags[resource_id][tags['Key']]) + del self.resource_tags[resource_id][tags["Key"]] def list_tags_for_resource(self, resource_id): if resource_id in self.resource_tags: diff --git a/moto/route53/responses.py b/moto/route53/responses.py index f933c575a..3e688b65d 100644 --- a/moto/route53/responses.py +++ b/moto/route53/responses.py @@ -8,23 +8,24 @@ import xmltodict class Route53(BaseResponse): - def list_or_create_hostzone_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) if request.method == "POST": elements = xmltodict.parse(self.body) if "HostedZoneConfig" in elements["CreateHostedZoneRequest"]: - comment = elements["CreateHostedZoneRequest"][ - "HostedZoneConfig"]["Comment"] + comment = elements["CreateHostedZoneRequest"]["HostedZoneConfig"][ + "Comment" + ] try: # in boto3, this field is set directly in the xml private_zone = elements["CreateHostedZoneRequest"][ - "HostedZoneConfig"]["PrivateZone"] + "HostedZoneConfig" + ]["PrivateZone"] except KeyError: # if a VPC subsection is only included in xmls params when private_zone=True, # see boto: boto/route53/connection.py - private_zone = 'VPC' in elements["CreateHostedZoneRequest"] + private_zone = "VPC" in elements["CreateHostedZoneRequest"] else: comment = None private_zone = False @@ -35,9 +36,7 @@ class Route53(BaseResponse): name += "." new_zone = route53_backend.create_hosted_zone( - name, - comment=comment, - private_zone=private_zone, + name, comment=comment, private_zone=private_zone ) template = Template(CREATE_HOSTED_ZONE_RESPONSE) return 201, headers, template.render(zone=new_zone) @@ -54,9 +53,15 @@ class Route53(BaseResponse): dnsname = query_params.get("dnsname") if dnsname: - dnsname = dnsname[0] # parse_qs gives us a list, but this parameter doesn't repeat + dnsname = dnsname[ + 0 + ] # parse_qs gives us a list, but this parameter doesn't repeat # return all zones with that name (there can be more than one) - zones = [zone for zone in route53_backend.get_all_hosted_zones() if zone.name == dnsname] + zones = [ + zone + for zone in route53_backend.get_all_hosted_zones() + if zone.name == dnsname + ] else: # sort by names, but with domain components reversed # see http://boto3.readthedocs.io/en/latest/reference/services/route53.html#Route53.Client.list_hosted_zones_by_name @@ -76,7 +81,7 @@ class Route53(BaseResponse): def get_or_delete_hostzone_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) parsed_url = urlparse(full_url) - zoneid = parsed_url.path.rstrip('/').rsplit('/', 1)[1] + zoneid = parsed_url.path.rstrip("/").rsplit("/", 1)[1] the_zone = route53_backend.get_hosted_zone(zoneid) if not the_zone: return 404, headers, "Zone %s not Found" % zoneid @@ -95,7 +100,7 @@ class Route53(BaseResponse): parsed_url = urlparse(full_url) method = request.method - zoneid = parsed_url.path.rstrip('/').rsplit('/', 2)[1] + zoneid = parsed_url.path.rstrip("/").rsplit("/", 2)[1] the_zone = route53_backend.get_hosted_zone(zoneid) if not the_zone: return 404, headers, "Zone %s Not Found" % zoneid @@ -103,46 +108,55 @@ class Route53(BaseResponse): if method == "POST": elements = xmltodict.parse(self.body) - change_list = elements['ChangeResourceRecordSetsRequest'][ - 'ChangeBatch']['Changes']['Change'] + change_list = elements["ChangeResourceRecordSetsRequest"]["ChangeBatch"][ + "Changes" + ]["Change"] if not isinstance(change_list, list): - change_list = [elements['ChangeResourceRecordSetsRequest'][ - 'ChangeBatch']['Changes']['Change']] + change_list = [ + elements["ChangeResourceRecordSetsRequest"]["ChangeBatch"][ + "Changes" + ]["Change"] + ] for value in change_list: - action = value['Action'] - record_set = value['ResourceRecordSet'] + action = value["Action"] + record_set = value["ResourceRecordSet"] - cleaned_record_name = record_set['Name'].strip('.') - cleaned_hosted_zone_name = the_zone.name.strip('.') + cleaned_record_name = record_set["Name"].strip(".") + cleaned_hosted_zone_name = the_zone.name.strip(".") if not cleaned_record_name.endswith(cleaned_hosted_zone_name): error_msg = """ An error occurred (InvalidChangeBatch) when calling the ChangeResourceRecordSets operation: RRSet with DNS name %s is not permitted in zone %s - """ % (record_set['Name'], the_zone.name) + """ % ( + record_set["Name"], + the_zone.name, + ) return 400, headers, error_msg - if not record_set['Name'].endswith('.'): - record_set['Name'] += '.' + if not record_set["Name"].endswith("."): + record_set["Name"] += "." - if action in ('CREATE', 'UPSERT'): - if 'ResourceRecords' in record_set: - resource_records = list( - record_set['ResourceRecords'].values())[0] + if action in ("CREATE", "UPSERT"): + if "ResourceRecords" in record_set: + resource_records = list(record_set["ResourceRecords"].values())[ + 0 + ] if not isinstance(resource_records, list): # Depending on how many records there are, this may # or may not be a list resource_records = [resource_records] - record_set['ResourceRecords'] = [x['Value'] for x in resource_records] - if action == 'CREATE': + record_set["ResourceRecords"] = [ + x["Value"] for x in resource_records + ] + if action == "CREATE": the_zone.add_rrset(record_set) else: the_zone.upsert_rrset(record_set) elif action == "DELETE": - if 'SetIdentifier' in record_set: - the_zone.delete_rrset_by_id( - record_set["SetIdentifier"]) + if "SetIdentifier" in record_set: + the_zone.delete_rrset_by_id(record_set["SetIdentifier"]) else: the_zone.delete_rrset(record_set) @@ -163,20 +177,20 @@ class Route53(BaseResponse): method = request.method if method == "POST": - properties = xmltodict.parse(self.body)['CreateHealthCheckRequest'][ - 'HealthCheckConfig'] + properties = xmltodict.parse(self.body)["CreateHealthCheckRequest"][ + "HealthCheckConfig" + ] health_check_args = { - "ip_address": properties.get('IPAddress'), - "port": properties.get('Port'), - "type": properties['Type'], - "resource_path": properties.get('ResourcePath'), - "fqdn": properties.get('FullyQualifiedDomainName'), - "search_string": properties.get('SearchString'), - "request_interval": properties.get('RequestInterval'), - "failure_threshold": properties.get('FailureThreshold'), + "ip_address": properties.get("IPAddress"), + "port": properties.get("Port"), + "type": properties["Type"], + "resource_path": properties.get("ResourcePath"), + "fqdn": properties.get("FullyQualifiedDomainName"), + "search_string": properties.get("SearchString"), + "request_interval": properties.get("RequestInterval"), + "failure_threshold": properties.get("FailureThreshold"), } - health_check = route53_backend.create_health_check( - health_check_args) + health_check = route53_backend.create_health_check(health_check_args) template = Template(CREATE_HEALTH_CHECK_RESPONSE) return 201, headers, template.render(health_check=health_check) elif method == "DELETE": @@ -191,13 +205,14 @@ class Route53(BaseResponse): def not_implemented_response(self, request, full_url, headers): self.setup_class(request, full_url, headers) - action = '' - if 'tags' in full_url: - action = 'tags' - elif 'trafficpolicyinstances' in full_url: - action = 'policies' + action = "" + if "tags" in full_url: + action = "tags" + elif "trafficpolicyinstances" in full_url: + action = "policies" raise NotImplementedError( - "The action for {0} has not been implemented for route 53".format(action)) + "The action for {0} has not been implemented for route 53".format(action) + ) def list_or_change_tags_for_resource_request(self, request, full_url, headers): self.setup_class(request, full_url, headers) @@ -209,17 +224,19 @@ class Route53(BaseResponse): if request.method == "GET": tags = route53_backend.list_tags_for_resource(id_) template = Template(LIST_TAGS_FOR_RESOURCE_RESPONSE) - return 200, headers, template.render( - resource_type=type_, resource_id=id_, tags=tags) + return ( + 200, + headers, + template.render(resource_type=type_, resource_id=id_, tags=tags), + ) if request.method == "POST": - tags = xmltodict.parse( - self.body)['ChangeTagsForResourceRequest'] + tags = xmltodict.parse(self.body)["ChangeTagsForResourceRequest"] - if 'AddTags' in tags: - tags = tags['AddTags'] - elif 'RemoveTagKeys' in tags: - tags = tags['RemoveTagKeys'] + if "AddTags" in tags: + tags = tags["AddTags"] + elif "RemoveTagKeys" in tags: + tags = tags["RemoveTagKeys"] route53_backend.change_tags_for_resource(id_, tags) template = Template(CHANGE_TAGS_FOR_RESOURCE_RESPONSE) diff --git a/moto/route53/urls.py b/moto/route53/urls.py index 53abf23a2..a697d258a 100644 --- a/moto/route53/urls.py +++ b/moto/route53/urls.py @@ -1,9 +1,7 @@ from __future__ import unicode_literals from .responses import Route53 -url_bases = [ - "https?://route53(.*).amazonaws.com", -] +url_bases = ["https?://route53(.*).amazonaws.com"] def tag_response1(*args, **kwargs): @@ -15,12 +13,12 @@ def tag_response2(*args, **kwargs): url_paths = { - '{0}/(?P[\d_-]+)/hostedzone$': Route53().list_or_create_hostzone_response, - '{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)$': Route53().get_or_delete_hostzone_response, - '{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)/rrset/?$': Route53().rrset_response, - '{0}/(?P[\d_-]+)/hostedzonesbyname': Route53().list_hosted_zones_by_name_response, - '{0}/(?P[\d_-]+)/healthcheck': Route53().health_check_response, - '{0}/(?P[\d_-]+)/tags/healthcheck/(?P[^/]+)$': tag_response1, - '{0}/(?P[\d_-]+)/tags/hostedzone/(?P[^/]+)$': tag_response2, - '{0}/(?P[\d_-]+)/trafficpolicyinstances/*': Route53().not_implemented_response + "{0}/(?P[\d_-]+)/hostedzone$": Route53().list_or_create_hostzone_response, + "{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)$": Route53().get_or_delete_hostzone_response, + "{0}/(?P[\d_-]+)/hostedzone/(?P[^/]+)/rrset/?$": Route53().rrset_response, + "{0}/(?P[\d_-]+)/hostedzonesbyname": Route53().list_hosted_zones_by_name_response, + "{0}/(?P[\d_-]+)/healthcheck": Route53().health_check_response, + "{0}/(?P[\d_-]+)/tags/healthcheck/(?P[^/]+)$": tag_response1, + "{0}/(?P[\d_-]+)/tags/hostedzone/(?P[^/]+)$": tag_response2, + "{0}/(?P[\d_-]+)/trafficpolicyinstances/*": Route53().not_implemented_response, } diff --git a/moto/s3/config.py b/moto/s3/config.py index 11a16a071..8098addfc 100644 --- a/moto/s3/config.py +++ b/moto/s3/config.py @@ -6,8 +6,15 @@ from moto.s3 import s3_backends class S3ConfigQuery(ConfigQueryModel): - - def list_config_service_resources(self, resource_ids, resource_name, limit, next_token, backend_region=None, resource_region=None): + def list_config_service_resources( + self, + resource_ids, + resource_name, + limit, + next_token, + backend_region=None, + resource_region=None, + ): # The resource_region only matters for aggregated queries as you can filter on bucket regions for them. # For other resource types, you would need to iterate appropriately for the backend_region. @@ -20,14 +27,14 @@ class S3ConfigQuery(ConfigQueryModel): # If no filter was passed in for resource names/ids then return them all: if not resource_ids and not resource_name: - bucket_list = list(self.backends['global'].buckets.keys()) + bucket_list = list(self.backends["global"].buckets.keys()) else: # Match the resource name / ID: bucket_list = [] filter_buckets = [resource_name] if resource_name else resource_ids - for bucket in self.backends['global'].buckets.keys(): + for bucket in self.backends["global"].buckets.keys(): if bucket in filter_buckets: bucket_list.append(bucket) @@ -37,7 +44,7 @@ class S3ConfigQuery(ConfigQueryModel): region_buckets = [] for bucket in bucket_list: - if self.backends['global'].buckets[bucket].region_name == region_filter: + if self.backends["global"].buckets[bucket].region_name == region_filter: region_buckets.append(bucket) bucket_list = region_buckets @@ -61,17 +68,29 @@ class S3ConfigQuery(ConfigQueryModel): start = sorted_buckets.index(next_token) # Get the list of items to collect: - bucket_list = sorted_buckets[start:(start + limit)] + bucket_list = sorted_buckets[start : (start + limit)] if len(sorted_buckets) > (start + limit): new_token = sorted_buckets[start + limit] - return [{'type': 'AWS::S3::Bucket', 'id': bucket, 'name': bucket, 'region': self.backends['global'].buckets[bucket].region_name} - for bucket in bucket_list], new_token + return ( + [ + { + "type": "AWS::S3::Bucket", + "id": bucket, + "name": bucket, + "region": self.backends["global"].buckets[bucket].region_name, + } + for bucket in bucket_list + ], + new_token, + ) - def get_config_resource(self, resource_id, resource_name=None, backend_region=None, resource_region=None): + def get_config_resource( + self, resource_id, resource_name=None, backend_region=None, resource_region=None + ): # Get the bucket: - bucket = self.backends['global'].buckets.get(resource_id, {}) + bucket = self.backends["global"].buckets.get(resource_id, {}) if not bucket: return @@ -89,12 +108,12 @@ class S3ConfigQuery(ConfigQueryModel): config_data = bucket.to_config_dict() # The 'configuration' field is also a JSON string: - config_data['configuration'] = json.dumps(config_data['configuration']) + config_data["configuration"] = json.dumps(config_data["configuration"]) # Supplementary config need all values converted to JSON strings if they are not strings already: - for field, value in config_data['supplementaryConfiguration'].items(): + for field, value in config_data["supplementaryConfiguration"].items(): if not isinstance(value, str): - config_data['supplementaryConfiguration'][field] = json.dumps(value) + config_data["supplementaryConfiguration"][field] = json.dumps(value) return config_data diff --git a/moto/s3/exceptions.py b/moto/s3/exceptions.py index 8d2326fa1..c8236398f 100644 --- a/moto/s3/exceptions.py +++ b/moto/s3/exceptions.py @@ -12,18 +12,16 @@ ERROR_WITH_KEY_NAME = """{% extends 'single_error' %} class S3ClientError(RESTError): - def __init__(self, *args, **kwargs): - kwargs.setdefault('template', 'single_error') - self.templates['bucket_error'] = ERROR_WITH_BUCKET_NAME + kwargs.setdefault("template", "single_error") + self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super(S3ClientError, self).__init__(*args, **kwargs) class BucketError(S3ClientError): - def __init__(self, *args, **kwargs): - kwargs.setdefault('template', 'bucket_error') - self.templates['bucket_error'] = ERROR_WITH_BUCKET_NAME + kwargs.setdefault("template", "bucket_error") + self.templates["bucket_error"] = ERROR_WITH_BUCKET_NAME super(BucketError, self).__init__(*args, **kwargs) @@ -33,10 +31,14 @@ class BucketAlreadyExists(BucketError): def __init__(self, *args, **kwargs): super(BucketAlreadyExists, self).__init__( "BucketAlreadyExists", - ("The requested bucket name is not available. The bucket " - "namespace is shared by all users of the system. Please " - "select a different name and try again"), - *args, **kwargs) + ( + "The requested bucket name is not available. The bucket " + "namespace is shared by all users of the system. Please " + "select a different name and try again" + ), + *args, + **kwargs + ) class MissingBucket(BucketError): @@ -44,9 +46,8 @@ class MissingBucket(BucketError): def __init__(self, *args, **kwargs): super(MissingBucket, self).__init__( - "NoSuchBucket", - "The specified bucket does not exist", - *args, **kwargs) + "NoSuchBucket", "The specified bucket does not exist", *args, **kwargs + ) class MissingKey(S3ClientError): @@ -54,9 +55,7 @@ class MissingKey(S3ClientError): def __init__(self, key_name): super(MissingKey, self).__init__( - "NoSuchKey", - "The specified key does not exist.", - Key=key_name, + "NoSuchKey", "The specified key does not exist.", Key=key_name ) @@ -77,9 +76,13 @@ class InvalidPartOrder(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidPartOrder, self).__init__( "InvalidPartOrder", - ("The list of parts was not in ascending order. The parts " - "list must be specified in order by part number."), - *args, **kwargs) + ( + "The list of parts was not in ascending order. The parts " + "list must be specified in order by part number." + ), + *args, + **kwargs + ) class InvalidPart(S3ClientError): @@ -88,10 +91,14 @@ class InvalidPart(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidPart, self).__init__( "InvalidPart", - ("One or more of the specified parts could not be found. " - "The part might not have been uploaded, or the specified " - "entity tag might not have matched the part's entity tag."), - *args, **kwargs) + ( + "One or more of the specified parts could not be found. " + "The part might not have been uploaded, or the specified " + "entity tag might not have matched the part's entity tag." + ), + *args, + **kwargs + ) class EntityTooSmall(S3ClientError): @@ -101,7 +108,9 @@ class EntityTooSmall(S3ClientError): super(EntityTooSmall, self).__init__( "EntityTooSmall", "Your proposed upload is smaller than the minimum allowed object size.", - *args, **kwargs) + *args, + **kwargs + ) class InvalidRequest(S3ClientError): @@ -110,8 +119,12 @@ class InvalidRequest(S3ClientError): def __init__(self, method, *args, **kwargs): super(InvalidRequest, self).__init__( "InvalidRequest", - "Found unsupported HTTP method in CORS config. Unsupported method is {}".format(method), - *args, **kwargs) + "Found unsupported HTTP method in CORS config. Unsupported method is {}".format( + method + ), + *args, + **kwargs + ) class MalformedXML(S3ClientError): @@ -121,7 +134,9 @@ class MalformedXML(S3ClientError): super(MalformedXML, self).__init__( "MalformedXML", "The XML you provided was not well-formed or did not validate against our published schema", - *args, **kwargs) + *args, + **kwargs + ) class MalformedACLError(S3ClientError): @@ -131,14 +146,18 @@ class MalformedACLError(S3ClientError): super(MalformedACLError, self).__init__( "MalformedACLError", "The XML you provided was not well-formed or did not validate against our published schema", - *args, **kwargs) + *args, + **kwargs + ) class InvalidTargetBucketForLogging(S3ClientError): code = 400 def __init__(self, msg): - super(InvalidTargetBucketForLogging, self).__init__("InvalidTargetBucketForLogging", msg) + super(InvalidTargetBucketForLogging, self).__init__( + "InvalidTargetBucketForLogging", msg + ) class CrossLocationLoggingProhibitted(S3ClientError): @@ -146,8 +165,7 @@ class CrossLocationLoggingProhibitted(S3ClientError): def __init__(self): super(CrossLocationLoggingProhibitted, self).__init__( - "CrossLocationLoggingProhibitted", - "Cross S3 location logging not allowed." + "CrossLocationLoggingProhibitted", "Cross S3 location logging not allowed." ) @@ -156,9 +174,8 @@ class InvalidNotificationARN(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidNotificationARN, self).__init__( - "InvalidArgument", - "The ARN is not well formed", - *args, **kwargs) + "InvalidArgument", "The ARN is not well formed", *args, **kwargs + ) class InvalidNotificationDestination(S3ClientError): @@ -168,7 +185,9 @@ class InvalidNotificationDestination(S3ClientError): super(InvalidNotificationDestination, self).__init__( "InvalidArgument", "The notification destination service region is not valid for the bucket location constraint", - *args, **kwargs) + *args, + **kwargs + ) class InvalidNotificationEvent(S3ClientError): @@ -178,7 +197,9 @@ class InvalidNotificationEvent(S3ClientError): super(InvalidNotificationEvent, self).__init__( "InvalidArgument", "The event is not supported for notifications", - *args, **kwargs) + *args, + **kwargs + ) class InvalidStorageClass(S3ClientError): @@ -188,7 +209,9 @@ class InvalidStorageClass(S3ClientError): super(InvalidStorageClass, self).__init__( "InvalidStorageClass", "The storage class you specified is not valid", - *args, **kwargs) + *args, + **kwargs + ) class InvalidBucketName(S3ClientError): @@ -196,9 +219,7 @@ class InvalidBucketName(S3ClientError): def __init__(self, *args, **kwargs): super(InvalidBucketName, self).__init__( - "InvalidBucketName", - "The specified bucket is not valid.", - *args, **kwargs + "InvalidBucketName", "The specified bucket is not valid.", *args, **kwargs ) @@ -209,35 +230,51 @@ class DuplicateTagKeys(S3ClientError): super(DuplicateTagKeys, self).__init__( "InvalidTag", "Cannot provide multiple Tags with the same key", - *args, **kwargs) + *args, + **kwargs + ) class S3AccessDeniedError(S3ClientError): code = 403 def __init__(self, *args, **kwargs): - super(S3AccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs) + super(S3AccessDeniedError, self).__init__( + "AccessDenied", "Access Denied", *args, **kwargs + ) class BucketAccessDeniedError(BucketError): code = 403 def __init__(self, *args, **kwargs): - super(BucketAccessDeniedError, self).__init__('AccessDenied', 'Access Denied', *args, **kwargs) + super(BucketAccessDeniedError, self).__init__( + "AccessDenied", "Access Denied", *args, **kwargs + ) class S3InvalidTokenError(S3ClientError): code = 400 def __init__(self, *args, **kwargs): - super(S3InvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs) + super(S3InvalidTokenError, self).__init__( + "InvalidToken", + "The provided token is malformed or otherwise invalid.", + *args, + **kwargs + ) class BucketInvalidTokenError(BucketError): code = 400 def __init__(self, *args, **kwargs): - super(BucketInvalidTokenError, self).__init__('InvalidToken', 'The provided token is malformed or otherwise invalid.', *args, **kwargs) + super(BucketInvalidTokenError, self).__init__( + "InvalidToken", + "The provided token is malformed or otherwise invalid.", + *args, + **kwargs + ) class S3InvalidAccessKeyIdError(S3ClientError): @@ -245,8 +282,11 @@ class S3InvalidAccessKeyIdError(S3ClientError): def __init__(self, *args, **kwargs): super(S3InvalidAccessKeyIdError, self).__init__( - 'InvalidAccessKeyId', - "The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs) + "InvalidAccessKeyId", + "The AWS Access Key Id you provided does not exist in our records.", + *args, + **kwargs + ) class BucketInvalidAccessKeyIdError(S3ClientError): @@ -254,8 +294,11 @@ class BucketInvalidAccessKeyIdError(S3ClientError): def __init__(self, *args, **kwargs): super(BucketInvalidAccessKeyIdError, self).__init__( - 'InvalidAccessKeyId', - "The AWS Access Key Id you provided does not exist in our records.", *args, **kwargs) + "InvalidAccessKeyId", + "The AWS Access Key Id you provided does not exist in our records.", + *args, + **kwargs + ) class S3SignatureDoesNotMatchError(S3ClientError): @@ -263,8 +306,11 @@ class S3SignatureDoesNotMatchError(S3ClientError): def __init__(self, *args, **kwargs): super(S3SignatureDoesNotMatchError, self).__init__( - 'SignatureDoesNotMatch', - "The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs) + "SignatureDoesNotMatch", + "The request signature we calculated does not match the signature you provided. Check your key and signing method.", + *args, + **kwargs + ) class BucketSignatureDoesNotMatchError(S3ClientError): @@ -272,5 +318,8 @@ class BucketSignatureDoesNotMatchError(S3ClientError): def __init__(self, *args, **kwargs): super(BucketSignatureDoesNotMatchError, self).__init__( - 'SignatureDoesNotMatch', - "The request signature we calculated does not match the signature you provided. Check your key and signing method.", *args, **kwargs) + "SignatureDoesNotMatch", + "The request signature we calculated does not match the signature you provided. Check your key and signing method.", + *args, + **kwargs + ) diff --git a/moto/s3/models.py b/moto/s3/models.py index 8c4a058ee..9c8f64242 100644 --- a/moto/s3/models.py +++ b/moto/s3/models.py @@ -22,9 +22,19 @@ from bisect import insort from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds, rfc_1123_datetime from .exceptions import ( - BucketAlreadyExists, MissingBucket, InvalidBucketName, InvalidPart, InvalidRequest, - EntityTooSmall, MissingKey, InvalidNotificationDestination, MalformedXML, InvalidStorageClass, - InvalidTargetBucketForLogging, DuplicateTagKeys, CrossLocationLoggingProhibitted + BucketAlreadyExists, + MissingBucket, + InvalidBucketName, + InvalidPart, + InvalidRequest, + EntityTooSmall, + MissingKey, + InvalidNotificationDestination, + MalformedXML, + InvalidStorageClass, + InvalidTargetBucketForLogging, + DuplicateTagKeys, + CrossLocationLoggingProhibitted, ) from .utils import clean_key_name, _VersionedKeyStore @@ -32,15 +42,21 @@ MAX_BUCKET_NAME_LENGTH = 63 MIN_BUCKET_NAME_LENGTH = 3 UPLOAD_ID_BYTES = 43 UPLOAD_PART_MIN_SIZE = 5242880 -STORAGE_CLASS = ["STANDARD", "REDUCED_REDUNDANCY", "STANDARD_IA", "ONEZONE_IA", - "INTELLIGENT_TIERING", "GLACIER", "DEEP_ARCHIVE"] +STORAGE_CLASS = [ + "STANDARD", + "REDUCED_REDUNDANCY", + "STANDARD_IA", + "ONEZONE_IA", + "INTELLIGENT_TIERING", + "GLACIER", + "DEEP_ARCHIVE", +] DEFAULT_KEY_BUFFER_SIZE = 16 * 1024 * 1024 DEFAULT_TEXT_ENCODING = sys.getdefaultencoding() -OWNER = '75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a' +OWNER = "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a" class FakeDeleteMarker(BaseModel): - def __init__(self, key): self.key = key self.name = key.name @@ -57,7 +73,6 @@ class FakeDeleteMarker(BaseModel): class FakeKey(BaseModel): - def __init__( self, name, @@ -67,11 +82,11 @@ class FakeKey(BaseModel): is_versioned=False, version_id=0, max_buffer_size=DEFAULT_KEY_BUFFER_SIZE, - multipart=None + multipart=None, ): self.name = name self.last_modified = datetime.datetime.utcnow() - self.acl = get_canned_acl('private') + self.acl = get_canned_acl("private") self.website_redirect_location = None self._storage_class = storage if storage else "STANDARD" self._metadata = {} @@ -184,21 +199,21 @@ class FakeKey(BaseModel): @property def response_dict(self): res = { - 'ETag': self.etag, - 'last-modified': self.last_modified_RFC1123, - 'content-length': str(self.size), + "ETag": self.etag, + "last-modified": self.last_modified_RFC1123, + "content-length": str(self.size), } - if self._storage_class != 'STANDARD': - res['x-amz-storage-class'] = self._storage_class + if self._storage_class != "STANDARD": + res["x-amz-storage-class"] = self._storage_class if self._expiry is not None: rhdr = 'ongoing-request="false", expiry-date="{0}"' - res['x-amz-restore'] = rhdr.format(self.expiry_date) + res["x-amz-restore"] = rhdr.format(self.expiry_date) if self._is_versioned: - res['x-amz-version-id'] = str(self.version_id) + res["x-amz-version-id"] = str(self.version_id) if self.website_redirect_location: - res['x-amz-website-redirect-location'] = self.website_redirect_location + res["x-amz-website-redirect-location"] = self.website_redirect_location return res @@ -222,30 +237,27 @@ class FakeKey(BaseModel): # https://docs.python.org/3/library/pickle.html#handling-stateful-objects def __getstate__(self): state = self.__dict__.copy() - state['value'] = self.value - del state['_value_buffer'] + state["value"] = self.value + del state["_value_buffer"] return state def __setstate__(self, state): - self.__dict__.update({ - k: v for k, v in six.iteritems(state) - if k != 'value' - }) + self.__dict__.update({k: v for k, v in six.iteritems(state) if k != "value"}) - self._value_buffer = \ - tempfile.SpooledTemporaryFile(max_size=self._max_buffer_size) - self.value = state['value'] + self._value_buffer = tempfile.SpooledTemporaryFile( + max_size=self._max_buffer_size + ) + self.value = state["value"] class FakeMultipart(BaseModel): - def __init__(self, key_name, metadata): self.key_name = key_name self.metadata = metadata self.parts = {} self.partlist = [] # ordered list of part ID's rand_b64 = base64.b64encode(os.urandom(UPLOAD_ID_BYTES)) - self.id = rand_b64.decode('utf-8').replace('=', '').replace('+', '') + self.id = rand_b64.decode("utf-8").replace("=", "").replace("+", "") def complete(self, body): decode_hex = codecs.getdecoder("hex_codec") @@ -258,8 +270,8 @@ class FakeMultipart(BaseModel): part = self.parts.get(pn) part_etag = None if part is not None: - part_etag = part.etag.replace('"', '') - etag = etag.replace('"', '') + part_etag = part.etag.replace('"', "") + etag = etag.replace('"', "") if part is None or part_etag != etag: raise InvalidPart() if last is not None and len(last.value) < UPLOAD_PART_MIN_SIZE: @@ -289,8 +301,7 @@ class FakeMultipart(BaseModel): class FakeGrantee(BaseModel): - - def __init__(self, id='', uri='', display_name=''): + def __init__(self, id="", uri="", display_name=""): self.id = id self.uri = uri self.display_name = display_name @@ -298,50 +309,55 @@ class FakeGrantee(BaseModel): def __eq__(self, other): if not isinstance(other, FakeGrantee): return False - return self.id == other.id and self.uri == other.uri and self.display_name == other.display_name + return ( + self.id == other.id + and self.uri == other.uri + and self.display_name == other.display_name + ) @property def type(self): - return 'Group' if self.uri else 'CanonicalUser' + return "Group" if self.uri else "CanonicalUser" def __repr__(self): - return "FakeGrantee(display_name: '{}', id: '{}', uri: '{}')".format(self.display_name, self.id, self.uri) + return "FakeGrantee(display_name: '{}', id: '{}', uri: '{}')".format( + self.display_name, self.id, self.uri + ) -ALL_USERS_GRANTEE = FakeGrantee( - uri='http://acs.amazonaws.com/groups/global/AllUsers') +ALL_USERS_GRANTEE = FakeGrantee(uri="http://acs.amazonaws.com/groups/global/AllUsers") AUTHENTICATED_USERS_GRANTEE = FakeGrantee( - uri='http://acs.amazonaws.com/groups/global/AuthenticatedUsers') -LOG_DELIVERY_GRANTEE = FakeGrantee( - uri='http://acs.amazonaws.com/groups/s3/LogDelivery') + uri="http://acs.amazonaws.com/groups/global/AuthenticatedUsers" +) +LOG_DELIVERY_GRANTEE = FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery") -PERMISSION_FULL_CONTROL = 'FULL_CONTROL' -PERMISSION_WRITE = 'WRITE' -PERMISSION_READ = 'READ' -PERMISSION_WRITE_ACP = 'WRITE_ACP' -PERMISSION_READ_ACP = 'READ_ACP' +PERMISSION_FULL_CONTROL = "FULL_CONTROL" +PERMISSION_WRITE = "WRITE" +PERMISSION_READ = "READ" +PERMISSION_WRITE_ACP = "WRITE_ACP" +PERMISSION_READ_ACP = "READ_ACP" CAMEL_CASED_PERMISSIONS = { - 'FULL_CONTROL': 'FullControl', - 'WRITE': 'Write', - 'READ': 'Read', - 'WRITE_ACP': 'WriteAcp', - 'READ_ACP': 'ReadAcp' + "FULL_CONTROL": "FullControl", + "WRITE": "Write", + "READ": "Read", + "WRITE_ACP": "WriteAcp", + "READ_ACP": "ReadAcp", } class FakeGrant(BaseModel): - def __init__(self, grantees, permissions): self.grantees = grantees self.permissions = permissions def __repr__(self): - return "FakeGrant(grantees: {}, permissions: {})".format(self.grantees, self.permissions) + return "FakeGrant(grantees: {}, permissions: {})".format( + self.grantees, self.permissions + ) class FakeAcl(BaseModel): - def __init__(self, grants=None): grants = grants or [] self.grants = grants @@ -362,34 +378,48 @@ class FakeAcl(BaseModel): def to_config_dict(self): """Returns the object into the format expected by AWS Config""" data = { - 'grantSet': None, # Always setting this to None. Feel free to change. - 'owner': {'displayName': None, 'id': OWNER} + "grantSet": None, # Always setting this to None. Feel free to change. + "owner": {"displayName": None, "id": OWNER}, } # Add details for each Grant: grant_list = [] for grant in self.grants: - permissions = grant.permissions if isinstance(grant.permissions, list) else [grant.permissions] + permissions = ( + grant.permissions + if isinstance(grant.permissions, list) + else [grant.permissions] + ) for permission in permissions: for grantee in grant.grantees: # Config does not add the owner if its permissions are FULL_CONTROL: - if permission == 'FULL_CONTROL' and grantee.id == OWNER: + if permission == "FULL_CONTROL" and grantee.id == OWNER: continue if grantee.uri: - grant_list.append({'grantee': grantee.uri.split('http://acs.amazonaws.com/groups/s3/')[1], - 'permission': CAMEL_CASED_PERMISSIONS[permission]}) + grant_list.append( + { + "grantee": grantee.uri.split( + "http://acs.amazonaws.com/groups/s3/" + )[1], + "permission": CAMEL_CASED_PERMISSIONS[permission], + } + ) else: - grant_list.append({ - 'grantee': { - 'id': grantee.id, - 'displayName': None if not grantee.display_name else grantee.display_name - }, - 'permission': CAMEL_CASED_PERMISSIONS[permission] - }) + grant_list.append( + { + "grantee": { + "id": grantee.id, + "displayName": None + if not grantee.display_name + else grantee.display_name, + }, + "permission": CAMEL_CASED_PERMISSIONS[permission], + } + ) if grant_list: - data['grantList'] = grant_list + data["grantList"] = grant_list return data @@ -397,51 +427,48 @@ class FakeAcl(BaseModel): def get_canned_acl(acl): owner_grantee = FakeGrantee(id=OWNER) grants = [FakeGrant([owner_grantee], [PERMISSION_FULL_CONTROL])] - if acl == 'private': + if acl == "private": pass # no other permissions - elif acl == 'public-read': + elif acl == "public-read": grants.append(FakeGrant([ALL_USERS_GRANTEE], [PERMISSION_READ])) - elif acl == 'public-read-write': - grants.append(FakeGrant([ALL_USERS_GRANTEE], [ - PERMISSION_READ, PERMISSION_WRITE])) - elif acl == 'authenticated-read': + elif acl == "public-read-write": grants.append( - FakeGrant([AUTHENTICATED_USERS_GRANTEE], [PERMISSION_READ])) - elif acl == 'bucket-owner-read': + FakeGrant([ALL_USERS_GRANTEE], [PERMISSION_READ, PERMISSION_WRITE]) + ) + elif acl == "authenticated-read": + grants.append(FakeGrant([AUTHENTICATED_USERS_GRANTEE], [PERMISSION_READ])) + elif acl == "bucket-owner-read": pass # TODO: bucket owner ACL - elif acl == 'bucket-owner-full-control': + elif acl == "bucket-owner-full-control": pass # TODO: bucket owner ACL - elif acl == 'aws-exec-read': + elif acl == "aws-exec-read": pass # TODO: bucket owner, EC2 Read - elif acl == 'log-delivery-write': - grants.append(FakeGrant([LOG_DELIVERY_GRANTEE], [ - PERMISSION_READ_ACP, PERMISSION_WRITE])) + elif acl == "log-delivery-write": + grants.append( + FakeGrant([LOG_DELIVERY_GRANTEE], [PERMISSION_READ_ACP, PERMISSION_WRITE]) + ) else: - assert False, 'Unknown canned acl: %s' % (acl,) + assert False, "Unknown canned acl: %s" % (acl,) return FakeAcl(grants=grants) class FakeTagging(BaseModel): - def __init__(self, tag_set=None): self.tag_set = tag_set or FakeTagSet() class FakeTagSet(BaseModel): - def __init__(self, tags=None): self.tags = tags or [] class FakeTag(BaseModel): - def __init__(self, key, value=None): self.key = key self.value = value class LifecycleFilter(BaseModel): - def __init__(self, prefix=None, tag=None, and_filter=None): self.prefix = prefix self.tag = tag @@ -450,34 +477,27 @@ class LifecycleFilter(BaseModel): def to_config_dict(self): if self.prefix is not None: return { - 'predicate': { - 'type': 'LifecyclePrefixPredicate', - 'prefix': self.prefix - } + "predicate": {"type": "LifecyclePrefixPredicate", "prefix": self.prefix} } elif self.tag: return { - 'predicate': { - 'type': 'LifecycleTagPredicate', - 'tag': { - 'key': self.tag.key, - 'value': self.tag.value - } + "predicate": { + "type": "LifecycleTagPredicate", + "tag": {"key": self.tag.key, "value": self.tag.value}, } } else: return { - 'predicate': { - 'type': 'LifecycleAndOperator', - 'operands': self.and_filter.to_config_dict() + "predicate": { + "type": "LifecycleAndOperator", + "operands": self.and_filter.to_config_dict(), } } class LifecycleAndFilter(BaseModel): - def __init__(self, prefix=None, tags=None): self.prefix = prefix self.tags = tags @@ -486,20 +506,37 @@ class LifecycleAndFilter(BaseModel): data = [] if self.prefix is not None: - data.append({'type': 'LifecyclePrefixPredicate', 'prefix': self.prefix}) + data.append({"type": "LifecyclePrefixPredicate", "prefix": self.prefix}) for tag in self.tags: - data.append({'type': 'LifecycleTagPredicate', 'tag': {'key': tag.key, 'value': tag.value}}) + data.append( + { + "type": "LifecycleTagPredicate", + "tag": {"key": tag.key, "value": tag.value}, + } + ) return data class LifecycleRule(BaseModel): - - def __init__(self, id=None, prefix=None, lc_filter=None, status=None, expiration_days=None, - expiration_date=None, transition_days=None, transition_date=None, storage_class=None, - expired_object_delete_marker=None, nve_noncurrent_days=None, nvt_noncurrent_days=None, - nvt_storage_class=None, aimu_days=None): + def __init__( + self, + id=None, + prefix=None, + lc_filter=None, + status=None, + expiration_days=None, + expiration_date=None, + transition_days=None, + transition_date=None, + storage_class=None, + expired_object_delete_marker=None, + nve_noncurrent_days=None, + nvt_noncurrent_days=None, + nvt_storage_class=None, + aimu_days=None, + ): self.id = id self.prefix = prefix self.filter = lc_filter @@ -527,49 +564,79 @@ class LifecycleRule(BaseModel): """ lifecycle_dict = { - 'id': self.id, - 'prefix': self.prefix, - 'status': self.status, - 'expirationInDays': int(self.expiration_days) if self.expiration_days else None, - 'expiredObjectDeleteMarker': self.expired_object_delete_marker, - 'noncurrentVersionExpirationInDays': -1 or int(self.nve_noncurrent_days), - 'expirationDate': self.expiration_date, - 'transitions': None, # Replace me with logic to fill in - 'noncurrentVersionTransitions': None, # Replace me with logic to fill in + "id": self.id, + "prefix": self.prefix, + "status": self.status, + "expirationInDays": int(self.expiration_days) + if self.expiration_days + else None, + "expiredObjectDeleteMarker": self.expired_object_delete_marker, + "noncurrentVersionExpirationInDays": -1 or int(self.nve_noncurrent_days), + "expirationDate": self.expiration_date, + "transitions": None, # Replace me with logic to fill in + "noncurrentVersionTransitions": None, # Replace me with logic to fill in } if self.aimu_days: - lifecycle_dict['abortIncompleteMultipartUpload'] = {'daysAfterInitiation': self.aimu_days} + lifecycle_dict["abortIncompleteMultipartUpload"] = { + "daysAfterInitiation": self.aimu_days + } else: - lifecycle_dict['abortIncompleteMultipartUpload'] = None + lifecycle_dict["abortIncompleteMultipartUpload"] = None # Format the filter: if self.prefix is None and self.filter is None: - lifecycle_dict['filter'] = {'predicate': None} + lifecycle_dict["filter"] = {"predicate": None} elif self.prefix: - lifecycle_dict['filter'] = None + lifecycle_dict["filter"] = None else: - lifecycle_dict['filter'] = self.filter.to_config_dict() + lifecycle_dict["filter"] = self.filter.to_config_dict() return lifecycle_dict class CorsRule(BaseModel): - - def __init__(self, allowed_methods, allowed_origins, allowed_headers=None, expose_headers=None, - max_age_seconds=None): - self.allowed_methods = [allowed_methods] if isinstance(allowed_methods, six.string_types) else allowed_methods - self.allowed_origins = [allowed_origins] if isinstance(allowed_origins, six.string_types) else allowed_origins - self.allowed_headers = [allowed_headers] if isinstance(allowed_headers, six.string_types) else allowed_headers - self.exposed_headers = [expose_headers] if isinstance(expose_headers, six.string_types) else expose_headers + def __init__( + self, + allowed_methods, + allowed_origins, + allowed_headers=None, + expose_headers=None, + max_age_seconds=None, + ): + self.allowed_methods = ( + [allowed_methods] + if isinstance(allowed_methods, six.string_types) + else allowed_methods + ) + self.allowed_origins = ( + [allowed_origins] + if isinstance(allowed_origins, six.string_types) + else allowed_origins + ) + self.allowed_headers = ( + [allowed_headers] + if isinstance(allowed_headers, six.string_types) + else allowed_headers + ) + self.exposed_headers = ( + [expose_headers] + if isinstance(expose_headers, six.string_types) + else expose_headers + ) self.max_age_seconds = max_age_seconds class Notification(BaseModel): - def __init__(self, arn, events, filters=None, id=None): - self.id = id if id else ''.join(random.choice(string.ascii_letters + string.digits) for _ in range(50)) + self.id = ( + id + if id + else "".join( + random.choice(string.ascii_letters + string.digits) for _ in range(50) + ) + ) self.arn = arn self.events = events self.filters = filters if filters else {} @@ -578,56 +645,90 @@ class Notification(BaseModel): data = {} # Type and ARN will be filled in by NotificationConfiguration's to_config_dict: - data['events'] = [event for event in self.events] + data["events"] = [event for event in self.events] if self.filters: - data['filter'] = {'s3KeyFilter': {'filterRules': [ - {'name': fr['Name'], 'value': fr['Value']} for fr in self.filters['S3Key']['FilterRule'] - ]}} + data["filter"] = { + "s3KeyFilter": { + "filterRules": [ + {"name": fr["Name"], "value": fr["Value"]} + for fr in self.filters["S3Key"]["FilterRule"] + ] + } + } else: - data['filter'] = None + data["filter"] = None - data['objectPrefixes'] = [] # Not sure why this is a thing since AWS just seems to return this as filters ¯\_(ツ)_/¯ + data[ + "objectPrefixes" + ] = ( + [] + ) # Not sure why this is a thing since AWS just seems to return this as filters ¯\_(ツ)_/¯ return data class NotificationConfiguration(BaseModel): - def __init__(self, topic=None, queue=None, cloud_function=None): - self.topic = [Notification(t["Topic"], t["Event"], filters=t.get("Filter"), id=t.get("Id")) for t in topic] \ - if topic else [] - self.queue = [Notification(q["Queue"], q["Event"], filters=q.get("Filter"), id=q.get("Id")) for q in queue] \ - if queue else [] - self.cloud_function = [Notification(c["CloudFunction"], c["Event"], filters=c.get("Filter"), id=c.get("Id")) - for c in cloud_function] if cloud_function else [] + self.topic = ( + [ + Notification( + t["Topic"], t["Event"], filters=t.get("Filter"), id=t.get("Id") + ) + for t in topic + ] + if topic + else [] + ) + self.queue = ( + [ + Notification( + q["Queue"], q["Event"], filters=q.get("Filter"), id=q.get("Id") + ) + for q in queue + ] + if queue + else [] + ) + self.cloud_function = ( + [ + Notification( + c["CloudFunction"], + c["Event"], + filters=c.get("Filter"), + id=c.get("Id"), + ) + for c in cloud_function + ] + if cloud_function + else [] + ) def to_config_dict(self): - data = {'configurations': {}} + data = {"configurations": {}} for topic in self.topic: topic_config = topic.to_config_dict() - topic_config['topicARN'] = topic.arn - topic_config['type'] = 'TopicConfiguration' - data['configurations'][topic.id] = topic_config + topic_config["topicARN"] = topic.arn + topic_config["type"] = "TopicConfiguration" + data["configurations"][topic.id] = topic_config for queue in self.queue: queue_config = queue.to_config_dict() - queue_config['queueARN'] = queue.arn - queue_config['type'] = 'QueueConfiguration' - data['configurations'][queue.id] = queue_config + queue_config["queueARN"] = queue.arn + queue_config["type"] = "QueueConfiguration" + data["configurations"][queue.id] = queue_config for cloud_function in self.cloud_function: cf_config = cloud_function.to_config_dict() - cf_config['queueARN'] = cloud_function.arn - cf_config['type'] = 'LambdaConfiguration' - data['configurations'][cloud_function.id] = cf_config + cf_config["queueARN"] = cloud_function.arn + cf_config["type"] = "LambdaConfiguration" + data["configurations"][cloud_function.id] = cf_config return data class FakeBucket(BaseModel): - def __init__(self, name, region_name): self.name = name self.region_name = region_name @@ -637,13 +738,13 @@ class FakeBucket(BaseModel): self.rules = [] self.policy = None self.website_configuration = None - self.acl = get_canned_acl('private') + self.acl = get_canned_acl("private") self.tags = FakeTagging() self.cors = [] self.logging = {} self.notification_configuration = None self.accelerate_configuration = None - self.payer = 'BucketOwner' + self.payer = "BucketOwner" self.creation_date = datetime.datetime.utcnow() @property @@ -652,41 +753,52 @@ class FakeBucket(BaseModel): @property def is_versioned(self): - return self.versioning_status == 'Enabled' + return self.versioning_status == "Enabled" def set_lifecycle(self, rules): self.rules = [] for rule in rules: # Extract and validate actions from Lifecycle rule - expiration = rule.get('Expiration') - transition = rule.get('Transition') + expiration = rule.get("Expiration") + transition = rule.get("Transition") try: - top_level_prefix = rule['Prefix'] or '' # If it's `None` the set to the empty string + top_level_prefix = ( + rule["Prefix"] or "" + ) # If it's `None` the set to the empty string except KeyError: top_level_prefix = None nve_noncurrent_days = None - if rule.get('NoncurrentVersionExpiration') is not None: - if rule["NoncurrentVersionExpiration"].get('NoncurrentDays') is None: + if rule.get("NoncurrentVersionExpiration") is not None: + if rule["NoncurrentVersionExpiration"].get("NoncurrentDays") is None: raise MalformedXML() - nve_noncurrent_days = rule["NoncurrentVersionExpiration"]["NoncurrentDays"] + nve_noncurrent_days = rule["NoncurrentVersionExpiration"][ + "NoncurrentDays" + ] nvt_noncurrent_days = None nvt_storage_class = None - if rule.get('NoncurrentVersionTransition') is not None: - if rule["NoncurrentVersionTransition"].get('NoncurrentDays') is None: + if rule.get("NoncurrentVersionTransition") is not None: + if rule["NoncurrentVersionTransition"].get("NoncurrentDays") is None: raise MalformedXML() - if rule["NoncurrentVersionTransition"].get('StorageClass') is None: + if rule["NoncurrentVersionTransition"].get("StorageClass") is None: raise MalformedXML() - nvt_noncurrent_days = rule["NoncurrentVersionTransition"]["NoncurrentDays"] + nvt_noncurrent_days = rule["NoncurrentVersionTransition"][ + "NoncurrentDays" + ] nvt_storage_class = rule["NoncurrentVersionTransition"]["StorageClass"] aimu_days = None - if rule.get('AbortIncompleteMultipartUpload') is not None: - if rule["AbortIncompleteMultipartUpload"].get('DaysAfterInitiation') is None: + if rule.get("AbortIncompleteMultipartUpload") is not None: + if ( + rule["AbortIncompleteMultipartUpload"].get("DaysAfterInitiation") + is None + ): raise MalformedXML() - aimu_days = rule["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] + aimu_days = rule["AbortIncompleteMultipartUpload"][ + "DaysAfterInitiation" + ] eodm = None if expiration and expiration.get("ExpiredObjectDeleteMarker") is not None: @@ -708,7 +820,9 @@ class FakeBucket(BaseModel): filters = 0 try: - prefix_filter = rule['Filter']['Prefix'] or '' # If it's `None` the set to the empty string + prefix_filter = ( + rule["Filter"]["Prefix"] or "" + ) # If it's `None` the set to the empty string filters += 1 except KeyError: prefix_filter = None @@ -719,13 +833,17 @@ class FakeBucket(BaseModel): and_tags = [] if rule["Filter"]["And"].get("Tag"): if not isinstance(rule["Filter"]["And"]["Tag"], list): - rule["Filter"]["And"]["Tag"] = [rule["Filter"]["And"]["Tag"]] + rule["Filter"]["And"]["Tag"] = [ + rule["Filter"]["And"]["Tag"] + ] for t in rule["Filter"]["And"]["Tag"]: - and_tags.append(FakeTag(t["Key"], t.get("Value", ''))) + and_tags.append(FakeTag(t["Key"], t.get("Value", ""))) try: - and_prefix = rule["Filter"]["And"]["Prefix"] or '' # If it's `None` then set to the empty string + and_prefix = ( + rule["Filter"]["And"]["Prefix"] or "" + ) # If it's `None` then set to the empty string except KeyError: and_prefix = None @@ -734,37 +852,46 @@ class FakeBucket(BaseModel): filter_tag = None if rule["Filter"].get("Tag"): filters += 1 - filter_tag = FakeTag(rule["Filter"]["Tag"]["Key"], rule["Filter"]["Tag"].get("Value", '')) + filter_tag = FakeTag( + rule["Filter"]["Tag"]["Key"], + rule["Filter"]["Tag"].get("Value", ""), + ) # Can't have more than 1 filter: if filters > 1: raise MalformedXML() - lc_filter = LifecycleFilter(prefix=prefix_filter, tag=filter_tag, and_filter=and_filter) + lc_filter = LifecycleFilter( + prefix=prefix_filter, tag=filter_tag, and_filter=and_filter + ) # If no top level prefix and no filter is present, then this is invalid: if top_level_prefix is None: try: - rule['Filter'] + rule["Filter"] except KeyError: raise MalformedXML() - self.rules.append(LifecycleRule( - id=rule.get('ID'), - prefix=top_level_prefix, - lc_filter=lc_filter, - status=rule['Status'], - expiration_days=expiration.get('Days') if expiration else None, - expiration_date=expiration.get('Date') if expiration else None, - transition_days=transition.get('Days') if transition else None, - transition_date=transition.get('Date') if transition else None, - storage_class=transition.get('StorageClass') if transition else None, - expired_object_delete_marker=eodm, - nve_noncurrent_days=nve_noncurrent_days, - nvt_noncurrent_days=nvt_noncurrent_days, - nvt_storage_class=nvt_storage_class, - aimu_days=aimu_days, - )) + self.rules.append( + LifecycleRule( + id=rule.get("ID"), + prefix=top_level_prefix, + lc_filter=lc_filter, + status=rule["Status"], + expiration_days=expiration.get("Days") if expiration else None, + expiration_date=expiration.get("Date") if expiration else None, + transition_days=transition.get("Days") if transition else None, + transition_date=transition.get("Date") if transition else None, + storage_class=transition.get("StorageClass") + if transition + else None, + expired_object_delete_marker=eodm, + nve_noncurrent_days=nve_noncurrent_days, + nvt_noncurrent_days=nvt_noncurrent_days, + nvt_storage_class=nvt_storage_class, + aimu_days=aimu_days, + ) + ) def delete_lifecycle(self): self.rules = [] @@ -776,12 +903,18 @@ class FakeBucket(BaseModel): raise MalformedXML() for rule in rules: - assert isinstance(rule["AllowedMethod"], list) or isinstance(rule["AllowedMethod"], six.string_types) - assert isinstance(rule["AllowedOrigin"], list) or isinstance(rule["AllowedOrigin"], six.string_types) - assert isinstance(rule.get("AllowedHeader", []), list) or isinstance(rule.get("AllowedHeader", ""), - six.string_types) - assert isinstance(rule.get("ExposedHeader", []), list) or isinstance(rule.get("ExposedHeader", ""), - six.string_types) + assert isinstance(rule["AllowedMethod"], list) or isinstance( + rule["AllowedMethod"], six.string_types + ) + assert isinstance(rule["AllowedOrigin"], list) or isinstance( + rule["AllowedOrigin"], six.string_types + ) + assert isinstance(rule.get("AllowedHeader", []), list) or isinstance( + rule.get("AllowedHeader", ""), six.string_types + ) + assert isinstance(rule.get("ExposedHeader", []), list) or isinstance( + rule.get("ExposedHeader", ""), six.string_types + ) assert isinstance(rule.get("MaxAgeSeconds", "0"), six.string_types) if isinstance(rule["AllowedMethod"], six.string_types): @@ -793,13 +926,15 @@ class FakeBucket(BaseModel): if method not in ["GET", "PUT", "HEAD", "POST", "DELETE"]: raise InvalidRequest(method) - self.cors.append(CorsRule( - rule["AllowedMethod"], - rule["AllowedOrigin"], - rule.get("AllowedHeader"), - rule.get("ExposedHeader"), - rule.get("MaxAgeSecond") - )) + self.cors.append( + CorsRule( + rule["AllowedMethod"], + rule["AllowedOrigin"], + rule.get("AllowedHeader"), + rule.get("ExposedHeader"), + rule.get("MaxAgeSecond"), + ) + ) def delete_cors(self): self.cors = [] @@ -821,7 +956,9 @@ class FakeBucket(BaseModel): # Target bucket must exist in the same account (assuming all moto buckets are in the same account): if not bucket_backend.buckets.get(logging_config["TargetBucket"]): - raise InvalidTargetBucketForLogging("The target bucket for logging does not exist.") + raise InvalidTargetBucketForLogging( + "The target bucket for logging does not exist." + ) # Does the target bucket have the log-delivery WRITE and READ_ACP permissions? write = read_acp = False @@ -829,20 +966,31 @@ class FakeBucket(BaseModel): # Must be granted to: http://acs.amazonaws.com/groups/s3/LogDelivery for grantee in grant.grantees: if grantee.uri == "http://acs.amazonaws.com/groups/s3/LogDelivery": - if "WRITE" in grant.permissions or "FULL_CONTROL" in grant.permissions: + if ( + "WRITE" in grant.permissions + or "FULL_CONTROL" in grant.permissions + ): write = True - if "READ_ACP" in grant.permissions or "FULL_CONTROL" in grant.permissions: + if ( + "READ_ACP" in grant.permissions + or "FULL_CONTROL" in grant.permissions + ): read_acp = True break if not write or not read_acp: - raise InvalidTargetBucketForLogging("You must give the log-delivery group WRITE and READ_ACP" - " permissions to the target bucket") + raise InvalidTargetBucketForLogging( + "You must give the log-delivery group WRITE and READ_ACP" + " permissions to the target bucket" + ) # Buckets must also exist within the same region: - if bucket_backend.buckets[logging_config["TargetBucket"]].region_name != self.region_name: + if ( + bucket_backend.buckets[logging_config["TargetBucket"]].region_name + != self.region_name + ): raise CrossLocationLoggingProhibitted() # Checks pass -- set the logging config: @@ -856,7 +1004,7 @@ class FakeBucket(BaseModel): self.notification_configuration = NotificationConfiguration( topic=notification_config.get("TopicConfiguration"), queue=notification_config.get("QueueConfiguration"), - cloud_function=notification_config.get("CloudFunctionConfiguration") + cloud_function=notification_config.get("CloudFunctionConfiguration"), ) # Validate that the region is correct: @@ -867,7 +1015,7 @@ class FakeBucket(BaseModel): raise InvalidNotificationDestination() def set_accelerate_configuration(self, accelerate_config): - if self.accelerate_configuration is None and accelerate_config == 'Suspended': + if self.accelerate_configuration is None and accelerate_config == "Suspended": # Cannot "suspend" a not active acceleration. Leaves it undefined return @@ -878,12 +1026,11 @@ class FakeBucket(BaseModel): def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'DomainName': - raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "DomainName" ]"') - elif attribute_name == 'WebsiteURL': - raise NotImplementedError( - '"Fn::GetAtt" : [ "{0}" , "WebsiteURL" ]"') + + if attribute_name == "DomainName": + raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "DomainName" ]"') + elif attribute_name == "WebsiteURL": + raise NotImplementedError('"Fn::GetAtt" : [ "{0}" , "WebsiteURL" ]"') raise UnformattedGetAttTemplateException() def set_acl(self, acl): @@ -895,7 +1042,8 @@ class FakeBucket(BaseModel): @classmethod def create_from_cloudformation_json( - cls, resource_name, cloudformation_json, region_name): + cls, resource_name, cloudformation_json, region_name + ): bucket = s3_backend.create_bucket(resource_name, region_name) return bucket @@ -906,69 +1054,78 @@ class FakeBucket(BaseModel): - Bucket Accelerate Configuration """ config_dict = { - 'version': '1.3', - 'configurationItemCaptureTime': str(self.creation_date), - 'configurationItemStatus': 'ResourceDiscovered', - 'configurationStateId': str(int(time.mktime(self.creation_date.timetuple()))), # PY2 and 3 compatible - 'configurationItemMD5Hash': '', - 'arn': "arn:aws:s3:::{}".format(self.name), - 'resourceType': 'AWS::S3::Bucket', - 'resourceId': self.name, - 'resourceName': self.name, - 'awsRegion': self.region_name, - 'availabilityZone': 'Regional', - 'resourceCreationTime': str(self.creation_date), - 'relatedEvents': [], - 'relationships': [], - 'tags': {tag.key: tag.value for tag in self.tagging.tag_set.tags}, - 'configuration': { - 'name': self.name, - 'owner': {'id': OWNER}, - 'creationDate': self.creation_date.isoformat() - } + "version": "1.3", + "configurationItemCaptureTime": str(self.creation_date), + "configurationItemStatus": "ResourceDiscovered", + "configurationStateId": str( + int(time.mktime(self.creation_date.timetuple())) + ), # PY2 and 3 compatible + "configurationItemMD5Hash": "", + "arn": "arn:aws:s3:::{}".format(self.name), + "resourceType": "AWS::S3::Bucket", + "resourceId": self.name, + "resourceName": self.name, + "awsRegion": self.region_name, + "availabilityZone": "Regional", + "resourceCreationTime": str(self.creation_date), + "relatedEvents": [], + "relationships": [], + "tags": {tag.key: tag.value for tag in self.tagging.tag_set.tags}, + "configuration": { + "name": self.name, + "owner": {"id": OWNER}, + "creationDate": self.creation_date.isoformat(), + }, } # Make the supplementary configuration: # TODO: Implement Public Access Block Support # This is a dobule-wrapped JSON for some reason... - s_config = {'AccessControlList': json.dumps(json.dumps(self.acl.to_config_dict()))} + s_config = { + "AccessControlList": json.dumps(json.dumps(self.acl.to_config_dict())) + } # Tagging is special: - if config_dict['tags']: - s_config['BucketTaggingConfiguration'] = json.dumps({'tagSets': [{'tags': config_dict['tags']}]}) + if config_dict["tags"]: + s_config["BucketTaggingConfiguration"] = json.dumps( + {"tagSets": [{"tags": config_dict["tags"]}]} + ) # TODO implement Accelerate Configuration: - s_config['BucketAccelerateConfiguration'] = {'status': None} + s_config["BucketAccelerateConfiguration"] = {"status": None} if self.rules: - s_config['BucketLifecycleConfiguration'] = { + s_config["BucketLifecycleConfiguration"] = { "rules": [rule.to_config_dict() for rule in self.rules] } - s_config['BucketLoggingConfiguration'] = { - 'destinationBucketName': self.logging.get('TargetBucket', None), - 'logFilePrefix': self.logging.get('TargetPrefix', None) + s_config["BucketLoggingConfiguration"] = { + "destinationBucketName": self.logging.get("TargetBucket", None), + "logFilePrefix": self.logging.get("TargetPrefix", None), } - s_config['BucketPolicy'] = { - 'policyText': self.policy.decode('utf-8') if self.policy else None + s_config["BucketPolicy"] = { + "policyText": self.policy.decode("utf-8") if self.policy else None } - s_config['IsRequesterPaysEnabled'] = 'false' if self.payer == 'BucketOwner' else 'true' + s_config["IsRequesterPaysEnabled"] = ( + "false" if self.payer == "BucketOwner" else "true" + ) if self.notification_configuration: - s_config['BucketNotificationConfiguration'] = self.notification_configuration.to_config_dict() + s_config[ + "BucketNotificationConfiguration" + ] = self.notification_configuration.to_config_dict() else: - s_config['BucketNotificationConfiguration'] = {'configurations': {}} + s_config["BucketNotificationConfiguration"] = {"configurations": {}} - config_dict['supplementaryConfiguration'] = s_config + config_dict["supplementaryConfiguration"] = s_config return config_dict class S3Backend(BaseBackend): - def __init__(self): self.buckets = {} @@ -1014,27 +1171,33 @@ class S3Backend(BaseBackend): last_modified = version.last_modified version_id = version.version_id latest_modified_per_key[name] = max( - last_modified, - latest_modified_per_key.get(name, datetime.datetime.min) + last_modified, latest_modified_per_key.get(name, datetime.datetime.min) ) if last_modified == latest_modified_per_key[name]: latest_versions[name] = version_id return latest_versions - def get_bucket_versions(self, bucket_name, delimiter=None, - encoding_type=None, - key_marker=None, - max_keys=None, - version_id_marker=None, - prefix=''): + def get_bucket_versions( + self, + bucket_name, + delimiter=None, + encoding_type=None, + key_marker=None, + max_keys=None, + version_id_marker=None, + prefix="", + ): bucket = self.get_bucket(bucket_name) if any((delimiter, key_marker, version_id_marker)): raise NotImplementedError( - "Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker") + "Called get_bucket_versions with some of delimiter, encoding_type, key_marker, version_id_marker" + ) - return itertools.chain(*(l for key, l in bucket.keys.iterlists() if key.startswith(prefix))) + return itertools.chain( + *(l for key, l in bucket.keys.iterlists() if key.startswith(prefix)) + ) def get_bucket_policy(self, bucket_name): return self.get_bucket(bucket_name).policy @@ -1059,13 +1222,7 @@ class S3Backend(BaseBackend): return bucket.website_configuration def set_key( - self, - bucket_name, - key_name, - value, - storage=None, - etag=None, - multipart=None, + self, bucket_name, key_name, value, storage=None, etag=None, multipart=None ): key_name = clean_key_name(key_name) if storage is not None and storage not in STORAGE_CLASS: @@ -1084,7 +1241,8 @@ class S3Backend(BaseBackend): ) keys = [ - key for key in bucket.keys.getlist(key_name, []) + key + for key in bucket.keys.getlist(key_name, []) if key.version_id != new_key.version_id ] + [new_key] bucket.keys.setlist(key_name, keys) @@ -1155,13 +1313,15 @@ class S3Backend(BaseBackend): bucket = self.get_bucket(bucket_name) bucket.set_notification_configuration(notification_config) - def put_bucket_accelerate_configuration(self, bucket_name, accelerate_configuration): - if accelerate_configuration not in ['Enabled', 'Suspended']: + def put_bucket_accelerate_configuration( + self, bucket_name, accelerate_configuration + ): + if accelerate_configuration not in ["Enabled", "Suspended"]: raise MalformedXML() bucket = self.get_bucket(bucket_name) - if bucket.name.find('.') != -1: - raise InvalidRequest('PutBucketAccelerateConfiguration') + if bucket.name.find(".") != -1: + raise InvalidRequest("PutBucketAccelerateConfiguration") bucket.set_accelerate_configuration(accelerate_configuration) def initiate_multipart(self, bucket_name, key_name, metadata): @@ -1180,10 +1340,7 @@ class S3Backend(BaseBackend): del bucket.multiparts[multipart_id] key = self.set_key( - bucket_name, - multipart.key_name, - value, etag=etag, - multipart=multipart + bucket_name, multipart.key_name, value, etag=etag, multipart=multipart ) key.set_metadata(multipart.metadata) return key @@ -1205,14 +1362,25 @@ class S3Backend(BaseBackend): multipart = bucket.multiparts[multipart_id] return multipart.set_part(part_id, value) - def copy_part(self, dest_bucket_name, multipart_id, part_id, - src_bucket_name, src_key_name, src_version_id, start_byte, end_byte): + def copy_part( + self, + dest_bucket_name, + multipart_id, + part_id, + src_bucket_name, + src_key_name, + src_version_id, + start_byte, + end_byte, + ): dest_bucket = self.get_bucket(dest_bucket_name) multipart = dest_bucket.multiparts[multipart_id] - src_value = self.get_key(src_bucket_name, src_key_name, version_id=src_version_id).value + src_value = self.get_key( + src_bucket_name, src_key_name, version_id=src_version_id + ).value if start_byte is not None: - src_value = src_value[start_byte:end_byte + 1] + src_value = src_value[start_byte : end_byte + 1] return multipart.set_part(part_id, src_value) def prefix_query(self, bucket, prefix, delimiter): @@ -1224,33 +1392,33 @@ class S3Backend(BaseBackend): key_without_prefix = key_name.replace(prefix, "", 1) if delimiter and delimiter in key_without_prefix: # If delimiter, we need to split out folder_results - key_without_delimiter = key_without_prefix.split(delimiter)[ - 0] - folder_results.add("{0}{1}{2}".format( - prefix, key_without_delimiter, delimiter)) + key_without_delimiter = key_without_prefix.split(delimiter)[0] + folder_results.add( + "{0}{1}{2}".format(prefix, key_without_delimiter, delimiter) + ) else: key_results.add(key) else: for key_name, key in bucket.keys.items(): if delimiter and delimiter in key_name: # If delimiter, we need to split out folder_results - folder_results.add(key_name.split( - delimiter)[0] + delimiter) + folder_results.add(key_name.split(delimiter)[0] + delimiter) else: key_results.add(key) - key_results = filter(lambda key: not isinstance(key, FakeDeleteMarker), key_results) + key_results = filter( + lambda key: not isinstance(key, FakeDeleteMarker), key_results + ) key_results = sorted(key_results, key=lambda key: key.name) - folder_results = [folder_name for folder_name in sorted( - folder_results, key=lambda key: key)] + folder_results = [ + folder_name for folder_name in sorted(folder_results, key=lambda key: key) + ] return key_results, folder_results def _set_delete_marker(self, bucket_name, key_name): bucket = self.get_bucket(bucket_name) - bucket.keys[key_name] = FakeDeleteMarker( - key=bucket.keys[key_name] - ) + bucket.keys[key_name] = FakeDeleteMarker(key=bucket.keys[key_name]) def delete_key(self, bucket_name, key_name, version_id=None): key_name = clean_key_name(key_name) @@ -1271,7 +1439,7 @@ class S3Backend(BaseBackend): key for key in bucket.keys.getlist(key_name) if str(key.version_id) != str(version_id) - ] + ], ) if not bucket.keys.getlist(key_name): @@ -1280,13 +1448,20 @@ class S3Backend(BaseBackend): except KeyError: return False - def copy_key(self, src_bucket_name, src_key_name, dest_bucket_name, - dest_key_name, storage=None, acl=None, src_version_id=None): + def copy_key( + self, + src_bucket_name, + src_key_name, + dest_bucket_name, + dest_key_name, + storage=None, + acl=None, + src_version_id=None, + ): src_key_name = clean_key_name(src_key_name) dest_key_name = clean_key_name(dest_key_name) dest_bucket = self.get_bucket(dest_bucket_name) - key = self.get_key(src_bucket_name, src_key_name, - version_id=src_version_id) + key = self.get_key(src_bucket_name, src_key_name, version_id=src_version_id) new_key = key.copy(dest_key_name, dest_bucket.is_versioned) diff --git a/moto/s3/responses.py b/moto/s3/responses.py index a0dcbbaf4..fd3a7b2db 100644 --- a/moto/s3/responses.py +++ b/moto/s3/responses.py @@ -13,18 +13,46 @@ from moto.packages.httpretty.core import HTTPrettyRequest from moto.core.responses import _TemplateEnvironmentMixin, ActionAuthenticatorMixin from moto.core.utils import path_url -from moto.s3bucket_path.utils import bucket_name_from_url as bucketpath_bucket_name_from_url, \ - parse_key_name as bucketpath_parse_key_name, is_delete_keys as bucketpath_is_delete_keys +from moto.s3bucket_path.utils import ( + bucket_name_from_url as bucketpath_bucket_name_from_url, + parse_key_name as bucketpath_parse_key_name, + is_delete_keys as bucketpath_is_delete_keys, +) -from .exceptions import BucketAlreadyExists, S3ClientError, MissingBucket, MissingKey, InvalidPartOrder, MalformedXML, \ - MalformedACLError, InvalidNotificationARN, InvalidNotificationEvent, ObjectNotInActiveTierError -from .models import s3_backend, get_canned_acl, FakeGrantee, FakeGrant, FakeAcl, FakeKey, FakeTagging, FakeTagSet, \ - FakeTag -from .utils import bucket_name_from_url, clean_key_name, undo_clean_key_name, metadata_from_headers, parse_region_from_url +from .exceptions import ( + BucketAlreadyExists, + S3ClientError, + MissingBucket, + MissingKey, + InvalidPartOrder, + MalformedXML, + MalformedACLError, + InvalidNotificationARN, + InvalidNotificationEvent, + ObjectNotInActiveTierError, +) +from .models import ( + s3_backend, + get_canned_acl, + FakeGrantee, + FakeGrant, + FakeAcl, + FakeKey, + FakeTagging, + FakeTagSet, + FakeTag, +) +from .utils import ( + bucket_name_from_url, + clean_key_name, + undo_clean_key_name, + metadata_from_headers, + parse_region_from_url, +) from xml.dom import minidom -DEFAULT_REGION_NAME = 'us-east-1' +DEFAULT_REGION_NAME = "us-east-1" ACTION_MAP = { "BUCKET": { @@ -42,7 +70,7 @@ ACTION_MAP = { "notification": "GetBucketNotification", "accelerate": "GetAccelerateConfiguration", "versions": "ListBucketVersions", - "DEFAULT": "ListBucket" + "DEFAULT": "ListBucket", }, "PUT": { "lifecycle": "PutLifecycleConfiguration", @@ -55,15 +83,15 @@ ACTION_MAP = { "cors": "PutBucketCORS", "notification": "PutBucketNotification", "accelerate": "PutAccelerateConfiguration", - "DEFAULT": "CreateBucket" + "DEFAULT": "CreateBucket", }, "DELETE": { "lifecycle": "PutLifecycleConfiguration", "policy": "DeleteBucketPolicy", "tagging": "PutBucketTagging", "cors": "PutBucketCORS", - "DEFAULT": "DeleteBucket" - } + "DEFAULT": "DeleteBucket", + }, }, "KEY": { "GET": { @@ -71,25 +99,24 @@ ACTION_MAP = { "acl": "GetObjectAcl", "tagging": "GetObjectTagging", "versionId": "GetObjectVersion", - "DEFAULT": "GetObject" + "DEFAULT": "GetObject", }, "PUT": { "acl": "PutObjectAcl", "tagging": "PutObjectTagging", - "DEFAULT": "PutObject" + "DEFAULT": "PutObject", }, "DELETE": { "uploadId": "AbortMultipartUpload", "versionId": "DeleteObjectVersion", - "DEFAULT": " DeleteObject" + "DEFAULT": " DeleteObject", }, "POST": { "uploads": "PutObject", "restore": "RestoreObject", - "uploadId": "PutObject" - } - } - + "uploadId": "PutObject", + }, + }, } @@ -98,14 +125,12 @@ def parse_key_name(pth): def is_delete_keys(request, path, bucket_name): - return path == u'/?delete' or ( - path == u'/' and - getattr(request, "query_string", "") == "delete" + return path == "/?delete" or ( + path == "/" and getattr(request, "query_string", "") == "delete" ) class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): - def __init__(self, backend): super(ResponseObject, self).__init__() self.backend = backend @@ -128,34 +153,43 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return template.render(buckets=all_buckets) def subdomain_based_buckets(self, request): - host = request.headers.get('host', request.headers.get('Host')) + host = request.headers.get("host", request.headers.get("Host")) if not host: host = urlparse(request.url).netloc - if (not host or host.startswith('localhost') or host.startswith('localstack') or - re.match(r'^[^.]+$', host) or re.match(r'^.*\.svc\.cluster\.local$', host)): + if ( + not host + or host.startswith("localhost") + or host.startswith("localstack") + or re.match(r"^[^.]+$", host) + or re.match(r"^.*\.svc\.cluster\.local$", host) + ): # Default to path-based buckets for (1) localhost, (2) localstack hosts (e.g. localstack.dev), # (3) local host names that do not contain a "." (e.g., Docker container host names), or # (4) kubernetes host names return False - match = re.match(r'^([^\[\]:]+)(:\d+)?$', host) - if match: - match = re.match(r'((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|$)){4}', - match.groups()[0]) - if match: - return False - - match = re.match(r'^\[(.+)\](:\d+)?$', host) + match = re.match(r"^([^\[\]:]+)(:\d+)?$", host) if match: match = re.match( - r'^(((?=.*(::))(?!.*\3.+\3))\3?|[\dA-F]{1,4}:)([\dA-F]{1,4}(\3|:\b)|\2){5}(([\dA-F]{1,4}(\3|:\b|$)|\2){2}|(((2[0-4]|1\d|[1-9])?\d|25[0-5])\.?\b){4})\Z', - match.groups()[0], re.IGNORECASE) + r"((25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)(\.|$)){4}", match.groups()[0] + ) if match: return False - path_based = (host == 's3.amazonaws.com' or re.match( - r"s3[\.\-]([^.]*)\.amazonaws\.com", host)) + match = re.match(r"^\[(.+)\](:\d+)?$", host) + if match: + match = re.match( + r"^(((?=.*(::))(?!.*\3.+\3))\3?|[\dA-F]{1,4}:)([\dA-F]{1,4}(\3|:\b)|\2){5}(([\dA-F]{1,4}(\3|:\b|$)|\2){2}|(((2[0-4]|1\d|[1-9])?\d|25[0-5])\.?\b){4})\Z", + match.groups()[0], + re.IGNORECASE, + ) + if match: + return False + + path_based = host == "s3.amazonaws.com" or re.match( + r"s3[\.\-]([^.]*)\.amazonaws\.com", host + ) return not path_based def is_delete_keys(self, request, path, bucket_name): @@ -189,8 +223,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.method = request.method self.path = self._get_path(request) self.headers = request.headers - if 'host' not in self.headers: - self.headers['host'] = urlparse(full_url).netloc + if "host" not in self.headers: + self.headers["host"] = urlparse(full_url).netloc try: response = self._bucket_response(request, full_url, headers) except S3ClientError as s3error: @@ -221,31 +255,36 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self.data["BucketName"] = bucket_name - if hasattr(request, 'body'): + if hasattr(request, "body"): # Boto body = request.body else: # Flask server body = request.data if body is None: - body = b'' + body = b"" if isinstance(body, six.binary_type): - body = body.decode('utf-8') - body = u'{0}'.format(body).encode('utf-8') + body = body.decode("utf-8") + body = "{0}".format(body).encode("utf-8") - if method == 'HEAD': + if method == "HEAD": return self._bucket_response_head(bucket_name) - elif method == 'GET': + elif method == "GET": return self._bucket_response_get(bucket_name, querystring) - elif method == 'PUT': - return self._bucket_response_put(request, body, region_name, bucket_name, querystring) - elif method == 'DELETE': + elif method == "PUT": + return self._bucket_response_put( + request, body, region_name, bucket_name, querystring + ) + elif method == "DELETE": return self._bucket_response_delete(body, bucket_name, querystring) - elif method == 'POST': + elif method == "POST": return self._bucket_response_post(request, body, bucket_name) else: raise NotImplementedError( - "Method {0} has not been impelemented in the S3 backend yet".format(method)) + "Method {0} has not been impelemented in the S3 backend yet".format( + method + ) + ) @staticmethod def _get_querystring(full_url): @@ -268,22 +307,25 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._set_action("BUCKET", "GET", querystring) self._authenticate_and_authorize_s3_action() - if 'uploads' in querystring: - for unsup in ('delimiter', 'max-uploads'): + if "uploads" in querystring: + for unsup in ("delimiter", "max-uploads"): if unsup in querystring: raise NotImplementedError( - "Listing multipart uploads with {} has not been implemented yet.".format(unsup)) - multiparts = list( - self.backend.get_all_multiparts(bucket_name).values()) - if 'prefix' in querystring: - prefix = querystring.get('prefix', [None])[0] + "Listing multipart uploads with {} has not been implemented yet.".format( + unsup + ) + ) + multiparts = list(self.backend.get_all_multiparts(bucket_name).values()) + if "prefix" in querystring: + prefix = querystring.get("prefix", [None])[0] multiparts = [ - upload for upload in multiparts if upload.key_name.startswith(prefix)] + upload + for upload in multiparts + if upload.key_name.startswith(prefix) + ] template = self.response_template(S3_ALL_MULTIPARTS) - return template.render( - bucket_name=bucket_name, - uploads=multiparts) - elif 'location' in querystring: + return template.render(bucket_name=bucket_name, uploads=multiparts) + elif "location" in querystring: bucket = self.backend.get_bucket(bucket_name) template = self.response_template(S3_BUCKET_LOCATION) @@ -293,36 +335,36 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): location = None return template.render(location=location) - elif 'lifecycle' in querystring: + elif "lifecycle" in querystring: bucket = self.backend.get_bucket(bucket_name) if not bucket.rules: template = self.response_template(S3_NO_LIFECYCLE) return 404, {}, template.render(bucket_name=bucket_name) - template = self.response_template( - S3_BUCKET_LIFECYCLE_CONFIGURATION) + template = self.response_template(S3_BUCKET_LIFECYCLE_CONFIGURATION) return template.render(rules=bucket.rules) - elif 'versioning' in querystring: + elif "versioning" in querystring: versioning = self.backend.get_bucket_versioning(bucket_name) template = self.response_template(S3_BUCKET_GET_VERSIONING) return template.render(status=versioning) - elif 'policy' in querystring: + elif "policy" in querystring: policy = self.backend.get_bucket_policy(bucket_name) if not policy: template = self.response_template(S3_NO_POLICY) return 404, {}, template.render(bucket_name=bucket_name) return 200, {}, policy - elif 'website' in querystring: + elif "website" in querystring: website_configuration = self.backend.get_bucket_website_configuration( - bucket_name) + bucket_name + ) if not website_configuration: template = self.response_template(S3_NO_BUCKET_WEBSITE_CONFIG) return 404, {}, template.render(bucket_name=bucket_name) return 200, {}, website_configuration - elif 'acl' in querystring: + elif "acl" in querystring: bucket = self.backend.get_bucket(bucket_name) template = self.response_template(S3_OBJECT_ACL_RESPONSE) return template.render(obj=bucket) - elif 'tagging' in querystring: + elif "tagging" in querystring: bucket = self.backend.get_bucket(bucket_name) # "Special Error" if no tags: if len(bucket.tagging.tag_set.tags) == 0: @@ -330,7 +372,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 404, {}, template.render(bucket_name=bucket_name) template = self.response_template(S3_BUCKET_TAGGING_RESPONSE) return template.render(bucket=bucket) - elif 'logging' in querystring: + elif "logging" in querystring: bucket = self.backend.get_bucket(bucket_name) if not bucket.logging: template = self.response_template(S3_NO_LOGGING_CONFIG) @@ -358,13 +400,13 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): template = self.response_template(S3_BUCKET_ACCELERATE) return template.render(bucket=bucket) - elif 'versions' in querystring: - delimiter = querystring.get('delimiter', [None])[0] - encoding_type = querystring.get('encoding-type', [None])[0] - key_marker = querystring.get('key-marker', [None])[0] - max_keys = querystring.get('max-keys', [None])[0] - prefix = querystring.get('prefix', [''])[0] - version_id_marker = querystring.get('version-id-marker', [None])[0] + elif "versions" in querystring: + delimiter = querystring.get("delimiter", [None])[0] + encoding_type = querystring.get("encoding-type", [None])[0] + key_marker = querystring.get("key-marker", [None])[0] + max_keys = querystring.get("max-keys", [None])[0] + prefix = querystring.get("prefix", [""])[0] + version_id_marker = querystring.get("version-id-marker", [None])[0] bucket = self.backend.get_bucket(bucket_name) versions = self.backend.get_bucket_versions( @@ -374,7 +416,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): key_marker=key_marker, max_keys=max_keys, version_id_marker=version_id_marker, - prefix=prefix + prefix=prefix, ) latest_versions = self.backend.get_bucket_latest_versions( bucket_name=bucket_name @@ -387,49 +429,62 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: delete_marker_list.append(version) template = self.response_template(S3_BUCKET_GET_VERSIONS) - return 200, {}, template.render( - key_list=key_list, - delete_marker_list=delete_marker_list, - latest_versions=latest_versions, - bucket=bucket, - prefix='', - max_keys=1000, - delimiter='', - is_truncated='false', + return ( + 200, + {}, + template.render( + key_list=key_list, + delete_marker_list=delete_marker_list, + latest_versions=latest_versions, + bucket=bucket, + prefix="", + max_keys=1000, + delimiter="", + is_truncated="false", + ), ) - elif querystring.get('list-type', [None])[0] == '2': + elif querystring.get("list-type", [None])[0] == "2": return 200, {}, self._handle_list_objects_v2(bucket_name, querystring) bucket = self.backend.get_bucket(bucket_name) - prefix = querystring.get('prefix', [None])[0] + prefix = querystring.get("prefix", [None])[0] if prefix and isinstance(prefix, six.binary_type): prefix = prefix.decode("utf-8") - delimiter = querystring.get('delimiter', [None])[0] - max_keys = int(querystring.get('max-keys', [1000])[0]) - marker = querystring.get('marker', [None])[0] + delimiter = querystring.get("delimiter", [None])[0] + max_keys = int(querystring.get("max-keys", [1000])[0]) + marker = querystring.get("marker", [None])[0] result_keys, result_folders = self.backend.prefix_query( - bucket, prefix, delimiter) + bucket, prefix, delimiter + ) if marker: result_keys = self._get_results_from_token(result_keys, marker) - result_keys, is_truncated, next_marker = self._truncate_result(result_keys, max_keys) + result_keys, is_truncated, next_marker = self._truncate_result( + result_keys, max_keys + ) template = self.response_template(S3_BUCKET_GET_RESPONSE) - return 200, {}, template.render( - bucket=bucket, - prefix=prefix, - delimiter=delimiter, - result_keys=result_keys, - result_folders=result_folders, - is_truncated=is_truncated, - next_marker=next_marker, - max_keys=max_keys + return ( + 200, + {}, + template.render( + bucket=bucket, + prefix=prefix, + delimiter=delimiter, + result_keys=result_keys, + result_folders=result_folders, + is_truncated=is_truncated, + next_marker=next_marker, + max_keys=max_keys, + ), ) def _set_action(self, action_resource_type, method, querystring): action_set = False - for action_in_querystring, action in ACTION_MAP[action_resource_type][method].items(): + for action_in_querystring, action in ACTION_MAP[action_resource_type][ + method + ].items(): if action_in_querystring in querystring: self.data["Action"] = action action_set = True @@ -440,17 +495,18 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): template = self.response_template(S3_BUCKET_GET_RESPONSE_V2) bucket = self.backend.get_bucket(bucket_name) - prefix = querystring.get('prefix', [None])[0] + prefix = querystring.get("prefix", [None])[0] if prefix and isinstance(prefix, six.binary_type): prefix = prefix.decode("utf-8") - delimiter = querystring.get('delimiter', [None])[0] + delimiter = querystring.get("delimiter", [None])[0] result_keys, result_folders = self.backend.prefix_query( - bucket, prefix, delimiter) + bucket, prefix, delimiter + ) - fetch_owner = querystring.get('fetch-owner', [False])[0] - max_keys = int(querystring.get('max-keys', [1000])[0]) - continuation_token = querystring.get('continuation-token', [None])[0] - start_after = querystring.get('start-after', [None])[0] + fetch_owner = querystring.get("fetch-owner", [False])[0] + max_keys = int(querystring.get("max-keys", [1000])[0]) + continuation_token = querystring.get("continuation-token", [None])[0] + start_after = querystring.get("start-after", [None])[0] # sort the combination of folders and keys into lexicographical order all_keys = result_keys + result_folders @@ -460,14 +516,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): limit = continuation_token or start_after all_keys = self._get_results_from_token(all_keys, limit) - truncated_keys, is_truncated, next_continuation_token = self._truncate_result(all_keys, max_keys) + truncated_keys, is_truncated, next_continuation_token = self._truncate_result( + all_keys, max_keys + ) result_keys, result_folders = self._split_truncated_keys(truncated_keys) key_count = len(result_keys) + len(result_folders) return template.render( bucket=bucket, - prefix=prefix or '', + prefix=prefix or "", delimiter=delimiter, key_count=key_count, result_keys=result_keys, @@ -476,7 +534,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): max_keys=max_keys, is_truncated=is_truncated, next_continuation_token=next_continuation_token, - start_after=None if continuation_token else start_after + start_after=None if continuation_token else start_after, ) @staticmethod @@ -507,41 +565,43 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _truncate_result(self, result_keys, max_keys): if len(result_keys) > max_keys: - is_truncated = 'true' + is_truncated = "true" result_keys = result_keys[:max_keys] item = result_keys[-1] - next_continuation_token = (item.name if isinstance(item, FakeKey) else item) + next_continuation_token = item.name if isinstance(item, FakeKey) else item else: - is_truncated = 'false' + is_truncated = "false" next_continuation_token = None return result_keys, is_truncated, next_continuation_token - def _bucket_response_put(self, request, body, region_name, bucket_name, querystring): - if not request.headers.get('Content-Length'): + def _bucket_response_put( + self, request, body, region_name, bucket_name, querystring + ): + if not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" self._set_action("BUCKET", "PUT", querystring) self._authenticate_and_authorize_s3_action() - if 'versioning' in querystring: - ver = re.search('([A-Za-z]+)', body.decode()) + if "versioning" in querystring: + ver = re.search("([A-Za-z]+)", body.decode()) if ver: self.backend.set_bucket_versioning(bucket_name, ver.group(1)) template = self.response_template(S3_BUCKET_VERSIONING) return template.render(bucket_versioning_status=ver.group(1)) else: return 404, {}, "" - elif 'lifecycle' in querystring: - rules = xmltodict.parse(body)['LifecycleConfiguration']['Rule'] + elif "lifecycle" in querystring: + rules = xmltodict.parse(body)["LifecycleConfiguration"]["Rule"] if not isinstance(rules, list): # If there is only one rule, xmldict returns just the item rules = [rules] self.backend.set_bucket_lifecycle(bucket_name, rules) return "" - elif 'policy' in querystring: + elif "policy" in querystring: self.backend.set_bucket_policy(bucket_name, body) - return 'True' - elif 'acl' in querystring: + return "True" + elif "acl" in querystring: # Headers are first. If not set, then look at the body (consistent with the documentation): acls = self._acl_from_headers(request.headers) if not acls: @@ -552,7 +612,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): tagging = self._bucket_tagging_from_xml(body) self.backend.put_bucket_tagging(bucket_name, tagging) return "" - elif 'website' in querystring: + elif "website" in querystring: self.backend.set_bucket_website_configuration(bucket_name, body) return "" elif "cors" in querystring: @@ -563,14 +623,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): raise MalformedXML() elif "logging" in querystring: try: - self.backend.put_bucket_logging(bucket_name, self._logging_from_xml(body)) + self.backend.put_bucket_logging( + bucket_name, self._logging_from_xml(body) + ) return "" except KeyError: raise MalformedXML() elif "notification" in querystring: try: - self.backend.put_bucket_notification_configuration(bucket_name, - self._notification_config_from_xml(body)) + self.backend.put_bucket_notification_configuration( + bucket_name, self._notification_config_from_xml(body) + ) return "" except KeyError: raise MalformedXML() @@ -579,7 +642,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): elif "accelerate" in querystring: try: accelerate_status = self._accelerate_config_from_xml(body) - self.backend.put_bucket_accelerate_configuration(bucket_name, accelerate_status) + self.backend.put_bucket_accelerate_configuration( + bucket_name, accelerate_status + ) return "" except KeyError: raise MalformedXML() @@ -592,12 +657,14 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # - you should not use it as a location constraint --> it fails # - querying the location constraint returns None try: - forced_region = xmltodict.parse(body)['CreateBucketConfiguration']['LocationConstraint'] + forced_region = xmltodict.parse(body)["CreateBucketConfiguration"][ + "LocationConstraint" + ] if forced_region == DEFAULT_REGION_NAME: raise S3ClientError( - 'InvalidLocationConstraint', - 'The specified location-constraint is not valid' + "InvalidLocationConstraint", + "The specified location-constraint is not valid", ) else: region_name = forced_region @@ -605,8 +672,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): pass try: - new_bucket = self.backend.create_bucket( - bucket_name, region_name) + new_bucket = self.backend.create_bucket(bucket_name, region_name) except BucketAlreadyExists: if region_name == DEFAULT_REGION_NAME: # us-east-1 has different behavior @@ -614,9 +680,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: raise - if 'x-amz-acl' in request.headers: + if "x-amz-acl" in request.headers: # TODO: Support the XML-based ACL format - self.backend.set_bucket_acl(bucket_name, self._acl_from_headers(request.headers)) + self.backend.set_bucket_acl( + bucket_name, self._acl_from_headers(request.headers) + ) template = self.response_template(S3_BUCKET_CREATE_RESPONSE) return 200, {}, template.render(bucket=new_bucket) @@ -625,7 +693,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._set_action("BUCKET", "DELETE", querystring) self._authenticate_and_authorize_s3_action() - if 'policy' in querystring: + if "policy" in querystring: self.backend.delete_bucket_policy(bucket_name, body) return 204, {}, "" elif "tagging" in querystring: @@ -634,7 +702,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): elif "cors" in querystring: self.backend.delete_bucket_cors(bucket_name) return 204, {}, "" - elif 'lifecycle' in querystring: + elif "lifecycle" in querystring: bucket = self.backend.get_bucket(bucket_name) bucket.delete_lifecycle() return 204, {}, "" @@ -647,12 +715,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 204, {}, template.render(bucket=removed_bucket) else: # Tried to delete a bucket that still has keys - template = self.response_template( - S3_DELETE_BUCKET_WITH_ITEMS_ERROR) + template = self.response_template(S3_DELETE_BUCKET_WITH_ITEMS_ERROR) return 409, {}, template.render(bucket=removed_bucket) def _bucket_response_post(self, request, body, bucket_name): - if not request.headers.get('Content-Length'): + if not request.headers.get("Content-Length"): return 411, {}, "Content-Length required" path = self._get_path(request) @@ -667,7 +734,7 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._authenticate_and_authorize_s3_action() # POST to bucket-url should create file from form - if hasattr(request, 'form'): + if hasattr(request, "form"): # Not HTTPretty form = request.form else: @@ -675,15 +742,15 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): body = body.decode() form = {} - for kv in body.split('&'): - k, v = kv.split('=') + for kv in body.split("&"): + k, v = kv.split("=") form[k] = v - key = form['key'] - if 'file' in form: - f = form['file'] + key = form["key"] + if "file" in form: + f = form["file"] else: - f = request.files['file'].stream.read() + f = request.files["file"].stream.read() new_key = self.backend.set_key(bucket_name, key, f) @@ -698,13 +765,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if isinstance(request, HTTPrettyRequest): path = request.path else: - path = request.full_path if hasattr(request, 'full_path') else path_url(request.url) + path = ( + request.full_path + if hasattr(request, "full_path") + else path_url(request.url) + ) return path def _bucket_response_delete_keys(self, request, body, bucket_name): template = self.response_template(S3_DELETE_KEYS_RESPONSE) - keys = minidom.parseString(body).getElementsByTagName('Key') + keys = minidom.parseString(body).getElementsByTagName("Key") deleted_names = [] error_names = [] if len(keys) == 0: @@ -712,27 +783,32 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): for k in keys: key_name = k.firstChild.nodeValue - success = self.backend.delete_key(bucket_name, undo_clean_key_name(key_name)) + success = self.backend.delete_key( + bucket_name, undo_clean_key_name(key_name) + ) if success: deleted_names.append(key_name) else: error_names.append(key_name) - return 200, {}, template.render(deleted=deleted_names, delete_errors=error_names) + return ( + 200, + {}, + template.render(deleted=deleted_names, delete_errors=error_names), + ) def _handle_range_header(self, request, headers, response_content): response_headers = {} length = len(response_content) last = length - 1 - _, rspec = request.headers.get('range').split('=') - if ',' in rspec: - raise NotImplementedError( - "Multiple range specifiers not supported") + _, rspec = request.headers.get("range").split("=") + if "," in rspec: + raise NotImplementedError("Multiple range specifiers not supported") def toint(i): return int(i) if i else None - begin, end = map(toint, rspec.split('-')) + begin, end = map(toint, rspec.split("-")) if begin is not None: # byte range end = last if end is None else min(end, last) elif end is not None: # suffix byte range @@ -742,16 +818,17 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return 400, response_headers, "" if begin < 0 or end > last or begin > min(end, last): return 416, response_headers, "" - response_headers['content-range'] = "bytes {0}-{1}/{2}".format( - begin, end, length) - return 206, response_headers, response_content[begin:end + 1] + response_headers["content-range"] = "bytes {0}-{1}/{2}".format( + begin, end, length + ) + return 206, response_headers, response_content[begin : end + 1] def key_response(self, request, full_url, headers): self.method = request.method self.path = self._get_path(request) self.headers = request.headers - if 'host' not in self.headers: - self.headers['host'] = urlparse(full_url).netloc + if "host" not in self.headers: + self.headers["host"] = urlparse(full_url).netloc response_headers = {} try: response = self._key_response(request, full_url, headers) @@ -764,8 +841,10 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): else: status_code, response_headers, response_content = response - if status_code == 200 and 'range' in request.headers: - return self._handle_range_header(request, response_headers, response_content) + if status_code == 200 and "range" in request.headers: + return self._handle_range_header( + request, response_headers, response_content + ) return status_code, response_headers, response_content def _key_response(self, request, full_url, headers): @@ -782,72 +861,84 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # Here we deny public access to private files by checking the # ACL and checking for the mere presence of an Authorization # header. - if 'Authorization' not in request.headers: - if hasattr(request, 'url'): - signed_url = 'Signature=' in request.url - elif hasattr(request, 'requestline'): - signed_url = 'Signature=' in request.path + if "Authorization" not in request.headers: + if hasattr(request, "url"): + signed_url = "Signature=" in request.url + elif hasattr(request, "requestline"): + signed_url = "Signature=" in request.path key = self.backend.get_key(bucket_name, key_name) if key: if not key.acl.public_read and not signed_url: return 403, {}, "" - if hasattr(request, 'body'): + if hasattr(request, "body"): # Boto body = request.body - if hasattr(body, 'read'): + if hasattr(body, "read"): body = body.read() else: # Flask server body = request.data if body is None: - body = b'' + body = b"" - if method == 'GET': - return self._key_response_get(bucket_name, query, key_name, headers=request.headers) - elif method == 'PUT': - return self._key_response_put(request, body, bucket_name, query, key_name, headers) - elif method == 'HEAD': - return self._key_response_head(bucket_name, query, key_name, headers=request.headers) - elif method == 'DELETE': + if method == "GET": + return self._key_response_get( + bucket_name, query, key_name, headers=request.headers + ) + elif method == "PUT": + return self._key_response_put( + request, body, bucket_name, query, key_name, headers + ) + elif method == "HEAD": + return self._key_response_head( + bucket_name, query, key_name, headers=request.headers + ) + elif method == "DELETE": return self._key_response_delete(bucket_name, query, key_name) - elif method == 'POST': + elif method == "POST": return self._key_response_post(request, body, bucket_name, query, key_name) else: raise NotImplementedError( - "Method {0} has not been implemented in the S3 backend yet".format(method)) + "Method {0} has not been implemented in the S3 backend yet".format( + method + ) + ) def _key_response_get(self, bucket_name, query, key_name, headers): self._set_action("KEY", "GET", query) self._authenticate_and_authorize_s3_action() response_headers = {} - if query.get('uploadId'): - upload_id = query['uploadId'][0] + if query.get("uploadId"): + upload_id = query["uploadId"][0] parts = self.backend.list_multipart(bucket_name, upload_id) template = self.response_template(S3_MULTIPART_LIST_RESPONSE) - return 200, response_headers, template.render( - bucket_name=bucket_name, - key_name=key_name, - upload_id=upload_id, - count=len(parts), - parts=parts + return ( + 200, + response_headers, + template.render( + bucket_name=bucket_name, + key_name=key_name, + upload_id=upload_id, + count=len(parts), + parts=parts, + ), ) - version_id = query.get('versionId', [None])[0] - if_modified_since = headers.get('If-Modified-Since', None) - key = self.backend.get_key( - bucket_name, key_name, version_id=version_id) + version_id = query.get("versionId", [None])[0] + if_modified_since = headers.get("If-Modified-Since", None) + key = self.backend.get_key(bucket_name, key_name, version_id=version_id) if key is None: raise MissingKey(key_name) if if_modified_since: if_modified_since = str_to_rfc_1123_datetime(if_modified_since) if if_modified_since and key.last_modified < if_modified_since: - return 304, response_headers, 'Not Modified' - if 'acl' in query: + return 304, response_headers, "Not Modified" + if "acl" in query: template = self.response_template(S3_OBJECT_ACL_RESPONSE) return 200, response_headers, template.render(obj=key) - if 'tagging' in query: + if "tagging" in query: template = self.response_template(S3_OBJECT_TAGGING_RESPONSE) return 200, response_headers, template.render(obj=key) @@ -860,16 +951,21 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): self._authenticate_and_authorize_s3_action() response_headers = {} - if query.get('uploadId') and query.get('partNumber'): - upload_id = query['uploadId'][0] - part_number = int(query['partNumber'][0]) - if 'x-amz-copy-source' in request.headers: + if query.get("uploadId") and query.get("partNumber"): + upload_id = query["uploadId"][0] + part_number = int(query["partNumber"][0]) + if "x-amz-copy-source" in request.headers: src = unquote(request.headers.get("x-amz-copy-source")).lstrip("/") src_bucket, src_key = src.split("/", 1) - src_key, src_version_id = src_key.split("?versionId=") if "?versionId=" in src_key else (src_key, None) - src_range = request.headers.get( - 'x-amz-copy-source-range', '').split("bytes=")[-1] + src_key, src_version_id = ( + src_key.split("?versionId=") + if "?versionId=" in src_key + else (src_key, None) + ) + src_range = request.headers.get("x-amz-copy-source-range", "").split( + "bytes=" + )[-1] try: start_byte, end_byte = src_range.split("-") @@ -879,74 +975,87 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if self.backend.get_key(src_bucket, src_key, version_id=src_version_id): key = self.backend.copy_part( - bucket_name, upload_id, part_number, src_bucket, - src_key, src_version_id, start_byte, end_byte) + bucket_name, + upload_id, + part_number, + src_bucket, + src_key, + src_version_id, + start_byte, + end_byte, + ) else: return 404, response_headers, "" template = self.response_template(S3_MULTIPART_UPLOAD_RESPONSE) response = template.render(part=key) else: - key = self.backend.set_part( - bucket_name, upload_id, part_number, body) + key = self.backend.set_part(bucket_name, upload_id, part_number, body) response = "" response_headers.update(key.response_dict) return 200, response_headers, response - storage_class = request.headers.get('x-amz-storage-class', 'STANDARD') + storage_class = request.headers.get("x-amz-storage-class", "STANDARD") acl = self._acl_from_headers(request.headers) if acl is None: acl = self.backend.get_bucket(bucket_name).acl tagging = self._tagging_from_headers(request.headers) - if 'acl' in query: + if "acl" in query: key = self.backend.get_key(bucket_name, key_name) # TODO: Support the XML-based ACL format key.set_acl(acl) return 200, response_headers, "" - if 'tagging' in query: - if 'versionId' in query: - version_id = query['versionId'][0] + if "tagging" in query: + if "versionId" in query: + version_id = query["versionId"][0] else: version_id = None tagging = self._tagging_from_xml(body) self.backend.set_key_tagging(bucket_name, key_name, tagging, version_id) return 200, response_headers, "" - if 'x-amz-copy-source' in request.headers: + if "x-amz-copy-source" in request.headers: # Copy key # you can have a quoted ?version=abc with a version Id, so work on # we need to parse the unquoted string first src_key = request.headers.get("x-amz-copy-source") if isinstance(src_key, six.binary_type): - src_key = src_key.decode('utf-8') + src_key = src_key.decode("utf-8") src_key_parsed = urlparse(src_key) - src_bucket, src_key = clean_key_name(src_key_parsed.path).\ - lstrip("/").split("/", 1) - src_version_id = parse_qs(src_key_parsed.query).get( - 'versionId', [None])[0] + src_bucket, src_key = ( + clean_key_name(src_key_parsed.path).lstrip("/").split("/", 1) + ) + src_version_id = parse_qs(src_key_parsed.query).get("versionId", [None])[0] key = self.backend.get_key(src_bucket, src_key, version_id=src_version_id) if key is not None: if key.storage_class in ["GLACIER", "DEEP_ARCHIVE"]: raise ObjectNotInActiveTierError(key) - self.backend.copy_key(src_bucket, src_key, bucket_name, key_name, - storage=storage_class, acl=acl, src_version_id=src_version_id) + self.backend.copy_key( + src_bucket, + src_key, + bucket_name, + key_name, + storage=storage_class, + acl=acl, + src_version_id=src_version_id, + ) else: return 404, response_headers, "" new_key = self.backend.get_key(bucket_name, key_name) - mdirective = request.headers.get('x-amz-metadata-directive') - if mdirective is not None and mdirective == 'REPLACE': + mdirective = request.headers.get("x-amz-metadata-directive") + if mdirective is not None and mdirective == "REPLACE": metadata = metadata_from_headers(request.headers) new_key.set_metadata(metadata, replace=True) template = self.response_template(S3_OBJECT_COPY_RESPONSE) response_headers.update(new_key.response_dict) return 200, response_headers, template.render(key=new_key) - streaming_request = hasattr(request, 'streaming') and request.streaming - closing_connection = headers.get('connection') == 'close' + streaming_request = hasattr(request, "streaming") and request.streaming + closing_connection = headers.get("connection") == "close" if closing_connection and streaming_request: # Closing the connection of a streaming request. No more data new_key = self.backend.get_key(bucket_name, key_name) @@ -955,13 +1064,16 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): new_key = self.backend.append_to_key(bucket_name, key_name, body) else: # Initial data - new_key = self.backend.set_key(bucket_name, key_name, body, - storage=storage_class) + new_key = self.backend.set_key( + bucket_name, key_name, body, storage=storage_class + ) request.streaming = True metadata = metadata_from_headers(request.headers) new_key.set_metadata(metadata) new_key.set_acl(acl) - new_key.website_redirect_location = request.headers.get('x-amz-website-redirect-location') + new_key.website_redirect_location = request.headers.get( + "x-amz-website-redirect-location" + ) new_key.set_tagging(tagging) template = self.response_template(S3_OBJECT_RESPONSE) @@ -970,27 +1082,24 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _key_response_head(self, bucket_name, query, key_name, headers): response_headers = {} - version_id = query.get('versionId', [None])[0] - part_number = query.get('partNumber', [None])[0] + version_id = query.get("versionId", [None])[0] + part_number = query.get("partNumber", [None])[0] if part_number: part_number = int(part_number) - if_modified_since = headers.get('If-Modified-Since', None) + if_modified_since = headers.get("If-Modified-Since", None) if if_modified_since: if_modified_since = str_to_rfc_1123_datetime(if_modified_since) key = self.backend.get_key( - bucket_name, - key_name, - version_id=version_id, - part_number=part_number + bucket_name, key_name, version_id=version_id, part_number=part_number ) if key: response_headers.update(key.metadata) response_headers.update(key.response_dict) if if_modified_since and key.last_modified < if_modified_since: - return 304, response_headers, 'Not Modified' + return 304, response_headers, "Not Modified" else: return 200, response_headers, "" else: @@ -1013,20 +1122,20 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if not parsed_xml["AccessControlPolicy"]["AccessControlList"].get("Grant"): raise MalformedACLError() - permissions = [ - "READ", - "WRITE", - "READ_ACP", - "WRITE_ACP", - "FULL_CONTROL" - ] + permissions = ["READ", "WRITE", "READ_ACP", "WRITE_ACP", "FULL_CONTROL"] - if not isinstance(parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], list): - parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"] = \ - [parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"]] + if not isinstance( + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], list + ): + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"] = [ + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"] + ] - grants = self._get_grants_from_xml(parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], - MalformedACLError, permissions) + grants = self._get_grants_from_xml( + parsed_xml["AccessControlPolicy"]["AccessControlList"]["Grant"], + MalformedACLError, + permissions, + ) return FakeAcl(grants) def _get_grants_from_xml(self, grant_list, exception_type, permissions): @@ -1035,42 +1144,54 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): if grant.get("Permission", "") not in permissions: raise exception_type() - if grant["Grantee"].get("@xsi:type", "") not in ["CanonicalUser", "AmazonCustomerByEmail", "Group"]: + if grant["Grantee"].get("@xsi:type", "") not in [ + "CanonicalUser", + "AmazonCustomerByEmail", + "Group", + ]: raise exception_type() # TODO: Verify that the proper grantee data is supplied based on the type. - grants.append(FakeGrant( - [FakeGrantee(id=grant["Grantee"].get("ID", ""), display_name=grant["Grantee"].get("DisplayName", ""), - uri=grant["Grantee"].get("URI", ""))], - [grant["Permission"]]) + grants.append( + FakeGrant( + [ + FakeGrantee( + id=grant["Grantee"].get("ID", ""), + display_name=grant["Grantee"].get("DisplayName", ""), + uri=grant["Grantee"].get("URI", ""), + ) + ], + [grant["Permission"]], + ) ) return grants def _acl_from_headers(self, headers): - canned_acl = headers.get('x-amz-acl', '') + canned_acl = headers.get("x-amz-acl", "") if canned_acl: return get_canned_acl(canned_acl) grants = [] for header, value in headers.items(): - if not header.startswith('x-amz-grant-'): + if not header.startswith("x-amz-grant-"): continue permission = { - 'read': 'READ', - 'write': 'WRITE', - 'read-acp': 'READ_ACP', - 'write-acp': 'WRITE_ACP', - 'full-control': 'FULL_CONTROL', - }[header[len('x-amz-grant-'):]] + "read": "READ", + "write": "WRITE", + "read-acp": "READ_ACP", + "write-acp": "WRITE_ACP", + "full-control": "FULL_CONTROL", + }[header[len("x-amz-grant-") :]] grantees = [] for key_and_value in value.split(","): key, value = re.match( - '([^=]+)="([^"]+)"', key_and_value.strip()).groups() - if key.lower() == 'id': + '([^=]+)="([^"]+)"', key_and_value.strip() + ).groups() + if key.lower() == "id": grantees.append(FakeGrantee(id=value)) else: grantees.append(FakeGrantee(uri=value)) @@ -1082,8 +1203,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return None def _tagging_from_headers(self, headers): - if headers.get('x-amz-tagging'): - parsed_header = parse_qs(headers['x-amz-tagging'], keep_blank_values=True) + if headers.get("x-amz-tagging"): + parsed_header = parse_qs(headers["x-amz-tagging"], keep_blank_values=True) tags = [] for tag in parsed_header.items(): tags.append(FakeTag(tag[0], tag[1][0])) @@ -1095,11 +1216,11 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return FakeTagging() def _tagging_from_xml(self, xml): - parsed_xml = xmltodict.parse(xml, force_list={'Tag': True}) + parsed_xml = xmltodict.parse(xml, force_list={"Tag": True}) tags = [] - for tag in parsed_xml['Tagging']['TagSet']['Tag']: - tags.append(FakeTag(tag['Key'], tag['Value'])) + for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]: + tags.append(FakeTag(tag["Key"], tag["Value"])) tag_set = FakeTagSet(tags) tagging = FakeTagging(tag_set) @@ -1110,14 +1231,18 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): tags = [] # Optional if no tags are being sent: - if parsed_xml['Tagging'].get('TagSet'): + if parsed_xml["Tagging"].get("TagSet"): # If there is only 1 tag, then it's not a list: - if not isinstance(parsed_xml['Tagging']['TagSet']['Tag'], list): - tags.append(FakeTag(parsed_xml['Tagging']['TagSet']['Tag']['Key'], - parsed_xml['Tagging']['TagSet']['Tag']['Value'])) + if not isinstance(parsed_xml["Tagging"]["TagSet"]["Tag"], list): + tags.append( + FakeTag( + parsed_xml["Tagging"]["TagSet"]["Tag"]["Key"], + parsed_xml["Tagging"]["TagSet"]["Tag"]["Value"], + ) + ) else: - for tag in parsed_xml['Tagging']['TagSet']['Tag']: - tags.append(FakeTag(tag['Key'], tag['Value'])) + for tag in parsed_xml["Tagging"]["TagSet"]["Tag"]: + tags.append(FakeTag(tag["Key"], tag["Value"])) tag_set = FakeTagSet(tags) tagging = FakeTagging(tag_set) @@ -1145,25 +1270,34 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): # Get the ACLs: if parsed_xml["BucketLoggingStatus"]["LoggingEnabled"].get("TargetGrants"): - permissions = [ - "READ", - "WRITE", - "FULL_CONTROL" - ] - if not isinstance(parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"]["Grant"], list): + permissions = ["READ", "WRITE", "FULL_CONTROL"] + if not isinstance( + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"][ + "Grant" + ], + list, + ): target_grants = self._get_grants_from_xml( - [parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"]["Grant"]], + [ + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"][ + "TargetGrants" + ]["Grant"] + ], MalformedXML, - permissions + permissions, ) else: target_grants = self._get_grants_from_xml( - parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"]["Grant"], + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"][ + "Grant" + ], MalformedXML, - permissions + permissions, ) - parsed_xml["BucketLoggingStatus"]["LoggingEnabled"]["TargetGrants"] = target_grants + parsed_xml["BucketLoggingStatus"]["LoggingEnabled"][ + "TargetGrants" + ] = target_grants return parsed_xml["BucketLoggingStatus"]["LoggingEnabled"] @@ -1178,31 +1312,36 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): notification_fields = [ ("Topic", "sns"), ("Queue", "sqs"), - ("CloudFunction", "lambda") + ("CloudFunction", "lambda"), ] event_names = [ - 's3:ReducedRedundancyLostObject', - 's3:ObjectCreated:*', - 's3:ObjectCreated:Put', - 's3:ObjectCreated:Post', - 's3:ObjectCreated:Copy', - 's3:ObjectCreated:CompleteMultipartUpload', - 's3:ObjectRemoved:*', - 's3:ObjectRemoved:Delete', - 's3:ObjectRemoved:DeleteMarkerCreated' + "s3:ReducedRedundancyLostObject", + "s3:ObjectCreated:*", + "s3:ObjectCreated:Put", + "s3:ObjectCreated:Post", + "s3:ObjectCreated:Copy", + "s3:ObjectCreated:CompleteMultipartUpload", + "s3:ObjectRemoved:*", + "s3:ObjectRemoved:Delete", + "s3:ObjectRemoved:DeleteMarkerCreated", ] - found_notifications = 0 # Tripwire -- if this is not ever set, then there were no notifications + found_notifications = ( + 0 # Tripwire -- if this is not ever set, then there were no notifications + ) for name, arn_string in notification_fields: # 1st verify that the proper notification configuration has been passed in (with an ARN that is close # to being correct -- nothing too complex in the ARN logic): - the_notification = parsed_xml["NotificationConfiguration"].get("{}Configuration".format(name)) + the_notification = parsed_xml["NotificationConfiguration"].get( + "{}Configuration".format(name) + ) if the_notification: found_notifications += 1 if not isinstance(the_notification, list): - the_notification = parsed_xml["NotificationConfiguration"]["{}Configuration".format(name)] \ - = [the_notification] + the_notification = parsed_xml["NotificationConfiguration"][ + "{}Configuration".format(name) + ] = [the_notification] for n in the_notification: if not n[name].startswith("arn:aws:{}:".format(arn_string)): @@ -1224,7 +1363,9 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): raise KeyError() if not isinstance(n["Filter"]["S3Key"]["FilterRule"], list): - n["Filter"]["S3Key"]["FilterRule"] = [n["Filter"]["S3Key"]["FilterRule"]] + n["Filter"]["S3Key"]["FilterRule"] = [ + n["Filter"]["S3Key"]["FilterRule"] + ] for filter_rule in n["Filter"]["S3Key"]["FilterRule"]: assert filter_rule["Name"] in ["suffix", "prefix"] @@ -1237,61 +1378,55 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): def _accelerate_config_from_xml(self, xml): parsed_xml = xmltodict.parse(xml) - config = parsed_xml['AccelerateConfiguration'] - return config['Status'] + config = parsed_xml["AccelerateConfiguration"] + return config["Status"] def _key_response_delete(self, bucket_name, query, key_name): self._set_action("KEY", "DELETE", query) self._authenticate_and_authorize_s3_action() - if query.get('uploadId'): - upload_id = query['uploadId'][0] + if query.get("uploadId"): + upload_id = query["uploadId"][0] self.backend.cancel_multipart(bucket_name, upload_id) return 204, {}, "" - version_id = query.get('versionId', [None])[0] + version_id = query.get("versionId", [None])[0] self.backend.delete_key(bucket_name, key_name, version_id=version_id) template = self.response_template(S3_DELETE_OBJECT_SUCCESS) return 204, {}, template.render() def _complete_multipart_body(self, body): - ps = minidom.parseString(body).getElementsByTagName('Part') + ps = minidom.parseString(body).getElementsByTagName("Part") prev = 0 for p in ps: - pn = int(p.getElementsByTagName( - 'PartNumber')[0].firstChild.wholeText) + pn = int(p.getElementsByTagName("PartNumber")[0].firstChild.wholeText) if pn <= prev: raise InvalidPartOrder() - yield (pn, p.getElementsByTagName('ETag')[0].firstChild.wholeText) + yield (pn, p.getElementsByTagName("ETag")[0].firstChild.wholeText) def _key_response_post(self, request, body, bucket_name, query, key_name): self._set_action("KEY", "POST", query) self._authenticate_and_authorize_s3_action() - if body == b'' and 'uploads' in query: + if body == b"" and "uploads" in query: metadata = metadata_from_headers(request.headers) - multipart = self.backend.initiate_multipart( - bucket_name, key_name, metadata) + multipart = self.backend.initiate_multipart(bucket_name, key_name, metadata) template = self.response_template(S3_MULTIPART_INITIATE_RESPONSE) response = template.render( - bucket_name=bucket_name, - key_name=key_name, - upload_id=multipart.id, + bucket_name=bucket_name, key_name=key_name, upload_id=multipart.id ) return 200, {}, response - if query.get('uploadId'): + if query.get("uploadId"): body = self._complete_multipart_body(body) - upload_id = query['uploadId'][0] + upload_id = query["uploadId"][0] key = self.backend.complete_multipart(bucket_name, upload_id, body) template = self.response_template(S3_MULTIPART_COMPLETE_RESPONSE) return template.render( - bucket_name=bucket_name, - key_name=key.name, - etag=key.etag, + bucket_name=bucket_name, key_name=key.name, etag=key.etag ) - elif 'restore' in query: - es = minidom.parseString(body).getElementsByTagName('Days') + elif "restore" in query: + es = minidom.parseString(body).getElementsByTagName("Days") days = es[0].childNodes[0].wholeText key = self.backend.get_key(bucket_name, key_name) r = 202 @@ -1301,7 +1436,8 @@ class ResponseObject(_TemplateEnvironmentMixin, ActionAuthenticatorMixin): return r, {}, "" else: raise NotImplementedError( - "Method POST had only been implemented for multipart uploads and restore operations, so far") + "Method POST had only been implemented for multipart uploads and restore operations, so far" + ) S3ResponseInstance = ResponseObject(s3_backend) diff --git a/moto/s3/urls.py b/moto/s3/urls.py index 1388c81e5..7241dbef1 100644 --- a/moto/s3/urls.py +++ b/moto/s3/urls.py @@ -4,17 +4,16 @@ from .responses import S3ResponseInstance url_bases = [ "https?://s3(.*).amazonaws.com", - r"https?://(?P[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com" + r"https?://(?P[a-zA-Z0-9\-_.]*)\.?s3(.*).amazonaws.com", ] url_paths = { # subdomain bucket - '{0}/$': S3ResponseInstance.bucket_response, - + "{0}/$": S3ResponseInstance.bucket_response, # subdomain key of path-based bucket - '{0}/(?P[^/]+)/?$': S3ResponseInstance.ambiguous_response, + "{0}/(?P[^/]+)/?$": S3ResponseInstance.ambiguous_response, # path-based bucket + key - '{0}/(?P[^/]+)/(?P.+)': S3ResponseInstance.key_response, + "{0}/(?P[^/]+)/(?P.+)": S3ResponseInstance.key_response, # subdomain bucket + key with empty first part of path - '{0}//(?P.*)$': S3ResponseInstance.key_response, + "{0}//(?P.*)$": S3ResponseInstance.key_response, } diff --git a/moto/s3/utils.py b/moto/s3/utils.py index 3bdd24cc4..e7d9e5580 100644 --- a/moto/s3/utils.py +++ b/moto/s3/utils.py @@ -16,19 +16,19 @@ bucket_name_regex = re.compile("(.+).s3(.*).amazonaws.com") def bucket_name_from_url(url): - if os.environ.get('S3_IGNORE_SUBDOMAIN_BUCKETNAME', '') in ['1', 'true']: + if os.environ.get("S3_IGNORE_SUBDOMAIN_BUCKETNAME", "") in ["1", "true"]: return None domain = urlparse(url).netloc - if domain.startswith('www.'): + if domain.startswith("www."): domain = domain[4:] - if 'amazonaws.com' in domain: + if "amazonaws.com" in domain: bucket_result = bucket_name_regex.search(domain) if bucket_result: return bucket_result.groups()[0] else: - if '.' in domain: + if "." in domain: return domain.split(".")[0] else: # No subdomain found. @@ -36,23 +36,23 @@ def bucket_name_from_url(url): REGION_URL_REGEX = re.compile( - r'^https?://(s3[-\.](?P.+)\.amazonaws\.com/(.+)|' - r'(.+)\.s3-(?P.+)\.amazonaws\.com)/?') + r"^https?://(s3[-\.](?P.+)\.amazonaws\.com/(.+)|" + r"(.+)\.s3-(?P.+)\.amazonaws\.com)/?" +) def parse_region_from_url(url): match = REGION_URL_REGEX.search(url) if match: - region = match.group('region1') or match.group('region2') + region = match.group("region1") or match.group("region2") else: - region = 'us-east-1' + region = "us-east-1" return region def metadata_from_headers(headers): metadata = {} - meta_regex = re.compile( - '^x-amz-meta-([a-zA-Z0-9\-_]+)$', flags=re.IGNORECASE) + meta_regex = re.compile("^x-amz-meta-([a-zA-Z0-9\-_]+)$", flags=re.IGNORECASE) for header, value in headers.items(): if isinstance(header, six.string_types): result = meta_regex.match(header) @@ -70,13 +70,13 @@ def metadata_from_headers(headers): def clean_key_name(key_name): if six.PY2: - return unquote(key_name.encode('utf-8')).decode('utf-8') + return unquote(key_name.encode("utf-8")).decode("utf-8") return unquote(key_name) def undo_clean_key_name(key_name): if six.PY2: - return quote(key_name.encode('utf-8')).decode('utf-8') + return quote(key_name.encode("utf-8")).decode("utf-8") return quote(key_name) @@ -140,6 +140,7 @@ class _VersionedKeyStore(dict): values = itervalues = _itervalues if sys.version_info[0] < 3: + def items(self): return list(self.iteritems()) diff --git a/moto/s3bucket_path/utils.py b/moto/s3bucket_path/utils.py index 1b9a034f4..d514a1b35 100644 --- a/moto/s3bucket_path/utils.py +++ b/moto/s3bucket_path/utils.py @@ -17,8 +17,10 @@ def parse_key_name(path): def is_delete_keys(request, path, bucket_name): return ( - path == u'/' + bucket_name + u'/?delete' or - path == u'/' + bucket_name + u'?delete' or - (path == u'/' + bucket_name and - getattr(request, "query_string", "") == "delete") + path == "/" + bucket_name + "/?delete" + or path == "/" + bucket_name + "?delete" + or ( + path == "/" + bucket_name + and getattr(request, "query_string", "") == "delete" + ) ) diff --git a/moto/secretsmanager/__init__.py b/moto/secretsmanager/__init__.py index c7fbb2869..5d41d07ae 100644 --- a/moto/secretsmanager/__init__.py +++ b/moto/secretsmanager/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import secretsmanager_backends from ..core.models import base_decorator -secretsmanager_backend = secretsmanager_backends['us-east-1'] +secretsmanager_backend = secretsmanager_backends["us-east-1"] mock_secretsmanager = base_decorator(secretsmanager_backends) diff --git a/moto/secretsmanager/exceptions.py b/moto/secretsmanager/exceptions.py index 7ef1a9239..13f1f2766 100644 --- a/moto/secretsmanager/exceptions.py +++ b/moto/secretsmanager/exceptions.py @@ -10,8 +10,7 @@ class ResourceNotFoundException(SecretsManagerClientError): def __init__(self, message): self.code = 404 super(ResourceNotFoundException, self).__init__( - "ResourceNotFoundException", - message, + "ResourceNotFoundException", message ) @@ -21,7 +20,7 @@ class SecretNotFoundException(SecretsManagerClientError): self.code = 404 super(SecretNotFoundException, self).__init__( "ResourceNotFoundException", - message=u"Secrets Manager can\u2019t find the specified secret." + message="Secrets Manager can\u2019t find the specified secret.", ) @@ -31,35 +30,32 @@ class SecretHasNoValueException(SecretsManagerClientError): self.code = 404 super(SecretHasNoValueException, self).__init__( "ResourceNotFoundException", - message=u"Secrets Manager can\u2019t find the specified secret " - u"value for staging label: {}".format(version_stage) + message="Secrets Manager can\u2019t find the specified secret " + "value for staging label: {}".format(version_stage), ) class ClientError(SecretsManagerClientError): def __init__(self, message): - super(ClientError, self).__init__( - 'InvalidParameterValue', - message) + super(ClientError, self).__init__("InvalidParameterValue", message) class InvalidParameterException(SecretsManagerClientError): def __init__(self, message): super(InvalidParameterException, self).__init__( - 'InvalidParameterException', - message) + "InvalidParameterException", message + ) class ResourceExistsException(SecretsManagerClientError): def __init__(self, message): super(ResourceExistsException, self).__init__( - 'ResourceExistsException', - message + "ResourceExistsException", message ) class InvalidRequestException(SecretsManagerClientError): def __init__(self, message): super(InvalidRequestException, self).__init__( - 'InvalidRequestException', - message) + "InvalidRequestException", message + ) diff --git a/moto/secretsmanager/models.py b/moto/secretsmanager/models.py index e1a380c39..6f9a00b0e 100644 --- a/moto/secretsmanager/models.py +++ b/moto/secretsmanager/models.py @@ -15,19 +15,17 @@ from .exceptions import ( InvalidParameterException, ResourceExistsException, InvalidRequestException, - ClientError + ClientError, ) from .utils import random_password, secret_arn class SecretsManager(BaseModel): - def __init__(self, region_name, **kwargs): self.region = region_name class SecretsManagerBackend(BaseBackend): - def __init__(self, region_name=None, **kwargs): super(SecretsManagerBackend, self).__init__() self.region = region_name @@ -52,123 +50,148 @@ class SecretsManagerBackend(BaseBackend): if not version_id and version_stage: # set version_id to match version_stage - versions_dict = self.secrets[secret_id]['versions'] + versions_dict = self.secrets[secret_id]["versions"] for ver_id, ver_val in versions_dict.items(): - if version_stage in ver_val['version_stages']: + if version_stage in ver_val["version_stages"]: version_id = ver_id break if not version_id: raise SecretNotFoundException() # TODO check this part - if 'deleted_date' in self.secrets[secret_id]: + if "deleted_date" in self.secrets[secret_id]: raise InvalidRequestException( "An error occurred (InvalidRequestException) when calling the GetSecretValue operation: You tried to \ perform the operation on a secret that's currently marked deleted." ) secret = self.secrets[secret_id] - version_id = version_id or secret['default_version_id'] + version_id = version_id or secret["default_version_id"] - secret_version = secret['versions'][version_id] + secret_version = secret["versions"][version_id] response_data = { - "ARN": secret_arn(self.region, secret['secret_id']), - "Name": secret['name'], - "VersionId": secret_version['version_id'], - "VersionStages": secret_version['version_stages'], - "CreatedDate": secret_version['createdate'], + "ARN": secret_arn(self.region, secret["secret_id"]), + "Name": secret["name"], + "VersionId": secret_version["version_id"], + "VersionStages": secret_version["version_stages"], + "CreatedDate": secret_version["createdate"], } - if 'secret_string' in secret_version: - response_data["SecretString"] = secret_version['secret_string'] + if "secret_string" in secret_version: + response_data["SecretString"] = secret_version["secret_string"] - if 'secret_binary' in secret_version: - response_data["SecretBinary"] = secret_version['secret_binary'] + if "secret_binary" in secret_version: + response_data["SecretBinary"] = secret_version["secret_binary"] - if 'secret_string' not in secret_version and 'secret_binary' not in secret_version: - raise SecretHasNoValueException(version_stage or u"AWSCURRENT") + if ( + "secret_string" not in secret_version + and "secret_binary" not in secret_version + ): + raise SecretHasNoValueException(version_stage or "AWSCURRENT") response = json.dumps(response_data) return response - def create_secret(self, name, secret_string=None, secret_binary=None, tags=[], **kwargs): + def create_secret( + self, name, secret_string=None, secret_binary=None, tags=[], **kwargs + ): # error if secret exists if name in self.secrets.keys(): - raise ResourceExistsException('A resource with the ID you requested already exists.') + raise ResourceExistsException( + "A resource with the ID you requested already exists." + ) - version_id = self._add_secret(name, secret_string=secret_string, secret_binary=secret_binary, tags=tags) + version_id = self._add_secret( + name, secret_string=secret_string, secret_binary=secret_binary, tags=tags + ) - response = json.dumps({ - "ARN": secret_arn(self.region, name), - "Name": name, - "VersionId": version_id, - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, name), + "Name": name, + "VersionId": version_id, + } + ) return response - def _add_secret(self, secret_id, secret_string=None, secret_binary=None, tags=[], version_id=None, version_stages=None): + def _add_secret( + self, + secret_id, + secret_string=None, + secret_binary=None, + tags=[], + version_id=None, + version_stages=None, + ): if version_stages is None: - version_stages = ['AWSCURRENT'] + version_stages = ["AWSCURRENT"] if not version_id: version_id = str(uuid.uuid4()) secret_version = { - 'createdate': int(time.time()), - 'version_id': version_id, - 'version_stages': version_stages, + "createdate": int(time.time()), + "version_id": version_id, + "version_stages": version_stages, } if secret_string is not None: - secret_version['secret_string'] = secret_string + secret_version["secret_string"] = secret_string if secret_binary is not None: - secret_version['secret_binary'] = secret_binary + secret_version["secret_binary"] = secret_binary if secret_id in self.secrets: # remove all old AWSPREVIOUS stages - for secret_verion_to_look_at in self.secrets[secret_id]['versions'].values(): - if 'AWSPREVIOUS' in secret_verion_to_look_at['version_stages']: - secret_verion_to_look_at['version_stages'].remove('AWSPREVIOUS') + for secret_verion_to_look_at in self.secrets[secret_id][ + "versions" + ].values(): + if "AWSPREVIOUS" in secret_verion_to_look_at["version_stages"]: + secret_verion_to_look_at["version_stages"].remove("AWSPREVIOUS") # set old AWSCURRENT secret to AWSPREVIOUS - previous_current_version_id = self.secrets[secret_id]['default_version_id'] - self.secrets[secret_id]['versions'][previous_current_version_id]['version_stages'] = ['AWSPREVIOUS'] + previous_current_version_id = self.secrets[secret_id]["default_version_id"] + self.secrets[secret_id]["versions"][previous_current_version_id][ + "version_stages" + ] = ["AWSPREVIOUS"] - self.secrets[secret_id]['versions'][version_id] = secret_version - self.secrets[secret_id]['default_version_id'] = version_id + self.secrets[secret_id]["versions"][version_id] = secret_version + self.secrets[secret_id]["default_version_id"] = version_id else: self.secrets[secret_id] = { - 'versions': { - version_id: secret_version - }, - 'default_version_id': version_id, + "versions": {version_id: secret_version}, + "default_version_id": version_id, } secret = self.secrets[secret_id] - secret['secret_id'] = secret_id - secret['name'] = secret_id - secret['rotation_enabled'] = False - secret['rotation_lambda_arn'] = '' - secret['auto_rotate_after_days'] = 0 - secret['tags'] = tags + secret["secret_id"] = secret_id + secret["name"] = secret_id + secret["rotation_enabled"] = False + secret["rotation_lambda_arn"] = "" + secret["auto_rotate_after_days"] = 0 + secret["tags"] = tags return version_id def put_secret_value(self, secret_id, secret_string, secret_binary, version_stages): - version_id = self._add_secret(secret_id, secret_string, secret_binary, version_stages=version_stages) + version_id = self._add_secret( + secret_id, secret_string, secret_binary, version_stages=version_stages + ) - response = json.dumps({ - 'ARN': secret_arn(self.region, secret_id), - 'Name': secret_id, - 'VersionId': version_id, - 'VersionStages': version_stages - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, secret_id), + "Name": secret_id, + "VersionId": version_id, + "VersionStages": version_stages, + } + ) return response @@ -178,34 +201,41 @@ class SecretsManagerBackend(BaseBackend): secret = self.secrets[secret_id] - response = json.dumps({ - "ARN": secret_arn(self.region, secret['secret_id']), - "Name": secret['name'], - "Description": "", - "KmsKeyId": "", - "RotationEnabled": secret['rotation_enabled'], - "RotationLambdaARN": secret['rotation_lambda_arn'], - "RotationRules": { - "AutomaticallyAfterDays": secret['auto_rotate_after_days'] - }, - "LastRotatedDate": None, - "LastChangedDate": None, - "LastAccessedDate": None, - "DeletedDate": secret.get('deleted_date', None), - "Tags": secret['tags'] - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, secret["secret_id"]), + "Name": secret["name"], + "Description": "", + "KmsKeyId": "", + "RotationEnabled": secret["rotation_enabled"], + "RotationLambdaARN": secret["rotation_lambda_arn"], + "RotationRules": { + "AutomaticallyAfterDays": secret["auto_rotate_after_days"] + }, + "LastRotatedDate": None, + "LastChangedDate": None, + "LastAccessedDate": None, + "DeletedDate": secret.get("deleted_date", None), + "Tags": secret["tags"], + } + ) return response - def rotate_secret(self, secret_id, client_request_token=None, - rotation_lambda_arn=None, rotation_rules=None): + def rotate_secret( + self, + secret_id, + client_request_token=None, + rotation_lambda_arn=None, + rotation_rules=None, + ): - rotation_days = 'AutomaticallyAfterDays' + rotation_days = "AutomaticallyAfterDays" if not self._is_valid_identifier(secret_id): raise SecretNotFoundException() - if 'deleted_date' in self.secrets[secret_id]: + if "deleted_date" in self.secrets[secret_id]: raise InvalidRequestException( "An error occurred (InvalidRequestException) when calling the RotateSecret operation: You tried to \ perform the operation on a secret that's currently marked deleted." @@ -214,18 +244,12 @@ class SecretsManagerBackend(BaseBackend): if client_request_token: token_length = len(client_request_token) if token_length < 32 or token_length > 64: - msg = ( - 'ClientRequestToken ' - 'must be 32-64 characters long.' - ) + msg = "ClientRequestToken " "must be 32-64 characters long." raise InvalidParameterException(msg) if rotation_lambda_arn: if len(rotation_lambda_arn) > 2048: - msg = ( - 'RotationLambdaARN ' - 'must <= 2048 characters long.' - ) + msg = "RotationLambdaARN " "must <= 2048 characters long." raise InvalidParameterException(msg) if rotation_rules: @@ -233,61 +257,82 @@ class SecretsManagerBackend(BaseBackend): rotation_period = rotation_rules[rotation_days] if rotation_period < 1 or rotation_period > 1000: msg = ( - 'RotationRules.AutomaticallyAfterDays ' - 'must be within 1-1000.' + "RotationRules.AutomaticallyAfterDays " "must be within 1-1000." ) raise InvalidParameterException(msg) secret = self.secrets[secret_id] - old_secret_version = secret['versions'][secret['default_version_id']] + old_secret_version = secret["versions"][secret["default_version_id"]] new_version_id = client_request_token or str(uuid.uuid4()) - self._add_secret(secret_id, old_secret_version['secret_string'], secret['tags'], version_id=new_version_id, version_stages=['AWSCURRENT']) + self._add_secret( + secret_id, + old_secret_version["secret_string"], + secret["tags"], + version_id=new_version_id, + version_stages=["AWSCURRENT"], + ) - secret['rotation_lambda_arn'] = rotation_lambda_arn or '' + secret["rotation_lambda_arn"] = rotation_lambda_arn or "" if rotation_rules: - secret['auto_rotate_after_days'] = rotation_rules.get(rotation_days, 0) - if secret['auto_rotate_after_days'] > 0: - secret['rotation_enabled'] = True + secret["auto_rotate_after_days"] = rotation_rules.get(rotation_days, 0) + if secret["auto_rotate_after_days"] > 0: + secret["rotation_enabled"] = True - if 'AWSCURRENT' in old_secret_version['version_stages']: - old_secret_version['version_stages'].remove('AWSCURRENT') + if "AWSCURRENT" in old_secret_version["version_stages"]: + old_secret_version["version_stages"].remove("AWSCURRENT") - response = json.dumps({ - "ARN": secret_arn(self.region, secret['secret_id']), - "Name": secret['name'], - "VersionId": new_version_id - }) + response = json.dumps( + { + "ARN": secret_arn(self.region, secret["secret_id"]), + "Name": secret["name"], + "VersionId": new_version_id, + } + ) return response - def get_random_password(self, password_length, - exclude_characters, exclude_numbers, - exclude_punctuation, exclude_uppercase, - exclude_lowercase, include_space, - require_each_included_type): + def get_random_password( + self, + password_length, + exclude_characters, + exclude_numbers, + exclude_punctuation, + exclude_uppercase, + exclude_lowercase, + include_space, + require_each_included_type, + ): # password size must have value less than or equal to 4096 if password_length > 4096: raise ClientError( "ClientError: An error occurred (ValidationException) \ when calling the GetRandomPassword operation: 1 validation error detected: Value '{}' at 'passwordLength' \ - failed to satisfy constraint: Member must have value less than or equal to 4096".format(password_length)) + failed to satisfy constraint: Member must have value less than or equal to 4096".format( + password_length + ) + ) if password_length < 4: raise InvalidParameterException( "InvalidParameterException: An error occurred (InvalidParameterException) \ - when calling the GetRandomPassword operation: Password length is too short based on the required types.") + when calling the GetRandomPassword operation: Password length is too short based on the required types." + ) - response = json.dumps({ - "RandomPassword": random_password(password_length, - exclude_characters, - exclude_numbers, - exclude_punctuation, - exclude_uppercase, - exclude_lowercase, - include_space, - require_each_included_type) - }) + response = json.dumps( + { + "RandomPassword": random_password( + password_length, + exclude_characters, + exclude_numbers, + exclude_punctuation, + exclude_uppercase, + exclude_lowercase, + include_space, + require_each_included_type, + ) + } + ) return response @@ -295,20 +340,24 @@ class SecretsManagerBackend(BaseBackend): secret = self.secrets[secret_id] version_list = [] - for version_id, version in secret['versions'].items(): - version_list.append({ - 'CreatedDate': int(time.time()), - 'LastAccessedDate': int(time.time()), - 'VersionId': version_id, - 'VersionStages': version['version_stages'], - }) + for version_id, version in secret["versions"].items(): + version_list.append( + { + "CreatedDate": int(time.time()), + "LastAccessedDate": int(time.time()), + "VersionId": version_id, + "VersionStages": version["version_stages"], + } + ) - response = json.dumps({ - 'ARN': secret['secret_id'], - 'Name': secret['name'], - 'NextToken': '', - 'Versions': version_list, - }) + response = json.dumps( + { + "ARN": secret["secret_id"], + "Name": secret["name"], + "NextToken": "", + "Versions": version_list, + } + ) return response @@ -319,35 +368,39 @@ class SecretsManagerBackend(BaseBackend): for secret in self.secrets.values(): versions_to_stages = {} - for version_id, version in secret['versions'].items(): - versions_to_stages[version_id] = version['version_stages'] + for version_id, version in secret["versions"].items(): + versions_to_stages[version_id] = version["version_stages"] - secret_list.append({ - "ARN": secret_arn(self.region, secret['secret_id']), - "DeletedDate": secret.get('deleted_date', None), - "Description": "", - "KmsKeyId": "", - "LastAccessedDate": None, - "LastChangedDate": None, - "LastRotatedDate": None, - "Name": secret['name'], - "RotationEnabled": secret['rotation_enabled'], - "RotationLambdaARN": secret['rotation_lambda_arn'], - "RotationRules": { - "AutomaticallyAfterDays": secret['auto_rotate_after_days'] - }, - "SecretVersionsToStages": versions_to_stages, - "Tags": secret['tags'] - }) + secret_list.append( + { + "ARN": secret_arn(self.region, secret["secret_id"]), + "DeletedDate": secret.get("deleted_date", None), + "Description": "", + "KmsKeyId": "", + "LastAccessedDate": None, + "LastChangedDate": None, + "LastRotatedDate": None, + "Name": secret["name"], + "RotationEnabled": secret["rotation_enabled"], + "RotationLambdaARN": secret["rotation_lambda_arn"], + "RotationRules": { + "AutomaticallyAfterDays": secret["auto_rotate_after_days"] + }, + "SecretVersionsToStages": versions_to_stages, + "Tags": secret["tags"], + } + ) return secret_list, None - def delete_secret(self, secret_id, recovery_window_in_days, force_delete_without_recovery): + def delete_secret( + self, secret_id, recovery_window_in_days, force_delete_without_recovery + ): if not self._is_valid_identifier(secret_id): raise SecretNotFoundException() - if 'deleted_date' in self.secrets[secret_id]: + if "deleted_date" in self.secrets[secret_id]: raise InvalidRequestException( "An error occurred (InvalidRequestException) when calling the DeleteSecret operation: You tried to \ perform the operation on a secret that's currently marked deleted." @@ -359,7 +412,9 @@ class SecretsManagerBackend(BaseBackend): use ForceDeleteWithoutRecovery in conjunction with RecoveryWindowInDays." ) - if recovery_window_in_days and (recovery_window_in_days < 7 or recovery_window_in_days > 30): + if recovery_window_in_days and ( + recovery_window_in_days < 7 or recovery_window_in_days > 30 + ): raise InvalidParameterException( "An error occurred (InvalidParameterException) when calling the DeleteSecret operation: The \ RecoveryWindowInDays value must be between 7 and 30 days (inclusive)." @@ -371,14 +426,16 @@ class SecretsManagerBackend(BaseBackend): secret = self.secrets.pop(secret_id, None) else: deletion_date += datetime.timedelta(days=recovery_window_in_days or 30) - self.secrets[secret_id]['deleted_date'] = self._unix_time_secs(deletion_date) + self.secrets[secret_id]["deleted_date"] = self._unix_time_secs( + deletion_date + ) secret = self.secrets.get(secret_id, None) if not secret: raise SecretNotFoundException() - arn = secret_arn(self.region, secret['secret_id']) - name = secret['name'] + arn = secret_arn(self.region, secret["secret_id"]) + name = secret["name"] return arn, name, self._unix_time_secs(deletion_date) @@ -387,18 +444,17 @@ class SecretsManagerBackend(BaseBackend): if not self._is_valid_identifier(secret_id): raise SecretNotFoundException() - self.secrets[secret_id].pop('deleted_date', None) + self.secrets[secret_id].pop("deleted_date", None) secret = self.secrets[secret_id] - arn = secret_arn(self.region, secret['secret_id']) - name = secret['name'] + arn = secret_arn(self.region, secret["secret_id"]) + name = secret["name"] return arn, name -available_regions = ( - boto3.session.Session().get_available_regions("secretsmanager") -) -secretsmanager_backends = {region: SecretsManagerBackend(region_name=region) - for region in available_regions} +available_regions = boto3.session.Session().get_available_regions("secretsmanager") +secretsmanager_backends = { + region: SecretsManagerBackend(region_name=region) for region in available_regions +} diff --git a/moto/secretsmanager/responses.py b/moto/secretsmanager/responses.py index 4995c4bc7..09df0fbbf 100644 --- a/moto/secretsmanager/responses.py +++ b/moto/secretsmanager/responses.py @@ -9,38 +9,37 @@ import json class SecretsManagerResponse(BaseResponse): - def get_secret_value(self): - secret_id = self._get_param('SecretId') - version_id = self._get_param('VersionId') - version_stage = self._get_param('VersionStage') + secret_id = self._get_param("SecretId") + version_id = self._get_param("VersionId") + version_stage = self._get_param("VersionStage") return secretsmanager_backends[self.region].get_secret_value( - secret_id=secret_id, - version_id=version_id, - version_stage=version_stage) + secret_id=secret_id, version_id=version_id, version_stage=version_stage + ) def create_secret(self): - name = self._get_param('Name') - secret_string = self._get_param('SecretString') - secret_binary = self._get_param('SecretBinary') - tags = self._get_param('Tags', if_none=[]) + name = self._get_param("Name") + secret_string = self._get_param("SecretString") + secret_binary = self._get_param("SecretBinary") + tags = self._get_param("Tags", if_none=[]) return secretsmanager_backends[self.region].create_secret( name=name, secret_string=secret_string, secret_binary=secret_binary, - tags=tags + tags=tags, ) def get_random_password(self): - password_length = self._get_param('PasswordLength', if_none=32) - exclude_characters = self._get_param('ExcludeCharacters', if_none='') - exclude_numbers = self._get_param('ExcludeNumbers', if_none=False) - exclude_punctuation = self._get_param('ExcludePunctuation', if_none=False) - exclude_uppercase = self._get_param('ExcludeUppercase', if_none=False) - exclude_lowercase = self._get_param('ExcludeLowercase', if_none=False) - include_space = self._get_param('IncludeSpace', if_none=False) + password_length = self._get_param("PasswordLength", if_none=32) + exclude_characters = self._get_param("ExcludeCharacters", if_none="") + exclude_numbers = self._get_param("ExcludeNumbers", if_none=False) + exclude_punctuation = self._get_param("ExcludePunctuation", if_none=False) + exclude_uppercase = self._get_param("ExcludeUppercase", if_none=False) + exclude_lowercase = self._get_param("ExcludeLowercase", if_none=False) + include_space = self._get_param("IncludeSpace", if_none=False) require_each_included_type = self._get_param( - 'RequireEachIncludedType', if_none=True) + "RequireEachIncludedType", if_none=True + ) return secretsmanager_backends[self.region].get_random_password( password_length=password_length, exclude_characters=exclude_characters, @@ -49,34 +48,34 @@ class SecretsManagerResponse(BaseResponse): exclude_uppercase=exclude_uppercase, exclude_lowercase=exclude_lowercase, include_space=include_space, - require_each_included_type=require_each_included_type + require_each_included_type=require_each_included_type, ) def describe_secret(self): - secret_id = self._get_param('SecretId') - return secretsmanager_backends[self.region].describe_secret( - secret_id=secret_id - ) + secret_id = self._get_param("SecretId") + return secretsmanager_backends[self.region].describe_secret(secret_id=secret_id) def rotate_secret(self): - client_request_token = self._get_param('ClientRequestToken') - rotation_lambda_arn = self._get_param('RotationLambdaARN') - rotation_rules = self._get_param('RotationRules') - secret_id = self._get_param('SecretId') + client_request_token = self._get_param("ClientRequestToken") + rotation_lambda_arn = self._get_param("RotationLambdaARN") + rotation_rules = self._get_param("RotationRules") + secret_id = self._get_param("SecretId") return secretsmanager_backends[self.region].rotate_secret( secret_id=secret_id, client_request_token=client_request_token, rotation_lambda_arn=rotation_lambda_arn, - rotation_rules=rotation_rules + rotation_rules=rotation_rules, ) def put_secret_value(self): - secret_id = self._get_param('SecretId', if_none='') - secret_string = self._get_param('SecretString') - secret_binary = self._get_param('SecretBinary') + secret_id = self._get_param("SecretId", if_none="") + secret_string = self._get_param("SecretString") + secret_binary = self._get_param("SecretBinary") if not secret_binary and not secret_string: - raise InvalidRequestException('You must provide either SecretString or SecretBinary.') - version_stages = self._get_param('VersionStages', if_none=['AWSCURRENT']) + raise InvalidRequestException( + "You must provide either SecretString or SecretBinary." + ) + version_stages = self._get_param("VersionStages", if_none=["AWSCURRENT"]) return secretsmanager_backends[self.region].put_secret_value( secret_id=secret_id, secret_binary=secret_binary, @@ -85,7 +84,7 @@ class SecretsManagerResponse(BaseResponse): ) def list_secret_version_ids(self): - secret_id = self._get_param('SecretId', if_none='') + secret_id = self._get_param("SecretId", if_none="") return secretsmanager_backends[self.region].list_secret_version_ids( secret_id=secret_id ) @@ -94,8 +93,7 @@ class SecretsManagerResponse(BaseResponse): max_results = self._get_int_param("MaxResults") next_token = self._get_param("NextToken") secret_list, next_token = secretsmanager_backends[self.region].list_secrets( - max_results=max_results, - next_token=next_token, + max_results=max_results, next_token=next_token ) return json.dumps(dict(SecretList=secret_list, NextToken=next_token)) @@ -113,6 +111,6 @@ class SecretsManagerResponse(BaseResponse): def restore_secret(self): secret_id = self._get_param("SecretId") arn, name = secretsmanager_backends[self.region].restore_secret( - secret_id=secret_id, + secret_id=secret_id ) return json.dumps(dict(ARN=arn, Name=name)) diff --git a/moto/secretsmanager/urls.py b/moto/secretsmanager/urls.py index 9e39e7263..57cbac0e4 100644 --- a/moto/secretsmanager/urls.py +++ b/moto/secretsmanager/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import SecretsManagerResponse -url_bases = [ - "https?://secretsmanager.(.+).amazonaws.com", -] +url_bases = ["https?://secretsmanager.(.+).amazonaws.com"] -url_paths = { - '{0}/$': SecretsManagerResponse.dispatch, -} +url_paths = {"{0}/$": SecretsManagerResponse.dispatch} diff --git a/moto/secretsmanager/utils.py b/moto/secretsmanager/utils.py index 231fea296..44385270c 100644 --- a/moto/secretsmanager/utils.py +++ b/moto/secretsmanager/utils.py @@ -6,55 +6,70 @@ import six import re -def random_password(password_length, exclude_characters, exclude_numbers, - exclude_punctuation, exclude_uppercase, exclude_lowercase, - include_space, require_each_included_type): +def random_password( + password_length, + exclude_characters, + exclude_numbers, + exclude_punctuation, + exclude_uppercase, + exclude_lowercase, + include_space, + require_each_included_type, +): - password = '' - required_characters = '' + password = "" + required_characters = "" if not exclude_lowercase and not exclude_uppercase: password += string.ascii_letters - required_characters += random.choice(_exclude_characters( - string.ascii_lowercase, exclude_characters)) - required_characters += random.choice(_exclude_characters( - string.ascii_uppercase, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.ascii_lowercase, exclude_characters) + ) + required_characters += random.choice( + _exclude_characters(string.ascii_uppercase, exclude_characters) + ) elif not exclude_lowercase: password += string.ascii_lowercase - required_characters += random.choice(_exclude_characters( - string.ascii_lowercase, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.ascii_lowercase, exclude_characters) + ) elif not exclude_uppercase: password += string.ascii_uppercase - required_characters += random.choice(_exclude_characters( - string.ascii_uppercase, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.ascii_uppercase, exclude_characters) + ) if not exclude_numbers: password += string.digits - required_characters += random.choice(_exclude_characters( - string.digits, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.digits, exclude_characters) + ) if not exclude_punctuation: password += string.punctuation - required_characters += random.choice(_exclude_characters( - string.punctuation, exclude_characters)) + required_characters += random.choice( + _exclude_characters(string.punctuation, exclude_characters) + ) if include_space: password += " " required_characters += " " - password = ''.join( - six.text_type(random.choice(password)) - for x in range(password_length)) + password = "".join( + six.text_type(random.choice(password)) for x in range(password_length) + ) if require_each_included_type: password = _add_password_require_each_included_type( - password, required_characters) + password, required_characters + ) password = _exclude_characters(password, exclude_characters) return password def secret_arn(region, secret_id): - id_string = ''.join(random.choice(string.ascii_letters) for _ in range(5)) + id_string = "".join(random.choice(string.ascii_letters) for _ in range(5)) return "arn:aws:secretsmanager:{0}:1234567890:secret:{1}-{2}".format( - region, secret_id, id_string) + region, secret_id, id_string + ) def _exclude_characters(password, exclude_characters): @@ -62,12 +77,12 @@ def _exclude_characters(password, exclude_characters): if c in string.punctuation: # Escape punctuation regex usage c = "\{0}".format(c) - password = re.sub(c, '', str(password)) + password = re.sub(c, "", str(password)) return password def _add_password_require_each_included_type(password, required_characters): - password_with_required_char = password[:-len(required_characters)] + password_with_required_char = password[: -len(required_characters)] password_with_required_char += required_characters return password_with_required_char diff --git a/moto/server.py b/moto/server.py index b245f6e6f..bbc309fe2 100644 --- a/moto/server.py +++ b/moto/server.py @@ -21,13 +21,13 @@ from moto.core.utils import convert_flask_to_httpretty_response HTTP_METHODS = ["GET", "POST", "PUT", "DELETE", "HEAD", "PATCH"] -DEFAULT_SERVICE_REGION = ('s3', 'us-east-1') +DEFAULT_SERVICE_REGION = ("s3", "us-east-1") # Map of unsigned calls to service-region as per AWS API docs # https://docs.aws.amazon.com/cognito/latest/developerguide/resource-permissions.html#amazon-cognito-signed-versus-unsigned-apis UNSIGNED_REQUESTS = { - 'AWSCognitoIdentityService': ('cognito-identity', 'us-east-1'), - 'AWSCognitoIdentityProviderService': ('cognito-idp', 'us-east-1'), + "AWSCognitoIdentityService": ("cognito-identity", "us-east-1"), + "AWSCognitoIdentityProviderService": ("cognito-idp", "us-east-1"), } @@ -44,7 +44,7 @@ class DomainDispatcherApplication(object): self.service = service def get_backend_for_host(self, host): - if host == 'moto_api': + if host == "moto_api": return host if self.service: @@ -55,11 +55,11 @@ class DomainDispatcherApplication(object): for backend_name, backend in BACKENDS.items(): for url_base in list(backend.values())[0].url_bases: - if re.match(url_base, 'http://%s' % host): + if re.match(url_base, "http://%s" % host): return backend_name def infer_service_region_host(self, environ): - auth = environ.get('HTTP_AUTHORIZATION') + auth = environ.get("HTTP_AUTHORIZATION") if auth: # Signed request # Parse auth header to find service assuming a SigV4 request @@ -76,43 +76,46 @@ class DomainDispatcherApplication(object): service, region = DEFAULT_SERVICE_REGION else: # Unsigned request - target = environ.get('HTTP_X_AMZ_TARGET') + target = environ.get("HTTP_X_AMZ_TARGET") if target: - service, _ = target.split('.', 1) + service, _ = target.split(".", 1) service, region = UNSIGNED_REQUESTS.get(service, DEFAULT_SERVICE_REGION) else: # S3 is the last resort when the target is also unknown service, region = DEFAULT_SERVICE_REGION - if service == 'dynamodb': - if environ['HTTP_X_AMZ_TARGET'].startswith('DynamoDBStreams'): - host = 'dynamodbstreams' + if service == "dynamodb": + if environ["HTTP_X_AMZ_TARGET"].startswith("DynamoDBStreams"): + host = "dynamodbstreams" else: - dynamo_api_version = environ['HTTP_X_AMZ_TARGET'].split("_")[1].split(".")[0] + dynamo_api_version = ( + environ["HTTP_X_AMZ_TARGET"].split("_")[1].split(".")[0] + ) # If Newer API version, use dynamodb2 if dynamo_api_version > "20111205": host = "dynamodb2" else: host = "{service}.{region}.amazonaws.com".format( - service=service, region=region) + service=service, region=region + ) return host def get_application(self, environ): - path_info = environ.get('PATH_INFO', '') + path_info = environ.get("PATH_INFO", "") # The URL path might contain non-ASCII text, for instance unicode S3 bucket names if six.PY2 and isinstance(path_info, str): path_info = six.u(path_info) if six.PY3 and isinstance(path_info, six.binary_type): - path_info = path_info.decode('utf-8') + path_info = path_info.decode("utf-8") if path_info.startswith("/moto-api") or path_info == "/favicon.ico": host = "moto_api" elif path_info.startswith("/latest/meta-data/"): host = "instance_metadata" else: - host = environ['HTTP_HOST'].split(':')[0] + host = environ["HTTP_HOST"].split(":")[0] with self.lock: backend = self.get_backend_for_host(host) @@ -141,15 +144,18 @@ class RegexConverter(BaseConverter): class AWSTestHelper(FlaskClient): - def action_data(self, action_name, **kwargs): """ Method calls resource with action_name and returns data of response. """ opts = {"Action": action_name} opts.update(kwargs) - res = self.get("/?{0}".format(urlencode(opts)), - headers={"Host": "{0}.us-east-1.amazonaws.com".format(self.application.service)}) + res = self.get( + "/?{0}".format(urlencode(opts)), + headers={ + "Host": "{0}.us-east-1.amazonaws.com".format(self.application.service) + }, + ) return res.data.decode("utf-8") def action_json(self, action_name, **kwargs): @@ -171,12 +177,12 @@ def create_backend_app(service): # Reset view functions to reset the app backend_app.view_functions = {} backend_app.url_map = Map() - backend_app.url_map.converters['regex'] = RegexConverter + backend_app.url_map.converters["regex"] = RegexConverter backend = list(BACKENDS[service].values())[0] for url_path, handler in backend.flask_paths.items(): view_func = convert_flask_to_httpretty_response(handler) - if handler.__name__ == 'dispatch': - endpoint = '{0}.dispatch'.format(handler.__self__.__name__) + if handler.__name__ == "dispatch": + endpoint = "{0}.dispatch".format(handler.__self__.__name__) else: endpoint = view_func.__name__ @@ -207,54 +213,57 @@ def main(argv=sys.argv[1:]): parser.add_argument( "service", type=str, - nargs='?', # http://stackoverflow.com/a/4480202/731592 - default=None) - parser.add_argument( - '-H', '--host', type=str, - help='Which host to bind', - default='127.0.0.1') - parser.add_argument( - '-p', '--port', type=int, - help='Port number to use for connection', - default=5000) - parser.add_argument( - '-r', '--reload', - action='store_true', - help='Reload server on a file change', - default=False + nargs="?", # http://stackoverflow.com/a/4480202/731592 + default=None, ) parser.add_argument( - '-s', '--ssl', - action='store_true', - help='Enable SSL encrypted connection with auto-generated certificate (use https://... URL)', - default=False + "-H", "--host", type=str, help="Which host to bind", default="127.0.0.1" ) parser.add_argument( - '-c', '--ssl-cert', type=str, - help='Path to SSL certificate', - default=None) + "-p", "--port", type=int, help="Port number to use for connection", default=5000 + ) parser.add_argument( - '-k', '--ssl-key', type=str, - help='Path to SSL private key', - default=None) + "-r", + "--reload", + action="store_true", + help="Reload server on a file change", + default=False, + ) + parser.add_argument( + "-s", + "--ssl", + action="store_true", + help="Enable SSL encrypted connection with auto-generated certificate (use https://... URL)", + default=False, + ) + parser.add_argument( + "-c", "--ssl-cert", type=str, help="Path to SSL certificate", default=None + ) + parser.add_argument( + "-k", "--ssl-key", type=str, help="Path to SSL private key", default=None + ) args = parser.parse_args(argv) # Wrap the main application - main_app = DomainDispatcherApplication( - create_backend_app, service=args.service) + main_app = DomainDispatcherApplication(create_backend_app, service=args.service) main_app.debug = True ssl_context = None if args.ssl_key and args.ssl_cert: ssl_context = (args.ssl_cert, args.ssl_key) elif args.ssl: - ssl_context = 'adhoc' + ssl_context = "adhoc" - run_simple(args.host, args.port, main_app, - threaded=True, use_reloader=args.reload, - ssl_context=ssl_context) + run_simple( + args.host, + args.port, + main_app, + threaded=True, + use_reloader=args.reload, + ssl_context=ssl_context, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/moto/ses/exceptions.py b/moto/ses/exceptions.py index f888af9f6..a905039e2 100644 --- a/moto/ses/exceptions.py +++ b/moto/ses/exceptions.py @@ -6,5 +6,4 @@ class MessageRejectedError(RESTError): code = 400 def __init__(self, message): - super(MessageRejectedError, self).__init__( - "MessageRejected", message) + super(MessageRejectedError, self).__init__("MessageRejected", message) diff --git a/moto/ses/feedback.py b/moto/ses/feedback.py index 2d32f9ce0..c3d630e59 100644 --- a/moto/ses/feedback.py +++ b/moto/ses/feedback.py @@ -11,32 +11,20 @@ COMMON_MAIL = { "sourceArn": "arn:aws:ses:us-west-2:888888888888:identity/example.com", "sourceIp": "127.0.3.0", "sendingAccountId": "123456789012", - "destination": [ - "recipient@example.com" - ], + "destination": ["recipient@example.com"], "headersTruncated": False, "headers": [ - { - "name": "From", - "value": "\"Sender Name\" " - }, - { - "name": "To", - "value": "\"Recipient Name\" " - } + {"name": "From", "value": '"Sender Name" '}, + {"name": "To", "value": '"Recipient Name" '}, ], "commonHeaders": { - "from": [ - "Sender Name " - ], + "from": ["Sender Name "], "date": "Mon, 08 Oct 2018 14:05:45 +0000", - "to": [ - "Recipient Name " - ], + "to": ["Recipient Name "], "messageId": " custom-message-ID", - "subject": "Message sent using Amazon SES" - } - } + "subject": "Message sent using Amazon SES", + }, + }, } BOUNCE = { "bounceType": "Permanent", @@ -46,30 +34,26 @@ BOUNCE = { "status": "5.0.0", "action": "failed", "diagnosticCode": "smtp; 550 user unknown", - "emailAddress": "recipient1@example.com" + "emailAddress": "recipient1@example.com", }, { "status": "4.0.0", "action": "delayed", - "emailAddress": "recipient2@example.com" - } + "emailAddress": "recipient2@example.com", + }, ], "reportingMTA": "example.com", "timestamp": "2012-05-25T14:59:38.605Z", "feedbackId": "000001378603176d-5a4b5ad9-6f30-4198-a8c3-b1eb0c270a1d-000000", - "remoteMtaIp": "127.0.2.0" + "remoteMtaIp": "127.0.2.0", } COMPLAINT = { "userAgent": "AnyCompany Feedback Loop (V0.01)", - "complainedRecipients": [ - { - "emailAddress": "recipient1@example.com" - } - ], + "complainedRecipients": [{"emailAddress": "recipient1@example.com"}], "complaintFeedbackType": "abuse", "arrivalDate": "2009-12-03T04:24:21.000-05:00", "timestamp": "2012-05-25T14:59:38.623Z", - "feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000" + "feedbackId": "000001378603177f-18c07c78-fa81-4a58-9dd1-fedc3cb8f49a-000000", } DELIVERY = { "timestamp": "2014-05-28T22:41:01.184Z", @@ -77,5 +61,5 @@ DELIVERY = { "recipients": ["success@simulator.amazonses.com"], "smtpResponse": "250 ok: Message 64111812 accepted", "reportingMTA": "a8-70.smtp-out.amazonses.com", - "remoteMtaIp": "127.0.2.0" + "remoteMtaIp": "127.0.2.0", } diff --git a/moto/ses/models.py b/moto/ses/models.py index 22af15427..353d6f4b7 100644 --- a/moto/ses/models.py +++ b/moto/ses/models.py @@ -40,7 +40,6 @@ class SESFeedback(BaseModel): class Message(BaseModel): - def __init__(self, message_id, source, subject, body, destinations): self.id = message_id self.source = source @@ -50,13 +49,7 @@ class Message(BaseModel): class TemplateMessage(BaseModel): - - def __init__(self, - message_id, - source, - template, - template_data, - destinations): + def __init__(self, message_id, source, template, template_data, destinations): self.id = message_id self.source = source self.template = template @@ -65,7 +58,6 @@ class TemplateMessage(BaseModel): class RawMessage(BaseModel): - def __init__(self, message_id, source, destinations, raw_data): self.id = message_id self.source = source @@ -74,7 +66,6 @@ class RawMessage(BaseModel): class SESQuota(BaseModel): - def __init__(self, sent): self.sent = sent @@ -84,7 +75,6 @@ class SESQuota(BaseModel): class SESBackend(BaseBackend): - def __init__(self): self.addresses = [] self.email_addresses = [] @@ -97,7 +87,7 @@ class SESBackend(BaseBackend): _, address = parseaddr(source) if address in self.addresses: return True - user, host = address.split('@', 1) + user, host = address.split("@", 1) return host in self.domains def verify_email_identity(self, address): @@ -116,7 +106,7 @@ class SESBackend(BaseBackend): return self.email_addresses def delete_identity(self, identity): - if '@' in identity: + if "@" in identity: self.addresses.remove(identity) else: self.domains.remove(identity) @@ -124,11 +114,9 @@ class SESBackend(BaseBackend): def send_email(self, source, subject, body, destinations, region): recipient_count = sum(map(len, destinations.values())) if recipient_count > RECIPIENT_LIMIT: - raise MessageRejectedError('Too many recipients.') + raise MessageRejectedError("Too many recipients.") if not self._is_verified_address(source): - raise MessageRejectedError( - "Email address not verified %s" % source - ) + raise MessageRejectedError("Email address not verified %s" % source) self.__process_sns_feedback__(source, destinations, region) @@ -138,23 +126,21 @@ class SESBackend(BaseBackend): self.sent_message_count += recipient_count return message - def send_templated_email(self, source, template, template_data, destinations, region): + def send_templated_email( + self, source, template, template_data, destinations, region + ): recipient_count = sum(map(len, destinations.values())) if recipient_count > RECIPIENT_LIMIT: - raise MessageRejectedError('Too many recipients.') + raise MessageRejectedError("Too many recipients.") if not self._is_verified_address(source): - raise MessageRejectedError( - "Email address not verified %s" % source - ) + raise MessageRejectedError("Email address not verified %s" % source) self.__process_sns_feedback__(source, destinations, region) message_id = get_random_message_id() - message = TemplateMessage(message_id, - source, - template, - template_data, - destinations) + message = TemplateMessage( + message_id, source, template, template_data, destinations + ) self.sent_messages.append(message) self.sent_message_count += recipient_count return message @@ -162,10 +148,11 @@ class SESBackend(BaseBackend): def __type_of_message__(self, destinations): """Checks the destination for any special address that could indicate delivery, complaint or bounce like in SES simualtor""" - alladdress = destinations.get( - "ToAddresses", []) + destinations.get( - "CcAddresses", []) + destinations.get( - "BccAddresses", []) + alladdress = ( + destinations.get("ToAddresses", []) + + destinations.get("CcAddresses", []) + + destinations.get("BccAddresses", []) + ) for addr in alladdress: if SESFeedback.SUCCESS_ADDR in addr: return SESFeedback.DELIVERY @@ -198,30 +185,29 @@ class SESBackend(BaseBackend): _, source_email_address = parseaddr(source) if source_email_address not in self.addresses: raise MessageRejectedError( - "Did not have authority to send from email %s" % source_email_address + "Did not have authority to send from email %s" + % source_email_address ) recipient_count = len(destinations) message = email.message_from_string(raw_data) if source is None: - if message['from'] is None: - raise MessageRejectedError( - "Source not specified" - ) + if message["from"] is None: + raise MessageRejectedError("Source not specified") - _, source_email_address = parseaddr(message['from']) + _, source_email_address = parseaddr(message["from"]) if source_email_address not in self.addresses: raise MessageRejectedError( - "Did not have authority to send from email %s" % source_email_address + "Did not have authority to send from email %s" + % source_email_address ) - for header in 'TO', 'CC', 'BCC': + for header in "TO", "CC", "BCC": recipient_count += sum( - d.strip() and 1 or 0 - for d in message.get(header, '').split(',') + d.strip() and 1 or 0 for d in message.get(header, "").split(",") ) if recipient_count > RECIPIENT_LIMIT: - raise MessageRejectedError('Too many recipients.') + raise MessageRejectedError("Too many recipients.") self.__process_sns_feedback__(source, destinations, region) diff --git a/moto/ses/responses.py b/moto/ses/responses.py index d49e47d84..1034aeb0d 100644 --- a/moto/ses/responses.py +++ b/moto/ses/responses.py @@ -8,15 +8,14 @@ from .models import ses_backend class EmailResponse(BaseResponse): - def verify_email_identity(self): - address = self.querystring.get('EmailAddress')[0] + address = self.querystring.get("EmailAddress")[0] ses_backend.verify_email_identity(address) template = self.response_template(VERIFY_EMAIL_IDENTITY) return template.render() def verify_email_address(self): - address = self.querystring.get('EmailAddress')[0] + address = self.querystring.get("EmailAddress")[0] ses_backend.verify_email_address(address) template = self.response_template(VERIFY_EMAIL_ADDRESS) return template.render() @@ -32,94 +31,88 @@ class EmailResponse(BaseResponse): return template.render(email_addresses=email_addresses) def verify_domain_dkim(self): - domain = self.querystring.get('Domain')[0] + domain = self.querystring.get("Domain")[0] ses_backend.verify_domain(domain) template = self.response_template(VERIFY_DOMAIN_DKIM_RESPONSE) return template.render() def verify_domain_identity(self): - domain = self.querystring.get('Domain')[0] + domain = self.querystring.get("Domain")[0] ses_backend.verify_domain(domain) template = self.response_template(VERIFY_DOMAIN_IDENTITY_RESPONSE) return template.render() def delete_identity(self): - domain = self.querystring.get('Identity')[0] + domain = self.querystring.get("Identity")[0] ses_backend.delete_identity(domain) template = self.response_template(DELETE_IDENTITY_RESPONSE) return template.render() def send_email(self): - bodydatakey = 'Message.Body.Text.Data' - if 'Message.Body.Html.Data' in self.querystring: - bodydatakey = 'Message.Body.Html.Data' + bodydatakey = "Message.Body.Text.Data" + if "Message.Body.Html.Data" in self.querystring: + bodydatakey = "Message.Body.Html.Data" body = self.querystring.get(bodydatakey)[0] - source = self.querystring.get('Source')[0] - subject = self.querystring.get('Message.Subject.Data')[0] - destinations = { - 'ToAddresses': [], - 'CcAddresses': [], - 'BccAddresses': [], - } + source = self.querystring.get("Source")[0] + subject = self.querystring.get("Message.Subject.Data")[0] + destinations = {"ToAddresses": [], "CcAddresses": [], "BccAddresses": []} for dest_type in destinations: # consume up to 51 to allow exception for i in six.moves.range(1, 52): - field = 'Destination.%s.member.%s' % (dest_type, i) + field = "Destination.%s.member.%s" % (dest_type, i) address = self.querystring.get(field) if address is None: break destinations[dest_type].append(address[0]) - message = ses_backend.send_email(source, subject, body, destinations, self.region) + message = ses_backend.send_email( + source, subject, body, destinations, self.region + ) template = self.response_template(SEND_EMAIL_RESPONSE) return template.render(message=message) def send_templated_email(self): - source = self.querystring.get('Source')[0] - template = self.querystring.get('Template') - template_data = self.querystring.get('TemplateData') + source = self.querystring.get("Source")[0] + template = self.querystring.get("Template") + template_data = self.querystring.get("TemplateData") - destinations = { - 'ToAddresses': [], - 'CcAddresses': [], - 'BccAddresses': [], - } + destinations = {"ToAddresses": [], "CcAddresses": [], "BccAddresses": []} for dest_type in destinations: # consume up to 51 to allow exception for i in six.moves.range(1, 52): - field = 'Destination.%s.member.%s' % (dest_type, i) + field = "Destination.%s.member.%s" % (dest_type, i) address = self.querystring.get(field) if address is None: break destinations[dest_type].append(address[0]) - message = ses_backend.send_templated_email(source, - template, - template_data, - destinations, - self.region) + message = ses_backend.send_templated_email( + source, template, template_data, destinations, self.region + ) template = self.response_template(SEND_TEMPLATED_EMAIL_RESPONSE) return template.render(message=message) def send_raw_email(self): - source = self.querystring.get('Source') + source = self.querystring.get("Source") if source is not None: - source, = source + (source,) = source - raw_data = self.querystring.get('RawMessage.Data')[0] + raw_data = self.querystring.get("RawMessage.Data")[0] raw_data = base64.b64decode(raw_data) if six.PY3: - raw_data = raw_data.decode('utf-8') + raw_data = raw_data.decode("utf-8") destinations = [] # consume up to 51 to allow exception for i in six.moves.range(1, 52): - field = 'Destinations.member.%s' % i + field = "Destinations.member.%s" % i address = self.querystring.get(field) if address is None: break destinations.append(address[0]) - message = ses_backend.send_raw_email(source, destinations, raw_data, self.region) + message = ses_backend.send_raw_email( + source, destinations, raw_data, self.region + ) template = self.response_template(SEND_RAW_EMAIL_RESPONSE) return template.render(message=message) diff --git a/moto/ses/urls.py b/moto/ses/urls.py index adfb4c6e4..5c26d2152 100644 --- a/moto/ses/urls.py +++ b/moto/ses/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import EmailResponse -url_bases = [ - "https?://email.(.+).amazonaws.com", - "https?://ses.(.+).amazonaws.com", -] +url_bases = ["https?://email.(.+).amazonaws.com", "https?://ses.(.+).amazonaws.com"] -url_paths = { - '{0}/$': EmailResponse.dispatch, -} +url_paths = {"{0}/$": EmailResponse.dispatch} diff --git a/moto/ses/utils.py b/moto/ses/utils.py index c674892d1..6d9151cea 100644 --- a/moto/ses/utils.py +++ b/moto/ses/utils.py @@ -4,16 +4,16 @@ import string def random_hex(length): - return ''.join(random.choice(string.ascii_lowercase) for x in range(length)) + return "".join(random.choice(string.ascii_lowercase) for x in range(length)) def get_random_message_id(): return "{0}-{1}-{2}-{3}-{4}-{5}-{6}".format( - random_hex(16), - random_hex(8), - random_hex(4), - random_hex(4), - random_hex(4), - random_hex(12), - random_hex(6), + random_hex(16), + random_hex(8), + random_hex(4), + random_hex(4), + random_hex(4), + random_hex(12), + random_hex(6), ) diff --git a/moto/settings.py b/moto/settings.py index 12402dc80..707c61397 100644 --- a/moto/settings.py +++ b/moto/settings.py @@ -1,4 +1,6 @@ import os -TEST_SERVER_MODE = os.environ.get('TEST_SERVER_MODE', '0').lower() == 'true' -INITIAL_NO_AUTH_ACTION_COUNT = float(os.environ.get('INITIAL_NO_AUTH_ACTION_COUNT', float('inf'))) +TEST_SERVER_MODE = os.environ.get("TEST_SERVER_MODE", "0").lower() == "true" +INITIAL_NO_AUTH_ACTION_COUNT = float( + os.environ.get("INITIAL_NO_AUTH_ACTION_COUNT", float("inf")) +) diff --git a/moto/sns/__init__.py b/moto/sns/__init__.py index bd36cb23d..896735b43 100644 --- a/moto/sns/__init__.py +++ b/moto/sns/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import sns_backends from ..core.models import base_decorator, deprecated_base_decorator -sns_backend = sns_backends['us-east-1'] +sns_backend = sns_backends["us-east-1"] mock_sns = base_decorator(sns_backends) mock_sns_deprecated = deprecated_base_decorator(sns_backends) diff --git a/moto/sns/exceptions.py b/moto/sns/exceptions.py index 6d29e7acb..187865220 100644 --- a/moto/sns/exceptions.py +++ b/moto/sns/exceptions.py @@ -6,8 +6,7 @@ class SNSNotFoundError(RESTError): code = 404 def __init__(self, message): - super(SNSNotFoundError, self).__init__( - "NotFound", message) + super(SNSNotFoundError, self).__init__("NotFound", message) class ResourceNotFoundError(RESTError): @@ -15,39 +14,36 @@ class ResourceNotFoundError(RESTError): def __init__(self): super(ResourceNotFoundError, self).__init__( - 'ResourceNotFound', 'Resource does not exist') + "ResourceNotFound", "Resource does not exist" + ) class DuplicateSnsEndpointError(RESTError): code = 400 def __init__(self, message): - super(DuplicateSnsEndpointError, self).__init__( - "DuplicateEndpoint", message) + super(DuplicateSnsEndpointError, self).__init__("DuplicateEndpoint", message) class SnsEndpointDisabled(RESTError): code = 400 def __init__(self, message): - super(SnsEndpointDisabled, self).__init__( - "EndpointDisabled", message) + super(SnsEndpointDisabled, self).__init__("EndpointDisabled", message) class SNSInvalidParameter(RESTError): code = 400 def __init__(self, message): - super(SNSInvalidParameter, self).__init__( - "InvalidParameter", message) + super(SNSInvalidParameter, self).__init__("InvalidParameter", message) class InvalidParameterValue(RESTError): code = 400 def __init__(self, message): - super(InvalidParameterValue, self).__init__( - "InvalidParameterValue", message) + super(InvalidParameterValue, self).__init__("InvalidParameterValue", message) class TagLimitExceededError(RESTError): @@ -55,12 +51,13 @@ class TagLimitExceededError(RESTError): def __init__(self): super(TagLimitExceededError, self).__init__( - 'TagLimitExceeded', 'Could not complete request: tag quota of per resource exceeded') + "TagLimitExceeded", + "Could not complete request: tag quota of per resource exceeded", + ) class InternalError(RESTError): code = 500 def __init__(self, message): - super(InternalError, self).__init__( - "InternalFailure", message) + super(InternalError, self).__init__("InternalFailure", message) diff --git a/moto/sns/models.py b/moto/sns/models.py index 4fcacb495..094fb820f 100644 --- a/moto/sns/models.py +++ b/moto/sns/models.py @@ -12,13 +12,22 @@ from boto3 import Session from moto.compat import OrderedDict from moto.core import BaseBackend, BaseModel -from moto.core.utils import iso_8601_datetime_with_milliseconds, camelcase_to_underscores +from moto.core.utils import ( + iso_8601_datetime_with_milliseconds, + camelcase_to_underscores, +) from moto.sqs import sqs_backends from moto.awslambda import lambda_backends from .exceptions import ( - SNSNotFoundError, DuplicateSnsEndpointError, SnsEndpointDisabled, SNSInvalidParameter, - InvalidParameterValue, InternalError, ResourceNotFoundError, TagLimitExceededError + SNSNotFoundError, + DuplicateSnsEndpointError, + SnsEndpointDisabled, + SNSInvalidParameter, + InvalidParameterValue, + InternalError, + ResourceNotFoundError, + TagLimitExceededError, ) from .utils import make_arn_for_topic, make_arn_for_subscription @@ -28,7 +37,6 @@ MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB class Topic(BaseModel): - def __init__(self, name, sns_backend): self.name = name self.sns_backend = sns_backend @@ -36,27 +44,33 @@ class Topic(BaseModel): self.display_name = "" self.delivery_policy = "" self.effective_delivery_policy = json.dumps(DEFAULT_EFFECTIVE_DELIVERY_POLICY) - self.arn = make_arn_for_topic( - self.account_id, name, sns_backend.region_name) + self.arn = make_arn_for_topic(self.account_id, name, sns_backend.region_name) self.subscriptions_pending = 0 self.subscriptions_confimed = 0 self.subscriptions_deleted = 0 - self._policy_json = self._create_default_topic_policy(sns_backend.region_name, self.account_id, name) + self._policy_json = self._create_default_topic_policy( + sns_backend.region_name, self.account_id, name + ) self._tags = {} def publish(self, message, subject=None, message_attributes=None): message_id = six.text_type(uuid.uuid4()) subscriptions, _ = self.sns_backend.list_subscriptions(self.arn) for subscription in subscriptions: - subscription.publish(message, message_id, subject=subject, - message_attributes=message_attributes) + subscription.publish( + message, + message_id, + subject=subject, + message_attributes=message_attributes, + ) return message_id def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'TopicName': + + if attribute_name == "TopicName": return self.name raise UnformattedGetAttTemplateException() @@ -73,52 +87,47 @@ class Topic(BaseModel): self._policy_json = json.loads(policy) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): sns_backend = sns_backends[region_name] - properties = cloudformation_json['Properties'] + properties = cloudformation_json["Properties"] - topic = sns_backend.create_topic( - properties.get("TopicName") - ) + topic = sns_backend.create_topic(properties.get("TopicName")) for subscription in properties.get("Subscription", []): - sns_backend.subscribe(topic.arn, subscription[ - 'Endpoint'], subscription['Protocol']) + sns_backend.subscribe( + topic.arn, subscription["Endpoint"], subscription["Protocol"] + ) return topic def _create_default_topic_policy(self, region_name, account_id, name): return { "Version": "2008-10-17", "Id": "__default_policy_ID", - "Statement": [{ - "Effect": "Allow", - "Sid": "__default_statement_ID", - "Principal": { - "AWS": "*" - }, - "Action": [ - "SNS:GetTopicAttributes", - "SNS:SetTopicAttributes", - "SNS:AddPermission", - "SNS:RemovePermission", - "SNS:DeleteTopic", - "SNS:Subscribe", - "SNS:ListSubscriptionsByTopic", - "SNS:Publish", - "SNS:Receive", - ], - "Resource": make_arn_for_topic( - self.account_id, name, region_name), - "Condition": { - "StringEquals": { - "AWS:SourceOwner": str(account_id) - } + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": make_arn_for_topic(self.account_id, name, region_name), + "Condition": {"StringEquals": {"AWS:SourceOwner": str(account_id)}}, } - }] + ], } class Subscription(BaseModel): - def __init__(self, topic, endpoint, protocol): self.topic = topic self.endpoint = endpoint @@ -128,39 +137,54 @@ class Subscription(BaseModel): self._filter_policy = None # filter policy as a dict, not json. self.confirmed = False - def publish(self, message, message_id, subject=None, - message_attributes=None): + def publish(self, message, message_id, subject=None, message_attributes=None): if not self._matches_filter_policy(message_attributes): return - if self.protocol == 'sqs': + if self.protocol == "sqs": queue_name = self.endpoint.split(":")[-1] region = self.endpoint.split(":")[3] - if self.attributes.get('RawMessageDelivery') != 'true': - enveloped_message = json.dumps(self.get_post_data(message, message_id, subject, message_attributes=message_attributes), sort_keys=True, indent=2, separators=(',', ': ')) + if self.attributes.get("RawMessageDelivery") != "true": + enveloped_message = json.dumps( + self.get_post_data( + message, + message_id, + subject, + message_attributes=message_attributes, + ), + sort_keys=True, + indent=2, + separators=(",", ": "), + ) else: enveloped_message = message sqs_backends[region].send_message(queue_name, enveloped_message) - elif self.protocol in ['http', 'https']: + elif self.protocol in ["http", "https"]: post_data = self.get_post_data(message, message_id, subject) - requests.post(self.endpoint, json=post_data, headers={'Content-Type': 'text/plain; charset=UTF-8'}) - elif self.protocol == 'lambda': + requests.post( + self.endpoint, + json=post_data, + headers={"Content-Type": "text/plain; charset=UTF-8"}, + ) + elif self.protocol == "lambda": # TODO: support bad function name # http://docs.aws.amazon.com/general/latest/gr/aws-arns-and-namespaces.html arr = self.endpoint.split(":") region = arr[3] qualifier = None if len(arr) == 7: - assert arr[5] == 'function' + assert arr[5] == "function" function_name = arr[-1] elif len(arr) == 8: - assert arr[5] == 'function' + assert arr[5] == "function" qualifier = arr[-1] function_name = arr[-2] else: assert False - lambda_backends[region].send_sns_message(function_name, message, subject=subject, qualifier=qualifier) + lambda_backends[region].send_sns_message( + function_name, message, subject=subject, qualifier=qualifier + ) def _matches_filter_policy(self, message_attributes): # TODO: support Anything-but matching, prefix matching and @@ -177,10 +201,10 @@ class Subscription(BaseModel): if isinstance(rule, six.string_types): if field not in message_attributes: return False - if message_attributes[field]['Value'] == rule: + if message_attributes[field]["Value"] == rule: return True try: - json_data = json.loads(message_attributes[field]['Value']) + json_data = json.loads(message_attributes[field]["Value"]) if rule in json_data: return True except (ValueError, TypeError): @@ -188,11 +212,13 @@ class Subscription(BaseModel): if isinstance(rule, (six.integer_types, float)): if field not in message_attributes: return False - if message_attributes[field]['Type'] == 'Number': - attribute_values = [message_attributes[field]['Value']] - elif message_attributes[field]['Type'] == 'String.Array': + if message_attributes[field]["Type"] == "Number": + attribute_values = [message_attributes[field]["Value"]] + elif message_attributes[field]["Type"] == "String.Array": try: - attribute_values = json.loads(message_attributes[field]['Value']) + attribute_values = json.loads( + message_attributes[field]["Value"] + ) if not isinstance(attribute_values, list): attribute_values = [attribute_values] except (ValueError, TypeError): @@ -208,29 +234,32 @@ class Subscription(BaseModel): if isinstance(rule, dict): keyword = list(rule.keys())[0] attributes = list(rule.values())[0] - if keyword == 'exists': + if keyword == "exists": if attributes and field in message_attributes: return True elif not attributes and field not in message_attributes: return True return False - return all(_field_match(field, rules, message_attributes) - for field, rules in six.iteritems(self._filter_policy)) + return all( + _field_match(field, rules, message_attributes) + for field, rules in six.iteritems(self._filter_policy) + ) - def get_post_data( - self, message, message_id, subject, message_attributes=None): + def get_post_data(self, message, message_id, subject, message_attributes=None): post_data = { "Type": "Notification", "MessageId": message_id, "TopicArn": self.topic.arn, "Subject": subject or "my subject", "Message": message, - "Timestamp": iso_8601_datetime_with_milliseconds(datetime.datetime.utcnow()), + "Timestamp": iso_8601_datetime_with_milliseconds( + datetime.datetime.utcnow() + ), "SignatureVersion": "1", "Signature": "EXAMPLElDMXvB8r9R83tGoNn0ecwd5UjllzsvSvbItzfaMpN2nk5HVSw7XnOn/49IkxDKz8YrlH2qJXj2iZB0Zo2O71c4qQk1fMUDi3LGpij7RCW7AW9vYYsSqIKRnFS94ilu7NFhUzLiieYr4BKHpdTmdD6c0esKEYBpabxDSc=", "SigningCertURL": "https://sns.us-east-1.amazonaws.com/SimpleNotificationService-f3ecfb7224c7233fe7bb5f59f96de52f.pem", - "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55" + "UnsubscribeURL": "https://sns.us-east-1.amazonaws.com/?Action=Unsubscribe&SubscriptionArn=arn:aws:sns:us-east-1:123456789012:some-topic:2bcfbf39-05c3-41de-beaa-fcfcc21c8f55", } if message_attributes: post_data["MessageAttributes"] = message_attributes @@ -238,7 +267,6 @@ class Subscription(BaseModel): class PlatformApplication(BaseModel): - def __init__(self, region, name, platform, attributes): self.region = region self.name = name @@ -248,14 +276,11 @@ class PlatformApplication(BaseModel): @property def arn(self): return "arn:aws:sns:{region}:123456789012:app/{platform}/{name}".format( - region=self.region, - platform=self.platform, - name=self.name, + region=self.region, platform=self.platform, name=self.name ) class PlatformEndpoint(BaseModel): - def __init__(self, region, application, custom_user_data, token, attributes): self.region = region self.application = application @@ -269,14 +294,14 @@ class PlatformEndpoint(BaseModel): def __fixup_attributes(self): # When AWS returns the attributes dict, it always contains these two elements, so we need to # automatically ensure they exist as well. - if 'Token' not in self.attributes: - self.attributes['Token'] = self.token - if 'Enabled' not in self.attributes: - self.attributes['Enabled'] = 'True' + if "Token" not in self.attributes: + self.attributes["Token"] = self.token + if "Enabled" not in self.attributes: + self.attributes["Enabled"] = "True" @property def enabled(self): - return json.loads(self.attributes.get('Enabled', 'true').lower()) + return json.loads(self.attributes.get("Enabled", "true").lower()) @property def arn(self): @@ -298,7 +323,6 @@ class PlatformEndpoint(BaseModel): class SNSBackend(BaseBackend): - def __init__(self, region_name): super(SNSBackend, self).__init__() self.topics = OrderedDict() @@ -307,7 +331,16 @@ class SNSBackend(BaseBackend): self.platform_endpoints = {} self.region_name = region_name self.sms_attributes = {} - self.opt_out_numbers = ['+447420500600', '+447420505401', '+447632960543', '+447632960028', '+447700900149', '+447700900550', '+447700900545', '+447700900907'] + self.opt_out_numbers = [ + "+447420500600", + "+447420505401", + "+447632960543", + "+447632960028", + "+447700900149", + "+447700900550", + "+447700900545", + "+447700900907", + ] def reset(self): region_name = self.region_name @@ -318,13 +351,19 @@ class SNSBackend(BaseBackend): self.sms_attributes.update(attrs) def create_topic(self, name, attributes=None, tags=None): - fails_constraints = not re.match(r'^[a-zA-Z0-9_-]{1,256}$', name) + fails_constraints = not re.match(r"^[a-zA-Z0-9_-]{1,256}$", name) if fails_constraints: - raise InvalidParameterValue("Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long.") + raise InvalidParameterValue( + "Topic names must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long." + ) candidate_topic = Topic(name, self) if attributes: for attribute in attributes: - setattr(candidate_topic, camelcase_to_underscores(attribute), attributes[attribute]) + setattr( + candidate_topic, + camelcase_to_underscores(attribute), + attributes[attribute], + ) if tags: candidate_topic._tags = tags if candidate_topic.arn in self.topics: @@ -337,8 +376,7 @@ class SNSBackend(BaseBackend): if next_token is None or not next_token: next_token = 0 next_token = int(next_token) - values = list(values_map.values())[ - next_token: next_token + DEFAULT_PAGE_SIZE] + values = list(values_map.values())[next_token : next_token + DEFAULT_PAGE_SIZE] if len(values) == DEFAULT_PAGE_SIZE: next_token = next_token + DEFAULT_PAGE_SIZE else: @@ -366,9 +404,9 @@ class SNSBackend(BaseBackend): def get_topic_from_phone_number(self, number): for subscription in self.subscriptions.values(): - if subscription.protocol == 'sms' and subscription.endpoint == number: + if subscription.protocol == "sms" and subscription.endpoint == number: return subscription.topic.arn - raise SNSNotFoundError('Could not find valid subscription') + raise SNSNotFoundError("Could not find valid subscription") def set_topic_attribute(self, topic_arn, attribute_name, attribute_value): topic = self.get_topic(topic_arn) @@ -382,11 +420,11 @@ class SNSBackend(BaseBackend): topic = self.get_topic(topic_arn) subscription = Subscription(topic, endpoint, protocol) attributes = { - 'PendingConfirmation': 'false', - 'Endpoint': endpoint, - 'TopicArn': topic_arn, - 'Protocol': protocol, - 'SubscriptionArn': subscription.arn + "PendingConfirmation": "false", + "Endpoint": endpoint, + "TopicArn": topic_arn, + "Protocol": protocol, + "SubscriptionArn": subscription.arn, } subscription.attributes = attributes self.subscriptions[subscription.arn] = subscription @@ -394,7 +432,11 @@ class SNSBackend(BaseBackend): def _find_subscription(self, topic_arn, endpoint, protocol): for subscription in self.subscriptions.values(): - if subscription.topic.arn == topic_arn and subscription.endpoint == endpoint and subscription.protocol == protocol: + if ( + subscription.topic.arn == topic_arn + and subscription.endpoint == endpoint + and subscription.protocol == protocol + ): return subscription return None @@ -405,7 +447,8 @@ class SNSBackend(BaseBackend): if topic_arn: topic = self.get_topic(topic_arn) filtered = OrderedDict( - [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)]) + [(sub.arn, sub) for sub in self._get_topic_subscriptions(topic)] + ) return self._get_values_nexttoken(filtered, next_token) else: return self._get_values_nexttoken(self.subscriptions, next_token) @@ -413,15 +456,18 @@ class SNSBackend(BaseBackend): def publish(self, arn, message, subject=None, message_attributes=None): if subject is not None and len(subject) > 100: # Note that the AWS docs around length are wrong: https://github.com/spulec/moto/issues/1503 - raise ValueError('Subject must be less than 100 characters') + raise ValueError("Subject must be less than 100 characters") if len(message) > MAXIMUM_MESSAGE_LENGTH: - raise InvalidParameterValue("An error occurred (InvalidParameter) when calling the Publish operation: Invalid parameter: Message too long") + raise InvalidParameterValue( + "An error occurred (InvalidParameter) when calling the Publish operation: Invalid parameter: Message too long" + ) try: topic = self.get_topic(arn) - message_id = topic.publish(message, subject=subject, - message_attributes=message_attributes) + message_id = topic.publish( + message, subject=subject, message_attributes=message_attributes + ) except SNSNotFoundError: endpoint = self.get_endpoint(arn) message_id = endpoint.publish(message) @@ -436,8 +482,7 @@ class SNSBackend(BaseBackend): try: return self.applications[arn] except KeyError: - raise SNSNotFoundError( - "Application with arn {0} not found".format(arn)) + raise SNSNotFoundError("Application with arn {0} not found".format(arn)) def set_application_attributes(self, arn, attributes): application = self.get_application(arn) @@ -450,18 +495,23 @@ class SNSBackend(BaseBackend): def delete_platform_application(self, platform_arn): self.applications.pop(platform_arn) - def create_platform_endpoint(self, region, application, custom_user_data, token, attributes): - if any(token == endpoint.token for endpoint in self.platform_endpoints.values()): + def create_platform_endpoint( + self, region, application, custom_user_data, token, attributes + ): + if any( + token == endpoint.token for endpoint in self.platform_endpoints.values() + ): raise DuplicateSnsEndpointError("Duplicate endpoint token: %s" % token) platform_endpoint = PlatformEndpoint( - region, application, custom_user_data, token, attributes) + region, application, custom_user_data, token, attributes + ) self.platform_endpoints[platform_endpoint.arn] = platform_endpoint return platform_endpoint def list_endpoints_by_platform_application(self, application_arn): return [ - endpoint for endpoint - in self.platform_endpoints.values() + endpoint + for endpoint in self.platform_endpoints.values() if endpoint.application.arn == application_arn ] @@ -469,8 +519,7 @@ class SNSBackend(BaseBackend): try: return self.platform_endpoints[arn] except KeyError: - raise SNSNotFoundError( - "Endpoint with arn {0} not found".format(arn)) + raise SNSNotFoundError("Endpoint with arn {0} not found".format(arn)) def set_endpoint_attributes(self, arn, attributes): endpoint = self.get_endpoint(arn) @@ -481,8 +530,7 @@ class SNSBackend(BaseBackend): try: del self.platform_endpoints[arn] except KeyError: - raise SNSNotFoundError( - "Endpoint with arn {0} not found".format(arn)) + raise SNSNotFoundError("Endpoint with arn {0} not found".format(arn)) def get_subscription_attributes(self, arn): _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] @@ -493,8 +541,8 @@ class SNSBackend(BaseBackend): return subscription.attributes def set_subscription_attributes(self, arn, name, value): - if name not in ['RawMessageDelivery', 'DeliveryPolicy', 'FilterPolicy']: - raise SNSInvalidParameter('AttributeName') + if name not in ["RawMessageDelivery", "DeliveryPolicy", "FilterPolicy"]: + raise SNSInvalidParameter("AttributeName") # TODO: should do validation _subscription = [_ for _ in self.subscriptions.values() if _.arn == arn] @@ -504,7 +552,7 @@ class SNSBackend(BaseBackend): subscription.attributes[name] = value - if name == 'FilterPolicy': + if name == "FilterPolicy": filter_policy = json.loads(value) self._validate_filter_policy(filter_policy) subscription._filter_policy = filter_policy @@ -517,7 +565,9 @@ class SNSBackend(BaseBackend): # Even the offical documentation states the total combination of values must not exceed 100, in reality it is 150 # https://docs.aws.amazon.com/sns/latest/dg/sns-subscription-filter-policies.html#subscription-filter-policy-constraints if combinations > 150: - raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Filter policy is too complex") + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: Filter policy is too complex" + ) for field, rules in six.iteritems(value): for rule in rules: @@ -534,57 +584,74 @@ class SNSBackend(BaseBackend): if isinstance(rule, dict): keyword = list(rule.keys())[0] attributes = list(rule.values())[0] - if keyword == 'anything-but': + if keyword == "anything-but": continue - elif keyword == 'exists': + elif keyword == "exists": if not isinstance(attributes, bool): - raise SNSInvalidParameter("Invalid parameter: FilterPolicy: exists match pattern must be either true or false.") + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: exists match pattern must be either true or false." + ) continue - elif keyword == 'numeric': + elif keyword == "numeric": continue - elif keyword == 'prefix': + elif keyword == "prefix": continue else: - raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Unrecognized match type {type}".format(type=keyword)) + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: Unrecognized match type {type}".format( + type=keyword + ) + ) - raise SNSInvalidParameter("Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null") + raise SNSInvalidParameter( + "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" + ) def add_permission(self, topic_arn, label, aws_account_ids, action_names): if topic_arn not in self.topics: - raise SNSNotFoundError('Topic does not exist') + raise SNSNotFoundError("Topic does not exist") policy = self.topics[topic_arn]._policy_json - statement = next((statement for statement in policy['Statement'] if statement['Sid'] == label), None) + statement = next( + ( + statement + for statement in policy["Statement"] + if statement["Sid"] == label + ), + None, + ) if statement: - raise SNSInvalidParameter('Statement already exists') + raise SNSInvalidParameter("Statement already exists") if any(action_name not in VALID_POLICY_ACTIONS for action_name in action_names): - raise SNSInvalidParameter('Policy statement action out of service scope!') + raise SNSInvalidParameter("Policy statement action out of service scope!") - principals = ['arn:aws:iam::{}:root'.format(account_id) for account_id in aws_account_ids] - actions = ['SNS:{}'.format(action_name) for action_name in action_names] + principals = [ + "arn:aws:iam::{}:root".format(account_id) for account_id in aws_account_ids + ] + actions = ["SNS:{}".format(action_name) for action_name in action_names] statement = { - 'Sid': label, - 'Effect': 'Allow', - 'Principal': { - 'AWS': principals[0] if len(principals) == 1 else principals - }, - 'Action': actions[0] if len(actions) == 1 else actions, - 'Resource': topic_arn + "Sid": label, + "Effect": "Allow", + "Principal": {"AWS": principals[0] if len(principals) == 1 else principals}, + "Action": actions[0] if len(actions) == 1 else actions, + "Resource": topic_arn, } - self.topics[topic_arn]._policy_json['Statement'].append(statement) + self.topics[topic_arn]._policy_json["Statement"].append(statement) def remove_permission(self, topic_arn, label): if topic_arn not in self.topics: - raise SNSNotFoundError('Topic does not exist') + raise SNSNotFoundError("Topic does not exist") - statements = self.topics[topic_arn]._policy_json['Statement'] - statements = [statement for statement in statements if statement['Sid'] != label] + statements = self.topics[topic_arn]._policy_json["Statement"] + statements = [ + statement for statement in statements if statement["Sid"] != label + ] - self.topics[topic_arn]._policy_json['Statement'] = statements + self.topics[topic_arn]._policy_json["Statement"] = statements def list_tags_for_resource(self, resource_arn): if resource_arn not in self.topics: @@ -613,34 +680,34 @@ class SNSBackend(BaseBackend): sns_backends = {} -for region in Session().get_available_regions('sns'): +for region in Session().get_available_regions("sns"): sns_backends[region] = SNSBackend(region) DEFAULT_EFFECTIVE_DELIVERY_POLICY = { - 'http': { - 'disableSubscriptionOverrides': False, - 'defaultHealthyRetryPolicy': { - 'numNoDelayRetries': 0, - 'numMinDelayRetries': 0, - 'minDelayTarget': 20, - 'maxDelayTarget': 20, - 'numMaxDelayRetries': 0, - 'numRetries': 3, - 'backoffFunction': 'linear' - } + "http": { + "disableSubscriptionOverrides": False, + "defaultHealthyRetryPolicy": { + "numNoDelayRetries": 0, + "numMinDelayRetries": 0, + "minDelayTarget": 20, + "maxDelayTarget": 20, + "numMaxDelayRetries": 0, + "numRetries": 3, + "backoffFunction": "linear", + }, } } VALID_POLICY_ACTIONS = [ - 'GetTopicAttributes', - 'SetTopicAttributes', - 'AddPermission', - 'RemovePermission', - 'DeleteTopic', - 'Subscribe', - 'ListSubscriptionsByTopic', - 'Publish', - 'Receive' + "GetTopicAttributes", + "SetTopicAttributes", + "AddPermission", + "RemovePermission", + "DeleteTopic", + "Subscribe", + "ListSubscriptionsByTopic", + "Publish", + "Receive", ] diff --git a/moto/sns/responses.py b/moto/sns/responses.py index ced2d68a1..d6470199e 100644 --- a/moto/sns/responses.py +++ b/moto/sns/responses.py @@ -11,562 +11,620 @@ from .utils import is_e164 class SNSResponse(BaseResponse): - SMS_ATTR_REGEX = re.compile(r'^attributes\.entry\.(?P\d+)\.(?Pkey|value)$') - OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r'^\+?\d+$') + SMS_ATTR_REGEX = re.compile( + r"^attributes\.entry\.(?P\d+)\.(?Pkey|value)$" + ) + OPT_OUT_PHONE_NUMBER_REGEX = re.compile(r"^\+?\d+$") @property def backend(self): return sns_backends[self.region] - def _error(self, code, message, sender='Sender'): + def _error(self, code, message, sender="Sender"): template = self.response_template(ERROR_RESPONSE) return template.render(code=code, message=message, sender=sender) def _get_attributes(self): - attributes = self._get_list_prefix('Attributes.entry') - return dict( - (attribute['key'], attribute['value']) - for attribute - in attributes - ) + attributes = self._get_list_prefix("Attributes.entry") + return dict((attribute["key"], attribute["value"]) for attribute in attributes) def _get_tags(self): - tags = self._get_list_prefix('Tags.member') - return {tag['key']: tag['value'] for tag in tags} + tags = self._get_list_prefix("Tags.member") + return {tag["key"]: tag["value"] for tag in tags} - def _parse_message_attributes(self, prefix='', value_namespace='Value.'): + def _parse_message_attributes(self, prefix="", value_namespace="Value."): message_attributes = self._get_object_map( - 'MessageAttributes.entry', - name='Name', - value='Value' + "MessageAttributes.entry", name="Name", value="Value" ) # SNS converts some key names before forwarding messages # DataType -> Type, StringValue -> Value, BinaryValue -> Value transformed_message_attributes = {} for name, value in message_attributes.items(): # validation - data_type = value['DataType'] + data_type = value["DataType"] if not data_type: raise InvalidParameterValue( "The message attribute '{0}' must contain non-empty " - "message attribute value.".format(name)) + "message attribute value.".format(name) + ) - data_type_parts = data_type.split('.') - if (len(data_type_parts) > 2 or - data_type_parts[0] not in ['String', 'Binary', 'Number']): + data_type_parts = data_type.split(".") + if len(data_type_parts) > 2 or data_type_parts[0] not in [ + "String", + "Binary", + "Number", + ]: raise InvalidParameterValue( "The message attribute '{0}' has an invalid message " "attribute type, the set of supported type prefixes is " - "Binary, Number, and String.".format(name)) + "Binary, Number, and String.".format(name) + ) transform_value = None - if 'StringValue' in value: - if data_type == 'Number': + if "StringValue" in value: + if data_type == "Number": try: - transform_value = float(value['StringValue']) + transform_value = float(value["StringValue"]) except ValueError: raise InvalidParameterValue( "An error occurred (ParameterValueInvalid) " "when calling the Publish operation: " - "Could not cast message attribute '{0}' value to number.".format(name)) + "Could not cast message attribute '{0}' value to number.".format( + name + ) + ) else: - transform_value = value['StringValue'] - elif 'BinaryValue' in value: - transform_value = value['BinaryValue'] + transform_value = value["StringValue"] + elif "BinaryValue" in value: + transform_value = value["BinaryValue"] if not transform_value: raise InvalidParameterValue( "The message attribute '{0}' must contain non-empty " "message attribute value for message attribute " - "type '{1}'.".format(name, data_type[0])) + "type '{1}'.".format(name, data_type[0]) + ) # transformation transformed_message_attributes[name] = { - 'Type': data_type, 'Value': transform_value + "Type": data_type, + "Value": transform_value, } return transformed_message_attributes def create_topic(self): - name = self._get_param('Name') + name = self._get_param("Name") attributes = self._get_attributes() tags = self._get_tags() topic = self.backend.create_topic(name, attributes, tags) if self.request_json: - return json.dumps({ - 'CreateTopicResponse': { - 'CreateTopicResult': { - 'TopicArn': topic.arn, - }, - 'ResponseMetadata': { - 'RequestId': 'a8dec8b3-33a4-11df-8963-01868b7c937a', + return json.dumps( + { + "CreateTopicResponse": { + "CreateTopicResult": {"TopicArn": topic.arn}, + "ResponseMetadata": { + "RequestId": "a8dec8b3-33a4-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(CREATE_TOPIC_TEMPLATE) return template.render(topic=topic) def list_topics(self): - next_token = self._get_param('NextToken') + next_token = self._get_param("NextToken") topics, next_token = self.backend.list_topics(next_token=next_token) if self.request_json: - return json.dumps({ - 'ListTopicsResponse': { - 'ListTopicsResult': { - 'Topics': [{'TopicArn': topic.arn} for topic in topics], - 'NextToken': next_token, - } - }, - 'ResponseMetadata': { - 'RequestId': 'a8dec8b3-33a4-11df-8963-01868b7c937a', + return json.dumps( + { + "ListTopicsResponse": { + "ListTopicsResult": { + "Topics": [{"TopicArn": topic.arn} for topic in topics], + "NextToken": next_token, + } + }, + "ResponseMetadata": { + "RequestId": "a8dec8b3-33a4-11df-8963-01868b7c937a" + }, } - }) + ) template = self.response_template(LIST_TOPICS_TEMPLATE) return template.render(topics=topics, next_token=next_token) def delete_topic(self): - topic_arn = self._get_param('TopicArn') + topic_arn = self._get_param("TopicArn") self.backend.delete_topic(topic_arn) if self.request_json: - return json.dumps({ - 'DeleteTopicResponse': { - 'ResponseMetadata': { - 'RequestId': 'a8dec8b3-33a4-11df-8963-01868b7c937a', + return json.dumps( + { + "DeleteTopicResponse": { + "ResponseMetadata": { + "RequestId": "a8dec8b3-33a4-11df-8963-01868b7c937a" + } } } - }) + ) template = self.response_template(DELETE_TOPIC_TEMPLATE) return template.render() def get_topic_attributes(self): - topic_arn = self._get_param('TopicArn') + topic_arn = self._get_param("TopicArn") topic = self.backend.get_topic(topic_arn) if self.request_json: - return json.dumps({ - "GetTopicAttributesResponse": { - "GetTopicAttributesResult": { - "Attributes": { - "Owner": topic.account_id, - "Policy": topic.policy, - "TopicArn": topic.arn, - "DisplayName": topic.display_name, - "SubscriptionsPending": topic.subscriptions_pending, - "SubscriptionsConfirmed": topic.subscriptions_confimed, - "SubscriptionsDeleted": topic.subscriptions_deleted, - "DeliveryPolicy": topic.delivery_policy, - "EffectiveDeliveryPolicy": topic.effective_delivery_policy, - } - }, - "ResponseMetadata": { - "RequestId": "057f074c-33a7-11df-9540-99d0768312d3" + return json.dumps( + { + "GetTopicAttributesResponse": { + "GetTopicAttributesResult": { + "Attributes": { + "Owner": topic.account_id, + "Policy": topic.policy, + "TopicArn": topic.arn, + "DisplayName": topic.display_name, + "SubscriptionsPending": topic.subscriptions_pending, + "SubscriptionsConfirmed": topic.subscriptions_confimed, + "SubscriptionsDeleted": topic.subscriptions_deleted, + "DeliveryPolicy": topic.delivery_policy, + "EffectiveDeliveryPolicy": topic.effective_delivery_policy, + } + }, + "ResponseMetadata": { + "RequestId": "057f074c-33a7-11df-9540-99d0768312d3" + }, } } - }) + ) template = self.response_template(GET_TOPIC_ATTRIBUTES_TEMPLATE) return template.render(topic=topic) def set_topic_attributes(self): - topic_arn = self._get_param('TopicArn') - attribute_name = self._get_param('AttributeName') + topic_arn = self._get_param("TopicArn") + attribute_name = self._get_param("AttributeName") attribute_name = camelcase_to_underscores(attribute_name) - attribute_value = self._get_param('AttributeValue') - self.backend.set_topic_attribute( - topic_arn, attribute_name, attribute_value) + attribute_value = self._get_param("AttributeValue") + self.backend.set_topic_attribute(topic_arn, attribute_name, attribute_value) if self.request_json: - return json.dumps({ - "SetTopicAttributesResponse": { - "ResponseMetadata": { - "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + return json.dumps( + { + "SetTopicAttributesResponse": { + "ResponseMetadata": { + "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + } } } - }) + ) template = self.response_template(SET_TOPIC_ATTRIBUTES_TEMPLATE) return template.render() def subscribe(self): - topic_arn = self._get_param('TopicArn') - endpoint = self._get_param('Endpoint') - protocol = self._get_param('Protocol') + topic_arn = self._get_param("TopicArn") + endpoint = self._get_param("Endpoint") + protocol = self._get_param("Protocol") attributes = self._get_attributes() - if protocol == 'sms' and not is_e164(endpoint): - return self._error( - 'InvalidParameter', - 'Phone number does not meet the E164 format' - ), dict(status=400) + if protocol == "sms" and not is_e164(endpoint): + return ( + self._error( + "InvalidParameter", "Phone number does not meet the E164 format" + ), + dict(status=400), + ) subscription = self.backend.subscribe(topic_arn, endpoint, protocol) if attributes is not None: for attr_name, attr_value in attributes.items(): - self.backend.set_subscription_attributes(subscription.arn, attr_name, attr_value) + self.backend.set_subscription_attributes( + subscription.arn, attr_name, attr_value + ) if self.request_json: - return json.dumps({ - "SubscribeResponse": { - "SubscribeResult": { - "SubscriptionArn": subscription.arn, - }, - "ResponseMetadata": { - "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + return json.dumps( + { + "SubscribeResponse": { + "SubscribeResult": {"SubscriptionArn": subscription.arn}, + "ResponseMetadata": { + "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + }, } } - }) + ) template = self.response_template(SUBSCRIBE_TEMPLATE) return template.render(subscription=subscription) def unsubscribe(self): - subscription_arn = self._get_param('SubscriptionArn') + subscription_arn = self._get_param("SubscriptionArn") self.backend.unsubscribe(subscription_arn) if self.request_json: - return json.dumps({ - "UnsubscribeResponse": { - "ResponseMetadata": { - "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + return json.dumps( + { + "UnsubscribeResponse": { + "ResponseMetadata": { + "RequestId": "a8763b99-33a7-11df-a9b7-05d48da6f042" + } } } - }) + ) template = self.response_template(UNSUBSCRIBE_TEMPLATE) return template.render() def list_subscriptions(self): - next_token = self._get_param('NextToken') + next_token = self._get_param("NextToken") subscriptions, next_token = self.backend.list_subscriptions( - next_token=next_token) + next_token=next_token + ) if self.request_json: - return json.dumps({ - "ListSubscriptionsResponse": { - "ListSubscriptionsResult": { - "Subscriptions": [{ - "TopicArn": subscription.topic.arn, - "Protocol": subscription.protocol, - "SubscriptionArn": subscription.arn, - "Owner": subscription.topic.account_id, - "Endpoint": subscription.endpoint, - } for subscription in subscriptions], - 'NextToken': next_token, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "ListSubscriptionsResponse": { + "ListSubscriptionsResult": { + "Subscriptions": [ + { + "TopicArn": subscription.topic.arn, + "Protocol": subscription.protocol, + "SubscriptionArn": subscription.arn, + "Owner": subscription.topic.account_id, + "Endpoint": subscription.endpoint, + } + for subscription in subscriptions + ], + "NextToken": next_token, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(LIST_SUBSCRIPTIONS_TEMPLATE) - return template.render(subscriptions=subscriptions, - next_token=next_token) + return template.render(subscriptions=subscriptions, next_token=next_token) def list_subscriptions_by_topic(self): - topic_arn = self._get_param('TopicArn') - next_token = self._get_param('NextToken') + topic_arn = self._get_param("TopicArn") + next_token = self._get_param("NextToken") subscriptions, next_token = self.backend.list_subscriptions( - topic_arn, next_token=next_token) + topic_arn, next_token=next_token + ) if self.request_json: - return json.dumps({ - "ListSubscriptionsByTopicResponse": { - "ListSubscriptionsByTopicResult": { - "Subscriptions": [{ - "TopicArn": subscription.topic.arn, - "Protocol": subscription.protocol, - "SubscriptionArn": subscription.arn, - "Owner": subscription.topic.account_id, - "Endpoint": subscription.endpoint, - } for subscription in subscriptions], - 'NextToken': next_token, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "ListSubscriptionsByTopicResponse": { + "ListSubscriptionsByTopicResult": { + "Subscriptions": [ + { + "TopicArn": subscription.topic.arn, + "Protocol": subscription.protocol, + "SubscriptionArn": subscription.arn, + "Owner": subscription.topic.account_id, + "Endpoint": subscription.endpoint, + } + for subscription in subscriptions + ], + "NextToken": next_token, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(LIST_SUBSCRIPTIONS_BY_TOPIC_TEMPLATE) - return template.render(subscriptions=subscriptions, - next_token=next_token) + return template.render(subscriptions=subscriptions, next_token=next_token) def publish(self): - target_arn = self._get_param('TargetArn') - topic_arn = self._get_param('TopicArn') - phone_number = self._get_param('PhoneNumber') - subject = self._get_param('Subject') + target_arn = self._get_param("TargetArn") + topic_arn = self._get_param("TopicArn") + phone_number = self._get_param("PhoneNumber") + subject = self._get_param("Subject") message_attributes = self._parse_message_attributes() if phone_number is not None: # Check phone is correct syntax (e164) if not is_e164(phone_number): - return self._error( - 'InvalidParameter', - 'Phone number does not meet the E164 format' - ), dict(status=400) + return ( + self._error( + "InvalidParameter", "Phone number does not meet the E164 format" + ), + dict(status=400), + ) # Look up topic arn by phone number try: arn = self.backend.get_topic_from_phone_number(phone_number) except SNSNotFoundError: - return self._error( - 'ParameterValueInvalid', - 'Could not find topic associated with phone number' - ), dict(status=400) + return ( + self._error( + "ParameterValueInvalid", + "Could not find topic associated with phone number", + ), + dict(status=400), + ) elif target_arn is not None: arn = target_arn else: arn = topic_arn - message = self._get_param('Message') + message = self._get_param("Message") try: message_id = self.backend.publish( - arn, message, subject=subject, - message_attributes=message_attributes) + arn, message, subject=subject, message_attributes=message_attributes + ) except ValueError as err: - error_response = self._error('InvalidParameter', str(err)) + error_response = self._error("InvalidParameter", str(err)) return error_response, dict(status=400) if self.request_json: - return json.dumps({ - "PublishResponse": { - "PublishResult": { - "MessageId": message_id, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "PublishResponse": { + "PublishResult": {"MessageId": message_id}, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template(PUBLISH_TEMPLATE) return template.render(message_id=message_id) def create_platform_application(self): - name = self._get_param('Name') - platform = self._get_param('Platform') + name = self._get_param("Name") + platform = self._get_param("Platform") attributes = self._get_attributes() platform_application = self.backend.create_platform_application( - self.region, name, platform, attributes) + self.region, name, platform, attributes + ) if self.request_json: - return json.dumps({ - "CreatePlatformApplicationResponse": { - "CreatePlatformApplicationResult": { - "PlatformApplicationArn": platform_application.arn, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937b", + return json.dumps( + { + "CreatePlatformApplicationResponse": { + "CreatePlatformApplicationResult": { + "PlatformApplicationArn": platform_application.arn + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937b" + }, } } - }) + ) template = self.response_template(CREATE_PLATFORM_APPLICATION_TEMPLATE) return template.render(platform_application=platform_application) def get_platform_application_attributes(self): - arn = self._get_param('PlatformApplicationArn') + arn = self._get_param("PlatformApplicationArn") application = self.backend.get_application(arn) if self.request_json: - return json.dumps({ - "GetPlatformApplicationAttributesResponse": { - "GetPlatformApplicationAttributesResult": { - "Attributes": application.attributes, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937f", + return json.dumps( + { + "GetPlatformApplicationAttributesResponse": { + "GetPlatformApplicationAttributesResult": { + "Attributes": application.attributes + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937f" + }, } } - }) + ) - template = self.response_template( - GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) + template = self.response_template(GET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) return template.render(application=application) def set_platform_application_attributes(self): - arn = self._get_param('PlatformApplicationArn') + arn = self._get_param("PlatformApplicationArn") attributes = self._get_attributes() self.backend.set_application_attributes(arn, attributes) if self.request_json: - return json.dumps({ - "SetPlatformApplicationAttributesResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-12df-8963-01868b7c937f", + return json.dumps( + { + "SetPlatformApplicationAttributesResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-12df-8963-01868b7c937f" + } } } - }) + ) - template = self.response_template( - SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) + template = self.response_template(SET_PLATFORM_APPLICATION_ATTRIBUTES_TEMPLATE) return template.render() def list_platform_applications(self): applications = self.backend.list_platform_applications() if self.request_json: - return json.dumps({ - "ListPlatformApplicationsResponse": { - "ListPlatformApplicationsResult": { - "PlatformApplications": [{ - "PlatformApplicationArn": application.arn, - "attributes": application.attributes, - } for application in applications], - "NextToken": None - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937c", + return json.dumps( + { + "ListPlatformApplicationsResponse": { + "ListPlatformApplicationsResult": { + "PlatformApplications": [ + { + "PlatformApplicationArn": application.arn, + "attributes": application.attributes, + } + for application in applications + ], + "NextToken": None, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937c" + }, } } - }) + ) template = self.response_template(LIST_PLATFORM_APPLICATIONS_TEMPLATE) return template.render(applications=applications) def delete_platform_application(self): - platform_arn = self._get_param('PlatformApplicationArn') + platform_arn = self._get_param("PlatformApplicationArn") self.backend.delete_platform_application(platform_arn) if self.request_json: - return json.dumps({ - "DeletePlatformApplicationResponse": { - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937e", + return json.dumps( + { + "DeletePlatformApplicationResponse": { + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937e" + } } } - }) + ) template = self.response_template(DELETE_PLATFORM_APPLICATION_TEMPLATE) return template.render() def create_platform_endpoint(self): - application_arn = self._get_param('PlatformApplicationArn') + application_arn = self._get_param("PlatformApplicationArn") application = self.backend.get_application(application_arn) - custom_user_data = self._get_param('CustomUserData') - token = self._get_param('Token') + custom_user_data = self._get_param("CustomUserData") + token = self._get_param("Token") attributes = self._get_attributes() platform_endpoint = self.backend.create_platform_endpoint( - self.region, application, custom_user_data, token, attributes) + self.region, application, custom_user_data, token, attributes + ) if self.request_json: - return json.dumps({ - "CreatePlatformEndpointResponse": { - "CreatePlatformEndpointResult": { - "EndpointArn": platform_endpoint.arn, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3779-11df-8963-01868b7c937b", + return json.dumps( + { + "CreatePlatformEndpointResponse": { + "CreatePlatformEndpointResult": { + "EndpointArn": platform_endpoint.arn + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3779-11df-8963-01868b7c937b" + }, } } - }) + ) template = self.response_template(CREATE_PLATFORM_ENDPOINT_TEMPLATE) return template.render(platform_endpoint=platform_endpoint) def list_endpoints_by_platform_application(self): - application_arn = self._get_param('PlatformApplicationArn') - endpoints = self.backend.list_endpoints_by_platform_application( - application_arn) + application_arn = self._get_param("PlatformApplicationArn") + endpoints = self.backend.list_endpoints_by_platform_application(application_arn) if self.request_json: - return json.dumps({ - "ListEndpointsByPlatformApplicationResponse": { - "ListEndpointsByPlatformApplicationResult": { - "Endpoints": [ - { - "Attributes": endpoint.attributes, - "EndpointArn": endpoint.arn, - } for endpoint in endpoints - ], - "NextToken": None - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937a", + return json.dumps( + { + "ListEndpointsByPlatformApplicationResponse": { + "ListEndpointsByPlatformApplicationResult": { + "Endpoints": [ + { + "Attributes": endpoint.attributes, + "EndpointArn": endpoint.arn, + } + for endpoint in endpoints + ], + "NextToken": None, + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937a" + }, } } - }) + ) template = self.response_template( - LIST_ENDPOINTS_BY_PLATFORM_APPLICATION_TEMPLATE) + LIST_ENDPOINTS_BY_PLATFORM_APPLICATION_TEMPLATE + ) return template.render(endpoints=endpoints) def get_endpoint_attributes(self): - arn = self._get_param('EndpointArn') + arn = self._get_param("EndpointArn") endpoint = self.backend.get_endpoint(arn) if self.request_json: - return json.dumps({ - "GetEndpointAttributesResponse": { - "GetEndpointAttributesResult": { - "Attributes": endpoint.attributes, - }, - "ResponseMetadata": { - "RequestId": "384ac68d-3775-11df-8963-01868b7c937f", + return json.dumps( + { + "GetEndpointAttributesResponse": { + "GetEndpointAttributesResult": { + "Attributes": endpoint.attributes + }, + "ResponseMetadata": { + "RequestId": "384ac68d-3775-11df-8963-01868b7c937f" + }, } } - }) + ) template = self.response_template(GET_ENDPOINT_ATTRIBUTES_TEMPLATE) return template.render(endpoint=endpoint) def set_endpoint_attributes(self): - arn = self._get_param('EndpointArn') + arn = self._get_param("EndpointArn") attributes = self._get_attributes() self.backend.set_endpoint_attributes(arn, attributes) if self.request_json: - return json.dumps({ - "SetEndpointAttributesResponse": { - "ResponseMetadata": { - "RequestId": "384bc68d-3775-12df-8963-01868b7c937f", + return json.dumps( + { + "SetEndpointAttributesResponse": { + "ResponseMetadata": { + "RequestId": "384bc68d-3775-12df-8963-01868b7c937f" + } } } - }) + ) template = self.response_template(SET_ENDPOINT_ATTRIBUTES_TEMPLATE) return template.render() def delete_endpoint(self): - arn = self._get_param('EndpointArn') + arn = self._get_param("EndpointArn") self.backend.delete_endpoint(arn) if self.request_json: - return json.dumps({ - "DeleteEndpointResponse": { - "ResponseMetadata": { - "RequestId": "384bc68d-3775-12df-8963-01868b7c937f", + return json.dumps( + { + "DeleteEndpointResponse": { + "ResponseMetadata": { + "RequestId": "384bc68d-3775-12df-8963-01868b7c937f" + } } } - }) + ) template = self.response_template(DELETE_ENDPOINT_TEMPLATE) return template.render() def get_subscription_attributes(self): - arn = self._get_param('SubscriptionArn') + arn = self._get_param("SubscriptionArn") attributes = self.backend.get_subscription_attributes(arn) template = self.response_template(GET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) return template.render(attributes=attributes) def set_subscription_attributes(self): - arn = self._get_param('SubscriptionArn') - attr_name = self._get_param('AttributeName') - attr_value = self._get_param('AttributeValue') + arn = self._get_param("SubscriptionArn") + attr_name = self._get_param("AttributeName") + attr_value = self._get_param("AttributeValue") self.backend.set_subscription_attributes(arn, attr_name, attr_value) template = self.response_template(SET_SUBSCRIPTION_ATTRIBUTES_TEMPLATE) return template.render() @@ -580,7 +638,7 @@ class SNSResponse(BaseResponse): for key, value in self.querystring.items(): match = self.SMS_ATTR_REGEX.match(key) if match is not None: - temp_dict[match.group('index')][match.group('type')] = value[0] + temp_dict[match.group("index")][match.group("type")] = value[0] # 1: {key:X, value:Y} # to @@ -588,8 +646,8 @@ class SNSResponse(BaseResponse): # All of this, just to take into account when people provide invalid stuff. result = {} for item in temp_dict.values(): - if 'key' in item and 'value' in item: - result[item['key']] = item['value'] + if "key" in item and "value" in item: + result[item["key"]] = item["value"] self.backend.update_sms_attributes(result) @@ -599,11 +657,13 @@ class SNSResponse(BaseResponse): def get_sms_attributes(self): filter_list = set() for key, value in self.querystring.items(): - if key.startswith('attributes.member.1'): + if key.startswith("attributes.member.1"): filter_list.add(value[0]) if len(filter_list) > 0: - result = {k: v for k, v in self.backend.sms_attributes.items() if k in filter_list} + result = { + k: v for k, v in self.backend.sms_attributes.items() if k in filter_list + } else: result = self.backend.sms_attributes @@ -611,24 +671,24 @@ class SNSResponse(BaseResponse): return template.render(attributes=result) def check_if_phone_number_is_opted_out(self): - number = self._get_param('phoneNumber') + number = self._get_param("phoneNumber") if self.OPT_OUT_PHONE_NUMBER_REGEX.match(number) is None: error_response = self._error( - code='InvalidParameter', - message='Invalid parameter: PhoneNumber Reason: input incorrectly formatted' + code="InvalidParameter", + message="Invalid parameter: PhoneNumber Reason: input incorrectly formatted", ) return error_response, dict(status=400) # There should be a nicer way to set if a nubmer has opted out template = self.response_template(CHECK_IF_OPTED_OUT_TEMPLATE) - return template.render(opt_out=str(number.endswith('99')).lower()) + return template.render(opt_out=str(number.endswith("99")).lower()) def list_phone_numbers_opted_out(self): template = self.response_template(LIST_OPTOUT_TEMPLATE) return template.render(opt_outs=self.backend.opt_out_numbers) def opt_in_phone_number(self): - number = self._get_param('phoneNumber') + number = self._get_param("phoneNumber") try: self.backend.opt_out_numbers.remove(number) @@ -639,10 +699,10 @@ class SNSResponse(BaseResponse): return template.render() def add_permission(self): - topic_arn = self._get_param('TopicArn') - label = self._get_param('Label') - aws_account_ids = self._get_multi_param('AWSAccountId.member.') - action_names = self._get_multi_param('ActionName.member.') + topic_arn = self._get_param("TopicArn") + label = self._get_param("Label") + aws_account_ids = self._get_multi_param("AWSAccountId.member.") + action_names = self._get_multi_param("ActionName.member.") self.backend.add_permission(topic_arn, label, aws_account_ids, action_names) @@ -650,8 +710,8 @@ class SNSResponse(BaseResponse): return template.render() def remove_permission(self): - topic_arn = self._get_param('TopicArn') - label = self._get_param('Label') + topic_arn = self._get_param("TopicArn") + label = self._get_param("Label") self.backend.remove_permission(topic_arn, label) @@ -659,10 +719,10 @@ class SNSResponse(BaseResponse): return template.render() def confirm_subscription(self): - arn = self._get_param('TopicArn') + arn = self._get_param("TopicArn") if arn not in self.backend.topics: - error_response = self._error('NotFound', 'Topic does not exist') + error_response = self._error("NotFound", "Topic does not exist") return error_response, dict(status=404) # Once Tokens are stored by the `subscribe` endpoint and distributed @@ -681,10 +741,12 @@ class SNSResponse(BaseResponse): # return error_response, dict(status=400) template = self.response_template(CONFIRM_SUBSCRIPTION_TEMPLATE) - return template.render(sub_arn='{0}:68762e72-e9b1-410a-8b3b-903da69ee1d5'.format(arn)) + return template.render( + sub_arn="{0}:68762e72-e9b1-410a-8b3b-903da69ee1d5".format(arn) + ) def list_tags_for_resource(self): - arn = self._get_param('ResourceArn') + arn = self._get_param("ResourceArn") result = self.backend.list_tags_for_resource(arn) @@ -692,7 +754,7 @@ class SNSResponse(BaseResponse): return template.render(tags=result) def tag_resource(self): - arn = self._get_param('ResourceArn') + arn = self._get_param("ResourceArn") tags = self._get_tags() self.backend.tag_resource(arn, tags) @@ -700,8 +762,8 @@ class SNSResponse(BaseResponse): return self.response_template(TAG_RESOURCE_TEMPLATE).render() def untag_resource(self): - arn = self._get_param('ResourceArn') - tag_keys = self._get_multi_param('TagKeys.member') + arn = self._get_param("ResourceArn") + tag_keys = self._get_multi_param("TagKeys.member") self.backend.untag_resource(arn, tag_keys) diff --git a/moto/sns/urls.py b/moto/sns/urls.py index 518531c55..8c38bb12c 100644 --- a/moto/sns/urls.py +++ b/moto/sns/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import SNSResponse -url_bases = [ - "https?://sns.(.+).amazonaws.com", -] +url_bases = ["https?://sns.(.+).amazonaws.com"] -url_paths = { - '{0}/$': SNSResponse.dispatch, -} +url_paths = {"{0}/$": SNSResponse.dispatch} diff --git a/moto/sns/utils.py b/moto/sns/utils.py index 7793b0f6d..a46b84ac2 100644 --- a/moto/sns/utils.py +++ b/moto/sns/utils.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import re import uuid -E164_REGEX = re.compile(r'^\+?[1-9]\d{1,14}$') +E164_REGEX = re.compile(r"^\+?[1-9]\d{1,14}$") def make_arn_for_topic(account_id, name, region_name): diff --git a/moto/sqs/__init__.py b/moto/sqs/__init__.py index 46c83133f..b2617b4e4 100644 --- a/moto/sqs/__init__.py +++ b/moto/sqs/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import sqs_backends from ..core.models import base_decorator, deprecated_base_decorator -sqs_backend = sqs_backends['us-east-1'] +sqs_backend = sqs_backends["us-east-1"] mock_sqs = base_decorator(sqs_backends) mock_sqs_deprecated = deprecated_base_decorator(sqs_backends) diff --git a/moto/sqs/exceptions.py b/moto/sqs/exceptions.py index 68c4abaae..01123d777 100644 --- a/moto/sqs/exceptions.py +++ b/moto/sqs/exceptions.py @@ -12,8 +12,7 @@ class ReceiptHandleIsInvalid(RESTError): def __init__(self): super(ReceiptHandleIsInvalid, self).__init__( - 'ReceiptHandleIsInvalid', - 'The input receipt handle is invalid.' + "ReceiptHandleIsInvalid", "The input receipt handle is invalid." ) @@ -29,15 +28,16 @@ class QueueDoesNotExist(RESTError): def __init__(self): super(QueueDoesNotExist, self).__init__( - "QueueDoesNotExist", "The specified queue does not exist for this wsdl version.") + "QueueDoesNotExist", + "The specified queue does not exist for this wsdl version.", + ) class QueueAlreadyExists(RESTError): code = 400 def __init__(self, message): - super(QueueAlreadyExists, self).__init__( - "QueueAlreadyExists", message) + super(QueueAlreadyExists, self).__init__("QueueAlreadyExists", message) class EmptyBatchRequest(RESTError): @@ -45,8 +45,8 @@ class EmptyBatchRequest(RESTError): def __init__(self): super(EmptyBatchRequest, self).__init__( - 'EmptyBatchRequest', - 'There should be at least one SendMessageBatchRequestEntry in the request.' + "EmptyBatchRequest", + "There should be at least one SendMessageBatchRequestEntry in the request.", ) @@ -55,9 +55,9 @@ class InvalidBatchEntryId(RESTError): def __init__(self): super(InvalidBatchEntryId, self).__init__( - 'InvalidBatchEntryId', - 'A batch entry id can only contain alphanumeric characters, ' - 'hyphens and underscores. It can be at most 80 letters long.' + "InvalidBatchEntryId", + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", ) @@ -66,9 +66,9 @@ class BatchRequestTooLong(RESTError): def __init__(self, length): super(BatchRequestTooLong, self).__init__( - 'BatchRequestTooLong', - 'Batch requests cannot be longer than 262144 bytes. ' - 'You have sent {} bytes.'.format(length) + "BatchRequestTooLong", + "Batch requests cannot be longer than 262144 bytes. " + "You have sent {} bytes.".format(length), ) @@ -77,8 +77,7 @@ class BatchEntryIdsNotDistinct(RESTError): def __init__(self, entry_id): super(BatchEntryIdsNotDistinct, self).__init__( - 'BatchEntryIdsNotDistinct', - 'Id {} repeated.'.format(entry_id) + "BatchEntryIdsNotDistinct", "Id {} repeated.".format(entry_id) ) @@ -87,9 +86,9 @@ class TooManyEntriesInBatchRequest(RESTError): def __init__(self, number): super(TooManyEntriesInBatchRequest, self).__init__( - 'TooManyEntriesInBatchRequest', - 'Maximum number of entries per request are 10. ' - 'You have sent {}.'.format(number) + "TooManyEntriesInBatchRequest", + "Maximum number of entries per request are 10. " + "You have sent {}.".format(number), ) @@ -98,6 +97,5 @@ class InvalidAttributeName(RESTError): def __init__(self, attribute_name): super(InvalidAttributeName, self).__init__( - 'InvalidAttributeName', - 'Unknown Attribute {}.'.format(attribute_name) + "InvalidAttributeName", "Unknown Attribute {}.".format(attribute_name) ) diff --git a/moto/sqs/models.py b/moto/sqs/models.py index 8d02fe529..e975c1bae 100644 --- a/moto/sqs/models.py +++ b/moto/sqs/models.py @@ -12,7 +12,12 @@ import boto.sqs from moto.core.exceptions import RESTError from moto.core import BaseBackend, BaseModel -from moto.core.utils import camelcase_to_underscores, get_random_message_id, unix_time, unix_time_millis +from moto.core.utils import ( + camelcase_to_underscores, + get_random_message_id, + unix_time, + unix_time_millis, +) from .utils import generate_receipt_handle from .exceptions import ( MessageAttributesInvalid, @@ -24,7 +29,7 @@ from .exceptions import ( BatchRequestTooLong, BatchEntryIdsNotDistinct, TooManyEntriesInBatchRequest, - InvalidAttributeName + InvalidAttributeName, ) DEFAULT_ACCOUNT_ID = 123456789012 @@ -32,11 +37,10 @@ DEFAULT_SENDER_ID = "AIDAIT2UOQQY3AUEKVGXU" MAXIMUM_MESSAGE_LENGTH = 262144 # 256 KiB -TRANSPORT_TYPE_ENCODINGS = {'String': b'\x01', 'Binary': b'\x02', 'Number': b'\x01'} +TRANSPORT_TYPE_ENCODINGS = {"String": b"\x01", "Binary": b"\x02", "Number": b"\x01"} class Message(BaseModel): - def __init__(self, message_id, body): self.id = message_id self._body = body @@ -54,7 +58,7 @@ class Message(BaseModel): @property def body_md5(self): md5 = hashlib.md5() - md5.update(self._body.encode('utf-8')) + md5.update(self._body.encode("utf-8")) return md5.hexdigest() @property @@ -68,17 +72,19 @@ class Message(BaseModel): Not yet implemented: List types (https://github.com/aws/aws-sdk-java/blob/7844c64cf248aed889811bf2e871ad6b276a89ca/aws-java-sdk-sqs/src/main/java/com/amazonaws/services/sqs/MessageMD5ChecksumHandler.java#L58k) """ + def utf8(str): if isinstance(str, six.string_types): - return str.encode('utf-8') + return str.encode("utf-8") return str + md5 = hashlib.md5() - struct_format = "!I".encode('ascii') # ensure it's a bytestring + struct_format = "!I".encode("ascii") # ensure it's a bytestring for name in sorted(self.message_attributes.keys()): attr = self.message_attributes[name] - data_type = attr['data_type'] + data_type = attr["data_type"] - encoded = utf8('') + encoded = utf8("") # Each part of each attribute is encoded right after it's # own length is packed into a 4-byte integer # 'timestamp' -> b'\x00\x00\x00\t' @@ -88,18 +94,22 @@ class Message(BaseModel): encoded += struct.pack(struct_format, len(data_type)) + utf8(data_type) encoded += TRANSPORT_TYPE_ENCODINGS[data_type] - if data_type == 'String' or data_type == 'Number': - value = attr['string_value'] - elif data_type == 'Binary': - print(data_type, attr['binary_value'], type(attr['binary_value'])) - value = base64.b64decode(attr['binary_value']) + if data_type == "String" or data_type == "Number": + value = attr["string_value"] + elif data_type == "Binary": + print(data_type, attr["binary_value"], type(attr["binary_value"])) + value = base64.b64decode(attr["binary_value"]) else: - print("Moto hasn't implemented MD5 hashing for {} attributes".format(data_type)) + print( + "Moto hasn't implemented MD5 hashing for {} attributes".format( + data_type + ) + ) # The following should be enough of a clue to users that # they are not, in fact, looking at a correct MD5 while # also following the character and length constraints of # MD5 so as not to break client softwre - return('deadbeefdeadbeefdeadbeefdeadbeef') + return "deadbeefdeadbeefdeadbeefdeadbeef" encoded += struct.pack(struct_format, len(utf8(value))) + utf8(value) @@ -162,24 +172,30 @@ class Message(BaseModel): class Queue(BaseModel): - BASE_ATTRIBUTES = ['ApproximateNumberOfMessages', - 'ApproximateNumberOfMessagesDelayed', - 'ApproximateNumberOfMessagesNotVisible', - 'CreatedTimestamp', - 'DelaySeconds', - 'LastModifiedTimestamp', - 'MaximumMessageSize', - 'MessageRetentionPeriod', - 'QueueArn', - 'ReceiveMessageWaitTimeSeconds', - 'VisibilityTimeout'] - FIFO_ATTRIBUTES = ['FifoQueue', - 'ContentBasedDeduplication'] - KMS_ATTRIBUTES = ['KmsDataKeyReusePeriodSeconds', - 'KmsMasterKeyId'] - ALLOWED_PERMISSIONS = ('*', 'ChangeMessageVisibility', 'DeleteMessage', - 'GetQueueAttributes', 'GetQueueUrl', - 'ReceiveMessage', 'SendMessage') + BASE_ATTRIBUTES = [ + "ApproximateNumberOfMessages", + "ApproximateNumberOfMessagesDelayed", + "ApproximateNumberOfMessagesNotVisible", + "CreatedTimestamp", + "DelaySeconds", + "LastModifiedTimestamp", + "MaximumMessageSize", + "MessageRetentionPeriod", + "QueueArn", + "ReceiveMessageWaitTimeSeconds", + "VisibilityTimeout", + ] + FIFO_ATTRIBUTES = ["FifoQueue", "ContentBasedDeduplication"] + KMS_ATTRIBUTES = ["KmsDataKeyReusePeriodSeconds", "KmsMasterKeyId"] + ALLOWED_PERMISSIONS = ( + "*", + "ChangeMessageVisibility", + "DeleteMessage", + "GetQueueAttributes", + "GetQueueUrl", + "ReceiveMessage", + "SendMessage", + ) def __init__(self, name, region, **kwargs): self.name = name @@ -192,34 +208,36 @@ class Queue(BaseModel): now = unix_time() self.created_timestamp = now - self.queue_arn = 'arn:aws:sqs:{0}:{1}:{2}'.format(self.region, - DEFAULT_ACCOUNT_ID, - self.name) + self.queue_arn = "arn:aws:sqs:{0}:{1}:{2}".format( + self.region, DEFAULT_ACCOUNT_ID, self.name + ) self.dead_letter_queue = None self.lambda_event_source_mappings = {} # default settings for a non fifo queue defaults = { - 'ContentBasedDeduplication': 'false', - 'DelaySeconds': 0, - 'FifoQueue': 'false', - 'KmsDataKeyReusePeriodSeconds': 300, # five minutes - 'KmsMasterKeyId': None, - 'MaximumMessageSize': int(64 << 10), - 'MessageRetentionPeriod': 86400 * 4, # four days - 'Policy': None, - 'ReceiveMessageWaitTimeSeconds': 0, - 'RedrivePolicy': None, - 'VisibilityTimeout': 30, + "ContentBasedDeduplication": "false", + "DelaySeconds": 0, + "FifoQueue": "false", + "KmsDataKeyReusePeriodSeconds": 300, # five minutes + "KmsMasterKeyId": None, + "MaximumMessageSize": int(64 << 10), + "MessageRetentionPeriod": 86400 * 4, # four days + "Policy": None, + "ReceiveMessageWaitTimeSeconds": 0, + "RedrivePolicy": None, + "VisibilityTimeout": 30, } defaults.update(kwargs) self._set_attributes(defaults, now) # Check some conditions - if self.fifo_queue and not self.name.endswith('.fifo'): - raise MessageAttributesInvalid('Queue name must end in .fifo for FIFO queues') + if self.fifo_queue and not self.name.endswith(".fifo"): + raise MessageAttributesInvalid( + "Queue name must end in .fifo for FIFO queues" + ) @property def pending_messages(self): @@ -227,18 +245,25 @@ class Queue(BaseModel): @property def pending_message_groups(self): - return set(message.group_id - for message in self._pending_messages - if message.group_id is not None) + return set( + message.group_id + for message in self._pending_messages + if message.group_id is not None + ) def _set_attributes(self, attributes, now=None): if not now: now = unix_time() - integer_fields = ('DelaySeconds', 'KmsDataKeyreusePeriodSeconds', - 'MaximumMessageSize', 'MessageRetentionPeriod', - 'ReceiveMessageWaitTime', 'VisibilityTimeout') - bool_fields = ('ContentBasedDeduplication', 'FifoQueue') + integer_fields = ( + "DelaySeconds", + "KmsDataKeyreusePeriodSeconds", + "MaximumMessageSize", + "MessageRetentionPeriod", + "ReceiveMessageWaitTime", + "VisibilityTimeout", + ) + bool_fields = ("ContentBasedDeduplication", "FifoQueue") for key, value in six.iteritems(attributes): if key in integer_fields: @@ -246,13 +271,13 @@ class Queue(BaseModel): if key in bool_fields: value = value == "true" - if key == 'RedrivePolicy' and value is not None: + if key == "RedrivePolicy" and value is not None: continue setattr(self, camelcase_to_underscores(key), value) - if attributes.get('RedrivePolicy', None): - self._setup_dlq(attributes['RedrivePolicy']) + if attributes.get("RedrivePolicy", None): + self._setup_dlq(attributes["RedrivePolicy"]) self.last_modified_timestamp = now @@ -262,59 +287,86 @@ class Queue(BaseModel): try: self.redrive_policy = json.loads(policy) except ValueError: - raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json') + raise RESTError( + "InvalidParameterValue", + "Redrive policy is not a dict or valid json", + ) elif isinstance(policy, dict): self.redrive_policy = policy else: - raise RESTError('InvalidParameterValue', 'Redrive policy is not a dict or valid json') + raise RESTError( + "InvalidParameterValue", "Redrive policy is not a dict or valid json" + ) - if 'deadLetterTargetArn' not in self.redrive_policy: - raise RESTError('InvalidParameterValue', 'Redrive policy does not contain deadLetterTargetArn') - if 'maxReceiveCount' not in self.redrive_policy: - raise RESTError('InvalidParameterValue', 'Redrive policy does not contain maxReceiveCount') + if "deadLetterTargetArn" not in self.redrive_policy: + raise RESTError( + "InvalidParameterValue", + "Redrive policy does not contain deadLetterTargetArn", + ) + if "maxReceiveCount" not in self.redrive_policy: + raise RESTError( + "InvalidParameterValue", + "Redrive policy does not contain maxReceiveCount", + ) # 'maxReceiveCount' is stored as int - self.redrive_policy['maxReceiveCount'] = int(self.redrive_policy['maxReceiveCount']) + self.redrive_policy["maxReceiveCount"] = int( + self.redrive_policy["maxReceiveCount"] + ) for queue in sqs_backends[self.region].queues.values(): - if queue.queue_arn == self.redrive_policy['deadLetterTargetArn']: + if queue.queue_arn == self.redrive_policy["deadLetterTargetArn"]: self.dead_letter_queue = queue if self.fifo_queue and not queue.fifo_queue: - raise RESTError('InvalidParameterCombination', 'Fifo queues cannot use non fifo dead letter queues') + raise RESTError( + "InvalidParameterCombination", + "Fifo queues cannot use non fifo dead letter queues", + ) break else: - raise RESTError('AWS.SimpleQueueService.NonExistentQueue', 'Could not find DLQ for {0}'.format(self.redrive_policy['deadLetterTargetArn'])) + raise RESTError( + "AWS.SimpleQueueService.NonExistentQueue", + "Could not find DLQ for {0}".format( + self.redrive_policy["deadLetterTargetArn"] + ), + ) @classmethod - def create_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] + def create_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] sqs_backend = sqs_backends[region_name] return sqs_backend.create_queue( - name=properties['QueueName'], - region=region_name, - **properties + name=properties["QueueName"], region=region_name, **properties ) @classmethod - def update_from_cloudformation_json(cls, original_resource, new_resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - queue_name = properties['QueueName'] + def update_from_cloudformation_json( + cls, original_resource, new_resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + queue_name = properties["QueueName"] sqs_backend = sqs_backends[region_name] queue = sqs_backend.get_queue(queue_name) - if 'VisibilityTimeout' in properties: - queue.visibility_timeout = int(properties['VisibilityTimeout']) + if "VisibilityTimeout" in properties: + queue.visibility_timeout = int(properties["VisibilityTimeout"]) - if 'ReceiveMessageWaitTimeSeconds' in properties: - queue.receive_message_wait_time_seconds = int(properties['ReceiveMessageWaitTimeSeconds']) + if "ReceiveMessageWaitTimeSeconds" in properties: + queue.receive_message_wait_time_seconds = int( + properties["ReceiveMessageWaitTimeSeconds"] + ) return queue @classmethod - def delete_from_cloudformation_json(cls, resource_name, cloudformation_json, region_name): - properties = cloudformation_json['Properties'] - queue_name = properties['QueueName'] + def delete_from_cloudformation_json( + cls, resource_name, cloudformation_json, region_name + ): + properties = cloudformation_json["Properties"] + queue_name = properties["QueueName"] sqs_backend = sqs_backends[region_name] sqs_backend.delete_queue(queue_name) @@ -353,10 +405,10 @@ class Queue(BaseModel): result[attribute] = attr if self.policy: - result['Policy'] = self.policy + result["Policy"] = self.policy if self.redrive_policy: - result['RedrivePolicy'] = json.dumps(self.redrive_policy) + result["RedrivePolicy"] = json.dumps(self.redrive_policy) for key in result: if isinstance(result[key], bool): @@ -365,15 +417,22 @@ class Queue(BaseModel): return result def url(self, request_url): - return "{0}://{1}/123456789012/{2}".format(request_url.scheme, request_url.netloc, self.name) + return "{0}://{1}/123456789012/{2}".format( + request_url.scheme, request_url.netloc, self.name + ) @property def messages(self): - return [message for message in self._messages if message.visible and not message.delayed] + return [ + message + for message in self._messages + if message.visible and not message.delayed + ] def add_message(self, message): self._messages.append(message) from moto.awslambda import lambda_backends + for arn, esm in self.lambda_event_source_mappings.items(): backend = sqs_backends[self.region] @@ -391,27 +450,28 @@ class Queue(BaseModel): ) result = lambda_backends[self.region].send_sqs_batch( - arn, - messages, - self.queue_arn, + arn, messages, self.queue_arn ) if result: [backend.delete_message(self.name, m.receipt_handle) for m in messages] else: - [backend.change_message_visibility(self.name, m.receipt_handle, 0) for m in messages] + [ + backend.change_message_visibility(self.name, m.receipt_handle, 0) + for m in messages + ] def get_cfn_attribute(self, attribute_name): from moto.cloudformation.exceptions import UnformattedGetAttTemplateException - if attribute_name == 'Arn': + + if attribute_name == "Arn": return self.queue_arn - elif attribute_name == 'QueueName': + elif attribute_name == "QueueName": return self.name raise UnformattedGetAttTemplateException() class SQSBackend(BaseBackend): - def __init__(self, region_name): self.region_name = region_name self.queues = {} @@ -427,7 +487,7 @@ class SQSBackend(BaseBackend): queue = self.queues.get(name) if queue: try: - kwargs.pop('region') + kwargs.pop("region") except KeyError: pass @@ -436,28 +496,26 @@ class SQSBackend(BaseBackend): queue_attributes = queue.attributes new_queue_attributes = new_queue.attributes static_attributes = ( - 'DelaySeconds', - 'MaximumMessageSize', - 'MessageRetentionPeriod', - 'Policy', - 'QueueArn', - 'ReceiveMessageWaitTimeSeconds', - 'RedrivePolicy', - 'VisibilityTimeout', - 'KmsMasterKeyId', - 'KmsDataKeyReusePeriodSeconds', - 'FifoQueue', - 'ContentBasedDeduplication', + "DelaySeconds", + "MaximumMessageSize", + "MessageRetentionPeriod", + "Policy", + "QueueArn", + "ReceiveMessageWaitTimeSeconds", + "RedrivePolicy", + "VisibilityTimeout", + "KmsMasterKeyId", + "KmsDataKeyReusePeriodSeconds", + "FifoQueue", + "ContentBasedDeduplication", ) for key in static_attributes: if queue_attributes.get(key) != new_queue_attributes.get(key): - raise QueueAlreadyExists( - "The specified queue already exists.", - ) + raise QueueAlreadyExists("The specified queue already exists.") else: try: - kwargs.pop('region') + kwargs.pop("region") except KeyError: pass queue = Queue(name, region=self.region_name, **kwargs) @@ -472,9 +530,9 @@ class SQSBackend(BaseBackend): return self.get_queue(queue_name) def list_queues(self, queue_name_prefix): - re_str = '.*' + re_str = ".*" if queue_name_prefix: - re_str = '^{0}.*'.format(queue_name_prefix) + re_str = "^{0}.*".format(queue_name_prefix) prefix_re = re.compile(re_str) qs = [] for name, q in self.queues.items(): @@ -497,17 +555,24 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) if not len(attribute_names): - attribute_names.append('All') + attribute_names.append("All") - valid_names = ['All'] + queue.BASE_ATTRIBUTES + queue.FIFO_ATTRIBUTES + queue.KMS_ATTRIBUTES - invalid_name = next((name for name in attribute_names if name not in valid_names), None) + valid_names = ( + ["All"] + + queue.BASE_ATTRIBUTES + + queue.FIFO_ATTRIBUTES + + queue.KMS_ATTRIBUTES + ) + invalid_name = next( + (name for name in attribute_names if name not in valid_names), None + ) - if invalid_name or invalid_name == '': + if invalid_name or invalid_name == "": raise InvalidAttributeName(invalid_name) attributes = {} - if 'All' in attribute_names: + if "All" in attribute_names: attributes = queue.attributes else: for name in (name for name in attribute_names if name in queue.attributes): @@ -520,7 +585,15 @@ class SQSBackend(BaseBackend): queue._set_attributes(attributes) return queue - def send_message(self, queue_name, message_body, message_attributes=None, delay_seconds=None, deduplication_id=None, group_id=None): + def send_message( + self, + queue_name, + message_body, + message_attributes=None, + delay_seconds=None, + deduplication_id=None, + group_id=None, + ): queue = self.get_queue(queue_name) @@ -541,9 +614,7 @@ class SQSBackend(BaseBackend): if message_attributes: message.message_attributes = message_attributes - message.mark_sent( - delay_seconds=delay_seconds - ) + message.mark_sent(delay_seconds=delay_seconds) queue.add_message(message) @@ -552,17 +623,25 @@ class SQSBackend(BaseBackend): def send_message_batch(self, queue_name, entries): self.get_queue(queue_name) - if any(not re.match(r'^[\w-]{1,80}$', entry['Id']) for entry in entries.values()): + if any( + not re.match(r"^[\w-]{1,80}$", entry["Id"]) for entry in entries.values() + ): raise InvalidBatchEntryId() body_length = next( - (len(entry['MessageBody']) for entry in entries.values() if len(entry['MessageBody']) > MAXIMUM_MESSAGE_LENGTH), - False + ( + len(entry["MessageBody"]) + for entry in entries.values() + if len(entry["MessageBody"]) > MAXIMUM_MESSAGE_LENGTH + ), + False, ) if body_length: raise BatchRequestTooLong(body_length) - duplicate_id = self._get_first_duplicate_id([entry['Id'] for entry in entries.values()]) + duplicate_id = self._get_first_duplicate_id( + [entry["Id"] for entry in entries.values()] + ) if duplicate_id: raise BatchEntryIdsNotDistinct(duplicate_id) @@ -574,11 +653,11 @@ class SQSBackend(BaseBackend): # Loop through looking for messages message = self.send_message( queue_name, - entry['MessageBody'], - message_attributes=entry['MessageAttributes'], - delay_seconds=entry['DelaySeconds'] + entry["MessageBody"], + message_attributes=entry["MessageAttributes"], + delay_seconds=entry["DelaySeconds"], ) - message.user_id = entry['Id'] + message.user_id = entry["Id"] messages.append(message) @@ -592,7 +671,9 @@ class SQSBackend(BaseBackend): unique_ids.add(id) return None - def receive_messages(self, queue_name, count, wait_seconds_timeout, visibility_timeout): + def receive_messages( + self, queue_name, count, wait_seconds_timeout, visibility_timeout + ): """ Attempt to retrieve visible messages from a queue. @@ -638,13 +719,15 @@ class SQSBackend(BaseBackend): queue.pending_messages.add(message) - if queue.dead_letter_queue is not None and message.approximate_receive_count >= queue.redrive_policy['maxReceiveCount']: + if ( + queue.dead_letter_queue is not None + and message.approximate_receive_count + >= queue.redrive_policy["maxReceiveCount"] + ): messages_to_dlq.append(message) continue - message.mark_received( - visibility_timeout=visibility_timeout - ) + message.mark_received(visibility_timeout=visibility_timeout) result.append(message) if len(result) >= count: break @@ -660,6 +743,7 @@ class SQSBackend(BaseBackend): break import time + time.sleep(0.01) continue @@ -670,7 +754,9 @@ class SQSBackend(BaseBackend): def delete_message(self, queue_name, receipt_handle): queue = self.get_queue(queue_name) - if not any(message.receipt_handle == receipt_handle for message in queue._messages): + if not any( + message.receipt_handle == receipt_handle for message in queue._messages + ): raise ReceiptHandleIsInvalid() new_messages = [] @@ -715,12 +801,12 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) if actions is None or len(actions) == 0: - raise RESTError('InvalidParameterValue', 'Need at least one Action') + raise RESTError("InvalidParameterValue", "Need at least one Action") if account_ids is None or len(account_ids) == 0: - raise RESTError('InvalidParameterValue', 'Need at least one Account ID') + raise RESTError("InvalidParameterValue", "Need at least one Account ID") if not all([item in Queue.ALLOWED_PERMISSIONS for item in actions]): - raise RESTError('InvalidParameterValue', 'Invalid permissions') + raise RESTError("InvalidParameterValue", "Invalid permissions") queue.permissions[label] = (account_ids, actions) @@ -728,7 +814,9 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) if label not in queue.permissions: - raise RESTError('InvalidParameterValue', 'Permission doesnt exist for the given label') + raise RESTError( + "InvalidParameterValue", "Permission doesnt exist for the given label" + ) del queue.permissions[label] @@ -736,12 +824,15 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) if not len(tags): - raise RESTError('MissingParameter', - 'The request must contain the parameter Tags.') + raise RESTError( + "MissingParameter", "The request must contain the parameter Tags." + ) if len(tags) > 50: - raise RESTError('InvalidParameterValue', - 'Too many tags added for queue {}.'.format(queue_name)) + raise RESTError( + "InvalidParameterValue", + "Too many tags added for queue {}.".format(queue_name), + ) queue.tags.update(tags) @@ -749,7 +840,10 @@ class SQSBackend(BaseBackend): queue = self.get_queue(queue_name) if not len(tag_keys): - raise RESTError('InvalidParameterValue', 'Tag keys must be between 1 and 128 characters in length.') + raise RESTError( + "InvalidParameterValue", + "Tag keys must be between 1 and 128 characters in length.", + ) for key in tag_keys: try: diff --git a/moto/sqs/responses.py b/moto/sqs/responses.py index ad46df723..2a8c06eaf 100644 --- a/moto/sqs/responses.py +++ b/moto/sqs/responses.py @@ -12,7 +12,7 @@ from .exceptions import ( MessageNotInflight, ReceiptHandleIsInvalid, EmptyBatchRequest, - InvalidAttributeName + InvalidAttributeName, ) MAXIMUM_VISIBILTY_TIMEOUT = 43200 @@ -22,7 +22,7 @@ DEFAULT_RECEIVED_MESSAGES = 1 class SQSResponse(BaseResponse): - region_regex = re.compile(r'://(.+?)\.queue\.amazonaws\.com') + region_regex = re.compile(r"://(.+?)\.queue\.amazonaws\.com") @property def sqs_backend(self): @@ -30,19 +30,21 @@ class SQSResponse(BaseResponse): @property def attribute(self): - if not hasattr(self, '_attribute'): - self._attribute = self._get_map_prefix('Attribute', key_end='.Name', value_end='.Value') + if not hasattr(self, "_attribute"): + self._attribute = self._get_map_prefix( + "Attribute", key_end=".Name", value_end=".Value" + ) return self._attribute @property def tags(self): - if not hasattr(self, '_tags'): - self._tags = self._get_map_prefix('Tag', key_end='.Key', value_end='.Value') + if not hasattr(self, "_tags"): + self._tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") return self._tags def _get_queue_name(self): try: - queue_name = self.querystring.get('QueueUrl')[0].split("/")[-1] + queue_name = self.querystring.get("QueueUrl")[0].split("/")[-1] except TypeError: # Fallback to reading from the URL queue_name = self.path.split("/")[-1] @@ -80,9 +82,11 @@ class SQSResponse(BaseResponse): queue_name = self._get_param("QueueName") try: - queue = self.sqs_backend.create_queue(queue_name, self.tags, **self.attribute) + queue = self.sqs_backend.create_queue( + queue_name, self.tags, **self.attribute + ) except MessageAttributesInvalid as e: - return self._error('InvalidParameterValue', e.description) + return self._error("InvalidParameterValue", e.description) template = self.response_template(CREATE_QUEUE_RESPONSE) return template.render(queue_url=queue.url(request_url)) @@ -98,14 +102,14 @@ class SQSResponse(BaseResponse): def list_queues(self): request_url = urlparse(self.uri) - queue_name_prefix = self._get_param('QueueNamePrefix') + queue_name_prefix = self._get_param("QueueNamePrefix") queues = self.sqs_backend.list_queues(queue_name_prefix) template = self.response_template(LIST_QUEUES_RESPONSE) return template.render(queues=queues, request_url=request_url) def change_message_visibility(self): queue_name = self._get_queue_name() - receipt_handle = self._get_param('ReceiptHandle') + receipt_handle = self._get_param("ReceiptHandle") try: visibility_timeout = self._get_validated_visibility_timeout() @@ -116,53 +120,64 @@ class SQSResponse(BaseResponse): self.sqs_backend.change_message_visibility( queue_name=queue_name, receipt_handle=receipt_handle, - visibility_timeout=visibility_timeout + visibility_timeout=visibility_timeout, ) except MessageNotInflight as e: - return "Invalid request: {0}".format(e.description), dict(status=e.status_code) + return ( + "Invalid request: {0}".format(e.description), + dict(status=e.status_code), + ) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_RESPONSE) return template.render() def change_message_visibility_batch(self): queue_name = self._get_queue_name() - entries = self._get_list_prefix('ChangeMessageVisibilityBatchRequestEntry') + entries = self._get_list_prefix("ChangeMessageVisibilityBatchRequestEntry") success = [] error = [] for entry in entries: try: - visibility_timeout = self._get_validated_visibility_timeout(entry['visibility_timeout']) + visibility_timeout = self._get_validated_visibility_timeout( + entry["visibility_timeout"] + ) except ValueError: - error.append({ - 'Id': entry['id'], - 'SenderFault': 'true', - 'Code': 'InvalidParameterValue', - 'Message': 'Visibility timeout invalid' - }) + error.append( + { + "Id": entry["id"], + "SenderFault": "true", + "Code": "InvalidParameterValue", + "Message": "Visibility timeout invalid", + } + ) continue try: self.sqs_backend.change_message_visibility( queue_name=queue_name, - receipt_handle=entry['receipt_handle'], - visibility_timeout=visibility_timeout + receipt_handle=entry["receipt_handle"], + visibility_timeout=visibility_timeout, ) - success.append(entry['id']) + success.append(entry["id"]) except ReceiptHandleIsInvalid as e: - error.append({ - 'Id': entry['id'], - 'SenderFault': 'true', - 'Code': 'ReceiptHandleIsInvalid', - 'Message': e.description - }) + error.append( + { + "Id": entry["id"], + "SenderFault": "true", + "Code": "ReceiptHandleIsInvalid", + "Message": e.description, + } + ) except MessageNotInflight as e: - error.append({ - 'Id': entry['id'], - 'SenderFault': 'false', - 'Code': 'AWS.SimpleQueueService.MessageNotInflight', - 'Message': e.description - }) + error.append( + { + "Id": entry["id"], + "SenderFault": "false", + "Code": "AWS.SimpleQueueService.MessageNotInflight", + "Message": e.description, + } + ) template = self.response_template(CHANGE_MESSAGE_VISIBILITY_BATCH_RESPONSE) return template.render(success=success, errors=error) @@ -170,10 +185,10 @@ class SQSResponse(BaseResponse): def get_queue_attributes(self): queue_name = self._get_queue_name() - if self.querystring.get('AttributeNames'): - raise InvalidAttributeName('') + if self.querystring.get("AttributeNames"): + raise InvalidAttributeName("") - attribute_names = self._get_multi_param('AttributeName') + attribute_names = self._get_multi_param("AttributeName") attributes = self.sqs_backend.get_queue_attributes(queue_name, attribute_names) @@ -192,14 +207,17 @@ class SQSResponse(BaseResponse): queue_name = self._get_queue_name() queue = self.sqs_backend.delete_queue(queue_name) if not queue: - return "A queue with name {0} does not exist".format(queue_name), dict(status=404) + return ( + "A queue with name {0} does not exist".format(queue_name), + dict(status=404), + ) template = self.response_template(DELETE_QUEUE_RESPONSE) return template.render(queue=queue) def send_message(self): - message = self._get_param('MessageBody') - delay_seconds = int(self._get_param('DelaySeconds', 0)) + message = self._get_param("MessageBody") + delay_seconds = int(self._get_param("DelaySeconds", 0)) message_group_id = self._get_param("MessageGroupId") message_dedupe_id = self._get_param("MessageDeduplicationId") @@ -219,7 +237,7 @@ class SQSResponse(BaseResponse): message_attributes=message_attributes, delay_seconds=delay_seconds, deduplication_id=message_dedupe_id, - group_id=message_group_id + group_id=message_group_id, ) template = self.response_template(SEND_MESSAGE_RESPONSE) return template.render(message=message, message_attributes=message_attributes) @@ -240,25 +258,30 @@ class SQSResponse(BaseResponse): self.sqs_backend.get_queue(queue_name) - if self.querystring.get('Entries'): + if self.querystring.get("Entries"): raise EmptyBatchRequest() entries = {} for key, value in self.querystring.items(): - match = re.match(r'^SendMessageBatchRequestEntry\.(\d+)\.Id', key) + match = re.match(r"^SendMessageBatchRequestEntry\.(\d+)\.Id", key) if match: index = match.group(1) message_attributes = parse_message_attributes( - self.querystring, base='SendMessageBatchRequestEntry.{}.'.format(index)) + self.querystring, + base="SendMessageBatchRequestEntry.{}.".format(index), + ) entries[index] = { - 'Id': value[0], - 'MessageBody': self.querystring.get( - 'SendMessageBatchRequestEntry.{}.MessageBody'.format(index))[0], - 'DelaySeconds': self.querystring.get( - 'SendMessageBatchRequestEntry.{}.DelaySeconds'.format(index), [None])[0], - 'MessageAttributes': message_attributes + "Id": value[0], + "MessageBody": self.querystring.get( + "SendMessageBatchRequestEntry.{}.MessageBody".format(index) + )[0], + "DelaySeconds": self.querystring.get( + "SendMessageBatchRequestEntry.{}.DelaySeconds".format(index), + [None], + )[0], + "MessageAttributes": message_attributes, } messages = self.sqs_backend.send_message_batch(queue_name, entries) @@ -288,8 +311,9 @@ class SQSResponse(BaseResponse): message_ids = [] for index in range(1, 11): # Loop through looking for messages - receipt_key = 'DeleteMessageBatchRequestEntry.{0}.ReceiptHandle'.format( - index) + receipt_key = "DeleteMessageBatchRequestEntry.{0}.ReceiptHandle".format( + index + ) receipt_handle = self.querystring.get(receipt_key) if not receipt_handle: # Found all messages @@ -297,8 +321,7 @@ class SQSResponse(BaseResponse): self.sqs_backend.delete_message(queue_name, receipt_handle[0]) - message_user_id_key = 'DeleteMessageBatchRequestEntry.{0}.Id'.format( - index) + message_user_id_key = "DeleteMessageBatchRequestEntry.{0}.Id".format(index) message_user_id = self.querystring.get(message_user_id_key)[0] message_ids.append(message_user_id) @@ -327,7 +350,8 @@ class SQSResponse(BaseResponse): "An error occurred (InvalidParameterValue) when calling " "the ReceiveMessage operation: Value %s for parameter " "MaxNumberOfMessages is invalid. Reason: must be between " - "1 and 10, if provided." % message_count) + "1 and 10, if provided." % message_count, + ) try: wait_time = int(self.querystring.get("WaitTimeSeconds")[0]) @@ -340,7 +364,8 @@ class SQSResponse(BaseResponse): "An error occurred (InvalidParameterValue) when calling " "the ReceiveMessage operation: Value %s for parameter " "WaitTimeSeconds is invalid. Reason: must be <= 0 and " - ">= 20 if provided." % wait_time) + ">= 20 if provided." % wait_time, + ) try: visibility_timeout = self._get_validated_visibility_timeout() @@ -350,7 +375,8 @@ class SQSResponse(BaseResponse): return ERROR_MAX_VISIBILITY_TIMEOUT_RESPONSE, dict(status=400) messages = self.sqs_backend.receive_messages( - queue_name, message_count, wait_time, visibility_timeout) + queue_name, message_count, wait_time, visibility_timeout + ) template = self.response_template(RECEIVE_MESSAGE_RESPONSE) return template.render(messages=messages) @@ -365,9 +391,9 @@ class SQSResponse(BaseResponse): def add_permission(self): queue_name = self._get_queue_name() - actions = self._get_multi_param('ActionName') - account_ids = self._get_multi_param('AWSAccountId') - label = self._get_param('Label') + actions = self._get_multi_param("ActionName") + account_ids = self._get_multi_param("AWSAccountId") + label = self._get_param("Label") self.sqs_backend.add_permission(queue_name, actions, account_ids, label) @@ -376,7 +402,7 @@ class SQSResponse(BaseResponse): def remove_permission(self): queue_name = self._get_queue_name() - label = self._get_param('Label') + label = self._get_param("Label") self.sqs_backend.remove_permission(queue_name, label) @@ -385,7 +411,7 @@ class SQSResponse(BaseResponse): def tag_queue(self): queue_name = self._get_queue_name() - tags = self._get_map_prefix('Tag', key_end='.Key', value_end='.Value') + tags = self._get_map_prefix("Tag", key_end=".Key", value_end=".Value") self.sqs_backend.tag_queue(queue_name, tags) @@ -394,7 +420,7 @@ class SQSResponse(BaseResponse): def untag_queue(self): queue_name = self._get_queue_name() - tag_keys = self._get_multi_param('TagKey') + tag_keys = self._get_multi_param("TagKey") self.sqs_backend.untag_queue(queue_name, tag_keys) @@ -672,7 +698,8 @@ ERROR_TOO_LONG_RESPONSE = """ diff --git a/moto/sqs/urls.py b/moto/sqs/urls.py index 9ec014a80..3acf8591a 100644 --- a/moto/sqs/urls.py +++ b/moto/sqs/urls.py @@ -1,13 +1,11 @@ from __future__ import unicode_literals from .responses import SQSResponse -url_bases = [ - "https?://(.*?)(queue|sqs)(.*?).amazonaws.com" -] +url_bases = ["https?://(.*?)(queue|sqs)(.*?).amazonaws.com"] dispatch = SQSResponse().dispatch url_paths = { - '{0}/$': dispatch, - '{0}/(?P\d+)/(?P[a-zA-Z0-9\-_\.]+)': dispatch, + "{0}/$": dispatch, + "{0}/(?P\d+)/(?P[a-zA-Z0-9\-_\.]+)": dispatch, } diff --git a/moto/sqs/utils.py b/moto/sqs/utils.py index 78be5f629..f3b8bbfe8 100644 --- a/moto/sqs/utils.py +++ b/moto/sqs/utils.py @@ -8,46 +8,62 @@ from .exceptions import MessageAttributesInvalid def generate_receipt_handle(): # http://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/ImportantIdentifiers.html#ImportantIdentifiers-receipt-handles length = 185 - return ''.join(random.choice(string.ascii_lowercase) for x in range(length)) + return "".join(random.choice(string.ascii_lowercase) for x in range(length)) -def parse_message_attributes(querystring, base='', value_namespace='Value.'): +def parse_message_attributes(querystring, base="", value_namespace="Value."): message_attributes = {} index = 1 while True: # Loop through looking for message attributes - name_key = base + 'MessageAttribute.{0}.Name'.format(index) + name_key = base + "MessageAttribute.{0}.Name".format(index) name = querystring.get(name_key) if not name: # Found all attributes break - data_type_key = base + \ - 'MessageAttribute.{0}.{1}DataType'.format(index, value_namespace) + data_type_key = base + "MessageAttribute.{0}.{1}DataType".format( + index, value_namespace + ) data_type = querystring.get(data_type_key) if not data_type: raise MessageAttributesInvalid( - "The message attribute '{0}' must contain non-empty message attribute value.".format(name[0])) + "The message attribute '{0}' must contain non-empty message attribute value.".format( + name[0] + ) + ) - data_type_parts = data_type[0].split('.') - if len(data_type_parts) > 2 or data_type_parts[0] not in ['String', 'Binary', 'Number']: + data_type_parts = data_type[0].split(".") + if len(data_type_parts) > 2 or data_type_parts[0] not in [ + "String", + "Binary", + "Number", + ]: raise MessageAttributesInvalid( - "The message attribute '{0}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String.".format(name[0])) + "The message attribute '{0}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String.".format( + name[0] + ) + ) - type_prefix = 'String' - if data_type_parts[0] == 'Binary': - type_prefix = 'Binary' + type_prefix = "String" + if data_type_parts[0] == "Binary": + type_prefix = "Binary" - value_key = base + \ - 'MessageAttribute.{0}.{1}{2}Value'.format( - index, value_namespace, type_prefix) + value_key = base + "MessageAttribute.{0}.{1}{2}Value".format( + index, value_namespace, type_prefix + ) value = querystring.get(value_key) if not value: raise MessageAttributesInvalid( - "The message attribute '{0}' must contain non-empty message attribute value for message attribute type '{1}'.".format(name[0], data_type[0])) + "The message attribute '{0}' must contain non-empty message attribute value for message attribute type '{1}'.".format( + name[0], data_type[0] + ) + ) - message_attributes[name[0]] = {'data_type': data_type[ - 0], type_prefix.lower() + '_value': value[0]} + message_attributes[name[0]] = { + "data_type": data_type[0], + type_prefix.lower() + "_value": value[0], + } index += 1 diff --git a/moto/ssm/__init__.py b/moto/ssm/__init__.py index c42f3b780..18112544a 100644 --- a/moto/ssm/__init__.py +++ b/moto/ssm/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import ssm_backends from ..core.models import base_decorator -ssm_backend = ssm_backends['us-east-1'] +ssm_backend = ssm_backends["us-east-1"] mock_ssm = base_decorator(ssm_backends) diff --git a/moto/ssm/exceptions.py b/moto/ssm/exceptions.py index 4e01843fb..3458fe7d3 100644 --- a/moto/ssm/exceptions.py +++ b/moto/ssm/exceptions.py @@ -6,29 +6,25 @@ class InvalidFilterKey(JsonRESTError): code = 400 def __init__(self, message): - super(InvalidFilterKey, self).__init__( - "InvalidFilterKey", message) + super(InvalidFilterKey, self).__init__("InvalidFilterKey", message) class InvalidFilterOption(JsonRESTError): code = 400 def __init__(self, message): - super(InvalidFilterOption, self).__init__( - "InvalidFilterOption", message) + super(InvalidFilterOption, self).__init__("InvalidFilterOption", message) class InvalidFilterValue(JsonRESTError): code = 400 def __init__(self, message): - super(InvalidFilterValue, self).__init__( - "InvalidFilterValue", message) + super(InvalidFilterValue, self).__init__("InvalidFilterValue", message) class ValidationException(JsonRESTError): code = 400 def __init__(self, message): - super(ValidationException, self).__init__( - "ValidationException", message) + super(ValidationException, self).__init__("ValidationException", message) diff --git a/moto/ssm/models.py b/moto/ssm/models.py index 39bd63ede..0e0f8d353 100644 --- a/moto/ssm/models.py +++ b/moto/ssm/models.py @@ -13,12 +13,26 @@ import time import uuid import itertools -from .exceptions import ValidationException, InvalidFilterValue, InvalidFilterOption, InvalidFilterKey +from .exceptions import ( + ValidationException, + InvalidFilterValue, + InvalidFilterOption, + InvalidFilterKey, +) class Parameter(BaseModel): - def __init__(self, name, value, type, description, allowed_pattern, keyid, - last_modified_date, version): + def __init__( + self, + name, + value, + type, + description, + allowed_pattern, + keyid, + last_modified_date, + version, + ): self.name = name self.type = type self.description = description @@ -27,48 +41,48 @@ class Parameter(BaseModel): self.last_modified_date = last_modified_date self.version = version - if self.type == 'SecureString': + if self.type == "SecureString": if not self.keyid: - self.keyid = 'alias/aws/ssm' + self.keyid = "alias/aws/ssm" self.value = self.encrypt(value) else: self.value = value def encrypt(self, value): - return 'kms:{}:'.format(self.keyid) + value + return "kms:{}:".format(self.keyid) + value def decrypt(self, value): - if self.type != 'SecureString': + if self.type != "SecureString": return value - prefix = 'kms:{}:'.format(self.keyid or 'default') + prefix = "kms:{}:".format(self.keyid or "default") if value.startswith(prefix): - return value[len(prefix):] + return value[len(prefix) :] def response_object(self, decrypt=False): r = { - 'Name': self.name, - 'Type': self.type, - 'Value': self.decrypt(self.value) if decrypt else self.value, - 'Version': self.version, + "Name": self.name, + "Type": self.type, + "Value": self.decrypt(self.value) if decrypt else self.value, + "Version": self.version, } return r def describe_response_object(self, decrypt=False): r = self.response_object(decrypt) - r['LastModifiedDate'] = int(self.last_modified_date) - r['LastModifiedUser'] = 'N/A' + r["LastModifiedDate"] = int(self.last_modified_date) + r["LastModifiedUser"] = "N/A" if self.description: - r['Description'] = self.description + r["Description"] = self.description if self.keyid: - r['KeyId'] = self.keyid + r["KeyId"] = self.keyid if self.allowed_pattern: - r['AllowedPattern'] = self.allowed_pattern + r["AllowedPattern"] = self.allowed_pattern return r @@ -77,11 +91,23 @@ MAX_TIMEOUT_SECONDS = 3600 class Command(BaseModel): - def __init__(self, comment='', document_name='', timeout_seconds=MAX_TIMEOUT_SECONDS, - instance_ids=None, max_concurrency='', max_errors='', - notification_config=None, output_s3_bucket_name='', - output_s3_key_prefix='', output_s3_region='', parameters=None, - service_role_arn='', targets=None, backend_region='us-east-1'): + def __init__( + self, + comment="", + document_name="", + timeout_seconds=MAX_TIMEOUT_SECONDS, + instance_ids=None, + max_concurrency="", + max_errors="", + notification_config=None, + output_s3_bucket_name="", + output_s3_key_prefix="", + output_s3_region="", + parameters=None, + service_role_arn="", + targets=None, + backend_region="us-east-1", + ): if instance_ids is None: instance_ids = [] @@ -99,12 +125,14 @@ class Command(BaseModel): self.completed_count = len(instance_ids) self.target_count = len(instance_ids) self.command_id = str(uuid.uuid4()) - self.status = 'Success' - self.status_details = 'Details placeholder' + self.status = "Success" + self.status_details = "Details placeholder" self.requested_date_time = datetime.datetime.now() self.requested_date_time_iso = self.requested_date_time.isoformat() - expires_after = self.requested_date_time + datetime.timedelta(0, timeout_seconds) + expires_after = self.requested_date_time + datetime.timedelta( + 0, timeout_seconds + ) self.expires_after = expires_after.isoformat() self.comment = comment @@ -122,9 +150,11 @@ class Command(BaseModel): self.backend_region = backend_region # Get instance ids from a cloud formation stack target. - stack_instance_ids = [self.get_instance_ids_by_stack_ids(target['Values']) for - target in self.targets if - target['Key'] == 'tag:aws:cloudformation:stack-name'] + stack_instance_ids = [ + self.get_instance_ids_by_stack_ids(target["Values"]) + for target in self.targets + if target["Key"] == "tag:aws:cloudformation:stack-name" + ] self.instance_ids += list(itertools.chain.from_iterable(stack_instance_ids)) @@ -132,7 +162,8 @@ class Command(BaseModel): self.invocations = [] for instance_id in self.instance_ids: self.invocations.append( - self.invocation_response(instance_id, "aws:runShellScript")) + self.invocation_response(instance_id, "aws:runShellScript") + ) def get_instance_ids_by_stack_ids(self, stack_ids): instance_ids = [] @@ -140,34 +171,36 @@ class Command(BaseModel): for stack_id in stack_ids: stack_resources = cloudformation_backend.list_stack_resources(stack_id) instance_resources = [ - instance.id for instance in stack_resources - if instance.type == "AWS::EC2::Instance"] + instance.id + for instance in stack_resources + if instance.type == "AWS::EC2::Instance" + ] instance_ids.extend(instance_resources) return instance_ids def response_object(self): r = { - 'CommandId': self.command_id, - 'Comment': self.comment, - 'CompletedCount': self.completed_count, - 'DocumentName': self.document_name, - 'ErrorCount': self.error_count, - 'ExpiresAfter': self.expires_after, - 'InstanceIds': self.instance_ids, - 'MaxConcurrency': self.max_concurrency, - 'MaxErrors': self.max_errors, - 'NotificationConfig': self.notification_config, - 'OutputS3Region': self.output_s3_region, - 'OutputS3BucketName': self.output_s3_bucket_name, - 'OutputS3KeyPrefix': self.output_s3_key_prefix, - 'Parameters': self.parameters, - 'RequestedDateTime': self.requested_date_time_iso, - 'ServiceRole': self.service_role_arn, - 'Status': self.status, - 'StatusDetails': self.status_details, - 'TargetCount': self.target_count, - 'Targets': self.targets, + "CommandId": self.command_id, + "Comment": self.comment, + "CompletedCount": self.completed_count, + "DocumentName": self.document_name, + "ErrorCount": self.error_count, + "ExpiresAfter": self.expires_after, + "InstanceIds": self.instance_ids, + "MaxConcurrency": self.max_concurrency, + "MaxErrors": self.max_errors, + "NotificationConfig": self.notification_config, + "OutputS3Region": self.output_s3_region, + "OutputS3BucketName": self.output_s3_bucket_name, + "OutputS3KeyPrefix": self.output_s3_key_prefix, + "Parameters": self.parameters, + "RequestedDateTime": self.requested_date_time_iso, + "ServiceRole": self.service_role_arn, + "Status": self.status, + "StatusDetails": self.status_details, + "TargetCount": self.target_count, + "Targets": self.targets, } return r @@ -181,44 +214,50 @@ class Command(BaseModel): end_time = self.requested_date_time + elapsed_time_delta r = { - 'CommandId': self.command_id, - 'InstanceId': instance_id, - 'Comment': self.comment, - 'DocumentName': self.document_name, - 'PluginName': plugin_name, - 'ResponseCode': 0, - 'ExecutionStartDateTime': self.requested_date_time_iso, - 'ExecutionElapsedTime': elapsed_time_iso, - 'ExecutionEndDateTime': end_time.isoformat(), - 'Status': 'Success', - 'StatusDetails': 'Success', - 'StandardOutputContent': '', - 'StandardOutputUrl': '', - 'StandardErrorContent': '', + "CommandId": self.command_id, + "InstanceId": instance_id, + "Comment": self.comment, + "DocumentName": self.document_name, + "PluginName": plugin_name, + "ResponseCode": 0, + "ExecutionStartDateTime": self.requested_date_time_iso, + "ExecutionElapsedTime": elapsed_time_iso, + "ExecutionEndDateTime": end_time.isoformat(), + "Status": "Success", + "StatusDetails": "Success", + "StandardOutputContent": "", + "StandardOutputUrl": "", + "StandardErrorContent": "", } return r def get_invocation(self, instance_id, plugin_name): invocation = next( - (invocation for invocation in self.invocations - if invocation['InstanceId'] == instance_id), None) + ( + invocation + for invocation in self.invocations + if invocation["InstanceId"] == instance_id + ), + None, + ) if invocation is None: raise RESTError( - 'InvocationDoesNotExist', - 'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation') + "InvocationDoesNotExist", + "An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation", + ) - if plugin_name is not None and invocation['PluginName'] != plugin_name: - raise RESTError( - 'InvocationDoesNotExist', - 'An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation') + if plugin_name is not None and invocation["PluginName"] != plugin_name: + raise RESTError( + "InvocationDoesNotExist", + "An error occurred (InvocationDoesNotExist) when calling the GetCommandInvocation operation", + ) return invocation class SimpleSystemManagerBackend(BaseBackend): - def __init__(self): self._parameters = {} self._resource_tags = defaultdict(lambda: defaultdict(dict)) @@ -248,7 +287,9 @@ class SimpleSystemManagerBackend(BaseBackend): def describe_parameters(self, filters, parameter_filters): if filters and parameter_filters: - raise ValidationException('You can use either Filters or ParameterFilters in a single request.') + raise ValidationException( + "You can use either Filters or ParameterFilters in a single request." + ) self._validate_parameter_filters(parameter_filters, by_path=False) @@ -260,22 +301,22 @@ class SimpleSystemManagerBackend(BaseBackend): if filters: for filter in filters: - if filter['Key'] == 'Name': + if filter["Key"] == "Name": k = ssm_parameter.name - for v in filter['Values']: + for v in filter["Values"]: if k.startswith(v): result.append(ssm_parameter) break - elif filter['Key'] == 'Type': + elif filter["Key"] == "Type": k = ssm_parameter.type - for v in filter['Values']: + for v in filter["Values"]: if k == v: result.append(ssm_parameter) break - elif filter['Key'] == 'KeyId': + elif filter["Key"] == "KeyId": k = ssm_parameter.keyid if k: - for v in filter['Values']: + for v in filter["Values"]: if k == v: result.append(ssm_parameter) break @@ -287,125 +328,157 @@ class SimpleSystemManagerBackend(BaseBackend): def _validate_parameter_filters(self, parameter_filters, by_path): for index, filter_obj in enumerate(parameter_filters or []): - key = filter_obj['Key'] - values = filter_obj.get('Values', []) + key = filter_obj["Key"] + values = filter_obj.get("Values", []) - if key == 'Path': - option = filter_obj.get('Option', 'OneLevel') + if key == "Path": + option = filter_obj.get("Option", "OneLevel") else: - option = filter_obj.get('Option', 'Equals') + option = filter_obj.get("Option", "Equals") - if not re.match(r'^tag:.+|Name|Type|KeyId|Path|Label|Tier$', key): - self._errors.append(self._format_error( - key='parameterFilters.{index}.member.key'.format(index=(index + 1)), - value=key, - constraint='Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier', - )) + if not re.match(r"^tag:.+|Name|Type|KeyId|Path|Label|Tier$", key): + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.key".format( + index=(index + 1) + ), + value=key, + constraint="Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier", + ) + ) if len(key) > 132: - self._errors.append(self._format_error( - key='parameterFilters.{index}.member.key'.format(index=(index + 1)), - value=key, - constraint='Member must have length less than or equal to 132', - )) + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.key".format( + index=(index + 1) + ), + value=key, + constraint="Member must have length less than or equal to 132", + ) + ) if len(option) > 10: - self._errors.append(self._format_error( - key='parameterFilters.{index}.member.option'.format(index=(index + 1)), - value='over 10 chars', - constraint='Member must have length less than or equal to 10', - )) + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.option".format( + index=(index + 1) + ), + value="over 10 chars", + constraint="Member must have length less than or equal to 10", + ) + ) if len(values) > 50: - self._errors.append(self._format_error( - key='parameterFilters.{index}.member.values'.format(index=(index + 1)), - value=values, - constraint='Member must have length less than or equal to 50', - )) + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.values".format( + index=(index + 1) + ), + value=values, + constraint="Member must have length less than or equal to 50", + ) + ) if any(len(value) > 1024 for value in values): - self._errors.append(self._format_error( - key='parameterFilters.{index}.member.values'.format(index=(index + 1)), - value=values, - constraint='[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]', - )) + self._errors.append( + self._format_error( + key="parameterFilters.{index}.member.values".format( + index=(index + 1) + ), + value=values, + constraint="[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]", + ) + ) self._raise_errors() filter_keys = [] - for filter_obj in (parameter_filters or []): - key = filter_obj['Key'] - values = filter_obj.get('Values') + for filter_obj in parameter_filters or []: + key = filter_obj["Key"] + values = filter_obj.get("Values") - if key == 'Path': - option = filter_obj.get('Option', 'OneLevel') + if key == "Path": + option = filter_obj.get("Option", "OneLevel") else: - option = filter_obj.get('Option', 'Equals') + option = filter_obj.get("Option", "Equals") - if not by_path and key == 'Label': - raise InvalidFilterKey('The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier].') + if not by_path and key == "Label": + raise InvalidFilterKey( + "The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier]." + ) if not values: - raise InvalidFilterValue('The following filter values are missing : null for filter key Name.') + raise InvalidFilterValue( + "The following filter values are missing : null for filter key Name." + ) if key in filter_keys: raise InvalidFilterKey( - 'The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter.' + "The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter." ) - if key == 'Path': - if option not in ['Recursive', 'OneLevel']: + if key == "Path": + if option not in ["Recursive", "OneLevel"]: raise InvalidFilterOption( - 'The following filter option is not valid: {option}. Valid options include: [Recursive, OneLevel].'.format(option=option) + "The following filter option is not valid: {option}. Valid options include: [Recursive, OneLevel].".format( + option=option + ) ) - if any(value.lower().startswith(('/aws', '/ssm')) for value in values): + if any(value.lower().startswith(("/aws", "/ssm")) for value in values): raise ValidationException( 'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' - 'When using global parameters, please specify within a global namespace.' + "When using global parameters, please specify within a global namespace." ) for value in values: - if value.lower().startswith(('/aws', '/ssm')): + if value.lower().startswith(("/aws", "/ssm")): raise ValidationException( 'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' - 'When using global parameters, please specify within a global namespace.' + "When using global parameters, please specify within a global namespace." ) - if ('//' in value or - not value.startswith('/') or - not re.match('^[a-zA-Z0-9_.-/]*$', value)): + if ( + "//" in value + or not value.startswith("/") + or not re.match("^[a-zA-Z0-9_.-/]*$", value) + ): raise ValidationException( 'The parameter doesn\'t meet the parameter name requirements. The parameter name must begin with a forward slash "/". ' - 'It can\'t be prefixed with \"aws\" or \"ssm\" (case-insensitive). ' - 'It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). ' + 'It can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' + "It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). " 'Special characters are not allowed. All sub-paths, if specified, must use the forward slash symbol "/". ' - 'Valid example: /get/parameters2-/by1./path0_.' + "Valid example: /get/parameters2-/by1./path0_." ) - if key == 'Tier': + if key == "Tier": for value in values: - if value not in ['Standard', 'Advanced', 'Intelligent-Tiering']: + if value not in ["Standard", "Advanced", "Intelligent-Tiering"]: raise InvalidFilterOption( - 'The following filter value is not valid: {value}. Valid values include: [Standard, Advanced, Intelligent-Tiering].'.format(value=value) + "The following filter value is not valid: {value}. Valid values include: [Standard, Advanced, Intelligent-Tiering].".format( + value=value + ) ) - if key == 'Type': + if key == "Type": for value in values: - if value not in ['String', 'StringList', 'SecureString']: + if value not in ["String", "StringList", "SecureString"]: raise InvalidFilterOption( - 'The following filter value is not valid: {value}. Valid values include: [String, StringList, SecureString].'.format(value=value) + "The following filter value is not valid: {value}. Valid values include: [String, StringList, SecureString].".format( + value=value + ) ) - if key != 'Path' and option not in ['Equals', 'BeginsWith']: + if key != "Path" and option not in ["Equals", "BeginsWith"]: raise InvalidFilterOption( - 'The following filter option is not valid: {option}. Valid options include: [BeginsWith, Equals].'.format(option=option) + "The following filter option is not valid: {option}. Valid options include: [BeginsWith, Equals].".format( + option=option + ) ) filter_keys.append(key) def _format_error(self, key, value, constraint): return 'Value "{value}" at "{key}" failed to satisfy constraint: {constraint}'.format( - constraint=constraint, - key=key, - value=value, + constraint=constraint, key=key, value=value ) def _raise_errors(self): @@ -415,9 +488,11 @@ class SimpleSystemManagerBackend(BaseBackend): errors = "; ".join(self._errors) self._errors = [] # reset collected errors - raise ValidationException('{count} validation error{plural} detected: {errors}'.format( - count=count, plural=plural, errors=errors, - )) + raise ValidationException( + "{count} validation error{plural} detected: {errors}".format( + count=count, plural=plural, errors=errors + ) + ) def get_all_parameters(self): result = [] @@ -437,11 +512,11 @@ class SimpleSystemManagerBackend(BaseBackend): result = [] # path could be with or without a trailing /. we handle this # difference here. - path = path.rstrip('/') + '/' + path = path.rstrip("/") + "/" for param in self._parameters: - if path != '/' and not param.startswith(path): + if path != "/" and not param.startswith(path): continue - if '/' in param[len(path) + 1:] and not recursive: + if "/" in param[len(path) + 1 :] and not recursive: continue if not self._match_filters(self._parameters[param], filters): continue @@ -451,48 +526,51 @@ class SimpleSystemManagerBackend(BaseBackend): def _match_filters(self, parameter, filters=None): """Return True if the given parameter matches all the filters""" - for filter_obj in (filters or []): - key = filter_obj['Key'] - values = filter_obj.get('Values', []) + for filter_obj in filters or []: + key = filter_obj["Key"] + values = filter_obj.get("Values", []) - if key == 'Path': - option = filter_obj.get('Option', 'OneLevel') + if key == "Path": + option = filter_obj.get("Option", "OneLevel") else: - option = filter_obj.get('Option', 'Equals') + option = filter_obj.get("Option", "Equals") what = None - if key == 'KeyId': + if key == "KeyId": what = parameter.keyid - elif key == 'Name': - what = '/' + parameter.name.lstrip('/') - values = ['/' + value.lstrip('/') for value in values] - elif key == 'Path': - what = '/' + parameter.name.lstrip('/') - values = ['/' + value.strip('/') for value in values] - elif key == 'Type': + elif key == "Name": + what = "/" + parameter.name.lstrip("/") + values = ["/" + value.lstrip("/") for value in values] + elif key == "Path": + what = "/" + parameter.name.lstrip("/") + values = ["/" + value.strip("/") for value in values] + elif key == "Type": what = parameter.type if what is None: return False - elif (option == 'BeginsWith' and - not any(what.startswith(value) for value in values)): + elif option == "BeginsWith" and not any( + what.startswith(value) for value in values + ): return False - elif (option == 'Equals' and - not any(what == value for value in values)): + elif option == "Equals" and not any(what == value for value in values): return False - elif option == 'OneLevel': - if any(value == '/' and len(what.split('/')) == 2 for value in values): + elif option == "OneLevel": + if any(value == "/" and len(what.split("/")) == 2 for value in values): continue - elif any(value != '/' and - what.startswith(value + '/') and - len(what.split('/')) - 1 == len(value.split('/')) for value in values): + elif any( + value != "/" + and what.startswith(value + "/") + and len(what.split("/")) - 1 == len(value.split("/")) + for value in values + ): continue else: return False - elif option == 'Recursive': - if any(value == '/' for value in values): + elif option == "Recursive": + if any(value == "/" for value in values): continue - elif any(what.startswith(value + '/') for value in values): + elif any(what.startswith(value + "/") for value in values): continue else: return False @@ -504,8 +582,9 @@ class SimpleSystemManagerBackend(BaseBackend): return self._parameters[name] return None - def put_parameter(self, name, description, value, type, allowed_pattern, - keyid, overwrite): + def put_parameter( + self, name, description, value, type, allowed_pattern, keyid, overwrite + ): previous_parameter = self._parameters.get(name) version = 1 @@ -516,8 +595,16 @@ class SimpleSystemManagerBackend(BaseBackend): return last_modified_date = time.time() - self._parameters[name] = Parameter(name, value, type, description, - allowed_pattern, keyid, last_modified_date, version) + self._parameters[name] = Parameter( + name, + value, + type, + description, + allowed_pattern, + keyid, + last_modified_date, + version, + ) return version def add_tags_to_resource(self, resource_type, resource_id, tags): @@ -535,29 +622,31 @@ class SimpleSystemManagerBackend(BaseBackend): def send_command(self, **kwargs): command = Command( - comment=kwargs.get('Comment', ''), - document_name=kwargs.get('DocumentName'), - timeout_seconds=kwargs.get('TimeoutSeconds', 3600), - instance_ids=kwargs.get('InstanceIds', []), - max_concurrency=kwargs.get('MaxConcurrency', '50'), - max_errors=kwargs.get('MaxErrors', '0'), - notification_config=kwargs.get('NotificationConfig', { - 'NotificationArn': 'string', - 'NotificationEvents': ['Success'], - 'NotificationType': 'Command' - }), - output_s3_bucket_name=kwargs.get('OutputS3BucketName', ''), - output_s3_key_prefix=kwargs.get('OutputS3KeyPrefix', ''), - output_s3_region=kwargs.get('OutputS3Region', ''), - parameters=kwargs.get('Parameters', {}), - service_role_arn=kwargs.get('ServiceRoleArn', ''), - targets=kwargs.get('Targets', []), - backend_region=self._region) + comment=kwargs.get("Comment", ""), + document_name=kwargs.get("DocumentName"), + timeout_seconds=kwargs.get("TimeoutSeconds", 3600), + instance_ids=kwargs.get("InstanceIds", []), + max_concurrency=kwargs.get("MaxConcurrency", "50"), + max_errors=kwargs.get("MaxErrors", "0"), + notification_config=kwargs.get( + "NotificationConfig", + { + "NotificationArn": "string", + "NotificationEvents": ["Success"], + "NotificationType": "Command", + }, + ), + output_s3_bucket_name=kwargs.get("OutputS3BucketName", ""), + output_s3_key_prefix=kwargs.get("OutputS3KeyPrefix", ""), + output_s3_region=kwargs.get("OutputS3Region", ""), + parameters=kwargs.get("Parameters", {}), + service_role_arn=kwargs.get("ServiceRoleArn", ""), + targets=kwargs.get("Targets", []), + backend_region=self._region, + ) self._commands.append(command) - return { - 'Command': command.response_object() - } + return {"Command": command.response_object()} def list_commands(self, **kwargs): """ @@ -565,39 +654,38 @@ class SimpleSystemManagerBackend(BaseBackend): """ commands = self._commands - command_id = kwargs.get('CommandId', None) + command_id = kwargs.get("CommandId", None) if command_id: commands = [self.get_command_by_id(command_id)] - instance_id = kwargs.get('InstanceId', None) + instance_id = kwargs.get("InstanceId", None) if instance_id: commands = self.get_commands_by_instance_id(instance_id) - return { - 'Commands': [command.response_object() for command in commands] - } + return {"Commands": [command.response_object() for command in commands]} def get_command_by_id(self, id): command = next( - (command for command in self._commands if command.command_id == id), None) + (command for command in self._commands if command.command_id == id), None + ) if command is None: - raise RESTError('InvalidCommandId', 'Invalid command id.') + raise RESTError("InvalidCommandId", "Invalid command id.") return command def get_commands_by_instance_id(self, instance_id): return [ - command for command in self._commands - if instance_id in command.instance_ids] + command for command in self._commands if instance_id in command.instance_ids + ] def get_command_invocation(self, **kwargs): """ https://docs.aws.amazon.com/systems-manager/latest/APIReference/API_GetCommandInvocation.html """ - command_id = kwargs.get('CommandId') - instance_id = kwargs.get('InstanceId') - plugin_name = kwargs.get('PluginName', None) + command_id = kwargs.get("CommandId") + instance_id = kwargs.get("InstanceId") + plugin_name = kwargs.get("PluginName", None) command = self.get_command_by_id(command_id) return command.get_invocation(instance_id, plugin_name) diff --git a/moto/ssm/responses.py b/moto/ssm/responses.py index 27a5f8e35..0bb034428 100644 --- a/moto/ssm/responses.py +++ b/moto/ssm/responses.py @@ -6,7 +6,6 @@ from .models import ssm_backends class SimpleSystemManagerResponse(BaseResponse): - @property def ssm_backend(self): return ssm_backends[self.region] @@ -22,171 +21,151 @@ class SimpleSystemManagerResponse(BaseResponse): return self.request_params.get(param, default) def delete_parameter(self): - name = self._get_param('Name') + name = self._get_param("Name") self.ssm_backend.delete_parameter(name) return json.dumps({}) def delete_parameters(self): - names = self._get_param('Names') + names = self._get_param("Names") result = self.ssm_backend.delete_parameters(names) - response = { - 'DeletedParameters': [], - 'InvalidParameters': [] - } + response = {"DeletedParameters": [], "InvalidParameters": []} for name in names: if name in result: - response['DeletedParameters'].append(name) + response["DeletedParameters"].append(name) else: - response['InvalidParameters'].append(name) + response["InvalidParameters"].append(name) return json.dumps(response) def get_parameter(self): - name = self._get_param('Name') - with_decryption = self._get_param('WithDecryption') + name = self._get_param("Name") + with_decryption = self._get_param("WithDecryption") result = self.ssm_backend.get_parameter(name, with_decryption) if result is None: error = { - '__type': 'ParameterNotFound', - 'message': 'Parameter {0} not found.'.format(name) + "__type": "ParameterNotFound", + "message": "Parameter {0} not found.".format(name), } return json.dumps(error), dict(status=400) - response = { - 'Parameter': result.response_object(with_decryption) - } + response = {"Parameter": result.response_object(with_decryption)} return json.dumps(response) def get_parameters(self): - names = self._get_param('Names') - with_decryption = self._get_param('WithDecryption') + names = self._get_param("Names") + with_decryption = self._get_param("WithDecryption") result = self.ssm_backend.get_parameters(names, with_decryption) - response = { - 'Parameters': [], - 'InvalidParameters': [], - } + response = {"Parameters": [], "InvalidParameters": []} for parameter in result: param_data = parameter.response_object(with_decryption) - response['Parameters'].append(param_data) + response["Parameters"].append(param_data) param_names = [param.name for param in result] for name in names: if name not in param_names: - response['InvalidParameters'].append(name) + response["InvalidParameters"].append(name) return json.dumps(response) def get_parameters_by_path(self): - path = self._get_param('Path') - with_decryption = self._get_param('WithDecryption') - recursive = self._get_param('Recursive', False) - filters = self._get_param('ParameterFilters') + path = self._get_param("Path") + with_decryption = self._get_param("WithDecryption") + recursive = self._get_param("Recursive", False) + filters = self._get_param("ParameterFilters") result = self.ssm_backend.get_parameters_by_path( path, with_decryption, recursive, filters ) - response = { - 'Parameters': [], - } + response = {"Parameters": []} for parameter in result: param_data = parameter.response_object(with_decryption) - response['Parameters'].append(param_data) + response["Parameters"].append(param_data) return json.dumps(response) def describe_parameters(self): page_size = 10 - filters = self._get_param('Filters') - parameter_filters = self._get_param('ParameterFilters') - token = self._get_param('NextToken') - if hasattr(token, 'strip'): + filters = self._get_param("Filters") + parameter_filters = self._get_param("ParameterFilters") + token = self._get_param("NextToken") + if hasattr(token, "strip"): token = token.strip() if not token: - token = '0' + token = "0" token = int(token) - result = self.ssm_backend.describe_parameters( - filters, parameter_filters - ) + result = self.ssm_backend.describe_parameters(filters, parameter_filters) - response = { - 'Parameters': [], - } + response = {"Parameters": []} end = token + page_size for parameter in result[token:]: - response['Parameters'].append(parameter.describe_response_object(False)) + response["Parameters"].append(parameter.describe_response_object(False)) token = token + 1 - if len(response['Parameters']) == page_size: - response['NextToken'] = str(end) + if len(response["Parameters"]) == page_size: + response["NextToken"] = str(end) break return json.dumps(response) def put_parameter(self): - name = self._get_param('Name') - description = self._get_param('Description') - value = self._get_param('Value') - type_ = self._get_param('Type') - allowed_pattern = self._get_param('AllowedPattern') - keyid = self._get_param('KeyId') - overwrite = self._get_param('Overwrite', False) + name = self._get_param("Name") + description = self._get_param("Description") + value = self._get_param("Value") + type_ = self._get_param("Type") + allowed_pattern = self._get_param("AllowedPattern") + keyid = self._get_param("KeyId") + overwrite = self._get_param("Overwrite", False) result = self.ssm_backend.put_parameter( - name, description, value, type_, allowed_pattern, keyid, overwrite) + name, description, value, type_, allowed_pattern, keyid, overwrite + ) if result is None: error = { - '__type': 'ParameterAlreadyExists', - 'message': 'Parameter {0} already exists.'.format(name) + "__type": "ParameterAlreadyExists", + "message": "Parameter {0} already exists.".format(name), } return json.dumps(error), dict(status=400) - response = {'Version': result} + response = {"Version": result} return json.dumps(response) def add_tags_to_resource(self): - resource_id = self._get_param('ResourceId') - resource_type = self._get_param('ResourceType') - tags = {t['Key']: t['Value'] for t in self._get_param('Tags')} - self.ssm_backend.add_tags_to_resource( - resource_id, resource_type, tags) + resource_id = self._get_param("ResourceId") + resource_type = self._get_param("ResourceType") + tags = {t["Key"]: t["Value"] for t in self._get_param("Tags")} + self.ssm_backend.add_tags_to_resource(resource_id, resource_type, tags) return json.dumps({}) def remove_tags_from_resource(self): - resource_id = self._get_param('ResourceId') - resource_type = self._get_param('ResourceType') - keys = self._get_param('TagKeys') - self.ssm_backend.remove_tags_from_resource( - resource_id, resource_type, keys) + resource_id = self._get_param("ResourceId") + resource_type = self._get_param("ResourceType") + keys = self._get_param("TagKeys") + self.ssm_backend.remove_tags_from_resource(resource_id, resource_type, keys) return json.dumps({}) def list_tags_for_resource(self): - resource_id = self._get_param('ResourceId') - resource_type = self._get_param('ResourceType') - tags = self.ssm_backend.list_tags_for_resource( - resource_id, resource_type) - tag_list = [{'Key': k, 'Value': v} for (k, v) in tags.items()] - response = {'TagList': tag_list} + resource_id = self._get_param("ResourceId") + resource_type = self._get_param("ResourceType") + tags = self.ssm_backend.list_tags_for_resource(resource_id, resource_type) + tag_list = [{"Key": k, "Value": v} for (k, v) in tags.items()] + response = {"TagList": tag_list} return json.dumps(response) def send_command(self): - return json.dumps( - self.ssm_backend.send_command(**self.request_params) - ) + return json.dumps(self.ssm_backend.send_command(**self.request_params)) def list_commands(self): - return json.dumps( - self.ssm_backend.list_commands(**self.request_params) - ) + return json.dumps(self.ssm_backend.list_commands(**self.request_params)) def get_command_invocation(self): return json.dumps( diff --git a/moto/ssm/urls.py b/moto/ssm/urls.py index 9ac327325..bd6706dfa 100644 --- a/moto/ssm/urls.py +++ b/moto/ssm/urls.py @@ -1,11 +1,6 @@ from __future__ import unicode_literals from .responses import SimpleSystemManagerResponse -url_bases = [ - "https?://ssm.(.+).amazonaws.com", - "https?://ssm.(.+).amazonaws.com.cn", -] +url_bases = ["https?://ssm.(.+).amazonaws.com", "https?://ssm.(.+).amazonaws.com.cn"] -url_paths = { - '{0}/$': SimpleSystemManagerResponse.dispatch, -} +url_paths = {"{0}/$": SimpleSystemManagerResponse.dispatch} diff --git a/moto/stepfunctions/__init__.py b/moto/stepfunctions/__init__.py index dc2b0ba13..6dd50c9dc 100644 --- a/moto/stepfunctions/__init__.py +++ b/moto/stepfunctions/__init__.py @@ -2,5 +2,5 @@ from __future__ import unicode_literals from .models import stepfunction_backends from ..core.models import base_decorator -stepfunction_backend = stepfunction_backends['us-east-1'] +stepfunction_backend = stepfunction_backends["us-east-1"] mock_stepfunctions = base_decorator(stepfunction_backends) diff --git a/moto/stepfunctions/exceptions.py b/moto/stepfunctions/exceptions.py index 8af4686c7..704e4ea83 100644 --- a/moto/stepfunctions/exceptions.py +++ b/moto/stepfunctions/exceptions.py @@ -12,24 +12,27 @@ class AWSError(Exception): self.status = status if status is not None else self.STATUS def response(self): - return json.dumps({'__type': self.type, 'message': self.message}), dict(status=self.status) + return ( + json.dumps({"__type": self.type, "message": self.message}), + dict(status=self.status), + ) class ExecutionDoesNotExist(AWSError): - TYPE = 'ExecutionDoesNotExist' + TYPE = "ExecutionDoesNotExist" STATUS = 400 class InvalidArn(AWSError): - TYPE = 'InvalidArn' + TYPE = "InvalidArn" STATUS = 400 class InvalidName(AWSError): - TYPE = 'InvalidName' + TYPE = "InvalidName" STATUS = 400 class StateMachineDoesNotExist(AWSError): - TYPE = 'StateMachineDoesNotExist' + TYPE = "StateMachineDoesNotExist" STATUS = 400 diff --git a/moto/stepfunctions/models.py b/moto/stepfunctions/models.py index fedcb8e77..665f3b777 100644 --- a/moto/stepfunctions/models.py +++ b/moto/stepfunctions/models.py @@ -5,10 +5,15 @@ from moto.core import BaseBackend from moto.core.utils import iso_8601_datetime_without_milliseconds from moto.sts.models import ACCOUNT_ID from uuid import uuid4 -from .exceptions import ExecutionDoesNotExist, InvalidArn, InvalidName, StateMachineDoesNotExist +from .exceptions import ( + ExecutionDoesNotExist, + InvalidArn, + InvalidName, + StateMachineDoesNotExist, +) -class StateMachine(): +class StateMachine: def __init__(self, arn, name, definition, roleArn, tags=None): self.creation_date = iso_8601_datetime_without_milliseconds(datetime.now()) self.arn = arn @@ -18,19 +23,28 @@ class StateMachine(): self.tags = tags -class Execution(): - def __init__(self, region_name, account_id, state_machine_name, execution_name, state_machine_arn): - execution_arn = 'arn:aws:states:{}:{}:execution:{}:{}' - execution_arn = execution_arn.format(region_name, account_id, state_machine_name, execution_name) +class Execution: + def __init__( + self, + region_name, + account_id, + state_machine_name, + execution_name, + state_machine_arn, + ): + execution_arn = "arn:aws:states:{}:{}:execution:{}:{}" + execution_arn = execution_arn.format( + region_name, account_id, state_machine_name, execution_name + ) self.execution_arn = execution_arn self.name = execution_name self.start_date = iso_8601_datetime_without_milliseconds(datetime.now()) self.state_machine_arn = state_machine_arn - self.status = 'RUNNING' + self.status = "RUNNING" self.stop_date = None def stop(self): - self.status = 'SUCCEEDED' + self.status = "SUCCEEDED" self.stop_date = iso_8601_datetime_without_milliseconds(datetime.now()) @@ -42,26 +56,108 @@ class StepFunctionBackend(BaseBackend): # brackets < > { } [ ] # wildcard characters ? * # special characters " # % \ ^ | ~ ` $ & , ; : / - invalid_chars_for_name = [' ', '{', '}', '[', ']', '<', '>', - '?', '*', - '"', '#', '%', '\\', '^', '|', '~', '`', '$', '&', ',', ';', ':', '/'] + invalid_chars_for_name = [ + " ", + "{", + "}", + "[", + "]", + "<", + ">", + "?", + "*", + '"', + "#", + "%", + "\\", + "^", + "|", + "~", + "`", + "$", + "&", + ",", + ";", + ":", + "/", + ] # control characters (U+0000-001F , U+007F-009F ) - invalid_unicodes_for_name = [u'\u0000', u'\u0001', u'\u0002', u'\u0003', u'\u0004', - u'\u0005', u'\u0006', u'\u0007', u'\u0008', u'\u0009', - u'\u000A', u'\u000B', u'\u000C', u'\u000D', u'\u000E', u'\u000F', - u'\u0010', u'\u0011', u'\u0012', u'\u0013', u'\u0014', - u'\u0015', u'\u0016', u'\u0017', u'\u0018', u'\u0019', - u'\u001A', u'\u001B', u'\u001C', u'\u001D', u'\u001E', u'\u001F', - u'\u007F', - u'\u0080', u'\u0081', u'\u0082', u'\u0083', u'\u0084', u'\u0085', - u'\u0086', u'\u0087', u'\u0088', u'\u0089', - u'\u008A', u'\u008B', u'\u008C', u'\u008D', u'\u008E', u'\u008F', - u'\u0090', u'\u0091', u'\u0092', u'\u0093', u'\u0094', u'\u0095', - u'\u0096', u'\u0097', u'\u0098', u'\u0099', - u'\u009A', u'\u009B', u'\u009C', u'\u009D', u'\u009E', u'\u009F'] - accepted_role_arn_format = re.compile('arn:aws:iam::(?P[0-9]{12}):role/.+') - accepted_mchn_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):stateMachine:.+') - accepted_exec_arn_format = re.compile('arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):execution:.+') + invalid_unicodes_for_name = [ + u"\u0000", + u"\u0001", + u"\u0002", + u"\u0003", + u"\u0004", + u"\u0005", + u"\u0006", + u"\u0007", + u"\u0008", + u"\u0009", + u"\u000A", + u"\u000B", + u"\u000C", + u"\u000D", + u"\u000E", + u"\u000F", + u"\u0010", + u"\u0011", + u"\u0012", + u"\u0013", + u"\u0014", + u"\u0015", + u"\u0016", + u"\u0017", + u"\u0018", + u"\u0019", + u"\u001A", + u"\u001B", + u"\u001C", + u"\u001D", + u"\u001E", + u"\u001F", + u"\u007F", + u"\u0080", + u"\u0081", + u"\u0082", + u"\u0083", + u"\u0084", + u"\u0085", + u"\u0086", + u"\u0087", + u"\u0088", + u"\u0089", + u"\u008A", + u"\u008B", + u"\u008C", + u"\u008D", + u"\u008E", + u"\u008F", + u"\u0090", + u"\u0091", + u"\u0092", + u"\u0093", + u"\u0094", + u"\u0095", + u"\u0096", + u"\u0097", + u"\u0098", + u"\u0099", + u"\u009A", + u"\u009B", + u"\u009C", + u"\u009D", + u"\u009E", + u"\u009F", + ] + accepted_role_arn_format = re.compile( + "arn:aws:iam::(?P[0-9]{12}):role/.+" + ) + accepted_mchn_arn_format = re.compile( + "arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):stateMachine:.+" + ) + accepted_exec_arn_format = re.compile( + "arn:aws:states:[-0-9a-zA-Z]+:(?P[0-9]{12}):execution:.+" + ) def __init__(self, region_name): self.state_machines = [] @@ -72,7 +168,14 @@ class StepFunctionBackend(BaseBackend): def create_state_machine(self, name, definition, roleArn, tags=None): self._validate_name(name) self._validate_role_arn(roleArn) - arn = 'arn:aws:states:' + self.region_name + ':' + str(self._get_account_id()) + ':stateMachine:' + name + arn = ( + "arn:aws:states:" + + self.region_name + + ":" + + str(self._get_account_id()) + + ":stateMachine:" + + name + ) try: return self.describe_state_machine(arn) except StateMachineDoesNotExist: @@ -87,7 +190,9 @@ class StepFunctionBackend(BaseBackend): self._validate_machine_arn(arn) sm = next((x for x in self.state_machines if x.arn == arn), None) if not sm: - raise StateMachineDoesNotExist("State Machine Does Not Exist: '" + arn + "'") + raise StateMachineDoesNotExist( + "State Machine Does Not Exist: '" + arn + "'" + ) return sm def delete_state_machine(self, arn): @@ -98,23 +203,33 @@ class StepFunctionBackend(BaseBackend): def start_execution(self, state_machine_arn, name=None): state_machine_name = self.describe_state_machine(state_machine_arn).name - execution = Execution(region_name=self.region_name, - account_id=self._get_account_id(), - state_machine_name=state_machine_name, - execution_name=name or str(uuid4()), - state_machine_arn=state_machine_arn) + execution = Execution( + region_name=self.region_name, + account_id=self._get_account_id(), + state_machine_name=state_machine_name, + execution_name=name or str(uuid4()), + state_machine_arn=state_machine_arn, + ) self.executions.append(execution) return execution def stop_execution(self, execution_arn): - execution = next((x for x in self.executions if x.execution_arn == execution_arn), None) + execution = next( + (x for x in self.executions if x.execution_arn == execution_arn), None + ) if not execution: - raise ExecutionDoesNotExist("Execution Does Not Exist: '" + execution_arn + "'") + raise ExecutionDoesNotExist( + "Execution Does Not Exist: '" + execution_arn + "'" + ) execution.stop() return execution def list_executions(self, state_machine_arn): - return [execution for execution in self.executions if execution.state_machine_arn == state_machine_arn] + return [ + execution + for execution in self.executions + if execution.state_machine_arn == state_machine_arn + ] def describe_execution(self, arn): self._validate_execution_arn(arn) @@ -136,19 +251,25 @@ class StepFunctionBackend(BaseBackend): raise InvalidName("Invalid Name: '" + name + "'") def _validate_role_arn(self, role_arn): - self._validate_arn(arn=role_arn, - regex=self.accepted_role_arn_format, - invalid_msg="Invalid Role Arn: '" + role_arn + "'") + self._validate_arn( + arn=role_arn, + regex=self.accepted_role_arn_format, + invalid_msg="Invalid Role Arn: '" + role_arn + "'", + ) def _validate_machine_arn(self, machine_arn): - self._validate_arn(arn=machine_arn, - regex=self.accepted_mchn_arn_format, - invalid_msg="Invalid State Machine Arn: '" + machine_arn + "'") + self._validate_arn( + arn=machine_arn, + regex=self.accepted_mchn_arn_format, + invalid_msg="Invalid State Machine Arn: '" + machine_arn + "'", + ) def _validate_execution_arn(self, execution_arn): - self._validate_arn(arn=execution_arn, - regex=self.accepted_exec_arn_format, - invalid_msg="Execution Does Not Exist: '" + execution_arn + "'") + self._validate_arn( + arn=execution_arn, + regex=self.accepted_exec_arn_format, + invalid_msg="Execution Does Not Exist: '" + execution_arn + "'", + ) def _validate_arn(self, arn, regex, invalid_msg): match = regex.match(arn) @@ -159,4 +280,7 @@ class StepFunctionBackend(BaseBackend): return ACCOUNT_ID -stepfunction_backends = {_region.name: StepFunctionBackend(_region.name) for _region in boto.awslambda.regions()} +stepfunction_backends = { + _region.name: StepFunctionBackend(_region.name) + for _region in boto.awslambda.regions() +} diff --git a/moto/stepfunctions/responses.py b/moto/stepfunctions/responses.py index 902a860e5..689961d5a 100644 --- a/moto/stepfunctions/responses.py +++ b/moto/stepfunctions/responses.py @@ -9,24 +9,23 @@ from .models import stepfunction_backends class StepFunctionResponse(BaseResponse): - @property def stepfunction_backend(self): return stepfunction_backends[self.region] @amzn_request_id def create_state_machine(self): - name = self._get_param('name') - definition = self._get_param('definition') - roleArn = self._get_param('roleArn') - tags = self._get_param('tags') + name = self._get_param("name") + definition = self._get_param("definition") + roleArn = self._get_param("roleArn") + tags = self._get_param("tags") try: - state_machine = self.stepfunction_backend.create_state_machine(name=name, definition=definition, - roleArn=roleArn, - tags=tags) + state_machine = self.stepfunction_backend.create_state_machine( + name=name, definition=definition, roleArn=roleArn, tags=tags + ) response = { - 'creationDate': state_machine.creation_date, - 'stateMachineArn': state_machine.arn + "creationDate": state_machine.creation_date, + "stateMachineArn": state_machine.arn, } return 200, {}, json.dumps(response) except AWSError as err: @@ -35,29 +34,38 @@ class StepFunctionResponse(BaseResponse): @amzn_request_id def list_state_machines(self): list_all = self.stepfunction_backend.list_state_machines() - list_all = sorted([{'creationDate': sm.creation_date, - 'name': sm.name, - 'stateMachineArn': sm.arn} for sm in list_all], - key=lambda x: x['name']) - response = {'stateMachines': list_all} + list_all = sorted( + [ + { + "creationDate": sm.creation_date, + "name": sm.name, + "stateMachineArn": sm.arn, + } + for sm in list_all + ], + key=lambda x: x["name"], + ) + response = {"stateMachines": list_all} return 200, {}, json.dumps(response) @amzn_request_id def describe_state_machine(self): - arn = self._get_param('stateMachineArn') + arn = self._get_param("stateMachineArn") return self._describe_state_machine(arn) @amzn_request_id def _describe_state_machine(self, state_machine_arn): try: - state_machine = self.stepfunction_backend.describe_state_machine(state_machine_arn) + state_machine = self.stepfunction_backend.describe_state_machine( + state_machine_arn + ) response = { - 'creationDate': state_machine.creation_date, - 'stateMachineArn': state_machine.arn, - 'definition': state_machine.definition, - 'name': state_machine.name, - 'roleArn': state_machine.roleArn, - 'status': 'ACTIVE' + "creationDate": state_machine.creation_date, + "stateMachineArn": state_machine.arn, + "definition": state_machine.definition, + "name": state_machine.name, + "roleArn": state_machine.roleArn, + "status": "ACTIVE", } return 200, {}, json.dumps(response) except AWSError as err: @@ -65,58 +73,65 @@ class StepFunctionResponse(BaseResponse): @amzn_request_id def delete_state_machine(self): - arn = self._get_param('stateMachineArn') + arn = self._get_param("stateMachineArn") try: self.stepfunction_backend.delete_state_machine(arn) - return 200, {}, json.dumps('{}') + return 200, {}, json.dumps("{}") except AWSError as err: return err.response() @amzn_request_id def list_tags_for_resource(self): - arn = self._get_param('resourceArn') + arn = self._get_param("resourceArn") try: state_machine = self.stepfunction_backend.describe_state_machine(arn) tags = state_machine.tags or [] except AWSError: tags = [] - response = {'tags': tags} + response = {"tags": tags} return 200, {}, json.dumps(response) @amzn_request_id def start_execution(self): - arn = self._get_param('stateMachineArn') - name = self._get_param('name') + arn = self._get_param("stateMachineArn") + name = self._get_param("name") execution = self.stepfunction_backend.start_execution(arn, name) - response = {'executionArn': execution.execution_arn, - 'startDate': execution.start_date} + response = { + "executionArn": execution.execution_arn, + "startDate": execution.start_date, + } return 200, {}, json.dumps(response) @amzn_request_id def list_executions(self): - arn = self._get_param('stateMachineArn') + arn = self._get_param("stateMachineArn") state_machine = self.stepfunction_backend.describe_state_machine(arn) executions = self.stepfunction_backend.list_executions(arn) - executions = [{'executionArn': execution.execution_arn, - 'name': execution.name, - 'startDate': execution.start_date, - 'stateMachineArn': state_machine.arn, - 'status': execution.status} for execution in executions] - return 200, {}, json.dumps({'executions': executions}) + executions = [ + { + "executionArn": execution.execution_arn, + "name": execution.name, + "startDate": execution.start_date, + "stateMachineArn": state_machine.arn, + "status": execution.status, + } + for execution in executions + ] + return 200, {}, json.dumps({"executions": executions}) @amzn_request_id def describe_execution(self): - arn = self._get_param('executionArn') + arn = self._get_param("executionArn") try: execution = self.stepfunction_backend.describe_execution(arn) response = { - 'executionArn': arn, - 'input': '{}', - 'name': execution.name, - 'startDate': execution.start_date, - 'stateMachineArn': execution.state_machine_arn, - 'status': execution.status, - 'stopDate': execution.stop_date + "executionArn": arn, + "input": "{}", + "name": execution.name, + "startDate": execution.start_date, + "stateMachineArn": execution.state_machine_arn, + "status": execution.status, + "stopDate": execution.stop_date, } return 200, {}, json.dumps(response) except AWSError as err: @@ -124,7 +139,7 @@ class StepFunctionResponse(BaseResponse): @amzn_request_id def describe_state_machine_for_execution(self): - arn = self._get_param('executionArn') + arn = self._get_param("executionArn") try: execution = self.stepfunction_backend.describe_execution(arn) return self._describe_state_machine(execution.state_machine_arn) @@ -133,7 +148,7 @@ class StepFunctionResponse(BaseResponse): @amzn_request_id def stop_execution(self): - arn = self._get_param('executionArn') + arn = self._get_param("executionArn") execution = self.stepfunction_backend.stop_execution(arn) - response = {'stopDate': execution.stop_date} + response = {"stopDate": execution.stop_date} return 200, {}, json.dumps(response) diff --git a/moto/stepfunctions/urls.py b/moto/stepfunctions/urls.py index f8d5fb1e8..46dfd4e24 100644 --- a/moto/stepfunctions/urls.py +++ b/moto/stepfunctions/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import StepFunctionResponse -url_bases = [ - "https?://states.(.+).amazonaws.com", -] +url_bases = ["https?://states.(.+).amazonaws.com"] -url_paths = { - '{0}/$': StepFunctionResponse.dispatch, -} +url_paths = {"{0}/$": StepFunctionResponse.dispatch} diff --git a/moto/sts/exceptions.py b/moto/sts/exceptions.py index bddb56e3f..1acda9288 100644 --- a/moto/sts/exceptions.py +++ b/moto/sts/exceptions.py @@ -7,9 +7,5 @@ class STSClientError(RESTError): class STSValidationError(STSClientError): - def __init__(self, *args, **kwargs): - super(STSValidationError, self).__init__( - "ValidationError", - *args, **kwargs - ) + super(STSValidationError, self).__init__("ValidationError", *args, **kwargs) diff --git a/moto/sts/models.py b/moto/sts/models.py index c2ff7a8d3..d3afc9904 100644 --- a/moto/sts/models.py +++ b/moto/sts/models.py @@ -3,11 +3,15 @@ import datetime from moto.core import BaseBackend, BaseModel from moto.core.utils import iso_8601_datetime_with_milliseconds from moto.iam.models import ACCOUNT_ID -from moto.sts.utils import random_access_key_id, random_secret_access_key, random_session_token, random_assumed_role_id +from moto.sts.utils import ( + random_access_key_id, + random_secret_access_key, + random_session_token, + random_assumed_role_id, +) class Token(BaseModel): - def __init__(self, duration, name=None, policy=None): now = datetime.datetime.utcnow() self.expiration = now + datetime.timedelta(seconds=duration) @@ -20,7 +24,6 @@ class Token(BaseModel): class AssumedRole(BaseModel): - def __init__(self, role_session_name, role_arn, policy, duration, external_id): self.session_name = role_session_name self.role_arn = role_arn @@ -46,12 +49,11 @@ class AssumedRole(BaseModel): return "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( account_id=ACCOUNT_ID, role_name=self.role_arn.split("/")[-1], - session_name=self.session_name + session_name=self.session_name, ) class STSBackend(BaseBackend): - def __init__(self): self.assumed_roles = [] diff --git a/moto/sts/responses.py b/moto/sts/responses.py index 496b81682..f6d8647c4 100644 --- a/moto/sts/responses.py +++ b/moto/sts/responses.py @@ -10,38 +10,38 @@ MAX_FEDERATION_TOKEN_POLICY_LENGTH = 2048 class TokenResponse(BaseResponse): - def get_session_token(self): - duration = int(self.querystring.get('DurationSeconds', [43200])[0]) + duration = int(self.querystring.get("DurationSeconds", [43200])[0]) token = sts_backend.get_session_token(duration=duration) template = self.response_template(GET_SESSION_TOKEN_RESPONSE) return template.render(token=token) def get_federation_token(self): - duration = int(self.querystring.get('DurationSeconds', [43200])[0]) - policy = self.querystring.get('Policy', [None])[0] + duration = int(self.querystring.get("DurationSeconds", [43200])[0]) + policy = self.querystring.get("Policy", [None])[0] if policy is not None and len(policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH: raise STSValidationError( "1 validation error detected: Value " - "'{\"Version\": \"2012-10-17\", \"Statement\": [...]}' " + '\'{"Version": "2012-10-17", "Statement": [...]}\' ' "at 'policy' failed to satisfy constraint: Member must have length less than or " " equal to %s" % MAX_FEDERATION_TOKEN_POLICY_LENGTH ) - name = self.querystring.get('Name')[0] + name = self.querystring.get("Name")[0] token = sts_backend.get_federation_token( - duration=duration, name=name, policy=policy) + duration=duration, name=name, policy=policy + ) template = self.response_template(GET_FEDERATION_TOKEN_RESPONSE) return template.render(token=token, account_id=ACCOUNT_ID) def assume_role(self): - role_session_name = self.querystring.get('RoleSessionName')[0] - role_arn = self.querystring.get('RoleArn')[0] + role_session_name = self.querystring.get("RoleSessionName")[0] + role_arn = self.querystring.get("RoleArn")[0] - policy = self.querystring.get('Policy', [None])[0] - duration = int(self.querystring.get('DurationSeconds', [3600])[0]) - external_id = self.querystring.get('ExternalId', [None])[0] + policy = self.querystring.get("Policy", [None])[0] + duration = int(self.querystring.get("DurationSeconds", [3600])[0]) + external_id = self.querystring.get("ExternalId", [None])[0] role = sts_backend.assume_role( role_session_name=role_session_name, @@ -54,12 +54,12 @@ class TokenResponse(BaseResponse): return template.render(role=role) def assume_role_with_web_identity(self): - role_session_name = self.querystring.get('RoleSessionName')[0] - role_arn = self.querystring.get('RoleArn')[0] + role_session_name = self.querystring.get("RoleSessionName")[0] + role_arn = self.querystring.get("RoleArn")[0] - policy = self.querystring.get('Policy', [None])[0] - duration = int(self.querystring.get('DurationSeconds', [3600])[0]) - external_id = self.querystring.get('ExternalId', [None])[0] + policy = self.querystring.get("Policy", [None])[0] + duration = int(self.querystring.get("DurationSeconds", [3600])[0]) + external_id = self.querystring.get("ExternalId", [None])[0] role = sts_backend.assume_role_with_web_identity( role_session_name=role_session_name, diff --git a/moto/sts/urls.py b/moto/sts/urls.py index 2078e0b2c..e110f39df 100644 --- a/moto/sts/urls.py +++ b/moto/sts/urls.py @@ -1,10 +1,6 @@ from __future__ import unicode_literals from .responses import TokenResponse -url_bases = [ - "https?://sts(.*).amazonaws.com" -] +url_bases = ["https?://sts(.*).amazonaws.com"] -url_paths = { - '{0}/$': TokenResponse.dispatch, -} +url_paths = {"{0}/$": TokenResponse.dispatch} diff --git a/moto/sts/utils.py b/moto/sts/utils.py index 50767729f..1e8a13569 100644 --- a/moto/sts/utils.py +++ b/moto/sts/utils.py @@ -19,17 +19,20 @@ def random_secret_access_key(): def random_session_token(): - return SESSION_TOKEN_PREFIX + base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX):].decode() + return ( + SESSION_TOKEN_PREFIX + + base64.b64encode(os.urandom(266))[len(SESSION_TOKEN_PREFIX) :].decode() + ) def random_assumed_role_id(): - return ACCOUNT_SPECIFIC_ASSUMED_ROLE_ID_PREFIX + _random_uppercase_or_digit_sequence(9) + return ( + ACCOUNT_SPECIFIC_ASSUMED_ROLE_ID_PREFIX + _random_uppercase_or_digit_sequence(9) + ) def _random_uppercase_or_digit_sequence(length): - return ''.join( - six.text_type( - random.choice( - string.ascii_uppercase + string.digits - )) for _ in range(length) + return "".join( + six.text_type(random.choice(string.ascii_uppercase + string.digits)) + for _ in range(length) ) diff --git a/moto/swf/__init__.py b/moto/swf/__init__.py index 0d626690a..2a500458e 100644 --- a/moto/swf/__init__.py +++ b/moto/swf/__init__.py @@ -2,6 +2,6 @@ from __future__ import unicode_literals from .models import swf_backends from ..core.models import base_decorator, deprecated_base_decorator -swf_backend = swf_backends['us-east-1'] +swf_backend = swf_backends["us-east-1"] mock_swf = base_decorator(swf_backends) mock_swf_deprecated = deprecated_base_decorator(swf_backends) diff --git a/moto/swf/constants.py b/moto/swf/constants.py index b9f680d39..80e384d3c 100644 --- a/moto/swf/constants.py +++ b/moto/swf/constants.py @@ -3,9 +3,7 @@ # See http://docs.aws.amazon.com/amazonswf/latest/apireference/API_RespondDecisionTaskCompleted.html # and subsequent docs for each decision type. DECISIONS_FIELDS = { - "cancelTimerDecisionAttributes": { - "timerId": {"type": "string", "required": True} - }, + "cancelTimerDecisionAttributes": {"timerId": {"type": "string", "required": True}}, "cancelWorkflowExecutionDecisionAttributes": { "details": {"type": "string", "required": False} }, @@ -21,15 +19,15 @@ DECISIONS_FIELDS = { "taskList": {"type": "TaskList", "required": False}, "taskPriority": {"type": "string", "required": False}, "taskStartToCloseTimeout": {"type": "string", "required": False}, - "workflowTypeVersion": {"type": "string", "required": False} + "workflowTypeVersion": {"type": "string", "required": False}, }, "failWorkflowExecutionDecisionAttributes": { "details": {"type": "string", "required": False}, - "reason": {"type": "string", "required": False} + "reason": {"type": "string", "required": False}, }, "recordMarkerDecisionAttributes": { "details": {"type": "string", "required": False}, - "markerName": {"type": "string", "required": True} + "markerName": {"type": "string", "required": True}, }, "requestCancelActivityTaskDecisionAttributes": { "activityId": {"type": "string", "required": True} @@ -37,7 +35,7 @@ DECISIONS_FIELDS = { "requestCancelExternalWorkflowExecutionDecisionAttributes": { "control": {"type": "string", "required": False}, "runId": {"type": "string", "required": False}, - "workflowId": {"type": "string", "required": True} + "workflowId": {"type": "string", "required": True}, }, "scheduleActivityTaskDecisionAttributes": { "activityId": {"type": "string", "required": True}, @@ -49,20 +47,20 @@ DECISIONS_FIELDS = { "scheduleToStartTimeout": {"type": "string", "required": False}, "startToCloseTimeout": {"type": "string", "required": False}, "taskList": {"type": "TaskList", "required": False}, - "taskPriority": {"type": "string", "required": False} + "taskPriority": {"type": "string", "required": False}, }, "scheduleLambdaFunctionDecisionAttributes": { "id": {"type": "string", "required": True}, "input": {"type": "string", "required": False}, "name": {"type": "string", "required": True}, - "startToCloseTimeout": {"type": "string", "required": False} + "startToCloseTimeout": {"type": "string", "required": False}, }, "signalExternalWorkflowExecutionDecisionAttributes": { "control": {"type": "string", "required": False}, "input": {"type": "string", "required": False}, "runId": {"type": "string", "required": False}, "signalName": {"type": "string", "required": True}, - "workflowId": {"type": "string", "required": True} + "workflowId": {"type": "string", "required": True}, }, "startChildWorkflowExecutionDecisionAttributes": { "childPolicy": {"type": "string", "required": False}, @@ -75,11 +73,11 @@ DECISIONS_FIELDS = { "taskPriority": {"type": "string", "required": False}, "taskStartToCloseTimeout": {"type": "string", "required": False}, "workflowId": {"type": "string", "required": True}, - "workflowType": {"type": "WorkflowType", "required": True} + "workflowType": {"type": "WorkflowType", "required": True}, }, "startTimerDecisionAttributes": { "control": {"type": "string", "required": False}, "startToFireTimeout": {"type": "string", "required": True}, - "timerId": {"type": "string", "required": True} - } + "timerId": {"type": "string", "required": True}, + }, } diff --git a/moto/swf/exceptions.py b/moto/swf/exceptions.py index 232b1f237..def30b313 100644 --- a/moto/swf/exceptions.py +++ b/moto/swf/exceptions.py @@ -8,71 +8,59 @@ class SWFClientError(JsonRESTError): class SWFUnknownResourceFault(SWFClientError): - def __init__(self, resource_type, resource_name=None): if resource_name: message = "Unknown {0}: {1}".format(resource_type, resource_name) else: message = "Unknown {0}".format(resource_type) super(SWFUnknownResourceFault, self).__init__( - "com.amazonaws.swf.base.model#UnknownResourceFault", - message, + "com.amazonaws.swf.base.model#UnknownResourceFault", message ) class SWFDomainAlreadyExistsFault(SWFClientError): - def __init__(self, domain_name): super(SWFDomainAlreadyExistsFault, self).__init__( - "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", - domain_name, + "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", domain_name ) class SWFDomainDeprecatedFault(SWFClientError): - def __init__(self, domain_name): super(SWFDomainDeprecatedFault, self).__init__( - "com.amazonaws.swf.base.model#DomainDeprecatedFault", - domain_name, + "com.amazonaws.swf.base.model#DomainDeprecatedFault", domain_name ) class SWFSerializationException(SWFClientError): - def __init__(self, value): message = "class java.lang.Foo can not be converted to an String " - message += " (not a real SWF exception ; happened on: {0})".format( - value) + message += " (not a real SWF exception ; happened on: {0})".format(value) __type = "com.amazonaws.swf.base.model#SerializationException" - super(SWFSerializationException, self).__init__( - __type, - message, - ) + super(SWFSerializationException, self).__init__(__type, message) class SWFTypeAlreadyExistsFault(SWFClientError): - def __init__(self, _type): super(SWFTypeAlreadyExistsFault, self).__init__( "com.amazonaws.swf.base.model#TypeAlreadyExistsFault", "{0}=[name={1}, version={2}]".format( - _type.__class__.__name__, _type.name, _type.version), + _type.__class__.__name__, _type.name, _type.version + ), ) class SWFTypeDeprecatedFault(SWFClientError): - def __init__(self, _type): super(SWFTypeDeprecatedFault, self).__init__( "com.amazonaws.swf.base.model#TypeDeprecatedFault", "{0}=[name={1}, version={2}]".format( - _type.__class__.__name__, _type.name, _type.version), + _type.__class__.__name__, _type.name, _type.version + ), ) class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError): - def __init__(self): super(SWFWorkflowExecutionAlreadyStartedFault, self).__init__( "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", @@ -81,7 +69,6 @@ class SWFWorkflowExecutionAlreadyStartedFault(SWFClientError): class SWFDefaultUndefinedFault(SWFClientError): - def __init__(self, key): # TODO: move that into moto.core.utils maybe? words = key.split("_") @@ -89,22 +76,18 @@ class SWFDefaultUndefinedFault(SWFClientError): for word in words: key_camel_case += word.capitalize() super(SWFDefaultUndefinedFault, self).__init__( - "com.amazonaws.swf.base.model#DefaultUndefinedFault", - key_camel_case, + "com.amazonaws.swf.base.model#DefaultUndefinedFault", key_camel_case ) class SWFValidationException(SWFClientError): - def __init__(self, message): super(SWFValidationException, self).__init__( - "com.amazon.coral.validate#ValidationException", - message, + "com.amazon.coral.validate#ValidationException", message ) class SWFDecisionValidationException(SWFClientError): - def __init__(self, problems): # messages messages = [] @@ -122,8 +105,7 @@ class SWFDecisionValidationException(SWFClientError): ) else: raise ValueError( - "Unhandled decision constraint type: {0}".format(pb[ - "type"]) + "Unhandled decision constraint type: {0}".format(pb["type"]) ) # prefix count = len(problems) @@ -138,6 +120,5 @@ class SWFDecisionValidationException(SWFClientError): class SWFWorkflowExecutionClosedError(Exception): - def __str__(self): return repr("Cannot change this object because the WorkflowExecution is closed") diff --git a/moto/swf/models/__init__.py b/moto/swf/models/__init__.py index a8bc57f40..50cc29bb3 100644 --- a/moto/swf/models/__init__.py +++ b/moto/swf/models/__init__.py @@ -12,25 +12,21 @@ from ..exceptions import ( SWFTypeDeprecatedFault, SWFValidationException, ) -from .activity_task import ActivityTask # flake8: noqa -from .activity_type import ActivityType # flake8: noqa -from .decision_task import DecisionTask # flake8: noqa -from .domain import Domain # flake8: noqa -from .generic_type import GenericType # flake8: noqa -from .history_event import HistoryEvent # flake8: noqa -from .timeout import Timeout # flake8: noqa -from .workflow_type import WorkflowType # flake8: noqa -from .workflow_execution import WorkflowExecution # flake8: noqa +from .activity_task import ActivityTask # noqa +from .activity_type import ActivityType # noqa +from .decision_task import DecisionTask # noqa +from .domain import Domain # noqa +from .generic_type import GenericType # noqa +from .history_event import HistoryEvent # noqa +from .timeout import Timeout # noqa +from .workflow_type import WorkflowType # noqa +from .workflow_execution import WorkflowExecution # noqa from time import sleep -KNOWN_SWF_TYPES = { - "activity": ActivityType, - "workflow": WorkflowType, -} +KNOWN_SWF_TYPES = {"activity": ActivityType, "workflow": WorkflowType} class SWFBackend(BaseBackend): - def __init__(self, region_name): self.region_name = region_name self.domains = [] @@ -55,46 +51,53 @@ class SWFBackend(BaseBackend): wfe._process_timeouts() def list_domains(self, status, reverse_order=None): - domains = [domain for domain in self.domains - if domain.status == status] + domains = [domain for domain in self.domains if domain.status == status] domains = sorted(domains, key=lambda domain: domain.name) if reverse_order: domains = reversed(domains) return domains - def list_open_workflow_executions(self, domain_name, maximum_page_size, - tag_filter, reverse_order, **kwargs): + def list_open_workflow_executions( + self, domain_name, maximum_page_size, tag_filter, reverse_order, **kwargs + ): self._process_timeouts() domain = self._get_domain(domain_name) if domain.status == "DEPRECATED": raise SWFDomainDeprecatedFault(domain_name) open_wfes = [ - wfe for wfe in domain.workflow_executions - if wfe.execution_status == 'OPEN' + wfe for wfe in domain.workflow_executions if wfe.execution_status == "OPEN" ] if tag_filter: for open_wfe in open_wfes: - if tag_filter['tag'] not in open_wfe.tag_list: + if tag_filter["tag"] not in open_wfe.tag_list: open_wfes.remove(open_wfe) if reverse_order: open_wfes = reversed(open_wfes) return open_wfes[0:maximum_page_size] - def list_closed_workflow_executions(self, domain_name, close_time_filter, - tag_filter, close_status_filter, maximum_page_size, reverse_order, - **kwargs): + def list_closed_workflow_executions( + self, + domain_name, + close_time_filter, + tag_filter, + close_status_filter, + maximum_page_size, + reverse_order, + **kwargs + ): self._process_timeouts() domain = self._get_domain(domain_name) if domain.status == "DEPRECATED": raise SWFDomainDeprecatedFault(domain_name) closed_wfes = [ - wfe for wfe in domain.workflow_executions - if wfe.execution_status == 'CLOSED' + wfe + for wfe in domain.workflow_executions + if wfe.execution_status == "CLOSED" ] if tag_filter: for closed_wfe in closed_wfes: - if tag_filter['tag'] not in closed_wfe.tag_list: + if tag_filter["tag"] not in closed_wfe.tag_list: closed_wfes.remove(closed_wfe) if close_status_filter: for closed_wfe in closed_wfes: @@ -104,12 +107,12 @@ class SWFBackend(BaseBackend): closed_wfes = reversed(closed_wfes) return closed_wfes[0:maximum_page_size] - def register_domain(self, name, workflow_execution_retention_period_in_days, - description=None): + def register_domain( + self, name, workflow_execution_retention_period_in_days, description=None + ): if self._get_domain(name, ignore_empty=True): raise SWFDomainAlreadyExistsFault(name) - domain = Domain(name, workflow_execution_retention_period_in_days, - description) + domain = Domain(name, workflow_execution_retention_period_in_days, description) self.domains.append(domain) def deprecate_domain(self, name): @@ -149,15 +152,23 @@ class SWFBackend(BaseBackend): domain = self._get_domain(domain_name) return domain.get_type(kind, name, version) - def start_workflow_execution(self, domain_name, workflow_id, - workflow_name, workflow_version, - tag_list=None, input=None, **kwargs): + def start_workflow_execution( + self, + domain_name, + workflow_id, + workflow_name, + workflow_version, + tag_list=None, + input=None, + **kwargs + ): domain = self._get_domain(domain_name) wf_type = domain.get_type("workflow", workflow_name, workflow_version) if wf_type.status == "DEPRECATED": raise SWFTypeDeprecatedFault(wf_type) - wfe = WorkflowExecution(domain, wf_type, workflow_id, - tag_list=tag_list, input=input, **kwargs) + wfe = WorkflowExecution( + domain, wf_type, workflow_id, tag_list=tag_list, input=input, **kwargs + ) domain.add_workflow_execution(wfe) wfe.start() @@ -213,9 +224,9 @@ class SWFBackend(BaseBackend): count += wfe.open_counts["openDecisionTasks"] return count - def respond_decision_task_completed(self, task_token, - decisions=None, - execution_context=None): + def respond_decision_task_completed( + self, task_token, decisions=None, execution_context=None + ): # process timeouts on all objects self._process_timeouts() # let's find decision task @@ -244,14 +255,15 @@ class SWFBackend(BaseBackend): "execution", "WorkflowExecution=[workflowId={0}, runId={1}]".format( wfe.workflow_id, wfe.run_id - ) + ), ) # decision task found, but already completed if decision_task.state != "STARTED": if decision_task.state == "COMPLETED": raise SWFUnknownResourceFault( "decision task, scheduledEventId = {0}".format( - decision_task.scheduled_event_id) + decision_task.scheduled_event_id + ) ) else: raise ValueError( @@ -263,9 +275,11 @@ class SWFBackend(BaseBackend): # everything's good if decision_task: wfe = decision_task.workflow_execution - wfe.complete_decision_task(decision_task.task_token, - decisions=decisions, - execution_context=execution_context) + wfe.complete_decision_task( + decision_task.task_token, + decisions=decisions, + execution_context=execution_context, + ) def poll_for_activity_task(self, domain_name, task_list, identity=None): # process timeouts on all objects @@ -308,8 +322,7 @@ class SWFBackend(BaseBackend): count = 0 for _task_list, tasks in domain.activity_task_lists.items(): if _task_list == task_list: - pending = [t for t in tasks if t.state in [ - "SCHEDULED", "STARTED"]] + pending = [t for t in tasks if t.state in ["SCHEDULED", "STARTED"]] count += len(pending) return count @@ -333,14 +346,15 @@ class SWFBackend(BaseBackend): "execution", "WorkflowExecution=[workflowId={0}, runId={1}]".format( wfe.workflow_id, wfe.run_id - ) + ), ) # activity task found, but already completed if activity_task.state != "STARTED": if activity_task.state == "COMPLETED": raise SWFUnknownResourceFault( "activity, scheduledEventId = {0}".format( - activity_task.scheduled_event_id) + activity_task.scheduled_event_id + ) ) else: raise ValueError( @@ -364,18 +378,24 @@ class SWFBackend(BaseBackend): self._process_timeouts() activity_task = self._find_activity_task_from_token(task_token) wfe = activity_task.workflow_execution - wfe.fail_activity_task(activity_task.task_token, - reason=reason, details=details) + wfe.fail_activity_task(activity_task.task_token, reason=reason, details=details) - def terminate_workflow_execution(self, domain_name, workflow_id, child_policy=None, - details=None, reason=None, run_id=None): + def terminate_workflow_execution( + self, + domain_name, + workflow_id, + child_policy=None, + details=None, + reason=None, + run_id=None, + ): # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) wfe = domain.get_workflow_execution( - workflow_id, run_id=run_id, raise_if_closed=True) - wfe.terminate(child_policy=child_policy, - details=details, reason=reason) + workflow_id, run_id=run_id, raise_if_closed=True + ) + wfe.terminate(child_policy=child_policy, details=details, reason=reason) def record_activity_task_heartbeat(self, task_token, details=None): # process timeouts on all objects @@ -385,12 +405,15 @@ class SWFBackend(BaseBackend): if details: activity_task.details = details - def signal_workflow_execution(self, domain_name, signal_name, workflow_id, input=None, run_id=None): + def signal_workflow_execution( + self, domain_name, signal_name, workflow_id, input=None, run_id=None + ): # process timeouts on all objects self._process_timeouts() domain = self._get_domain(domain_name) wfe = domain.get_workflow_execution( - workflow_id, run_id=run_id, raise_if_closed=True) + workflow_id, run_id=run_id, raise_if_closed=True + ) wfe.signal(signal_name, input) diff --git a/moto/swf/models/activity_task.py b/moto/swf/models/activity_task.py index 0c1f283ca..93e300ae5 100644 --- a/moto/swf/models/activity_task.py +++ b/moto/swf/models/activity_task.py @@ -10,9 +10,15 @@ from .timeout import Timeout class ActivityTask(BaseModel): - - def __init__(self, activity_id, activity_type, scheduled_event_id, - workflow_execution, timeouts, input=None): + def __init__( + self, + activity_id, + activity_type, + scheduled_event_id, + workflow_execution, + timeouts, + input=None, + ): self.activity_id = activity_id self.activity_type = activity_type self.details = None @@ -68,8 +74,9 @@ class ActivityTask(BaseModel): if not self.open or not self.workflow_execution.open: return None # TODO: handle the "NONE" case - heartbeat_timeout_at = (self.last_heartbeat_timestamp + - int(self.timeouts["heartbeatTimeout"])) + heartbeat_timeout_at = self.last_heartbeat_timestamp + int( + self.timeouts["heartbeatTimeout"] + ) _timeout = Timeout(self, heartbeat_timeout_at, "HEARTBEAT") if _timeout.reached: return _timeout diff --git a/moto/swf/models/activity_type.py b/moto/swf/models/activity_type.py index eb1bbfa68..95a83ca7a 100644 --- a/moto/swf/models/activity_type.py +++ b/moto/swf/models/activity_type.py @@ -2,7 +2,6 @@ from .generic_type import GenericType class ActivityType(GenericType): - @property def _configuration_keys(self): return [ diff --git a/moto/swf/models/decision_task.py b/moto/swf/models/decision_task.py index 9255dd6f2..c8c9824a2 100644 --- a/moto/swf/models/decision_task.py +++ b/moto/swf/models/decision_task.py @@ -10,7 +10,6 @@ from .timeout import Timeout class DecisionTask(BaseModel): - def __init__(self, workflow_execution, scheduled_event_id): self.workflow_execution = workflow_execution self.workflow_type = workflow_execution.workflow_type @@ -19,7 +18,9 @@ class DecisionTask(BaseModel): self.previous_started_event_id = 0 self.started_event_id = None self.started_timestamp = None - self.start_to_close_timeout = self.workflow_execution.task_start_to_close_timeout + self.start_to_close_timeout = ( + self.workflow_execution.task_start_to_close_timeout + ) self.state = "SCHEDULED" # this is *not* necessarily coherent with workflow execution history, # but that shouldn't be a problem for tests @@ -37,9 +38,7 @@ class DecisionTask(BaseModel): def to_full_dict(self, reverse_order=False): events = self.workflow_execution.events(reverse_order=reverse_order) hsh = { - "events": [ - evt.to_dict() for evt in events - ], + "events": [evt.to_dict() for evt in events], "taskToken": self.task_token, "previousStartedEventId": self.previous_started_event_id, "workflowExecution": self.workflow_execution.to_short_dict(), @@ -62,8 +61,7 @@ class DecisionTask(BaseModel): if not self.started or not self.workflow_execution.open: return None # TODO: handle the "NONE" case - start_to_close_at = self.started_timestamp + \ - int(self.start_to_close_timeout) + start_to_close_at = self.started_timestamp + int(self.start_to_close_timeout) _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") if _timeout.reached: return _timeout diff --git a/moto/swf/models/domain.py b/moto/swf/models/domain.py index 0aa62f4f0..54347b22b 100644 --- a/moto/swf/models/domain.py +++ b/moto/swf/models/domain.py @@ -9,16 +9,12 @@ from ..exceptions import ( class Domain(BaseModel): - def __init__(self, name, retention, description=None): self.name = name self.retention = retention self.description = description self.status = "REGISTERED" - self.types = { - "activity": defaultdict(dict), - "workflow": defaultdict(dict), - } + self.types = {"activity": defaultdict(dict), "workflow": defaultdict(dict)} # Workflow executions have an id, which unicity is guaranteed # at domain level (not super clear in the docs, but I checked # that against SWF API) ; hence the storage method as a dict @@ -32,10 +28,7 @@ class Domain(BaseModel): return "Domain(name: %(name)s, status: %(status)s)" % self.__dict__ def to_short_dict(self): - hsh = { - "name": self.name, - "status": self.status, - } + hsh = {"name": self.name, "status": self.status} if self.description: hsh["description"] = self.description return hsh @@ -43,9 +36,7 @@ class Domain(BaseModel): def to_full_dict(self): return { "domainInfo": self.to_short_dict(), - "configuration": { - "workflowExecutionRetentionPeriodInDays": self.retention, - } + "configuration": {"workflowExecutionRetentionPeriodInDays": self.retention}, } def get_type(self, kind, name, version, ignore_empty=False): @@ -57,7 +48,7 @@ class Domain(BaseModel): "type", "{0}Type=[name={1}, version={2}]".format( kind.capitalize(), name, version - ) + ), ) def add_type(self, _type): @@ -77,15 +68,22 @@ class Domain(BaseModel): raise SWFWorkflowExecutionAlreadyStartedFault() self.workflow_executions.append(workflow_execution) - def get_workflow_execution(self, workflow_id, run_id=None, - raise_if_none=True, raise_if_closed=False): + def get_workflow_execution( + self, workflow_id, run_id=None, raise_if_none=True, raise_if_closed=False + ): # query if run_id: - _all = [w for w in self.workflow_executions - if w.workflow_id == workflow_id and w.run_id == run_id] + _all = [ + w + for w in self.workflow_executions + if w.workflow_id == workflow_id and w.run_id == run_id + ] else: - _all = [w for w in self.workflow_executions - if w.workflow_id == workflow_id and w.open] + _all = [ + w + for w in self.workflow_executions + if w.workflow_id == workflow_id and w.open + ] # reduce wfe = _all[0] if _all else None # raise if closed / none @@ -93,8 +91,12 @@ class Domain(BaseModel): wfe = None if not wfe and raise_if_none: if run_id: - args = ["execution", "WorkflowExecution=[workflowId={0}, runId={1}]".format( - workflow_id, run_id)] + args = [ + "execution", + "WorkflowExecution=[workflowId={0}, runId={1}]".format( + workflow_id, run_id + ), + ] else: args = ["execution, workflowId = {0}".format(workflow_id)] raise SWFUnknownResourceFault(*args) diff --git a/moto/swf/models/generic_type.py b/moto/swf/models/generic_type.py index a56220ed6..8ae6ebc08 100644 --- a/moto/swf/models/generic_type.py +++ b/moto/swf/models/generic_type.py @@ -5,7 +5,6 @@ from moto.core.utils import camelcase_to_underscores class GenericType(BaseModel): - def __init__(self, name, version, **kwargs): self.name = name self.version = version @@ -24,7 +23,9 @@ class GenericType(BaseModel): def __repr__(self): cls = self.__class__.__name__ - attrs = "name: %(name)s, version: %(version)s, status: %(status)s" % self.__dict__ + attrs = ( + "name: %(name)s, version: %(version)s, status: %(status)s" % self.__dict__ + ) return "{0}({1})".format(cls, attrs) @property @@ -36,10 +37,7 @@ class GenericType(BaseModel): raise NotImplementedError() def to_short_dict(self): - return { - "name": self.name, - "version": self.version, - } + return {"name": self.name, "version": self.version} def to_medium_dict(self): hsh = { @@ -54,10 +52,7 @@ class GenericType(BaseModel): return hsh def to_full_dict(self): - hsh = { - "typeInfo": self.to_medium_dict(), - "configuration": {} - } + hsh = {"typeInfo": self.to_medium_dict(), "configuration": {}} if self.task_list: hsh["configuration"]["defaultTaskList"] = {"name": self.task_list} for key in self._configuration_keys: diff --git a/moto/swf/models/history_event.py b/moto/swf/models/history_event.py index e7ddfd924..f259ea94e 100644 --- a/moto/swf/models/history_event.py +++ b/moto/swf/models/history_event.py @@ -25,17 +25,17 @@ SUPPORTED_HISTORY_EVENT_TYPES = ( "ActivityTaskTimedOut", "DecisionTaskTimedOut", "WorkflowExecutionTimedOut", - "WorkflowExecutionSignaled" + "WorkflowExecutionSignaled", ) class HistoryEvent(BaseModel): - def __init__(self, event_id, event_type, event_timestamp=None, **kwargs): if event_type not in SUPPORTED_HISTORY_EVENT_TYPES: raise NotImplementedError( "HistoryEvent does not implement attributes for type '{0}'".format( - event_type) + event_type + ) ) self.event_id = event_id self.event_type = event_type @@ -61,7 +61,7 @@ class HistoryEvent(BaseModel): "eventId": self.event_id, "eventType": self.event_type, "eventTimestamp": self.event_timestamp, - self._attributes_key(): self.event_attributes + self._attributes_key(): self.event_attributes, } def _attributes_key(self): diff --git a/moto/swf/models/timeout.py b/moto/swf/models/timeout.py index f26c8a38b..bc576bb64 100644 --- a/moto/swf/models/timeout.py +++ b/moto/swf/models/timeout.py @@ -3,7 +3,6 @@ from moto.core.utils import unix_time class Timeout(BaseModel): - def __init__(self, obj, timestamp, kind): self.obj = obj self.timestamp = timestamp diff --git a/moto/swf/models/workflow_execution.py b/moto/swf/models/workflow_execution.py index 3d01f9192..fca780a41 100644 --- a/moto/swf/models/workflow_execution.py +++ b/moto/swf/models/workflow_execution.py @@ -4,9 +4,7 @@ import uuid from moto.core import BaseModel from moto.core.utils import camelcase_to_underscores, unix_time -from ..constants import ( - DECISIONS_FIELDS, -) +from ..constants import DECISIONS_FIELDS from ..exceptions import ( SWFDefaultUndefinedFault, SWFValidationException, @@ -38,7 +36,7 @@ class WorkflowExecution(BaseModel): "FailWorkflowExecution", "RequestCancelActivityTask", "StartChildWorkflowExecution", - "CancelWorkflowExecution" + "CancelWorkflowExecution", ] def __init__(self, domain, workflow_type, workflow_id, **kwargs): @@ -66,11 +64,10 @@ class WorkflowExecution(BaseModel): # param is set, # SWF will raise DefaultUndefinedFault errors in the # same order as the few lines that follow) self._set_from_kwargs_or_workflow_type( - kwargs, "execution_start_to_close_timeout") - self._set_from_kwargs_or_workflow_type( - kwargs, "task_list", "task_list") - self._set_from_kwargs_or_workflow_type( - kwargs, "task_start_to_close_timeout") + kwargs, "execution_start_to_close_timeout" + ) + self._set_from_kwargs_or_workflow_type(kwargs, "task_list", "task_list") + self._set_from_kwargs_or_workflow_type(kwargs, "task_start_to_close_timeout") self._set_from_kwargs_or_workflow_type(kwargs, "child_policy") self.input = kwargs.get("input") # counters @@ -89,7 +86,9 @@ class WorkflowExecution(BaseModel): def __repr__(self): return "WorkflowExecution(run_id: {0})".format(self.run_id) - def _set_from_kwargs_or_workflow_type(self, kwargs, local_key, workflow_type_key=None): + def _set_from_kwargs_or_workflow_type( + self, kwargs, local_key, workflow_type_key=None + ): if workflow_type_key is None: workflow_type_key = "default_" + local_key value = kwargs.get(local_key) @@ -109,10 +108,7 @@ class WorkflowExecution(BaseModel): ] def to_short_dict(self): - return { - "workflowId": self.workflow_id, - "runId": self.run_id - } + return {"workflowId": self.workflow_id, "runId": self.run_id} def to_medium_dict(self): hsh = { @@ -129,9 +125,7 @@ class WorkflowExecution(BaseModel): def to_full_dict(self): hsh = { "executionInfo": self.to_medium_dict(), - "executionConfiguration": { - "taskList": {"name": self.task_list} - } + "executionConfiguration": {"taskList": {"name": self.task_list}}, } # configuration for key in self._configuration_keys: @@ -152,23 +146,20 @@ class WorkflowExecution(BaseModel): def to_list_dict(self): hsh = { - 'execution': { - 'workflowId': self.workflow_id, - 'runId': self.run_id, - }, - 'workflowType': self.workflow_type.to_short_dict(), - 'startTimestamp': self.start_timestamp, - 'executionStatus': self.execution_status, - 'cancelRequested': self.cancel_requested, + "execution": {"workflowId": self.workflow_id, "runId": self.run_id}, + "workflowType": self.workflow_type.to_short_dict(), + "startTimestamp": self.start_timestamp, + "executionStatus": self.execution_status, + "cancelRequested": self.cancel_requested, } if self.tag_list: - hsh['tagList'] = self.tag_list + hsh["tagList"] = self.tag_list if self.parent: - hsh['parent'] = self.parent + hsh["parent"] = self.parent if self.close_status: - hsh['closeStatus'] = self.close_status + hsh["closeStatus"] = self.close_status if self.close_timestamp: - hsh['closeTimestamp'] = self.close_timestamp + hsh["closeTimestamp"] = self.close_timestamp return hsh def _process_timeouts(self): @@ -206,10 +197,7 @@ class WorkflowExecution(BaseModel): # now find the first timeout to process first_timeout = None if timeout_candidates: - first_timeout = min( - timeout_candidates, - key=lambda t: t.timestamp - ) + first_timeout = min(timeout_candidates, key=lambda t: t.timestamp) if first_timeout: should_schedule_decision_next = False @@ -258,7 +246,7 @@ class WorkflowExecution(BaseModel): task_list=self.task_list, task_start_to_close_timeout=self.task_start_to_close_timeout, workflow_type=self.workflow_type, - input=self.input + input=self.input, ) self.schedule_decision_task() @@ -269,8 +257,7 @@ class WorkflowExecution(BaseModel): task_list=self.task_list, ) self.domain.add_to_decision_task_list( - self.task_list, - DecisionTask(self, evt.event_id), + self.task_list, DecisionTask(self, evt.event_id) ) self.open_counts["openDecisionTasks"] += 1 @@ -285,32 +272,30 @@ class WorkflowExecution(BaseModel): @property def decision_tasks(self): - return [t for t in self.domain.decision_tasks - if t.workflow_execution == self] + return [t for t in self.domain.decision_tasks if t.workflow_execution == self] @property def activity_tasks(self): - return [t for t in self.domain.activity_tasks - if t.workflow_execution == self] + return [t for t in self.domain.activity_tasks if t.workflow_execution == self] def _find_decision_task(self, task_token): for dt in self.decision_tasks: if dt.task_token == task_token: return dt - raise ValueError( - "No decision task with token: {0}".format(task_token) - ) + raise ValueError("No decision task with token: {0}".format(task_token)) def start_decision_task(self, task_token, identity=None): dt = self._find_decision_task(task_token) evt = self._add_event( "DecisionTaskStarted", scheduled_event_id=dt.scheduled_event_id, - identity=identity + identity=identity, ) dt.start(evt.event_id) - def complete_decision_task(self, task_token, decisions=None, execution_context=None): + def complete_decision_task( + self, task_token, decisions=None, execution_context=None + ): # 'decisions' can be None per boto.swf defaults, so replace it with something iterable if not decisions: decisions = [] @@ -336,12 +321,14 @@ class WorkflowExecution(BaseModel): constraints = DECISIONS_FIELDS.get(kind, {}) for key, constraint in constraints.items(): if constraint["required"] and not value.get(key): - problems.append({ - "type": "null_value", - "where": "decisions.{0}.member.{1}.{2}".format( - decision_id, kind, key - ) - }) + problems.append( + { + "type": "null_value", + "where": "decisions.{0}.member.{1}.{2}".format( + decision_id, kind, key + ), + } + ) return problems def validate_decisions(self, decisions): @@ -362,9 +349,7 @@ class WorkflowExecution(BaseModel): "CancelWorkflowExecution", ] if dcs["decisionType"] in close_decision_types: - raise SWFValidationException( - "Close must be last decision in list" - ) + raise SWFValidationException("Close must be last decision in list") decision_number = 0 for dcs in decisions: @@ -372,24 +357,29 @@ class WorkflowExecution(BaseModel): # check decision types mandatory attributes # NB: the real SWF service seems to check attributes even for attributes list # that are not in line with the decisionType, so we do the same - attrs_to_check = [ - d for d in dcs.keys() if d.endswith("DecisionAttributes")] + attrs_to_check = [d for d in dcs.keys() if d.endswith("DecisionAttributes")] if dcs["decisionType"] in self.KNOWN_DECISION_TYPES: decision_type = dcs["decisionType"] decision_attr = "{0}DecisionAttributes".format( - decapitalize(decision_type)) + decapitalize(decision_type) + ) attrs_to_check.append(decision_attr) for attr in attrs_to_check: problems += self._check_decision_attributes( - attr, dcs.get(attr, {}), decision_number) + attr, dcs.get(attr, {}), decision_number + ) # check decision type is correct if dcs["decisionType"] not in self.KNOWN_DECISION_TYPES: - problems.append({ - "type": "bad_decision_type", - "value": dcs["decisionType"], - "where": "decisions.{0}.member.decisionType".format(decision_number), - "possible_values": ", ".join(self.KNOWN_DECISION_TYPES), - }) + problems.append( + { + "type": "bad_decision_type", + "value": dcs["decisionType"], + "where": "decisions.{0}.member.decisionType".format( + decision_number + ), + "possible_values": ", ".join(self.KNOWN_DECISION_TYPES), + } + ) # raise if any problem if any(problems): @@ -403,14 +393,12 @@ class WorkflowExecution(BaseModel): # handle each decision separately, in order for decision in decisions: decision_type = decision["decisionType"] - attributes_key = "{0}DecisionAttributes".format( - decapitalize(decision_type)) + attributes_key = "{0}DecisionAttributes".format(decapitalize(decision_type)) attributes = decision.get(attributes_key, {}) if decision_type == "CompleteWorkflowExecution": self.complete(event_id, attributes.get("result")) elif decision_type == "FailWorkflowExecution": - self.fail(event_id, attributes.get( - "details"), attributes.get("reason")) + self.fail(event_id, attributes.get("details"), attributes.get("reason")) elif decision_type == "ScheduleActivityTask": self.schedule_activity_task(event_id, attributes) else: @@ -425,7 +413,8 @@ class WorkflowExecution(BaseModel): # TODO: implement Decision type: StartChildWorkflowExecution # TODO: implement Decision type: StartTimer raise NotImplementedError( - "Cannot handle decision: {0}".format(decision_type)) + "Cannot handle decision: {0}".format(decision_type) + ) # finally decrement counter if and only if everything went well self.open_counts["openDecisionTasks"] -= 1 @@ -475,18 +464,21 @@ class WorkflowExecution(BaseModel): ignore_empty=True, ) if not activity_type: - fake_type = ActivityType(attributes["activityType"]["name"], - attributes["activityType"]["version"]) - fail_schedule_activity_task(fake_type, - "ACTIVITY_TYPE_DOES_NOT_EXIST") + fake_type = ActivityType( + attributes["activityType"]["name"], + attributes["activityType"]["version"], + ) + fail_schedule_activity_task(fake_type, "ACTIVITY_TYPE_DOES_NOT_EXIST") return if activity_type.status == "DEPRECATED": - fail_schedule_activity_task(activity_type, - "ACTIVITY_TYPE_DEPRECATED") + fail_schedule_activity_task(activity_type, "ACTIVITY_TYPE_DEPRECATED") return - if any(at for at in self.activity_tasks if at.activity_id == attributes["activityId"]): - fail_schedule_activity_task(activity_type, - "ACTIVITY_ID_ALREADY_IN_USE") + if any( + at + for at in self.activity_tasks + if at.activity_id == attributes["activityId"] + ): + fail_schedule_activity_task(activity_type, "ACTIVITY_ID_ALREADY_IN_USE") return # find task list or default task list, else fail @@ -494,20 +486,25 @@ class WorkflowExecution(BaseModel): if not task_list and activity_type.task_list: task_list = activity_type.task_list if not task_list: - fail_schedule_activity_task(activity_type, - "DEFAULT_TASK_LIST_UNDEFINED") + fail_schedule_activity_task(activity_type, "DEFAULT_TASK_LIST_UNDEFINED") return # find timeouts or default timeout, else fail timeouts = {} - for _type in ["scheduleToStartTimeout", "scheduleToCloseTimeout", "startToCloseTimeout", "heartbeatTimeout"]: + for _type in [ + "scheduleToStartTimeout", + "scheduleToCloseTimeout", + "startToCloseTimeout", + "heartbeatTimeout", + ]: default_key = "default_task_" + camelcase_to_underscores(_type) default_value = getattr(activity_type, default_key) timeouts[_type] = attributes.get(_type, default_value) if not timeouts[_type]: error_key = default_key.replace("default_task_", "default_") - fail_schedule_activity_task(activity_type, - "{0}_UNDEFINED".format(error_key.upper())) + fail_schedule_activity_task( + activity_type, "{0}_UNDEFINED".format(error_key.upper()) + ) return # Only add event and increment counters now that nothing went wrong @@ -541,16 +538,14 @@ class WorkflowExecution(BaseModel): for task in self.activity_tasks: if task.task_token == task_token: return task - raise ValueError( - "No activity task with token: {0}".format(task_token) - ) + raise ValueError("No activity task with token: {0}".format(task_token)) def start_activity_task(self, task_token, identity=None): task = self._find_activity_task(task_token) evt = self._add_event( "ActivityTaskStarted", scheduled_event_id=task.scheduled_event_id, - identity=identity + identity=identity, ) task.start(evt.event_id) @@ -601,17 +596,16 @@ class WorkflowExecution(BaseModel): def signal(self, signal_name, input): self._add_event( - "WorkflowExecutionSignaled", - signal_name=signal_name, - input=input, + "WorkflowExecutionSignaled", signal_name=signal_name, input=input ) self.schedule_decision_task() def first_timeout(self): if not self.open or not self.start_timestamp: return None - start_to_close_at = self.start_timestamp + \ - int(self.execution_start_to_close_timeout) + start_to_close_at = self.start_timestamp + int( + self.execution_start_to_close_timeout + ) _timeout = Timeout(self, start_to_close_at, "START_TO_CLOSE") if _timeout.reached: return _timeout diff --git a/moto/swf/models/workflow_type.py b/moto/swf/models/workflow_type.py index 18d18d415..ddb2475b2 100644 --- a/moto/swf/models/workflow_type.py +++ b/moto/swf/models/workflow_type.py @@ -2,7 +2,6 @@ from .generic_type import GenericType class WorkflowType(GenericType): - @property def _configuration_keys(self): return [ diff --git a/moto/swf/responses.py b/moto/swf/responses.py index 6f002d3d4..98b736cda 100644 --- a/moto/swf/responses.py +++ b/moto/swf/responses.py @@ -8,7 +8,6 @@ from .models import swf_backends class SWFResponse(BaseResponse): - @property def swf_backend(self): return swf_backends[self.region] @@ -51,11 +50,12 @@ class SWFResponse(BaseResponse): return keys = kwargs.keys() if len(keys) == 2: - message = 'Cannot specify both a {0} and a {1}'.format(keys[0], - keys[1]) + message = "Cannot specify both a {0} and a {1}".format(keys[0], keys[1]) else: - message = 'Cannot specify more than one exclusive filters in the' \ - ' same query: {0}'.format(keys) + message = ( + "Cannot specify more than one exclusive filters in the" + " same query: {0}".format(keys) + ) raise SWFValidationException(message) def _list_types(self, kind): @@ -65,10 +65,9 @@ class SWFResponse(BaseResponse): self._check_string(domain_name) self._check_string(status) types = self.swf_backend.list_types( - kind, domain_name, status, reverse_order=reverse_order) - return json.dumps({ - "typeInfos": [_type.to_medium_dict() for _type in types] - }) + kind, domain_name, status, reverse_order=reverse_order + ) + return json.dumps({"typeInfos": [_type.to_medium_dict() for _type in types]}) def _describe_type(self, kind): domain = self._params["domain"] @@ -98,50 +97,51 @@ class SWFResponse(BaseResponse): status = self._params["registrationStatus"] self._check_string(status) reverse_order = self._params.get("reverseOrder", None) - domains = self.swf_backend.list_domains( - status, reverse_order=reverse_order) - return json.dumps({ - "domainInfos": [domain.to_short_dict() for domain in domains] - }) + domains = self.swf_backend.list_domains(status, reverse_order=reverse_order) + return json.dumps( + {"domainInfos": [domain.to_short_dict() for domain in domains]} + ) def list_closed_workflow_executions(self): - domain = self._params['domain'] - start_time_filter = self._params.get('startTimeFilter', None) - close_time_filter = self._params.get('closeTimeFilter', None) - execution_filter = self._params.get('executionFilter', None) - workflow_id = execution_filter[ - 'workflowId'] if execution_filter else None - maximum_page_size = self._params.get('maximumPageSize', 1000) - reverse_order = self._params.get('reverseOrder', None) - tag_filter = self._params.get('tagFilter', None) - type_filter = self._params.get('typeFilter', None) - close_status_filter = self._params.get('closeStatusFilter', None) + domain = self._params["domain"] + start_time_filter = self._params.get("startTimeFilter", None) + close_time_filter = self._params.get("closeTimeFilter", None) + execution_filter = self._params.get("executionFilter", None) + workflow_id = execution_filter["workflowId"] if execution_filter else None + maximum_page_size = self._params.get("maximumPageSize", 1000) + reverse_order = self._params.get("reverseOrder", None) + tag_filter = self._params.get("tagFilter", None) + type_filter = self._params.get("typeFilter", None) + close_status_filter = self._params.get("closeStatusFilter", None) self._check_string(domain) self._check_none_or_string(workflow_id) - self._check_exclusivity(executionFilter=execution_filter, - typeFilter=type_filter, - tagFilter=tag_filter, - closeStatusFilter=close_status_filter) - self._check_exclusivity(startTimeFilter=start_time_filter, - closeTimeFilter=close_time_filter) + self._check_exclusivity( + executionFilter=execution_filter, + typeFilter=type_filter, + tagFilter=tag_filter, + closeStatusFilter=close_status_filter, + ) + self._check_exclusivity( + startTimeFilter=start_time_filter, closeTimeFilter=close_time_filter + ) if start_time_filter is None and close_time_filter is None: - raise SWFValidationException('Must specify time filter') + raise SWFValidationException("Must specify time filter") if start_time_filter: - self._check_float_or_int(start_time_filter['oldestDate']) - if 'latestDate' in start_time_filter: - self._check_float_or_int(start_time_filter['latestDate']) + self._check_float_or_int(start_time_filter["oldestDate"]) + if "latestDate" in start_time_filter: + self._check_float_or_int(start_time_filter["latestDate"]) if close_time_filter: - self._check_float_or_int(close_time_filter['oldestDate']) - if 'latestDate' in close_time_filter: - self._check_float_or_int(close_time_filter['latestDate']) + self._check_float_or_int(close_time_filter["oldestDate"]) + if "latestDate" in close_time_filter: + self._check_float_or_int(close_time_filter["latestDate"]) if tag_filter: - self._check_string(tag_filter['tag']) + self._check_string(tag_filter["tag"]) if type_filter: - self._check_string(type_filter['name']) - self._check_string(type_filter['version']) + self._check_string(type_filter["name"]) + self._check_string(type_filter["version"]) if close_status_filter: - self._check_string(close_status_filter['status']) + self._check_string(close_status_filter["status"]) self._check_int(maximum_page_size) workflow_executions = self.swf_backend.list_closed_workflow_executions( @@ -154,37 +154,38 @@ class SWFResponse(BaseResponse): maximum_page_size=maximum_page_size, reverse_order=reverse_order, workflow_id=workflow_id, - close_status_filter=close_status_filter + close_status_filter=close_status_filter, ) - return json.dumps({ - 'executionInfos': [wfe.to_list_dict() for wfe in workflow_executions] - }) + return json.dumps( + {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} + ) def list_open_workflow_executions(self): - domain = self._params['domain'] - start_time_filter = self._params['startTimeFilter'] - execution_filter = self._params.get('executionFilter', None) - workflow_id = execution_filter[ - 'workflowId'] if execution_filter else None - maximum_page_size = self._params.get('maximumPageSize', 1000) - reverse_order = self._params.get('reverseOrder', None) - tag_filter = self._params.get('tagFilter', None) - type_filter = self._params.get('typeFilter', None) + domain = self._params["domain"] + start_time_filter = self._params["startTimeFilter"] + execution_filter = self._params.get("executionFilter", None) + workflow_id = execution_filter["workflowId"] if execution_filter else None + maximum_page_size = self._params.get("maximumPageSize", 1000) + reverse_order = self._params.get("reverseOrder", None) + tag_filter = self._params.get("tagFilter", None) + type_filter = self._params.get("typeFilter", None) self._check_string(domain) self._check_none_or_string(workflow_id) - self._check_exclusivity(executionFilter=execution_filter, - typeFilter=type_filter, - tagFilter=tag_filter) - self._check_float_or_int(start_time_filter['oldestDate']) - if 'latestDate' in start_time_filter: - self._check_float_or_int(start_time_filter['latestDate']) + self._check_exclusivity( + executionFilter=execution_filter, + typeFilter=type_filter, + tagFilter=tag_filter, + ) + self._check_float_or_int(start_time_filter["oldestDate"]) + if "latestDate" in start_time_filter: + self._check_float_or_int(start_time_filter["latestDate"]) if tag_filter: - self._check_string(tag_filter['tag']) + self._check_string(tag_filter["tag"]) if type_filter: - self._check_string(type_filter['name']) - self._check_string(type_filter['version']) + self._check_string(type_filter["name"]) + self._check_string(type_filter["version"]) self._check_int(maximum_page_size) workflow_executions = self.swf_backend.list_open_workflow_executions( @@ -195,12 +196,12 @@ class SWFResponse(BaseResponse): type_filter=type_filter, maximum_page_size=maximum_page_size, reverse_order=reverse_order, - workflow_id=workflow_id + workflow_id=workflow_id, ) - return json.dumps({ - 'executionInfos': [wfe.to_list_dict() for wfe in workflow_executions] - }) + return json.dumps( + {"executionInfos": [wfe.to_list_dict() for wfe in workflow_executions]} + ) def register_domain(self): name = self._params["name"] @@ -209,8 +210,7 @@ class SWFResponse(BaseResponse): self._check_string(retention) self._check_string(name) self._check_none_or_string(description) - self.swf_backend.register_domain(name, retention, - description=description) + self.swf_backend.register_domain(name, retention, description=description) return "" def deprecate_domain(self): @@ -238,14 +238,16 @@ class SWFResponse(BaseResponse): task_list = default_task_list.get("name") else: task_list = None - default_task_heartbeat_timeout = self._params.get( - "defaultTaskHeartbeatTimeout") + default_task_heartbeat_timeout = self._params.get("defaultTaskHeartbeatTimeout") default_task_schedule_to_close_timeout = self._params.get( - "defaultTaskScheduleToCloseTimeout") + "defaultTaskScheduleToCloseTimeout" + ) default_task_schedule_to_start_timeout = self._params.get( - "defaultTaskScheduleToStartTimeout") + "defaultTaskScheduleToStartTimeout" + ) default_task_start_to_close_timeout = self._params.get( - "defaultTaskStartToCloseTimeout") + "defaultTaskStartToCloseTimeout" + ) description = self._params.get("description") self._check_string(domain) @@ -260,7 +262,11 @@ class SWFResponse(BaseResponse): # TODO: add defaultTaskPriority when boto gets to support it self.swf_backend.register_type( - "activity", domain, name, version, task_list=task_list, + "activity", + domain, + name, + version, + task_list=task_list, default_task_heartbeat_timeout=default_task_heartbeat_timeout, default_task_schedule_to_close_timeout=default_task_schedule_to_close_timeout, default_task_schedule_to_start_timeout=default_task_schedule_to_start_timeout, @@ -289,9 +295,11 @@ class SWFResponse(BaseResponse): task_list = None default_child_policy = self._params.get("defaultChildPolicy") default_task_start_to_close_timeout = self._params.get( - "defaultTaskStartToCloseTimeout") + "defaultTaskStartToCloseTimeout" + ) default_execution_start_to_close_timeout = self._params.get( - "defaultExecutionStartToCloseTimeout") + "defaultExecutionStartToCloseTimeout" + ) description = self._params.get("description") self._check_string(domain) @@ -306,7 +314,11 @@ class SWFResponse(BaseResponse): # TODO: add defaultTaskPriority when boto gets to support it # TODO: add defaultLambdaRole when boto gets to support it self.swf_backend.register_type( - "workflow", domain, name, version, task_list=task_list, + "workflow", + domain, + name, + version, + task_list=task_list, default_child_policy=default_child_policy, default_task_start_to_close_timeout=default_task_start_to_close_timeout, default_execution_start_to_close_timeout=default_execution_start_to_close_timeout, @@ -333,11 +345,11 @@ class SWFResponse(BaseResponse): task_list = None child_policy = self._params.get("childPolicy") execution_start_to_close_timeout = self._params.get( - "executionStartToCloseTimeout") + "executionStartToCloseTimeout" + ) input_ = self._params.get("input") tag_list = self._params.get("tagList") - task_start_to_close_timeout = self._params.get( - "taskStartToCloseTimeout") + task_start_to_close_timeout = self._params.get("taskStartToCloseTimeout") self._check_string(domain) self._check_string(workflow_id) @@ -351,16 +363,19 @@ class SWFResponse(BaseResponse): self._check_none_or_string(task_start_to_close_timeout) wfe = self.swf_backend.start_workflow_execution( - domain, workflow_id, workflow_name, workflow_version, - task_list=task_list, child_policy=child_policy, + domain, + workflow_id, + workflow_name, + workflow_version, + task_list=task_list, + child_policy=child_policy, execution_start_to_close_timeout=execution_start_to_close_timeout, - input=input_, tag_list=tag_list, - task_start_to_close_timeout=task_start_to_close_timeout + input=input_, + tag_list=tag_list, + task_start_to_close_timeout=task_start_to_close_timeout, ) - return json.dumps({ - "runId": wfe.run_id - }) + return json.dumps({"runId": wfe.run_id}) def describe_workflow_execution(self): domain_name = self._params["domain"] @@ -373,7 +388,8 @@ class SWFResponse(BaseResponse): self._check_string(workflow_id) wfe = self.swf_backend.describe_workflow_execution( - domain_name, run_id, workflow_id) + domain_name, run_id, workflow_id + ) return json.dumps(wfe.to_full_dict()) def get_workflow_execution_history(self): @@ -383,11 +399,10 @@ class SWFResponse(BaseResponse): workflow_id = _workflow_execution["workflowId"] reverse_order = self._params.get("reverseOrder", None) wfe = self.swf_backend.describe_workflow_execution( - domain_name, run_id, workflow_id) + domain_name, run_id, workflow_id + ) events = wfe.events(reverse_order=reverse_order) - return json.dumps({ - "events": [evt.to_dict() for evt in events] - }) + return json.dumps({"events": [evt.to_dict() for evt in events]}) def poll_for_decision_task(self): domain_name = self._params["domain"] @@ -402,9 +417,7 @@ class SWFResponse(BaseResponse): domain_name, task_list, identity=identity ) if decision: - return json.dumps( - decision.to_full_dict(reverse_order=reverse_order) - ) + return json.dumps(decision.to_full_dict(reverse_order=reverse_order)) else: return json.dumps({"previousStartedEventId": 0, "startedEventId": 0}) @@ -413,8 +426,7 @@ class SWFResponse(BaseResponse): task_list = self._params["taskList"]["name"] self._check_string(domain_name) self._check_string(task_list) - count = self.swf_backend.count_pending_decision_tasks( - domain_name, task_list) + count = self.swf_backend.count_pending_decision_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) def respond_decision_task_completed(self): @@ -439,9 +451,7 @@ class SWFResponse(BaseResponse): domain_name, task_list, identity=identity ) if activity_task: - return json.dumps( - activity_task.to_full_dict() - ) + return json.dumps(activity_task.to_full_dict()) else: return json.dumps({"startedEventId": 0}) @@ -450,8 +460,7 @@ class SWFResponse(BaseResponse): task_list = self._params["taskList"]["name"] self._check_string(domain_name) self._check_string(task_list) - count = self.swf_backend.count_pending_activity_tasks( - domain_name, task_list) + count = self.swf_backend.count_pending_activity_tasks(domain_name, task_list) return json.dumps({"count": count, "truncated": False}) def respond_activity_task_completed(self): @@ -459,9 +468,7 @@ class SWFResponse(BaseResponse): result = self._params.get("result") self._check_string(task_token) self._check_none_or_string(result) - self.swf_backend.respond_activity_task_completed( - task_token, result=result - ) + self.swf_backend.respond_activity_task_completed(task_token, result=result) return "" def respond_activity_task_failed(self): @@ -492,8 +499,12 @@ class SWFResponse(BaseResponse): self._check_none_or_string(reason) self._check_none_or_string(run_id) self.swf_backend.terminate_workflow_execution( - domain_name, workflow_id, child_policy=child_policy, - details=details, reason=reason, run_id=run_id + domain_name, + workflow_id, + child_policy=child_policy, + details=details, + reason=reason, + run_id=run_id, ) return "" @@ -502,9 +513,7 @@ class SWFResponse(BaseResponse): details = self._params.get("details") self._check_string(task_token) self._check_none_or_string(details) - self.swf_backend.record_activity_task_heartbeat( - task_token, details=details - ) + self.swf_backend.record_activity_task_heartbeat(task_token, details=details) # TODO: make it dynamic when we implement activity tasks cancellation return json.dumps({"cancelRequested": False}) @@ -522,5 +531,6 @@ class SWFResponse(BaseResponse): self._check_none_or_string(run_id) self.swf_backend.signal_workflow_execution( - domain_name, signal_name, workflow_id, _input, run_id) + domain_name, signal_name, workflow_id, _input, run_id + ) return "" diff --git a/moto/swf/urls.py b/moto/swf/urls.py index 582c874fc..cafc39ad3 100644 --- a/moto/swf/urls.py +++ b/moto/swf/urls.py @@ -1,9 +1,5 @@ from .responses import SWFResponse -url_bases = [ - "https?://swf.(.+).amazonaws.com", -] +url_bases = ["https?://swf.(.+).amazonaws.com"] -url_paths = { - '{0}/$': SWFResponse.dispatch, -} +url_paths = {"{0}/$": SWFResponse.dispatch} diff --git a/moto/swf/utils.py b/moto/swf/utils.py index de628ce50..1b85f4ca9 100644 --- a/moto/swf/utils.py +++ b/moto/swf/utils.py @@ -1,3 +1,2 @@ - def decapitalize(key): return key[0].lower() + key[1:] diff --git a/moto/xray/__init__.py b/moto/xray/__init__.py index 41f00af58..c6c612250 100644 --- a/moto/xray/__init__.py +++ b/moto/xray/__init__.py @@ -3,5 +3,5 @@ from .models import xray_backends from ..core.models import base_decorator from .mock_client import mock_xray_client, XRaySegment # noqa -xray_backend = xray_backends['us-east-1'] +xray_backend = xray_backends["us-east-1"] mock_xray = base_decorator(xray_backends) diff --git a/moto/xray/exceptions.py b/moto/xray/exceptions.py index 24f700178..8b5c87e36 100644 --- a/moto/xray/exceptions.py +++ b/moto/xray/exceptions.py @@ -11,11 +11,14 @@ class AWSError(Exception): self.status = status if status is not None else self.STATUS def response(self): - return json.dumps({'__type': self.code, 'message': self.message}), dict(status=self.status) + return ( + json.dumps({"__type": self.code, "message": self.message}), + dict(status=self.status), + ) class InvalidRequestException(AWSError): - CODE = 'InvalidRequestException' + CODE = "InvalidRequestException" class BadSegmentException(Exception): @@ -25,15 +28,15 @@ class BadSegmentException(Exception): self.message = message def __repr__(self): - return ''.format('-'.join([self.id, self.code, self.message])) + return "".format("-".join([self.id, self.code, self.message])) def to_dict(self): result = {} if self.id is not None: - result['Id'] = self.id + result["Id"] = self.id if self.code is not None: - result['ErrorCode'] = self.code + result["ErrorCode"] = self.code if self.message is not None: - result['Message'] = self.message + result["Message"] = self.message return result diff --git a/moto/xray/mock_client.py b/moto/xray/mock_client.py index 135796054..9e042c594 100644 --- a/moto/xray/mock_client.py +++ b/moto/xray/mock_client.py @@ -10,8 +10,11 @@ class MockEmitter(UDPEmitter): """ Replaces the code that sends UDP to local X-Ray daemon """ - def __init__(self, daemon_address='127.0.0.1:2000'): - address = os.getenv('AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE', daemon_address) + + def __init__(self, daemon_address="127.0.0.1:2000"): + address = os.getenv( + "AWS_XRAY_DAEMON_ADDRESS_YEAH_NOT_TODAY_MATE", daemon_address + ) self._ip, self._port = self._parse_address(address) def _xray_backend(self, region): @@ -26,7 +29,7 @@ class MockEmitter(UDPEmitter): pass def _send_data(self, data): - raise RuntimeError('Should not be running this') + raise RuntimeError("Should not be running this") def mock_xray_client(f): @@ -39,12 +42,13 @@ def mock_xray_client(f): We also patch the Emitter by subclassing the UDPEmitter class replacing its methods and pushing that itno the recorder instance. """ + @wraps(f) def _wrapped(*args, **kwargs): print("Starting X-Ray Patch") - old_xray_context_var = os.environ.get('AWS_XRAY_CONTEXT_MISSING') - os.environ['AWS_XRAY_CONTEXT_MISSING'] = 'LOG_ERROR' + old_xray_context_var = os.environ.get("AWS_XRAY_CONTEXT_MISSING") + os.environ["AWS_XRAY_CONTEXT_MISSING"] = "LOG_ERROR" old_xray_context = aws_xray_sdk.core.xray_recorder._context old_xray_emitter = aws_xray_sdk.core.xray_recorder._emitter aws_xray_sdk.core.xray_recorder._context = AWSContext() @@ -55,9 +59,9 @@ def mock_xray_client(f): finally: if old_xray_context_var is None: - del os.environ['AWS_XRAY_CONTEXT_MISSING'] + del os.environ["AWS_XRAY_CONTEXT_MISSING"] else: - os.environ['AWS_XRAY_CONTEXT_MISSING'] = old_xray_context_var + os.environ["AWS_XRAY_CONTEXT_MISSING"] = old_xray_context_var aws_xray_sdk.core.xray_recorder._emitter = old_xray_emitter aws_xray_sdk.core.xray_recorder._context = old_xray_context @@ -74,8 +78,11 @@ class XRaySegment(object): During testing we're going to have to control the start and end of a segment via context managers. """ + def __enter__(self): - aws_xray_sdk.core.xray_recorder.begin_segment(name='moto_mock', traceid=None, parent_id=None, sampling=1) + aws_xray_sdk.core.xray_recorder.begin_segment( + name="moto_mock", traceid=None, parent_id=None, sampling=1 + ) return self diff --git a/moto/xray/models.py b/moto/xray/models.py index b2d418232..33a271f9b 100644 --- a/moto/xray/models.py +++ b/moto/xray/models.py @@ -18,18 +18,36 @@ class TelemetryRecords(BaseModel): @classmethod def from_json(cls, json): - instance_id = json.get('EC2InstanceId', None) - hostname = json.get('Hostname') - resource_arn = json.get('ResourceARN') - telemetry_records = json['TelemetryRecords'] + instance_id = json.get("EC2InstanceId", None) + hostname = json.get("Hostname") + resource_arn = json.get("ResourceARN") + telemetry_records = json["TelemetryRecords"] return cls(instance_id, hostname, resource_arn, telemetry_records) # https://docs.aws.amazon.com/xray/latest/devguide/xray-api-segmentdocuments.html class TraceSegment(BaseModel): - def __init__(self, name, segment_id, trace_id, start_time, raw, end_time=None, in_progress=False, service=None, user=None, - origin=None, parent_id=None, http=None, aws=None, metadata=None, annotations=None, subsegments=None, **kwargs): + def __init__( + self, + name, + segment_id, + trace_id, + start_time, + raw, + end_time=None, + in_progress=False, + service=None, + user=None, + origin=None, + parent_id=None, + http=None, + aws=None, + metadata=None, + annotations=None, + subsegments=None, + **kwargs + ): self.name = name self.id = segment_id self.trace_id = trace_id @@ -61,14 +79,16 @@ class TraceSegment(BaseModel): @property def trace_version(self): if self._trace_version is None: - self._trace_version = int(self.trace_id.split('-', 1)[0]) + self._trace_version = int(self.trace_id.split("-", 1)[0]) return self._trace_version @property def request_start_date(self): if self._original_request_start_time is None: - start_time = int(self.trace_id.split('-')[1], 16) - self._original_request_start_time = datetime.datetime.fromtimestamp(start_time) + start_time = int(self.trace_id.split("-")[1], 16) + self._original_request_start_time = datetime.datetime.fromtimestamp( + start_time + ) return self._original_request_start_time @property @@ -86,19 +106,27 @@ class TraceSegment(BaseModel): @classmethod def from_dict(cls, data, raw): # Check manditory args - if 'id' not in data: - raise BadSegmentException(code='MissingParam', message='Missing segment ID') - seg_id = data['id'] - data['segment_id'] = seg_id # Just adding this key for future convenience + if "id" not in data: + raise BadSegmentException(code="MissingParam", message="Missing segment ID") + seg_id = data["id"] + data["segment_id"] = seg_id # Just adding this key for future convenience - for arg in ('name', 'trace_id', 'start_time'): + for arg in ("name", "trace_id", "start_time"): if arg not in data: - raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing segment ID') + raise BadSegmentException( + seg_id=seg_id, code="MissingParam", message="Missing segment ID" + ) - if 'end_time' not in data and 'in_progress' not in data: - raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time or in_progress') - if 'end_time' not in data and data['in_progress'] == 'false': - raise BadSegmentException(seg_id=seg_id, code='MissingParam', message='Missing end_time') + if "end_time" not in data and "in_progress" not in data: + raise BadSegmentException( + seg_id=seg_id, + code="MissingParam", + message="Missing end_time or in_progress", + ) + if "end_time" not in data and data["in_progress"] == "false": + raise BadSegmentException( + seg_id=seg_id, code="MissingParam", message="Missing end_time" + ) return cls(raw=raw, **data) @@ -110,65 +138,79 @@ class SegmentCollection(object): @staticmethod def _new_trace_item(): return { - 'start_date': datetime.datetime(1970, 1, 1), - 'end_date': datetime.datetime(1970, 1, 1), - 'finished': False, - 'trace_id': None, - 'segments': [] + "start_date": datetime.datetime(1970, 1, 1), + "end_date": datetime.datetime(1970, 1, 1), + "finished": False, + "trace_id": None, + "segments": [], } def put_segment(self, segment): # insert into a sorted list - bisect.insort_left(self._traces[segment.trace_id]['segments'], segment) + bisect.insort_left(self._traces[segment.trace_id]["segments"], segment) # Get the last segment (takes into account incorrect ordering) # and if its the last one, mark trace as complete - if self._traces[segment.trace_id]['segments'][-1].end_time is not None: - self._traces[segment.trace_id]['finished'] = True + if self._traces[segment.trace_id]["segments"][-1].end_time is not None: + self._traces[segment.trace_id]["finished"] = True - start_time = self._traces[segment.trace_id]['segments'][0].start_date - end_time = self._traces[segment.trace_id]['segments'][-1].end_date - self._traces[segment.trace_id]['start_date'] = start_time - self._traces[segment.trace_id]['end_date'] = end_time - self._traces[segment.trace_id]['trace_id'] = segment.trace_id + start_time = self._traces[segment.trace_id]["segments"][0].start_date + end_time = self._traces[segment.trace_id]["segments"][-1].end_date + self._traces[segment.trace_id]["start_date"] = start_time + self._traces[segment.trace_id]["end_date"] = end_time + self._traces[segment.trace_id]["trace_id"] = segment.trace_id # Todo consolidate trace segments into a trace. # not enough working knowledge of xray to do this def summary(self, start_time, end_time, filter_expression=None, sampling=False): # This beast https://docs.aws.amazon.com/xray/latest/api/API_GetTraceSummaries.html#API_GetTraceSummaries_ResponseSyntax if filter_expression is not None: - raise AWSError('Not implemented yet - moto', code='InternalFailure', status=500) + raise AWSError( + "Not implemented yet - moto", code="InternalFailure", status=500 + ) summaries = [] for tid, trace in self._traces.items(): - if trace['finished'] and start_time < trace['start_date'] and trace['end_date'] < end_time: - duration = int((trace['end_date'] - trace['start_date']).total_seconds()) + if ( + trace["finished"] + and start_time < trace["start_date"] + and trace["end_date"] < end_time + ): + duration = int( + (trace["end_date"] - trace["start_date"]).total_seconds() + ) # this stuff is mostly guesses, refer to TODO above - has_error = any(['error' in seg.misc for seg in trace['segments']]) - has_fault = any(['fault' in seg.misc for seg in trace['segments']]) - has_throttle = any(['throttle' in seg.misc for seg in trace['segments']]) + has_error = any(["error" in seg.misc for seg in trace["segments"]]) + has_fault = any(["fault" in seg.misc for seg in trace["segments"]]) + has_throttle = any( + ["throttle" in seg.misc for seg in trace["segments"]] + ) # Apparently all of these options are optional summary_part = { - 'Annotations': {}, # Not implemented yet - 'Duration': duration, - 'HasError': has_error, - 'HasFault': has_fault, - 'HasThrottle': has_throttle, - 'Http': {}, # Not implemented yet - 'Id': tid, - 'IsParital': False, # needs lots more work to work on partials - 'ResponseTime': 1, # definitely 1ms resposnetime - 'ServiceIds': [], # Not implemented yet - 'Users': {} # Not implemented yet + "Annotations": {}, # Not implemented yet + "Duration": duration, + "HasError": has_error, + "HasFault": has_fault, + "HasThrottle": has_throttle, + "Http": {}, # Not implemented yet + "Id": tid, + "IsParital": False, # needs lots more work to work on partials + "ResponseTime": 1, # definitely 1ms resposnetime + "ServiceIds": [], # Not implemented yet + "Users": {}, # Not implemented yet } summaries.append(summary_part) result = { - "ApproximateTime": int((datetime.datetime.now() - datetime.datetime(1970, 1, 1)).total_seconds()), + "ApproximateTime": int( + ( + datetime.datetime.now() - datetime.datetime(1970, 1, 1) + ).total_seconds() + ), "TracesProcessedCount": len(summaries), - "TraceSummaries": summaries + "TraceSummaries": summaries, } return result @@ -189,59 +231,57 @@ class SegmentCollection(object): class XRayBackend(BaseBackend): - def __init__(self): self._telemetry_records = [] self._segment_collection = SegmentCollection() def add_telemetry_records(self, json): - self._telemetry_records.append( - TelemetryRecords.from_json(json) - ) + self._telemetry_records.append(TelemetryRecords.from_json(json)) def process_segment(self, doc): try: data = json.loads(doc) except ValueError: - raise BadSegmentException(code='JSONFormatError', message='Bad JSON data') + raise BadSegmentException(code="JSONFormatError", message="Bad JSON data") try: # Get Segment Object segment = TraceSegment.from_dict(data, raw=doc) except ValueError: - raise BadSegmentException(code='JSONFormatError', message='Bad JSON data') + raise BadSegmentException(code="JSONFormatError", message="Bad JSON data") try: # Store Segment Object self._segment_collection.put_segment(segment) except Exception as err: - raise BadSegmentException(seg_id=segment.id, code='InternalFailure', message=str(err)) + raise BadSegmentException( + seg_id=segment.id, code="InternalFailure", message=str(err) + ) def get_trace_summary(self, start_time, end_time, filter_expression, summaries): - return self._segment_collection.summary(start_time, end_time, filter_expression, summaries) + return self._segment_collection.summary( + start_time, end_time, filter_expression, summaries + ) def get_trace_ids(self, trace_ids, next_token): traces, unprocessed_ids = self._segment_collection.get_trace_ids(trace_ids) - result = { - 'Traces': [], - 'UnprocessedTraceIds': unprocessed_ids - - } + result = {"Traces": [], "UnprocessedTraceIds": unprocessed_ids} for trace in traces: segments = [] - for segment in trace['segments']: - segments.append({ - 'Id': segment.id, - 'Document': segment.raw - }) + for segment in trace["segments"]: + segments.append({"Id": segment.id, "Document": segment.raw}) - result['Traces'].append({ - 'Duration': int((trace['end_date'] - trace['start_date']).total_seconds()), - 'Id': trace['trace_id'], - 'Segments': segments - }) + result["Traces"].append( + { + "Duration": int( + (trace["end_date"] - trace["start_date"]).total_seconds() + ), + "Id": trace["trace_id"], + "Segments": segments, + } + ) return result diff --git a/moto/xray/responses.py b/moto/xray/responses.py index 328a266bf..118f2de2f 100644 --- a/moto/xray/responses.py +++ b/moto/xray/responses.py @@ -10,9 +10,8 @@ from .exceptions import AWSError, BadSegmentException class XRayResponse(BaseResponse): - def _error(self, code, message): - return json.dumps({'__type': code, 'message': message}), dict(status=400) + return json.dumps({"__type": code, "message": message}), dict(status=400) @property def xray_backend(self): @@ -32,7 +31,7 @@ class XRayResponse(BaseResponse): # Amazon is just calling urls like /TelemetryRecords etc... # This uses the value after / as the camalcase action, which then # gets converted in call_action to find the following methods - return urlsplit(self.uri).path.lstrip('/') + return urlsplit(self.uri).path.lstrip("/") # PutTelemetryRecords def telemetry_records(self): @@ -41,15 +40,18 @@ class XRayResponse(BaseResponse): except AWSError as err: return err.response() - return '' + return "" # PutTraceSegments def trace_segments(self): - docs = self._get_param('TraceSegmentDocuments') + docs = self._get_param("TraceSegmentDocuments") if docs is None: - msg = 'Parameter TraceSegmentDocuments is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter TraceSegmentDocuments is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) # Raises an exception that contains info about a bad segment, # the object also has a to_dict() method @@ -60,91 +62,120 @@ class XRayResponse(BaseResponse): except BadSegmentException as bad_seg: bad_segments.append(bad_seg) except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) - result = {'UnprocessedTraceSegments': [x.to_dict() for x in bad_segments]} + result = {"UnprocessedTraceSegments": [x.to_dict() for x in bad_segments]} return json.dumps(result) # GetTraceSummaries def trace_summaries(self): - start_time = self._get_param('StartTime') - end_time = self._get_param('EndTime') + start_time = self._get_param("StartTime") + end_time = self._get_param("EndTime") if start_time is None: - msg = 'Parameter StartTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter StartTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) if end_time is None: - msg = 'Parameter EndTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter EndTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) - filter_expression = self._get_param('FilterExpression') - sampling = self._get_param('Sampling', 'false') == 'true' + filter_expression = self._get_param("FilterExpression") + sampling = self._get_param("Sampling", "false") == "true" try: start_time = datetime.datetime.fromtimestamp(int(start_time)) end_time = datetime.datetime.fromtimestamp(int(end_time)) except ValueError: - msg = 'start_time and end_time are not integers' - return json.dumps({'__type': 'InvalidParameterValue', 'message': msg}), dict(status=400) + msg = "start_time and end_time are not integers" + return ( + json.dumps({"__type": "InvalidParameterValue", "message": msg}), + dict(status=400), + ) except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) try: - result = self.xray_backend.get_trace_summary(start_time, end_time, filter_expression, sampling) + result = self.xray_backend.get_trace_summary( + start_time, end_time, filter_expression, sampling + ) except AWSError as err: return err.response() except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) return json.dumps(result) # BatchGetTraces def traces(self): - trace_ids = self._get_param('TraceIds') - next_token = self._get_param('NextToken') # not implemented yet + trace_ids = self._get_param("TraceIds") + next_token = self._get_param("NextToken") # not implemented yet if trace_ids is None: - msg = 'Parameter TraceIds is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter TraceIds is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) try: result = self.xray_backend.get_trace_ids(trace_ids, next_token) except AWSError as err: return err.response() except Exception as err: - return json.dumps({'__type': 'InternalFailure', 'message': str(err)}), dict(status=500) + return ( + json.dumps({"__type": "InternalFailure", "message": str(err)}), + dict(status=500), + ) return json.dumps(result) # GetServiceGraph - just a dummy response for now def service_graph(self): - start_time = self._get_param('StartTime') - end_time = self._get_param('EndTime') + start_time = self._get_param("StartTime") + end_time = self._get_param("EndTime") # next_token = self._get_param('NextToken') # not implemented yet if start_time is None: - msg = 'Parameter StartTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter StartTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) if end_time is None: - msg = 'Parameter EndTime is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter EndTime is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) - result = { - 'StartTime': start_time, - 'EndTime': end_time, - 'Services': [] - } + result = {"StartTime": start_time, "EndTime": end_time, "Services": []} return json.dumps(result) # GetTraceGraph - just a dummy response for now def trace_graph(self): - trace_ids = self._get_param('TraceIds') + trace_ids = self._get_param("TraceIds") # next_token = self._get_param('NextToken') # not implemented yet if trace_ids is None: - msg = 'Parameter TraceIds is missing' - return json.dumps({'__type': 'MissingParameter', 'message': msg}), dict(status=400) + msg = "Parameter TraceIds is missing" + return ( + json.dumps({"__type": "MissingParameter", "message": msg}), + dict(status=400), + ) - result = { - 'Services': [] - } + result = {"Services": []} return json.dumps(result) diff --git a/moto/xray/urls.py b/moto/xray/urls.py index b0f13a980..4a3d4b253 100644 --- a/moto/xray/urls.py +++ b/moto/xray/urls.py @@ -1,15 +1,13 @@ from __future__ import unicode_literals from .responses import XRayResponse -url_bases = [ - "https?://xray.(.+).amazonaws.com", -] +url_bases = ["https?://xray.(.+).amazonaws.com"] url_paths = { - '{0}/TelemetryRecords$': XRayResponse.dispatch, - '{0}/TraceSegments$': XRayResponse.dispatch, - '{0}/Traces$': XRayResponse.dispatch, - '{0}/ServiceGraph$': XRayResponse.dispatch, - '{0}/TraceGraph$': XRayResponse.dispatch, - '{0}/TraceSummaries$': XRayResponse.dispatch, + "{0}/TelemetryRecords$": XRayResponse.dispatch, + "{0}/TraceSegments$": XRayResponse.dispatch, + "{0}/Traces$": XRayResponse.dispatch, + "{0}/ServiceGraph$": XRayResponse.dispatch, + "{0}/TraceGraph$": XRayResponse.dispatch, + "{0}/TraceSummaries$": XRayResponse.dispatch, } diff --git a/requirements-dev.txt b/requirements-dev.txt index 1dd8ef1f8..436a7a51b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,10 @@ -r requirements.txt mock nose +black; python_version >= '3.6' sure==1.4.11 coverage -flake8==3.5.0 +flake8==3.7.8 freezegun flask boto>=2.45.0 diff --git a/scripts/scaffold.py b/scripts/scaffold.py index 6c83eeb50..be154f103 100755 --- a/scripts/scaffold.py +++ b/scripts/scaffold.py @@ -119,7 +119,7 @@ def append_mock_to_init_py(service): filtered_lines = [_ for _ in lines if re.match('^from.*mock.*$', _)] last_import_line_index = lines.index(filtered_lines[-1]) - new_line = 'from .{} import mock_{} # flake8: noqa'.format(get_escaped_service(service), get_escaped_service(service)) + new_line = 'from .{} import mock_{} # noqa'.format(get_escaped_service(service), get_escaped_service(service)) lines.insert(last_import_line_index + 1, new_line) body = '\n'.join(lines) + '\n' diff --git a/tests/__init__.py b/tests/__init__.py index bf582e0b3..05b1d476b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import logging + # Disable extra logging for tests -logging.getLogger('boto').setLevel(logging.CRITICAL) -logging.getLogger('boto3').setLevel(logging.CRITICAL) -logging.getLogger('botocore').setLevel(logging.CRITICAL) -logging.getLogger('nose').setLevel(logging.CRITICAL) +logging.getLogger("boto").setLevel(logging.CRITICAL) +logging.getLogger("boto3").setLevel(logging.CRITICAL) +logging.getLogger("botocore").setLevel(logging.CRITICAL) +logging.getLogger("nose").setLevel(logging.CRITICAL) diff --git a/tests/backport_assert_raises.py b/tests/backport_assert_raises.py index 9b20edf9d..bfed51308 100644 --- a/tests/backport_assert_raises.py +++ b/tests/backport_assert_raises.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + """ Patch courtesy of: https://marmida.com/blog/index.php/2012/08/08/monkey-patching-assert_raises/ @@ -19,7 +20,6 @@ try: except TypeError: # this version of assert_raises doesn't support the 1-arg version class AssertRaisesContext(object): - def __init__(self, expected): self.expected = expected diff --git a/tests/helpers.py b/tests/helpers.py index 50615b094..ffe27103d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -29,7 +29,6 @@ class requires_boto_gte(object): class disable_on_py3(object): - def __call__(self, test): if not six.PY3: return test diff --git a/tests/test_acm/test_acm.py b/tests/test_acm/test_acm.py index cdd8682e1..6f879e55e 100644 --- a/tests/test_acm/test_acm.py +++ b/tests/test_acm/test_acm.py @@ -11,110 +11,104 @@ from botocore.exceptions import ClientError from moto import mock_acm -RESOURCE_FOLDER = os.path.join(os.path.dirname(__file__), 'resources') -_GET_RESOURCE = lambda x: open(os.path.join(RESOURCE_FOLDER, x), 'rb').read() -CA_CRT = _GET_RESOURCE('ca.pem') -CA_KEY = _GET_RESOURCE('ca.key') -SERVER_CRT = _GET_RESOURCE('star_moto_com.pem') -SERVER_COMMON_NAME = '*.moto.com' -SERVER_CRT_BAD = _GET_RESOURCE('star_moto_com-bad.pem') -SERVER_KEY = _GET_RESOURCE('star_moto_com.key') -BAD_ARN = 'arn:aws:acm:us-east-2:123456789012:certificate/_0000000-0000-0000-0000-000000000000' +RESOURCE_FOLDER = os.path.join(os.path.dirname(__file__), "resources") +_GET_RESOURCE = lambda x: open(os.path.join(RESOURCE_FOLDER, x), "rb").read() +CA_CRT = _GET_RESOURCE("ca.pem") +CA_KEY = _GET_RESOURCE("ca.key") +SERVER_CRT = _GET_RESOURCE("star_moto_com.pem") +SERVER_COMMON_NAME = "*.moto.com" +SERVER_CRT_BAD = _GET_RESOURCE("star_moto_com-bad.pem") +SERVER_KEY = _GET_RESOURCE("star_moto_com.key") +BAD_ARN = "arn:aws:acm:us-east-2:123456789012:certificate/_0000000-0000-0000-0000-000000000000" def _import_cert(client): response = client.import_certificate( - Certificate=SERVER_CRT, - PrivateKey=SERVER_KEY, - CertificateChain=CA_CRT + Certificate=SERVER_CRT, PrivateKey=SERVER_KEY, CertificateChain=CA_CRT ) - return response['CertificateArn'] + return response["CertificateArn"] # Also tests GetCertificate @mock_acm def test_import_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") resp = client.import_certificate( - Certificate=SERVER_CRT, - PrivateKey=SERVER_KEY, - CertificateChain=CA_CRT + Certificate=SERVER_CRT, PrivateKey=SERVER_KEY, CertificateChain=CA_CRT ) - resp = client.get_certificate(CertificateArn=resp['CertificateArn']) + resp = client.get_certificate(CertificateArn=resp["CertificateArn"]) - resp['Certificate'].should.equal(SERVER_CRT.decode()) - resp.should.contain('CertificateChain') + resp["Certificate"].should.equal(SERVER_CRT.decode()) + resp.should.contain("CertificateChain") @mock_acm def test_import_bad_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: - client.import_certificate( - Certificate=SERVER_CRT_BAD, - PrivateKey=SERVER_KEY, - ) + client.import_certificate(Certificate=SERVER_CRT_BAD, PrivateKey=SERVER_KEY) except ClientError as err: - err.response['Error']['Code'].should.equal('ValidationException') + err.response["Error"]["Code"].should.equal("ValidationException") else: - raise RuntimeError('Should of raised ValidationException') + raise RuntimeError("Should of raised ValidationException") @mock_acm def test_list_certificates(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) resp = client.list_certificates() - len(resp['CertificateSummaryList']).should.equal(1) + len(resp["CertificateSummaryList"]).should.equal(1) - resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(arn) - resp['CertificateSummaryList'][0]['DomainName'].should.equal(SERVER_COMMON_NAME) + resp["CertificateSummaryList"][0]["CertificateArn"].should.equal(arn) + resp["CertificateSummaryList"][0]["DomainName"].should.equal(SERVER_COMMON_NAME) @mock_acm def test_list_certificates_by_status(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") issued_arn = _import_cert(client) - pending_arn = client.request_certificate(DomainName='google.com')['CertificateArn'] + pending_arn = client.request_certificate(DomainName="google.com")["CertificateArn"] resp = client.list_certificates() - len(resp['CertificateSummaryList']).should.equal(2) - resp = client.list_certificates(CertificateStatuses=['EXPIRED', 'INACTIVE']) - len(resp['CertificateSummaryList']).should.equal(0) - resp = client.list_certificates(CertificateStatuses=['PENDING_VALIDATION']) - len(resp['CertificateSummaryList']).should.equal(1) - resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(pending_arn) + len(resp["CertificateSummaryList"]).should.equal(2) + resp = client.list_certificates(CertificateStatuses=["EXPIRED", "INACTIVE"]) + len(resp["CertificateSummaryList"]).should.equal(0) + resp = client.list_certificates(CertificateStatuses=["PENDING_VALIDATION"]) + len(resp["CertificateSummaryList"]).should.equal(1) + resp["CertificateSummaryList"][0]["CertificateArn"].should.equal(pending_arn) - resp = client.list_certificates(CertificateStatuses=['ISSUED']) - len(resp['CertificateSummaryList']).should.equal(1) - resp['CertificateSummaryList'][0]['CertificateArn'].should.equal(issued_arn) - resp = client.list_certificates(CertificateStatuses=['ISSUED', 'PENDING_VALIDATION']) - len(resp['CertificateSummaryList']).should.equal(2) - arns = {cert['CertificateArn'] for cert in resp['CertificateSummaryList']} + resp = client.list_certificates(CertificateStatuses=["ISSUED"]) + len(resp["CertificateSummaryList"]).should.equal(1) + resp["CertificateSummaryList"][0]["CertificateArn"].should.equal(issued_arn) + resp = client.list_certificates( + CertificateStatuses=["ISSUED", "PENDING_VALIDATION"] + ) + len(resp["CertificateSummaryList"]).should.equal(2) + arns = {cert["CertificateArn"] for cert in resp["CertificateSummaryList"]} arns.should.contain(issued_arn) arns.should.contain(pending_arn) - @mock_acm def test_get_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.get_certificate(CertificateArn=BAD_ARN) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") # Also tests deleting invalid certificate @mock_acm def test_delete_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) # If it does not raise an error and the next call does, all is fine @@ -123,222 +117,209 @@ def test_delete_certificate(): try: client.delete_certificate(CertificateArn=arn) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_describe_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) resp = client.describe_certificate(CertificateArn=arn) - resp['Certificate']['CertificateArn'].should.equal(arn) - resp['Certificate']['DomainName'].should.equal(SERVER_COMMON_NAME) - resp['Certificate']['Issuer'].should.equal('Moto') - resp['Certificate']['KeyAlgorithm'].should.equal('RSA_2048') - resp['Certificate']['Status'].should.equal('ISSUED') - resp['Certificate']['Type'].should.equal('IMPORTED') + resp["Certificate"]["CertificateArn"].should.equal(arn) + resp["Certificate"]["DomainName"].should.equal(SERVER_COMMON_NAME) + resp["Certificate"]["Issuer"].should.equal("Moto") + resp["Certificate"]["KeyAlgorithm"].should.equal("RSA_2048") + resp["Certificate"]["Status"].should.equal("ISSUED") + resp["Certificate"]["Type"].should.equal("IMPORTED") @mock_acm def test_describe_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.describe_certificate(CertificateArn=BAD_ARN) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") # Also tests ListTagsForCertificate @mock_acm def test_add_tags_to_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) client.add_tags_to_certificate( - CertificateArn=arn, - Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - ] + CertificateArn=arn, Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2"}] ) resp = client.list_tags_for_certificate(CertificateArn=arn) - tags = {item['Key']: item.get('Value', '__NONE__') for item in resp['Tags']} + tags = {item["Key"]: item.get("Value", "__NONE__") for item in resp["Tags"]} - tags.should.contain('key1') - tags.should.contain('key2') - tags['key1'].should.equal('value1') + tags.should.contain("key1") + tags.should.contain("key2") + tags["key1"].should.equal("value1") # This way, it ensures that we can detect if None is passed back when it shouldnt, # as we store keys without values with a value of None, but it shouldnt be passed back - tags['key2'].should.equal('__NONE__') + tags["key2"].should.equal("__NONE__") @mock_acm def test_add_tags_to_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.add_tags_to_certificate( CertificateArn=BAD_ARN, - Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - ] + Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2"}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_list_tags_for_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.list_tags_for_certificate(CertificateArn=BAD_ARN) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_remove_tags_from_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) client.add_tags_to_certificate( CertificateArn=arn, Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - {'Key': 'key3', 'Value': 'value3'}, - {'Key': 'key4', 'Value': 'value4'}, - ] + {"Key": "key1", "Value": "value1"}, + {"Key": "key2"}, + {"Key": "key3", "Value": "value3"}, + {"Key": "key4", "Value": "value4"}, + ], ) client.remove_tags_from_certificate( CertificateArn=arn, Tags=[ - {'Key': 'key1', 'Value': 'value2'}, # Should not remove as doesnt match - {'Key': 'key2'}, # Single key removal - {'Key': 'key3', 'Value': 'value3'}, # Exact match removal - {'Key': 'key4'} # Partial match removal - ] + {"Key": "key1", "Value": "value2"}, # Should not remove as doesnt match + {"Key": "key2"}, # Single key removal + {"Key": "key3", "Value": "value3"}, # Exact match removal + {"Key": "key4"}, # Partial match removal + ], ) resp = client.list_tags_for_certificate(CertificateArn=arn) - tags = {item['Key']: item.get('Value', '__NONE__') for item in resp['Tags']} + tags = {item["Key"]: item.get("Value", "__NONE__") for item in resp["Tags"]} - for key in ('key2', 'key3', 'key4'): + for key in ("key2", "key3", "key4"): tags.should_not.contain(key) - tags.should.contain('key1') + tags.should.contain("key1") @mock_acm def test_remove_tags_from_invalid_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") try: client.remove_tags_from_certificate( CertificateArn=BAD_ARN, - Tags=[ - {'Key': 'key1', 'Value': 'value1'}, - {'Key': 'key2'}, - ] + Tags=[{"Key": "key1", "Value": "value1"}, {"Key": "key2"}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_resend_validation_email(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) client.resend_validation_email( - CertificateArn=arn, - Domain='*.moto.com', - ValidationDomain='NOTUSEDYET' + CertificateArn=arn, Domain="*.moto.com", ValidationDomain="NOTUSEDYET" ) # Returns nothing, boto would raise Exceptions otherwise @mock_acm def test_resend_validation_email_invalid(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") arn = _import_cert(client) try: client.resend_validation_email( CertificateArn=arn, - Domain='no-match.moto.com', - ValidationDomain='NOTUSEDYET' + Domain="no-match.moto.com", + ValidationDomain="NOTUSEDYET", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidDomainValidationOptionsException') + err.response["Error"]["Code"].should.equal( + "InvalidDomainValidationOptionsException" + ) else: - raise RuntimeError('Should of raised InvalidDomainValidationOptionsException') + raise RuntimeError("Should of raised InvalidDomainValidationOptionsException") try: client.resend_validation_email( CertificateArn=BAD_ARN, - Domain='no-match.moto.com', - ValidationDomain='NOTUSEDYET' + Domain="no-match.moto.com", + ValidationDomain="NOTUSEDYET", ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should of raised ResourceNotFoundException') + raise RuntimeError("Should of raised ResourceNotFoundException") @mock_acm def test_request_certificate(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") token = str(uuid.uuid4()) resp = client.request_certificate( - DomainName='google.com', + DomainName="google.com", IdempotencyToken=token, - SubjectAlternativeNames=['google.com', 'www.google.com', 'mail.google.com'], + SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"], ) - resp.should.contain('CertificateArn') - arn = resp['CertificateArn'] + resp.should.contain("CertificateArn") + arn = resp["CertificateArn"] arn.should.match(r"arn:aws:acm:eu-central-1:\d{12}:certificate/") resp = client.request_certificate( - DomainName='google.com', + DomainName="google.com", IdempotencyToken=token, - SubjectAlternativeNames=['google.com', 'www.google.com', 'mail.google.com'], + SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"], ) - resp['CertificateArn'].should.equal(arn) + resp["CertificateArn"].should.equal(arn) @mock_acm def test_request_certificate_no_san(): - client = boto3.client('acm', region_name='eu-central-1') + client = boto3.client("acm", region_name="eu-central-1") - resp = client.request_certificate( - DomainName='google.com' - ) - resp.should.contain('CertificateArn') + resp = client.request_certificate(DomainName="google.com") + resp.should.contain("CertificateArn") + + resp2 = client.describe_certificate(CertificateArn=resp["CertificateArn"]) + resp2.should.contain("Certificate") - resp2 = client.describe_certificate( - CertificateArn=resp['CertificateArn'] - ) - resp2.should.contain('Certificate') # # Also tests the SAN code # # requires Pull: https://github.com/spulec/freezegun/pull/210 diff --git a/tests/test_apigateway/test_apigateway.py b/tests/test_apigateway/test_apigateway.py index 42e9d7254..7016ae867 100644 --- a/tests/test_apigateway/test_apigateway.py +++ b/tests/test_apigateway/test_apigateway.py @@ -14,644 +14,503 @@ from moto import mock_apigateway, settings @freeze_time("2015-01-01") @mock_apigateway def test_create_and_get_rest_api(): - client = boto3.client('apigateway', region_name='us-west-2') + client = boto3.client("apigateway", region_name="us-west-2") - response = client.create_rest_api( - name='my_api', - description='this is my api', + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + + response = client.get_rest_api(restApiId=api_id) + + response.pop("ResponseMetadata") + response.pop("createdDate") + response.should.equal( + {"id": api_id, "name": "my_api", "description": "this is my api"} ) - api_id = response['id'] - - response = client.get_rest_api( - restApiId=api_id - ) - - response.pop('ResponseMetadata') - response.pop('createdDate') - response.should.equal({ - 'id': api_id, - 'name': 'my_api', - 'description': 'this is my api', - }) @mock_apigateway def test_list_and_delete_apis(): - client = boto3.client('apigateway', region_name='us-west-2') + client = boto3.client("apigateway", region_name="us-west-2") - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] - client.create_rest_api( - name='my_api2', - description='this is my api2', - ) + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] + client.create_rest_api(name="my_api2", description="this is my api2") response = client.get_rest_apis() - len(response['items']).should.equal(2) + len(response["items"]).should.equal(2) - client.delete_rest_api( - restApiId=api_id - ) + client.delete_rest_api(restApiId=api_id) response = client.get_rest_apis() - len(response['items']).should.equal(1) + len(response["items"]).should.equal(1) @mock_apigateway def test_create_resource(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] - root_resource = client.get_resource( - restApiId=api_id, - resourceId=root_id, - ) + root_resource = client.get_resource(restApiId=api_id, resourceId=root_id) # this is hard to match against, so remove it - root_resource['ResponseMetadata'].pop('HTTPHeaders', None) - root_resource['ResponseMetadata'].pop('RetryAttempts', None) - root_resource.should.equal({ - 'path': '/', - 'id': root_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'resourceMethods': { - 'GET': {} + root_resource["ResponseMetadata"].pop("HTTPHeaders", None) + root_resource["ResponseMetadata"].pop("RetryAttempts", None) + root_resource.should.equal( + { + "path": "/", + "id": root_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "resourceMethods": {"GET": {}}, } - }) + ) response = client.create_resource( - restApiId=api_id, - parentId=root_id, - pathPart='/users', + restApiId=api_id, parentId=root_id, pathPart="/users" ) - resources = client.get_resources(restApiId=api_id)['items'] + resources = client.get_resources(restApiId=api_id)["items"] len(resources).should.equal(2) - non_root_resource = [ - resource for resource in resources if resource['path'] != '/'][0] + non_root_resource = [resource for resource in resources if resource["path"] != "/"][ + 0 + ] response = client.delete_resource( - restApiId=api_id, - resourceId=non_root_resource['id'] + restApiId=api_id, resourceId=non_root_resource["id"] ) - len(client.get_resources(restApiId=api_id)['items']).should.equal(1) + len(client.get_resources(restApiId=api_id)["items"]).should.equal(1) @mock_apigateway def test_child_resource(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] response = client.create_resource( - restApiId=api_id, - parentId=root_id, - pathPart='users', + restApiId=api_id, parentId=root_id, pathPart="users" ) - users_id = response['id'] + users_id = response["id"] response = client.create_resource( - restApiId=api_id, - parentId=users_id, - pathPart='tags', + restApiId=api_id, parentId=users_id, pathPart="tags" ) - tags_id = response['id'] + tags_id = response["id"] - child_resource = client.get_resource( - restApiId=api_id, - resourceId=tags_id, - ) + child_resource = client.get_resource(restApiId=api_id, resourceId=tags_id) # this is hard to match against, so remove it - child_resource['ResponseMetadata'].pop('HTTPHeaders', None) - child_resource['ResponseMetadata'].pop('RetryAttempts', None) - child_resource.should.equal({ - 'path': '/users/tags', - 'pathPart': 'tags', - 'parentId': users_id, - 'id': tags_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'resourceMethods': {'GET': {}}, - }) + child_resource["ResponseMetadata"].pop("HTTPHeaders", None) + child_resource["ResponseMetadata"].pop("RetryAttempts", None) + child_resource.should.equal( + { + "path": "/users/tags", + "pathPart": "tags", + "parentId": users_id, + "id": tags_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "resourceMethods": {"GET": {}}, + } + ) @mock_apigateway def test_create_method(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' - ) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'httpMethod': 'GET', - 'authorizationType': 'none', - 'ResponseMetadata': {'HTTPStatusCode': 200} - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "httpMethod": "GET", + "authorizationType": "none", + "ResponseMetadata": {"HTTPStatusCode": 200}, + } + ) @mock_apigateway def test_create_method_response(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' - ) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") response = client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'statusCode': '200' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + {"ResponseMetadata": {"HTTPStatusCode": 200}, "statusCode": "200"} + ) response = client.get_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'statusCode': '200' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + {"ResponseMetadata": {"HTTPStatusCode": 200}, "statusCode": "200"} + ) response = client.delete_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({'ResponseMetadata': {'HTTPStatusCode': 200}}) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal({"ResponseMetadata": {"HTTPStatusCode": 200}}) @mock_apigateway def test_integrations(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) response = client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - type='HTTP', - uri='http://httpbin.org/robots.txt', + httpMethod="GET", + type="HTTP", + uri="http://httpbin.org/robots.txt", ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'httpMethod': 'GET', - 'integrationResponses': { - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'statusCode': 200 - } - }, - 'type': 'HTTP', - 'uri': 'http://httpbin.org/robots.txt' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "httpMethod": "GET", + "integrationResponses": { + "200": { + "responseTemplates": {"application/json": None}, + "statusCode": 200, + } + }, + "type": "HTTP", + "uri": "http://httpbin.org/robots.txt", + } + ) response = client.get_integration( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' + restApiId=api_id, resourceId=root_id, httpMethod="GET" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'httpMethod': 'GET', - 'integrationResponses': { - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'statusCode': 200 - } - }, - 'type': 'HTTP', - 'uri': 'http://httpbin.org/robots.txt' - }) - - response = client.get_resource( - restApiId=api_id, - resourceId=root_id, + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "httpMethod": "GET", + "integrationResponses": { + "200": { + "responseTemplates": {"application/json": None}, + "statusCode": 200, + } + }, + "type": "HTTP", + "uri": "http://httpbin.org/robots.txt", + } ) + + response = client.get_resource(restApiId=api_id, resourceId=root_id) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response['resourceMethods']['GET']['methodIntegration'].should.equal({ - 'httpMethod': 'GET', - 'integrationResponses': { - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'statusCode': 200 - } - }, - 'type': 'HTTP', - 'uri': 'http://httpbin.org/robots.txt' - }) - - client.delete_integration( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET' + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response["resourceMethods"]["GET"]["methodIntegration"].should.equal( + { + "httpMethod": "GET", + "integrationResponses": { + "200": { + "responseTemplates": {"application/json": None}, + "statusCode": 200, + } + }, + "type": "HTTP", + "uri": "http://httpbin.org/robots.txt", + } ) - response = client.get_resource( - restApiId=api_id, - resourceId=root_id, - ) - response['resourceMethods']['GET'].shouldnt.contain("methodIntegration") + client.delete_integration(restApiId=api_id, resourceId=root_id, httpMethod="GET") + + response = client.get_resource(restApiId=api_id, resourceId=root_id) + response["resourceMethods"]["GET"].shouldnt.contain("methodIntegration") # Create a new integration with a requestTemplates config client.put_method( restApiId=api_id, resourceId=root_id, - httpMethod='POST', - authorizationType='none', + httpMethod="POST", + authorizationType="none", ) templates = { # example based on # http://docs.aws.amazon.com/apigateway/latest/developerguide/api-as-kinesis-proxy-export-swagger-with-extensions.html - 'application/json': "{\n \"StreamName\": \"$input.params('stream-name')\",\n \"Records\": []\n}" + "application/json": '{\n "StreamName": "$input.params(\'stream-name\')",\n "Records": []\n}' } - test_uri = 'http://example.com/foobar.txt' + test_uri = "http://example.com/foobar.txt" response = client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='POST', - type='HTTP', + httpMethod="POST", + type="HTTP", uri=test_uri, - requestTemplates=templates + requestTemplates=templates, ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response['ResponseMetadata'].should.equal({'HTTPStatusCode': 200}) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response["ResponseMetadata"].should.equal({"HTTPStatusCode": 200}) response = client.get_integration( - restApiId=api_id, - resourceId=root_id, - httpMethod='POST' + restApiId=api_id, resourceId=root_id, httpMethod="POST" ) - response['uri'].should.equal(test_uri) - response['requestTemplates'].should.equal(templates) + response["uri"].should.equal(test_uri) + response["requestTemplates"].should.equal(templates) @mock_apigateway def test_integration_response(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) response = client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - type='HTTP', - uri='http://httpbin.org/robots.txt', + httpMethod="GET", + type="HTTP", + uri="http://httpbin.org/robots.txt", ) response = client.put_integration_response( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - statusCode='200', - selectionPattern='foobar', + httpMethod="GET", + statusCode="200", + selectionPattern="foobar", ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'statusCode': '200', - 'selectionPattern': 'foobar', - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'responseTemplates': { - 'application/json': None + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "statusCode": "200", + "selectionPattern": "foobar", + "ResponseMetadata": {"HTTPStatusCode": 200}, + "responseTemplates": {"application/json": None}, } - }) + ) response = client.get_integration_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'statusCode': '200', - 'selectionPattern': 'foobar', - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'responseTemplates': { - 'application/json': None + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "statusCode": "200", + "selectionPattern": "foobar", + "ResponseMetadata": {"HTTPStatusCode": 200}, + "responseTemplates": {"application/json": None}, } - }) + ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - ) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response['methodIntegration']['integrationResponses'].should.equal({ - '200': { - 'responseTemplates': { - 'application/json': None - }, - 'selectionPattern': 'foobar', - 'statusCode': '200' + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response["methodIntegration"]["integrationResponses"].should.equal( + { + "200": { + "responseTemplates": {"application/json": None}, + "selectionPattern": "foobar", + "statusCode": "200", + } } - }) + ) response = client.delete_integration_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) - response = client.get_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - ) - response['methodIntegration']['integrationResponses'].should.equal({}) + response = client.get_method(restApiId=api_id, resourceId=root_id, httpMethod="GET") + response["methodIntegration"]["integrationResponses"].should.equal({}) @mock_apigateway def test_update_stage_configuration(): - client = boto3.client('apigateway', region_name='us-west-2') - stage_name = 'staging' - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, - description="1.0.1" + restApiId=api_id, stageName=stage_name, description="1.0.1" ) - deployment_id = response['id'] + deployment_id = response["id"] - response = client.get_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id) # createdDate is hard to match against, remove it - response.pop('createdDate', None) + response.pop("createdDate", None) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'id': deployment_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '1.0.1' - }) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "id": deployment_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "1.0.1", + } + ) response = client.create_deployment( + restApiId=api_id, stageName=stage_name, description="1.0.2" + ) + deployment_id2 = response["id"] + + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + stage["stageName"].should.equal(stage_name) + stage["deploymentId"].should.equal(deployment_id2) + stage.shouldnt.have.key("cacheClusterSize") + + client.update_stage( restApiId=api_id, stageName=stage_name, - description="1.0.2" + patchOperations=[ + {"op": "replace", "path": "/cacheClusterEnabled", "value": "True"} + ], ) - deployment_id2 = response['id'] - - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - stage['stageName'].should.equal(stage_name) - stage['deploymentId'].should.equal(deployment_id2) - stage.shouldnt.have.key('cacheClusterSize') - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "replace", - "path": "/cacheClusterEnabled", - "value": "True" - } - ]) - - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - - stage.should.have.key('cacheClusterSize').which.should.equal("0.5") - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "replace", - "path": "/cacheClusterSize", - "value": "1.6" - } - ]) - - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - - stage.should.have.key('cacheClusterSize').which.should.equal("1.6") - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "replace", - "path": "/deploymentId", - "value": deployment_id - }, - { - "op": "replace", - "path": "/variables/environment", - "value": "dev" - }, - { - "op": "replace", - "path": "/variables/region", - "value": "eu-west-1" - }, - { - "op": "replace", - "path": "/*/*/caching/dataEncrypted", - "value": "True" - }, - { - "op": "replace", - "path": "/cacheClusterEnabled", - "value": "True" - }, - { - "op": "replace", - "path": "/description", - "value": "stage description update" - }, - { - "op": "replace", - "path": "/cacheClusterSize", - "value": "1.6" - } - ]) - - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "remove", - "path": "/variables/region", - "value": "eu-west-1" - } - ]) stage = client.get_stage(restApiId=api_id, stageName=stage_name) - stage['description'].should.match('stage description update') - stage['cacheClusterSize'].should.equal("1.6") - stage['variables']['environment'].should.match('dev') - stage['variables'].should_not.have.key('region') - stage['cacheClusterEnabled'].should.be.true - stage['deploymentId'].should.match(deployment_id) - stage['methodSettings'].should.have.key('*/*') - stage['methodSettings'][ - '*/*'].should.have.key('cacheDataEncrypted').which.should.be.true + stage.should.have.key("cacheClusterSize").which.should.equal("0.5") + + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "replace", "path": "/cacheClusterSize", "value": "1.6"} + ], + ) + + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + + stage.should.have.key("cacheClusterSize").which.should.equal("1.6") + + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "replace", "path": "/deploymentId", "value": deployment_id}, + {"op": "replace", "path": "/variables/environment", "value": "dev"}, + {"op": "replace", "path": "/variables/region", "value": "eu-west-1"}, + {"op": "replace", "path": "/*/*/caching/dataEncrypted", "value": "True"}, + {"op": "replace", "path": "/cacheClusterEnabled", "value": "True"}, + { + "op": "replace", + "path": "/description", + "value": "stage description update", + }, + {"op": "replace", "path": "/cacheClusterSize", "value": "1.6"}, + ], + ) + + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "remove", "path": "/variables/region", "value": "eu-west-1"} + ], + ) + + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + + stage["description"].should.match("stage description update") + stage["cacheClusterSize"].should.equal("1.6") + stage["variables"]["environment"].should.match("dev") + stage["variables"].should_not.have.key("region") + stage["cacheClusterEnabled"].should.be.true + stage["deploymentId"].should.match(deployment_id) + stage["methodSettings"].should.have.key("*/*") + stage["methodSettings"]["*/*"].should.have.key( + "cacheDataEncrypted" + ).which.should.be.true try: - client.update_stage(restApiId=api_id, stageName=stage_name, - patchOperations=[ - { - "op": "add", - "path": "/notasetting", - "value": "eu-west-1" - } - ]) + client.update_stage( + restApiId=api_id, + stageName=stage_name, + patchOperations=[ + {"op": "add", "path": "/notasetting", "value": "eu-west-1"} + ], + ) assert False.should.be.ok # Fail, should not be here except Exception: assert True.should.be.ok @@ -659,307 +518,270 @@ def test_update_stage_configuration(): @mock_apigateway def test_non_existent_stage(): - client = boto3.client('apigateway', region_name='us-west-2') - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] - client.get_stage.when.called_with( - restApiId=api_id, stageName='xxx').should.throw(ClientError) + client.get_stage.when.called_with(restApiId=api_id, stageName="xxx").should.throw( + ClientError + ) @mock_apigateway def test_create_stage(): - client = boto3.client('apigateway', region_name='us-west-2') - stage_name = 'staging' - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] - response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, - ) - deployment_id = response['id'] + response = client.create_deployment(restApiId=api_id, stageName=stage_name) + deployment_id = response["id"] - response = client.get_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id) # createdDate is hard to match against, remove it - response.pop('createdDate', None) + response.pop("createdDate", None) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'id': deployment_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '' - }) - - response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "id": deployment_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + } ) - deployment_id2 = response['id'] + response = client.create_deployment(restApiId=api_id, stageName=stage_name) - response = client.get_deployments( - restApiId=api_id, - ) + deployment_id2 = response["id"] + + response = client.get_deployments(restApiId=api_id) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) - response['items'][0].pop('createdDate') - response['items'][1].pop('createdDate') - response['items'][0]['id'].should.match( - r"{0}|{1}".format(deployment_id2, deployment_id)) - response['items'][1]['id'].should.match( - r"{0}|{1}".format(deployment_id2, deployment_id)) + response["items"][0].pop("createdDate") + response["items"][1].pop("createdDate") + response["items"][0]["id"].should.match( + r"{0}|{1}".format(deployment_id2, deployment_id) + ) + response["items"][1]["id"].should.match( + r"{0}|{1}".format(deployment_id2, deployment_id) + ) - new_stage_name = 'current' + new_stage_name = "current" response = client.create_stage( - restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2) - - # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - - response.should.equal({ - 'stageName': new_stage_name, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '', - 'cacheClusterEnabled': False - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name - ) - stage['stageName'].should.equal(new_stage_name) - stage['deploymentId'].should.equal(deployment_id2) - - new_stage_name_with_vars = 'stage_with_vars' - response = client.create_stage(restApiId=api_id, stageName=new_stage_name_with_vars, deploymentId=deployment_id2, variables={ - "env": "dev" - }) - - # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - - response.should.equal({ - 'stageName': new_stage_name_with_vars, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {"env": "dev"}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '', - 'cacheClusterEnabled': False - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name_with_vars - ) - stage['stageName'].should.equal(new_stage_name_with_vars) - stage['deploymentId'].should.equal(deployment_id2) - stage['variables'].should.have.key('env').which.should.match("dev") - - new_stage_name = 'stage_with_vars_and_cache_settings' - response = client.create_stage(restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2, variables={ - "env": "dev" - }, cacheClusterEnabled=True, description="hello moto") - - # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - - response.should.equal({ - 'stageName': new_stage_name, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {"env": "dev"}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': 'hello moto', - 'cacheClusterEnabled': True, - 'cacheClusterSize': "0.5" - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name + restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2 ) - stage['cacheClusterSize'].should.equal("0.5") + # this is hard to match against, so remove it + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) - new_stage_name = 'stage_with_vars_and_cache_settings_and_size' - response = client.create_stage(restApiId=api_id, stageName=new_stage_name, deploymentId=deployment_id2, variables={ - "env": "dev" - }, cacheClusterEnabled=True, cacheClusterSize="1.6", description="hello moto") + response.should.equal( + { + "stageName": new_stage_name, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + "cacheClusterEnabled": False, + } + ) + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name) + stage["stageName"].should.equal(new_stage_name) + stage["deploymentId"].should.equal(deployment_id2) + + new_stage_name_with_vars = "stage_with_vars" + response = client.create_stage( + restApiId=api_id, + stageName=new_stage_name_with_vars, + deploymentId=deployment_id2, + variables={"env": "dev"}, + ) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) - response.should.equal({ - 'stageName': new_stage_name, - 'deploymentId': deployment_id2, - 'methodSettings': {}, - 'variables': {"env": "dev"}, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': 'hello moto', - 'cacheClusterEnabled': True, - 'cacheClusterSize': "1.6" - }) - - stage = client.get_stage( - restApiId=api_id, - stageName=new_stage_name + response.should.equal( + { + "stageName": new_stage_name_with_vars, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {"env": "dev"}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + "cacheClusterEnabled": False, + } ) - stage['stageName'].should.equal(new_stage_name) - stage['deploymentId'].should.equal(deployment_id2) - stage['variables'].should.have.key('env').which.should.match("dev") - stage['cacheClusterSize'].should.equal("1.6") + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name_with_vars) + stage["stageName"].should.equal(new_stage_name_with_vars) + stage["deploymentId"].should.equal(deployment_id2) + stage["variables"].should.have.key("env").which.should.match("dev") + + new_stage_name = "stage_with_vars_and_cache_settings" + response = client.create_stage( + restApiId=api_id, + stageName=new_stage_name, + deploymentId=deployment_id2, + variables={"env": "dev"}, + cacheClusterEnabled=True, + description="hello moto", + ) + + # this is hard to match against, so remove it + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + + response.should.equal( + { + "stageName": new_stage_name, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {"env": "dev"}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "hello moto", + "cacheClusterEnabled": True, + "cacheClusterSize": "0.5", + } + ) + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name) + + stage["cacheClusterSize"].should.equal("0.5") + + new_stage_name = "stage_with_vars_and_cache_settings_and_size" + response = client.create_stage( + restApiId=api_id, + stageName=new_stage_name, + deploymentId=deployment_id2, + variables={"env": "dev"}, + cacheClusterEnabled=True, + cacheClusterSize="1.6", + description="hello moto", + ) + + # this is hard to match against, so remove it + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + + response.should.equal( + { + "stageName": new_stage_name, + "deploymentId": deployment_id2, + "methodSettings": {}, + "variables": {"env": "dev"}, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "hello moto", + "cacheClusterEnabled": True, + "cacheClusterSize": "1.6", + } + ) + + stage = client.get_stage(restApiId=api_id, stageName=new_stage_name) + stage["stageName"].should.equal(new_stage_name) + stage["deploymentId"].should.equal(deployment_id2) + stage["variables"].should.have.key("env").which.should.match("dev") + stage["cacheClusterSize"].should.equal("1.6") @mock_apigateway def test_deployment(): - client = boto3.client('apigateway', region_name='us-west-2') - stage_name = 'staging' - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + client = boto3.client("apigateway", region_name="us-west-2") + stage_name = "staging" + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] - response = client.create_deployment( - restApiId=api_id, - stageName=stage_name, - ) - deployment_id = response['id'] + response = client.create_deployment(restApiId=api_id, stageName=stage_name) + deployment_id = response["id"] - response = client.get_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response = client.get_deployment(restApiId=api_id, deploymentId=deployment_id) # createdDate is hard to match against, remove it - response.pop('createdDate', None) + response.pop("createdDate", None) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'id': deployment_id, - 'ResponseMetadata': {'HTTPStatusCode': 200}, - 'description': '' - }) - - response = client.get_deployments( - restApiId=api_id, + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "id": deployment_id, + "ResponseMetadata": {"HTTPStatusCode": 200}, + "description": "", + } ) - response['items'][0].pop('createdDate') - response['items'].should.equal([ - {'id': deployment_id, 'description': ''} - ]) + response = client.get_deployments(restApiId=api_id) - response = client.delete_deployment( - restApiId=api_id, - deploymentId=deployment_id, - ) + response["items"][0].pop("createdDate") + response["items"].should.equal([{"id": deployment_id, "description": ""}]) - response = client.get_deployments( - restApiId=api_id, - ) - len(response['items']).should.equal(0) + response = client.delete_deployment(restApiId=api_id, deploymentId=deployment_id) + + response = client.get_deployments(restApiId=api_id) + len(response["items"]).should.equal(0) # test deployment stages - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - stage['stageName'].should.equal(stage_name) - stage['deploymentId'].should.equal(deployment_id) + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + stage["stageName"].should.equal(stage_name) + stage["deploymentId"].should.equal(deployment_id) stage = client.update_stage( restApiId=api_id, stageName=stage_name, patchOperations=[ - { - 'op': 'replace', - 'path': '/description', - 'value': '_new_description_' - }, - ] + {"op": "replace", "path": "/description", "value": "_new_description_"} + ], ) - stage = client.get_stage( - restApiId=api_id, - stageName=stage_name - ) - stage['stageName'].should.equal(stage_name) - stage['deploymentId'].should.equal(deployment_id) - stage['description'].should.equal('_new_description_') + stage = client.get_stage(restApiId=api_id, stageName=stage_name) + stage["stageName"].should.equal(stage_name) + stage["deploymentId"].should.equal(deployment_id) + stage["description"].should.equal("_new_description_") @mock_apigateway def test_http_proxying_integration(): responses.add( - responses.GET, "http://httpbin.org/robots.txt", body='a fake response' + responses.GET, "http://httpbin.org/robots.txt", body="a fake response" ) - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) - response = client.create_rest_api( - name='my_api', - description='this is my api', - ) - api_id = response['id'] + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) + response = client.create_rest_api(name="my_api", description="this is my api") + api_id = response["id"] resources = client.get_resources(restApiId=api_id) - root_id = [resource for resource in resources[ - 'items'] if resource['path'] == '/'][0]['id'] + root_id = [resource for resource in resources["items"] if resource["path"] == "/"][ + 0 + ]["id"] client.put_method( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - authorizationType='none', + restApiId=api_id, resourceId=root_id, httpMethod="GET", authorizationType="none" ) client.put_method_response( - restApiId=api_id, - resourceId=root_id, - httpMethod='GET', - statusCode='200', + restApiId=api_id, resourceId=root_id, httpMethod="GET", statusCode="200" ) response = client.put_integration( restApiId=api_id, resourceId=root_id, - httpMethod='GET', - type='HTTP', - uri='http://httpbin.org/robots.txt', + httpMethod="GET", + type="HTTP", + uri="http://httpbin.org/robots.txt", ) - stage_name = 'staging' - client.create_deployment( - restApiId=api_id, - stageName=stage_name, - ) + stage_name = "staging" + client.create_deployment(restApiId=api_id, stageName=stage_name) deploy_url = "https://{api_id}.execute-api.{region_name}.amazonaws.com/{stage_name}".format( - api_id=api_id, region_name=region_name, stage_name=stage_name) + api_id=api_id, region_name=region_name, stage_name=stage_name + ) if not settings.TEST_SERVER_MODE: requests.get(deploy_url).content.should.equal(b"a fake response") @@ -967,78 +789,79 @@ def test_http_proxying_integration(): @mock_apigateway def test_api_keys(): - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) response = client.get_api_keys() - len(response['items']).should.equal(0) + len(response["items"]).should.equal(0) - apikey_value = '12345' - apikey_name = 'TESTKEY1' - payload = {'value': apikey_value, 'name': apikey_name} + apikey_value = "12345" + apikey_name = "TESTKEY1" + payload = {"value": apikey_value, "name": apikey_name} response = client.create_api_key(**payload) - apikey = client.get_api_key(apiKey=response['id']) - apikey['name'].should.equal(apikey_name) - apikey['value'].should.equal(apikey_value) + apikey = client.get_api_key(apiKey=response["id"]) + apikey["name"].should.equal(apikey_name) + apikey["value"].should.equal(apikey_value) - apikey_name = 'TESTKEY2' - payload = {'name': apikey_name, 'tags': {'tag1': 'test_tag1', 'tag2': '1'}} + apikey_name = "TESTKEY2" + payload = {"name": apikey_name, "tags": {"tag1": "test_tag1", "tag2": "1"}} response = client.create_api_key(**payload) - apikey_id = response['id'] + apikey_id = response["id"] apikey = client.get_api_key(apiKey=apikey_id) - apikey['name'].should.equal(apikey_name) - apikey['tags']['tag1'].should.equal('test_tag1') - apikey['tags']['tag2'].should.equal('1') - len(apikey['value']).should.equal(40) + apikey["name"].should.equal(apikey_name) + apikey["tags"]["tag1"].should.equal("test_tag1") + apikey["tags"]["tag2"].should.equal("1") + len(apikey["value"]).should.equal(40) - apikey_name = 'TESTKEY3' - payload = {'name': apikey_name } + apikey_name = "TESTKEY3" + payload = {"name": apikey_name} response = client.create_api_key(**payload) - apikey_id = response['id'] + apikey_id = response["id"] patch_operations = [ - {'op': 'replace', 'path': '/name', 'value': 'TESTKEY3_CHANGE'}, - {'op': 'replace', 'path': '/customerId', 'value': '12345'}, - {'op': 'replace', 'path': '/description', 'value': 'APIKEY UPDATE TEST'}, - {'op': 'replace', 'path': '/enabled', 'value': 'false'}, + {"op": "replace", "path": "/name", "value": "TESTKEY3_CHANGE"}, + {"op": "replace", "path": "/customerId", "value": "12345"}, + {"op": "replace", "path": "/description", "value": "APIKEY UPDATE TEST"}, + {"op": "replace", "path": "/enabled", "value": "false"}, ] response = client.update_api_key(apiKey=apikey_id, patchOperations=patch_operations) - response['name'].should.equal('TESTKEY3_CHANGE') - response['customerId'].should.equal('12345') - response['description'].should.equal('APIKEY UPDATE TEST') - response['enabled'].should.equal(False) + response["name"].should.equal("TESTKEY3_CHANGE") + response["customerId"].should.equal("12345") + response["description"].should.equal("APIKEY UPDATE TEST") + response["enabled"].should.equal(False) response = client.get_api_keys() - len(response['items']).should.equal(3) + len(response["items"]).should.equal(3) client.delete_api_key(apiKey=apikey_id) response = client.get_api_keys() - len(response['items']).should.equal(2) + len(response["items"]).should.equal(2) + @mock_apigateway def test_usage_plans(): - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) response = client.get_usage_plans() - len(response['items']).should.equal(0) + len(response["items"]).should.equal(0) - usage_plan_name = 'TEST-PLAN' - payload = {'name': usage_plan_name} + usage_plan_name = "TEST-PLAN" + payload = {"name": usage_plan_name} response = client.create_usage_plan(**payload) - usage_plan = client.get_usage_plan(usagePlanId=response['id']) - usage_plan['name'].should.equal(usage_plan_name) - usage_plan['apiStages'].should.equal([]) + usage_plan = client.get_usage_plan(usagePlanId=response["id"]) + usage_plan["name"].should.equal(usage_plan_name) + usage_plan["apiStages"].should.equal([]) payload = { - 'name': 'TEST-PLAN-2', - 'description': 'Description', - 'quota': {'limit': 10, 'period': 'DAY', 'offset': 0}, - 'throttle': {'rateLimit': 2, 'burstLimit': 1}, - 'apiStages': [{'apiId': 'foo', 'stage': 'bar'}], - 'tags': {'tag_key': 'tag_value'}, + "name": "TEST-PLAN-2", + "description": "Description", + "quota": {"limit": 10, "period": "DAY", "offset": 0}, + "throttle": {"rateLimit": 2, "burstLimit": 1}, + "apiStages": [{"apiId": "foo", "stage": "bar"}], + "tags": {"tag_key": "tag_value"}, } response = client.create_usage_plan(**payload) - usage_plan_id = response['id'] + usage_plan_id = response["id"] usage_plan = client.get_usage_plan(usagePlanId=usage_plan_id) # The payload should remain unchanged @@ -1046,100 +869,110 @@ def test_usage_plans(): usage_plan.should.have.key(key).which.should.equal(value) # Status code should be 200 - usage_plan['ResponseMetadata'].should.have.key('HTTPStatusCode').which.should.equal(200) + usage_plan["ResponseMetadata"].should.have.key("HTTPStatusCode").which.should.equal( + 200 + ) # An Id should've been generated - usage_plan.should.have.key('id').which.should_not.be.none + usage_plan.should.have.key("id").which.should_not.be.none response = client.get_usage_plans() - len(response['items']).should.equal(2) + len(response["items"]).should.equal(2) client.delete_usage_plan(usagePlanId=usage_plan_id) response = client.get_usage_plans() - len(response['items']).should.equal(1) + len(response["items"]).should.equal(1) + @mock_apigateway def test_usage_plan_keys(): - region_name = 'us-west-2' - usage_plan_id = 'test_usage_plan_id' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + usage_plan_id = "test_usage_plan_id" + client = boto3.client("apigateway", region_name=region_name) usage_plan_id = "test" # Create an API key so we can use it - key_name = 'test-api-key' + key_name = "test-api-key" response = client.create_api_key(name=key_name) key_id = response["id"] key_value = response["value"] # Get current plan keys (expect none) response = client.get_usage_plan_keys(usagePlanId=usage_plan_id) - len(response['items']).should.equal(0) + len(response["items"]).should.equal(0) # Create usage plan key - key_type = 'API_KEY' - payload = {'usagePlanId': usage_plan_id, 'keyId': key_id, 'keyType': key_type } + key_type = "API_KEY" + payload = {"usagePlanId": usage_plan_id, "keyId": key_id, "keyType": key_type} response = client.create_usage_plan_key(**payload) usage_plan_key_id = response["id"] # Get current plan keys (expect 1) response = client.get_usage_plan_keys(usagePlanId=usage_plan_id) - len(response['items']).should.equal(1) + len(response["items"]).should.equal(1) # Get a single usage plan key and check it matches the created one - usage_plan_key = client.get_usage_plan_key(usagePlanId=usage_plan_id, keyId=usage_plan_key_id) - usage_plan_key['name'].should.equal(key_name) - usage_plan_key['id'].should.equal(key_id) - usage_plan_key['type'].should.equal(key_type) - usage_plan_key['value'].should.equal(key_value) + usage_plan_key = client.get_usage_plan_key( + usagePlanId=usage_plan_id, keyId=usage_plan_key_id + ) + usage_plan_key["name"].should.equal(key_name) + usage_plan_key["id"].should.equal(key_id) + usage_plan_key["type"].should.equal(key_type) + usage_plan_key["value"].should.equal(key_value) # Delete usage plan key client.delete_usage_plan_key(usagePlanId=usage_plan_id, keyId=key_id) # Get current plan keys (expect none) response = client.get_usage_plan_keys(usagePlanId=usage_plan_id) - len(response['items']).should.equal(0) + len(response["items"]).should.equal(0) + @mock_apigateway def test_create_usage_plan_key_non_existent_api_key(): - region_name = 'us-west-2' - usage_plan_id = 'test_usage_plan_id' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + usage_plan_id = "test_usage_plan_id" + client = boto3.client("apigateway", region_name=region_name) usage_plan_id = "test" # Attempt to create a usage plan key for a API key that doesn't exists - payload = {'usagePlanId': usage_plan_id, 'keyId': 'non-existent', 'keyType': 'API_KEY' } + payload = { + "usagePlanId": usage_plan_id, + "keyId": "non-existent", + "keyType": "API_KEY", + } client.create_usage_plan_key.when.called_with(**payload).should.throw(ClientError) @mock_apigateway def test_get_usage_plans_using_key_id(): - region_name = 'us-west-2' - client = boto3.client('apigateway', region_name=region_name) + region_name = "us-west-2" + client = boto3.client("apigateway", region_name=region_name) # Create 2 Usage Plans # one will be attached to an API Key, the other will remain unattached - attached_plan = client.create_usage_plan(name='Attached') - unattached_plan = client.create_usage_plan(name='Unattached') + attached_plan = client.create_usage_plan(name="Attached") + unattached_plan = client.create_usage_plan(name="Unattached") # Create an API key # to attach to the usage plan - key_name = 'test-api-key' + key_name = "test-api-key" response = client.create_api_key(name=key_name) key_id = response["id"] # Create a Usage Plan Key # Attached the Usage Plan and API Key - key_type = 'API_KEY' - payload = {'usagePlanId': attached_plan['id'], 'keyId': key_id, 'keyType': key_type} + key_type = "API_KEY" + payload = {"usagePlanId": attached_plan["id"], "keyId": key_id, "keyType": key_type} response = client.create_usage_plan_key(**payload) # All usage plans should be returned when keyId is not included all_plans = client.get_usage_plans() - len(all_plans['items']).should.equal(2) + len(all_plans["items"]).should.equal(2) # Only the usage plan attached to the given api key are included only_plans_with_key = client.get_usage_plans(keyId=key_id) - len(only_plans_with_key['items']).should.equal(1) - only_plans_with_key['items'][0]['name'].should.equal(attached_plan['name']) - only_plans_with_key['items'][0]['id'].should.equal(attached_plan['id']) + len(only_plans_with_key["items"]).should.equal(1) + only_plans_with_key["items"][0]["name"].should.equal(attached_plan["name"]) + only_plans_with_key["items"][0]["id"].should.equal(attached_plan["id"]) diff --git a/tests/test_apigateway/test_server.py b/tests/test_apigateway/test_server.py index 953d942cc..08b20cc61 100644 --- a/tests/test_apigateway/test_server.py +++ b/tests/test_apigateway/test_server.py @@ -4,88 +4,100 @@ import json import moto.server as server -''' +""" Test the different server responses -''' +""" def test_list_apis(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() - res = test_client.get('/restapis') + res = test_client.get("/restapis") res.data.should.equal(b'{"item": []}') + def test_usage_plans_apis(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() # List usage plans (expect empty) - res = test_client.get('/usageplans') + res = test_client.get("/usageplans") json.loads(res.data)["item"].should.have.length_of(0) # Create usage plan - res = test_client.post('/usageplans', data=json.dumps({'name': 'test'})) + res = test_client.post("/usageplans", data=json.dumps({"name": "test"})) created_plan = json.loads(res.data) - created_plan['name'].should.equal('test') + created_plan["name"].should.equal("test") # List usage plans (expect 1 plan) - res = test_client.get('/usageplans') + res = test_client.get("/usageplans") json.loads(res.data)["item"].should.have.length_of(1) # Get single usage plan - res = test_client.get('/usageplans/{0}'.format(created_plan["id"])) + res = test_client.get("/usageplans/{0}".format(created_plan["id"])) fetched_plan = json.loads(res.data) fetched_plan.should.equal(created_plan) # Delete usage plan - res = test_client.delete('/usageplans/{0}'.format(created_plan["id"])) - res.data.should.equal(b'{}') + res = test_client.delete("/usageplans/{0}".format(created_plan["id"])) + res.data.should.equal(b"{}") # List usage plans (expect empty again) - res = test_client.get('/usageplans') + res = test_client.get("/usageplans") json.loads(res.data)["item"].should.have.length_of(0) + def test_usage_plans_keys(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() - usage_plan_id = 'test_plan_id' + usage_plan_id = "test_plan_id" # Create API key to be used in tests - res = test_client.post('/apikeys', data=json.dumps({'name': 'test'})) + res = test_client.post("/apikeys", data=json.dumps({"name": "test"})) created_api_key = json.loads(res.data) # List usage plans keys (expect empty) - res = test_client.get('/usageplans/{0}/keys'.format(usage_plan_id)) + res = test_client.get("/usageplans/{0}/keys".format(usage_plan_id)) json.loads(res.data)["item"].should.have.length_of(0) # Create usage plan key - res = test_client.post('/usageplans/{0}/keys'.format(usage_plan_id), data=json.dumps({'keyId': created_api_key["id"], 'keyType': 'API_KEY'})) + res = test_client.post( + "/usageplans/{0}/keys".format(usage_plan_id), + data=json.dumps({"keyId": created_api_key["id"], "keyType": "API_KEY"}), + ) created_usage_plan_key = json.loads(res.data) # List usage plans keys (expect 1 key) - res = test_client.get('/usageplans/{0}/keys'.format(usage_plan_id)) + res = test_client.get("/usageplans/{0}/keys".format(usage_plan_id)) json.loads(res.data)["item"].should.have.length_of(1) # Get single usage plan key - res = test_client.get('/usageplans/{0}/keys/{1}'.format(usage_plan_id, created_api_key["id"])) + res = test_client.get( + "/usageplans/{0}/keys/{1}".format(usage_plan_id, created_api_key["id"]) + ) fetched_plan_key = json.loads(res.data) fetched_plan_key.should.equal(created_usage_plan_key) # Delete usage plan key - res = test_client.delete('/usageplans/{0}/keys/{1}'.format(usage_plan_id, created_api_key["id"])) - res.data.should.equal(b'{}') + res = test_client.delete( + "/usageplans/{0}/keys/{1}".format(usage_plan_id, created_api_key["id"]) + ) + res.data.should.equal(b"{}") # List usage plans keys (expect to be empty again) - res = test_client.get('/usageplans/{0}/keys'.format(usage_plan_id)) + res = test_client.get("/usageplans/{0}/keys".format(usage_plan_id)) json.loads(res.data)["item"].should.have.length_of(0) + def test_create_usage_plans_key_non_existent_api_key(): - backend = server.create_backend_app('apigateway') + backend = server.create_backend_app("apigateway") test_client = backend.test_client() - usage_plan_id = 'test_plan_id' + usage_plan_id = "test_plan_id" # Create usage plan key with non-existent api key - res = test_client.post('/usageplans/{0}/keys'.format(usage_plan_id), data=json.dumps({'keyId': 'non-existent', 'keyType': 'API_KEY'})) + res = test_client.post( + "/usageplans/{0}/keys".format(usage_plan_id), + data=json.dumps({"keyId": "non-existent", "keyType": "API_KEY"}), + ) res.status_code.should.equal(404) - diff --git a/tests/test_athena/test_athena.py b/tests/test_athena/test_athena.py index ad6ef908f..d36653910 100644 --- a/tests/test_athena/test_athena.py +++ b/tests/test_athena/test_athena.py @@ -11,19 +11,19 @@ from moto import mock_athena @mock_athena def test_create_work_group(): - client = boto3.client('athena', region_name='us-east-1') + client = boto3.client("athena", region_name="us-east-1") response = client.create_work_group( - Name='athena_workgroup', - Description='Test work group', + Name="athena_workgroup", + Description="Test work group", Configuration={ - 'ResultConfiguration': { - 'OutputLocation': 's3://bucket-name/prefix/', - 'EncryptionConfiguration': { - 'EncryptionOption': 'SSE_KMS', - 'KmsKey': 'aws:arn:kms:1233456789:us-east-1:key/number-1', + "ResultConfiguration": { + "OutputLocation": "s3://bucket-name/prefix/", + "EncryptionConfiguration": { + "EncryptionOption": "SSE_KMS", + "KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1", }, - }, + } }, Tags=[], ) @@ -31,29 +31,29 @@ def test_create_work_group(): try: # The second time should throw an error response = client.create_work_group( - Name='athena_workgroup', - Description='duplicate', + Name="athena_workgroup", + Description="duplicate", Configuration={ - 'ResultConfiguration': { - 'OutputLocation': 's3://bucket-name/prefix/', - 'EncryptionConfiguration': { - 'EncryptionOption': 'SSE_KMS', - 'KmsKey': 'aws:arn:kms:1233456789:us-east-1:key/number-1', + "ResultConfiguration": { + "OutputLocation": "s3://bucket-name/prefix/", + "EncryptionConfiguration": { + "EncryptionOption": "SSE_KMS", + "KmsKey": "aws:arn:kms:1233456789:us-east-1:key/number-1", }, - }, + } }, ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidRequestException') - err.response['Error']['Message'].should.equal('WorkGroup already exists') + err.response["Error"]["Code"].should.equal("InvalidRequestException") + err.response["Error"]["Message"].should.equal("WorkGroup already exists") else: - raise RuntimeError('Should have raised ResourceNotFoundException') + raise RuntimeError("Should have raised ResourceNotFoundException") # Then test the work group appears in the work group list response = client.list_work_groups() - response['WorkGroups'].should.have.length_of(1) - work_group = response['WorkGroups'][0] - work_group['Name'].should.equal('athena_workgroup') - work_group['Description'].should.equal('Test work group') - work_group['State'].should.equal('ENABLED') + response["WorkGroups"].should.have.length_of(1) + work_group = response["WorkGroups"][0] + work_group["Name"].should.equal("athena_workgroup") + work_group["Description"].should.equal("Test work group") + work_group["State"].should.equal("ENABLED") diff --git a/tests/test_autoscaling/test_autoscaling.py b/tests/test_autoscaling/test_autoscaling.py index 2df7bf30f..c46bc7219 100644 --- a/tests/test_autoscaling/test_autoscaling.py +++ b/tests/test_autoscaling/test_autoscaling.py @@ -10,31 +10,39 @@ import sure # noqa from botocore.exceptions import ClientError from nose.tools import assert_raises -from moto import mock_autoscaling, mock_ec2_deprecated, mock_elb_deprecated, mock_elb, mock_autoscaling_deprecated, mock_ec2 +from moto import ( + mock_autoscaling, + mock_ec2_deprecated, + mock_elb_deprecated, + mock_elb, + mock_autoscaling_deprecated, + mock_ec2, +) from tests.helpers import requires_boto_gte -from utils import setup_networking, setup_networking_deprecated, setup_instance_with_networking +from utils import ( + setup_networking, + setup_networking_deprecated, + setup_instance_with_networking, +) @mock_autoscaling_deprecated @mock_elb_deprecated def test_create_autoscaling_group(): mocked_networking = setup_networking_deprecated() - elb_conn = boto.ec2.elb.connect_to_region('us-east-1') - elb_conn.create_load_balancer( - 'test_lb', zones=[], listeners=[(80, 8080, 'http')]) + elb_conn = boto.ec2.elb.connect_to_region("us-east-1") + elb_conn.create_load_balancer("test_lb", zones=[], listeners=[(80, 8080, "http")]) - conn = boto.ec2.autoscale.connect_to_region('us-east-1') + conn = boto.ec2.autoscale.connect_to_region("us-east-1") config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a', 'us-east-1b'], + name="tester_group", + availability_zones=["us-east-1a", "us-east-1b"], default_cooldown=60, desired_capacity=2, health_check_period=100, @@ -45,45 +53,44 @@ def test_create_autoscaling_group(): load_balancers=["test_lb"], placement_group="test_placement", vpc_zone_identifier="{subnet1},{subnet2}".format( - subnet1=mocked_networking['subnet1'], - subnet2=mocked_networking['subnet2'], + subnet1=mocked_networking["subnet1"], subnet2=mocked_networking["subnet2"] ), termination_policies=["OldestInstance", "NewestInstance"], - tags=[Tag( - resource_id='tester_group', - key='test_key', - value='test_value', - propagate_at_launch=True - ) + tags=[ + Tag( + resource_id="tester_group", + key="test_key", + value="test_value", + propagate_at_launch=True, + ) ], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] - group.name.should.equal('tester_group') - set(group.availability_zones).should.equal( - set(['us-east-1a', 'us-east-1b'])) + group.name.should.equal("tester_group") + set(group.availability_zones).should.equal(set(["us-east-1a", "us-east-1b"])) group.desired_capacity.should.equal(2) group.max_size.should.equal(2) group.min_size.should.equal(2) group.instances.should.have.length_of(2) - group.vpc_zone_identifier.should.equal("{subnet1},{subnet2}".format( - subnet1=mocked_networking['subnet1'], - subnet2=mocked_networking['subnet2'], - )) - group.launch_config_name.should.equal('tester') + group.vpc_zone_identifier.should.equal( + "{subnet1},{subnet2}".format( + subnet1=mocked_networking["subnet1"], subnet2=mocked_networking["subnet2"] + ) + ) + group.launch_config_name.should.equal("tester") group.default_cooldown.should.equal(60) group.health_check_period.should.equal(100) group.health_check_type.should.equal("EC2") list(group.load_balancers).should.equal(["test_lb"]) group.placement_group.should.equal("test_placement") - list(group.termination_policies).should.equal( - ["OldestInstance", "NewestInstance"]) + list(group.termination_policies).should.equal(["OldestInstance", "NewestInstance"]) len(list(group.tags)).should.equal(1) tag = list(group.tags)[0] - tag.resource_id.should.equal('tester_group') - tag.key.should.equal('test_key') - tag.value.should.equal('test_value') + tag.resource_id.should.equal("tester_group") + tag.key.should.equal("test_key") + tag.value.should.equal("test_value") tag.propagate_at_launch.should.equal(True) @@ -95,31 +102,29 @@ def test_create_autoscaling_groups_defaults(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] - group.name.should.equal('tester_group') + group.name.should.equal("tester_group") group.max_size.should.equal(2) group.min_size.should.equal(2) - group.launch_config_name.should.equal('tester') + group.launch_config_name.should.equal("tester") # Defaults - list(group.availability_zones).should.equal(['us-east-1a']) # subnet1 + list(group.availability_zones).should.equal(["us-east-1a"]) # subnet1 group.desired_capacity.should.equal(2) - group.vpc_zone_identifier.should.equal(mocked_networking['subnet1']) + group.vpc_zone_identifier.should.equal(mocked_networking["subnet1"]) group.default_cooldown.should.equal(300) group.health_check_period.should.equal(300) group.health_check_type.should.equal("EC2") @@ -132,55 +137,61 @@ def test_create_autoscaling_groups_defaults(): @mock_autoscaling def test_list_many_autoscaling_groups(): mocked_networking = setup_networking() - conn = boto3.client('autoscaling', region_name='us-east-1') - conn.create_launch_configuration(LaunchConfigurationName='TestLC') + conn = boto3.client("autoscaling", region_name="us-east-1") + conn.create_launch_configuration(LaunchConfigurationName="TestLC") for i in range(51): - conn.create_auto_scaling_group(AutoScalingGroupName='TestGroup%d' % i, - MinSize=1, - MaxSize=2, - LaunchConfigurationName='TestLC', - VPCZoneIdentifier=mocked_networking['subnet1']) + conn.create_auto_scaling_group( + AutoScalingGroupName="TestGroup%d" % i, + MinSize=1, + MaxSize=2, + LaunchConfigurationName="TestLC", + VPCZoneIdentifier=mocked_networking["subnet1"], + ) response = conn.describe_auto_scaling_groups() groups = response["AutoScalingGroups"] marker = response["NextToken"] groups.should.have.length_of(50) - marker.should.equal(groups[-1]['AutoScalingGroupName']) + marker.should.equal(groups[-1]["AutoScalingGroupName"]) response2 = conn.describe_auto_scaling_groups(NextToken=marker) groups.extend(response2["AutoScalingGroups"]) groups.should.have.length_of(51) - assert 'NextToken' not in response2.keys() + assert "NextToken" not in response2.keys() @mock_autoscaling @mock_ec2 def test_list_many_autoscaling_groups(): mocked_networking = setup_networking() - conn = boto3.client('autoscaling', region_name='us-east-1') - conn.create_launch_configuration(LaunchConfigurationName='TestLC') + conn = boto3.client("autoscaling", region_name="us-east-1") + conn.create_launch_configuration(LaunchConfigurationName="TestLC") - conn.create_auto_scaling_group(AutoScalingGroupName='TestGroup1', - MinSize=1, - MaxSize=2, - LaunchConfigurationName='TestLC', - Tags=[{ - "ResourceId": 'TestGroup1', - "ResourceType": "auto-scaling-group", - "PropagateAtLaunch": True, - "Key": 'TestTagKey1', - "Value": 'TestTagValue1' - }], - VPCZoneIdentifier=mocked_networking['subnet1']) + conn.create_auto_scaling_group( + AutoScalingGroupName="TestGroup1", + MinSize=1, + MaxSize=2, + LaunchConfigurationName="TestLC", + Tags=[ + { + "ResourceId": "TestGroup1", + "ResourceType": "auto-scaling-group", + "PropagateAtLaunch": True, + "Key": "TestTagKey1", + "Value": "TestTagValue1", + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], + ) - ec2 = boto3.client('ec2', region_name='us-east-1') + ec2 = boto3.client("ec2", region_name="us-east-1") instances = ec2.describe_instances() - tags = instances['Reservations'][0]['Instances'][0]['Tags'] - tags.should.contain({u'Value': 'TestTagValue1', u'Key': 'TestTagKey1'}) - tags.should.contain({u'Value': 'TestGroup1', u'Key': 'aws:autoscaling:groupName'}) + tags = instances["Reservations"][0]["Instances"][0]["Tags"] + tags.should.contain({"Value": "TestTagValue1", "Key": "TestTagKey1"}) + tags.should.contain({"Value": "TestGroup1", "Key": "aws:autoscaling:groupName"}) @mock_autoscaling_deprecated @@ -188,27 +199,26 @@ def test_autoscaling_group_describe_filter(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) - group.name = 'tester_group2' + group.name = "tester_group2" conn.create_auto_scaling_group(group) - group.name = 'tester_group3' + group.name = "tester_group3" conn.create_auto_scaling_group(group) - conn.get_all_groups( - names=['tester_group', 'tester_group2']).should.have.length_of(2) + conn.get_all_groups(names=["tester_group", "tester_group2"]).should.have.length_of( + 2 + ) conn.get_all_groups().should.have.length_of(3) @@ -217,33 +227,31 @@ def test_autoscaling_update(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] - group.availability_zones.should.equal(['us-east-1a']) - group.vpc_zone_identifier.should.equal(mocked_networking['subnet1']) + group.availability_zones.should.equal(["us-east-1a"]) + group.vpc_zone_identifier.should.equal(mocked_networking["subnet1"]) - group.availability_zones = ['us-east-1b'] - group.vpc_zone_identifier = mocked_networking['subnet2'] + group.availability_zones = ["us-east-1b"] + group.vpc_zone_identifier = mocked_networking["subnet2"] group.update() group = conn.get_all_groups()[0] - group.availability_zones.should.equal(['us-east-1b']) - group.vpc_zone_identifier.should.equal(mocked_networking['subnet2']) + group.availability_zones.should.equal(["us-east-1b"]) + group.vpc_zone_identifier.should.equal(mocked_networking["subnet2"]) @mock_autoscaling_deprecated @@ -251,40 +259,45 @@ def test_autoscaling_tags_update(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - tags=[Tag( - resource_id='tester_group', - key='test_key', - value='test_value', - propagate_at_launch=True - )], - vpc_zone_identifier=mocked_networking['subnet1'], + tags=[ + Tag( + resource_id="tester_group", + key="test_key", + value="test_value", + propagate_at_launch=True, + ) + ], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) - conn.create_or_update_tags(tags=[Tag( - resource_id='tester_group', - key='test_key', - value='new_test_value', - propagate_at_launch=True - ), Tag( - resource_id='tester_group', - key='test_key2', - value='test_value2', - propagate_at_launch=True - )]) + conn.create_or_update_tags( + tags=[ + Tag( + resource_id="tester_group", + key="test_key", + value="new_test_value", + propagate_at_launch=True, + ), + Tag( + resource_id="tester_group", + key="test_key2", + value="test_value2", + propagate_at_launch=True, + ), + ] + ) group = conn.get_all_groups()[0] group.tags.should.have.length_of(2) @@ -294,24 +307,22 @@ def test_autoscaling_group_delete(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) conn.get_all_groups().should.have.length_of(1) - conn.delete_auto_scaling_group('tester_group') + conn.delete_auto_scaling_group("tester_group") conn.get_all_groups().should.have.length_of(0) @@ -319,30 +330,28 @@ def test_autoscaling_group_delete(): @mock_autoscaling_deprecated def test_autoscaling_group_describe_instances(): mocked_networking = setup_networking_deprecated() - conn = boto.ec2.autoscale.connect_to_region('us-east-1') + conn = boto.ec2.autoscale.connect_to_region("us-east-1") config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) instances = list(conn.get_all_autoscaling_instances()) instances.should.have.length_of(2) - instances[0].launch_config_name.should.equal('tester') - instances[0].health_status.should.equal('Healthy') + instances[0].launch_config_name.should.equal("tester") + instances[0].health_status.should.equal("Healthy") autoscale_instance_ids = [instance.instance_id for instance in instances] - ec2_conn = boto.ec2.connect_to_region('us-east-1') + ec2_conn = boto.ec2.connect_to_region("us-east-1") reservations = ec2_conn.get_all_instances() instances = reservations[0].instances instances.should.have.length_of(2) @@ -357,20 +366,18 @@ def test_set_desired_capacity_up(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) @@ -393,20 +400,18 @@ def test_set_desired_capacity_down(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) @@ -429,20 +434,18 @@ def test_set_desired_capacity_the_same(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', - availability_zones=['us-east-1a'], + name="tester_group", + availability_zones=["us-east-1a"], desired_capacity=2, max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) @@ -464,26 +467,24 @@ def test_set_desired_capacity_the_same(): def test_autoscaling_group_with_elb(): mocked_networking = setup_networking_deprecated() elb_conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = elb_conn.create_load_balancer('my-lb', zones, ports) - instances_health = elb_conn.describe_instance_health('my-lb') + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = elb_conn.create_load_balancer("my-lb", zones, ports) + instances_health = elb_conn.describe_instance_health("my-lb") instances_health.should.be.empty conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t2.medium', + name="tester", image_id="ami-abcd1234", instance_type="t2.medium" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, load_balancers=["my-lb"], - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) group = conn.get_all_groups()[0] @@ -491,8 +492,7 @@ def test_autoscaling_group_with_elb(): group.desired_capacity.should.equal(2) elb.instances.should.have.length_of(2) - autoscale_instance_ids = set( - instance.instance_id for instance in group.instances) + autoscale_instance_ids = set(instance.instance_id for instance in group.instances) elb_instace_ids = set(instance.id for instance in elb.instances) autoscale_instance_ids.should.equal(elb_instace_ids) @@ -502,20 +502,19 @@ def test_autoscaling_group_with_elb(): group.desired_capacity.should.equal(3) elb.instances.should.have.length_of(3) - autoscale_instance_ids = set( - instance.instance_id for instance in group.instances) + autoscale_instance_ids = set(instance.instance_id for instance in group.instances) elb_instace_ids = set(instance.id for instance in elb.instances) autoscale_instance_ids.should.equal(elb_instace_ids) - conn.delete_auto_scaling_group('tester_group') + conn.delete_auto_scaling_group("tester_group") conn.get_all_groups().should.have.length_of(0) elb = elb_conn.get_all_load_balancers()[0] elb.instances.should.have.length_of(0) -''' +""" Boto3 -''' +""" @mock_autoscaling @@ -524,77 +523,74 @@ def test_describe_load_balancers(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', - LoadBalancerNames=['my-lb'], + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", + LoadBalancerNames=["my-lb"], MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_load_balancers(AutoScalingGroupName='test_asg') - assert response['ResponseMetadata']['RequestId'] - list(response['LoadBalancers']).should.have.length_of(1) - response['LoadBalancers'][0]['LoadBalancerName'].should.equal('my-lb') + response = client.describe_load_balancers(AutoScalingGroupName="test_asg") + assert response["ResponseMetadata"]["RequestId"] + list(response["LoadBalancers"]).should.have.length_of(1) + response["LoadBalancers"][0]["LoadBalancerName"].should.equal("my-lb") + @mock_autoscaling @mock_elb def test_create_elb_and_autoscaling_group_no_relationship(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - ELB_NAME = 'my-elb' + ELB_NAME = "my-elb" - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( LoadBalancerName=ELB_NAME, - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) # autoscaling group and elb should have no relationship - response = client.describe_load_balancers( - AutoScalingGroupName='test_asg' - ) - list(response['LoadBalancers']).should.have.length_of(0) - response = elb_client.describe_load_balancers( - LoadBalancerNames=[ELB_NAME] - ) - list(response['LoadBalancerDescriptions'][0]['Instances']).should.have.length_of(0) + response = client.describe_load_balancers(AutoScalingGroupName="test_asg") + list(response["LoadBalancers"]).should.have.length_of(0) + response = elb_client.describe_load_balancers(LoadBalancerNames=[ELB_NAME]) + list(response["LoadBalancerDescriptions"][0]["Instances"]).should.have.length_of(0) @mock_autoscaling @@ -603,47 +599,46 @@ def test_attach_load_balancer(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) response = client.attach_load_balancers( - AutoScalingGroupName='test_asg', - LoadBalancerNames=['my-lb']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - - response = elb_client.describe_load_balancers( - LoadBalancerNames=['my-lb'] + AutoScalingGroupName="test_asg", LoadBalancerNames=["my-lb"] ) - list(response['LoadBalancerDescriptions'][0]['Instances']).should.have.length_of(INSTANCE_COUNT) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] + response = elb_client.describe_load_balancers(LoadBalancerNames=["my-lb"]) + list(response["LoadBalancerDescriptions"][0]["Instances"]).should.have.length_of( + INSTANCE_COUNT ) - list(response['AutoScalingGroups'][0]['LoadBalancerNames']).should.have.length_of(1) + + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + list(response["AutoScalingGroups"][0]["LoadBalancerNames"]).should.have.length_of(1) @mock_autoscaling @@ -652,740 +647,736 @@ def test_detach_load_balancer(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - elb_client = boto3.client('elb', region_name='us-east-1') + elb_client = boto3.client("elb", region_name="us-east-1") elb_client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', - LoadBalancerNames=['my-lb'], + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", + LoadBalancerNames=["my-lb"], MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) response = client.detach_load_balancers( - AutoScalingGroupName='test_asg', - LoadBalancerNames=['my-lb']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - - response = elb_client.describe_load_balancers( - LoadBalancerNames=['my-lb'] + AutoScalingGroupName="test_asg", LoadBalancerNames=["my-lb"] ) - list(response['LoadBalancerDescriptions'][0]['Instances']).should.have.length_of(0) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_load_balancers(AutoScalingGroupName='test_asg') - list(response['LoadBalancers']).should.have.length_of(0) + response = elb_client.describe_load_balancers(LoadBalancerNames=["my-lb"]) + list(response["LoadBalancerDescriptions"][0]["Instances"]).should.have.length_of(0) + + response = client.describe_load_balancers(AutoScalingGroupName="test_asg") + list(response["LoadBalancers"]).should.have.length_of(0) @mock_autoscaling def test_create_autoscaling_group_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) response = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, Tags=[ - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }, - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'not-propogated-tag-key', - 'Value': 'not-propogate-tag-value', - 'PropagateAtLaunch': False - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + }, + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "not-propogated-tag-key", + "Value": "not-propogate-tag-value", + "PropagateAtLaunch": False, + }, + ], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @mock_autoscaling def test_create_autoscaling_group_from_instance(): - autoscaling_group_name = 'test_asg' - image_id = 'ami-0cc293023f983ed53' - instance_type = 't2.micro' + autoscaling_group_name = "test_asg" + image_id = "ami-0cc293023f983ed53" + instance_type = "t2.micro" - mocked_instance_with_networking = setup_instance_with_networking(image_id, instance_type) - client = boto3.client('autoscaling', region_name='us-east-1') + mocked_instance_with_networking = setup_instance_with_networking( + image_id, instance_type + ) + client = boto3.client("autoscaling", region_name="us-east-1") response = client.create_auto_scaling_group( AutoScalingGroupName=autoscaling_group_name, - InstanceId=mocked_instance_with_networking['instance'], + InstanceId=mocked_instance_with_networking["instance"], MinSize=1, MaxSize=3, DesiredCapacity=2, Tags=[ - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }, - {'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'not-propogated-tag-key', - 'Value': 'not-propogate-tag-value', - 'PropagateAtLaunch': False - }], - VPCZoneIdentifier=mocked_instance_with_networking['subnet1'], + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + }, + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "not-propogated-tag-key", + "Value": "not-propogate-tag-value", + "PropagateAtLaunch": False, + }, + ], + VPCZoneIdentifier=mocked_instance_with_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) describe_launch_configurations_response = client.describe_launch_configurations() - describe_launch_configurations_response['LaunchConfigurations'].should.have.length_of(1) - launch_configuration_from_instance = describe_launch_configurations_response['LaunchConfigurations'][0] - launch_configuration_from_instance['LaunchConfigurationName'].should.equal('test_asg') - launch_configuration_from_instance['ImageId'].should.equal(image_id) - launch_configuration_from_instance['InstanceType'].should.equal(instance_type) + describe_launch_configurations_response[ + "LaunchConfigurations" + ].should.have.length_of(1) + launch_configuration_from_instance = describe_launch_configurations_response[ + "LaunchConfigurations" + ][0] + launch_configuration_from_instance["LaunchConfigurationName"].should.equal( + "test_asg" + ) + launch_configuration_from_instance["ImageId"].should.equal(image_id) + launch_configuration_from_instance["InstanceType"].should.equal(instance_type) @mock_autoscaling def test_create_autoscaling_group_from_invalid_instance_id(): - invalid_instance_id = 'invalid_instance' + invalid_instance_id = "invalid_instance" mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") with assert_raises(ClientError) as ex: client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceId=invalid_instance_id, MinSize=9, MaxSize=15, DesiredCapacity=12, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Code'].should.equal('ValidationError') - ex.exception.response['Error']['Message'].should.equal('Instance [{0}] is invalid.'.format(invalid_instance_id)) + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Code"].should.equal("ValidationError") + ex.exception.response["Error"]["Message"].should.equal( + "Instance [{0}] is invalid.".format(invalid_instance_id) + ) @mock_autoscaling def test_describe_autoscaling_groups_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] - ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - group = response['AutoScalingGroups'][0] - group['AutoScalingGroupName'].should.equal('test_asg') - group['AvailabilityZones'].should.equal(['us-east-1a']) - group['VPCZoneIdentifier'].should.equal(mocked_networking['subnet1']) - group['NewInstancesProtectedFromScaleIn'].should.equal(True) - for instance in group['Instances']: - instance['AvailabilityZone'].should.equal('us-east-1a') - instance['ProtectedFromScaleIn'].should.equal(True) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + group = response["AutoScalingGroups"][0] + group["AutoScalingGroupName"].should.equal("test_asg") + group["AvailabilityZones"].should.equal(["us-east-1a"]) + group["VPCZoneIdentifier"].should.equal(mocked_networking["subnet1"]) + group["NewInstancesProtectedFromScaleIn"].should.equal(True) + for instance in group["Instances"]: + instance["AvailabilityZone"].should.equal("us-east-1a") + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling def test_describe_autoscaling_instances_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] - ) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) instance_ids = [ - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ] response = client.describe_auto_scaling_instances(InstanceIds=instance_ids) - for instance in response['AutoScalingInstances']: - instance['AutoScalingGroupName'].should.equal('test_asg') - instance['AvailabilityZone'].should.equal('us-east-1a') - instance['ProtectedFromScaleIn'].should.equal(True) + for instance in response["AutoScalingInstances"]: + instance["AutoScalingGroupName"].should.equal("test_asg") + instance["AvailabilityZone"].should.equal("us-east-1a") + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling def test_update_autoscaling_group_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) _ = client.update_auto_scaling_group( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", MinSize=1, VPCZoneIdentifier="{subnet1},{subnet2}".format( - subnet1=mocked_networking['subnet1'], - subnet2=mocked_networking['subnet2'], + subnet1=mocked_networking["subnet1"], subnet2=mocked_networking["subnet2"] ), NewInstancesProtectedFromScaleIn=False, ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] - ) - group = response['AutoScalingGroups'][0] - group['MinSize'].should.equal(1) - set(group['AvailabilityZones']).should.equal({'us-east-1a', 'us-east-1b'}) - group['NewInstancesProtectedFromScaleIn'].should.equal(False) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["MinSize"].should.equal(1) + set(group["AvailabilityZones"]).should.equal({"us-east-1a", "us-east-1b"}) + group["NewInstancesProtectedFromScaleIn"].should.equal(False) @mock_autoscaling def test_update_autoscaling_group_min_size_desired_capacity_change(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=20, DesiredCapacity=3, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - client.update_auto_scaling_group( - AutoScalingGroupName='test_asg', - MinSize=5, - ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg']) - group = response['AutoScalingGroups'][0] - group['DesiredCapacity'].should.equal(5) - group['MinSize'].should.equal(5) - group['Instances'].should.have.length_of(5) + client.update_auto_scaling_group(AutoScalingGroupName="test_asg", MinSize=5) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["DesiredCapacity"].should.equal(5) + group["MinSize"].should.equal(5) + group["Instances"].should.have.length_of(5) @mock_autoscaling def test_update_autoscaling_group_max_size_desired_capacity_change(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=20, DesiredCapacity=10, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - client.update_auto_scaling_group( - AutoScalingGroupName='test_asg', - MaxSize=5, - ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg']) - group = response['AutoScalingGroups'][0] - group['DesiredCapacity'].should.equal(5) - group['MaxSize'].should.equal(5) - group['Instances'].should.have.length_of(5) + client.update_auto_scaling_group(AutoScalingGroupName="test_asg", MaxSize=5) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["DesiredCapacity"].should.equal(5) + group["MaxSize"].should.equal(5) + group["Instances"].should.have.length_of(5) @mock_autoscaling def test_autoscaling_taqs_update_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - client.create_or_update_tags(Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'updated_test_value', - "PropagateAtLaunch": True - }, { - "ResourceId": 'test_asg', - "Key": 'test_key2', - "Value": 'test_value2', - "PropagateAtLaunch": False - }]) - - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=["test_asg"] + client.create_or_update_tags( + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "updated_test_value", + "PropagateAtLaunch": True, + }, + { + "ResourceId": "test_asg", + "Key": "test_key2", + "Value": "test_value2", + "PropagateAtLaunch": False, + }, + ] ) - response['AutoScalingGroups'][0]['Tags'].should.have.length_of(2) + + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + response["AutoScalingGroups"][0]["Tags"].should.have.length_of(2) @mock_autoscaling def test_autoscaling_describe_policies_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - Tags=[{ - "ResourceId": 'test_asg', - "Key": 'test_key', - "Value": 'test_value', - "PropagateAtLaunch": True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "Key": "test_key", + "Value": "test_value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) client.put_scaling_policy( - AutoScalingGroupName='test_asg', - PolicyName='test_policy_down', - PolicyType='SimpleScaling', - AdjustmentType='PercentChangeInCapacity', + AutoScalingGroupName="test_asg", + PolicyName="test_policy_down", + PolicyType="SimpleScaling", + AdjustmentType="PercentChangeInCapacity", ScalingAdjustment=-10, Cooldown=60, - MinAdjustmentMagnitude=1) + MinAdjustmentMagnitude=1, + ) client.put_scaling_policy( - AutoScalingGroupName='test_asg', - PolicyName='test_policy_up', - PolicyType='SimpleScaling', - AdjustmentType='PercentChangeInCapacity', + AutoScalingGroupName="test_asg", + PolicyName="test_policy_up", + PolicyType="SimpleScaling", + AdjustmentType="PercentChangeInCapacity", ScalingAdjustment=10, Cooldown=60, - MinAdjustmentMagnitude=1) + MinAdjustmentMagnitude=1, + ) response = client.describe_policies() - response['ScalingPolicies'].should.have.length_of(2) + response["ScalingPolicies"].should.have.length_of(2) - response = client.describe_policies(AutoScalingGroupName='test_asg') - response['ScalingPolicies'].should.have.length_of(2) + response = client.describe_policies(AutoScalingGroupName="test_asg") + response["ScalingPolicies"].should.have.length_of(2) - response = client.describe_policies(PolicyTypes=['StepScaling']) - response['ScalingPolicies'].should.have.length_of(0) + response = client.describe_policies(PolicyTypes=["StepScaling"]) + response["ScalingPolicies"].should.have.length_of(0) response = client.describe_policies( - AutoScalingGroupName='test_asg', - PolicyNames=['test_policy_down'], - PolicyTypes=['SimpleScaling'] + AutoScalingGroupName="test_asg", + PolicyNames=["test_policy_down"], + PolicyTypes=["SimpleScaling"], ) - response['ScalingPolicies'].should.have.length_of(1) - response['ScalingPolicies'][0][ - 'PolicyName'].should.equal('test_policy_down') + response["ScalingPolicies"].should.have.length_of(1) + response["ScalingPolicies"][0]["PolicyName"].should.equal("test_policy_down") + @mock_autoscaling @mock_ec2 def test_detach_one_instance_decrement(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=2, DesiredCapacity=2, - Tags=[{ - 'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - instance_to_detach = response['AutoScalingGroups'][0]['Instances'][0]['InstanceId'] - instance_to_keep = response['AutoScalingGroups'][0]['Instances'][1]['InstanceId'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instance_to_detach = response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"] + instance_to_keep = response["AutoScalingGroups"][0]["Instances"][1]["InstanceId"] - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) response = client.detach_instances( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=[instance_to_detach], - ShouldDecrementDesiredCapacity=True + ShouldDecrementDesiredCapacity=True, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - response['AutoScalingGroups'][0]['Instances'].should.have.length_of(1) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + response["AutoScalingGroups"][0]["Instances"].should.have.length_of(1) # test to ensure tag has been removed response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(1) # test to ensure tag is present on other instance response = ec2_client.describe_instances(InstanceIds=[instance_to_keep]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(2) + @mock_autoscaling @mock_ec2 def test_detach_one_instance(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=2, DesiredCapacity=2, - Tags=[{ - 'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - instance_to_detach = response['AutoScalingGroups'][0]['Instances'][0]['InstanceId'] - instance_to_keep = response['AutoScalingGroups'][0]['Instances'][1]['InstanceId'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instance_to_detach = response["AutoScalingGroups"][0]["Instances"][0]["InstanceId"] + instance_to_keep = response["AutoScalingGroups"][0]["Instances"][1]["InstanceId"] - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) response = client.detach_instances( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=[instance_to_detach], - ShouldDecrementDesiredCapacity=False + ShouldDecrementDesiredCapacity=False, ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) # test to ensure instance was replaced - response['AutoScalingGroups'][0]['Instances'].should.have.length_of(2) + response["AutoScalingGroups"][0]["Instances"].should.have.length_of(2) response = ec2_client.describe_instances(InstanceIds=[instance_to_detach]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(1) response = ec2_client.describe_instances(InstanceIds=[instance_to_keep]) - tags = response['Reservations'][0]['Instances'][0]['Tags'] + tags = response["Reservations"][0]["Instances"][0]["Tags"] tags.should.have.length_of(2) + @mock_autoscaling @mock_ec2 def test_attach_one_instance(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=4, DesiredCapacity=2, - Tags=[{ - 'ResourceId': 'test_asg', - 'ResourceType': 'auto-scaling-group', - 'Key': 'propogated-tag-key', - 'Value': 'propogate-tag-value', - 'PropagateAtLaunch': True - }], - VPCZoneIdentifier=mocked_networking['subnet1'], + Tags=[ + { + "ResourceId": "test_asg", + "ResourceType": "auto-scaling-group", + "Key": "propogated-tag-key", + "Value": "propogate-tag-value", + "PropagateAtLaunch": True, + } + ], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - ec2 = boto3.resource('ec2', 'us-east-1') - instances_to_add = [x.id for x in ec2.create_instances(ImageId='', MinCount=1, MaxCount=1)] + ec2 = boto3.resource("ec2", "us-east-1") + instances_to_add = [ + x.id for x in ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + ] response = client.attach_instances( - AutoScalingGroupName='test_asg', - InstanceIds=instances_to_add + AutoScalingGroupName="test_asg", InstanceIds=instances_to_add ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - instances = response['AutoScalingGroups'][0]['Instances'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instances = response["AutoScalingGroups"][0]["Instances"] instances.should.have.length_of(3) for instance in instances: - instance['ProtectedFromScaleIn'].should.equal(True) + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling @mock_ec2 def test_describe_instance_health(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=4, DesiredCapacity=2, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + + instance1 = response["AutoScalingGroups"][0]["Instances"][0] + instance1["HealthStatus"].should.equal("Healthy") - instance1 = response['AutoScalingGroups'][0]['Instances'][0] - instance1['HealthStatus'].should.equal('Healthy') @mock_autoscaling @mock_ec2 def test_set_instance_health(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=2, MaxSize=4, DesiredCapacity=2, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + + instance1 = response["AutoScalingGroups"][0]["Instances"][0] + instance1["HealthStatus"].should.equal("Healthy") + + client.set_instance_health( + InstanceId=instance1["InstanceId"], HealthStatus="Unhealthy" ) - instance1 = response['AutoScalingGroups'][0]['Instances'][0] - instance1['HealthStatus'].should.equal('Healthy') + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) - client.set_instance_health(InstanceId=instance1['InstanceId'], HealthStatus='Unhealthy') + instance1 = response["AutoScalingGroups"][0]["Instances"][0] + instance1["HealthStatus"].should.equal("Unhealthy") - response = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test_asg'] - ) - - instance1 = response['AutoScalingGroups'][0]['Instances'][0] - instance1['HealthStatus'].should.equal('Unhealthy') @mock_autoscaling def test_suspend_processes(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') - client.create_launch_configuration( - LaunchConfigurationName='lc', - ) + client = boto3.client("autoscaling", region_name="us-east-1") + client.create_launch_configuration(LaunchConfigurationName="lc") client.create_auto_scaling_group( - LaunchConfigurationName='lc', - AutoScalingGroupName='test-asg', + LaunchConfigurationName="lc", + AutoScalingGroupName="test-asg", MinSize=1, MaxSize=1, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], ) # When we suspend the 'Launch' process on the ASG client client.suspend_processes( - AutoScalingGroupName='test-asg', - ScalingProcesses=['Launch'] + AutoScalingGroupName="test-asg", ScalingProcesses=["Launch"] ) - res = client.describe_auto_scaling_groups( - AutoScalingGroupNames=['test-asg'] - ) + res = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test-asg"]) # The 'Launch' process should, in fact, be suspended launch_suspended = False - for proc in res['AutoScalingGroups'][0]['SuspendedProcesses']: - if proc.get('ProcessName') == 'Launch': + for proc in res["AutoScalingGroups"][0]["SuspendedProcesses"]: + if proc.get("ProcessName") == "Launch": launch_suspended = True assert launch_suspended is True + @mock_autoscaling def test_set_instance_protection(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=False, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) instance_ids = [ - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ] protected = instance_ids[:3] _ = client.set_instance_protection( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=protected, ProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) - for instance in response['AutoScalingGroups'][0]['Instances']: - instance['ProtectedFromScaleIn'].should.equal( - instance['InstanceId'] in protected + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + for instance in response["AutoScalingGroups"][0]["Instances"]: + instance["ProtectedFromScaleIn"].should.equal( + instance["InstanceId"] in protected ) @mock_autoscaling def test_set_desired_capacity_up_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - _ = client.set_desired_capacity( - AutoScalingGroupName='test_asg', - DesiredCapacity=10, - ) + _ = client.set_desired_capacity(AutoScalingGroupName="test_asg", DesiredCapacity=10) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) - instances = response['AutoScalingGroups'][0]['Instances'] + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + instances = response["AutoScalingGroups"][0]["Instances"] instances.should.have.length_of(10) for instance in instances: - instance['ProtectedFromScaleIn'].should.equal(True) + instance["ProtectedFromScaleIn"].should.equal(True) @mock_autoscaling def test_set_desired_capacity_down_boto3(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=20, DesiredCapacity=5, - VPCZoneIdentifier=mocked_networking['subnet1'], + VPCZoneIdentifier=mocked_networking["subnet1"], NewInstancesProtectedFromScaleIn=True, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) instance_ids = [ - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ] unprotected, protected = instance_ids[:2], instance_ids[2:] _ = client.set_instance_protection( - AutoScalingGroupName='test_asg', + AutoScalingGroupName="test_asg", InstanceIds=unprotected, ProtectedFromScaleIn=False, ) - _ = client.set_desired_capacity( - AutoScalingGroupName='test_asg', - DesiredCapacity=1, - ) + _ = client.set_desired_capacity(AutoScalingGroupName="test_asg", DesiredCapacity=1) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) - group = response['AutoScalingGroups'][0] - group['DesiredCapacity'].should.equal(1) - instance_ids = {instance['InstanceId'] for instance in group['Instances']} + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) + group = response["AutoScalingGroups"][0] + group["DesiredCapacity"].should.equal(1) + instance_ids = {instance["InstanceId"] for instance in group["Instances"]} set(protected).should.equal(instance_ids) set(unprotected).should_not.be.within(instance_ids) # only unprotected killed @@ -1394,30 +1385,30 @@ def test_set_desired_capacity_down_boto3(): @mock_ec2 def test_terminate_instance_in_autoscaling_group(): mocked_networking = setup_networking() - client = boto3.client('autoscaling', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") _ = client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration' + LaunchConfigurationName="test_launch_configuration" ) _ = client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=1, MaxSize=20, - VPCZoneIdentifier=mocked_networking['subnet1'], - NewInstancesProtectedFromScaleIn=False + VPCZoneIdentifier=mocked_networking["subnet1"], + NewInstancesProtectedFromScaleIn=False, ) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) original_instance_id = next( - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ) - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") ec2_client.terminate_instances(InstanceIds=[original_instance_id]) - response = client.describe_auto_scaling_groups(AutoScalingGroupNames=['test_asg']) + response = client.describe_auto_scaling_groups(AutoScalingGroupNames=["test_asg"]) replaced_instance_id = next( - instance['InstanceId'] - for instance in response['AutoScalingGroups'][0]['Instances'] + instance["InstanceId"] + for instance in response["AutoScalingGroups"][0]["Instances"] ) replaced_instance_id.should_not.equal(original_instance_id) diff --git a/tests/test_autoscaling/test_elbv2.py b/tests/test_autoscaling/test_elbv2.py index a142fd133..a3d3dba9f 100644 --- a/tests/test_autoscaling/test_elbv2.py +++ b/tests/test_autoscaling/test_elbv2.py @@ -2,127 +2,134 @@ from __future__ import unicode_literals import boto3 import sure # noqa -from moto import mock_autoscaling, mock_ec2, mock_elbv2 +from moto import mock_autoscaling, mock_ec2, mock_elbv2 from utils import setup_networking + @mock_elbv2 @mock_autoscaling def test_attach_detach_target_groups(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - client = boto3.client('autoscaling', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") response = elbv2_client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, - VpcId=mocked_networking['vpc'], - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + VpcId=mocked_networking["vpc"], + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group_arn = response['TargetGroups'][0]['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group_arn = response["TargetGroups"][0]["TargetGroupArn"] client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration') + LaunchConfigurationName="test_launch_configuration" + ) # create asg, attach to target group on create client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, TargetGroupARNs=[target_group_arn], - VPCZoneIdentifier=mocked_networking['subnet1']) + VPCZoneIdentifier=mocked_networking["subnet1"], + ) # create asg without attaching to target group client.create_auto_scaling_group( - AutoScalingGroupName='test_asg2', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg2", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, - VPCZoneIdentifier=mocked_networking['subnet2']) + VPCZoneIdentifier=mocked_networking["subnet2"], + ) response = client.describe_load_balancer_target_groups( - AutoScalingGroupName='test_asg') - list(response['LoadBalancerTargetGroups']).should.have.length_of(1) + AutoScalingGroupName="test_asg" + ) + list(response["LoadBalancerTargetGroups"]).should.have.length_of(1) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT) client.attach_load_balancer_target_groups( - AutoScalingGroupName='test_asg2', - TargetGroupARNs=[target_group_arn]) + AutoScalingGroupName="test_asg2", TargetGroupARNs=[target_group_arn] + ) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT * 2) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT * 2) response = client.detach_load_balancer_target_groups( - AutoScalingGroupName='test_asg2', - TargetGroupARNs=[target_group_arn]) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT) + AutoScalingGroupName="test_asg2", TargetGroupARNs=[target_group_arn] + ) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT) + @mock_elbv2 @mock_autoscaling def test_detach_all_target_groups(): mocked_networking = setup_networking() INSTANCE_COUNT = 2 - client = boto3.client('autoscaling', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + client = boto3.client("autoscaling", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") response = elbv2_client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, - VpcId=mocked_networking['vpc'], - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + VpcId=mocked_networking["vpc"], + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group_arn = response['TargetGroups'][0]['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group_arn = response["TargetGroups"][0]["TargetGroupArn"] client.create_launch_configuration( - LaunchConfigurationName='test_launch_configuration') + LaunchConfigurationName="test_launch_configuration" + ) client.create_auto_scaling_group( - AutoScalingGroupName='test_asg', - LaunchConfigurationName='test_launch_configuration', + AutoScalingGroupName="test_asg", + LaunchConfigurationName="test_launch_configuration", MinSize=0, MaxSize=INSTANCE_COUNT, DesiredCapacity=INSTANCE_COUNT, TargetGroupARNs=[target_group_arn], - VPCZoneIdentifier=mocked_networking['subnet1']) + VPCZoneIdentifier=mocked_networking["subnet1"], + ) response = client.describe_load_balancer_target_groups( - AutoScalingGroupName='test_asg') - list(response['LoadBalancerTargetGroups']).should.have.length_of(1) + AutoScalingGroupName="test_asg" + ) + list(response["LoadBalancerTargetGroups"]).should.have.length_of(1) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(INSTANCE_COUNT) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(INSTANCE_COUNT) response = client.detach_load_balancer_target_groups( - AutoScalingGroupName='test_asg', - TargetGroupARNs=[target_group_arn]) + AutoScalingGroupName="test_asg", TargetGroupARNs=[target_group_arn] + ) - response = elbv2_client.describe_target_health( - TargetGroupArn=target_group_arn) - list(response['TargetHealthDescriptions']).should.have.length_of(0) + response = elbv2_client.describe_target_health(TargetGroupArn=target_group_arn) + list(response["TargetHealthDescriptions"]).should.have.length_of(0) response = client.describe_load_balancer_target_groups( - AutoScalingGroupName='test_asg') - list(response['LoadBalancerTargetGroups']).should.have.length_of(0) + AutoScalingGroupName="test_asg" + ) + list(response["LoadBalancerTargetGroups"]).should.have.length_of(0) diff --git a/tests/test_autoscaling/test_launch_configurations.py b/tests/test_autoscaling/test_launch_configurations.py index 931fc8a7e..8cd596ee7 100644 --- a/tests/test_autoscaling/test_launch_configurations.py +++ b/tests/test_autoscaling/test_launch_configurations.py @@ -15,29 +15,29 @@ from tests.helpers import requires_boto_gte def test_create_launch_configuration(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='t1.micro', - key_name='the_keys', + name="tester", + image_id="ami-abcd1234", + instance_type="t1.micro", + key_name="the_keys", security_groups=["default", "default2"], user_data=b"This is some user_data", instance_monitoring=True, - instance_profile_name='arn:aws:iam::123456789012:instance-profile/testing', + instance_profile_name="arn:aws:iam::123456789012:instance-profile/testing", spot_price=0.1, ) conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] - launch_config.name.should.equal('tester') - launch_config.image_id.should.equal('ami-abcd1234') - launch_config.instance_type.should.equal('t1.micro') - launch_config.key_name.should.equal('the_keys') - set(launch_config.security_groups).should.equal( - set(['default', 'default2'])) + launch_config.name.should.equal("tester") + launch_config.image_id.should.equal("ami-abcd1234") + launch_config.instance_type.should.equal("t1.micro") + launch_config.key_name.should.equal("the_keys") + set(launch_config.security_groups).should.equal(set(["default", "default2"])) launch_config.user_data.should.equal(b"This is some user_data") - launch_config.instance_monitoring.enabled.should.equal('true') + launch_config.instance_monitoring.enabled.should.equal("true") launch_config.instance_profile_name.should.equal( - 'arn:aws:iam::123456789012:instance-profile/testing') + "arn:aws:iam::123456789012:instance-profile/testing" + ) launch_config.spot_price.should.equal(0.1) @@ -47,64 +47,65 @@ def test_create_launch_configuration_with_block_device_mappings(): block_device_mapping = BlockDeviceMapping() ephemeral_drive = BlockDeviceType() - ephemeral_drive.ephemeral_name = 'ephemeral0' - block_device_mapping['/dev/xvdb'] = ephemeral_drive + ephemeral_drive.ephemeral_name = "ephemeral0" + block_device_mapping["/dev/xvdb"] = ephemeral_drive snapshot_drive = BlockDeviceType() snapshot_drive.snapshot_id = "snap-1234abcd" snapshot_drive.volume_type = "standard" - block_device_mapping['/dev/xvdp'] = snapshot_drive + block_device_mapping["/dev/xvdp"] = snapshot_drive ebs_drive = BlockDeviceType() ebs_drive.volume_type = "io1" ebs_drive.size = 100 ebs_drive.iops = 1000 ebs_drive.delete_on_termination = False - block_device_mapping['/dev/xvdh'] = ebs_drive + block_device_mapping["/dev/xvdh"] = ebs_drive conn = boto.connect_autoscale(use_block_device_types=True) config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', - key_name='the_keys', + name="tester", + image_id="ami-abcd1234", + instance_type="m1.small", + key_name="the_keys", security_groups=["default", "default2"], user_data=b"This is some user_data", instance_monitoring=True, - instance_profile_name='arn:aws:iam::123456789012:instance-profile/testing', + instance_profile_name="arn:aws:iam::123456789012:instance-profile/testing", spot_price=0.1, - block_device_mappings=[block_device_mapping] + block_device_mappings=[block_device_mapping], ) conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] - launch_config.name.should.equal('tester') - launch_config.image_id.should.equal('ami-abcd1234') - launch_config.instance_type.should.equal('m1.small') - launch_config.key_name.should.equal('the_keys') - set(launch_config.security_groups).should.equal( - set(['default', 'default2'])) + launch_config.name.should.equal("tester") + launch_config.image_id.should.equal("ami-abcd1234") + launch_config.instance_type.should.equal("m1.small") + launch_config.key_name.should.equal("the_keys") + set(launch_config.security_groups).should.equal(set(["default", "default2"])) launch_config.user_data.should.equal(b"This is some user_data") - launch_config.instance_monitoring.enabled.should.equal('true') + launch_config.instance_monitoring.enabled.should.equal("true") launch_config.instance_profile_name.should.equal( - 'arn:aws:iam::123456789012:instance-profile/testing') + "arn:aws:iam::123456789012:instance-profile/testing" + ) launch_config.spot_price.should.equal(0.1) len(launch_config.block_device_mappings).should.equal(3) returned_mapping = launch_config.block_device_mappings set(returned_mapping.keys()).should.equal( - set(['/dev/xvdb', '/dev/xvdp', '/dev/xvdh'])) + set(["/dev/xvdb", "/dev/xvdp", "/dev/xvdh"]) + ) - returned_mapping['/dev/xvdh'].iops.should.equal(1000) - returned_mapping['/dev/xvdh'].size.should.equal(100) - returned_mapping['/dev/xvdh'].volume_type.should.equal("io1") - returned_mapping['/dev/xvdh'].delete_on_termination.should.be.false + returned_mapping["/dev/xvdh"].iops.should.equal(1000) + returned_mapping["/dev/xvdh"].size.should.equal(100) + returned_mapping["/dev/xvdh"].volume_type.should.equal("io1") + returned_mapping["/dev/xvdh"].delete_on_termination.should.be.false - returned_mapping['/dev/xvdp'].snapshot_id.should.equal("snap-1234abcd") - returned_mapping['/dev/xvdp'].volume_type.should.equal("standard") + returned_mapping["/dev/xvdp"].snapshot_id.should.equal("snap-1234abcd") + returned_mapping["/dev/xvdp"].volume_type.should.equal("standard") - returned_mapping['/dev/xvdb'].ephemeral_name.should.equal('ephemeral0') + returned_mapping["/dev/xvdb"].ephemeral_name.should.equal("ephemeral0") @requires_boto_gte("2.12") @@ -112,9 +113,7 @@ def test_create_launch_configuration_with_block_device_mappings(): def test_create_launch_configuration_for_2_12(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - ebs_optimized=True, + name="tester", image_id="ami-abcd1234", ebs_optimized=True ) conn.create_launch_configuration(config) @@ -127,9 +126,7 @@ def test_create_launch_configuration_for_2_12(): def test_create_launch_configuration_using_ip_association(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - associate_public_ip_address=True, + name="tester", image_id="ami-abcd1234", associate_public_ip_address=True ) conn.create_launch_configuration(config) @@ -141,10 +138,7 @@ def test_create_launch_configuration_using_ip_association(): @mock_autoscaling_deprecated def test_create_launch_configuration_using_ip_association_should_default_to_false(): conn = boto.connect_autoscale() - config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - ) + config = LaunchConfiguration(name="tester", image_id="ami-abcd1234") conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] @@ -157,22 +151,20 @@ def test_create_launch_configuration_defaults(): are assigned for the other attributes """ conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] - launch_config.name.should.equal('tester') - launch_config.image_id.should.equal('ami-abcd1234') - launch_config.instance_type.should.equal('m1.small') + launch_config.name.should.equal("tester") + launch_config.image_id.should.equal("ami-abcd1234") + launch_config.instance_type.should.equal("m1.small") # Defaults - launch_config.key_name.should.equal('') + launch_config.key_name.should.equal("") list(launch_config.security_groups).should.equal([]) launch_config.user_data.should.equal(b"") - launch_config.instance_monitoring.enabled.should.equal('false') + launch_config.instance_monitoring.enabled.should.equal("false") launch_config.instance_profile_name.should.equal(None) launch_config.spot_price.should.equal(None) @@ -181,10 +173,7 @@ def test_create_launch_configuration_defaults(): @mock_autoscaling_deprecated def test_create_launch_configuration_defaults_for_2_12(): conn = boto.connect_autoscale() - config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - ) + config = LaunchConfiguration(name="tester", image_id="ami-abcd1234") conn.create_launch_configuration(config) launch_config = conn.get_all_launch_configurations()[0] @@ -195,51 +184,48 @@ def test_create_launch_configuration_defaults_for_2_12(): def test_launch_configuration_describe_filter(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) - config.name = 'tester2' + config.name = "tester2" conn.create_launch_configuration(config) - config.name = 'tester3' + config.name = "tester3" conn.create_launch_configuration(config) conn.get_all_launch_configurations( - names=['tester', 'tester2']).should.have.length_of(2) + names=["tester", "tester2"] + ).should.have.length_of(2) conn.get_all_launch_configurations().should.have.length_of(3) @mock_autoscaling def test_launch_configuration_describe_paginated(): - conn = boto3.client('autoscaling', region_name='us-east-1') + conn = boto3.client("autoscaling", region_name="us-east-1") for i in range(51): - conn.create_launch_configuration(LaunchConfigurationName='TestLC%d' % i) + conn.create_launch_configuration(LaunchConfigurationName="TestLC%d" % i) response = conn.describe_launch_configurations() lcs = response["LaunchConfigurations"] marker = response["NextToken"] lcs.should.have.length_of(50) - marker.should.equal(lcs[-1]['LaunchConfigurationName']) + marker.should.equal(lcs[-1]["LaunchConfigurationName"]) response2 = conn.describe_launch_configurations(NextToken=marker) lcs.extend(response2["LaunchConfigurations"]) lcs.should.have.length_of(51) - assert 'NextToken' not in response2.keys() + assert "NextToken" not in response2.keys() @mock_autoscaling_deprecated def test_launch_configuration_delete(): conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) conn.get_all_launch_configurations().should.have.length_of(1) - conn.delete_launch_configuration('tester') + conn.delete_launch_configuration("tester") conn.get_all_launch_configurations().should.have.length_of(0) diff --git a/tests/test_autoscaling/test_policies.py b/tests/test_autoscaling/test_policies.py index e6b01163f..f44938eea 100644 --- a/tests/test_autoscaling/test_policies.py +++ b/tests/test_autoscaling/test_policies.py @@ -14,18 +14,16 @@ def setup_autoscale_group(): mocked_networking = setup_networking_deprecated() conn = boto.connect_autoscale() config = LaunchConfiguration( - name='tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="tester", image_id="ami-abcd1234", instance_type="m1.small" ) conn.create_launch_configuration(config) group = AutoScalingGroup( - name='tester_group', + name="tester_group", max_size=2, min_size=2, launch_config=config, - vpc_zone_identifier=mocked_networking['subnet1'], + vpc_zone_identifier=mocked_networking["subnet1"], ) conn.create_auto_scaling_group(group) return group @@ -36,18 +34,18 @@ def test_create_policy(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, cooldown=60, ) conn.create_scaling_policy(policy) policy = conn.get_all_policies()[0] - policy.name.should.equal('ScaleUp') - policy.adjustment_type.should.equal('ExactCapacity') - policy.as_name.should.equal('tester_group') + policy.name.should.equal("ScaleUp") + policy.adjustment_type.should.equal("ExactCapacity") + policy.as_name.should.equal("tester_group") policy.scaling_adjustment.should.equal(3) policy.cooldown.should.equal(60) @@ -57,15 +55,15 @@ def test_create_policy_default_values(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) policy = conn.get_all_policies()[0] - policy.name.should.equal('ScaleUp') + policy.name.should.equal("ScaleUp") # Defaults policy.cooldown.should.equal(300) @@ -76,9 +74,9 @@ def test_update_policy(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) @@ -88,9 +86,9 @@ def test_update_policy(): # Now update it by creating another with the same name policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=2, ) conn.create_scaling_policy(policy) @@ -103,16 +101,16 @@ def test_delete_policy(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) conn.get_all_policies().should.have.length_of(1) - conn.delete_policy('ScaleUp') + conn.delete_policy("ScaleUp") conn.get_all_policies().should.have.length_of(0) @@ -121,9 +119,9 @@ def test_execute_policy_exact_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ExactCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ExactCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) @@ -139,9 +137,9 @@ def test_execute_policy_positive_change_in_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='ChangeInCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="ChangeInCapacity", + as_name="tester_group", scaling_adjustment=3, ) conn.create_scaling_policy(policy) @@ -157,9 +155,9 @@ def test_execute_policy_percent_change_in_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='PercentChangeInCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="PercentChangeInCapacity", + as_name="tester_group", scaling_adjustment=50, ) conn.create_scaling_policy(policy) @@ -178,9 +176,9 @@ def test_execute_policy_small_percent_change_in_capacity(): setup_autoscale_group() conn = boto.connect_autoscale() policy = ScalingPolicy( - name='ScaleUp', - adjustment_type='PercentChangeInCapacity', - as_name='tester_group', + name="ScaleUp", + adjustment_type="PercentChangeInCapacity", + as_name="tester_group", scaling_adjustment=1, ) conn.create_scaling_policy(policy) diff --git a/tests/test_autoscaling/test_server.py b/tests/test_autoscaling/test_server.py index 2025694cd..17263af44 100644 --- a/tests/test_autoscaling/test_server.py +++ b/tests/test_autoscaling/test_server.py @@ -3,16 +3,16 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_describe_autoscaling_groups(): backend = server.create_backend_app("autoscaling") test_client = backend.test_client() - res = test_client.get('/?Action=DescribeLaunchConfigurations') + res = test_client.get("/?Action=DescribeLaunchConfigurations") - res.data.should.contain(b'') + res.data.should.contain(b"") diff --git a/tests/test_autoscaling/utils.py b/tests/test_autoscaling/utils.py index dc38aba3d..8827d2693 100644 --- a/tests/test_autoscaling/utils.py +++ b/tests/test_autoscaling/utils.py @@ -6,43 +6,36 @@ from moto import mock_ec2, mock_ec2_deprecated @mock_ec2 def setup_networking(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.11.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.11.0.0/16") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='10.11.1.0/24', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="10.11.1.0/24", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='10.11.2.0/24', - AvailabilityZone='us-east-1b') - return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id} + VpcId=vpc.id, CidrBlock="10.11.2.0/24", AvailabilityZone="us-east-1b" + ) + return {"vpc": vpc.id, "subnet1": subnet1.id, "subnet2": subnet2.id} + @mock_ec2_deprecated def setup_networking_deprecated(): - conn = boto_vpc.connect_to_region('us-east-1') + conn = boto_vpc.connect_to_region("us-east-1") vpc = conn.create_vpc("10.11.0.0/16") - subnet1 = conn.create_subnet( - vpc.id, - "10.11.1.0/24", - availability_zone='us-east-1a') - subnet2 = conn.create_subnet( - vpc.id, - "10.11.2.0/24", - availability_zone='us-east-1b') - return {'vpc': vpc.id, 'subnet1': subnet1.id, 'subnet2': subnet2.id} + subnet1 = conn.create_subnet(vpc.id, "10.11.1.0/24", availability_zone="us-east-1a") + subnet2 = conn.create_subnet(vpc.id, "10.11.2.0/24", availability_zone="us-east-1b") + return {"vpc": vpc.id, "subnet1": subnet1.id, "subnet2": subnet2.id} @mock_ec2 def setup_instance_with_networking(image_id, instance_type): mock_data = setup_networking() - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") instances = ec2.create_instances( ImageId=image_id, InstanceType=instance_type, MaxCount=1, MinCount=1, - SubnetId=mock_data['subnet1'] + SubnetId=mock_data["subnet1"], ) - mock_data['instance'] = instances[0].id + mock_data["instance"] = instances[0].id return mock_data diff --git a/tests/test_awslambda/test_lambda.py b/tests/test_awslambda/test_lambda.py index 1ac1cb2a6..c1aff468d 100644 --- a/tests/test_awslambda/test_lambda.py +++ b/tests/test_awslambda/test_lambda.py @@ -12,18 +12,27 @@ import zipfile import sure # noqa from freezegun import freeze_time -from moto import mock_dynamodb2, mock_lambda, mock_s3, mock_ec2, mock_sns, mock_logs, settings, mock_sqs +from moto import ( + mock_dynamodb2, + mock_lambda, + mock_s3, + mock_ec2, + mock_sns, + mock_logs, + settings, + mock_sqs, +) from nose.tools import assert_raises from botocore.exceptions import ClientError -_lambda_region = 'us-west-2' +_lambda_region = "us-west-2" boto3.setup_default_session(region_name=_lambda_region) def _process_lambda(func_str): zip_output = io.BytesIO() - zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED) - zip_file.writestr('lambda_function.py', func_str) + zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) + zip_file.writestr("lambda_function.py", func_str) zip_file.close() zip_output.seek(0) return zip_output.read() @@ -49,7 +58,11 @@ def lambda_handler(event, context): print('get volume details for %s\\nVolume - %s state=%s, size=%s' % (volume_id, volume_id, vol.state, vol.size)) return event -""".format(base_url="motoserver:5000" if settings.TEST_SERVER_MODE else "ec2.us-west-2.amazonaws.com") +""".format( + base_url="motoserver:5000" + if settings.TEST_SERVER_MODE + else "ec2.us-west-2.amazonaws.com" + ) return _process_lambda(func_str) @@ -61,6 +74,7 @@ def lambda_handler(event, context): """ return _process_lambda(pfunc) + def get_test_zip_file4(): pfunc = """ def lambda_handler(event, context): @@ -71,113 +85,118 @@ def lambda_handler(event, context): @mock_lambda def test_list_functions(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") result = conn.list_functions() - result['Functions'].should.have.length_of(0) + result["Functions"].should.have.length_of(0) @mock_lambda def test_invoke_requestresponse_function(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file1(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file1()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - in_data = {'msg': 'So long and thanks for all the fish'} - success_result = conn.invoke(FunctionName='testFunction', InvocationType='RequestResponse', - Payload=json.dumps(in_data)) + in_data = {"msg": "So long and thanks for all the fish"} + success_result = conn.invoke( + FunctionName="testFunction", + InvocationType="RequestResponse", + Payload=json.dumps(in_data), + ) success_result["StatusCode"].should.equal(202) result_obj = json.loads( - base64.b64decode(success_result["LogResult"]).decode('utf-8')) + base64.b64decode(success_result["LogResult"]).decode("utf-8") + ) result_obj.should.equal(in_data) - payload = success_result["Payload"].read().decode('utf-8') + payload = success_result["Payload"].read().decode("utf-8") json.loads(payload).should.equal(in_data) @mock_lambda def test_invoke_event_function(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file1(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file1()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) conn.invoke.when.called_with( - FunctionName='notAFunction', - InvocationType='Event', - Payload='{}' + FunctionName="notAFunction", InvocationType="Event", Payload="{}" ).should.throw(botocore.client.ClientError) - in_data = {'msg': 'So long and thanks for all the fish'} + in_data = {"msg": "So long and thanks for all the fish"} success_result = conn.invoke( - FunctionName='testFunction', InvocationType='Event', Payload=json.dumps(in_data)) + FunctionName="testFunction", InvocationType="Event", Payload=json.dumps(in_data) + ) success_result["StatusCode"].should.equal(202) - json.loads(success_result['Payload'].read().decode( - 'utf-8')).should.equal({}) + json.loads(success_result["Payload"].read().decode("utf-8")).should.equal({}) if settings.TEST_SERVER_MODE: + @mock_ec2 @mock_lambda def test_invoke_function_get_ec2_volume(): conn = boto3.resource("ec2", "us-west-2") - vol = conn.create_volume(Size=99, AvailabilityZone='us-west-2') + vol = conn.create_volume(Size=99, AvailabilityZone="us-west-2") vol = conn.Volume(vol.id) - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file2(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file2()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - in_data = {'volume_id': vol.id} - result = conn.invoke(FunctionName='testFunction', - InvocationType='RequestResponse', Payload=json.dumps(in_data)) + in_data = {"volume_id": vol.id} + result = conn.invoke( + FunctionName="testFunction", + InvocationType="RequestResponse", + Payload=json.dumps(in_data), + ) result["StatusCode"].should.equal(202) - msg = 'get volume details for %s\nVolume - %s state=%s, size=%s\n%s' % ( - vol.id, vol.id, vol.state, vol.size, json.dumps(in_data)) + msg = "get volume details for %s\nVolume - %s state=%s, size=%s\n%s" % ( + vol.id, + vol.id, + vol.state, + vol.size, + json.dumps(in_data), + ) - log_result = base64.b64decode(result["LogResult"]).decode('utf-8') + log_result = base64.b64decode(result["LogResult"]).decode("utf-8") # fix for running under travis (TODO: investigate why it has an extra newline) - log_result = log_result.replace('\n\n', '\n') + log_result = log_result.replace("\n\n", "\n") log_result.should.equal(msg) - payload = result['Payload'].read().decode('utf-8') + payload = result["Payload"].read().decode("utf-8") # fix for running under travis (TODO: investigate why it has an extra newline) - payload = payload.replace('\n\n', '\n') + payload = payload.replace("\n\n", "\n") payload.should.equal(msg) @@ -191,39 +210,42 @@ def test_invoke_function_from_sns(): sns_conn.create_topic(Name="some-topic") topics_json = sns_conn.list_topics() topics = topics_json["Topics"] - topic_arn = topics[0]['TopicArn'] + topic_arn = topics[0]["TopicArn"] - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - sns_conn.subscribe(TopicArn=topic_arn, Protocol="lambda", Endpoint=result['FunctionArn']) + sns_conn.subscribe( + TopicArn=topic_arn, Protocol="lambda", Endpoint=result["FunctionArn"] + ) result = sns_conn.publish(TopicArn=topic_arn, Message=json.dumps({})) start = time.time() while (time.time() - start) < 30: - result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction') - log_streams = result.get('logStreams') + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") if not log_streams: time.sleep(1) continue assert len(log_streams) == 1 - result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName']) - for event in result.get('events'): - if event['message'] == 'get_test_zip_file3 success': + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if event["message"] == "get_test_zip_file3 success": return time.sleep(1) @@ -233,190 +255,182 @@ def test_invoke_function_from_sns(): @mock_lambda def test_create_based_on_s3_with_missing_bucket(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function.when.called_with( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'this-bucket-does-not-exist', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "this-bucket-does-not-exist", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, - VpcConfig={ - "SecurityGroupIds": ["sg-123abc"], - "SubnetIds": ["subnet-123abc"], - }, + VpcConfig={"SecurityGroupIds": ["sg-123abc"], "SubnetIds": ["subnet-123abc"]}, ).should.throw(botocore.client.ClientError) @mock_lambda @mock_s3 -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_create_function_from_aws_bucket(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, - VpcConfig={ - "SecurityGroupIds": ["sg-123abc"], - "SubnetIds": ["subnet-123abc"], - }, + VpcConfig={"SecurityGroupIds": ["sg-123abc"], "SubnetIds": ["subnet-123abc"]}, ) # this is hard to match against, so remove it - result['ResponseMetadata'].pop('HTTPHeaders', None) + result["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - result['ResponseMetadata'].pop('RetryAttempts', None) - result.pop('LastModified') - result.should.equal({ - 'FunctionName': 'testFunction', - 'FunctionArn': 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), - 'Runtime': 'python2.7', - 'Role': 'test-iam-role', - 'Handler': 'lambda_function.lambda_handler', - "CodeSha256": hashlib.sha256(zip_content).hexdigest(), - "CodeSize": len(zip_content), - 'Description': 'test lambda function', - 'Timeout': 3, - 'MemorySize': 128, - 'Version': '1', - 'VpcConfig': { - "SecurityGroupIds": ["sg-123abc"], - "SubnetIds": ["subnet-123abc"], - "VpcId": "vpc-123abc" - }, - 'ResponseMetadata': {'HTTPStatusCode': 201}, - }) + result["ResponseMetadata"].pop("RetryAttempts", None) + result.pop("LastModified") + result.should.equal( + { + "FunctionName": "testFunction", + "FunctionArn": "arn:aws:lambda:{}:123456789012:function:testFunction".format( + _lambda_region + ), + "Runtime": "python2.7", + "Role": "test-iam-role", + "Handler": "lambda_function.lambda_handler", + "CodeSha256": hashlib.sha256(zip_content).hexdigest(), + "CodeSize": len(zip_content), + "Description": "test lambda function", + "Timeout": 3, + "MemorySize": 128, + "Version": "1", + "VpcConfig": { + "SecurityGroupIds": ["sg-123abc"], + "SubnetIds": ["subnet-123abc"], + "VpcId": "vpc-123abc", + }, + "ResponseMetadata": {"HTTPStatusCode": 201}, + } + ) @mock_lambda -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_create_function_from_zipfile(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) # this is hard to match against, so remove it - result['ResponseMetadata'].pop('HTTPHeaders', None) + result["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - result['ResponseMetadata'].pop('RetryAttempts', None) - result.pop('LastModified') + result["ResponseMetadata"].pop("RetryAttempts", None) + result.pop("LastModified") - result.should.equal({ - 'FunctionName': 'testFunction', - 'FunctionArn': 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), - 'Runtime': 'python2.7', - 'Role': 'test-iam-role', - 'Handler': 'lambda_function.lambda_handler', - 'CodeSize': len(zip_content), - 'Description': 'test lambda function', - 'Timeout': 3, - 'MemorySize': 128, - 'CodeSha256': hashlib.sha256(zip_content).hexdigest(), - 'Version': '1', - 'VpcConfig': { - "SecurityGroupIds": [], - "SubnetIds": [], - }, - - 'ResponseMetadata': {'HTTPStatusCode': 201}, - }) + result.should.equal( + { + "FunctionName": "testFunction", + "FunctionArn": "arn:aws:lambda:{}:123456789012:function:testFunction".format( + _lambda_region + ), + "Runtime": "python2.7", + "Role": "test-iam-role", + "Handler": "lambda_function.lambda_handler", + "CodeSize": len(zip_content), + "Description": "test lambda function", + "Timeout": 3, + "MemorySize": 128, + "CodeSha256": hashlib.sha256(zip_content).hexdigest(), + "Version": "1", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + "ResponseMetadata": {"HTTPStatusCode": 201}, + } + ) @mock_lambda @mock_s3 -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_get_function(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file1() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, Environment={"Variables": {"test_variable": "test_value"}} ) - result = conn.get_function(FunctionName='testFunction') + result = conn.get_function(FunctionName="testFunction") # this is hard to match against, so remove it - result['ResponseMetadata'].pop('HTTPHeaders', None) + result["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - result['ResponseMetadata'].pop('RetryAttempts', None) - result['Configuration'].pop('LastModified') + result["ResponseMetadata"].pop("RetryAttempts", None) + result["Configuration"].pop("LastModified") - result['Code']['Location'].should.equal('s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip'.format(_lambda_region)) - result['Code']['RepositoryType'].should.equal('S3') + result["Code"]["Location"].should.equal( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip".format(_lambda_region) + ) + result["Code"]["RepositoryType"].should.equal("S3") - result['Configuration']['CodeSha256'].should.equal(hashlib.sha256(zip_content).hexdigest()) - result['Configuration']['CodeSize'].should.equal(len(zip_content)) - result['Configuration']['Description'].should.equal('test lambda function') - result['Configuration'].should.contain('FunctionArn') - result['Configuration']['FunctionName'].should.equal('testFunction') - result['Configuration']['Handler'].should.equal('lambda_function.lambda_handler') - result['Configuration']['MemorySize'].should.equal(128) - result['Configuration']['Role'].should.equal('test-iam-role') - result['Configuration']['Runtime'].should.equal('python2.7') - result['Configuration']['Timeout'].should.equal(3) - result['Configuration']['Version'].should.equal('$LATEST') - result['Configuration'].should.contain('VpcConfig') + result["Configuration"]["CodeSha256"].should.equal( + hashlib.sha256(zip_content).hexdigest() + ) + result["Configuration"]["CodeSize"].should.equal(len(zip_content)) + result["Configuration"]["Description"].should.equal("test lambda function") + result["Configuration"].should.contain("FunctionArn") + result["Configuration"]["FunctionName"].should.equal("testFunction") + result["Configuration"]["Handler"].should.equal("lambda_function.lambda_handler") + result["Configuration"]["MemorySize"].should.equal(128) + result["Configuration"]["Role"].should.equal("test-iam-role") + result["Configuration"]["Runtime"].should.equal("python2.7") + result["Configuration"]["Timeout"].should.equal(3) + result["Configuration"]["Version"].should.equal("$LATEST") + result["Configuration"].should.contain("VpcConfig") result['Configuration'].should.contain('Environment') result['Configuration']['Environment'].should.contain('Variables') result['Configuration']['Environment']["Variables"].should.equal({"test_variable": "test_value"}) # Test get function with - result = conn.get_function(FunctionName='testFunction', Qualifier='$LATEST') - result['Configuration']['Version'].should.equal('$LATEST') - result['Configuration']['FunctionArn'].should.equal('arn:aws:lambda:us-west-2:123456789012:function:testFunction:$LATEST') - + result = conn.get_function(FunctionName="testFunction", Qualifier="$LATEST") + result["Configuration"]["Version"].should.equal("$LATEST") + result["Configuration"]["FunctionArn"].should.equal( + "arn:aws:lambda:us-west-2:123456789012:function:testFunction:$LATEST" + ) # Test get function when can't find function name with assert_raises(ClientError): - conn.get_function(FunctionName='junk', Qualifier='$LATEST') + conn.get_function(FunctionName="junk", Qualifier="$LATEST") @mock_lambda @mock_s3 @@ -443,186 +457,186 @@ def test_get_function_by_arn(): @mock_lambda @mock_s3 def test_delete_function(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - success_result = conn.delete_function(FunctionName='testFunction') + success_result = conn.delete_function(FunctionName="testFunction") # this is hard to match against, so remove it - success_result['ResponseMetadata'].pop('HTTPHeaders', None) + success_result["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - success_result['ResponseMetadata'].pop('RetryAttempts', None) + success_result["ResponseMetadata"].pop("RetryAttempts", None) - success_result.should.equal({'ResponseMetadata': {'HTTPStatusCode': 204}}) + success_result.should.equal({"ResponseMetadata": {"HTTPStatusCode": 204}}) function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(0) + function_list["Functions"].should.have.length_of(0) @mock_lambda @mock_s3 def test_delete_function_by_arn(): - bucket_name = 'test-bucket' - s3_conn = boto3.client('s3', 'us-east-1') + bucket_name = "test-bucket" + s3_conn = boto3.client("s3", "us-east-1") s3_conn.create_bucket(Bucket=bucket_name) zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket=bucket_name, Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-east-1') + s3_conn.put_object(Bucket=bucket_name, Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-east-1") - fnc = conn.create_function(FunctionName='testFunction', - Runtime='python2.7', Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={'S3Bucket': bucket_name, 'S3Key': 'test.zip'}, - Description='test lambda function', - Timeout=3, MemorySize=128, Publish=True) + fnc = conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": bucket_name, "S3Key": "test.zip"}, + Description="test lambda function", + Timeout=3, + MemorySize=128, + Publish=True, + ) - conn.delete_function(FunctionName=fnc['FunctionArn']) + conn.delete_function(FunctionName=fnc["FunctionArn"]) function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(0) + function_list["Functions"].should.have.length_of(0) @mock_lambda def test_delete_unknown_function(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.delete_function.when.called_with( - FunctionName='testFunctionThatDoesntExist').should.throw(botocore.client.ClientError) + FunctionName="testFunctionThatDoesntExist" + ).should.throw(botocore.client.ClientError) @mock_lambda @mock_s3 def test_publish(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=False, ) function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(1) - latest_arn = function_list['Functions'][0]['FunctionArn'] + function_list["Functions"].should.have.length_of(1) + latest_arn = function_list["Functions"][0]["FunctionArn"] - res = conn.publish_version(FunctionName='testFunction') - assert res['ResponseMetadata']['HTTPStatusCode'] == 201 + res = conn.publish_version(FunctionName="testFunction") + assert res["ResponseMetadata"]["HTTPStatusCode"] == 201 function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(2) + function_list["Functions"].should.have.length_of(2) # #SetComprehension ;-) - published_arn = list({f['FunctionArn'] for f in function_list['Functions']} - {latest_arn})[0] - published_arn.should.contain('testFunction:1') + published_arn = list( + {f["FunctionArn"] for f in function_list["Functions"]} - {latest_arn} + )[0] + published_arn.should.contain("testFunction:1") - conn.delete_function(FunctionName='testFunction', Qualifier='1') + conn.delete_function(FunctionName="testFunction", Qualifier="1") function_list = conn.list_functions() - function_list['Functions'].should.have.length_of(1) - function_list['Functions'][0]['FunctionArn'].should.contain('testFunction') + function_list["Functions"].should.have.length_of(1) + function_list["Functions"][0]["FunctionArn"].should.contain("testFunction") @mock_lambda @mock_s3 -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_list_create_list_get_delete_list(): """ test `list -> create -> list -> get -> delete -> list` integration """ - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") - conn.list_functions()['Functions'].should.have.length_of(0) + conn.list_functions()["Functions"].should.have.length_of(0) conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) expected_function_result = { "Code": { - "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip".format(_lambda_region), - "RepositoryType": "S3" + "Location": "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com/test.zip".format( + _lambda_region + ), + "RepositoryType": "S3", }, "Configuration": { "CodeSha256": hashlib.sha256(zip_content).hexdigest(), "CodeSize": len(zip_content), "Description": "test lambda function", - "FunctionArn": 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), + "FunctionArn": "arn:aws:lambda:{}:123456789012:function:testFunction".format( + _lambda_region + ), "FunctionName": "testFunction", "Handler": "lambda_function.lambda_handler", "MemorySize": 128, "Role": "test-iam-role", "Runtime": "python2.7", "Timeout": 3, - "Version": '$LATEST', - "VpcConfig": { - "SecurityGroupIds": [], - "SubnetIds": [], - } + "Version": "$LATEST", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, }, - 'ResponseMetadata': {'HTTPStatusCode': 200}, + "ResponseMetadata": {"HTTPStatusCode": 200}, } - func = conn.list_functions()['Functions'][0] - func.pop('LastModified') - func.should.equal(expected_function_result['Configuration']) + func = conn.list_functions()["Functions"][0] + func.pop("LastModified") + func.should.equal(expected_function_result["Configuration"]) - func = conn.get_function(FunctionName='testFunction') + func = conn.get_function(FunctionName="testFunction") # this is hard to match against, so remove it - func['ResponseMetadata'].pop('HTTPHeaders', None) + func["ResponseMetadata"].pop("HTTPHeaders", None) # Botocore inserts retry attempts not seen in Python27 - func['ResponseMetadata'].pop('RetryAttempts', None) - func['Configuration'].pop('LastModified') + func["ResponseMetadata"].pop("RetryAttempts", None) + func["Configuration"].pop("LastModified") func.should.equal(expected_function_result) - conn.delete_function(FunctionName='testFunction') + conn.delete_function(FunctionName="testFunction") - conn.list_functions()['Functions'].should.have.length_of(0) + conn.list_functions()["Functions"].should.have.length_of(0) @mock_lambda @@ -632,34 +646,30 @@ def lambda_handler(event, context): raise Exception('failsauce') """ zip_output = io.BytesIO() - zip_file = zipfile.ZipFile(zip_output, 'w', zipfile.ZIP_DEFLATED) - zip_file.writestr('lambda_function.py', lambda_fx) + zip_file = zipfile.ZipFile(zip_output, "w", zipfile.ZIP_DEFLATED) + zip_file.writestr("lambda_function.py", lambda_fx) zip_file.close() zip_output.seek(0) - client = boto3.client('lambda', region_name='us-east-1') + client = boto3.client("lambda", region_name="us-east-1") client.create_function( - FunctionName='test-lambda-fx', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Description='test lambda function', + FunctionName="test-lambda-fx", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, - Code={ - 'ZipFile': zip_output.read() - }, + Code={"ZipFile": zip_output.read()}, ) result = client.invoke( - FunctionName='test-lambda-fx', - InvocationType='RequestResponse', - LogType='Tail' + FunctionName="test-lambda-fx", InvocationType="RequestResponse", LogType="Tail" ) - assert 'FunctionError' in result - assert result['FunctionError'] == 'Handled' + assert "FunctionError" in result + assert result["FunctionError"] == "Handled" @mock_lambda @@ -668,65 +678,56 @@ def test_tags(): """ test list_tags -> tag_resource -> list_tags -> tag_resource -> list_tags -> untag_resource -> list_tags integration """ - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") function = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) # List tags when there are none - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict()) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal(dict()) # List tags when there is one - conn.tag_resource( - Resource=function['FunctionArn'], - Tags=dict(spam='eggs') - )['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict(spam='eggs')) + conn.tag_resource(Resource=function["FunctionArn"], Tags=dict(spam="eggs"))[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(200) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal( + dict(spam="eggs") + ) # List tags when another has been added - conn.tag_resource( - Resource=function['FunctionArn'], - Tags=dict(foo='bar') - )['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict(spam='eggs', foo='bar')) + conn.tag_resource(Resource=function["FunctionArn"], Tags=dict(foo="bar"))[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(200) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal( + dict(spam="eggs", foo="bar") + ) # Untag resource - conn.untag_resource( - Resource=function['FunctionArn'], - TagKeys=['spam', 'trolls'] - )['ResponseMetadata']['HTTPStatusCode'].should.equal(204) - conn.list_tags( - Resource=function['FunctionArn'] - )['Tags'].should.equal(dict(foo='bar')) + conn.untag_resource(Resource=function["FunctionArn"], TagKeys=["spam", "trolls"])[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(204) + conn.list_tags(Resource=function["FunctionArn"])["Tags"].should.equal( + dict(foo="bar") + ) # Untag a tag that does not exist (no error and no change) - conn.untag_resource( - Resource=function['FunctionArn'], - TagKeys=['spam'] - )['ResponseMetadata']['HTTPStatusCode'].should.equal(204) + conn.untag_resource(Resource=function["FunctionArn"], TagKeys=["spam"])[ + "ResponseMetadata" + ]["HTTPStatusCode"].should.equal(204) @mock_lambda @@ -734,299 +735,285 @@ def test_tags_not_found(): """ Test list_tags and tag_resource when the lambda with the given arn does not exist """ - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.list_tags.when.called_with( - Resource='arn:aws:lambda:123456789012:function:not-found' + Resource="arn:aws:lambda:123456789012:function:not-found" ).should.throw(botocore.client.ClientError) conn.tag_resource.when.called_with( - Resource='arn:aws:lambda:123456789012:function:not-found', - Tags=dict(spam='eggs') + Resource="arn:aws:lambda:123456789012:function:not-found", + Tags=dict(spam="eggs"), ).should.throw(botocore.client.ClientError) conn.untag_resource.when.called_with( - Resource='arn:aws:lambda:123456789012:function:not-found', - TagKeys=['spam'] + Resource="arn:aws:lambda:123456789012:function:not-found", TagKeys=["spam"] ).should.throw(botocore.client.ClientError) @mock_lambda def test_invoke_async_function(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={'ZipFile': get_test_zip_file1()}, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file1()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) success_result = conn.invoke_async( - FunctionName='testFunction', - InvokeArgs=json.dumps({'test': 'event'}) - ) + FunctionName="testFunction", InvokeArgs=json.dumps({"test": "event"}) + ) - success_result['Status'].should.equal(202) + success_result["Status"].should.equal(202) @mock_lambda -@freeze_time('2015-01-01 00:00:00') +@freeze_time("2015-01-01 00:00:00") def test_get_function_created_with_zipfile(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() result = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - response = conn.get_function( - FunctionName='testFunction' - ) - response['Configuration'].pop('LastModified') + response = conn.get_function(FunctionName="testFunction") + response["Configuration"].pop("LastModified") - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - assert len(response['Code']) == 2 - assert response['Code']['RepositoryType'] == 'S3' - assert response['Code']['Location'].startswith('s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com'.format(_lambda_region)) - response['Configuration'].should.equal( + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + assert len(response["Code"]) == 2 + assert response["Code"]["RepositoryType"] == "S3" + assert response["Code"]["Location"].startswith( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com".format(_lambda_region) + ) + response["Configuration"].should.equal( { "CodeSha256": hashlib.sha256(zip_content).hexdigest(), "CodeSize": len(zip_content), "Description": "test lambda function", - "FunctionArn": 'arn:aws:lambda:{}:123456789012:function:testFunction'.format(_lambda_region), + "FunctionArn": "arn:aws:lambda:{}:123456789012:function:testFunction".format( + _lambda_region + ), "FunctionName": "testFunction", "Handler": "lambda_function.handler", "MemorySize": 128, "Role": "test-iam-role", "Runtime": "python2.7", "Timeout": 3, - "Version": '$LATEST', - "VpcConfig": { - "SecurityGroupIds": [], - "SubnetIds": [], - } - }, + "Version": "$LATEST", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + } ) @mock_lambda def test_add_function_permission(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.add_permission( - FunctionName='testFunction', - StatementId='1', + FunctionName="testFunction", + StatementId="1", Action="lambda:InvokeFunction", - Principal='432143214321', + Principal="432143214321", SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld", - SourceAccount='123412341234', - EventSourceToken='blah', - Qualifier='2' + SourceAccount="123412341234", + EventSourceToken="blah", + Qualifier="2", ) - assert u'Statement' in response - res = json.loads(response[u'Statement']) - assert res[u'Action'] == u'lambda:InvokeFunction' + assert "Statement" in response + res = json.loads(response["Statement"]) + assert res["Action"] == "lambda:InvokeFunction" @mock_lambda def test_get_function_policy(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") zip_content = get_test_zip_file1() conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.handler', - Code={ - 'ZipFile': zip_content, - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.handler", + Code={"ZipFile": zip_content}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.add_permission( - FunctionName='testFunction', - StatementId='1', + FunctionName="testFunction", + StatementId="1", Action="lambda:InvokeFunction", - Principal='432143214321', + Principal="432143214321", SourceArn="arn:aws:lambda:us-west-2:account-id:function:helloworld", - SourceAccount='123412341234', - EventSourceToken='blah', - Qualifier='2' + SourceAccount="123412341234", + EventSourceToken="blah", + Qualifier="2", ) - response = conn.get_policy( - FunctionName='testFunction' - ) + response = conn.get_policy(FunctionName="testFunction") - assert u'Policy' in response - res = json.loads(response[u'Policy']) - assert res[u'Statement'][0][u'Action'] == u'lambda:InvokeFunction' + assert "Policy" in response + res = json.loads(response["Policy"]) + assert res["Statement"][0]["Action"] == "lambda:InvokeFunction" @mock_lambda @mock_s3 def test_list_versions_by_function(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='arn:aws:iam::123456789012:role/test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="arn:aws:iam::123456789012:role/test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - res = conn.publish_version(FunctionName='testFunction') - assert res['ResponseMetadata']['HTTPStatusCode'] == 201 - versions = conn.list_versions_by_function(FunctionName='testFunction') - assert len(versions['Versions']) == 3 - assert versions['Versions'][0]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction:$LATEST' - assert versions['Versions'][1]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction:1' - assert versions['Versions'][2]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction:2' + res = conn.publish_version(FunctionName="testFunction") + assert res["ResponseMetadata"]["HTTPStatusCode"] == 201 + versions = conn.list_versions_by_function(FunctionName="testFunction") + assert len(versions["Versions"]) == 3 + assert ( + versions["Versions"][0]["FunctionArn"] + == "arn:aws:lambda:us-west-2:123456789012:function:testFunction:$LATEST" + ) + assert ( + versions["Versions"][1]["FunctionArn"] + == "arn:aws:lambda:us-west-2:123456789012:function:testFunction:1" + ) + assert ( + versions["Versions"][2]["FunctionArn"] + == "arn:aws:lambda:us-west-2:123456789012:function:testFunction:2" + ) conn.create_function( - FunctionName='testFunction_2', - Runtime='python2.7', - Role='arn:aws:iam::123456789012:role/test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction_2", + Runtime="python2.7", + Role="arn:aws:iam::123456789012:role/test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=False, ) - versions = conn.list_versions_by_function(FunctionName='testFunction_2') - assert len(versions['Versions']) == 1 - assert versions['Versions'][0]['FunctionArn'] == 'arn:aws:lambda:us-west-2:123456789012:function:testFunction_2:$LATEST' + versions = conn.list_versions_by_function(FunctionName="testFunction_2") + assert len(versions["Versions"]) == 1 + assert ( + versions["Versions"][0]["FunctionArn"] + == "arn:aws:lambda:us-west-2:123456789012:function:testFunction_2:$LATEST" + ) @mock_lambda @mock_s3 def test_create_function_with_already_exists(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) - assert response['FunctionName'] == 'testFunction' + assert response["FunctionName"] == "testFunction" @mock_lambda @mock_s3 def test_list_versions_by_function_for_nonexistent_function(): - conn = boto3.client('lambda', 'us-west-2') - versions = conn.list_versions_by_function(FunctionName='testFunction') + conn = boto3.client("lambda", "us-west-2") + versions = conn.list_versions_by_function(FunctionName="testFunction") - assert len(versions['Versions']) == 0 + assert len(versions["Versions"]) == 0 @mock_logs @mock_lambda @mock_sqs def test_create_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - assert response['EventSourceArn'] == queue.attributes['QueueArn'] - assert response['FunctionArn'] == func['FunctionArn'] - assert response['State'] == 'Enabled' + assert response["EventSourceArn"] == queue.attributes["QueueArn"] + assert response["FunctionArn"] == func["FunctionArn"] + assert response["State"] == "Enabled" @mock_logs @@ -1034,46 +1021,46 @@ def test_create_event_source_mapping(): @mock_sqs def test_invoke_function_from_sqs(): logs_conn = boto3.client("logs") - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - assert response['EventSourceArn'] == queue.attributes['QueueArn'] - assert response['State'] == 'Enabled' + assert response["EventSourceArn"] == queue.attributes["QueueArn"] + assert response["State"] == "Enabled" - sqs_client = boto3.client('sqs') - sqs_client.send_message(QueueUrl=queue.url, MessageBody='test') + sqs_client = boto3.client("sqs") + sqs_client.send_message(QueueUrl=queue.url, MessageBody="test") start = time.time() while (time.time() - start) < 30: - result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction') - log_streams = result.get('logStreams') + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") if not log_streams: time.sleep(1) continue assert len(log_streams) == 1 - result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName']) - for event in result.get('events'): - if event['message'] == 'get_test_zip_file3 success': + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if event["message"] == "get_test_zip_file3 success": return time.sleep(1) @@ -1085,43 +1072,55 @@ def test_invoke_function_from_sqs(): @mock_dynamodb2 def test_invoke_function_from_dynamodb(): logs_conn = boto3.client("logs") - dynamodb = boto3.client('dynamodb') - table_name = 'table_with_stream' - table = dynamodb.create_table(TableName=table_name, - KeySchema=[{'AttributeName':'id','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'id','AttributeType':'S'}], - StreamSpecification={'StreamEnabled': True, - 'StreamViewType': 'NEW_AND_OLD_IMAGES'}) - - conn = boto3.client('lambda') - func = conn.create_function(FunctionName='testFunction', Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={'ZipFile': get_test_zip_file3()}, - Description='test lambda function executed after a DynamoDB table is updated', - Timeout=3, MemorySize=128, Publish=True) - - response = conn.create_event_source_mapping( - EventSourceArn=table['TableDescription']['LatestStreamArn'], - FunctionName=func['FunctionArn'] + dynamodb = boto3.client("dynamodb") + table_name = "table_with_stream" + table = dynamodb.create_table( + TableName=table_name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + StreamSpecification={ + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", + }, ) - assert response['EventSourceArn'] == table['TableDescription']['LatestStreamArn'] - assert response['State'] == 'Enabled' + conn = boto3.client("lambda") + func = conn.create_function( + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function executed after a DynamoDB table is updated", + Timeout=3, + MemorySize=128, + Publish=True, + ) - dynamodb.put_item(TableName=table_name, Item={'id': { 'S': 'item 1' }}) + response = conn.create_event_source_mapping( + EventSourceArn=table["TableDescription"]["LatestStreamArn"], + FunctionName=func["FunctionArn"], + ) + + assert response["EventSourceArn"] == table["TableDescription"]["LatestStreamArn"] + assert response["State"] == "Enabled" + + dynamodb.put_item(TableName=table_name, Item={"id": {"S": "item 1"}}) start = time.time() while (time.time() - start) < 30: - result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction') - log_streams = result.get('logStreams') + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") if not log_streams: time.sleep(1) continue assert len(log_streams) == 1 - result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName']) - for event in result.get('events'): - if event['message'] == 'get_test_zip_file3 success': + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if event["message"] == "get_test_zip_file3 success": return time.sleep(1) @@ -1133,58 +1132,52 @@ def test_invoke_function_from_dynamodb(): @mock_sqs def test_invoke_function_from_sqs_exception(): logs_conn = boto3.client("logs") - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file4(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file4()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - assert response['EventSourceArn'] == queue.attributes['QueueArn'] - assert response['State'] == 'Enabled' + assert response["EventSourceArn"] == queue.attributes["QueueArn"] + assert response["State"] == "Enabled" entries = [] for i in range(3): - body = { - "uuid": str(uuid.uuid4()), - "test": "test_{}".format(i), - } - entry = { - 'Id': str(i), - 'MessageBody': json.dumps(body) - } + body = {"uuid": str(uuid.uuid4()), "test": "test_{}".format(i)} + entry = {"Id": str(i), "MessageBody": json.dumps(body)} entries.append(entry) queue.send_messages(Entries=entries) start = time.time() while (time.time() - start) < 30: - result = logs_conn.describe_log_streams(logGroupName='/aws/lambda/testFunction') - log_streams = result.get('logStreams') + result = logs_conn.describe_log_streams(logGroupName="/aws/lambda/testFunction") + log_streams = result.get("logStreams") if not log_streams: time.sleep(1) continue assert len(log_streams) >= 1 - result = logs_conn.get_log_events(logGroupName='/aws/lambda/testFunction', logStreamName=log_streams[0]['logStreamName']) - for event in result.get('events'): - if 'I failed!' in event['message']: + result = logs_conn.get_log_events( + logGroupName="/aws/lambda/testFunction", + logStreamName=log_streams[0]["logStreamName"], + ) + for event in result.get("events"): + if "I failed!" in event["message"]: messages = queue.receive_messages(MaxNumberOfMessages=10) # Verify messages are still visible and unprocessed assert len(messages) == 3 @@ -1198,221 +1191,202 @@ def test_invoke_function_from_sqs_exception(): @mock_lambda @mock_sqs def test_list_event_source_mappings(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - mappings = conn.list_event_source_mappings(EventSourceArn='123') - assert len(mappings['EventSourceMappings']) == 0 + mappings = conn.list_event_source_mappings(EventSourceArn="123") + assert len(mappings["EventSourceMappings"]) == 0 - mappings = conn.list_event_source_mappings(EventSourceArn=queue.attributes['QueueArn']) - assert len(mappings['EventSourceMappings']) == 1 - assert mappings['EventSourceMappings'][0]['UUID'] == response['UUID'] - assert mappings['EventSourceMappings'][0]['FunctionArn'] == func['FunctionArn'] + mappings = conn.list_event_source_mappings( + EventSourceArn=queue.attributes["QueueArn"] + ) + assert len(mappings["EventSourceMappings"]) == 1 + assert mappings["EventSourceMappings"][0]["UUID"] == response["UUID"] + assert mappings["EventSourceMappings"][0]["FunctionArn"] == func["FunctionArn"] @mock_lambda @mock_sqs def test_get_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda") func = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func["FunctionArn"] ) - mapping = conn.get_event_source_mapping(UUID=response['UUID']) - assert mapping['UUID'] == response['UUID'] - assert mapping['FunctionArn'] == func['FunctionArn'] + mapping = conn.get_event_source_mapping(UUID=response["UUID"]) + assert mapping["UUID"] == response["UUID"] + assert mapping["FunctionArn"] == func["FunctionArn"] - conn.get_event_source_mapping.when.called_with(UUID='1')\ - .should.throw(botocore.client.ClientError) + conn.get_event_source_mapping.when.called_with(UUID="1").should.throw( + botocore.client.ClientError + ) @mock_lambda @mock_sqs def test_update_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda") func1 = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) func2 = conn.create_function( - FunctionName='testFunction2', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction2", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func1['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func1["FunctionArn"] ) - assert response['FunctionArn'] == func1['FunctionArn'] - assert response['BatchSize'] == 10 - assert response['State'] == 'Enabled' + assert response["FunctionArn"] == func1["FunctionArn"] + assert response["BatchSize"] == 10 + assert response["State"] == "Enabled" mapping = conn.update_event_source_mapping( - UUID=response['UUID'], - Enabled=False, - BatchSize=15, - FunctionName='testFunction2' - + UUID=response["UUID"], Enabled=False, BatchSize=15, FunctionName="testFunction2" ) - assert mapping['UUID'] == response['UUID'] - assert mapping['FunctionArn'] == func2['FunctionArn'] - assert mapping['State'] == 'Disabled' + assert mapping["UUID"] == response["UUID"] + assert mapping["FunctionArn"] == func2["FunctionArn"] + assert mapping["State"] == "Disabled" @mock_lambda @mock_sqs def test_delete_event_source_mapping(): - sqs = boto3.resource('sqs') + sqs = boto3.resource("sqs") queue = sqs.create_queue(QueueName="test-sqs-queue1") - conn = boto3.client('lambda') + conn = boto3.client("lambda") func1 = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': get_test_zip_file3(), - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": get_test_zip_file3()}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) response = conn.create_event_source_mapping( - EventSourceArn=queue.attributes['QueueArn'], - FunctionName=func1['FunctionArn'], + EventSourceArn=queue.attributes["QueueArn"], FunctionName=func1["FunctionArn"] ) - assert response['FunctionArn'] == func1['FunctionArn'] - assert response['BatchSize'] == 10 - assert response['State'] == 'Enabled' + assert response["FunctionArn"] == func1["FunctionArn"] + assert response["BatchSize"] == 10 + assert response["State"] == "Enabled" - response = conn.delete_event_source_mapping(UUID=response['UUID']) + response = conn.delete_event_source_mapping(UUID=response["UUID"]) - assert response['State'] == 'Deleting' - conn.get_event_source_mapping.when.called_with(UUID=response['UUID'])\ - .should.throw(botocore.client.ClientError) + assert response["State"] == "Deleting" + conn.get_event_source_mapping.when.called_with(UUID=response["UUID"]).should.throw( + botocore.client.ClientError + ) @mock_lambda @mock_s3 def test_update_configuration(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) + conn = boto3.client("lambda", "us-west-2") fxn = conn.create_function( - FunctionName='testFunction', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunction", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, Environment={'Variables': {"test_old_environment": "test_old_value"}} ) - assert fxn['Description'] == 'test lambda function' - assert fxn['Handler'] == 'lambda_function.lambda_handler' - assert fxn['MemorySize'] == 128 - assert fxn['Runtime'] == 'python2.7' - assert fxn['Timeout'] == 3 + assert fxn["Description"] == "test lambda function" + assert fxn["Handler"] == "lambda_function.lambda_handler" + assert fxn["MemorySize"] == 128 + assert fxn["Runtime"] == "python2.7" + assert fxn["Timeout"] == 3 updated_config = conn.update_function_configuration( - FunctionName='testFunction', - Description='updated test lambda function', - Handler='lambda_function.new_lambda_handler', - Runtime='python3.6', + FunctionName="testFunction", + Description="updated test lambda function", + Handler="lambda_function.new_lambda_handler", + Runtime="python3.6", Timeout=7, - Environment={'Variables': {"test_environment": "test_value"}} + Environment={"Variables": {"test_environment": "test_value"}} ) - assert updated_config['ResponseMetadata']['HTTPStatusCode'] == 200 - assert updated_config['Description'] == 'updated test lambda function' - assert updated_config['Handler'] == 'lambda_function.new_lambda_handler' - assert updated_config['MemorySize'] == 128 - assert updated_config['Runtime'] == 'python3.6' - assert updated_config['Timeout'] == 7 - assert updated_config['Environment']['Variables'] == {"test_environment": "test_value"} + assert updated_config["ResponseMetadata"]["HTTPStatusCode"] == 200 + assert updated_config["Description"] == "updated test lambda function" + assert updated_config["Handler"] == "lambda_function.new_lambda_handler" + assert updated_config["MemorySize"] == 128 + assert updated_config["Runtime"] == "python3.6" + assert updated_config["Timeout"] == 7 + assert updated_config["Environment"]["Variables"] == {"test_environment": "test_value"} @mock_lambda def test_update_function_zip(): - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") zip_content_one = get_test_zip_file1() fxn = conn.create_function( - FunctionName='testFunctionZip', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'ZipFile': zip_content_one, - }, - Description='test lambda function', + FunctionName="testFunctionZip", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"ZipFile": zip_content_one}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, @@ -1421,103 +1395,95 @@ def test_update_function_zip(): zip_content_two = get_test_zip_file2() fxn_updated = conn.update_function_code( - FunctionName='testFunctionZip', - ZipFile=zip_content_two, - Publish=True + FunctionName="testFunctionZip", ZipFile=zip_content_two, Publish=True ) - response = conn.get_function( - FunctionName='testFunctionZip', - Qualifier='2' - ) - response['Configuration'].pop('LastModified') + response = conn.get_function(FunctionName="testFunctionZip", Qualifier="2") + response["Configuration"].pop("LastModified") - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - assert len(response['Code']) == 2 - assert response['Code']['RepositoryType'] == 'S3' - assert response['Code']['Location'].startswith('s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com'.format(_lambda_region)) - response['Configuration'].should.equal( + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + assert len(response["Code"]) == 2 + assert response["Code"]["RepositoryType"] == "S3" + assert response["Code"]["Location"].startswith( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com".format(_lambda_region) + ) + response["Configuration"].should.equal( { "CodeSha256": hashlib.sha256(zip_content_two).hexdigest(), "CodeSize": len(zip_content_two), "Description": "test lambda function", - "FunctionArn": 'arn:aws:lambda:{}:123456789012:function:testFunctionZip:2'.format(_lambda_region), + "FunctionArn": "arn:aws:lambda:{}:123456789012:function:testFunctionZip:2".format( + _lambda_region + ), "FunctionName": "testFunctionZip", "Handler": "lambda_function.lambda_handler", "MemorySize": 128, "Role": "test-iam-role", "Runtime": "python2.7", "Timeout": 3, - "Version": '2', - "VpcConfig": { - "SecurityGroupIds": [], - "SubnetIds": [], - } - }, + "Version": "2", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + } ) + @mock_lambda @mock_s3 def test_update_function_s3(): - s3_conn = boto3.client('s3', 'us-west-2') - s3_conn.create_bucket(Bucket='test-bucket') + s3_conn = boto3.client("s3", "us-west-2") + s3_conn.create_bucket(Bucket="test-bucket") zip_content = get_test_zip_file1() - s3_conn.put_object(Bucket='test-bucket', Key='test.zip', Body=zip_content) + s3_conn.put_object(Bucket="test-bucket", Key="test.zip", Body=zip_content) - conn = boto3.client('lambda', 'us-west-2') + conn = boto3.client("lambda", "us-west-2") fxn = conn.create_function( - FunctionName='testFunctionS3', - Runtime='python2.7', - Role='test-iam-role', - Handler='lambda_function.lambda_handler', - Code={ - 'S3Bucket': 'test-bucket', - 'S3Key': 'test.zip', - }, - Description='test lambda function', + FunctionName="testFunctionS3", + Runtime="python2.7", + Role="test-iam-role", + Handler="lambda_function.lambda_handler", + Code={"S3Bucket": "test-bucket", "S3Key": "test.zip"}, + Description="test lambda function", Timeout=3, MemorySize=128, Publish=True, ) zip_content_two = get_test_zip_file2() - s3_conn.put_object(Bucket='test-bucket', Key='test2.zip', Body=zip_content_two) + s3_conn.put_object(Bucket="test-bucket", Key="test2.zip", Body=zip_content_two) fxn_updated = conn.update_function_code( - FunctionName='testFunctionS3', - S3Bucket='test-bucket', - S3Key='test2.zip', - Publish=True + FunctionName="testFunctionS3", + S3Bucket="test-bucket", + S3Key="test2.zip", + Publish=True, ) - response = conn.get_function( - FunctionName='testFunctionS3', - Qualifier='2' - ) - response['Configuration'].pop('LastModified') + response = conn.get_function(FunctionName="testFunctionS3", Qualifier="2") + response["Configuration"].pop("LastModified") - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - assert len(response['Code']) == 2 - assert response['Code']['RepositoryType'] == 'S3' - assert response['Code']['Location'].startswith('s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com'.format(_lambda_region)) - response['Configuration'].should.equal( + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + assert len(response["Code"]) == 2 + assert response["Code"]["RepositoryType"] == "S3" + assert response["Code"]["Location"].startswith( + "s3://awslambda-{0}-tasks.s3-{0}.amazonaws.com".format(_lambda_region) + ) + response["Configuration"].should.equal( { "CodeSha256": hashlib.sha256(zip_content_two).hexdigest(), "CodeSize": len(zip_content_two), "Description": "test lambda function", - "FunctionArn": 'arn:aws:lambda:{}:123456789012:function:testFunctionS3:2'.format(_lambda_region), + "FunctionArn": "arn:aws:lambda:{}:123456789012:function:testFunctionS3:2".format( + _lambda_region + ), "FunctionName": "testFunctionS3", "Handler": "lambda_function.lambda_handler", "MemorySize": 128, "Role": "test-iam-role", "Runtime": "python2.7", "Timeout": 3, - "Version": '2', - "VpcConfig": { - "SecurityGroupIds": [], - "SubnetIds": [], - } - }, + "Version": "2", + "VpcConfig": {"SecurityGroupIds": [], "SubnetIds": []}, + } ) diff --git a/tests/test_batch/test_batch.py b/tests/test_batch/test_batch.py index 5487cfb91..691d90b6d 100644 --- a/tests/test_batch/test_batch.py +++ b/tests/test_batch/test_batch.py @@ -17,17 +17,21 @@ def expected_failure(test): test(*args, **kwargs) except Exception as err: raise nose.SkipTest + return inner -DEFAULT_REGION = 'eu-central-1' + +DEFAULT_REGION = "eu-central-1" def _get_clients(): - return boto3.client('ec2', region_name=DEFAULT_REGION), \ - boto3.client('iam', region_name=DEFAULT_REGION), \ - boto3.client('ecs', region_name=DEFAULT_REGION), \ - boto3.client('logs', region_name=DEFAULT_REGION), \ - boto3.client('batch', region_name=DEFAULT_REGION) + return ( + boto3.client("ec2", region_name=DEFAULT_REGION), + boto3.client("iam", region_name=DEFAULT_REGION), + boto3.client("ecs", region_name=DEFAULT_REGION), + boto3.client("logs", region_name=DEFAULT_REGION), + boto3.client("batch", region_name=DEFAULT_REGION), + ) def _setup(ec2_client, iam_client): @@ -36,26 +40,21 @@ def _setup(ec2_client, iam_client): :return: VPC ID, Subnet ID, Security group ID, IAM Role ARN :rtype: tuple """ - resp = ec2_client.create_vpc(CidrBlock='172.30.0.0/24') - vpc_id = resp['Vpc']['VpcId'] + resp = ec2_client.create_vpc(CidrBlock="172.30.0.0/24") + vpc_id = resp["Vpc"]["VpcId"] resp = ec2_client.create_subnet( - AvailabilityZone='eu-central-1a', - CidrBlock='172.30.0.0/25', - VpcId=vpc_id + AvailabilityZone="eu-central-1a", CidrBlock="172.30.0.0/25", VpcId=vpc_id ) - subnet_id = resp['Subnet']['SubnetId'] + subnet_id = resp["Subnet"]["SubnetId"] resp = ec2_client.create_security_group( - Description='test_sg_desc', - GroupName='test_sg', - VpcId=vpc_id + Description="test_sg_desc", GroupName="test_sg", VpcId=vpc_id ) - sg_id = resp['GroupId'] + sg_id = resp["GroupId"] resp = iam_client.create_role( - RoleName='TestRole', - AssumeRolePolicyDocument='some_policy' + RoleName="TestRole", AssumeRolePolicyDocument="some_policy" ) - iam_arn = resp['Role']['Arn'] + iam_arn = resp["Role"]["Arn"] return vpc_id, subnet_id, sg_id, iam_arn @@ -69,49 +68,40 @@ def test_create_managed_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='MANAGED', - state='ENABLED', + type="MANAGED", + state="ENABLED", computeResources={ - 'type': 'EC2', - 'minvCpus': 5, - 'maxvCpus': 10, - 'desiredvCpus': 5, - 'instanceTypes': [ - 't2.small', - 't2.medium' - ], - 'imageId': 'some_image_id', - 'subnets': [ - subnet_id, - ], - 'securityGroupIds': [ - sg_id, - ], - 'ec2KeyPair': 'string', - 'instanceRole': iam_arn, - 'tags': { - 'string': 'string' - }, - 'bidPercentage': 123, - 'spotIamFleetRole': 'string' + "type": "EC2", + "minvCpus": 5, + "maxvCpus": 10, + "desiredvCpus": 5, + "instanceTypes": ["t2.small", "t2.medium"], + "imageId": "some_image_id", + "subnets": [subnet_id], + "securityGroupIds": [sg_id], + "ec2KeyPair": "string", + "instanceRole": iam_arn, + "tags": {"string": "string"}, + "bidPercentage": 123, + "spotIamFleetRole": "string", }, - serviceRole=iam_arn + serviceRole=iam_arn, ) - resp.should.contain('computeEnvironmentArn') - resp['computeEnvironmentName'].should.equal(compute_name) + resp.should.contain("computeEnvironmentArn") + resp["computeEnvironmentName"].should.equal(compute_name) # Given a t2.medium is 2 vcpu and t2.small is 1, therefore 2 mediums and 1 small should be created resp = ec2_client.describe_instances() - resp.should.contain('Reservations') - len(resp['Reservations']).should.equal(3) + resp.should.contain("Reservations") + len(resp["Reservations"]).should.equal(3) # Should have created 1 ECS cluster resp = ecs_client.list_clusters() - resp.should.contain('clusterArns') - len(resp['clusterArns']).should.equal(1) + resp.should.contain("clusterArns") + len(resp["clusterArns"]).should.equal(1) @mock_ec2 @@ -122,25 +112,26 @@ def test_create_unmanaged_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - resp.should.contain('computeEnvironmentArn') - resp['computeEnvironmentName'].should.equal(compute_name) + resp.should.contain("computeEnvironmentArn") + resp["computeEnvironmentName"].should.equal(compute_name) # Its unmanaged so no instances should be created resp = ec2_client.describe_instances() - resp.should.contain('Reservations') - len(resp['Reservations']).should.equal(0) + resp.should.contain("Reservations") + len(resp["Reservations"]).should.equal(0) # Should have created 1 ECS cluster resp = ecs_client.list_clusters() - resp.should.contain('clusterArns') - len(resp['clusterArns']).should.equal(1) + resp.should.contain("clusterArns") + len(resp["clusterArns"]).should.equal(1) + # TODO create 1000s of tests to test complex option combinations of create environment @@ -153,23 +144,21 @@ def test_describe_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(1) - resp['computeEnvironments'][0]['computeEnvironmentName'].should.equal(compute_name) + len(resp["computeEnvironments"]).should.equal(1) + resp["computeEnvironments"][0]["computeEnvironmentName"].should.equal(compute_name) # Test filtering - resp = batch_client.describe_compute_environments( - computeEnvironments=['test1'] - ) - len(resp['computeEnvironments']).should.equal(0) + resp = batch_client.describe_compute_environments(computeEnvironments=["test1"]) + len(resp["computeEnvironments"]).should.equal(0) @mock_ec2 @@ -180,23 +169,21 @@ def test_delete_unmanaged_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - batch_client.delete_compute_environment( - computeEnvironment=compute_name, - ) + batch_client.delete_compute_environment(computeEnvironment=compute_name) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(0) + len(resp["computeEnvironments"]).should.equal(0) resp = ecs_client.list_clusters() - len(resp.get('clusterArns', [])).should.equal(0) + len(resp.get("clusterArns", [])).should.equal(0) @mock_ec2 @@ -207,53 +194,42 @@ def test_delete_managed_compute_environment(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='MANAGED', - state='ENABLED', + type="MANAGED", + state="ENABLED", computeResources={ - 'type': 'EC2', - 'minvCpus': 5, - 'maxvCpus': 10, - 'desiredvCpus': 5, - 'instanceTypes': [ - 't2.small', - 't2.medium' - ], - 'imageId': 'some_image_id', - 'subnets': [ - subnet_id, - ], - 'securityGroupIds': [ - sg_id, - ], - 'ec2KeyPair': 'string', - 'instanceRole': iam_arn, - 'tags': { - 'string': 'string' - }, - 'bidPercentage': 123, - 'spotIamFleetRole': 'string' + "type": "EC2", + "minvCpus": 5, + "maxvCpus": 10, + "desiredvCpus": 5, + "instanceTypes": ["t2.small", "t2.medium"], + "imageId": "some_image_id", + "subnets": [subnet_id], + "securityGroupIds": [sg_id], + "ec2KeyPair": "string", + "instanceRole": iam_arn, + "tags": {"string": "string"}, + "bidPercentage": 123, + "spotIamFleetRole": "string", }, - serviceRole=iam_arn + serviceRole=iam_arn, ) - batch_client.delete_compute_environment( - computeEnvironment=compute_name, - ) + batch_client.delete_compute_environment(computeEnvironment=compute_name) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(0) + len(resp["computeEnvironments"]).should.equal(0) resp = ec2_client.describe_instances() - resp.should.contain('Reservations') - len(resp['Reservations']).should.equal(3) - for reservation in resp['Reservations']: - reservation['Instances'][0]['State']['Name'].should.equal('terminated') + resp.should.contain("Reservations") + len(resp["Reservations"]).should.equal(3) + for reservation in resp["Reservations"]: + reservation["Instances"][0]["State"]["Name"].should.equal("terminated") resp = ecs_client.list_clusters() - len(resp.get('clusterArns', [])).should.equal(0) + len(resp.get("clusterArns", [])).should.equal(0) @mock_ec2 @@ -264,22 +240,21 @@ def test_update_unmanaged_compute_environment_state(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) batch_client.update_compute_environment( - computeEnvironment=compute_name, - state='DISABLED' + computeEnvironment=compute_name, state="DISABLED" ) resp = batch_client.describe_compute_environments() - len(resp['computeEnvironments']).should.equal(1) - resp['computeEnvironments'][0]['state'].should.equal('DISABLED') + len(resp["computeEnvironments"]).should.equal(1) + resp["computeEnvironments"][0]["state"].should.equal("DISABLED") @mock_ec2 @@ -290,87 +265,70 @@ def test_create_job_queue(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - resp.should.contain('jobQueueArn') - resp.should.contain('jobQueueName') - queue_arn = resp['jobQueueArn'] + resp.should.contain("jobQueueArn") + resp.should.contain("jobQueueName") + queue_arn = resp["jobQueueArn"] resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(1) - resp['jobQueues'][0]['jobQueueArn'].should.equal(queue_arn) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(1) + resp["jobQueues"][0]["jobQueueArn"].should.equal(queue_arn) - resp = batch_client.describe_job_queues(jobQueues=['test_invalid_queue']) - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(0) + resp = batch_client.describe_job_queues(jobQueues=["test_invalid_queue"]) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(0) # Create job queue which already exists try: resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') - + err.response["Error"]["Code"].should.equal("ClientException") # Create job queue with incorrect state try: resp = batch_client.create_job_queue( - jobQueueName='test_job_queue2', - state='JUNK', + jobQueueName="test_job_queue2", + state="JUNK", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") # Create job queue with no compute env try: resp = batch_client.create_job_queue( - jobQueueName='test_job_queue3', - state='JUNK', + jobQueueName="test_job_queue3", + state="JUNK", priority=123, - computeEnvironmentOrder=[ - - ] + computeEnvironmentOrder=[], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") + @mock_ec2 @mock_ecs @@ -380,29 +338,26 @@ def test_job_queue_bad_arn(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] try: batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn + 'LALALA' - }, - ] + {"order": 123, "computeEnvironment": arn + "LALALA"} + ], ) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") @mock_ec2 @@ -413,48 +368,36 @@ def test_update_job_queue(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] - batch_client.update_job_queue( - jobQueue=queue_arn, - priority=5 - ) + batch_client.update_job_queue(jobQueue=queue_arn, priority=5) resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(1) - resp['jobQueues'][0]['priority'].should.equal(5) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(1) + resp["jobQueues"][0]["priority"].should.equal(5) - batch_client.update_job_queue( - jobQueue='test_job_queue', - priority=5 - ) + batch_client.update_job_queue(jobQueue="test_job_queue", priority=5) resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(1) - resp['jobQueues'][0]['priority'].should.equal(5) - + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(1) + resp["jobQueues"][0]["priority"].should.equal(5) @mock_ec2 @@ -465,35 +408,28 @@ def test_update_job_queue(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] - batch_client.delete_job_queue( - jobQueue=queue_arn - ) + batch_client.delete_job_queue(jobQueue=queue_arn) resp = batch_client.describe_job_queues() - resp.should.contain('jobQueues') - len(resp['jobQueues']).should.equal(0) + resp.should.contain("jobQueues") + len(resp["jobQueues"]).should.equal(0) @mock_ec2 @@ -505,21 +441,23 @@ def test_register_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - resp.should.contain('jobDefinitionArn') - resp.should.contain('jobDefinitionName') - resp.should.contain('revision') + resp.should.contain("jobDefinitionArn") + resp.should.contain("jobDefinitionName") + resp.should.contain("revision") - assert resp['jobDefinitionArn'].endswith('{0}:{1}'.format(resp['jobDefinitionName'], resp['revision'])) + assert resp["jobDefinitionArn"].endswith( + "{0}:{1}".format(resp["jobDefinitionName"], resp["revision"]) + ) @mock_ec2 @@ -532,68 +470,69 @@ def test_reregister_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp1 = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - resp1.should.contain('jobDefinitionArn') - resp1.should.contain('jobDefinitionName') - resp1.should.contain('revision') + resp1.should.contain("jobDefinitionArn") + resp1.should.contain("jobDefinitionName") + resp1.should.contain("revision") - assert resp1['jobDefinitionArn'].endswith('{0}:{1}'.format(resp1['jobDefinitionName'], resp1['revision'])) - resp1['revision'].should.equal(1) + assert resp1["jobDefinitionArn"].endswith( + "{0}:{1}".format(resp1["jobDefinitionName"], resp1["revision"]) + ) + resp1["revision"].should.equal(1) resp2 = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 68, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 68, + "command": ["sleep", "10"], + }, ) - resp2['revision'].should.equal(2) + resp2["revision"].should.equal(2) - resp2['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) + resp2["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) resp3 = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 42, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 42, + "command": ["sleep", "10"], + }, ) - resp3['revision'].should.equal(3) + resp3["revision"].should.equal(3) - resp3['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) - resp3['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn']) + resp3["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) + resp3["jobDefinitionArn"].should_not.equal(resp2["jobDefinitionArn"]) resp4 = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 41, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 41, + "command": ["sleep", "10"], + }, ) - resp4['revision'].should.equal(4) + resp4["revision"].should.equal(4) - resp4['jobDefinitionArn'].should_not.equal(resp1['jobDefinitionArn']) - resp4['jobDefinitionArn'].should_not.equal(resp2['jobDefinitionArn']) - resp4['jobDefinitionArn'].should_not.equal(resp3['jobDefinitionArn']) - + resp4["jobDefinitionArn"].should_not.equal(resp1["jobDefinitionArn"]) + resp4["jobDefinitionArn"].should_not.equal(resp2["jobDefinitionArn"]) + resp4["jobDefinitionArn"].should_not.equal(resp3["jobDefinitionArn"]) @mock_ec2 @@ -605,20 +544,20 @@ def test_delete_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - batch_client.deregister_job_definition(jobDefinition=resp['jobDefinitionArn']) + batch_client.deregister_job_definition(jobDefinition=resp["jobDefinitionArn"]) resp = batch_client.describe_job_definitions() - len(resp['jobDefinitions']).should.equal(0) + len(resp["jobDefinitions"]).should.equal(0) @mock_ec2 @@ -630,48 +569,44 @@ def test_describe_task_definition(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 64, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 64, + "command": ["sleep", "10"], + }, ) batch_client.register_job_definition( - jobDefinitionName='test1', - type='container', + jobDefinitionName="test1", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 64, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 64, + "command": ["sleep", "10"], + }, ) - resp = batch_client.describe_job_definitions( - jobDefinitionName='sleep10' - ) - len(resp['jobDefinitions']).should.equal(2) + resp = batch_client.describe_job_definitions(jobDefinitionName="sleep10") + len(resp["jobDefinitions"]).should.equal(2) resp = batch_client.describe_job_definitions() - len(resp['jobDefinitions']).should.equal(3) + len(resp["jobDefinitions"]).should.equal(3) - resp = batch_client.describe_job_definitions( - jobDefinitions=['sleep10', 'test1'] - ) - len(resp['jobDefinitions']).should.equal(3) + resp = batch_client.describe_job_definitions(jobDefinitions=["sleep10", "test1"]) + len(resp["jobDefinitions"]).should.equal(3) @mock_logs @@ -683,77 +618,71 @@ def test_submit_job_by_name(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] - job_definition_name = 'sleep10' + job_definition_name = "sleep10" batch_client.register_job_definition( jobDefinitionName=job_definition_name, - type='container', + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) batch_client.register_job_definition( jobDefinitionName=job_definition_name, - type='container', + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 256, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 256, + "command": ["sleep", "10"], + }, ) resp = batch_client.register_job_definition( jobDefinitionName=job_definition_name, - type='container', + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 512, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 512, + "command": ["sleep", "10"], + }, ) - job_definition_arn = resp['jobDefinitionArn'] + job_definition_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_definition_name + jobName="test1", jobQueue=queue_arn, jobDefinition=job_definition_name ) - job_id = resp['jobId'] + job_id = resp["jobId"] resp_jobs = batch_client.describe_jobs(jobs=[job_id]) # batch_client.terminate_job(jobId=job_id) - len(resp_jobs['jobs']).should.equal(1) - resp_jobs['jobs'][0]['jobId'].should.equal(job_id) - resp_jobs['jobs'][0]['jobQueue'].should.equal(queue_arn) - resp_jobs['jobs'][0]['jobDefinition'].should.equal(job_definition_arn) + len(resp_jobs["jobs"]).should.equal(1) + resp_jobs["jobs"][0]["jobId"].should.equal(job_id) + resp_jobs["jobs"][0]["jobQueue"].should.equal(queue_arn) + resp_jobs["jobs"][0]["jobDefinition"].should.equal(job_definition_arn) + # SLOW TESTS @expected_failure @@ -766,67 +695,68 @@ def test_submit_job(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - job_def_arn = resp['jobDefinitionArn'] + job_def_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id = resp['jobId'] + job_id = resp["jobId"] future = datetime.datetime.now() + datetime.timedelta(seconds=30) while datetime.datetime.now() < future: resp = batch_client.describe_jobs(jobs=[job_id]) - print("{0}:{1} {2}".format(resp['jobs'][0]['jobName'], resp['jobs'][0]['jobId'], resp['jobs'][0]['status'])) + print( + "{0}:{1} {2}".format( + resp["jobs"][0]["jobName"], + resp["jobs"][0]["jobId"], + resp["jobs"][0]["status"], + ) + ) - if resp['jobs'][0]['status'] == 'FAILED': - raise RuntimeError('Batch job failed') - if resp['jobs'][0]['status'] == 'SUCCEEDED': + if resp["jobs"][0]["status"] == "FAILED": + raise RuntimeError("Batch job failed") + if resp["jobs"][0]["status"] == "SUCCEEDED": break time.sleep(0.5) else: - raise RuntimeError('Batch job timed out') + raise RuntimeError("Batch job timed out") - resp = logs_client.describe_log_streams(logGroupName='/aws/batch/job') - len(resp['logStreams']).should.equal(1) - ls_name = resp['logStreams'][0]['logStreamName'] + resp = logs_client.describe_log_streams(logGroupName="/aws/batch/job") + len(resp["logStreams"]).should.equal(1) + ls_name = resp["logStreams"][0]["logStreamName"] - resp = logs_client.get_log_events(logGroupName='/aws/batch/job', logStreamName=ls_name) - len(resp['events']).should.be.greater_than(5) + resp = logs_client.get_log_events( + logGroupName="/aws/batch/job", logStreamName=ls_name + ) + len(resp["events"]).should.be.greater_than(5) @expected_failure @@ -839,82 +769,71 @@ def test_list_jobs(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - job_def_arn = resp['jobDefinitionArn'] + job_def_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id1 = resp['jobId'] + job_id1 = resp["jobId"] resp = batch_client.submit_job( - jobName='test2', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test2", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id2 = resp['jobId'] + job_id2 = resp["jobId"] future = datetime.datetime.now() + datetime.timedelta(seconds=30) resp_finished_jobs = batch_client.list_jobs( - jobQueue=queue_arn, - jobStatus='SUCCEEDED' + jobQueue=queue_arn, jobStatus="SUCCEEDED" ) # Wait only as long as it takes to run the jobs while datetime.datetime.now() < future: resp = batch_client.describe_jobs(jobs=[job_id1, job_id2]) - any_failed_jobs = any([job['status'] == 'FAILED' for job in resp['jobs']]) - succeeded_jobs = all([job['status'] == 'SUCCEEDED' for job in resp['jobs']]) + any_failed_jobs = any([job["status"] == "FAILED" for job in resp["jobs"]]) + succeeded_jobs = all([job["status"] == "SUCCEEDED" for job in resp["jobs"]]) if any_failed_jobs: - raise RuntimeError('A Batch job failed') + raise RuntimeError("A Batch job failed") if succeeded_jobs: break time.sleep(0.5) else: - raise RuntimeError('Batch jobs timed out') + raise RuntimeError("Batch jobs timed out") resp_finished_jobs2 = batch_client.list_jobs( - jobQueue=queue_arn, - jobStatus='SUCCEEDED' + jobQueue=queue_arn, jobStatus="SUCCEEDED" ) - len(resp_finished_jobs['jobSummaryList']).should.equal(0) - len(resp_finished_jobs2['jobSummaryList']).should.equal(2) + len(resp_finished_jobs["jobSummaryList"]).should.equal(0) + len(resp_finished_jobs2["jobSummaryList"]).should.equal(2) @expected_failure @@ -927,55 +846,47 @@ def test_terminate_job(): ec2_client, iam_client, ecs_client, logs_client, batch_client = _get_clients() vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) - compute_name = 'test_compute_env' + compute_name = "test_compute_env" resp = batch_client.create_compute_environment( computeEnvironmentName=compute_name, - type='UNMANAGED', - state='ENABLED', - serviceRole=iam_arn + type="UNMANAGED", + state="ENABLED", + serviceRole=iam_arn, ) - arn = resp['computeEnvironmentArn'] + arn = resp["computeEnvironmentArn"] resp = batch_client.create_job_queue( - jobQueueName='test_job_queue', - state='ENABLED', + jobQueueName="test_job_queue", + state="ENABLED", priority=123, - computeEnvironmentOrder=[ - { - 'order': 123, - 'computeEnvironment': arn - }, - ] + computeEnvironmentOrder=[{"order": 123, "computeEnvironment": arn}], ) - queue_arn = resp['jobQueueArn'] + queue_arn = resp["jobQueueArn"] resp = batch_client.register_job_definition( - jobDefinitionName='sleep10', - type='container', + jobDefinitionName="sleep10", + type="container", containerProperties={ - 'image': 'busybox', - 'vcpus': 1, - 'memory': 128, - 'command': ['sleep', '10'] - } + "image": "busybox", + "vcpus": 1, + "memory": 128, + "command": ["sleep", "10"], + }, ) - job_def_arn = resp['jobDefinitionArn'] + job_def_arn = resp["jobDefinitionArn"] resp = batch_client.submit_job( - jobName='test1', - jobQueue=queue_arn, - jobDefinition=job_def_arn + jobName="test1", jobQueue=queue_arn, jobDefinition=job_def_arn ) - job_id = resp['jobId'] + job_id = resp["jobId"] time.sleep(2) - batch_client.terminate_job(jobId=job_id, reason='test_terminate') + batch_client.terminate_job(jobId=job_id, reason="test_terminate") time.sleep(1) resp = batch_client.describe_jobs(jobs=[job_id]) - resp['jobs'][0]['jobName'].should.equal('test1') - resp['jobs'][0]['status'].should.equal('FAILED') - resp['jobs'][0]['statusReason'].should.equal('test_terminate') - + resp["jobs"][0]["jobName"].should.equal("test1") + resp["jobs"][0]["status"].should.equal("FAILED") + resp["jobs"][0]["statusReason"].should.equal("test_terminate") diff --git a/tests/test_batch/test_cloudformation.py b/tests/test_batch/test_cloudformation.py index 1e37aa3a6..a8b94f3a3 100644 --- a/tests/test_batch/test_cloudformation.py +++ b/tests/test_batch/test_cloudformation.py @@ -5,20 +5,29 @@ import datetime import boto3 from botocore.exceptions import ClientError import sure # noqa -from moto import mock_batch, mock_iam, mock_ec2, mock_ecs, mock_logs, mock_cloudformation +from moto import ( + mock_batch, + mock_iam, + mock_ec2, + mock_ecs, + mock_logs, + mock_cloudformation, +) import functools import nose import json -DEFAULT_REGION = 'eu-central-1' +DEFAULT_REGION = "eu-central-1" def _get_clients(): - return boto3.client('ec2', region_name=DEFAULT_REGION), \ - boto3.client('iam', region_name=DEFAULT_REGION), \ - boto3.client('ecs', region_name=DEFAULT_REGION), \ - boto3.client('logs', region_name=DEFAULT_REGION), \ - boto3.client('batch', region_name=DEFAULT_REGION) + return ( + boto3.client("ec2", region_name=DEFAULT_REGION), + boto3.client("iam", region_name=DEFAULT_REGION), + boto3.client("ecs", region_name=DEFAULT_REGION), + boto3.client("logs", region_name=DEFAULT_REGION), + boto3.client("batch", region_name=DEFAULT_REGION), + ) def _setup(ec2_client, iam_client): @@ -27,26 +36,21 @@ def _setup(ec2_client, iam_client): :return: VPC ID, Subnet ID, Security group ID, IAM Role ARN :rtype: tuple """ - resp = ec2_client.create_vpc(CidrBlock='172.30.0.0/24') - vpc_id = resp['Vpc']['VpcId'] + resp = ec2_client.create_vpc(CidrBlock="172.30.0.0/24") + vpc_id = resp["Vpc"]["VpcId"] resp = ec2_client.create_subnet( - AvailabilityZone='eu-central-1a', - CidrBlock='172.30.0.0/25', - VpcId=vpc_id + AvailabilityZone="eu-central-1a", CidrBlock="172.30.0.0/25", VpcId=vpc_id ) - subnet_id = resp['Subnet']['SubnetId'] + subnet_id = resp["Subnet"]["SubnetId"] resp = ec2_client.create_security_group( - Description='test_sg_desc', - GroupName='test_sg', - VpcId=vpc_id + Description="test_sg_desc", GroupName="test_sg", VpcId=vpc_id ) - sg_id = resp['GroupId'] + sg_id = resp["GroupId"] resp = iam_client.create_role( - RoleName='TestRole', - AssumeRolePolicyDocument='some_policy' + RoleName="TestRole", AssumeRolePolicyDocument="some_policy" ) - iam_arn = resp['Role']['Arn'] + iam_arn = resp["Role"]["Arn"] return vpc_id, subnet_id, sg_id, iam_arn @@ -61,7 +65,7 @@ def test_create_env_cf(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) create_environment_template = { - 'Resources': { + "Resources": { "ComputeEnvironment": { "Type": "AWS::Batch::ComputeEnvironment", "Properties": { @@ -71,32 +75,35 @@ def test_create_env_cf(): "MinvCpus": 0, "DesiredvCpus": 0, "MaxvCpus": 64, - "InstanceTypes": [ - "optimal" - ], + "InstanceTypes": ["optimal"], "Subnets": [subnet_id], "SecurityGroupIds": [sg_id], - "InstanceRole": iam_arn + "InstanceRole": iam_arn, }, - "ServiceRole": iam_arn - } + "ServiceRole": iam_arn, + }, } } } cf_json = json.dumps(create_environment_template) - cf_conn = boto3.client('cloudformation', DEFAULT_REGION) - stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=cf_json, - )['StackId'] + cf_conn = boto3.client("cloudformation", DEFAULT_REGION) + stack_id = cf_conn.create_stack(StackName="test_stack", TemplateBody=cf_json)[ + "StackId" + ] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - stack_resources['StackResourceSummaries'][0]['ResourceStatus'].should.equal('CREATE_COMPLETE') + stack_resources["StackResourceSummaries"][0]["ResourceStatus"].should.equal( + "CREATE_COMPLETE" + ) # Spot checks on the ARN - stack_resources['StackResourceSummaries'][0]['PhysicalResourceId'].startswith('arn:aws:batch:') - stack_resources['StackResourceSummaries'][0]['PhysicalResourceId'].should.contain('test_stack') + stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"].startswith( + "arn:aws:batch:" + ) + stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"].should.contain( + "test_stack" + ) @mock_cloudformation() @@ -109,7 +116,7 @@ def test_create_job_queue_cf(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) create_environment_template = { - 'Resources': { + "Resources": { "ComputeEnvironment": { "Type": "AWS::Batch::ComputeEnvironment", "Properties": { @@ -119,17 +126,14 @@ def test_create_job_queue_cf(): "MinvCpus": 0, "DesiredvCpus": 0, "MaxvCpus": 64, - "InstanceTypes": [ - "optimal" - ], + "InstanceTypes": ["optimal"], "Subnets": [subnet_id], "SecurityGroupIds": [sg_id], - "InstanceRole": iam_arn + "InstanceRole": iam_arn, }, - "ServiceRole": iam_arn - } + "ServiceRole": iam_arn, + }, }, - "JobQueue": { "Type": "AWS::Batch::JobQueue", "Properties": { @@ -137,31 +141,35 @@ def test_create_job_queue_cf(): "ComputeEnvironmentOrder": [ { "Order": 1, - "ComputeEnvironment": {"Ref": "ComputeEnvironment"} + "ComputeEnvironment": {"Ref": "ComputeEnvironment"}, } - ] - } + ], + }, }, } } cf_json = json.dumps(create_environment_template) - cf_conn = boto3.client('cloudformation', DEFAULT_REGION) - stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=cf_json, - )['StackId'] + cf_conn = boto3.client("cloudformation", DEFAULT_REGION) + stack_id = cf_conn.create_stack(StackName="test_stack", TemplateBody=cf_json)[ + "StackId" + ] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - len(stack_resources['StackResourceSummaries']).should.equal(2) + len(stack_resources["StackResourceSummaries"]).should.equal(2) - job_queue_resource = list(filter(lambda item: item['ResourceType'] == 'AWS::Batch::JobQueue', stack_resources['StackResourceSummaries']))[0] + job_queue_resource = list( + filter( + lambda item: item["ResourceType"] == "AWS::Batch::JobQueue", + stack_resources["StackResourceSummaries"], + ) + )[0] - job_queue_resource['ResourceStatus'].should.equal('CREATE_COMPLETE') + job_queue_resource["ResourceStatus"].should.equal("CREATE_COMPLETE") # Spot checks on the ARN - job_queue_resource['PhysicalResourceId'].startswith('arn:aws:batch:') - job_queue_resource['PhysicalResourceId'].should.contain('test_stack') - job_queue_resource['PhysicalResourceId'].should.contain('job-queue/') + job_queue_resource["PhysicalResourceId"].startswith("arn:aws:batch:") + job_queue_resource["PhysicalResourceId"].should.contain("test_stack") + job_queue_resource["PhysicalResourceId"].should.contain("job-queue/") @mock_cloudformation() @@ -174,7 +182,7 @@ def test_create_job_def_cf(): vpc_id, subnet_id, sg_id, iam_arn = _setup(ec2_client, iam_client) create_environment_template = { - 'Resources': { + "Resources": { "ComputeEnvironment": { "Type": "AWS::Batch::ComputeEnvironment", "Properties": { @@ -184,17 +192,14 @@ def test_create_job_def_cf(): "MinvCpus": 0, "DesiredvCpus": 0, "MaxvCpus": 64, - "InstanceTypes": [ - "optimal" - ], + "InstanceTypes": ["optimal"], "Subnets": [subnet_id], "SecurityGroupIds": [sg_id], - "InstanceRole": iam_arn + "InstanceRole": iam_arn, }, - "ServiceRole": iam_arn - } + "ServiceRole": iam_arn, + }, }, - "JobQueue": { "Type": "AWS::Batch::JobQueue", "Properties": { @@ -202,46 +207,54 @@ def test_create_job_def_cf(): "ComputeEnvironmentOrder": [ { "Order": 1, - "ComputeEnvironment": {"Ref": "ComputeEnvironment"} + "ComputeEnvironment": {"Ref": "ComputeEnvironment"}, } - ] - } + ], + }, }, - "JobDefinition": { "Type": "AWS::Batch::JobDefinition", "Properties": { "Type": "container", "ContainerProperties": { "Image": { - "Fn::Join": ["", ["137112412989.dkr.ecr.", {"Ref": "AWS::Region"}, ".amazonaws.com/amazonlinux:latest"]] + "Fn::Join": [ + "", + [ + "137112412989.dkr.ecr.", + {"Ref": "AWS::Region"}, + ".amazonaws.com/amazonlinux:latest", + ], + ] }, "Vcpus": 2, "Memory": 2000, - "Command": ["echo", "Hello world"] + "Command": ["echo", "Hello world"], }, - "RetryStrategy": { - "Attempts": 1 - } - } + "RetryStrategy": {"Attempts": 1}, + }, }, } } cf_json = json.dumps(create_environment_template) - cf_conn = boto3.client('cloudformation', DEFAULT_REGION) - stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=cf_json, - )['StackId'] + cf_conn = boto3.client("cloudformation", DEFAULT_REGION) + stack_id = cf_conn.create_stack(StackName="test_stack", TemplateBody=cf_json)[ + "StackId" + ] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - len(stack_resources['StackResourceSummaries']).should.equal(3) + len(stack_resources["StackResourceSummaries"]).should.equal(3) - job_def_resource = list(filter(lambda item: item['ResourceType'] == 'AWS::Batch::JobDefinition', stack_resources['StackResourceSummaries']))[0] + job_def_resource = list( + filter( + lambda item: item["ResourceType"] == "AWS::Batch::JobDefinition", + stack_resources["StackResourceSummaries"], + ) + )[0] - job_def_resource['ResourceStatus'].should.equal('CREATE_COMPLETE') + job_def_resource["ResourceStatus"].should.equal("CREATE_COMPLETE") # Spot checks on the ARN - job_def_resource['PhysicalResourceId'].startswith('arn:aws:batch:') - job_def_resource['PhysicalResourceId'].should.contain('test_stack-JobDef') - job_def_resource['PhysicalResourceId'].should.contain('job-definition/') + job_def_resource["PhysicalResourceId"].startswith("arn:aws:batch:") + job_def_resource["PhysicalResourceId"].should.contain("test_stack-JobDef") + job_def_resource["PhysicalResourceId"].should.contain("job-definition/") diff --git a/tests/test_batch/test_server.py b/tests/test_batch/test_server.py index 4a74260a8..91b5f0c47 100644 --- a/tests/test_batch/test_server.py +++ b/tests/test_batch/test_server.py @@ -5,9 +5,9 @@ import sure # noqa import moto.server as server from moto import mock_batch -''' +""" Test the different server responses -''' +""" @mock_batch @@ -15,5 +15,5 @@ def test_batch_list(): backend = server.create_backend_app("batch") test_client = backend.test_client() - res = test_client.get('/v1/describecomputeenvironments') + res = test_client.get("/v1/describecomputeenvironments") res.status_code.should.equal(200) diff --git a/tests/test_cloudformation/fixtures/ec2_classic_eip.py b/tests/test_cloudformation/fixtures/ec2_classic_eip.py index 626e90ada..fd7758300 100644 --- a/tests/test_cloudformation/fixtures/ec2_classic_eip.py +++ b/tests/test_cloudformation/fixtures/ec2_classic_eip.py @@ -1,9 +1,3 @@ from __future__ import unicode_literals -template = { - "Resources": { - "EC2EIP": { - "Type": "AWS::EC2::EIP" - } - } -} +template = {"Resources": {"EC2EIP": {"Type": "AWS::EC2::EIP"}}} diff --git a/tests/test_cloudformation/fixtures/fn_join.py b/tests/test_cloudformation/fixtures/fn_join.py index 79b62d01e..ac73e3cd2 100644 --- a/tests/test_cloudformation/fixtures/fn_join.py +++ b/tests/test_cloudformation/fixtures/fn_join.py @@ -1,23 +1,11 @@ from __future__ import unicode_literals template = { - "Resources": { - "EC2EIP": { - "Type": "AWS::EC2::EIP" - } - }, + "Resources": {"EC2EIP": {"Type": "AWS::EC2::EIP"}}, "Outputs": { "EIP": { "Description": "EIP for joining", - "Value": { - "Fn::Join": [ - ":", - [ - "test eip", - {"Ref": "EC2EIP"} - ] - ] - } + "Value": {"Fn::Join": [":", ["test eip", {"Ref": "EC2EIP"}]]}, } - } + }, } diff --git a/tests/test_cloudformation/fixtures/kms_key.py b/tests/test_cloudformation/fixtures/kms_key.py index 366dbfcf5..af6a535d1 100644 --- a/tests/test_cloudformation/fixtures/kms_key.py +++ b/tests/test_cloudformation/fixtures/kms_key.py @@ -2,38 +2,45 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template to create a KMS Key. The Fn::GetAtt is used to retrieve the ARN", - - "Resources" : { - "myKey" : { - "Type" : "AWS::KMS::Key", - "Properties" : { + "Resources": { + "myKey": { + "Type": "AWS::KMS::Key", + "Properties": { "Description": "Sample KmsKey", "EnableKeyRotation": False, "Enabled": True, - "KeyPolicy" : { + "KeyPolicy": { "Version": "2012-10-17", "Id": "key-default-1", "Statement": [ { - "Sid": "Enable IAM User Permissions", - "Effect": "Allow", - "Principal": { - "AWS": { "Fn::Join" : ["" , ["arn:aws:iam::", {"Ref" : "AWS::AccountId"} ,":root" ]] } - }, - "Action": "kms:*", - "Resource": "*" + "Sid": "Enable IAM User Permissions", + "Effect": "Allow", + "Principal": { + "AWS": { + "Fn::Join": [ + "", + [ + "arn:aws:iam::", + {"Ref": "AWS::AccountId"}, + ":root", + ], + ] + } + }, + "Action": "kms:*", + "Resource": "*", } - ] - } - } + ], + }, + }, } }, - "Outputs" : { - "KeyArn" : { + "Outputs": { + "KeyArn": { "Description": "Generated Key Arn", - "Value" : { "Fn::GetAtt" : [ "myKey", "Arn" ] } + "Value": {"Fn::GetAtt": ["myKey", "Arn"]}, } - } -} \ No newline at end of file + }, +} diff --git a/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py b/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py index 6f379daa6..d58516d3d 100644 --- a/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py +++ b/tests/test_cloudformation/fixtures/rds_mysql_with_db_parameter_group.py @@ -2,9 +2,7 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template RDS_MySQL_With_Read_Replica: Sample template showing how to create a highly-available, RDS DBInstance with a read replica. **WARNING** This template creates an Amazon Relational Database Service database instance and Amazon CloudWatch alarms. You will be billed for the AWS resources used if you create a stack from this template.", - "Parameters": { "DBName": { "Default": "MyDatabase", @@ -13,13 +11,9 @@ template = { "MinLength": "1", "MaxLength": "64", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - - "DBInstanceIdentifier": { - "Type": "String" - }, - + "DBInstanceIdentifier": {"Type": "String"}, "DBUser": { "NoEcho": "true", "Description": "The database admin account username", @@ -27,9 +21,8 @@ template = { "MinLength": "1", "MaxLength": "16", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - "DBPassword": { "NoEcho": "true", "Description": "The database admin account password", @@ -37,112 +30,121 @@ template = { "MinLength": "1", "MaxLength": "41", "AllowedPattern": "[a-zA-Z0-9]+", - "ConstraintDescription": "must contain only alphanumeric characters." + "ConstraintDescription": "must contain only alphanumeric characters.", }, - "DBAllocatedStorage": { "Default": "5", "Description": "The size of the database (Gb)", "Type": "Number", "MinValue": "5", "MaxValue": "1024", - "ConstraintDescription": "must be between 5 and 1024Gb." + "ConstraintDescription": "must be between 5 and 1024Gb.", }, - "DBInstanceClass": { "Description": "The database instance type", "Type": "String", "Default": "db.m1.small", - "AllowedValues": ["db.t1.micro", "db.m1.small", "db.m1.medium", "db.m1.large", "db.m1.xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.m3.medium", "db.m3.large", "db.m3.xlarge", "db.m3.2xlarge", "db.r3.large", "db.r3.xlarge", "db.r3.2xlarge", "db.r3.4xlarge", "db.r3.8xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.cr1.8xlarge"], - "ConstraintDescription": "must select a valid database instance type." + "AllowedValues": [ + "db.t1.micro", + "db.m1.small", + "db.m1.medium", + "db.m1.large", + "db.m1.xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.m3.medium", + "db.m3.large", + "db.m3.xlarge", + "db.m3.2xlarge", + "db.r3.large", + "db.r3.xlarge", + "db.r3.2xlarge", + "db.r3.4xlarge", + "db.r3.8xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.cr1.8xlarge", + ], + "ConstraintDescription": "must select a valid database instance type.", }, - "EC2SecurityGroup": { "Description": "The EC2 security group that contains instances that need access to the database", "Default": "default", "Type": "String", "AllowedPattern": "[a-zA-Z0-9\\-]+", - "ConstraintDescription": "must be a valid security group name." + "ConstraintDescription": "must be a valid security group name.", }, - "MultiAZ": { "Description": "Multi-AZ master database", "Type": "String", "Default": "false", "AllowedValues": ["true", "false"], - "ConstraintDescription": "must be true or false." - } + "ConstraintDescription": "must be true or false.", + }, }, - "Conditions": { - "Is-EC2-VPC": {"Fn::Or": [{"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, - {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}]}, - "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]} + "Is-EC2-VPC": { + "Fn::Or": [ + {"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, + {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}, + ] + }, + "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]}, }, - "Resources": { "DBParameterGroup": { "Type": "AWS::RDS::DBParameterGroup", "Properties": { "Description": "DB Parameter Goup", "Family": "MySQL5.1", - "Parameters": { - "BACKLOG_QUEUE_LIMIT": "2048" - } - } + "Parameters": {"BACKLOG_QUEUE_LIMIT": "2048"}, + }, }, - "DBEC2SecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Condition": "Is-EC2-VPC", "Properties": { "GroupDescription": "Open database for access", - "SecurityGroupIngress": [{ - "IpProtocol": "tcp", - "FromPort": "3306", - "ToPort": "3306", - "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"} - }] - } + "SecurityGroupIngress": [ + { + "IpProtocol": "tcp", + "FromPort": "3306", + "ToPort": "3306", + "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"}, + } + ], + }, }, - "DBSecurityGroup": { "Type": "AWS::RDS::DBSecurityGroup", "Condition": "Is-EC2-Classic", "Properties": { - "DBSecurityGroupIngress": [{ - "EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"} - }], - "GroupDescription": "database access" - } + "DBSecurityGroupIngress": [ + {"EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"}} + ], + "GroupDescription": "database access", + }, }, - - "my_vpc": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - } - }, - + "my_vpc": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, "EC2Subnet": { "Type": "AWS::EC2::Subnet", "Condition": "Is-EC2-VPC", "Properties": { "AvailabilityZone": "eu-central-1a", "CidrBlock": "10.0.1.0/24", - "VpcId": {"Ref": "my_vpc"} - } + "VpcId": {"Ref": "my_vpc"}, + }, }, - "DBSubnet": { "Type": "AWS::RDS::DBSubnetGroup", "Condition": "Is-EC2-VPC", "Properties": { "DBSubnetGroupDescription": "my db subnet group", "SubnetIds": [{"Ref": "EC2Subnet"}], - } + }, }, - "MasterDB": { "Type": "AWS::RDS::DBInstance", "Properties": { @@ -151,54 +153,79 @@ template = { "AllocatedStorage": {"Ref": "DBAllocatedStorage"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, "Engine": "MySQL", - "DBSubnetGroupName": {"Fn::If": ["Is-EC2-VPC", {"Ref": "DBSubnet"}, {"Ref": "AWS::NoValue"}]}, + "DBSubnetGroupName": { + "Fn::If": [ + "Is-EC2-VPC", + {"Ref": "DBSubnet"}, + {"Ref": "AWS::NoValue"}, + ] + }, "MasterUsername": {"Ref": "DBUser"}, "MasterUserPassword": {"Ref": "DBPassword"}, "MultiAZ": {"Ref": "MultiAZ"}, "Tags": [{"Key": "Name", "Value": "Master Database"}], - "VPCSecurityGroups": {"Fn::If": ["Is-EC2-VPC", [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], {"Ref": "AWS::NoValue"}]}, - "DBSecurityGroups": {"Fn::If": ["Is-EC2-Classic", [{"Ref": "DBSecurityGroup"}], {"Ref": "AWS::NoValue"}]} + "VPCSecurityGroups": { + "Fn::If": [ + "Is-EC2-VPC", + [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], + {"Ref": "AWS::NoValue"}, + ] + }, + "DBSecurityGroups": { + "Fn::If": [ + "Is-EC2-Classic", + [{"Ref": "DBSecurityGroup"}], + {"Ref": "AWS::NoValue"}, + ] + }, }, - "DeletionPolicy": "Snapshot" + "DeletionPolicy": "Snapshot", }, - "ReplicaDB": { "Type": "AWS::RDS::DBInstance", "Properties": { "SourceDBInstanceIdentifier": {"Ref": "MasterDB"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, - "Tags": [{"Key": "Name", "Value": "Read Replica Database"}] - } - } + "Tags": [{"Key": "Name", "Value": "Read Replica Database"}], + }, + }, }, - "Outputs": { "EC2Platform": { "Description": "Platform in which this stack is deployed", - "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]} + "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]}, }, - "MasterJDBCConnectionString": { "Description": "JDBC connection string for the master database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, }, "ReplicaJDBCConnectionString": { "Description": "JDBC connection string for the replica database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} - } - } + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py b/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py index 2fbfb4cad..30f2210fc 100644 --- a/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py +++ b/tests/test_cloudformation/fixtures/rds_mysql_with_read_replica.py @@ -2,9 +2,7 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template RDS_MySQL_With_Read_Replica: Sample template showing how to create a highly-available, RDS DBInstance with a read replica. **WARNING** This template creates an Amazon Relational Database Service database instance and Amazon CloudWatch alarms. You will be billed for the AWS resources used if you create a stack from this template.", - "Parameters": { "DBName": { "Default": "MyDatabase", @@ -13,13 +11,9 @@ template = { "MinLength": "1", "MaxLength": "64", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - - "DBInstanceIdentifier": { - "Type": "String" - }, - + "DBInstanceIdentifier": {"Type": "String"}, "DBUser": { "NoEcho": "true", "Description": "The database admin account username", @@ -27,9 +21,8 @@ template = { "MinLength": "1", "MaxLength": "16", "AllowedPattern": "[a-zA-Z][a-zA-Z0-9]*", - "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters." + "ConstraintDescription": "must begin with a letter and contain only alphanumeric characters.", }, - "DBPassword": { "NoEcho": "true", "Description": "The database admin account password", @@ -37,101 +30,113 @@ template = { "MinLength": "1", "MaxLength": "41", "AllowedPattern": "[a-zA-Z0-9]+", - "ConstraintDescription": "must contain only alphanumeric characters." + "ConstraintDescription": "must contain only alphanumeric characters.", }, - "DBAllocatedStorage": { "Default": "5", "Description": "The size of the database (Gb)", "Type": "Number", "MinValue": "5", "MaxValue": "1024", - "ConstraintDescription": "must be between 5 and 1024Gb." + "ConstraintDescription": "must be between 5 and 1024Gb.", }, - "DBInstanceClass": { "Description": "The database instance type", "Type": "String", "Default": "db.m1.small", - "AllowedValues": ["db.t1.micro", "db.m1.small", "db.m1.medium", "db.m1.large", "db.m1.xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.m3.medium", "db.m3.large", "db.m3.xlarge", "db.m3.2xlarge", "db.r3.large", "db.r3.xlarge", "db.r3.2xlarge", "db.r3.4xlarge", "db.r3.8xlarge", "db.m2.xlarge", "db.m2.2xlarge", "db.m2.4xlarge", "db.cr1.8xlarge"], - "ConstraintDescription": "must select a valid database instance type." + "AllowedValues": [ + "db.t1.micro", + "db.m1.small", + "db.m1.medium", + "db.m1.large", + "db.m1.xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.m3.medium", + "db.m3.large", + "db.m3.xlarge", + "db.m3.2xlarge", + "db.r3.large", + "db.r3.xlarge", + "db.r3.2xlarge", + "db.r3.4xlarge", + "db.r3.8xlarge", + "db.m2.xlarge", + "db.m2.2xlarge", + "db.m2.4xlarge", + "db.cr1.8xlarge", + ], + "ConstraintDescription": "must select a valid database instance type.", }, - "EC2SecurityGroup": { "Description": "The EC2 security group that contains instances that need access to the database", "Default": "default", "Type": "String", "AllowedPattern": "[a-zA-Z0-9\\-]+", - "ConstraintDescription": "must be a valid security group name." + "ConstraintDescription": "must be a valid security group name.", }, - "MultiAZ": { "Description": "Multi-AZ master database", "Type": "String", "Default": "false", "AllowedValues": ["true", "false"], - "ConstraintDescription": "must be true or false." - } + "ConstraintDescription": "must be true or false.", + }, }, - "Conditions": { - "Is-EC2-VPC": {"Fn::Or": [{"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, - {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}]}, - "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]} + "Is-EC2-VPC": { + "Fn::Or": [ + {"Fn::Equals": [{"Ref": "AWS::Region"}, "eu-central-1"]}, + {"Fn::Equals": [{"Ref": "AWS::Region"}, "cn-north-1"]}, + ] + }, + "Is-EC2-Classic": {"Fn::Not": [{"Condition": "Is-EC2-VPC"}]}, }, - "Resources": { "DBEC2SecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Condition": "Is-EC2-VPC", "Properties": { "GroupDescription": "Open database for access", - "SecurityGroupIngress": [{ - "IpProtocol": "tcp", - "FromPort": "3306", - "ToPort": "3306", - "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"} - }] - } + "SecurityGroupIngress": [ + { + "IpProtocol": "tcp", + "FromPort": "3306", + "ToPort": "3306", + "SourceSecurityGroupName": {"Ref": "EC2SecurityGroup"}, + } + ], + }, }, - "DBSecurityGroup": { "Type": "AWS::RDS::DBSecurityGroup", "Condition": "Is-EC2-Classic", "Properties": { - "DBSecurityGroupIngress": [{ - "EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"} - }], - "GroupDescription": "database access" - } + "DBSecurityGroupIngress": [ + {"EC2SecurityGroupName": {"Ref": "EC2SecurityGroup"}} + ], + "GroupDescription": "database access", + }, }, - - "my_vpc": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - } - }, - + "my_vpc": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, "EC2Subnet": { "Type": "AWS::EC2::Subnet", "Condition": "Is-EC2-VPC", "Properties": { "AvailabilityZone": "eu-central-1a", "CidrBlock": "10.0.1.0/24", - "VpcId": {"Ref": "my_vpc"} - } + "VpcId": {"Ref": "my_vpc"}, + }, }, - "DBSubnet": { "Type": "AWS::RDS::DBSubnetGroup", "Condition": "Is-EC2-VPC", "Properties": { "DBSubnetGroupDescription": "my db subnet group", "SubnetIds": [{"Ref": "EC2Subnet"}], - } + }, }, - "MasterDB": { "Type": "AWS::RDS::DBInstance", "Properties": { @@ -140,54 +145,79 @@ template = { "AllocatedStorage": {"Ref": "DBAllocatedStorage"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, "Engine": "MySQL", - "DBSubnetGroupName": {"Fn::If": ["Is-EC2-VPC", {"Ref": "DBSubnet"}, {"Ref": "AWS::NoValue"}]}, + "DBSubnetGroupName": { + "Fn::If": [ + "Is-EC2-VPC", + {"Ref": "DBSubnet"}, + {"Ref": "AWS::NoValue"}, + ] + }, "MasterUsername": {"Ref": "DBUser"}, "MasterUserPassword": {"Ref": "DBPassword"}, "MultiAZ": {"Ref": "MultiAZ"}, "Tags": [{"Key": "Name", "Value": "Master Database"}], - "VPCSecurityGroups": {"Fn::If": ["Is-EC2-VPC", [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], {"Ref": "AWS::NoValue"}]}, - "DBSecurityGroups": {"Fn::If": ["Is-EC2-Classic", [{"Ref": "DBSecurityGroup"}], {"Ref": "AWS::NoValue"}]} + "VPCSecurityGroups": { + "Fn::If": [ + "Is-EC2-VPC", + [{"Fn::GetAtt": ["DBEC2SecurityGroup", "GroupId"]}], + {"Ref": "AWS::NoValue"}, + ] + }, + "DBSecurityGroups": { + "Fn::If": [ + "Is-EC2-Classic", + [{"Ref": "DBSecurityGroup"}], + {"Ref": "AWS::NoValue"}, + ] + }, }, - "DeletionPolicy": "Snapshot" + "DeletionPolicy": "Snapshot", }, - "ReplicaDB": { "Type": "AWS::RDS::DBInstance", "Properties": { "SourceDBInstanceIdentifier": {"Ref": "MasterDB"}, "DBInstanceClass": {"Ref": "DBInstanceClass"}, - "Tags": [{"Key": "Name", "Value": "Read Replica Database"}] - } - } + "Tags": [{"Key": "Name", "Value": "Read Replica Database"}], + }, + }, }, - "Outputs": { "EC2Platform": { "Description": "Platform in which this stack is deployed", - "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]} + "Value": {"Fn::If": ["Is-EC2-VPC", "EC2-VPC", "EC2-Classic"]}, }, - "MasterJDBCConnectionString": { "Description": "JDBC connection string for the master database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "MasterDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["MasterDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, }, "ReplicaJDBCConnectionString": { "Description": "JDBC connection string for the replica database", - "Value": {"Fn::Join": ["", ["jdbc:mysql://", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Address"]}, - ":", - {"Fn::GetAtt": [ - "ReplicaDB", "Endpoint.Port"]}, - "/", - {"Ref": "DBName"}]]} - } - } + "Value": { + "Fn::Join": [ + "", + [ + "jdbc:mysql://", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Address"]}, + ":", + {"Fn::GetAtt": ["ReplicaDB", "Endpoint.Port"]}, + "/", + {"Ref": "DBName"}, + ], + ] + }, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/redshift.py b/tests/test_cloudformation/fixtures/redshift.py index 317e213bc..6da5c30db 100644 --- a/tests/test_cloudformation/fixtures/redshift.py +++ b/tests/test_cloudformation/fixtures/redshift.py @@ -7,35 +7,35 @@ template = { "Description": "The name of the first database to be created when the cluster is created", "Type": "String", "Default": "dev", - "AllowedPattern": "([a-z]|[0-9])+" + "AllowedPattern": "([a-z]|[0-9])+", }, "ClusterType": { "Description": "The type of cluster", "Type": "String", "Default": "single-node", - "AllowedValues": ["single-node", "multi-node"] + "AllowedValues": ["single-node", "multi-node"], }, "NumberOfNodes": { "Description": "The number of compute nodes in the cluster. For multi-node clusters, the NumberOfNodes parameter must be greater than 1", "Type": "Number", - "Default": "1" + "Default": "1", }, "NodeType": { "Description": "The type of node to be provisioned", "Type": "String", "Default": "dw1.xlarge", - "AllowedValues": ["dw1.xlarge", "dw1.8xlarge", "dw2.large", "dw2.8xlarge"] + "AllowedValues": ["dw1.xlarge", "dw1.8xlarge", "dw2.large", "dw2.8xlarge"], }, "MasterUsername": { "Description": "The user name that is associated with the master user account for the cluster that is being created", "Type": "String", "Default": "defaultuser", - "AllowedPattern": "([a-z])([a-z]|[0-9])*" + "AllowedPattern": "([a-z])([a-z]|[0-9])*", }, - "MasterUserPassword": { + "MasterUserPassword": { "Description": "The password that is associated with the master user account for the cluster that is being created.", "Type": "String", - "NoEcho": "true" + "NoEcho": "true", }, "InboundTraffic": { "Description": "Allow inbound traffic to the cluster from this CIDR range.", @@ -44,18 +44,16 @@ template = { "MaxLength": "18", "Default": "0.0.0.0/0", "AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})", - "ConstraintDescription": "must be a valid CIDR range of the form x.x.x.x/x." + "ConstraintDescription": "must be a valid CIDR range of the form x.x.x.x/x.", }, "PortNumber": { "Description": "The port number on which the cluster accepts incoming connections.", "Type": "Number", - "Default": "5439" - } + "Default": "5439", + }, }, "Conditions": { - "IsMultiNodeCluster": { - "Fn::Equals": [{"Ref": "ClusterType"}, "multi-node"] - } + "IsMultiNodeCluster": {"Fn::Equals": [{"Ref": "ClusterType"}, "multi-node"]} }, "Resources": { "RedshiftCluster": { @@ -63,7 +61,13 @@ template = { "DependsOn": "AttachGateway", "Properties": { "ClusterType": {"Ref": "ClusterType"}, - "NumberOfNodes": {"Fn::If": ["IsMultiNodeCluster", {"Ref": "NumberOfNodes"}, {"Ref": "AWS::NoValue"}]}, + "NumberOfNodes": { + "Fn::If": [ + "IsMultiNodeCluster", + {"Ref": "NumberOfNodes"}, + {"Ref": "AWS::NoValue"}, + ] + }, "NodeType": {"Ref": "NodeType"}, "DBName": {"Ref": "DatabaseName"}, "MasterUsername": {"Ref": "MasterUsername"}, @@ -72,116 +76,106 @@ template = { "VpcSecurityGroupIds": [{"Ref": "SecurityGroup"}], "ClusterSubnetGroupName": {"Ref": "RedshiftClusterSubnetGroup"}, "PubliclyAccessible": "true", - "Port": {"Ref": "PortNumber"} - } + "Port": {"Ref": "PortNumber"}, + }, }, "RedshiftClusterParameterGroup": { "Type": "AWS::Redshift::ClusterParameterGroup", "Properties": { "Description": "Cluster parameter group", "ParameterGroupFamily": "redshift-1.0", - "Parameters": [{ - "ParameterName": "enable_user_activity_logging", - "ParameterValue": "true" - }] - } + "Parameters": [ + { + "ParameterName": "enable_user_activity_logging", + "ParameterValue": "true", + } + ], + }, }, "RedshiftClusterSubnetGroup": { "Type": "AWS::Redshift::ClusterSubnetGroup", "Properties": { "Description": "Cluster subnet group", - "SubnetIds": [{"Ref": "PublicSubnet"}] - } - }, - "VPC": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16" - } + "SubnetIds": [{"Ref": "PublicSubnet"}], + }, }, + "VPC": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, "PublicSubnet": { "Type": "AWS::EC2::Subnet", - "Properties": { - "CidrBlock": "10.0.0.0/24", - "VpcId": {"Ref": "VPC"} - } + "Properties": {"CidrBlock": "10.0.0.0/24", "VpcId": {"Ref": "VPC"}}, }, "SecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "Security group", - "SecurityGroupIngress": [{ - "CidrIp": {"Ref": "InboundTraffic"}, - "FromPort": {"Ref": "PortNumber"}, - "ToPort": {"Ref": "PortNumber"}, - "IpProtocol": "tcp" - }], - "VpcId": {"Ref": "VPC"} - } - }, - "myInternetGateway": { - "Type": "AWS::EC2::InternetGateway" + "SecurityGroupIngress": [ + { + "CidrIp": {"Ref": "InboundTraffic"}, + "FromPort": {"Ref": "PortNumber"}, + "ToPort": {"Ref": "PortNumber"}, + "IpProtocol": "tcp", + } + ], + "VpcId": {"Ref": "VPC"}, + }, }, + "myInternetGateway": {"Type": "AWS::EC2::InternetGateway"}, "AttachGateway": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { "VpcId": {"Ref": "VPC"}, - "InternetGatewayId": {"Ref": "myInternetGateway"} - } + "InternetGatewayId": {"Ref": "myInternetGateway"}, + }, }, "PublicRouteTable": { "Type": "AWS::EC2::RouteTable", - "Properties": { - "VpcId": { - "Ref": "VPC" - } - } + "Properties": {"VpcId": {"Ref": "VPC"}}, }, "PublicRoute": { "Type": "AWS::EC2::Route", "DependsOn": "AttachGateway", "Properties": { - "RouteTableId": { - "Ref": "PublicRouteTable" - }, + "RouteTableId": {"Ref": "PublicRouteTable"}, "DestinationCidrBlock": "0.0.0.0/0", - "GatewayId": { - "Ref": "myInternetGateway" - } - } + "GatewayId": {"Ref": "myInternetGateway"}, + }, }, "PublicSubnetRouteTableAssociation": { "Type": "AWS::EC2::SubnetRouteTableAssociation", "Properties": { - "SubnetId": { - "Ref": "PublicSubnet" - }, - "RouteTableId": { - "Ref": "PublicRouteTable" - } - } - } + "SubnetId": {"Ref": "PublicSubnet"}, + "RouteTableId": {"Ref": "PublicRouteTable"}, + }, + }, }, "Outputs": { "ClusterEndpoint": { "Description": "Cluster endpoint", - "Value": {"Fn::Join": [":", [{"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Address"]}, {"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Port"]}]]} + "Value": { + "Fn::Join": [ + ":", + [ + {"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Address"]}, + {"Fn::GetAtt": ["RedshiftCluster", "Endpoint.Port"]}, + ], + ] + }, }, "ClusterName": { "Description": "Name of cluster", - "Value": {"Ref": "RedshiftCluster"} + "Value": {"Ref": "RedshiftCluster"}, }, "ParameterGroupName": { "Description": "Name of parameter group", - "Value": {"Ref": "RedshiftClusterParameterGroup"} + "Value": {"Ref": "RedshiftClusterParameterGroup"}, }, "RedshiftClusterSubnetGroupName": { "Description": "Name of cluster subnet group", - "Value": {"Ref": "RedshiftClusterSubnetGroup"} + "Value": {"Ref": "RedshiftClusterSubnetGroup"}, }, "RedshiftClusterSecurityGroupName": { "Description": "Name of cluster security group", - "Value": {"Ref": "SecurityGroup"} - } - } + "Value": {"Ref": "SecurityGroup"}, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py b/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py index 43a11104b..3f5735bba 100644 --- a/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py +++ b/tests/test_cloudformation/fixtures/route53_ec2_instance_with_public_ip.py @@ -1,47 +1,38 @@ from __future__ import unicode_literals template = { - "Parameters": { - "R53ZoneName": { - "Type": "String", - "Default": "my_zone" - } - }, - + "Parameters": {"R53ZoneName": {"Type": "String", "Default": "my_zone"}}, "Resources": { "Ec2Instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "PrivateIpAddress": "10.0.0.25", - } + "Properties": {"ImageId": "ami-1234abcd", "PrivateIpAddress": "10.0.0.25"}, }, - "HostedZone": { "Type": "AWS::Route53::HostedZone", - "Properties": { - "Name": {"Ref": "R53ZoneName"} - } + "Properties": {"Name": {"Ref": "R53ZoneName"}}, }, - "myDNSRecord": { "Type": "AWS::Route53::RecordSet", "Properties": { "HostedZoneId": {"Ref": "HostedZone"}, "Comment": "DNS name for my instance.", "Name": { - "Fn::Join": ["", [ - {"Ref": "Ec2Instance"}, ".", - {"Ref": "AWS::Region"}, ".", - {"Ref": "R53ZoneName"}, "." - ]] + "Fn::Join": [ + "", + [ + {"Ref": "Ec2Instance"}, + ".", + {"Ref": "AWS::Region"}, + ".", + {"Ref": "R53ZoneName"}, + ".", + ], + ] }, "Type": "A", "TTL": "900", - "ResourceRecords": [ - {"Fn::GetAtt": ["Ec2Instance", "PrivateIp"]} - ] - } - } + "ResourceRecords": [{"Fn::GetAtt": ["Ec2Instance", "PrivateIp"]}], + }, + }, }, } diff --git a/tests/test_cloudformation/fixtures/route53_health_check.py b/tests/test_cloudformation/fixtures/route53_health_check.py index 420cd38ba..876caf299 100644 --- a/tests/test_cloudformation/fixtures/route53_health_check.py +++ b/tests/test_cloudformation/fixtures/route53_health_check.py @@ -4,11 +4,8 @@ template = { "Resources": { "HostedZone": { "Type": "AWS::Route53::HostedZone", - "Properties": { - "Name": "my_zone" - } + "Properties": {"Name": "my_zone"}, }, - "my_health_check": { "Type": "AWS::Route53::HealthCheck", "Properties": { @@ -20,9 +17,8 @@ template = { "ResourcePath": "/", "Type": "HTTP", } - } + }, }, - "myDNSRecord": { "Type": "AWS::Route53::RecordSet", "Properties": { @@ -33,7 +29,7 @@ template = { "TTL": "900", "ResourceRecords": ["my.example.com"], "HealthCheckId": {"Ref": "my_health_check"}, - } - } - }, + }, + }, + } } diff --git a/tests/test_cloudformation/fixtures/route53_roundrobin.py b/tests/test_cloudformation/fixtures/route53_roundrobin.py index 199e3e088..9c9f8a6f9 100644 --- a/tests/test_cloudformation/fixtures/route53_roundrobin.py +++ b/tests/test_cloudformation/fixtures/route53_roundrobin.py @@ -2,53 +2,71 @@ from __future__ import unicode_literals template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "AWS CloudFormation Sample Template Route53_RoundRobin: Sample template showing how to use weighted round robin (WRR) DNS entried via Amazon Route 53. This contrived sample uses weighted CNAME records to illustrate that the weighting influences the return records. It assumes that you already have a Hosted Zone registered with Amazon Route 53. **WARNING** This template creates one or more AWS resources. You will be billed for the AWS resources used if you create a stack from this template.", - - "Parameters": { - "R53ZoneName": { - "Type": "String", - "Default": "my_zone" - } - }, - + "Parameters": {"R53ZoneName": {"Type": "String", "Default": "my_zone"}}, "Resources": { - "MyZone": { "Type": "AWS::Route53::HostedZone", - "Properties": { - "Name": {"Ref": "R53ZoneName"} - } + "Properties": {"Name": {"Ref": "R53ZoneName"}}, }, - "MyDNSRecord": { "Type": "AWS::Route53::RecordSetGroup", "Properties": { "HostedZoneId": {"Ref": "MyZone"}, "Comment": "Contrived example to redirect to aws.amazon.com 75% of the time and www.amazon.com 25% of the time.", - "RecordSets": [{ - "SetIdentifier": {"Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "AWS"]]}, - "Name": {"Fn::Join": ["", [{"Ref": "AWS::StackName"}, ".", {"Ref": "AWS::Region"}, ".", {"Ref": "R53ZoneName"}, "."]]}, - "Type": "CNAME", - "TTL": "900", - "ResourceRecords": ["aws.amazon.com"], - "Weight": "3" - }, { - "SetIdentifier": {"Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "Amazon"]]}, - "Name": {"Fn::Join": ["", [{"Ref": "AWS::StackName"}, ".", {"Ref": "AWS::Region"}, ".", {"Ref": "R53ZoneName"}, "."]]}, - "Type": "CNAME", - "TTL": "900", - "ResourceRecords": ["www.amazon.com"], - "Weight": "1" - }] - } - } + "RecordSets": [ + { + "SetIdentifier": { + "Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "AWS"]] + }, + "Name": { + "Fn::Join": [ + "", + [ + {"Ref": "AWS::StackName"}, + ".", + {"Ref": "AWS::Region"}, + ".", + {"Ref": "R53ZoneName"}, + ".", + ], + ] + }, + "Type": "CNAME", + "TTL": "900", + "ResourceRecords": ["aws.amazon.com"], + "Weight": "3", + }, + { + "SetIdentifier": { + "Fn::Join": [" ", [{"Ref": "AWS::StackName"}, "Amazon"]] + }, + "Name": { + "Fn::Join": [ + "", + [ + {"Ref": "AWS::StackName"}, + ".", + {"Ref": "AWS::Region"}, + ".", + {"Ref": "R53ZoneName"}, + ".", + ], + ] + }, + "Type": "CNAME", + "TTL": "900", + "ResourceRecords": ["www.amazon.com"], + "Weight": "1", + }, + ], + }, + }, }, - "Outputs": { "DomainName": { "Description": "Fully qualified domain name", - "Value": {"Ref": "MyDNSRecord"} + "Value": {"Ref": "MyDNSRecord"}, } - } + }, } diff --git a/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py b/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py index 37c7ca4f3..7962d2c56 100644 --- a/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py +++ b/tests/test_cloudformation/fixtures/single_instance_with_ebs_volume.py @@ -10,7 +10,7 @@ template = { "MinLength": "9", "AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})", "MaxLength": "18", - "Type": "String" + "Type": "String", }, "KeyName": { "Type": "String", @@ -18,7 +18,7 @@ template = { "MinLength": "1", "AllowedPattern": "[\\x20-\\x7E]*", "MaxLength": "255", - "ConstraintDescription": "can contain only ASCII characters." + "ConstraintDescription": "can contain only ASCII characters.", }, "InstanceType": { "Default": "m1.small", @@ -40,8 +40,8 @@ template = { "c1.xlarge", "cc1.4xlarge", "cc2.8xlarge", - "cg1.4xlarge" - ] + "cg1.4xlarge", + ], }, "VolumeSize": { "Description": "WebServer EC2 instance type", @@ -49,8 +49,8 @@ template = { "Type": "Number", "MaxValue": "1024", "MinValue": "5", - "ConstraintDescription": "must be between 5 and 1024 Gb." - } + "ConstraintDescription": "must be between 5 and 1024 Gb.", + }, }, "AWSTemplateFormatVersion": "2010-09-09", "Outputs": { @@ -59,17 +59,9 @@ template = { "Value": { "Fn::Join": [ "", - [ - "http://", - { - "Fn::GetAtt": [ - "WebServer", - "PublicDnsName" - ] - } - ] + ["http://", {"Fn::GetAtt": ["WebServer", "PublicDnsName"]}], ] - } + }, } }, "Resources": { @@ -81,19 +73,17 @@ template = { "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", - "FromPort": "80" + "FromPort": "80", }, { "ToPort": "22", "IpProtocol": "tcp", - "CidrIp": { - "Ref": "SSHLocation" - }, - "FromPort": "22" - } + "CidrIp": {"Ref": "SSHLocation"}, + "FromPort": "22", + }, ], - "GroupDescription": "Enable SSH access and HTTP access on the inbound port" - } + "GroupDescription": "Enable SSH access and HTTP access on the inbound port", + }, }, "WebServer": { "Type": "AWS::EC2::Instance", @@ -108,23 +98,17 @@ template = { "# Helper function\n", "function error_exit\n", "{\n", - " /opt/aws/bin/cfn-signal -e 1 -r \"$1\" '", - { - "Ref": "WaitHandle" - }, + ' /opt/aws/bin/cfn-signal -e 1 -r "$1" \'', + {"Ref": "WaitHandle"}, "'\n", " exit 1\n", "}\n", "# Install Rails packages\n", "/opt/aws/bin/cfn-init -s ", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, " -r WebServer ", " --region ", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, " || error_exit 'Failed to run cfn-init'\n", "# Wait for the EBS volume to show up\n", "while [ ! -e /dev/sdh ]; do echo Waiting for EBS volume to attach; sleep 5; done\n", @@ -137,56 +121,38 @@ template = { "git init\n", "gollum --port 80 --host 0.0.0.0 &\n", "# If all is well so signal success\n", - "/opt/aws/bin/cfn-signal -e $? -r \"Rails application setup complete\" '", - { - "Ref": "WaitHandle" - }, - "'\n" - ] + '/opt/aws/bin/cfn-signal -e $? -r "Rails application setup complete" \'', + {"Ref": "WaitHandle"}, + "'\n", + ], ] } }, - "KeyName": { - "Ref": "KeyName" - }, - "SecurityGroups": [ - { - "Ref": "WebServerSecurityGroup" - } - ], - "InstanceType": { - "Ref": "InstanceType" - }, + "KeyName": {"Ref": "KeyName"}, + "SecurityGroups": [{"Ref": "WebServerSecurityGroup"}], + "InstanceType": {"Ref": "InstanceType"}, "ImageId": { "Fn::FindInMap": [ "AWSRegionArch2AMI", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, { "Fn::FindInMap": [ "AWSInstanceType2Arch", - { - "Ref": "InstanceType" - }, - "Arch" + {"Ref": "InstanceType"}, + "Arch", ] - } + }, ] - } + }, }, "Metadata": { "AWS::CloudFormation::Init": { "config": { "packages": { "rubygems": { - "nokogiri": [ - "1.5.10" - ], + "nokogiri": ["1.5.10"], "rdiscount": [], - "gollum": [ - "1.1.1" - ] + "gollum": ["1.1.1"], }, "yum": { "libxslt-devel": [], @@ -196,150 +162,99 @@ template = { "ruby-devel": [], "ruby-rdoc": [], "make": [], - "libxml2-devel": [] - } + "libxml2-devel": [], + }, } } } - } + }, }, "DataVolume": { "Type": "AWS::EC2::Volume", "Properties": { - "Tags": [ - { - "Value": "Gollum Data Volume", - "Key": "Usage" - } - ], - "AvailabilityZone": { - "Fn::GetAtt": [ - "WebServer", - "AvailabilityZone" - ] - }, + "Tags": [{"Value": "Gollum Data Volume", "Key": "Usage"}], + "AvailabilityZone": {"Fn::GetAtt": ["WebServer", "AvailabilityZone"]}, "Size": "100", - } + }, }, "MountPoint": { "Type": "AWS::EC2::VolumeAttachment", "Properties": { - "InstanceId": { - "Ref": "WebServer" - }, + "InstanceId": {"Ref": "WebServer"}, "Device": "/dev/sdh", - "VolumeId": { - "Ref": "DataVolume" - } - } + "VolumeId": {"Ref": "DataVolume"}, + }, }, "WaitCondition": { "DependsOn": "MountPoint", "Type": "AWS::CloudFormation::WaitCondition", - "Properties": { - "Handle": { - "Ref": "WaitHandle" - }, - "Timeout": "300" - }, + "Properties": {"Handle": {"Ref": "WaitHandle"}, "Timeout": "300"}, "Metadata": { "Comment1": "Note that the WaitCondition is dependent on the volume mount point allowing the volume to be created and attached to the EC2 instance", - "Comment2": "The instance bootstrap script waits for the volume to be attached to the instance prior to installing Gollum and signalling completion" - } + "Comment2": "The instance bootstrap script waits for the volume to be attached to the instance prior to installing Gollum and signalling completion", + }, }, - "WaitHandle": { - "Type": "AWS::CloudFormation::WaitConditionHandle" - } + "WaitHandle": {"Type": "AWS::CloudFormation::WaitConditionHandle"}, }, "Mappings": { "AWSInstanceType2Arch": { - "m3.2xlarge": { - "Arch": "64" - }, - "m2.2xlarge": { - "Arch": "64" - }, - "m1.small": { - "Arch": "64" - }, - "c1.medium": { - "Arch": "64" - }, - "cg1.4xlarge": { - "Arch": "64HVM" - }, - "m2.xlarge": { - "Arch": "64" - }, - "t1.micro": { - "Arch": "64" - }, - "cc1.4xlarge": { - "Arch": "64HVM" - }, - "m1.medium": { - "Arch": "64" - }, - "cc2.8xlarge": { - "Arch": "64HVM" - }, - "m1.large": { - "Arch": "64" - }, - "m1.xlarge": { - "Arch": "64" - }, - "m2.4xlarge": { - "Arch": "64" - }, - "c1.xlarge": { - "Arch": "64" - }, - "m3.xlarge": { - "Arch": "64" - } + "m3.2xlarge": {"Arch": "64"}, + "m2.2xlarge": {"Arch": "64"}, + "m1.small": {"Arch": "64"}, + "c1.medium": {"Arch": "64"}, + "cg1.4xlarge": {"Arch": "64HVM"}, + "m2.xlarge": {"Arch": "64"}, + "t1.micro": {"Arch": "64"}, + "cc1.4xlarge": {"Arch": "64HVM"}, + "m1.medium": {"Arch": "64"}, + "cc2.8xlarge": {"Arch": "64HVM"}, + "m1.large": {"Arch": "64"}, + "m1.xlarge": {"Arch": "64"}, + "m2.4xlarge": {"Arch": "64"}, + "c1.xlarge": {"Arch": "64"}, + "m3.xlarge": {"Arch": "64"}, }, "AWSRegionArch2AMI": { "ap-southeast-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-b4b0cae6", - "64": "ami-beb0caec" + "64": "ami-beb0caec", }, "ap-southeast-2": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-b3990e89", - "64": "ami-bd990e87" + "64": "ami-bd990e87", }, "us-west-2": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-38fe7308", - "64": "ami-30fe7300" + "64": "ami-30fe7300", }, "us-east-1": { "64HVM": "ami-0da96764", "32": "ami-31814f58", - "64": "ami-1b814f72" + "64": "ami-1b814f72", }, "ap-northeast-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-0644f007", - "64": "ami-0a44f00b" + "64": "ami-0a44f00b", }, "us-west-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-11d68a54", - "64": "ami-1bd68a5e" + "64": "ami-1bd68a5e", }, "eu-west-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-973b06e3", - "64": "ami-953b06e1" + "64": "ami-953b06e1", }, "sa-east-1": { "64HVM": "NOT_YET_SUPPORTED", "32": "ami-3e3be423", - "64": "ami-3c3be421" - } - } - } + "64": "ami-3c3be421", + }, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/vpc_eip.py b/tests/test_cloudformation/fixtures/vpc_eip.py index c7a46c830..b5bd48c01 100644 --- a/tests/test_cloudformation/fixtures/vpc_eip.py +++ b/tests/test_cloudformation/fixtures/vpc_eip.py @@ -1,12 +1,5 @@ from __future__ import unicode_literals template = { - "Resources": { - "VPCEIP": { - "Type": "AWS::EC2::EIP", - "Properties": { - "Domain": "vpc" - } - } - } + "Resources": {"VPCEIP": {"Type": "AWS::EC2::EIP", "Properties": {"Domain": "vpc"}}} } diff --git a/tests/test_cloudformation/fixtures/vpc_eni.py b/tests/test_cloudformation/fixtures/vpc_eni.py index 3f8eb2d03..fc2d7d61b 100644 --- a/tests/test_cloudformation/fixtures/vpc_eni.py +++ b/tests/test_cloudformation/fixtures/vpc_eni.py @@ -6,33 +6,26 @@ template = { "Resources": { "ENI": { "Type": "AWS::EC2::NetworkInterface", - "Properties": { - "SubnetId": {"Ref": "Subnet"} - } + "Properties": {"SubnetId": {"Ref": "Subnet"}}, }, "Subnet": { "Type": "AWS::EC2::Subnet", "Properties": { "AvailabilityZone": "us-east-1a", "VpcId": {"Ref": "VPC"}, - "CidrBlock": "10.0.0.0/24" - } + "CidrBlock": "10.0.0.0/24", + }, }, - "VPC": { - "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16" - } - } + "VPC": {"Type": "AWS::EC2::VPC", "Properties": {"CidrBlock": "10.0.0.0/16"}}, }, "Outputs": { "NinjaENI": { "Description": "Elastic IP mapping to Auto-Scaling Group", - "Value": {"Ref": "ENI"} + "Value": {"Ref": "ENI"}, }, "ENIIpAddress": { "Description": "ENI's Private IP address", - "Value": {"Fn::GetAtt": ["ENI", "PrimaryPrivateIpAddress"]} - } - } + "Value": {"Fn::GetAtt": ["ENI", "PrimaryPrivateIpAddress"]}, + }, + }, } diff --git a/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py b/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py index 177da884e..ff7b75518 100644 --- a/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py +++ b/tests/test_cloudformation/fixtures/vpc_single_instance_in_subnet.py @@ -10,7 +10,7 @@ template = { "MinLength": "9", "AllowedPattern": "(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})\\.(\\d{1,3})/(\\d{1,2})", "MaxLength": "18", - "Type": "String" + "Type": "String", }, "KeyName": { "Type": "String", @@ -18,7 +18,7 @@ template = { "MinLength": "1", "AllowedPattern": "[\\x20-\\x7E]*", "MaxLength": "255", - "ConstraintDescription": "can contain only ASCII characters." + "ConstraintDescription": "can contain only ASCII characters.", }, "InstanceType": { "Default": "m1.small", @@ -40,9 +40,9 @@ template = { "c1.xlarge", "cc1.4xlarge", "cc2.8xlarge", - "cg1.4xlarge" - ] - } + "cg1.4xlarge", + ], + }, }, "AWSTemplateFormatVersion": "2010-09-09", "Outputs": { @@ -51,116 +51,61 @@ template = { "Value": { "Fn::Join": [ "", - [ - "http://", - { - "Fn::GetAtt": [ - "WebServerInstance", - "PublicIp" - ] - } - ] + ["http://", {"Fn::GetAtt": ["WebServerInstance", "PublicIp"]}], ] - } + }, } }, "Resources": { "Subnet": { "Type": "AWS::EC2::Subnet", "Properties": { - "VpcId": { - "Ref": "VPC" - }, + "VpcId": {"Ref": "VPC"}, "CidrBlock": "10.0.0.0/24", - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } - }, - "WebServerWaitHandle": { - "Type": "AWS::CloudFormation::WaitConditionHandle" + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], + }, }, + "WebServerWaitHandle": {"Type": "AWS::CloudFormation::WaitConditionHandle"}, "Route": { "Type": "AWS::EC2::Route", "Properties": { - "GatewayId": { - "Ref": "InternetGateway" - }, + "GatewayId": {"Ref": "InternetGateway"}, "DestinationCidrBlock": "0.0.0.0/0", - "RouteTableId": { - "Ref": "RouteTable" - } + "RouteTableId": {"Ref": "RouteTable"}, }, - "DependsOn": "AttachGateway" + "DependsOn": "AttachGateway", }, "SubnetRouteTableAssociation": { "Type": "AWS::EC2::SubnetRouteTableAssociation", "Properties": { - "SubnetId": { - "Ref": "Subnet" - }, - "RouteTableId": { - "Ref": "RouteTable" - } - } + "SubnetId": {"Ref": "Subnet"}, + "RouteTableId": {"Ref": "RouteTable"}, + }, }, "InternetGateway": { "Type": "AWS::EC2::InternetGateway", "Properties": { - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}] + }, }, "RouteTable": { "Type": "AWS::EC2::RouteTable", "Properties": { - "VpcId": { - "Ref": "VPC" - }, - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } + "VpcId": {"Ref": "VPC"}, + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], + }, }, "WebServerWaitCondition": { "Type": "AWS::CloudFormation::WaitCondition", - "Properties": { - "Handle": { - "Ref": "WebServerWaitHandle" - }, - "Timeout": "300" - }, - "DependsOn": "WebServerInstance" + "Properties": {"Handle": {"Ref": "WebServerWaitHandle"}, "Timeout": "300"}, + "DependsOn": "WebServerInstance", }, "VPC": { "Type": "AWS::EC2::VPC", "Properties": { "CidrBlock": "10.0.0.0/16", - "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - } - ] - } + "Tags": [{"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}], + }, }, "InstanceSecurityGroup": { "Type": "AWS::EC2::SecurityGroup", @@ -169,23 +114,19 @@ template = { { "ToPort": "22", "IpProtocol": "tcp", - "CidrIp": { - "Ref": "SSHLocation" - }, - "FromPort": "22" + "CidrIp": {"Ref": "SSHLocation"}, + "FromPort": "22", }, { "ToPort": "80", "IpProtocol": "tcp", "CidrIp": "0.0.0.0/0", - "FromPort": "80" - } + "FromPort": "80", + }, ], - "VpcId": { - "Ref": "VPC" - }, - "GroupDescription": "Enable SSH access via port 22" - } + "VpcId": {"Ref": "VPC"}, + "GroupDescription": "Enable SSH access via port 22", + }, }, "WebServerInstance": { "Type": "AWS::EC2::Instance", @@ -200,71 +141,39 @@ template = { "# Helper function\n", "function error_exit\n", "{\n", - " /opt/aws/bin/cfn-signal -e 1 -r \"$1\" '", - { - "Ref": "WebServerWaitHandle" - }, + ' /opt/aws/bin/cfn-signal -e 1 -r "$1" \'', + {"Ref": "WebServerWaitHandle"}, "'\n", " exit 1\n", "}\n", "# Install the simple web page\n", "/opt/aws/bin/cfn-init -s ", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, " -r WebServerInstance ", " --region ", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, " || error_exit 'Failed to run cfn-init'\n", "# Start up the cfn-hup daemon to listen for changes to the Web Server metadata\n", "/opt/aws/bin/cfn-hup || error_exit 'Failed to start cfn-hup'\n", "# All done so signal success\n", - "/opt/aws/bin/cfn-signal -e 0 -r \"WebServer setup complete\" '", - { - "Ref": "WebServerWaitHandle" - }, - "'\n" - ] + '/opt/aws/bin/cfn-signal -e 0 -r "WebServer setup complete" \'', + {"Ref": "WebServerWaitHandle"}, + "'\n", + ], ] } }, "Tags": [ - { - "Value": { - "Ref": "AWS::StackId" - }, - "Key": "Application" - }, - { - "Value": "Bar", - "Key": "Foo" - } + {"Value": {"Ref": "AWS::StackId"}, "Key": "Application"}, + {"Value": "Bar", "Key": "Foo"}, ], - "SecurityGroupIds": [ - { - "Ref": "InstanceSecurityGroup" - } - ], - "KeyName": { - "Ref": "KeyName" - }, - "SubnetId": { - "Ref": "Subnet" - }, + "SecurityGroupIds": [{"Ref": "InstanceSecurityGroup"}], + "KeyName": {"Ref": "KeyName"}, + "SubnetId": {"Ref": "Subnet"}, "ImageId": { - "Fn::FindInMap": [ - "RegionMap", - { - "Ref": "AWS::Region" - }, - "AMI" - ] + "Fn::FindInMap": ["RegionMap", {"Ref": "AWS::Region"}, "AMI"] }, - "InstanceType": { - "Ref": "InstanceType" - } + "InstanceType": {"Ref": "InstanceType"}, }, "Metadata": { "Comment": "Install a simple PHP application", @@ -278,21 +187,17 @@ template = { [ "[main]\n", "stack=", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, "\n", "region=", - { - "Ref": "AWS::Region" - }, - "\n" - ] + {"Ref": "AWS::Region"}, + "\n", + ], ] }, "owner": "root", "group": "root", - "mode": "000400" + "mode": "000400", }, "/etc/cfn/hooks.d/cfn-auto-reloader.conf": { "content": { @@ -303,17 +208,13 @@ template = { "triggers=post.update\n", "path=Resources.WebServerInstance.Metadata.AWS::CloudFormation::Init\n", "action=/opt/aws/bin/cfn-init -s ", - { - "Ref": "AWS::StackId" - }, + {"Ref": "AWS::StackId"}, " -r WebServerInstance ", " --region ", - { - "Ref": "AWS::Region" - }, + {"Ref": "AWS::Region"}, "\n", - "runas=root\n" - ] + "runas=root\n", + ], ] } }, @@ -324,85 +225,52 @@ template = { [ "AWS CloudFormation sample PHP application';\n", - "?>\n" - ] + "?>\n", + ], ] }, "owner": "apache", "group": "apache", - "mode": "000644" - } + "mode": "000644", + }, }, "services": { "sysvinit": { - "httpd": { - "ensureRunning": "true", - "enabled": "true" - }, + "httpd": {"ensureRunning": "true", "enabled": "true"}, "sendmail": { "ensureRunning": "false", - "enabled": "false" - } + "enabled": "false", + }, } }, - "packages": { - "yum": { - "httpd": [], - "php": [] - } - } + "packages": {"yum": {"httpd": [], "php": []}}, } - } - } + }, + }, }, "IPAddress": { "Type": "AWS::EC2::EIP", - "Properties": { - "InstanceId": { - "Ref": "WebServerInstance" - }, - "Domain": "vpc" - }, - "DependsOn": "AttachGateway" + "Properties": {"InstanceId": {"Ref": "WebServerInstance"}, "Domain": "vpc"}, + "DependsOn": "AttachGateway", }, "AttachGateway": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { - "VpcId": { - "Ref": "VPC" - }, - "InternetGatewayId": { - "Ref": "InternetGateway" - } - } - } + "VpcId": {"Ref": "VPC"}, + "InternetGatewayId": {"Ref": "InternetGateway"}, + }, + }, }, "Mappings": { "RegionMap": { - "ap-southeast-1": { - "AMI": "ami-74dda626" - }, - "ap-southeast-2": { - "AMI": "ami-b3990e89" - }, - "us-west-2": { - "AMI": "ami-16fd7026" - }, - "us-east-1": { - "AMI": "ami-7f418316" - }, - "ap-northeast-1": { - "AMI": "ami-dcfa4edd" - }, - "us-west-1": { - "AMI": "ami-951945d0" - }, - "eu-west-1": { - "AMI": "ami-24506250" - }, - "sa-east-1": { - "AMI": "ami-3e3be423" - } + "ap-southeast-1": {"AMI": "ami-74dda626"}, + "ap-southeast-2": {"AMI": "ami-b3990e89"}, + "us-west-2": {"AMI": "ami-16fd7026"}, + "us-east-1": {"AMI": "ami-7f418316"}, + "ap-northeast-1": {"AMI": "ami-dcfa4edd"}, + "us-west-1": {"AMI": "ami-951945d0"}, + "eu-west-1": {"AMI": "ami-24506250"}, + "sa-east-1": {"AMI": "ami-3e3be423"}, } - } + }, } diff --git a/tests/test_cloudformation/test_cloudformation_stack_crud.py b/tests/test_cloudformation/test_cloudformation_stack_crud.py index 27424bf8c..bfd66935a 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_crud.py +++ b/tests/test_cloudformation/test_cloudformation_stack_crud.py @@ -9,11 +9,16 @@ import boto.s3.key import boto.cloudformation from boto.exception import BotoServerError import sure # noqa + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises # noqa from nose.tools import assert_raises -from moto import mock_cloudformation_deprecated, mock_s3_deprecated, mock_route53_deprecated +from moto import ( + mock_cloudformation_deprecated, + mock_s3_deprecated, + mock_route53_deprecated, +) from moto.cloudformation import cloudformation_backends dummy_template = { @@ -33,12 +38,7 @@ dummy_template3 = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 3", "Resources": { - "VPC": { - "Properties": { - "CidrBlock": "192.168.0.0/16", - }, - "Type": "AWS::EC2::VPC" - } + "VPC": {"Properties": {"CidrBlock": "192.168.0.0/16"}, "Type": "AWS::EC2::VPC"} }, } @@ -50,24 +50,22 @@ dummy_template_json3 = json.dumps(dummy_template3) @mock_cloudformation_deprecated def test_create_stack(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks()[0] - stack.stack_name.should.equal('test_stack') - stack.get_template().should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + stack.stack_name.should.equal("test_stack") + stack.get_template().should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - - }) + ) @mock_cloudformation_deprecated @@ -77,44 +75,34 @@ def test_create_stack_hosted_zone_by_id(): dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 1", - "Parameters": { - }, + "Parameters": {}, "Resources": { "Bar": { - "Type" : "AWS::Route53::HostedZone", - "Properties" : { - "Name" : "foo.bar.baz", - } - }, + "Type": "AWS::Route53::HostedZone", + "Properties": {"Name": "foo.bar.baz"}, + } }, } dummy_template2 = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 2", - "Parameters": { - "ZoneId": { "Type": "String" } - }, + "Parameters": {"ZoneId": {"Type": "String"}}, "Resources": { "Foo": { - "Properties": { - "HostedZoneId": {"Ref": "ZoneId"}, - "RecordSets": [] - }, - "Type": "AWS::Route53::RecordSetGroup" + "Properties": {"HostedZoneId": {"Ref": "ZoneId"}, "RecordSets": []}, + "Type": "AWS::Route53::RecordSetGroup", } }, } conn.create_stack( - "test_stack", - template_body=json.dumps(dummy_template), - parameters={}.items() + "test_stack", template_body=json.dumps(dummy_template), parameters={}.items() ) r53_conn = boto.connect_route53() zone_id = r53_conn.get_zones()[0].id conn.create_stack( "test_stack", template_body=json.dumps(dummy_template2), - parameters={"ZoneId": zone_id}.items() + parameters={"ZoneId": zone_id}.items(), ) stack = conn.describe_stacks()[0] @@ -139,62 +127,57 @@ def test_create_stack_with_notification_arn(): conn.create_stack( "test_stack_with_notifications", template_body=dummy_template_json, - notification_arns='arn:aws:sns:us-east-1:123456789012:fake-queue' + notification_arns="arn:aws:sns:us-east-1:123456789012:fake-queue", ) stack = conn.describe_stacks()[0] [n.value for n in stack.notification_arns].should.contain( - 'arn:aws:sns:us-east-1:123456789012:fake-queue') + "arn:aws:sns:us-east-1:123456789012:fake-queue" + ) @mock_cloudformation_deprecated @mock_s3_deprecated def test_create_stack_from_s3_url(): - s3_conn = boto.s3.connect_to_region('us-west-1') + s3_conn = boto.s3.connect_to_region("us-west-1") bucket = s3_conn.create_bucket("foobar") key = boto.s3.key.Key(bucket) key.key = "template-key" key.set_contents_from_string(dummy_template_json) key_url = key.generate_url(expires_in=0, query_auth=False) - conn = boto.cloudformation.connect_to_region('us-west-1') - conn.create_stack('new-stack', template_url=key_url) + conn = boto.cloudformation.connect_to_region("us-west-1") + conn.create_stack("new-stack", template_url=key_url) stack = conn.describe_stacks()[0] - stack.stack_name.should.equal('new-stack') + stack.stack_name.should.equal("new-stack") stack.get_template().should.equal( { - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' - } + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } - - }) + } + ) @mock_cloudformation_deprecated def test_describe_stack_by_name(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks("test_stack")[0] - stack.stack_name.should.equal('test_stack') + stack.stack_name.should.equal("test_stack") @mock_cloudformation_deprecated def test_describe_stack_by_stack_id(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks("test_stack")[0] stack_by_id = conn.describe_stacks(stack.stack_id)[0] @@ -205,10 +188,7 @@ def test_describe_stack_by_stack_id(): @mock_cloudformation_deprecated def test_describe_deleted_stack(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) stack = conn.describe_stacks("test_stack")[0] stack_id = stack.stack_id @@ -222,36 +202,28 @@ def test_describe_deleted_stack(): @mock_cloudformation_deprecated def test_get_template_by_name(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) template = conn.get_template("test_stack") - template.should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + template.should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - - }) + ) @mock_cloudformation_deprecated def test_list_stacks(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) - conn.create_stack( - "test_stack2", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) + conn.create_stack("test_stack2", template_body=dummy_template_json) stacks = conn.list_stacks() stacks.should.have.length_of(2) @@ -261,10 +233,7 @@ def test_list_stacks(): @mock_cloudformation_deprecated def test_delete_stack_by_name(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) conn.describe_stacks().should.have.length_of(1) conn.delete_stack("test_stack") @@ -274,10 +243,7 @@ def test_delete_stack_by_name(): @mock_cloudformation_deprecated def test_delete_stack_by_id(): conn = boto.connect_cloudformation() - stack_id = conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + stack_id = conn.create_stack("test_stack", template_body=dummy_template_json) conn.describe_stacks().should.have.length_of(1) conn.delete_stack(stack_id) @@ -291,10 +257,7 @@ def test_delete_stack_by_id(): @mock_cloudformation_deprecated def test_delete_stack_with_resource_missing_delete_attr(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json3, - ) + conn.create_stack("test_stack", template_body=dummy_template_json3) conn.describe_stacks().should.have.length_of(1) conn.delete_stack("test_stack") @@ -318,19 +281,22 @@ def test_cloudformation_params(): "APPNAME": { "Default": "app-name", "Description": "The name of the app", - "Type": "String" + "Type": "String", } - } + }, } dummy_template_json = json.dumps(dummy_template) cfn = boto.connect_cloudformation() - cfn.create_stack('test_stack1', template_body=dummy_template_json, parameters=[ - ('APPNAME', 'testing123')]) - stack = cfn.describe_stacks('test_stack1')[0] + cfn.create_stack( + "test_stack1", + template_body=dummy_template_json, + parameters=[("APPNAME", "testing123")], + ) + stack = cfn.describe_stacks("test_stack1")[0] stack.parameters.should.have.length_of(1) param = stack.parameters[0] - param.key.should.equal('APPNAME') - param.value.should.equal('testing123') + param.key.should.equal("APPNAME") + param.value.should.equal("testing123") @mock_cloudformation_deprecated @@ -339,52 +305,34 @@ def test_cloudformation_params_conditions_and_resources_are_distinct(): "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 1", "Conditions": { - "FooEnabled": { - "Fn::Equals": [ - { - "Ref": "FooEnabled" - }, - "true" - ] - }, + "FooEnabled": {"Fn::Equals": [{"Ref": "FooEnabled"}, "true"]}, "FooDisabled": { - "Fn::Not": [ - { - "Fn::Equals": [ - { - "Ref": "FooEnabled" - }, - "true" - ] - } - ] - } + "Fn::Not": [{"Fn::Equals": [{"Ref": "FooEnabled"}, "true"]}] + }, }, "Parameters": { - "FooEnabled": { - "Type": "String", - "AllowedValues": [ - "true", - "false" - ] - } + "FooEnabled": {"Type": "String", "AllowedValues": ["true", "false"]} }, "Resources": { "Bar": { - "Properties": { - "CidrBlock": "192.168.0.0/16", - }, + "Properties": {"CidrBlock": "192.168.0.0/16"}, "Condition": "FooDisabled", - "Type": "AWS::EC2::VPC" + "Type": "AWS::EC2::VPC", } - } + }, } dummy_template_json = json.dumps(dummy_template) cfn = boto.connect_cloudformation() - cfn.create_stack('test_stack1', template_body=dummy_template_json, parameters=[('FooEnabled', 'true')]) - stack = cfn.describe_stacks('test_stack1')[0] + cfn.create_stack( + "test_stack1", + template_body=dummy_template_json, + parameters=[("FooEnabled", "true")], + ) + stack = cfn.describe_stacks("test_stack1")[0] resources = stack.list_resources() - assert not [resource for resource in resources if resource.logical_resource_id == 'Bar'] + assert not [ + resource for resource in resources if resource.logical_resource_id == "Bar" + ] @mock_cloudformation_deprecated @@ -403,48 +351,46 @@ def test_stack_tags(): @mock_cloudformation_deprecated def test_update_stack(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) conn.update_stack("test_stack", dummy_template_json2) stack = conn.describe_stacks()[0] stack.stack_status.should.equal("UPDATE_COMPLETE") - stack.get_template().should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json2, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + stack.get_template().should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json2, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - }) + ) @mock_cloudformation_deprecated def test_update_stack_with_previous_template(): conn = boto.connect_cloudformation() - conn.create_stack( - "test_stack", - template_body=dummy_template_json, - ) + conn.create_stack("test_stack", template_body=dummy_template_json) conn.update_stack("test_stack", use_previous_template=True) stack = conn.describe_stacks()[0] stack.stack_status.should.equal("UPDATE_COMPLETE") - stack.get_template().should.equal({ - 'GetTemplateResponse': { - 'GetTemplateResult': { - 'TemplateBody': dummy_template_json, - 'ResponseMetadata': { - 'RequestId': '2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE' + stack.get_template().should.equal( + { + "GetTemplateResponse": { + "GetTemplateResult": { + "TemplateBody": dummy_template_json, + "ResponseMetadata": { + "RequestId": "2d06e36c-ac1d-11e0-a958-f9382b6eb86bEXAMPLE" + }, } } } - }) + ) @mock_cloudformation_deprecated @@ -454,29 +400,23 @@ def test_update_stack_with_parameters(): "Description": "Stack", "Resources": { "VPC": { - "Properties": { - "CidrBlock": {"Ref": "Bar"} - }, - "Type": "AWS::EC2::VPC" + "Properties": {"CidrBlock": {"Ref": "Bar"}}, + "Type": "AWS::EC2::VPC", } }, - "Parameters": { - "Bar": { - "Type": "String" - } - } + "Parameters": {"Bar": {"Type": "String"}}, } dummy_template_json = json.dumps(dummy_template) conn = boto.connect_cloudformation() conn.create_stack( "test_stack", template_body=dummy_template_json, - parameters=[("Bar", "192.168.0.0/16")] + parameters=[("Bar", "192.168.0.0/16")], ) conn.update_stack( "test_stack", template_body=dummy_template_json, - parameters=[("Bar", "192.168.0.1/16")] + parameters=[("Bar", "192.168.0.1/16")], ) stack = conn.describe_stacks()[0] @@ -487,14 +427,10 @@ def test_update_stack_with_parameters(): def test_update_stack_replace_tags(): conn = boto.connect_cloudformation() conn.create_stack( - "test_stack", - template_body=dummy_template_json, - tags={"foo": "bar"}, + "test_stack", template_body=dummy_template_json, tags={"foo": "bar"} ) conn.update_stack( - "test_stack", - template_body=dummy_template_json, - tags={"foo": "baz"}, + "test_stack", template_body=dummy_template_json, tags={"foo": "baz"} ) stack = conn.describe_stacks()[0] @@ -506,28 +442,26 @@ def test_update_stack_replace_tags(): @mock_cloudformation_deprecated def test_update_stack_when_rolled_back(): conn = boto.connect_cloudformation() - stack_id = conn.create_stack( - "test_stack", template_body=dummy_template_json) + stack_id = conn.create_stack("test_stack", template_body=dummy_template_json) cloudformation_backends[conn.region.name].stacks[ - stack_id].status = 'ROLLBACK_COMPLETE' + stack_id + ].status = "ROLLBACK_COMPLETE" with assert_raises(BotoServerError) as err: conn.update_stack("test_stack", dummy_template_json) ex = err.exception - ex.body.should.match( - r'is in ROLLBACK_COMPLETE state and can not be updated') - ex.error_code.should.equal('ValidationError') - ex.reason.should.equal('Bad Request') + ex.body.should.match(r"is in ROLLBACK_COMPLETE state and can not be updated") + ex.error_code.should.equal("ValidationError") + ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @mock_cloudformation_deprecated def test_describe_stack_events_shows_create_update_and_delete(): conn = boto.connect_cloudformation() - stack_id = conn.create_stack( - "test_stack", template_body=dummy_template_json) + stack_id = conn.create_stack("test_stack", template_body=dummy_template_json) conn.update_stack(stack_id, template_body=dummy_template_json2) conn.delete_stack(stack_id) @@ -538,14 +472,16 @@ def test_describe_stack_events_shows_create_update_and_delete(): # testing ordering of stack events without assuming resource events will not exist # the AWS API returns events in reverse chronological order - stack_events_to_look_for = iter([ - ("DELETE_COMPLETE", None), - ("DELETE_IN_PROGRESS", "User Initiated"), - ("UPDATE_COMPLETE", None), - ("UPDATE_IN_PROGRESS", "User Initiated"), - ("CREATE_COMPLETE", None), - ("CREATE_IN_PROGRESS", "User Initiated"), - ]) + stack_events_to_look_for = iter( + [ + ("DELETE_COMPLETE", None), + ("DELETE_IN_PROGRESS", "User Initiated"), + ("UPDATE_COMPLETE", None), + ("UPDATE_IN_PROGRESS", "User Initiated"), + ("CREATE_COMPLETE", None), + ("CREATE_IN_PROGRESS", "User Initiated"), + ] + ) try: for event in events: event.stack_id.should.equal(stack_id) @@ -556,12 +492,10 @@ def test_describe_stack_events_shows_create_update_and_delete(): event.logical_resource_id.should.equal("test_stack") event.physical_resource_id.should.equal(stack_id) - status_to_look_for, reason_to_look_for = next( - stack_events_to_look_for) + status_to_look_for, reason_to_look_for = next(stack_events_to_look_for) event.resource_status.should.equal(status_to_look_for) if reason_to_look_for is not None: - event.resource_status_reason.should.equal( - reason_to_look_for) + event.resource_status_reason.should.equal(reason_to_look_for) except StopIteration: assert False, "Too many stack events" @@ -574,74 +508,60 @@ def test_create_stack_lambda_and_dynamodb(): dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack Lambda Test 1", - "Parameters": { - }, + "Parameters": {}, "Resources": { "func1": { - "Type" : "AWS::Lambda::Function", - "Properties" : { - "Code": { - "S3Bucket": "bucket_123", - "S3Key": "key_123" - }, + "Type": "AWS::Lambda::Function", + "Properties": { + "Code": {"S3Bucket": "bucket_123", "S3Key": "key_123"}, "FunctionName": "func1", "Handler": "handler.handler", "Role": "role1", "Runtime": "python2.7", "Description": "descr", "MemorySize": 12345, - } + }, }, "func1version": { "Type": "AWS::Lambda::Version", - "Properties": { - "FunctionName": { - "Ref": "func1" - } - } + "Properties": {"FunctionName": {"Ref": "func1"}}, }, "tab1": { - "Type" : "AWS::DynamoDB::Table", - "Properties" : { + "Type": "AWS::DynamoDB::Table", + "Properties": { "TableName": "tab1", - "KeySchema": [{ - "AttributeName": "attr1", - "KeyType": "HASH" - }], - "AttributeDefinitions": [{ - "AttributeName": "attr1", - "AttributeType": "string" - }], + "KeySchema": [{"AttributeName": "attr1", "KeyType": "HASH"}], + "AttributeDefinitions": [ + {"AttributeName": "attr1", "AttributeType": "string"} + ], "ProvisionedThroughput": { "ReadCapacityUnits": 10, - "WriteCapacityUnits": 10 - } - } + "WriteCapacityUnits": 10, + }, + }, }, "func1mapping": { "Type": "AWS::Lambda::EventSourceMapping", "Properties": { - "FunctionName": { - "Ref": "func1" - }, + "FunctionName": {"Ref": "func1"}, "EventSourceArn": "arn:aws:dynamodb:region:XXXXXX:table/tab1/stream/2000T00:00:00.000", "StartingPosition": "0", "BatchSize": 100, - "Enabled": True - } - } + "Enabled": True, + }, + }, }, } - validate_s3_before = os.environ.get('VALIDATE_LAMBDA_S3', '') + validate_s3_before = os.environ.get("VALIDATE_LAMBDA_S3", "") try: - os.environ['VALIDATE_LAMBDA_S3'] = 'false' + os.environ["VALIDATE_LAMBDA_S3"] = "false" conn.create_stack( "test_stack_lambda_1", template_body=json.dumps(dummy_template), - parameters={}.items() + parameters={}.items(), ) finally: - os.environ['VALIDATE_LAMBDA_S3'] = validate_s3_before + os.environ["VALIDATE_LAMBDA_S3"] = validate_s3_before stack = conn.describe_stacks()[0] resources = stack.list_resources() @@ -657,18 +577,15 @@ def test_create_stack_kinesis(): "Parameters": {}, "Resources": { "stream1": { - "Type" : "AWS::Kinesis::Stream", - "Properties" : { - "Name": "stream1", - "ShardCount": 2 - } + "Type": "AWS::Kinesis::Stream", + "Properties": {"Name": "stream1", "ShardCount": 2}, } - } + }, } conn.create_stack( "test_stack_kinesis_1", template_body=json.dumps(dummy_template), - parameters={}.items() + parameters={}.items(), ) stack = conn.describe_stacks()[0] diff --git a/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py b/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py index d05bc1b53..28d68fa20 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py +++ b/tests/test_cloudformation/test_cloudformation_stack_crud_boto3.py @@ -6,6 +6,7 @@ from collections import OrderedDict import boto3 from botocore.exceptions import ClientError import sure # noqa + # Ensure 'assert_raises' context manager support for Python 2.6 from nose.tools import assert_raises @@ -22,18 +23,12 @@ dummy_template = { "KeyName": "dummy", "InstanceType": "t2.micro", "Tags": [ - { - "Key": "Description", - "Value": "Test tag" - }, - { - "Key": "Name", - "Value": "Name tag for tests" - } - ] - } + {"Key": "Description", "Value": "Test tag"}, + {"Key": "Name", "Value": "Name tag for tests"}, + ], + }, } - } + }, } dummy_template_yaml = """--- @@ -100,17 +95,15 @@ dummy_update_template = { "KeyName": { "Description": "Name of an existing EC2 KeyPair", "Type": "AWS::EC2::KeyPair::KeyName", - "ConstraintDescription": "must be the name of an existing EC2 KeyPair." + "ConstraintDescription": "must be the name of an existing EC2 KeyPair.", } }, "Resources": { "Instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-08111162" - } + "Properties": {"ImageId": "ami-08111162"}, } - } + }, } dummy_output_template = { @@ -119,20 +112,16 @@ dummy_output_template = { "Resources": { "Instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-08111162" - } + "Properties": {"ImageId": "ami-08111162"}, } }, "Outputs": { "StackVPC": { "Description": "The ID of the VPC", "Value": "VPCID", - "Export": { - "Name": "My VPC ID" - } + "Export": {"Name": "My VPC ID"}, } - } + }, } dummy_import_template = { @@ -141,11 +130,11 @@ dummy_import_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::ImportValue": 'My VPC ID'}, + "QueueName": {"Fn::ImportValue": "My VPC ID"}, "VisibilityTimeout": 60, - } + }, } - } + }, } dummy_redrive_template = { @@ -158,23 +147,16 @@ dummy_redrive_template = { "FifoQueue": True, "ContentBasedDeduplication": False, "RedrivePolicy": { - "deadLetterTargetArn": { - "Fn::GetAtt": [ - "DeadLetterQueue", - "Arn" - ] - }, - "maxReceiveCount": 5 - } - } + "deadLetterTargetArn": {"Fn::GetAtt": ["DeadLetterQueue", "Arn"]}, + "maxReceiveCount": 5, + }, + }, }, "DeadLetterQueue": { "Type": "AWS::SQS::Queue", - "Properties": { - "FifoQueue": True - } + "Properties": {"FifoQueue": True}, }, - } + }, } dummy_template_json = json.dumps(dummy_template) @@ -186,43 +168,48 @@ dummy_redrive_template_json = json.dumps(dummy_redrive_template) @mock_cloudformation def test_boto3_describe_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-2"], ) usw2_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-2', + StackInstanceAccount="123456789012", + StackInstanceRegion="us-west-2", ) use1_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-east-1', + StackInstanceAccount="123456789012", + StackInstanceRegion="us-east-1", ) - usw2_instance['StackInstance'].should.have.key('Region').which.should.equal('us-west-2') - usw2_instance['StackInstance'].should.have.key('Account').which.should.equal('123456789012') - use1_instance['StackInstance'].should.have.key('Region').which.should.equal('us-east-1') - use1_instance['StackInstance'].should.have.key('Account').which.should.equal('123456789012') + usw2_instance["StackInstance"].should.have.key("Region").which.should.equal( + "us-west-2" + ) + usw2_instance["StackInstance"].should.have.key("Account").which.should.equal( + "123456789012" + ) + use1_instance["StackInstance"].should.have.key("Region").which.should.equal( + "us-east-1" + ) + use1_instance["StackInstance"].should.have.key("Account").which.should.equal( + "123456789012" + ) @mock_cloudformation def test_boto3_list_stacksets_length(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_set( - StackSetName="test_stack_set2", - TemplateBody=dummy_template_yaml, + StackSetName="test_stack_set2", TemplateBody=dummy_template_yaml ) stacksets = cf_conn.list_stack_sets() stacksets.should.have.length_of(2) @@ -230,106 +217,102 @@ def test_boto3_list_stacksets_length(): @mock_cloudformation def test_boto3_list_stacksets_contents(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) stacksets = cf_conn.list_stack_sets() - stacksets['Summaries'][0].should.have.key('StackSetName').which.should.equal('test_stack_set') - stacksets['Summaries'][0].should.have.key('Status').which.should.equal('ACTIVE') + stacksets["Summaries"][0].should.have.key("StackSetName").which.should.equal( + "test_stack_set" + ) + stacksets["Summaries"][0].should.have.key("Status").which.should.equal("ACTIVE") @mock_cloudformation def test_boto3_stop_stack_set_operation(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) - operation_id = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set")['Summaries'][-1]['OperationId'] + operation_id = cf_conn.list_stack_set_operations(StackSetName="test_stack_set")[ + "Summaries" + ][-1]["OperationId"] cf_conn.stop_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id + StackSetName="test_stack_set", OperationId=operation_id ) - list_operation = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set" - ) - list_operation['Summaries'][-1]['Status'].should.equal('STOPPED') + list_operation = cf_conn.list_stack_set_operations(StackSetName="test_stack_set") + list_operation["Summaries"][-1]["Status"].should.equal("STOPPED") @mock_cloudformation def test_boto3_describe_stack_set_operation(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) - operation_id = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set")['Summaries'][-1]['OperationId'] + operation_id = cf_conn.list_stack_set_operations(StackSetName="test_stack_set")[ + "Summaries" + ][-1]["OperationId"] cf_conn.stop_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id + StackSetName="test_stack_set", OperationId=operation_id ) response = cf_conn.describe_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id, + StackSetName="test_stack_set", OperationId=operation_id ) - response['StackSetOperation']['Status'].should.equal('STOPPED') - response['StackSetOperation']['Action'].should.equal('CREATE') + response["StackSetOperation"]["Status"].should.equal("STOPPED") + response["StackSetOperation"]["Action"].should.equal("CREATE") @mock_cloudformation def test_boto3_list_stack_set_operation_results(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) - operation_id = cf_conn.list_stack_set_operations( - StackSetName="test_stack_set")['Summaries'][-1]['OperationId'] + operation_id = cf_conn.list_stack_set_operations(StackSetName="test_stack_set")[ + "Summaries" + ][-1]["OperationId"] cf_conn.stop_stack_set_operation( - StackSetName="test_stack_set", - OperationId=operation_id + StackSetName="test_stack_set", OperationId=operation_id ) response = cf_conn.list_stack_set_operation_results( - StackSetName="test_stack_set", - OperationId=operation_id, + StackSetName="test_stack_set", OperationId=operation_id ) - response['Summaries'].should.have.length_of(3) - response['Summaries'][0].should.have.key('Account').which.should.equal('123456789012') - response['Summaries'][1].should.have.key('Status').which.should.equal('STOPPED') + response["Summaries"].should.have.length_of(3) + response["Summaries"][0].should.have.key("Account").which.should.equal( + "123456789012" + ) + response["Summaries"][1].should.have.key("Status").which.should.equal("STOPPED") @mock_cloudformation def test_boto3_update_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") param = [ - {'ParameterKey': 'SomeParam', 'ParameterValue': 'StackSetValue'}, - {'ParameterKey': 'AnotherParam', 'ParameterValue': 'StackSetValue2'}, + {"ParameterKey": "SomeParam", "ParameterValue": "StackSetValue"}, + {"ParameterKey": "AnotherParam", "ParameterValue": "StackSetValue2"}, ] param_overrides = [ - {'ParameterKey': 'SomeParam', 'ParameterValue': 'OverrideValue'}, - {'ParameterKey': 'AnotherParam', 'ParameterValue': 'OverrideValue2'} + {"ParameterKey": "SomeParam", "ParameterValue": "OverrideValue"}, + {"ParameterKey": "AnotherParam", "ParameterValue": "OverrideValue2"}, ] cf_conn.create_stack_set( StackSetName="test_stack_set", @@ -338,97 +321,117 @@ def test_boto3_update_stack_instances(): ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-1", "us-west-2"], ) cf_conn.update_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-west-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-west-1", "us-west-2"], ParameterOverrides=param_overrides, ) usw2_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-2', + StackInstanceAccount="123456789012", + StackInstanceRegion="us-west-2", ) usw1_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-1', + StackInstanceAccount="123456789012", + StackInstanceRegion="us-west-1", ) use1_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-east-1', + StackInstanceAccount="123456789012", + StackInstanceRegion="us-east-1", ) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterKey" + ].should.equal(param_overrides[0]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterValue" + ].should.equal(param_overrides[0]["ParameterValue"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterKey" + ].should.equal(param_overrides[1]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterValue" + ].should.equal(param_overrides[1]["ParameterValue"]) - usw1_instance['StackInstance']['ParameterOverrides'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - usw1_instance['StackInstance']['ParameterOverrides'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - usw1_instance['StackInstance']['ParameterOverrides'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) - usw1_instance['StackInstance']['ParameterOverrides'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) + usw1_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterKey" + ].should.equal(param_overrides[0]["ParameterKey"]) + usw1_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterValue" + ].should.equal(param_overrides[0]["ParameterValue"]) + usw1_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterKey" + ].should.equal(param_overrides[1]["ParameterKey"]) + usw1_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterValue" + ].should.equal(param_overrides[1]["ParameterValue"]) - use1_instance['StackInstance']['ParameterOverrides'].should.be.empty + use1_instance["StackInstance"]["ParameterOverrides"].should.be.empty @mock_cloudformation def test_boto3_delete_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-2"], ) cf_conn.delete_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1'], + Accounts=["123456789012"], + Regions=["us-east-1"], RetainStacks=False, ) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'].should.have.length_of(1) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'][0]['Region'].should.equal( - 'us-west-2') + cf_conn.list_stack_instances(StackSetName="test_stack_set")[ + "Summaries" + ].should.have.length_of(1) + cf_conn.list_stack_instances(StackSetName="test_stack_set")["Summaries"][0][ + "Region" + ].should.equal("us-west-2") @mock_cloudformation def test_boto3_create_stack_instances(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-2"], ) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'].should.have.length_of(2) - cf_conn.list_stack_instances(StackSetName="test_stack_set")['Summaries'][0]['Account'].should.equal( - '123456789012') + cf_conn.list_stack_instances(StackSetName="test_stack_set")[ + "Summaries" + ].should.have.length_of(2) + cf_conn.list_stack_instances(StackSetName="test_stack_set")["Summaries"][0][ + "Account" + ].should.equal("123456789012") @mock_cloudformation def test_boto3_create_stack_instances_with_param_overrides(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") param = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'StackSetValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'StackSetValue2'}, + {"ParameterKey": "TagDescription", "ParameterValue": "StackSetValue"}, + {"ParameterKey": "TagName", "ParameterValue": "StackSetValue2"}, ] param_overrides = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'OverrideValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'OverrideValue2'} + {"ParameterKey": "TagDescription", "ParameterValue": "OverrideValue"}, + {"ParameterKey": "TagName", "ParameterValue": "OverrideValue2"}, ] cf_conn.create_stack_set( StackSetName="test_stack_set", @@ -437,32 +440,40 @@ def test_boto3_create_stack_instances_with_param_overrides(): ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-2"], ParameterOverrides=param_overrides, ) usw2_instance = cf_conn.describe_stack_instance( StackSetName="test_stack_set", - StackInstanceAccount='123456789012', - StackInstanceRegion='us-west-2', + StackInstanceAccount="123456789012", + StackInstanceRegion="us-west-2", ) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) - usw2_instance['StackInstance']['ParameterOverrides'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - usw2_instance['StackInstance']['ParameterOverrides'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterKey" + ].should.equal(param_overrides[0]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterKey" + ].should.equal(param_overrides[1]["ParameterKey"]) + usw2_instance["StackInstance"]["ParameterOverrides"][0][ + "ParameterValue" + ].should.equal(param_overrides[0]["ParameterValue"]) + usw2_instance["StackInstance"]["ParameterOverrides"][1][ + "ParameterValue" + ].should.equal(param_overrides[1]["ParameterValue"]) @mock_cloudformation def test_update_stack_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") param = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'StackSetValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'StackSetValue2'}, + {"ParameterKey": "TagDescription", "ParameterValue": "StackSetValue"}, + {"ParameterKey": "TagName", "ParameterValue": "StackSetValue2"}, ] param_overrides = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'OverrideValue'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'OverrideValue2'} + {"ParameterKey": "TagDescription", "ParameterValue": "OverrideValue"}, + {"ParameterKey": "TagName", "ParameterValue": "OverrideValue2"}, ] cf_conn.create_stack_set( StackSetName="test_stack_set", @@ -470,203 +481,196 @@ def test_update_stack_set(): Parameters=param, ) cf_conn.update_stack_set( - StackSetName='test_stack_set', + StackSetName="test_stack_set", TemplateBody=dummy_template_yaml_with_ref, Parameters=param_overrides, ) - stackset = cf_conn.describe_stack_set(StackSetName='test_stack_set') + stackset = cf_conn.describe_stack_set(StackSetName="test_stack_set") - stackset['StackSet']['Parameters'][0]['ParameterValue'].should.equal(param_overrides[0]['ParameterValue']) - stackset['StackSet']['Parameters'][1]['ParameterValue'].should.equal(param_overrides[1]['ParameterValue']) - stackset['StackSet']['Parameters'][0]['ParameterKey'].should.equal(param_overrides[0]['ParameterKey']) - stackset['StackSet']['Parameters'][1]['ParameterKey'].should.equal(param_overrides[1]['ParameterKey']) + stackset["StackSet"]["Parameters"][0]["ParameterValue"].should.equal( + param_overrides[0]["ParameterValue"] + ) + stackset["StackSet"]["Parameters"][1]["ParameterValue"].should.equal( + param_overrides[1]["ParameterValue"] + ) + stackset["StackSet"]["Parameters"][0]["ParameterKey"].should.equal( + param_overrides[0]["ParameterKey"] + ) + stackset["StackSet"]["Parameters"][1]["ParameterKey"].should.equal( + param_overrides[1]["ParameterKey"] + ) @mock_cloudformation def test_boto3_list_stack_set_operations(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) cf_conn.create_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-2"], ) cf_conn.update_stack_instances( StackSetName="test_stack_set", - Accounts=['123456789012'], - Regions=['us-east-1', 'us-west-2'], + Accounts=["123456789012"], + Regions=["us-east-1", "us-west-2"], ) list_operation = cf_conn.list_stack_set_operations(StackSetName="test_stack_set") - list_operation['Summaries'].should.have.length_of(2) - list_operation['Summaries'][-1]['Action'].should.equal('UPDATE') + list_operation["Summaries"].should.have.length_of(2) + list_operation["Summaries"][-1]["Action"].should.equal("UPDATE") @mock_cloudformation def test_boto3_delete_stack_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) - cf_conn.delete_stack_set(StackSetName='test_stack_set') + cf_conn.delete_stack_set(StackSetName="test_stack_set") - cf_conn.describe_stack_set(StackSetName="test_stack_set")['StackSet']['Status'].should.equal( - 'DELETED') + cf_conn.describe_stack_set(StackSetName="test_stack_set")["StackSet"][ + "Status" + ].should.equal("DELETED") @mock_cloudformation def test_boto3_create_stack_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_json, + StackSetName="test_stack_set", TemplateBody=dummy_template_json ) - cf_conn.describe_stack_set(StackSetName="test_stack_set")['StackSet']['TemplateBody'].should.equal( - dummy_template_json) + cf_conn.describe_stack_set(StackSetName="test_stack_set")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_json) @mock_cloudformation def test_boto3_create_stack_set_with_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack_set( - StackSetName="test_stack_set", - TemplateBody=dummy_template_yaml, + StackSetName="test_stack_set", TemplateBody=dummy_template_yaml ) - cf_conn.describe_stack_set(StackSetName="test_stack_set")['StackSet']['TemplateBody'].should.equal( - dummy_template_yaml) + cf_conn.describe_stack_set(StackSetName="test_stack_set")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_yaml) @mock_cloudformation @mock_s3 def test_create_stack_set_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") bucket = s3_conn.create_bucket(Bucket="foobar") - key = s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_template_json) + key = s3_conn.Object("foobar", "template-key").put(Body=dummy_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn = boto3.client('cloudformation', region_name='us-west-1') - cf_conn.create_stack_set( - StackSetName='stack_from_url', - TemplateURL=key_url, - ) - cf_conn.describe_stack_set(StackSetName="stack_from_url")['StackSet']['TemplateBody'].should.equal( - dummy_template_json) + cf_conn = boto3.client("cloudformation", region_name="us-west-1") + cf_conn.create_stack_set(StackSetName="stack_from_url", TemplateURL=key_url) + cf_conn.describe_stack_set(StackSetName="stack_from_url")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_json) @mock_cloudformation def test_boto3_create_stack_set_with_ref_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") params = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'desc_ref'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'name_ref'}, + {"ParameterKey": "TagDescription", "ParameterValue": "desc_ref"}, + {"ParameterKey": "TagName", "ParameterValue": "name_ref"}, ] cf_conn.create_stack_set( StackSetName="test_stack", TemplateBody=dummy_template_yaml_with_ref, - Parameters=params + Parameters=params, ) - cf_conn.describe_stack_set(StackSetName="test_stack")['StackSet']['TemplateBody'].should.equal( - dummy_template_yaml_with_ref) + cf_conn.describe_stack_set(StackSetName="test_stack")["StackSet"][ + "TemplateBody" + ].should.equal(dummy_template_yaml_with_ref) @mock_cloudformation def test_boto3_describe_stack_set_params(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") params = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'desc_ref'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'name_ref'}, + {"ParameterKey": "TagDescription", "ParameterValue": "desc_ref"}, + {"ParameterKey": "TagName", "ParameterValue": "name_ref"}, ] cf_conn.create_stack_set( StackSetName="test_stack", TemplateBody=dummy_template_yaml_with_ref, - Parameters=params + Parameters=params, ) - cf_conn.describe_stack_set(StackSetName="test_stack")['StackSet']['Parameters'].should.equal( - params) + cf_conn.describe_stack_set(StackSetName="test_stack")["StackSet"][ + "Parameters" + ].should.equal(params) @mock_cloudformation def test_boto3_create_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - json.loads(dummy_template_json, object_pairs_hook=OrderedDict)) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + json.loads(dummy_template_json, object_pairs_hook=OrderedDict) + ) @mock_cloudformation def test_boto3_create_stack_with_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_yaml, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_yaml) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - dummy_template_yaml) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + dummy_template_yaml + ) @mock_cloudformation def test_boto3_create_stack_with_short_form_func_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_yaml_with_short_form_func, + StackName="test_stack", TemplateBody=dummy_template_yaml_with_short_form_func ) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - dummy_template_yaml_with_short_form_func) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + dummy_template_yaml_with_short_form_func + ) @mock_cloudformation def test_boto3_create_stack_with_ref_yaml(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") params = [ - {'ParameterKey': 'TagDescription', 'ParameterValue': 'desc_ref'}, - {'ParameterKey': 'TagName', 'ParameterValue': 'name_ref'}, + {"ParameterKey": "TagDescription", "ParameterValue": "desc_ref"}, + {"ParameterKey": "TagName", "ParameterValue": "name_ref"}, ] cf_conn.create_stack( StackName="test_stack", TemplateBody=dummy_template_yaml_with_ref, - Parameters=params + Parameters=params, ) - cf_conn.get_template(StackName="test_stack")['TemplateBody'].should.equal( - dummy_template_yaml_with_ref) + cf_conn.get_template(StackName="test_stack")["TemplateBody"].should.equal( + dummy_template_yaml_with_ref + ) @mock_cloudformation def test_creating_stacks_across_regions(): - west1_cf = boto3.resource('cloudformation', region_name='us-west-1') - west2_cf = boto3.resource('cloudformation', region_name='us-west-2') - west1_cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) - west2_cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + west1_cf = boto3.resource("cloudformation", region_name="us-west-1") + west2_cf = boto3.resource("cloudformation", region_name="us-west-2") + west1_cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) + west2_cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) list(west1_cf.stacks.all()).should.have.length_of(1) list(west2_cf.stacks.all()).should.have.length_of(1) @@ -674,289 +678,266 @@ def test_creating_stacks_across_regions(): @mock_cloudformation def test_create_stack_with_notification_arn(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") cf.create_stack( StackName="test_stack_with_notifications", TemplateBody=dummy_template_json, - NotificationARNs=['arn:aws:sns:us-east-1:123456789012:fake-queue'], + NotificationARNs=["arn:aws:sns:us-east-1:123456789012:fake-queue"], ) stack = list(cf.stacks.all())[0] stack.notification_arns.should.contain( - 'arn:aws:sns:us-east-1:123456789012:fake-queue') + "arn:aws:sns:us-east-1:123456789012:fake-queue" + ) @mock_cloudformation def test_create_stack_with_role_arn(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") cf.create_stack( StackName="test_stack_with_notifications", TemplateBody=dummy_template_json, - RoleARN='arn:aws:iam::123456789012:role/moto', + RoleARN="arn:aws:iam::123456789012:role/moto", ) stack = list(cf.stacks.all())[0] - stack.role_arn.should.equal('arn:aws:iam::123456789012:role/moto') + stack.role_arn.should.equal("arn:aws:iam::123456789012:role/moto") @mock_cloudformation @mock_s3 def test_create_stack_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") bucket = s3_conn.create_bucket(Bucket="foobar") - key = s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_template_json) + key = s3_conn.Object("foobar", "template-key").put(Body=dummy_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn = boto3.client('cloudformation', region_name='us-west-1') - cf_conn.create_stack( - StackName='stack_from_url', - TemplateURL=key_url, + cf_conn = boto3.client("cloudformation", region_name="us-west-1") + cf_conn.create_stack(StackName="stack_from_url", TemplateURL=key_url) + cf_conn.get_template(StackName="stack_from_url")["TemplateBody"].should.equal( + json.loads(dummy_template_json, object_pairs_hook=OrderedDict) ) - cf_conn.get_template(StackName="stack_from_url")['TemplateBody'].should.equal( - json.loads(dummy_template_json, object_pairs_hook=OrderedDict)) @mock_cloudformation def test_update_stack_with_previous_value(): - name = 'update_stack_with_previous_value' - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + name = "update_stack_with_previous_value" + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( - StackName=name, TemplateBody=dummy_template_yaml_with_ref, + StackName=name, + TemplateBody=dummy_template_yaml_with_ref, Parameters=[ - {'ParameterKey': 'TagName', 'ParameterValue': 'foo'}, - {'ParameterKey': 'TagDescription', 'ParameterValue': 'bar'}, - ] + {"ParameterKey": "TagName", "ParameterValue": "foo"}, + {"ParameterKey": "TagDescription", "ParameterValue": "bar"}, + ], ) cf_conn.update_stack( - StackName=name, UsePreviousTemplate=True, + StackName=name, + UsePreviousTemplate=True, Parameters=[ - {'ParameterKey': 'TagName', 'UsePreviousValue': True}, - {'ParameterKey': 'TagDescription', 'ParameterValue': 'not bar'}, - ] + {"ParameterKey": "TagName", "UsePreviousValue": True}, + {"ParameterKey": "TagDescription", "ParameterValue": "not bar"}, + ], ) - stack = cf_conn.describe_stacks(StackName=name)['Stacks'][0] - tag_name = [x['ParameterValue'] for x in stack['Parameters'] - if x['ParameterKey'] == 'TagName'][0] - tag_desc = [x['ParameterValue'] for x in stack['Parameters'] - if x['ParameterKey'] == 'TagDescription'][0] - assert tag_name == 'foo' - assert tag_desc == 'not bar' + stack = cf_conn.describe_stacks(StackName=name)["Stacks"][0] + tag_name = [ + x["ParameterValue"] + for x in stack["Parameters"] + if x["ParameterKey"] == "TagName" + ][0] + tag_desc = [ + x["ParameterValue"] + for x in stack["Parameters"] + if x["ParameterKey"] == "TagDescription" + ][0] + assert tag_name == "foo" + assert tag_desc == "not bar" @mock_cloudformation @mock_s3 @mock_ec2 def test_update_stack_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( StackName="update_stack_from_url", TemplateBody=dummy_template_json, - Tags=[{'Key': 'foo', 'Value': 'bar'}], + Tags=[{"Key": "foo", "Value": "bar"}], ) s3_conn.create_bucket(Bucket="foobar") - s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_update_template_json) + s3_conn.Object("foobar", "template-key").put(Body=dummy_update_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn.update_stack( - StackName="update_stack_from_url", - TemplateURL=key_url, - ) + cf_conn.update_stack(StackName="update_stack_from_url", TemplateURL=key_url) - cf_conn.get_template(StackName="update_stack_from_url")[ 'TemplateBody'].should.equal( - json.loads(dummy_update_template_json, object_pairs_hook=OrderedDict)) + cf_conn.get_template(StackName="update_stack_from_url")[ + "TemplateBody" + ].should.equal( + json.loads(dummy_update_template_json, object_pairs_hook=OrderedDict) + ) @mock_cloudformation @mock_s3 def test_create_change_set_from_s3_url(): - s3 = boto3.client('s3') - s3_conn = boto3.resource('s3') + s3 = boto3.client("s3") + s3_conn = boto3.resource("s3") bucket = s3_conn.create_bucket(Bucket="foobar") - key = s3_conn.Object( - 'foobar', 'template-key').put(Body=dummy_template_json) + key = s3_conn.Object("foobar", "template-key").put(Body=dummy_template_json) key_url = s3.generate_presigned_url( - ClientMethod='get_object', - Params={ - 'Bucket': 'foobar', - 'Key': 'template-key' - } + ClientMethod="get_object", Params={"Bucket": "foobar", "Key": "template-key"} ) - cf_conn = boto3.client('cloudformation', region_name='us-west-1') + cf_conn = boto3.client("cloudformation", region_name="us-west-1") response = cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateURL=key_url, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', - Tags=[ - {'Key': 'tag-key', 'Value': 'tag-value'} - ], + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", + Tags=[{"Key": "tag-key", "Value": "tag-value"}], + ) + assert ( + "arn:aws:cloudformation:us-west-1:123456789:changeSet/NewChangeSet/" + in response["Id"] + ) + assert ( + "arn:aws:cloudformation:us-east-1:123456789:stack/NewStack" + in response["StackId"] ) - assert 'arn:aws:cloudformation:us-west-1:123456789:changeSet/NewChangeSet/' in response['Id'] - assert 'arn:aws:cloudformation:us-east-1:123456789:stack/NewStack' in response['StackId'] @mock_cloudformation def test_describe_change_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) stack = cf_conn.describe_change_set(ChangeSetName="NewChangeSet") - stack['ChangeSetName'].should.equal('NewChangeSet') - stack['StackName'].should.equal('NewStack') + stack["ChangeSetName"].should.equal("NewChangeSet") + stack["StackName"].should.equal("NewStack") cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_update_template_json, - ChangeSetName='NewChangeSet2', - ChangeSetType='UPDATE', + ChangeSetName="NewChangeSet2", + ChangeSetType="UPDATE", ) stack = cf_conn.describe_change_set(ChangeSetName="NewChangeSet2") - stack['ChangeSetName'].should.equal('NewChangeSet2') - stack['StackName'].should.equal('NewStack') - stack['Changes'].should.have.length_of(2) + stack["ChangeSetName"].should.equal("NewChangeSet2") + stack["StackName"].should.equal("NewStack") + stack["Changes"].should.have.length_of(2) @mock_cloudformation def test_execute_change_set_w_arn(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") change_set = cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) - cf_conn.execute_change_set(ChangeSetName=change_set['Id']) + cf_conn.execute_change_set(ChangeSetName=change_set["Id"]) @mock_cloudformation def test_execute_change_set_w_name(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") change_set = cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) - cf_conn.execute_change_set(ChangeSetName='NewChangeSet', StackName='NewStack') + cf_conn.execute_change_set(ChangeSetName="NewChangeSet", StackName="NewStack") @mock_cloudformation def test_describe_stack_pagination(): - conn = boto3.client('cloudformation', region_name='us-east-1') + conn = boto3.client("cloudformation", region_name="us-east-1") for i in range(100): - conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) resp = conn.describe_stacks() - stacks = resp['Stacks'] + stacks = resp["Stacks"] stacks.should.have.length_of(50) - next_token = resp['NextToken'] + next_token = resp["NextToken"] next_token.should_not.be.none resp2 = conn.describe_stacks(NextToken=next_token) - stacks.extend(resp2['Stacks']) + stacks.extend(resp2["Stacks"]) stacks.should.have.length_of(100) - assert 'NextToken' not in resp2.keys() + assert "NextToken" not in resp2.keys() @mock_cloudformation def test_describe_stack_resources(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] - response = cf_conn.describe_stack_resources(StackName=stack['StackName']) - resource = response['StackResources'][0] - resource['LogicalResourceId'].should.equal('EC2Instance1') - resource['ResourceStatus'].should.equal('CREATE_COMPLETE') - resource['ResourceType'].should.equal('AWS::EC2::Instance') - resource['StackId'].should.equal(stack['StackId']) + response = cf_conn.describe_stack_resources(StackName=stack["StackName"]) + resource = response["StackResources"][0] + resource["LogicalResourceId"].should.equal("EC2Instance1") + resource["ResourceStatus"].should.equal("CREATE_COMPLETE") + resource["ResourceType"].should.equal("AWS::EC2::Instance") + resource["StackId"].should.equal(stack["StackId"]) @mock_cloudformation def test_describe_stack_by_name(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack['StackName'].should.equal('test_stack') + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack["StackName"].should.equal("test_stack") @mock_cloudformation def test_describe_stack_by_stack_id(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack_by_id = cf_conn.describe_stacks(StackName=stack['StackId'])['Stacks'][ - 0] + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack_by_id = cf_conn.describe_stacks(StackName=stack["StackId"])["Stacks"][0] - stack_by_id['StackId'].should.equal(stack['StackId']) - stack_by_id['StackName'].should.equal("test_stack") + stack_by_id["StackId"].should.equal(stack["StackId"]) + stack_by_id["StackName"].should.equal("test_stack") @mock_cloudformation def test_list_change_sets(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_change_set( - StackName='NewStack2', + StackName="NewStack2", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet2', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet2", + ChangeSetType="CREATE", ) - change_set = cf_conn.list_change_sets(StackName='NewStack2')['Summaries'][0] - change_set['StackName'].should.equal('NewStack2') - change_set['ChangeSetName'].should.equal('NewChangeSet2') + change_set = cf_conn.list_change_sets(StackName="NewStack2")["Summaries"][0] + change_set["StackName"].should.equal("NewStack2") + change_set["ChangeSetName"].should.equal("NewChangeSet2") @mock_cloudformation def test_list_stacks(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) - cf.create_stack( - StackName="test_stack2", - TemplateBody=dummy_template_json, - ) + cf = boto3.resource("cloudformation", region_name="us-east-1") + cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) + cf.create_stack(StackName="test_stack2", TemplateBody=dummy_template_json) stacks = list(cf.stacks.all()) stacks.should.have.length_of(2) @@ -967,11 +948,8 @@ def test_list_stacks(): @mock_cloudformation def test_delete_stack_from_resource(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf = boto3.resource("cloudformation", region_name="us-east-1") + stack = cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) list(cf.stacks.all()).should.have.length_of(1) stack.delete() @@ -981,95 +959,84 @@ def test_delete_stack_from_resource(): @mock_cloudformation @mock_ec2 def test_delete_change_set(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_change_set( - StackName='NewStack', + StackName="NewStack", TemplateBody=dummy_template_json, - ChangeSetName='NewChangeSet', - ChangeSetType='CREATE', + ChangeSetName="NewChangeSet", + ChangeSetType="CREATE", ) - cf_conn.list_change_sets(StackName='NewStack')['Summaries'].should.have.length_of(1) - cf_conn.delete_change_set(ChangeSetName='NewChangeSet', StackName='NewStack') - cf_conn.list_change_sets(StackName='NewStack')['Summaries'].should.have.length_of(0) + cf_conn.list_change_sets(StackName="NewStack")["Summaries"].should.have.length_of(1) + cf_conn.delete_change_set(ChangeSetName="NewChangeSet", StackName="NewStack") + cf_conn.list_change_sets(StackName="NewStack")["Summaries"].should.have.length_of(0) @mock_cloudformation @mock_ec2 def test_delete_stack_by_name(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - cf_conn.describe_stacks()['Stacks'].should.have.length_of(1) + cf_conn.describe_stacks()["Stacks"].should.have.length_of(1) cf_conn.delete_stack(StackName="test_stack") - cf_conn.describe_stacks()['Stacks'].should.have.length_of(0) + cf_conn.describe_stacks()["Stacks"].should.have.length_of(0) @mock_cloudformation def test_delete_stack(): - cf = boto3.client('cloudformation', region_name='us-east-1') - cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf = boto3.client("cloudformation", region_name="us-east-1") + cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - cf.delete_stack( - StackName="test_stack", - ) + cf.delete_stack(StackName="test_stack") stacks = cf.list_stacks() - assert stacks['StackSummaries'][0]['StackStatus'] == 'DELETE_COMPLETE' + assert stacks["StackSummaries"][0]["StackStatus"] == "DELETE_COMPLETE" @mock_cloudformation def test_describe_deleted_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack_id = stack['StackId'] - cf_conn.delete_stack(StackName=stack['StackId']) - stack_by_id = cf_conn.describe_stacks(StackName=stack_id)['Stacks'][0] - stack_by_id['StackId'].should.equal(stack['StackId']) - stack_by_id['StackName'].should.equal("test_stack") - stack_by_id['StackStatus'].should.equal("DELETE_COMPLETE") + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack_id = stack["StackId"] + cf_conn.delete_stack(StackName=stack["StackId"]) + stack_by_id = cf_conn.describe_stacks(StackName=stack_id)["Stacks"][0] + stack_by_id["StackId"].should.equal(stack["StackId"]) + stack_by_id["StackName"].should.equal("test_stack") + stack_by_id["StackStatus"].should.equal("DELETE_COMPLETE") @mock_cloudformation @mock_ec2 def test_describe_updated_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") cf_conn.create_stack( StackName="test_stack", TemplateBody=dummy_template_json, - Tags=[{'Key': 'foo', 'Value': 'bar'}], + Tags=[{"Key": "foo", "Value": "bar"}], ) cf_conn.update_stack( StackName="test_stack", - RoleARN='arn:aws:iam::123456789012:role/moto', + RoleARN="arn:aws:iam::123456789012:role/moto", TemplateBody=dummy_update_template_json, - Tags=[{'Key': 'foo', 'Value': 'baz'}], + Tags=[{"Key": "foo", "Value": "baz"}], ) - stack = cf_conn.describe_stacks(StackName="test_stack")['Stacks'][0] - stack_id = stack['StackId'] - stack_by_id = cf_conn.describe_stacks(StackName=stack_id)['Stacks'][0] - stack_by_id['StackId'].should.equal(stack['StackId']) - stack_by_id['StackName'].should.equal("test_stack") - stack_by_id['StackStatus'].should.equal("UPDATE_COMPLETE") - stack_by_id['RoleARN'].should.equal('arn:aws:iam::123456789012:role/moto') - stack_by_id['Tags'].should.equal([{'Key': 'foo', 'Value': 'baz'}]) + stack = cf_conn.describe_stacks(StackName="test_stack")["Stacks"][0] + stack_id = stack["StackId"] + stack_by_id = cf_conn.describe_stacks(StackName=stack_id)["Stacks"][0] + stack_by_id["StackId"].should.equal(stack["StackId"]) + stack_by_id["StackName"].should.equal("test_stack") + stack_by_id["StackStatus"].should.equal("UPDATE_COMPLETE") + stack_by_id["RoleARN"].should.equal("arn:aws:iam::123456789012:role/moto") + stack_by_id["Tags"].should.equal([{"Key": "foo", "Value": "baz"}]) @mock_cloudformation def test_bad_describe_stack(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") with assert_raises(ClientError): cf_conn.describe_stacks(StackName="non_existent_stack") @@ -1084,61 +1051,46 @@ def test_cloudformation_params(): "APPNAME": { "Default": "app-name", "Description": "The name of the app", - "Type": "String" + "Type": "String", } - } + }, } dummy_template_with_params_json = json.dumps(dummy_template_with_params) - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName='test_stack', + StackName="test_stack", TemplateBody=dummy_template_with_params_json, - Parameters=[{ - "ParameterKey": "APPNAME", - "ParameterValue": "testing123", - }], + Parameters=[{"ParameterKey": "APPNAME", "ParameterValue": "testing123"}], ) stack.parameters.should.have.length_of(1) param = stack.parameters[0] - param['ParameterKey'].should.equal('APPNAME') - param['ParameterValue'].should.equal('testing123') + param["ParameterKey"].should.equal("APPNAME") + param["ParameterValue"].should.equal("testing123") @mock_cloudformation def test_stack_tags(): - tags = [ - { - "Key": "foo", - "Value": "bar" - }, - { - "Key": "baz", - "Value": "bleh" - } - ] - cf = boto3.resource('cloudformation', region_name='us-east-1') + tags = [{"Key": "foo", "Value": "bar"}, {"Key": "baz", "Value": "bleh"}] + cf = boto3.resource("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - Tags=tags, + StackName="test_stack", TemplateBody=dummy_template_json, Tags=tags ) observed_tag_items = set( - item for items in [tag.items() for tag in stack.tags] for item in items) + item for items in [tag.items() for tag in stack.tags] for item in items + ) expected_tag_items = set( - item for items in [tag.items() for tag in tags] for item in items) + item for items in [tag.items() for tag in tags] for item in items + ) observed_tag_items.should.equal(expected_tag_items) @mock_cloudformation @mock_ec2 def test_stack_events(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_template_json, - ) + cf = boto3.resource("cloudformation", region_name="us-east-1") + stack = cf.create_stack(StackName="test_stack", TemplateBody=dummy_template_json) stack.update(TemplateBody=dummy_update_template_json) stack = cf.Stack(stack.stack_id) stack.delete() @@ -1150,14 +1102,16 @@ def test_stack_events(): # testing ordering of stack events without assuming resource events will not exist # the AWS API returns events in reverse chronological order - stack_events_to_look_for = iter([ - ("DELETE_COMPLETE", None), - ("DELETE_IN_PROGRESS", "User Initiated"), - ("UPDATE_COMPLETE", None), - ("UPDATE_IN_PROGRESS", "User Initiated"), - ("CREATE_COMPLETE", None), - ("CREATE_IN_PROGRESS", "User Initiated"), - ]) + stack_events_to_look_for = iter( + [ + ("DELETE_COMPLETE", None), + ("DELETE_IN_PROGRESS", "User Initiated"), + ("UPDATE_COMPLETE", None), + ("UPDATE_IN_PROGRESS", "User Initiated"), + ("CREATE_COMPLETE", None), + ("CREATE_IN_PROGRESS", "User Initiated"), + ] + ) try: for event in events: event.stack_id.should.equal(stack.stack_id) @@ -1168,12 +1122,10 @@ def test_stack_events(): event.logical_resource_id.should.equal("test_stack") event.physical_resource_id.should.equal(stack.stack_id) - status_to_look_for, reason_to_look_for = next( - stack_events_to_look_for) + status_to_look_for, reason_to_look_for = next(stack_events_to_look_for) event.resource_status.should.equal(status_to_look_for) if reason_to_look_for is not None: - event.resource_status_reason.should.equal( - reason_to_look_for) + event.resource_status_reason.should.equal(reason_to_look_for) except StopIteration: assert False, "Too many stack events" @@ -1182,90 +1134,81 @@ def test_stack_events(): @mock_cloudformation def test_list_exports(): - cf_client = boto3.client('cloudformation', region_name='us-east-1') - cf_resource = boto3.resource('cloudformation', region_name='us-east-1') + cf_client = boto3.client("cloudformation", region_name="us-east-1") + cf_resource = boto3.resource("cloudformation", region_name="us-east-1") stack = cf_resource.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, + StackName="test_stack", TemplateBody=dummy_output_template_json ) - output_value = 'VPCID' - exports = cf_client.list_exports()['Exports'] + output_value = "VPCID" + exports = cf_client.list_exports()["Exports"] stack.outputs.should.have.length_of(1) - stack.outputs[0]['OutputValue'].should.equal(output_value) + stack.outputs[0]["OutputValue"].should.equal(output_value) exports.should.have.length_of(1) - exports[0]['ExportingStackId'].should.equal(stack.stack_id) - exports[0]['Name'].should.equal('My VPC ID') - exports[0]['Value'].should.equal(output_value) + exports[0]["ExportingStackId"].should.equal(stack.stack_id) + exports[0]["Name"].should.equal("My VPC ID") + exports[0]["Value"].should.equal(output_value) @mock_cloudformation def test_list_exports_with_token(): - cf = boto3.client('cloudformation', region_name='us-east-1') + cf = boto3.client("cloudformation", region_name="us-east-1") for i in range(101): # Add index to ensure name is unique - dummy_output_template['Outputs']['StackVPC']['Export']['Name'] += str(i) + dummy_output_template["Outputs"]["StackVPC"]["Export"]["Name"] += str(i) cf.create_stack( - StackName="test_stack", - TemplateBody=json.dumps(dummy_output_template), + StackName="test_stack", TemplateBody=json.dumps(dummy_output_template) ) exports = cf.list_exports() - exports['Exports'].should.have.length_of(100) - exports.get('NextToken').should_not.be.none + exports["Exports"].should.have.length_of(100) + exports.get("NextToken").should_not.be.none - more_exports = cf.list_exports(NextToken=exports['NextToken']) - more_exports['Exports'].should.have.length_of(1) - more_exports.get('NextToken').should.be.none + more_exports = cf.list_exports(NextToken=exports["NextToken"]) + more_exports["Exports"].should.have.length_of(1) + more_exports.get("NextToken").should.be.none @mock_cloudformation def test_delete_stack_with_export(): - cf = boto3.client('cloudformation', region_name='us-east-1') + cf = boto3.client("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, + StackName="test_stack", TemplateBody=dummy_output_template_json ) - stack_id = stack['StackId'] - exports = cf.list_exports()['Exports'] + stack_id = stack["StackId"] + exports = cf.list_exports()["Exports"] exports.should.have.length_of(1) cf.delete_stack(StackName=stack_id) - cf.list_exports()['Exports'].should.have.length_of(0) + cf.list_exports()["Exports"].should.have.length_of(0) @mock_cloudformation def test_export_names_must_be_unique(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") first_stack = cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, + StackName="test_stack", TemplateBody=dummy_output_template_json ) with assert_raises(ClientError): - cf.create_stack( - StackName="test_stack", - TemplateBody=dummy_output_template_json, - ) + cf.create_stack(StackName="test_stack", TemplateBody=dummy_output_template_json) @mock_sqs @mock_cloudformation def test_stack_with_imports(): - cf = boto3.resource('cloudformation', region_name='us-east-1') - ec2_resource = boto3.resource('sqs', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") + ec2_resource = boto3.resource("sqs", region_name="us-east-1") output_stack = cf.create_stack( - StackName="test_stack1", - TemplateBody=dummy_output_template_json, + StackName="test_stack1", TemplateBody=dummy_output_template_json ) import_stack = cf.create_stack( - StackName="test_stack2", - TemplateBody=dummy_import_template_json + StackName="test_stack2", TemplateBody=dummy_import_template_json ) output_stack.outputs.should.have.length_of(1) - output = output_stack.outputs[0]['OutputValue'] + output = output_stack.outputs[0]["OutputValue"] queue = ec2_resource.get_queue_by_name(QueueName=output) queue.should_not.be.none @@ -1273,14 +1216,11 @@ def test_stack_with_imports(): @mock_sqs @mock_cloudformation def test_non_json_redrive_policy(): - cf = boto3.resource('cloudformation', region_name='us-east-1') + cf = boto3.resource("cloudformation", region_name="us-east-1") stack = cf.create_stack( - StackName="test_stack1", - TemplateBody=dummy_redrive_template_json + StackName="test_stack1", TemplateBody=dummy_redrive_template_json ) - stack.Resource('MainQueue').resource_status\ - .should.equal("CREATE_COMPLETE") - stack.Resource('DeadLetterQueue').resource_status\ - .should.equal("CREATE_COMPLETE") + stack.Resource("MainQueue").resource_status.should.equal("CREATE_COMPLETE") + stack.Resource("DeadLetterQueue").resource_status.should.equal("CREATE_COMPLETE") diff --git a/tests/test_cloudformation/test_cloudformation_stack_integration.py b/tests/test_cloudformation/test_cloudformation_stack_integration.py index 42ddd2351..ced6b2005 100644 --- a/tests/test_cloudformation/test_cloudformation_stack_integration.py +++ b/tests/test_cloudformation/test_cloudformation_stack_integration.py @@ -41,7 +41,8 @@ from moto import ( mock_sns_deprecated, mock_sqs, mock_sqs_deprecated, - mock_elbv2) + mock_elbv2, +) from moto.dynamodb2.models import Table from .fixtures import ( @@ -65,26 +66,19 @@ def test_stack_sqs_integration(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) stack = conn.describe_stacks()[0] queue = stack.describe_resources()[0] - queue.resource_type.should.equal('AWS::SQS::Queue') + queue.resource_type.should.equal("AWS::SQS::Queue") queue.logical_resource_id.should.equal("QueueGroup") queue.physical_resource_id.should.equal("my-queue") @@ -95,27 +89,20 @@ def test_stack_list_resources(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) resources = conn.list_stack_resources("test_stack") assert len(resources) == 1 queue = resources[0] - queue.resource_type.should.equal('AWS::SQS::Queue') + queue.resource_type.should.equal("AWS::SQS::Queue") queue.logical_resource_id.should.equal("QueueGroup") queue.physical_resource_id.should.equal("my-queue") @@ -127,38 +114,32 @@ def test_update_stack(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) sqs_conn = boto.sqs.connect_to_region("us-west-1") queues = sqs_conn.get_all_queues() queues.should.have.length_of(1) - queues[0].get_attributes('VisibilityTimeout')[ - 'VisibilityTimeout'].should.equal('60') + queues[0].get_attributes("VisibilityTimeout")["VisibilityTimeout"].should.equal( + "60" + ) - sqs_template['Resources']['QueueGroup'][ - 'Properties']['VisibilityTimeout'] = 100 + sqs_template["Resources"]["QueueGroup"]["Properties"]["VisibilityTimeout"] = 100 sqs_template_json = json.dumps(sqs_template) conn.update_stack("test_stack", sqs_template_json) queues = sqs_conn.get_all_queues() queues.should.have.length_of(1) - queues[0].get_attributes('VisibilityTimeout')[ - 'VisibilityTimeout'].should.equal('100') + queues[0].get_attributes("VisibilityTimeout")["VisibilityTimeout"].should.equal( + "100" + ) @mock_cloudformation_deprecated() @@ -168,28 +149,21 @@ def test_update_stack_and_remove_resource(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) sqs_conn = boto.sqs.connect_to_region("us-west-1") queues = sqs_conn.get_all_queues() queues.should.have.length_of(1) - sqs_template['Resources'].pop('QueueGroup') + sqs_template["Resources"].pop("QueueGroup") sqs_template_json = json.dumps(sqs_template) conn.update_stack("test_stack", sqs_template_json) @@ -200,17 +174,11 @@ def test_update_stack_and_remove_resource(): @mock_cloudformation_deprecated() @mock_sqs_deprecated() def test_update_stack_and_add_resource(): - sqs_template = { - "AWSTemplateFormatVersion": "2010-09-09", - "Resources": {}, - } + sqs_template = {"AWSTemplateFormatVersion": "2010-09-09", "Resources": {}} sqs_template_json = json.dumps(sqs_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=sqs_template_json, - ) + conn.create_stack("test_stack", template_body=sqs_template_json) sqs_conn = boto.sqs.connect_to_region("us-west-1") queues = sqs_conn.get_all_queues() @@ -220,13 +188,9 @@ def test_update_stack_and_add_resource(): "AWSTemplateFormatVersion": "2010-09-09", "Resources": { "QueueGroup": { - "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) @@ -244,20 +208,14 @@ def test_stack_ec2_integration(): "Resources": { "WebServerGroup": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } - }, + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, + } }, } ec2_template_json = json.dumps(ec2_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "ec2_stack", - template_body=ec2_template_json, - ) + conn.create_stack("ec2_stack", template_body=ec2_template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") reservation = ec2_conn.get_all_instances()[0] @@ -265,7 +223,7 @@ def test_stack_ec2_integration(): stack = conn.describe_stacks()[0] instance = stack.describe_resources()[0] - instance.resource_type.should.equal('AWS::EC2::Instance') + instance.resource_type.should.equal("AWS::EC2::Instance") instance.logical_resource_id.should.contain("WebServerGroup") instance.physical_resource_id.should.equal(ec2_instance.id) @@ -282,7 +240,7 @@ def test_stack_elb_integration_with_attached_ec2_instances(): "Properties": { "Instances": [{"Ref": "Ec2Instance1"}], "LoadBalancerName": "test-elb", - "AvailabilityZones": ['us-east-1'], + "AvailabilityZones": ["us-east-1"], "Listeners": [ { "InstancePort": "80", @@ -290,24 +248,18 @@ def test_stack_elb_integration_with_attached_ec2_instances(): "Protocol": "HTTP", } ], - } + }, }, "Ec2Instance1": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, }, } elb_template_json = json.dumps(elb_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.create_stack("elb_stack", template_body=elb_template_json) elb_conn = boto.ec2.elb.connect_to_region("us-west-1") load_balancer = elb_conn.get_all_load_balancers()[0] @@ -317,7 +269,7 @@ def test_stack_elb_integration_with_attached_ec2_instances(): ec2_instance = reservation.instances[0] load_balancer.instances[0].id.should.equal(ec2_instance.id) - list(load_balancer.availability_zones).should.equal(['us-east-1']) + list(load_balancer.availability_zones).should.equal(["us-east-1"]) @mock_elb_deprecated() @@ -330,7 +282,7 @@ def test_stack_elb_integration_with_health_check(): "Type": "AWS::ElasticLoadBalancing::LoadBalancer", "Properties": { "LoadBalancerName": "test-elb", - "AvailabilityZones": ['us-west-1'], + "AvailabilityZones": ["us-west-1"], "HealthCheck": { "HealthyThreshold": "3", "Interval": "5", @@ -345,17 +297,14 @@ def test_stack_elb_integration_with_health_check(): "Protocol": "HTTP", } ], - } - }, + }, + } }, } elb_template_json = json.dumps(elb_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.create_stack("elb_stack", template_body=elb_template_json) elb_conn = boto.ec2.elb.connect_to_region("us-west-1") load_balancer = elb_conn.get_all_load_balancers()[0] @@ -378,7 +327,7 @@ def test_stack_elb_integration_with_update(): "Type": "AWS::ElasticLoadBalancing::LoadBalancer", "Properties": { "LoadBalancerName": "test-elb", - "AvailabilityZones": ['us-west-1a'], + "AvailabilityZones": ["us-west-1a"], "Listeners": [ { "InstancePort": "80", @@ -387,31 +336,26 @@ def test_stack_elb_integration_with_update(): } ], "Policies": {"Ref": "AWS::NoValue"}, - } - }, + }, + } }, } elb_template_json = json.dumps(elb_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.create_stack("elb_stack", template_body=elb_template_json) elb_conn = boto.ec2.elb.connect_to_region("us-west-1") load_balancer = elb_conn.get_all_load_balancers()[0] - load_balancer.availability_zones[0].should.equal('us-west-1a') + load_balancer.availability_zones[0].should.equal("us-west-1a") - elb_template['Resources']['MyELB']['Properties'][ - 'AvailabilityZones'] = ['us-west-1b'] + elb_template["Resources"]["MyELB"]["Properties"]["AvailabilityZones"] = [ + "us-west-1b" + ] elb_template_json = json.dumps(elb_template) - conn.update_stack( - "elb_stack", - template_body=elb_template_json, - ) + conn.update_stack("elb_stack", template_body=elb_template_json) load_balancer = elb_conn.get_all_load_balancers()[0] - load_balancer.availability_zones[0].should.equal('us-west-1b') + load_balancer.availability_zones[0].should.equal("us-west-1b") @mock_ec2_deprecated() @@ -434,23 +378,24 @@ def test_redshift_stack(): ("MasterUserPassword", "mypass"), ("InboundTraffic", "10.0.0.1/16"), ("PortNumber", 5439), - ] + ], ) redshift_conn = boto.redshift.connect_to_region("us-west-2") cluster_res = redshift_conn.describe_clusters() - clusters = cluster_res['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'] + clusters = cluster_res["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ] clusters.should.have.length_of(1) cluster = clusters[0] - cluster['DBName'].should.equal("mydb") - cluster['NumberOfNodes'].should.equal(2) - cluster['NodeType'].should.equal("dw1.xlarge") - cluster['MasterUsername'].should.equal("myuser") - cluster['Port'].should.equal(5439) - cluster['VpcSecurityGroups'].should.have.length_of(1) - security_group_id = cluster['VpcSecurityGroups'][0]['VpcSecurityGroupId'] + cluster["DBName"].should.equal("mydb") + cluster["NumberOfNodes"].should.equal(2) + cluster["NodeType"].should.equal("dw1.xlarge") + cluster["MasterUsername"].should.equal("myuser") + cluster["Port"].should.equal(5439) + cluster["VpcSecurityGroups"].should.have.length_of(1) + security_group_id = cluster["VpcSecurityGroups"][0]["VpcSecurityGroupId"] groups = vpc_conn.get_all_security_groups(group_ids=[security_group_id]) groups.should.have.length_of(1) @@ -467,40 +412,36 @@ def test_stack_security_groups(): "Resources": { "my-security-group": { "Type": "AWS::EC2::SecurityGroup", - "Properties": { - "GroupDescription": "My other group", - }, + "Properties": {"GroupDescription": "My other group"}, }, "Ec2Instance2": { "Type": "AWS::EC2::Instance", "Properties": { "SecurityGroups": [{"Ref": "InstanceSecurityGroup"}], "ImageId": "ami-1234abcd", - } + }, }, "InstanceSecurityGroup": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "My security group", - "Tags": [ + "Tags": [{"Key": "bar", "Value": "baz"}], + "SecurityGroupIngress": [ { - "Key": "bar", - "Value": "baz" - } + "IpProtocol": "tcp", + "FromPort": "22", + "ToPort": "22", + "CidrIp": "123.123.123.123/32", + }, + { + "IpProtocol": "tcp", + "FromPort": "80", + "ToPort": "8000", + "SourceSecurityGroupId": {"Ref": "my-security-group"}, + }, ], - "SecurityGroupIngress": [{ - "IpProtocol": "tcp", - "FromPort": "22", - "ToPort": "22", - "CidrIp": "123.123.123.123/32", - }, { - "IpProtocol": "tcp", - "FromPort": "80", - "ToPort": "8000", - "SourceSecurityGroupId": {"Ref": "my-security-group"}, - }] - } - } + }, + }, }, } security_group_template_json = json.dumps(security_group_template) @@ -509,31 +450,33 @@ def test_stack_security_groups(): conn.create_stack( "security_group_stack", template_body=security_group_template_json, - tags={"foo": "bar"} + tags={"foo": "bar"}, ) ec2_conn = boto.ec2.connect_to_region("us-west-1") instance_group = ec2_conn.get_all_security_groups( - filters={'description': ['My security group']})[0] + filters={"description": ["My security group"]} + )[0] other_group = ec2_conn.get_all_security_groups( - filters={'description': ['My other group']})[0] + filters={"description": ["My other group"]} + )[0] reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] ec2_instance.groups[0].id.should.equal(instance_group.id) instance_group.description.should.equal("My security group") - instance_group.tags.should.have.key('foo').which.should.equal('bar') - instance_group.tags.should.have.key('bar').which.should.equal('baz') + instance_group.tags.should.have.key("foo").which.should.equal("bar") + instance_group.tags.should.have.key("bar").which.should.equal("baz") rule1, rule2 = instance_group.rules int(rule1.to_port).should.equal(22) int(rule1.from_port).should.equal(22) rule1.grants[0].cidr_ip.should.equal("123.123.123.123/32") - rule1.ip_protocol.should.equal('tcp') + rule1.ip_protocol.should.equal("tcp") int(rule2.to_port).should.equal(8000) int(rule2.from_port).should.equal(80) - rule2.ip_protocol.should.equal('tcp') + rule2.ip_protocol.should.equal("tcp") rule2.grants[0].group_id.should.equal(other_group.id) @@ -544,12 +487,11 @@ def test_stack_security_groups(): def test_autoscaling_group_with_elb(): web_setup_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Resources": { "my-as-group": { "Type": "AWS::AutoScaling::AutoScalingGroup", "Properties": { - "AvailabilityZones": ['us-east1'], + "AvailabilityZones": ["us-east1"], "LaunchConfigurationName": {"Ref": "my-launch-config"}, "MinSize": "2", "MaxSize": "2", @@ -557,34 +499,33 @@ def test_autoscaling_group_with_elb(): "LoadBalancerNames": [{"Ref": "my-elb"}], "Tags": [ { - "Key": "propagated-test-tag", "Value": "propagated-test-tag-value", - "PropagateAtLaunch": True}, + "Key": "propagated-test-tag", + "Value": "propagated-test-tag-value", + "PropagateAtLaunch": True, + }, { "Key": "not-propagated-test-tag", "Value": "not-propagated-test-tag-value", - "PropagateAtLaunch": False - } - ] + "PropagateAtLaunch": False, + }, + ], }, }, - "my-launch-config": { "Type": "AWS::AutoScaling::LaunchConfiguration", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, - "my-elb": { "Type": "AWS::ElasticLoadBalancing::LoadBalancer", "Properties": { - "AvailabilityZones": ['us-east1'], - "Listeners": [{ - "LoadBalancerPort": "80", - "InstancePort": "80", - "Protocol": "HTTP", - }], + "AvailabilityZones": ["us-east1"], + "Listeners": [ + { + "LoadBalancerPort": "80", + "InstancePort": "80", + "Protocol": "HTTP", + } + ], "LoadBalancerName": "my-elb", "HealthCheck": { "Target": "HTTP:80", @@ -595,21 +536,18 @@ def test_autoscaling_group_with_elb(): }, }, }, - } + }, } web_setup_template_json = json.dumps(web_setup_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "web_stack", - template_body=web_setup_template_json, - ) + conn.create_stack("web_stack", template_body=web_setup_template_json) autoscale_conn = boto.ec2.autoscale.connect_to_region("us-west-1") autoscale_group = autoscale_conn.get_all_groups()[0] autoscale_group.launch_config_name.should.contain("my-launch-config") - autoscale_group.load_balancers[0].should.equal('my-elb') + autoscale_group.load_balancers[0].should.equal("my-elb") # Confirm the Launch config was actually created autoscale_conn.get_all_launch_configurations().should.have.length_of(1) @@ -620,29 +558,36 @@ def test_autoscaling_group_with_elb(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() - as_group_resource = [resource for resource in resources if resource.resource_type == - 'AWS::AutoScaling::AutoScalingGroup'][0] + as_group_resource = [ + resource + for resource in resources + if resource.resource_type == "AWS::AutoScaling::AutoScalingGroup" + ][0] as_group_resource.physical_resource_id.should.contain("my-as-group") launch_config_resource = [ - resource for resource in resources if - resource.resource_type == 'AWS::AutoScaling::LaunchConfiguration'][0] - launch_config_resource.physical_resource_id.should.contain( - "my-launch-config") + resource + for resource in resources + if resource.resource_type == "AWS::AutoScaling::LaunchConfiguration" + ][0] + launch_config_resource.physical_resource_id.should.contain("my-launch-config") - elb_resource = [resource for resource in resources if resource.resource_type == - 'AWS::ElasticLoadBalancing::LoadBalancer'][0] + elb_resource = [ + resource + for resource in resources + if resource.resource_type == "AWS::ElasticLoadBalancing::LoadBalancer" + ][0] elb_resource.physical_resource_id.should.contain("my-elb") # confirm the instances were created with the right tags - ec2_conn = boto.ec2.connect_to_region('us-west-1') + ec2_conn = boto.ec2.connect_to_region("us-west-1") reservations = ec2_conn.get_all_reservations() len(reservations).should.equal(1) reservation = reservations[0] len(reservation.instances).should.equal(2) for instance in reservation.instances: - instance.tags['propagated-test-tag'].should.equal('propagated-test-tag-value') - instance.tags.keys().should_not.contain('not-propagated-test-tag') + instance.tags["propagated-test-tag"].should.equal("propagated-test-tag-value") + instance.tags.keys().should_not.contain("not-propagated-test-tag") @mock_autoscaling_deprecated() @@ -655,30 +600,23 @@ def test_autoscaling_group_update(): "my-as-group": { "Type": "AWS::AutoScaling::AutoScalingGroup", "Properties": { - "AvailabilityZones": ['us-west-1'], + "AvailabilityZones": ["us-west-1"], "LaunchConfigurationName": {"Ref": "my-launch-config"}, "MinSize": "2", "MaxSize": "2", - "DesiredCapacity": "2" + "DesiredCapacity": "2", }, }, - "my-launch-config": { "Type": "AWS::AutoScaling::LaunchConfiguration", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, }, } asg_template_json = json.dumps(asg_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "asg_stack", - template_body=asg_template_json, - ) + conn.create_stack("asg_stack", template_body=asg_template_json) autoscale_conn = boto.ec2.autoscale.connect_to_region("us-west-1") asg = autoscale_conn.get_all_groups()[0] @@ -686,37 +624,38 @@ def test_autoscaling_group_update(): asg.max_size.should.equal(2) asg.desired_capacity.should.equal(2) - asg_template['Resources']['my-as-group']['Properties']['MaxSize'] = 3 - asg_template['Resources']['my-as-group']['Properties']['Tags'] = [ + asg_template["Resources"]["my-as-group"]["Properties"]["MaxSize"] = 3 + asg_template["Resources"]["my-as-group"]["Properties"]["Tags"] = [ { - "Key": "propagated-test-tag", "Value": "propagated-test-tag-value", - "PropagateAtLaunch": True}, + "Key": "propagated-test-tag", + "Value": "propagated-test-tag-value", + "PropagateAtLaunch": True, + }, { "Key": "not-propagated-test-tag", "Value": "not-propagated-test-tag-value", - "PropagateAtLaunch": False - } + "PropagateAtLaunch": False, + }, ] asg_template_json = json.dumps(asg_template) - conn.update_stack( - "asg_stack", - template_body=asg_template_json, - ) + conn.update_stack("asg_stack", template_body=asg_template_json) asg = autoscale_conn.get_all_groups()[0] asg.min_size.should.equal(2) asg.max_size.should.equal(3) asg.desired_capacity.should.equal(2) # confirm the instances were created with the right tags - ec2_conn = boto.ec2.connect_to_region('us-west-1') + ec2_conn = boto.ec2.connect_to_region("us-west-1") reservations = ec2_conn.get_all_reservations() running_instance_count = 0 for res in reservations: for instance in res.instances: - if instance.state == 'running': + if instance.state == "running": running_instance_count += 1 - instance.tags['propagated-test-tag'].should.equal('propagated-test-tag-value') - instance.tags.keys().should_not.contain('not-propagated-test-tag') + instance.tags["propagated-test-tag"].should.equal( + "propagated-test-tag-value" + ) + instance.tags.keys().should_not.contain("not-propagated-test-tag") running_instance_count.should.equal(2) @@ -726,20 +665,18 @@ def test_vpc_single_instance_in_subnet(): template_json = json.dumps(vpc_single_instance_in_subnet.template) conn = boto.cloudformation.connect_to_region("us-west-1") conn.create_stack( - "test_stack", - template_body=template_json, - parameters=[("KeyName", "my_key")], + "test_stack", template_body=template_json, parameters=[("KeyName", "my_key")] ) vpc_conn = boto.vpc.connect_to_region("us-west-1") - vpc = vpc_conn.get_all_vpcs(filters={'cidrBlock': '10.0.0.0/16'})[0] + vpc = vpc_conn.get_all_vpcs(filters={"cidrBlock": "10.0.0.0/16"})[0] vpc.cidr_block.should.equal("10.0.0.0/16") # Add this once we implement the endpoint # vpc_conn.get_all_internet_gateways().should.have.length_of(1) - subnet = vpc_conn.get_all_subnets(filters={'vpcId': vpc.id})[0] + subnet = vpc_conn.get_all_subnets(filters={"vpcId": vpc.id})[0] subnet.vpc_id.should.equal(vpc.id) ec2_conn = boto.ec2.connect_to_region("us-west-1") @@ -748,28 +685,32 @@ def test_vpc_single_instance_in_subnet(): instance.tags["Foo"].should.equal("Bar") # Check that the EIP is attached the the EC2 instance eip = ec2_conn.get_all_addresses()[0] - eip.domain.should.equal('vpc') + eip.domain.should.equal("vpc") eip.instance_id.should.equal(instance.id) - security_group = ec2_conn.get_all_security_groups( - filters={'vpc_id': [vpc.id]})[0] + security_group = ec2_conn.get_all_security_groups(filters={"vpc_id": [vpc.id]})[0] security_group.vpc_id.should.equal(vpc.id) stack = conn.describe_stacks()[0] - vpc.tags.should.have.key('Application').which.should.equal(stack.stack_id) + vpc.tags.should.have.key("Application").which.should.equal(stack.stack_id) resources = stack.describe_resources() vpc_resource = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::VPC'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::VPC" + ][0] vpc_resource.physical_resource_id.should.equal(vpc.id) subnet_resource = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::Subnet'][0] + resource + for resource in resources + if resource.resource_type == "AWS::EC2::Subnet" + ][0] subnet_resource.physical_resource_id.should.equal(subnet.id) eip_resource = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::EIP'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::EIP" + ][0] eip_resource.physical_resource_id.should.equal(eip.public_ip) @@ -779,39 +720,45 @@ def test_vpc_single_instance_in_subnet(): def test_rds_db_parameter_groups(): ec2_conn = boto3.client("ec2", region_name="us-west-1") ec2_conn.create_security_group( - GroupName='application', Description='Our Application Group') + GroupName="application", Description="Our Application Group" + ) template_json = json.dumps(rds_mysql_with_db_parameter_group.template) - cf_conn = boto3.client('cloudformation', 'us-west-1') + cf_conn = boto3.client("cloudformation", "us-west-1") cf_conn.create_stack( StackName="test_stack", TemplateBody=template_json, - Parameters=[{'ParameterKey': key, 'ParameterValue': value} for - key, value in [ - ("DBInstanceIdentifier", "master_db"), - ("DBName", "my_db"), - ("DBUser", "my_user"), - ("DBPassword", "my_password"), - ("DBAllocatedStorage", "20"), - ("DBInstanceClass", "db.m1.medium"), - ("EC2SecurityGroup", "application"), - ("MultiAZ", "true"), - ] - ], + Parameters=[ + {"ParameterKey": key, "ParameterValue": value} + for key, value in [ + ("DBInstanceIdentifier", "master_db"), + ("DBName", "my_db"), + ("DBUser", "my_user"), + ("DBPassword", "my_password"), + ("DBAllocatedStorage", "20"), + ("DBInstanceClass", "db.m1.medium"), + ("EC2SecurityGroup", "application"), + ("MultiAZ", "true"), + ] + ], ) - rds_conn = boto3.client('rds', region_name="us-west-1") + rds_conn = boto3.client("rds", region_name="us-west-1") db_parameter_groups = rds_conn.describe_db_parameter_groups() - len(db_parameter_groups['DBParameterGroups']).should.equal(1) - db_parameter_group_name = db_parameter_groups[ - 'DBParameterGroups'][0]['DBParameterGroupName'] + len(db_parameter_groups["DBParameterGroups"]).should.equal(1) + db_parameter_group_name = db_parameter_groups["DBParameterGroups"][0][ + "DBParameterGroupName" + ] found_cloudformation_set_parameter = False - for db_parameter in rds_conn.describe_db_parameters(DBParameterGroupName=db_parameter_group_name)[ - 'Parameters']: - if db_parameter['ParameterName'] == 'BACKLOG_QUEUE_LIMIT' and db_parameter[ - 'ParameterValue'] == '2048': + for db_parameter in rds_conn.describe_db_parameters( + DBParameterGroupName=db_parameter_group_name + )["Parameters"]: + if ( + db_parameter["ParameterName"] == "BACKLOG_QUEUE_LIMIT" + and db_parameter["ParameterValue"] == "2048" + ): found_cloudformation_set_parameter = True found_cloudformation_set_parameter.should.equal(True) @@ -822,7 +769,7 @@ def test_rds_db_parameter_groups(): @mock_rds_deprecated() def test_rds_mysql_with_read_replica(): ec2_conn = boto.ec2.connect_to_region("us-west-1") - ec2_conn.create_security_group('application', 'Our Application Group') + ec2_conn.create_security_group("application", "Our Application Group") template_json = json.dumps(rds_mysql_with_read_replica.template) conn = boto.cloudformation.connect_to_region("us-west-1") @@ -893,43 +840,33 @@ def test_rds_mysql_with_read_replica_in_vpc(): def test_iam_roles(): iam_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Resources": { - "my-launch-config": { "Properties": { "IamInstanceProfile": {"Ref": "my-instance-profile-with-path"}, "ImageId": "ami-1234abcd", }, - "Type": "AWS::AutoScaling::LaunchConfiguration" + "Type": "AWS::AutoScaling::LaunchConfiguration", }, "my-instance-profile-with-path": { "Properties": { "Path": "my-path", "Roles": [{"Ref": "my-role-with-path"}], }, - "Type": "AWS::IAM::InstanceProfile" + "Type": "AWS::IAM::InstanceProfile", }, "my-instance-profile-no-path": { - "Properties": { - "Roles": [{"Ref": "my-role-no-path"}], - }, - "Type": "AWS::IAM::InstanceProfile" + "Properties": {"Roles": [{"Ref": "my-role-no-path"}]}, + "Type": "AWS::IAM::InstanceProfile", }, "my-role-with-path": { "Properties": { "AssumeRolePolicyDocument": { "Statement": [ { - "Action": [ - "sts:AssumeRole" - ], + "Action": ["sts:AssumeRole"], "Effect": "Allow", - "Principal": { - "Service": [ - "ec2.amazonaws.com" - ] - } + "Principal": {"Service": ["ec2.amazonaws.com"]}, } ] }, @@ -942,102 +879,90 @@ def test_iam_roles(): "Action": [ "ec2:CreateTags", "ec2:DescribeInstances", - "ec2:DescribeTags" + "ec2:DescribeTags", ], "Effect": "Allow", - "Resource": [ - "*" - ] + "Resource": ["*"], } ], - "Version": "2012-10-17" + "Version": "2012-10-17", }, - "PolicyName": "EC2_Tags" + "PolicyName": "EC2_Tags", }, { "PolicyDocument": { "Statement": [ { - "Action": [ - "sqs:*" - ], + "Action": ["sqs:*"], "Effect": "Allow", - "Resource": [ - "*" - ] + "Resource": ["*"], } ], - "Version": "2012-10-17" + "Version": "2012-10-17", }, - "PolicyName": "SQS" + "PolicyName": "SQS", }, - ] + ], }, - "Type": "AWS::IAM::Role" + "Type": "AWS::IAM::Role", }, "my-role-no-path": { "Properties": { "AssumeRolePolicyDocument": { "Statement": [ { - "Action": [ - "sts:AssumeRole" - ], + "Action": ["sts:AssumeRole"], "Effect": "Allow", - "Principal": { - "Service": [ - "ec2.amazonaws.com" - ] - } + "Principal": {"Service": ["ec2.amazonaws.com"]}, } ] - }, + } }, - "Type": "AWS::IAM::Role" - } - } + "Type": "AWS::IAM::Role", + }, + }, } iam_template_json = json.dumps(iam_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=iam_template_json, - ) + conn.create_stack("test_stack", template_body=iam_template_json) iam_conn = boto.iam.connect_to_region("us-west-1") - role_results = iam_conn.list_roles()['list_roles_response'][ - 'list_roles_result']['roles'] + role_results = iam_conn.list_roles()["list_roles_response"]["list_roles_result"][ + "roles" + ] role_name_to_id = {} for role_result in role_results: role = iam_conn.get_role(role_result.role_name) role.role_name.should.contain("my-role") - if 'with-path' in role.role_name: - role_name_to_id['with-path'] = role.role_id + if "with-path" in role.role_name: + role_name_to_id["with-path"] = role.role_id role.path.should.equal("my-path") else: - role_name_to_id['no-path'] = role.role_id - role.role_name.should.contain('no-path') - role.path.should.equal('/') + role_name_to_id["no-path"] = role.role_id + role.role_name.should.contain("no-path") + role.path.should.equal("/") instance_profile_responses = iam_conn.list_instance_profiles()[ - 'list_instance_profiles_response']['list_instance_profiles_result']['instance_profiles'] + "list_instance_profiles_response" + ]["list_instance_profiles_result"]["instance_profiles"] instance_profile_responses.should.have.length_of(2) instance_profile_names = [] for instance_profile_response in instance_profile_responses: - instance_profile = iam_conn.get_instance_profile(instance_profile_response.instance_profile_name) + instance_profile = iam_conn.get_instance_profile( + instance_profile_response.instance_profile_name + ) instance_profile_names.append(instance_profile.instance_profile_name) - instance_profile.instance_profile_name.should.contain( - "my-instance-profile") + instance_profile.instance_profile_name.should.contain("my-instance-profile") if "with-path" in instance_profile.instance_profile_name: instance_profile.path.should.equal("my-path") - instance_profile.role_id.should.equal(role_name_to_id['with-path']) + instance_profile.role_id.should.equal(role_name_to_id["with-path"]) else: - instance_profile.instance_profile_name.should.contain('no-path') - instance_profile.role_id.should.equal(role_name_to_id['no-path']) - instance_profile.path.should.equal('/') + instance_profile.instance_profile_name.should.contain("no-path") + instance_profile.role_id.should.equal(role_name_to_id["no-path"]) + instance_profile.path.should.equal("/") autoscale_conn = boto.ec2.autoscale.connect_to_region("us-west-1") launch_config = autoscale_conn.get_all_launch_configurations()[0] @@ -1046,12 +971,20 @@ def test_iam_roles(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() instance_profile_resources = [ - resource for resource in resources if resource.resource_type == 'AWS::IAM::InstanceProfile'] - {ip.physical_resource_id for ip in instance_profile_resources}.should.equal(set(instance_profile_names)) + resource + for resource in resources + if resource.resource_type == "AWS::IAM::InstanceProfile" + ] + {ip.physical_resource_id for ip in instance_profile_resources}.should.equal( + set(instance_profile_names) + ) role_resources = [ - resource for resource in resources if resource.resource_type == 'AWS::IAM::Role'] - {r.physical_resource_id for r in role_resources}.should.equal(set(role_name_to_id.values())) + resource for resource in resources if resource.resource_type == "AWS::IAM::Role" + ] + {r.physical_resource_id for r in role_resources}.should.equal( + set(role_name_to_id.values()) + ) @mock_ec2_deprecated() @@ -1060,9 +993,7 @@ def test_single_instance_with_ebs_volume(): template_json = json.dumps(single_instance_with_ebs_volume.template) conn = boto.cloudformation.connect_to_region("us-west-1") conn.create_stack( - "test_stack", - template_body=template_json, - parameters=[("KeyName", "key_name")] + "test_stack", template_body=template_json, parameters=[("KeyName", "key_name")] ) ec2_conn = boto.ec2.connect_to_region("us-west-1") @@ -1071,15 +1002,19 @@ def test_single_instance_with_ebs_volume(): volumes = ec2_conn.get_all_volumes() # Grab the mounted drive - volume = [ - volume for volume in volumes if volume.attach_data.device == '/dev/sdh'][0] - volume.volume_state().should.equal('in-use') + volume = [volume for volume in volumes if volume.attach_data.device == "/dev/sdh"][ + 0 + ] + volume.volume_state().should.equal("in-use") volume.attach_data.instance_id.should.equal(ec2_instance.id) stack = conn.describe_stacks()[0] resources = stack.describe_resources() ebs_volumes = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::Volume'] + resource + for resource in resources + if resource.resource_type == "AWS::EC2::Volume" + ] ebs_volumes[0].physical_resource_id.should.equal(volume.id) @@ -1088,8 +1023,7 @@ def test_create_template_without_required_param(): template_json = json.dumps(single_instance_with_ebs_volume.template) conn = boto.cloudformation.connect_to_region("us-west-1") conn.create_stack.when.called_with( - "test_stack", - template_body=template_json, + "test_stack", template_body=template_json ).should.throw(BotoServerError) @@ -1105,7 +1039,8 @@ def test_classic_eip(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() cfn_eip = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::EIP'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::EIP" + ][0] cfn_eip.physical_resource_id.should.equal(eip.public_ip) @@ -1121,7 +1056,8 @@ def test_vpc_eip(): stack = conn.describe_stacks()[0] resources = stack.describe_resources() cfn_eip = [ - resource for resource in resources if resource.resource_type == 'AWS::EC2::EIP'][0] + resource for resource in resources if resource.resource_type == "AWS::EC2::EIP" + ][0] cfn_eip.physical_resource_id.should.equal(eip.public_ip) @@ -1136,7 +1072,7 @@ def test_fn_join(): stack = conn.describe_stacks()[0] fn_join_output = stack.outputs[0] - fn_join_output.value.should.equal('test eip:{0}'.format(eip.public_ip)) + fn_join_output.value.should.equal("test eip:{0}".format(eip.public_ip)) @mock_cloudformation_deprecated() @@ -1145,23 +1081,15 @@ def test_conditional_resources(): sqs_template = { "AWSTemplateFormatVersion": "2010-09-09", "Parameters": { - "EnvType": { - "Description": "Environment type.", - "Type": "String", - } - }, - "Conditions": { - "CreateQueue": {"Fn::Equals": [{"Ref": "EnvType"}, "prod"]} + "EnvType": {"Description": "Environment type.", "Type": "String"} }, + "Conditions": {"CreateQueue": {"Fn::Equals": [{"Ref": "EnvType"}, "prod"]}}, "Resources": { "QueueGroup": { "Condition": "CreateQueue", "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, + } }, } sqs_template_json = json.dumps(sqs_template) @@ -1190,43 +1118,30 @@ def test_conditional_resources(): def test_conditional_if_handling(): dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Conditions": { - "EnvEqualsPrd": { - "Fn::Equals": [ - { - "Ref": "ENV" - }, - "prd" - ] - } - }, + "Conditions": {"EnvEqualsPrd": {"Fn::Equals": [{"Ref": "ENV"}, "prd"]}}, "Parameters": { "ENV": { "Default": "dev", "Description": "Deployment environment for the stack (dev/prd)", - "Type": "String" - }, + "Type": "String", + } }, "Description": "Stack 1", "Resources": { "App1": { "Properties": { "ImageId": { - "Fn::If": [ - "EnvEqualsPrd", - "ami-00000000", - "ami-ffffffff" - ] - }, + "Fn::If": ["EnvEqualsPrd", "ami-00000000", "ami-ffffffff"] + } }, - "Type": "AWS::EC2::Instance" - }, - } + "Type": "AWS::EC2::Instance", + } + }, } dummy_template_json = json.dumps(dummy_template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack('test_stack1', template_body=dummy_template_json) + conn.create_stack("test_stack1", template_body=dummy_template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] @@ -1235,7 +1150,8 @@ def test_conditional_if_handling(): conn = boto.cloudformation.connect_to_region("us-west-2") conn.create_stack( - 'test_stack1', template_body=dummy_template_json, parameters=[("ENV", "prd")]) + "test_stack1", template_body=dummy_template_json, parameters=[("ENV", "prd")] + ) ec2_conn = boto.ec2.connect_to_region("us-west-2") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] @@ -1253,7 +1169,7 @@ def test_cloudformation_mapping(): "us-west-1": {"32": "ami-c9c7978c", "64": "ami-cfc7978a"}, "eu-west-1": {"32": "ami-37c2f643", "64": "ami-31c2f645"}, "ap-southeast-1": {"32": "ami-66f28c34", "64": "ami-60f28c32"}, - "ap-northeast-1": {"32": "ami-9c03a89d", "64": "ami-a003a8a1"} + "ap-northeast-1": {"32": "ami-9c03a89d", "64": "ami-a003a8a1"}, } }, "Resources": { @@ -1263,24 +1179,24 @@ def test_cloudformation_mapping(): "ImageId": { "Fn::FindInMap": ["RegionMap", {"Ref": "AWS::Region"}, "32"] }, - "InstanceType": "m1.small" + "InstanceType": "m1.small", }, "Type": "AWS::EC2::Instance", - }, + } }, } dummy_template_json = json.dumps(dummy_template) conn = boto.cloudformation.connect_to_region("us-east-1") - conn.create_stack('test_stack1', template_body=dummy_template_json) + conn.create_stack("test_stack1", template_body=dummy_template_json) ec2_conn = boto.ec2.connect_to_region("us-east-1") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] ec2_instance.image_id.should.equal("ami-6411e20d") conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack('test_stack1', template_body=dummy_template_json) + conn.create_stack("test_stack1", template_body=dummy_template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") reservation = ec2_conn.get_all_instances()[0] ec2_instance = reservation.instances[0] @@ -1294,42 +1210,39 @@ def test_route53_roundrobin(): template_json = json.dumps(route53_roundrobin.template) conn = boto.cloudformation.connect_to_region("us-west-1") - stack = conn.create_stack( - "test_stack", - template_body=template_json, - ) + stack = conn.create_stack("test_stack", template_body=template_json) - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) rrsets.hosted_zone_id.should.equal(zone_id) rrsets.should.have.length_of(2) record_set1 = rrsets[0] - record_set1.name.should.equal('test_stack.us-west-1.my_zone.') + record_set1.name.should.equal("test_stack.us-west-1.my_zone.") record_set1.identifier.should.equal("test_stack AWS") - record_set1.type.should.equal('CNAME') - record_set1.ttl.should.equal('900') - record_set1.weight.should.equal('3') + record_set1.type.should.equal("CNAME") + record_set1.ttl.should.equal("900") + record_set1.weight.should.equal("3") record_set1.resource_records[0].should.equal("aws.amazon.com") record_set2 = rrsets[1] - record_set2.name.should.equal('test_stack.us-west-1.my_zone.') + record_set2.name.should.equal("test_stack.us-west-1.my_zone.") record_set2.identifier.should.equal("test_stack Amazon") - record_set2.type.should.equal('CNAME') - record_set2.ttl.should.equal('900') - record_set2.weight.should.equal('1') + record_set2.type.should.equal("CNAME") + record_set2.ttl.should.equal("900") + record_set2.weight.should.equal("1") record_set2.resource_records[0].should.equal("www.amazon.com") stack = conn.describe_stacks()[0] output = stack.outputs[0] - output.key.should.equal('DomainName') - output.value.should.equal( - 'arn:aws:route53:::hostedzone/{0}'.format(zone_id)) + output.key.should.equal("DomainName") + output.value.should.equal("arn:aws:route53:::hostedzone/{0}".format(zone_id)) @mock_cloudformation_deprecated() @@ -1341,28 +1254,26 @@ def test_route53_ec2_instance_with_public_ip(): template_json = json.dumps(route53_ec2_instance_with_public_ip.template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=template_json, - ) + conn.create_stack("test_stack", template_body=template_json) instance_id = ec2_conn.get_all_reservations()[0].instances[0].id - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) rrsets.should.have.length_of(1) record_set1 = rrsets[0] - record_set1.name.should.equal('{0}.us-west-1.my_zone.'.format(instance_id)) + record_set1.name.should.equal("{0}.us-west-1.my_zone.".format(instance_id)) record_set1.identifier.should.equal(None) - record_set1.type.should.equal('A') - record_set1.ttl.should.equal('900') + record_set1.type.should.equal("A") + record_set1.ttl.should.equal("900") record_set1.weight.should.equal(None) record_set1.resource_records[0].should.equal("10.0.0.25") @@ -1374,17 +1285,15 @@ def test_route53_associate_health_check(): template_json = json.dumps(route53_health_check.template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=template_json, - ) + conn.create_stack("test_stack", template_body=template_json) - checks = route53_conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = route53_conn.get_list_health_checks()["ListHealthChecksResponse"][ + "HealthChecks" + ] list(checks).should.have.length_of(1) check = checks[0] - health_check_id = check['Id'] - config = check['HealthCheckConfig'] + health_check_id = check["Id"] + config = check["HealthCheckConfig"] config["FailureThreshold"].should.equal("3") config["IPAddress"].should.equal("10.0.0.4") config["Port"].should.equal("80") @@ -1392,11 +1301,12 @@ def test_route53_associate_health_check(): config["ResourcePath"].should.equal("/") config["Type"].should.equal("HTTP") - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) @@ -1413,16 +1323,14 @@ def test_route53_with_update(): template_json = json.dumps(route53_health_check.template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) @@ -1431,19 +1339,18 @@ def test_route53_with_update(): record_set = rrsets[0] record_set.resource_records.should.equal(["my.example.com"]) - route53_health_check.template['Resources']['myDNSRecord'][ - 'Properties']['ResourceRecords'] = ["my_other.example.com"] + route53_health_check.template["Resources"]["myDNSRecord"]["Properties"][ + "ResourceRecords" + ] = ["my_other.example.com"] template_json = json.dumps(route53_health_check.template) - cf_conn.update_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.update_stack("test_stack", template_body=template_json) - zones = route53_conn.get_all_hosted_zones()['ListHostedZonesResponse'][ - 'HostedZones'] + zones = route53_conn.get_all_hosted_zones()["ListHostedZonesResponse"][ + "HostedZones" + ] list(zones).should.have.length_of(1) - zone_id = zones[0]['Id'] - zone_id = zone_id.split('/') + zone_id = zones[0]["Id"] + zone_id = zone_id.split("/") zone_id = zone_id[2] rrsets = route53_conn.get_all_rrsets(zone_id) @@ -1463,37 +1370,32 @@ def test_sns_topic(): "Type": "AWS::SNS::Topic", "Properties": { "Subscription": [ - {"Endpoint": "https://example.com", "Protocol": "https"}, + {"Endpoint": "https://example.com", "Protocol": "https"} ], "TopicName": "my_topics", - } + }, } }, "Outputs": { - "topic_name": { - "Value": {"Fn::GetAtt": ["MySNSTopic", "TopicName"]} - }, - "topic_arn": { - "Value": {"Ref": "MySNSTopic"} - }, - } + "topic_name": {"Value": {"Fn::GetAtt": ["MySNSTopic", "TopicName"]}}, + "topic_arn": {"Value": {"Ref": "MySNSTopic"}}, + }, } template_json = json.dumps(dummy_template) conn = boto.cloudformation.connect_to_region("us-west-1") - stack = conn.create_stack( - "test_stack", - template_body=template_json, - ) + stack = conn.create_stack("test_stack", template_body=template_json) sns_conn = boto.sns.connect_to_region("us-west-1") - topics = sns_conn.get_all_topics()["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"] + topics = sns_conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"][ + "Topics" + ] topics.should.have.length_of(1) - topic_arn = topics[0]['TopicArn'] + topic_arn = topics[0]["TopicArn"] topic_arn.should.contain("my_topics") subscriptions = sns_conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] + "ListSubscriptionsResult" + ]["Subscriptions"] subscriptions.should.have.length_of(1) subscription = subscriptions[0] subscription["TopicArn"].should.equal(topic_arn) @@ -1502,9 +1404,9 @@ def test_sns_topic(): subscription["Endpoint"].should.equal("https://example.com") stack = conn.describe_stacks()[0] - topic_name_output = [x for x in stack.outputs if x.key == 'topic_name'][0] + topic_name_output = [x for x in stack.outputs if x.key == "topic_name"][0] topic_name_output.value.should.equal("my_topics") - topic_arn_output = [x for x in stack.outputs if x.key == 'topic_arn'][0] + topic_arn_output = [x for x in stack.outputs if x.key == "topic_arn"][0] topic_arn_output.value.should.equal(topic_arn) @@ -1514,44 +1416,33 @@ def test_vpc_gateway_attachment_creation_should_attach_itself_to_vpc(): template = { "AWSTemplateFormatVersion": "2010-09-09", "Resources": { - "internetgateway": { - "Type": "AWS::EC2::InternetGateway" - }, + "internetgateway": {"Type": "AWS::EC2::InternetGateway"}, "testvpc": { "Type": "AWS::EC2::VPC", "Properties": { "CidrBlock": "10.0.0.0/16", "EnableDnsHostnames": "true", "EnableDnsSupport": "true", - "InstanceTenancy": "default" + "InstanceTenancy": "default", }, }, "vpcgatewayattachment": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { - "InternetGatewayId": { - "Ref": "internetgateway" - }, - "VpcId": { - "Ref": "testvpc" - } + "InternetGatewayId": {"Ref": "internetgateway"}, + "VpcId": {"Ref": "testvpc"}, }, }, - } + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) vpc_conn = boto.vpc.connect_to_region("us-west-1") - vpc = vpc_conn.get_all_vpcs(filters={'cidrBlock': '10.0.0.0/16'})[0] - igws = vpc_conn.get_all_internet_gateways( - filters={'attachment.vpc-id': vpc.id} - ) + vpc = vpc_conn.get_all_vpcs(filters={"cidrBlock": "10.0.0.0/16"})[0] + igws = vpc_conn.get_all_internet_gateways(filters={"attachment.vpc-id": vpc.id}) igws.should.have.length_of(1) @@ -1567,20 +1458,14 @@ def test_vpc_peering_creation(): "Resources": { "vpcpeeringconnection": { "Type": "AWS::EC2::VPCPeeringConnection", - "Properties": { - "PeerVpcId": peer_vpc.id, - "VpcId": vpc_source.id, - } - }, - } + "Properties": {"PeerVpcId": peer_vpc.id, "VpcId": vpc_source.id}, + } + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) peering_connections = vpc_conn.get_all_vpc_peering_connections() peering_connections.should.have.length_of(1) @@ -1596,24 +1481,14 @@ def test_multiple_security_group_ingress_separate_from_security_group_by_id(): "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "test security group", - "Tags": [ - { - "Key": "sg-name", - "Value": "sg1" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg1"}], }, }, "test-security-group2": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "test security group", - "Tags": [ - { - "Key": "sg-name", - "Value": "sg2" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg2"}], }, }, "test-sg-ingress": { @@ -1624,39 +1499,36 @@ def test_multiple_security_group_ingress_separate_from_security_group_by_id(): "FromPort": "80", "ToPort": "8080", "SourceSecurityGroupId": {"Ref": "test-security-group2"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") - security_group1 = ec2_conn.get_all_security_groups( - filters={"tag:sg-name": "sg1"})[0] - security_group2 = ec2_conn.get_all_security_groups( - filters={"tag:sg-name": "sg2"})[0] + security_group1 = ec2_conn.get_all_security_groups(filters={"tag:sg-name": "sg1"})[ + 0 + ] + security_group2 = ec2_conn.get_all_security_groups(filters={"tag:sg-name": "sg2"})[ + 0 + ] security_group1.rules.should.have.length_of(1) security_group1.rules[0].grants.should.have.length_of(1) - security_group1.rules[0].grants[ - 0].group_id.should.equal(security_group2.id) - security_group1.rules[0].ip_protocol.should.equal('tcp') - security_group1.rules[0].from_port.should.equal('80') - security_group1.rules[0].to_port.should.equal('8080') + security_group1.rules[0].grants[0].group_id.should.equal(security_group2.id) + security_group1.rules[0].ip_protocol.should.equal("tcp") + security_group1.rules[0].from_port.should.equal("80") + security_group1.rules[0].to_port.should.equal("8080") @mock_cloudformation_deprecated @mock_ec2_deprecated def test_security_group_ingress_separate_from_security_group_by_id(): ec2_conn = boto.ec2.connect_to_region("us-west-1") - ec2_conn.create_security_group( - "test-security-group1", "test security group") + ec2_conn.create_security_group("test-security-group1", "test security group") template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -1665,12 +1537,7 @@ def test_security_group_ingress_separate_from_security_group_by_id(): "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupDescription": "test security group", - "Tags": [ - { - "Key": "sg-name", - "Value": "sg2" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg2"}], }, }, "test-sg-ingress": { @@ -1681,29 +1548,27 @@ def test_security_group_ingress_separate_from_security_group_by_id(): "FromPort": "80", "ToPort": "8080", "SourceSecurityGroupId": {"Ref": "test-security-group2"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) security_group1 = ec2_conn.get_all_security_groups( - groupnames=["test-security-group1"])[0] - security_group2 = ec2_conn.get_all_security_groups( - filters={"tag:sg-name": "sg2"})[0] + groupnames=["test-security-group1"] + )[0] + security_group2 = ec2_conn.get_all_security_groups(filters={"tag:sg-name": "sg2"})[ + 0 + ] security_group1.rules.should.have.length_of(1) security_group1.rules[0].grants.should.have.length_of(1) - security_group1.rules[0].grants[ - 0].group_id.should.equal(security_group2.id) - security_group1.rules[0].ip_protocol.should.equal('tcp') - security_group1.rules[0].from_port.should.equal('80') - security_group1.rules[0].to_port.should.equal('8080') + security_group1.rules[0].grants[0].group_id.should.equal(security_group2.id) + security_group1.rules[0].ip_protocol.should.equal("tcp") + security_group1.rules[0].from_port.should.equal("80") + security_group1.rules[0].to_port.should.equal("8080") @mock_cloudformation_deprecated @@ -1720,12 +1585,7 @@ def test_security_group_ingress_separate_from_security_group_by_id_using_vpc(): "Properties": { "GroupDescription": "test security group", "VpcId": vpc.id, - "Tags": [ - { - "Key": "sg-name", - "Value": "sg1" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg1"}], }, }, "test-security-group2": { @@ -1733,12 +1593,7 @@ def test_security_group_ingress_separate_from_security_group_by_id_using_vpc(): "Properties": { "GroupDescription": "test security group", "VpcId": vpc.id, - "Tags": [ - { - "Key": "sg-name", - "Value": "sg2" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg2"}], }, }, "test-sg-ingress": { @@ -1750,29 +1605,27 @@ def test_security_group_ingress_separate_from_security_group_by_id_using_vpc(): "FromPort": "80", "ToPort": "8080", "SourceSecurityGroupId": {"Ref": "test-security-group2"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) - security_group1 = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg1"})[0] - security_group2 = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg2"})[0] + cf_conn.create_stack("test_stack", template_body=template_json) + security_group1 = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg1"})[ + 0 + ] + security_group2 = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg2"})[ + 0 + ] security_group1.rules.should.have.length_of(1) security_group1.rules[0].grants.should.have.length_of(1) - security_group1.rules[0].grants[ - 0].group_id.should.equal(security_group2.id) - security_group1.rules[0].ip_protocol.should.equal('tcp') - security_group1.rules[0].from_port.should.equal('80') - security_group1.rules[0].to_port.should.equal('8080') + security_group1.rules[0].grants[0].group_id.should.equal(security_group2.id) + security_group1.rules[0].ip_protocol.should.equal("tcp") + security_group1.rules[0].from_port.should.equal("80") + security_group1.rules[0].to_port.should.equal("8080") @mock_cloudformation_deprecated @@ -1789,44 +1642,30 @@ def test_security_group_with_update(): "Properties": { "GroupDescription": "test security group", "VpcId": vpc1.id, - "Tags": [ - { - "Key": "sg-name", - "Value": "sg" - } - ] + "Tags": [{"Key": "sg-name", "Value": "sg"}], }, - }, - } + } + }, } template_json = json.dumps(template) cf_conn = boto.cloudformation.connect_to_region("us-west-1") - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) - security_group = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg"})[0] + cf_conn.create_stack("test_stack", template_body=template_json) + security_group = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg"})[0] security_group.vpc_id.should.equal(vpc1.id) vpc2 = vpc_conn.create_vpc("10.1.0.0/16") - template['Resources'][ - 'test-security-group']['Properties']['VpcId'] = vpc2.id + template["Resources"]["test-security-group"]["Properties"]["VpcId"] = vpc2.id template_json = json.dumps(template) - cf_conn.update_stack( - "test_stack", - template_body=template_json, - ) - security_group = vpc_conn.get_all_security_groups( - filters={"tag:sg-name": "sg"})[0] + cf_conn.update_stack("test_stack", template_body=template_json) + security_group = vpc_conn.get_all_security_groups(filters={"tag:sg-name": "sg"})[0] security_group.vpc_id.should.equal(vpc2.id) @mock_cloudformation_deprecated @mock_ec2_deprecated def test_subnets_should_be_created_with_availability_zone(): - vpc_conn = boto.vpc.connect_to_region('us-west-1') + vpc_conn = boto.vpc.connect_to_region("us-west-1") vpc = vpc_conn.create_vpc("10.0.0.0/16") subnet_template = { @@ -1838,18 +1677,15 @@ def test_subnets_should_be_created_with_availability_zone(): "VpcId": vpc.id, "CidrBlock": "10.0.0.0/24", "AvailabilityZone": "us-west-1b", - } + }, } - } + }, } cf_conn = boto.cloudformation.connect_to_region("us-west-1") template_json = json.dumps(subnet_template) - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) - subnet = vpc_conn.get_all_subnets(filters={'cidrBlock': '10.0.0.0/24'})[0] - subnet.availability_zone.should.equal('us-west-1b') + cf_conn.create_stack("test_stack", template_body=template_json) + subnet = vpc_conn.get_all_subnets(filters={"cidrBlock": "10.0.0.0/24"})[0] + subnet.availability_zone.should.equal("us-west-1b") @mock_cloudformation_deprecated @@ -1867,71 +1703,53 @@ def test_datapipeline(): "Fields": [ { "Key": "failureAndRerunMode", - "StringValue": "CASCADE" - }, - { - "Key": "scheduleType", - "StringValue": "cron" - }, - { - "Key": "schedule", - "RefValue": "DefaultSchedule" + "StringValue": "CASCADE", }, + {"Key": "scheduleType", "StringValue": "cron"}, + {"Key": "schedule", "RefValue": "DefaultSchedule"}, { "Key": "pipelineLogUri", - "StringValue": "s3://bucket/logs" - }, - { - "Key": "type", - "StringValue": "Default" + "StringValue": "s3://bucket/logs", }, + {"Key": "type", "StringValue": "Default"}, ], "Id": "Default", - "Name": "Default" + "Name": "Default", }, { "Fields": [ { "Key": "startDateTime", - "StringValue": "1970-01-01T01:00:00" + "StringValue": "1970-01-01T01:00:00", }, - { - "Key": "period", - "StringValue": "1 Day" - }, - { - "Key": "type", - "StringValue": "Schedule" - } + {"Key": "period", "StringValue": "1 Day"}, + {"Key": "type", "StringValue": "Schedule"}, ], "Id": "DefaultSchedule", - "Name": "RunOnce" - } + "Name": "RunOnce", + }, ], - "PipelineTags": [] + "PipelineTags": [], }, - "Type": "AWS::DataPipeline::Pipeline" + "Type": "AWS::DataPipeline::Pipeline", } - } + }, } cf_conn = boto.cloudformation.connect_to_region("us-east-1") template_json = json.dumps(dp_template) - stack_id = cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + stack_id = cf_conn.create_stack("test_stack", template_body=template_json) - dp_conn = boto.datapipeline.connect_to_region('us-east-1') + dp_conn = boto.datapipeline.connect_to_region("us-east-1") data_pipelines = dp_conn.list_pipelines() - data_pipelines['pipelineIdList'].should.have.length_of(1) - data_pipelines['pipelineIdList'][0][ - 'name'].should.equal('testDataPipeline') + data_pipelines["pipelineIdList"].should.have.length_of(1) + data_pipelines["pipelineIdList"][0]["name"].should.equal("testDataPipeline") stack_resources = cf_conn.list_stack_resources(stack_id) stack_resources.should.have.length_of(1) stack_resources[0].physical_resource_id.should.equal( - data_pipelines['pipelineIdList'][0]['id']) + data_pipelines["pipelineIdList"][0]["id"] + ) @mock_cloudformation @@ -1957,45 +1775,40 @@ def lambda_handler(event, context): "MemorySize": 128, "Role": "test-role", "Runtime": "python2.7", - "Environment": { - "Variables": { - "TEST_ENV_KEY": "test-env-val", - } - }, - } + "Environment": {"Variables": {"TEST_ENV_KEY": "test-env-val"}}, + }, } - } + }, } template_json = json.dumps(template) - cf_conn = boto3.client('cloudformation', 'us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cf_conn = boto3.client("cloudformation", "us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=template_json) - conn = boto3.client('lambda', 'us-east-1') + conn = boto3.client("lambda", "us-east-1") result = conn.list_functions() - result['Functions'].should.have.length_of(1) - result['Functions'][0]['Description'].should.equal('Test function') - result['Functions'][0]['Handler'].should.equal('lambda_function.handler') - result['Functions'][0]['MemorySize'].should.equal(128) - result['Functions'][0]['Role'].should.equal('test-role') - result['Functions'][0]['Runtime'].should.equal('python2.7') - result['Functions'][0]['Environment'].should.equal({ - "Variables": {"TEST_ENV_KEY": "test-env-val"} - }) + result["Functions"].should.have.length_of(1) + result["Functions"][0]["Description"].should.equal("Test function") + result["Functions"][0]["Handler"].should.equal("lambda_function.handler") + result["Functions"][0]["MemorySize"].should.equal(128) + result["Functions"][0]["Role"].should.equal("test-role") + result["Functions"][0]["Runtime"].should.equal("python2.7") + result["Functions"][0]["Environment"].should.equal( + {"Variables": {"TEST_ENV_KEY": "test-env-val"}} + ) @mock_cloudformation @mock_ec2 def test_nat_gateway(): - ec2_conn = boto3.client('ec2', 'us-east-1') - vpc_id = ec2_conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc']['VpcId'] - subnet_id = ec2_conn.create_subnet( - CidrBlock='10.0.1.0/24', VpcId=vpc_id)['Subnet']['SubnetId'] - route_table_id = ec2_conn.create_route_table( - VpcId=vpc_id)['RouteTable']['RouteTableId'] + ec2_conn = boto3.client("ec2", "us-east-1") + vpc_id = ec2_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"]["VpcId"] + subnet_id = ec2_conn.create_subnet(CidrBlock="10.0.1.0/24", VpcId=vpc_id)["Subnet"][ + "SubnetId" + ] + route_table_id = ec2_conn.create_route_table(VpcId=vpc_id)["RouteTable"][ + "RouteTableId" + ] template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -2005,97 +1818,83 @@ def test_nat_gateway(): "Type": "AWS::EC2::NatGateway", "Properties": { "AllocationId": {"Fn::GetAtt": ["EIP", "AllocationId"]}, - "SubnetId": subnet_id - } - }, - "EIP": { - "Type": "AWS::EC2::EIP", - "Properties": { - "Domain": "vpc" - } + "SubnetId": subnet_id, + }, }, + "EIP": {"Type": "AWS::EC2::EIP", "Properties": {"Domain": "vpc"}}, "Route": { "Type": "AWS::EC2::Route", "Properties": { "RouteTableId": route_table_id, "DestinationCidrBlock": "0.0.0.0/0", - "NatGatewayId": {"Ref": "NAT"} - } - }, - "internetgateway": { - "Type": "AWS::EC2::InternetGateway" + "NatGatewayId": {"Ref": "NAT"}, + }, }, + "internetgateway": {"Type": "AWS::EC2::InternetGateway"}, "vpcgatewayattachment": { "Type": "AWS::EC2::VPCGatewayAttachment", "Properties": { - "InternetGatewayId": { - "Ref": "internetgateway" - }, + "InternetGatewayId": {"Ref": "internetgateway"}, "VpcId": vpc_id, }, - } - } + }, + }, } - cf_conn = boto3.client('cloudformation', 'us-east-1') - cf_conn.create_stack( - StackName="test_stack", - TemplateBody=json.dumps(template), - ) + cf_conn = boto3.client("cloudformation", "us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=json.dumps(template)) result = ec2_conn.describe_nat_gateways() - result['NatGateways'].should.have.length_of(1) - result['NatGateways'][0]['VpcId'].should.equal(vpc_id) - result['NatGateways'][0]['SubnetId'].should.equal(subnet_id) - result['NatGateways'][0]['State'].should.equal('available') + result["NatGateways"].should.have.length_of(1) + result["NatGateways"][0]["VpcId"].should.equal(vpc_id) + result["NatGateways"][0]["SubnetId"].should.equal(subnet_id) + result["NatGateways"][0]["State"].should.equal("available") @mock_cloudformation() @mock_kms() def test_stack_kms(): kms_key_template = { - 'Resources': { - 'kmskey': { - 'Properties': { - 'Description': 'A kms key', - 'EnableKeyRotation': True, - 'Enabled': True, - 'KeyPolicy': 'a policy', + "Resources": { + "kmskey": { + "Properties": { + "Description": "A kms key", + "EnableKeyRotation": True, + "Enabled": True, + "KeyPolicy": "a policy", }, - 'Type': 'AWS::KMS::Key' + "Type": "AWS::KMS::Key", } } } kms_key_template_json = json.dumps(kms_key_template) - cf_conn = boto3.client('cloudformation', 'us-east-1') - cf_conn.create_stack( - StackName='test_stack', - TemplateBody=kms_key_template_json, - ) + cf_conn = boto3.client("cloudformation", "us-east-1") + cf_conn.create_stack(StackName="test_stack", TemplateBody=kms_key_template_json) - kms_conn = boto3.client('kms', 'us-east-1') - keys = kms_conn.list_keys()['Keys'] + kms_conn = boto3.client("kms", "us-east-1") + keys = kms_conn.list_keys()["Keys"] len(keys).should.equal(1) - result = kms_conn.describe_key(KeyId=keys[0]['KeyId']) + result = kms_conn.describe_key(KeyId=keys[0]["KeyId"]) - result['KeyMetadata']['Enabled'].should.equal(True) - result['KeyMetadata']['KeyUsage'].should.equal('ENCRYPT_DECRYPT') + result["KeyMetadata"]["Enabled"].should.equal(True) + result["KeyMetadata"]["KeyUsage"].should.equal("ENCRYPT_DECRYPT") @mock_cloudformation() @mock_ec2() def test_stack_spot_fleet(): - conn = boto3.client('ec2', 'us-east-1') + conn = boto3.client("ec2", "us-east-1") - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] spot_fleet_template = { - 'Resources': { + "Resources": { "SpotFleet": { "Type": "AWS::EC2::SpotFleet", "Properties": { @@ -2107,7 +1906,7 @@ def test_stack_spot_fleet(): "LaunchSpecifications": [ { "EbsOptimized": "false", - "InstanceType": 't2.small', + "InstanceType": "t2.small", "ImageId": "ami-1234", "SubnetId": subnet_id, "WeightedCapacity": "2", @@ -2115,71 +1914,74 @@ def test_stack_spot_fleet(): }, { "EbsOptimized": "true", - "InstanceType": 't2.large', + "InstanceType": "t2.large", "ImageId": "ami-1234", "Monitoring": {"Enabled": "true"}, "SecurityGroups": [{"GroupId": "sg-123"}], "SubnetId": subnet_id, - "IamInstanceProfile": {"Arn": "arn:aws:iam::123456789012:role/fleet"}, + "IamInstanceProfile": { + "Arn": "arn:aws:iam::123456789012:role/fleet" + }, "WeightedCapacity": "4", "SpotPrice": "10.00", - } - ] + }, + ], } - } + }, } } } spot_fleet_template_json = json.dumps(spot_fleet_template) - cf_conn = boto3.client('cloudformation', 'us-east-1') + cf_conn = boto3.client("cloudformation", "us-east-1") stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=spot_fleet_template_json, - )['StackId'] + StackName="test_stack", TemplateBody=spot_fleet_template_json + )["StackId"] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - stack_resources['StackResourceSummaries'].should.have.length_of(1) - spot_fleet_id = stack_resources[ - 'StackResourceSummaries'][0]['PhysicalResourceId'] + stack_resources["StackResourceSummaries"].should.have.length_of(1) + spot_fleet_id = stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_request['SpotFleetRequestState'].should.equal("active") - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_request["SpotFleetRequestState"].should.equal("active") + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - spot_fleet_config['SpotPrice'].should.equal('0.12') - spot_fleet_config['TargetCapacity'].should.equal(6) - spot_fleet_config['IamFleetRole'].should.equal( - 'arn:aws:iam::123456789012:role/fleet') - spot_fleet_config['AllocationStrategy'].should.equal('diversified') - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + spot_fleet_config["SpotPrice"].should.equal("0.12") + spot_fleet_config["TargetCapacity"].should.equal(6) + spot_fleet_config["IamFleetRole"].should.equal( + "arn:aws:iam::123456789012:role/fleet" + ) + spot_fleet_config["AllocationStrategy"].should.equal("diversified") + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec = spot_fleet_config['LaunchSpecifications'][0] + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec = spot_fleet_config["LaunchSpecifications"][0] - launch_spec['EbsOptimized'].should.equal(False) - launch_spec['ImageId'].should.equal("ami-1234") - launch_spec['InstanceType'].should.equal("t2.small") - launch_spec['SubnetId'].should.equal(subnet_id) - launch_spec['SpotPrice'].should.equal("0.13") - launch_spec['WeightedCapacity'].should.equal(2.0) + launch_spec["EbsOptimized"].should.equal(False) + launch_spec["ImageId"].should.equal("ami-1234") + launch_spec["InstanceType"].should.equal("t2.small") + launch_spec["SubnetId"].should.equal(subnet_id) + launch_spec["SpotPrice"].should.equal("0.13") + launch_spec["WeightedCapacity"].should.equal(2.0) @mock_cloudformation() @mock_ec2() def test_stack_spot_fleet_should_figure_out_default_price(): - conn = boto3.client('ec2', 'us-east-1') + conn = boto3.client("ec2", "us-east-1") - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] spot_fleet_template = { - 'Resources': { + "Resources": { "SpotFleet1": { "Type": "AWS::EC2::SpotFleet", "Properties": { @@ -2190,54 +1992,55 @@ def test_stack_spot_fleet_should_figure_out_default_price(): "LaunchSpecifications": [ { "EbsOptimized": "false", - "InstanceType": 't2.small', + "InstanceType": "t2.small", "ImageId": "ami-1234", "SubnetId": subnet_id, "WeightedCapacity": "2", }, { "EbsOptimized": "true", - "InstanceType": 't2.large', + "InstanceType": "t2.large", "ImageId": "ami-1234", "Monitoring": {"Enabled": "true"}, "SecurityGroups": [{"GroupId": "sg-123"}], "SubnetId": subnet_id, - "IamInstanceProfile": {"Arn": "arn:aws:iam::123456789012:role/fleet"}, + "IamInstanceProfile": { + "Arn": "arn:aws:iam::123456789012:role/fleet" + }, "WeightedCapacity": "4", - } - ] + }, + ], } - } + }, } } } spot_fleet_template_json = json.dumps(spot_fleet_template) - cf_conn = boto3.client('cloudformation', 'us-east-1') + cf_conn = boto3.client("cloudformation", "us-east-1") stack_id = cf_conn.create_stack( - StackName='test_stack', - TemplateBody=spot_fleet_template_json, - )['StackId'] + StackName="test_stack", TemplateBody=spot_fleet_template_json + )["StackId"] stack_resources = cf_conn.list_stack_resources(StackName=stack_id) - stack_resources['StackResourceSummaries'].should.have.length_of(1) - spot_fleet_id = stack_resources[ - 'StackResourceSummaries'][0]['PhysicalResourceId'] + stack_resources["StackResourceSummaries"].should.have.length_of(1) + spot_fleet_id = stack_resources["StackResourceSummaries"][0]["PhysicalResourceId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_request['SpotFleetRequestState'].should.equal("active") - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_request["SpotFleetRequestState"].should.equal("active") + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - assert 'SpotPrice' not in spot_fleet_config - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec1 = spot_fleet_config['LaunchSpecifications'][0] - launch_spec2 = spot_fleet_config['LaunchSpecifications'][1] + assert "SpotPrice" not in spot_fleet_config + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec1 = spot_fleet_config["LaunchSpecifications"][0] + launch_spec2 = spot_fleet_config["LaunchSpecifications"][1] - assert 'SpotPrice' not in launch_spec1 - assert 'SpotPrice' not in launch_spec2 + assert "SpotPrice" not in launch_spec1 + assert "SpotPrice" not in launch_spec2 @mock_ec2 @@ -2262,19 +2065,15 @@ def test_stack_elbv2_resources_integration(): }, "Resources": { "alb": { - "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", - "Properties": { - "Name": "myelbv2", - "Scheme": "internet-facing", - "Subnets": [{ - "Ref": "mysubnet", - }], - "SecurityGroups": [{ - "Ref": "mysg", - }], - "Type": "application", - "IpAddressType": "ipv4", - } + "Type": "AWS::ElasticLoadBalancingV2::LoadBalancer", + "Properties": { + "Name": "myelbv2", + "Scheme": "internet-facing", + "Subnets": [{"Ref": "mysubnet"}], + "SecurityGroups": [{"Ref": "mysg"}], + "Type": "application", + "IpAddressType": "ipv4", + }, }, "mytargetgroup1": { "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", @@ -2286,23 +2085,14 @@ def test_stack_elbv2_resources_integration(): "HealthCheckTimeoutSeconds": 5, "HealthyThresholdCount": 30, "UnhealthyThresholdCount": 5, - "Matcher": { - "HttpCode": "200,201" - }, + "Matcher": {"HttpCode": "200,201"}, "Name": "mytargetgroup1", "Port": 80, "Protocol": "HTTP", "TargetType": "instance", - "Targets": [{ - "Id": { - "Ref": "ec2instance", - "Port": 80, - }, - }], - "VpcId": { - "Ref": "myvpc", - } - } + "Targets": [{"Id": {"Ref": "ec2instance", "Port": 80}}], + "VpcId": {"Ref": "myvpc"}, + }, }, "mytargetgroup2": { "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", @@ -2318,250 +2108,206 @@ def test_stack_elbv2_resources_integration(): "Port": 8080, "Protocol": "HTTP", "TargetType": "instance", - "Targets": [{ - "Id": { - "Ref": "ec2instance", - "Port": 8080, - }, - }], - "VpcId": { - "Ref": "myvpc", - } - } + "Targets": [{"Id": {"Ref": "ec2instance", "Port": 8080}}], + "VpcId": {"Ref": "myvpc"}, + }, }, "listener": { "Type": "AWS::ElasticLoadBalancingV2::Listener", "Properties": { - "DefaultActions": [{ - "Type": "forward", - "TargetGroupArn": {"Ref": "mytargetgroup1"} - }], + "DefaultActions": [ + {"Type": "forward", "TargetGroupArn": {"Ref": "mytargetgroup1"}} + ], "LoadBalancerArn": {"Ref": "alb"}, "Port": "80", - "Protocol": "HTTP" - } + "Protocol": "HTTP", + }, }, "myvpc": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - } + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "mysubnet": { "Type": "AWS::EC2::Subnet", - "Properties": { - "CidrBlock": "10.0.0.0/27", - "VpcId": {"Ref": "myvpc"}, - } + "Properties": {"CidrBlock": "10.0.0.0/27", "VpcId": {"Ref": "myvpc"}}, }, "mysg": { "Type": "AWS::EC2::SecurityGroup", "Properties": { "GroupName": "mysg", "GroupDescription": "test security group", - "VpcId": {"Ref": "myvpc"} - } + "VpcId": {"Ref": "myvpc"}, + }, }, "ec2instance": { "Type": "AWS::EC2::Instance", - "Properties": { - "ImageId": "ami-1234abcd", - "UserData": "some user data", - } + "Properties": {"ImageId": "ami-1234abcd", "UserData": "some user data"}, }, }, } alb_template_json = json.dumps(alb_template) cfn_conn = boto3.client("cloudformation", "us-west-1") - cfn_conn.create_stack( - StackName="elb_stack", - TemplateBody=alb_template_json, - ) + cfn_conn.create_stack(StackName="elb_stack", TemplateBody=alb_template_json) elbv2_conn = boto3.client("elbv2", "us-west-1") - load_balancers = elbv2_conn.describe_load_balancers()['LoadBalancers'] + load_balancers = elbv2_conn.describe_load_balancers()["LoadBalancers"] len(load_balancers).should.equal(1) - load_balancers[0]['LoadBalancerName'].should.equal('myelbv2') - load_balancers[0]['Scheme'].should.equal('internet-facing') - load_balancers[0]['Type'].should.equal('application') - load_balancers[0]['IpAddressType'].should.equal('ipv4') + load_balancers[0]["LoadBalancerName"].should.equal("myelbv2") + load_balancers[0]["Scheme"].should.equal("internet-facing") + load_balancers[0]["Type"].should.equal("application") + load_balancers[0]["IpAddressType"].should.equal("ipv4") target_groups = sorted( - elbv2_conn.describe_target_groups()['TargetGroups'], - key=lambda tg: tg['TargetGroupName']) # sort to do comparison with indexes + elbv2_conn.describe_target_groups()["TargetGroups"], + key=lambda tg: tg["TargetGroupName"], + ) # sort to do comparison with indexes len(target_groups).should.equal(2) - target_groups[0]['HealthCheckIntervalSeconds'].should.equal(30) - target_groups[0]['HealthCheckPath'].should.equal('/status') - target_groups[0]['HealthCheckPort'].should.equal('80') - target_groups[0]['HealthCheckProtocol'].should.equal('HTTP') - target_groups[0]['HealthCheckTimeoutSeconds'].should.equal(5) - target_groups[0]['HealthyThresholdCount'].should.equal(30) - target_groups[0]['UnhealthyThresholdCount'].should.equal(5) - target_groups[0]['Matcher'].should.equal({'HttpCode': '200,201'}) - target_groups[0]['TargetGroupName'].should.equal('mytargetgroup1') - target_groups[0]['Port'].should.equal(80) - target_groups[0]['Protocol'].should.equal('HTTP') - target_groups[0]['TargetType'].should.equal('instance') + target_groups[0]["HealthCheckIntervalSeconds"].should.equal(30) + target_groups[0]["HealthCheckPath"].should.equal("/status") + target_groups[0]["HealthCheckPort"].should.equal("80") + target_groups[0]["HealthCheckProtocol"].should.equal("HTTP") + target_groups[0]["HealthCheckTimeoutSeconds"].should.equal(5) + target_groups[0]["HealthyThresholdCount"].should.equal(30) + target_groups[0]["UnhealthyThresholdCount"].should.equal(5) + target_groups[0]["Matcher"].should.equal({"HttpCode": "200,201"}) + target_groups[0]["TargetGroupName"].should.equal("mytargetgroup1") + target_groups[0]["Port"].should.equal(80) + target_groups[0]["Protocol"].should.equal("HTTP") + target_groups[0]["TargetType"].should.equal("instance") - target_groups[1]['HealthCheckIntervalSeconds'].should.equal(30) - target_groups[1]['HealthCheckPath'].should.equal('/status') - target_groups[1]['HealthCheckPort'].should.equal('8080') - target_groups[1]['HealthCheckProtocol'].should.equal('HTTP') - target_groups[1]['HealthCheckTimeoutSeconds'].should.equal(5) - target_groups[1]['HealthyThresholdCount'].should.equal(30) - target_groups[1]['UnhealthyThresholdCount'].should.equal(5) - target_groups[1]['Matcher'].should.equal({'HttpCode': '200'}) - target_groups[1]['TargetGroupName'].should.equal('mytargetgroup2') - target_groups[1]['Port'].should.equal(8080) - target_groups[1]['Protocol'].should.equal('HTTP') - target_groups[1]['TargetType'].should.equal('instance') + target_groups[1]["HealthCheckIntervalSeconds"].should.equal(30) + target_groups[1]["HealthCheckPath"].should.equal("/status") + target_groups[1]["HealthCheckPort"].should.equal("8080") + target_groups[1]["HealthCheckProtocol"].should.equal("HTTP") + target_groups[1]["HealthCheckTimeoutSeconds"].should.equal(5) + target_groups[1]["HealthyThresholdCount"].should.equal(30) + target_groups[1]["UnhealthyThresholdCount"].should.equal(5) + target_groups[1]["Matcher"].should.equal({"HttpCode": "200"}) + target_groups[1]["TargetGroupName"].should.equal("mytargetgroup2") + target_groups[1]["Port"].should.equal(8080) + target_groups[1]["Protocol"].should.equal("HTTP") + target_groups[1]["TargetType"].should.equal("instance") - listeners = elbv2_conn.describe_listeners(LoadBalancerArn=load_balancers[0]['LoadBalancerArn'])['Listeners'] + listeners = elbv2_conn.describe_listeners( + LoadBalancerArn=load_balancers[0]["LoadBalancerArn"] + )["Listeners"] len(listeners).should.equal(1) - listeners[0]['LoadBalancerArn'].should.equal(load_balancers[0]['LoadBalancerArn']) - listeners[0]['Port'].should.equal(80) - listeners[0]['Protocol'].should.equal('HTTP') - listeners[0]['DefaultActions'].should.equal([{ - "Type": "forward", - "TargetGroupArn": target_groups[0]['TargetGroupArn'] - }]) + listeners[0]["LoadBalancerArn"].should.equal(load_balancers[0]["LoadBalancerArn"]) + listeners[0]["Port"].should.equal(80) + listeners[0]["Protocol"].should.equal("HTTP") + listeners[0]["DefaultActions"].should.equal( + [{"Type": "forward", "TargetGroupArn": target_groups[0]["TargetGroupArn"]}] + ) # test outputs - stacks = cfn_conn.describe_stacks(StackName='elb_stack')['Stacks'] + stacks = cfn_conn.describe_stacks(StackName="elb_stack")["Stacks"] len(stacks).should.equal(1) - dns = list(filter(lambda item: item['OutputKey'] == 'albdns', stacks[0]['Outputs']))[0] - name = list(filter(lambda item: item['OutputKey'] == 'albname', stacks[0]['Outputs']))[0] + dns = list( + filter(lambda item: item["OutputKey"] == "albdns", stacks[0]["Outputs"]) + )[0] + name = list( + filter(lambda item: item["OutputKey"] == "albname", stacks[0]["Outputs"]) + )[0] - dns['OutputValue'].should.equal(load_balancers[0]['DNSName']) - name['OutputValue'].should.equal(load_balancers[0]['LoadBalancerName']) + dns["OutputValue"].should.equal(load_balancers[0]["DNSName"]) + name["OutputValue"].should.equal(load_balancers[0]["LoadBalancerName"]) @mock_dynamodb2 @mock_cloudformation def test_stack_dynamodb_resources_integration(): dynamodb_template = { - "AWSTemplateFormatVersion": "2010-09-09", - "Resources": { - "myDynamoDBTable": { - "Type": "AWS::DynamoDB::Table", - "Properties": { - "AttributeDefinitions": [ - { - "AttributeName": "Album", - "AttributeType": "S" - }, - { - "AttributeName": "Artist", - "AttributeType": "S" - }, - { - "AttributeName": "Sales", - "AttributeType": "N" - }, - { - "AttributeName": "NumberOfSongs", - "AttributeType": "N" - } - ], - "KeySchema": [ - { - "AttributeName": "Album", - "KeyType": "HASH" - }, - { - "AttributeName": "Artist", - "KeyType": "RANGE" - } - ], - "ProvisionedThroughput": { - "ReadCapacityUnits": "5", - "WriteCapacityUnits": "5" - }, - "TableName": "myTableName", - "GlobalSecondaryIndexes": [{ - "IndexName": "myGSI", - "KeySchema": [ - { - "AttributeName": "Sales", - "KeyType": "HASH" + "AWSTemplateFormatVersion": "2010-09-09", + "Resources": { + "myDynamoDBTable": { + "Type": "AWS::DynamoDB::Table", + "Properties": { + "AttributeDefinitions": [ + {"AttributeName": "Album", "AttributeType": "S"}, + {"AttributeName": "Artist", "AttributeType": "S"}, + {"AttributeName": "Sales", "AttributeType": "N"}, + {"AttributeName": "NumberOfSongs", "AttributeType": "N"}, + ], + "KeySchema": [ + {"AttributeName": "Album", "KeyType": "HASH"}, + {"AttributeName": "Artist", "KeyType": "RANGE"}, + ], + "ProvisionedThroughput": { + "ReadCapacityUnits": "5", + "WriteCapacityUnits": "5", + }, + "TableName": "myTableName", + "GlobalSecondaryIndexes": [ + { + "IndexName": "myGSI", + "KeySchema": [ + {"AttributeName": "Sales", "KeyType": "HASH"}, + {"AttributeName": "Artist", "KeyType": "RANGE"}, + ], + "Projection": { + "NonKeyAttributes": ["Album", "NumberOfSongs"], + "ProjectionType": "INCLUDE", + }, + "ProvisionedThroughput": { + "ReadCapacityUnits": "5", + "WriteCapacityUnits": "5", + }, + }, + { + "IndexName": "myGSI2", + "KeySchema": [ + {"AttributeName": "NumberOfSongs", "KeyType": "HASH"}, + {"AttributeName": "Sales", "KeyType": "RANGE"}, + ], + "Projection": { + "NonKeyAttributes": ["Album", "Artist"], + "ProjectionType": "INCLUDE", + }, + "ProvisionedThroughput": { + "ReadCapacityUnits": "5", + "WriteCapacityUnits": "5", + }, + }, + ], + "LocalSecondaryIndexes": [ + { + "IndexName": "myLSI", + "KeySchema": [ + {"AttributeName": "Album", "KeyType": "HASH"}, + {"AttributeName": "Sales", "KeyType": "RANGE"}, + ], + "Projection": { + "NonKeyAttributes": ["Artist", "NumberOfSongs"], + "ProjectionType": "INCLUDE", + }, + } + ], }, - { - "AttributeName": "Artist", - "KeyType": "RANGE" - } - ], - "Projection": { - "NonKeyAttributes": ["Album","NumberOfSongs"], - "ProjectionType": "INCLUDE" - }, - "ProvisionedThroughput": { - "ReadCapacityUnits": "5", - "WriteCapacityUnits": "5" - } - }, - { - "IndexName": "myGSI2", - "KeySchema": [ - { - "AttributeName": "NumberOfSongs", - "KeyType": "HASH" - }, - { - "AttributeName": "Sales", - "KeyType": "RANGE" - } - ], - "Projection": { - "NonKeyAttributes": ["Album","Artist"], - "ProjectionType": "INCLUDE" - }, - "ProvisionedThroughput": { - "ReadCapacityUnits": "5", - "WriteCapacityUnits": "5" - } - }], - "LocalSecondaryIndexes":[{ - "IndexName": "myLSI", - "KeySchema": [ - { - "AttributeName": "Album", - "KeyType": "HASH" - }, - { - "AttributeName": "Sales", - "KeyType": "RANGE" - } - ], - "Projection": { - "NonKeyAttributes": ["Artist","NumberOfSongs"], - "ProjectionType": "INCLUDE" - } - }] - } - } - } + } + }, } dynamodb_template_json = json.dumps(dynamodb_template) - cfn_conn = boto3.client('cloudformation', 'us-east-1') + cfn_conn = boto3.client("cloudformation", "us-east-1") cfn_conn.create_stack( - StackName='dynamodb_stack', - TemplateBody=dynamodb_template_json, + StackName="dynamodb_stack", TemplateBody=dynamodb_template_json ) - dynamodb_conn = boto3.resource('dynamodb', region_name='us-east-1') - table = dynamodb_conn.Table('myTableName') - table.name.should.equal('myTableName') + dynamodb_conn = boto3.resource("dynamodb", region_name="us-east-1") + table = dynamodb_conn.Table("myTableName") + table.name.should.equal("myTableName") - table.put_item(Item={"Album": "myAlbum", "Artist": "myArtist", "Sales": 10, "NumberOfSongs": 5}) + table.put_item( + Item={"Album": "myAlbum", "Artist": "myArtist", "Sales": 10, "NumberOfSongs": 5} + ) response = table.get_item(Key={"Album": "myAlbum", "Artist": "myArtist"}) - response['Item']['Album'].should.equal('myAlbum') - response['Item']['Sales'].should.equal(Decimal('10')) - response['Item']['NumberOfSongs'].should.equal(Decimal('5')) - response['Item']['Album'].should.equal('myAlbum') + response["Item"]["Album"].should.equal("myAlbum") + response["Item"]["Sales"].should.equal(Decimal("10")) + response["Item"]["NumberOfSongs"].should.equal(Decimal("5")) + response["Item"]["Album"].should.equal("myAlbum") diff --git a/tests/test_cloudformation/test_import_value.py b/tests/test_cloudformation/test_import_value.py index 04c2b5801..41a8a8a30 100644 --- a/tests/test_cloudformation/test_import_value.py +++ b/tests/test_cloudformation/test_import_value.py @@ -11,9 +11,9 @@ from botocore.exceptions import ClientError # Package modules from moto import mock_cloudformation -AWS_REGION = 'us-west-1' +AWS_REGION = "us-west-1" -SG_STACK_NAME = 'simple-sg-stack' +SG_STACK_NAME = "simple-sg-stack" SG_TEMPLATE = """ AWSTemplateFormatVersion: 2010-09-09 Description: Simple test CF template for moto_cloudformation @@ -42,7 +42,7 @@ Outputs: """ -EC2_STACK_NAME = 'simple-ec2-stack' +EC2_STACK_NAME = "simple-ec2-stack" EC2_TEMPLATE = """ --- # The latest template format version is "2010-09-09" and as of 2018-04-09 @@ -65,23 +65,25 @@ class TestSimpleInstance(unittest.TestCase): def test_simple_instance(self): """Test that we can create a simple CloudFormation stack that imports values from an existing CloudFormation stack""" with mock_cloudformation(): - client = boto3.client('cloudformation', region_name=AWS_REGION) + client = boto3.client("cloudformation", region_name=AWS_REGION) client.create_stack(StackName=SG_STACK_NAME, TemplateBody=SG_TEMPLATE) - response = client.create_stack(StackName=EC2_STACK_NAME, TemplateBody=EC2_TEMPLATE) - self.assertIn('StackId', response) - response = client.describe_stacks(StackName=response['StackId']) - self.assertIn('Stacks', response) - stack_info = response['Stacks'] + response = client.create_stack( + StackName=EC2_STACK_NAME, TemplateBody=EC2_TEMPLATE + ) + self.assertIn("StackId", response) + response = client.describe_stacks(StackName=response["StackId"]) + self.assertIn("Stacks", response) + stack_info = response["Stacks"] self.assertEqual(1, len(stack_info)) - self.assertIn('StackName', stack_info[0]) - self.assertEqual(EC2_STACK_NAME, stack_info[0]['StackName']) + self.assertIn("StackName", stack_info[0]) + self.assertEqual(EC2_STACK_NAME, stack_info[0]["StackName"]) def test_simple_instance_missing_export(self): """Test that we get an exception if a CloudFormation stack tries to imports a non-existent export value""" with mock_cloudformation(): - client = boto3.client('cloudformation', region_name=AWS_REGION) + client = boto3.client("cloudformation", region_name=AWS_REGION) with self.assertRaises(ClientError) as e: client.create_stack(StackName=EC2_STACK_NAME, TemplateBody=EC2_TEMPLATE) - self.assertIn('Error', e.exception.response) - self.assertIn('Code', e.exception.response['Error']) - self.assertEqual('ExportNotFound', e.exception.response['Error']['Code']) + self.assertIn("Error", e.exception.response) + self.assertIn("Code", e.exception.response["Error"]) + self.assertEqual("ExportNotFound", e.exception.response["Error"]["Code"]) diff --git a/tests/test_cloudformation/test_server.py b/tests/test_cloudformation/test_server.py index de3ab77b5..f3f037c42 100644 --- a/tests/test_cloudformation/test_server.py +++ b/tests/test_cloudformation/test_server.py @@ -7,27 +7,30 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_cloudformation_server_get(): backend = server.create_backend_app("cloudformation") - stack_name = 'test stack' + stack_name = "test stack" test_client = backend.test_client() - template_body = { - "Resources": {}, - } - create_stack_resp = test_client.action_data("CreateStack", StackName=stack_name, - TemplateBody=json.dumps(template_body)) + template_body = {"Resources": {}} + create_stack_resp = test_client.action_data( + "CreateStack", StackName=stack_name, TemplateBody=json.dumps(template_body) + ) create_stack_resp.should.match( - r".*.*.*.*.*", re.DOTALL) + r".*.*.*.*.*", + re.DOTALL, + ) stack_id_from_create_response = re.search( - "(.*)", create_stack_resp).groups()[0] + "(.*)", create_stack_resp + ).groups()[0] list_stacks_resp = test_client.action_data("ListStacks") stack_id_from_list_response = re.search( - "(.*)", list_stacks_resp).groups()[0] + "(.*)", list_stacks_resp + ).groups()[0] stack_id_from_create_response.should.equal(stack_id_from_list_response) diff --git a/tests/test_cloudformation/test_stack_parsing.py b/tests/test_cloudformation/test_stack_parsing.py index 25242e352..85df76592 100644 --- a/tests/test_cloudformation/test_stack_parsing.py +++ b/tests/test_cloudformation/test_stack_parsing.py @@ -7,91 +7,57 @@ import sure # noqa from moto.cloudformation.exceptions import ValidationError from moto.cloudformation.models import FakeStack -from moto.cloudformation.parsing import resource_class_from_type, parse_condition, Export +from moto.cloudformation.parsing import ( + resource_class_from_type, + parse_condition, + Export, +) from moto.sqs.models import Queue from moto.s3.models import FakeBucket from moto.cloudformation.utils import yaml_tag_constructor from boto.cloudformation.stack import Output - dummy_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "Create a multi-az, load balanced, Auto Scaled sample web site. The Auto Scaling trigger is based on the CPU utilization of the web servers. The AMI is chosen based on the region in which the stack is run. This example creates a web service running across all availability zones in a region. The instances are load balanced with a simple health check. The web site is available on port 80, however, the instances can be configured to listen on any port (8888 by default). **WARNING** This template creates one or more Amazon EC2 instances. You will be billed for the AWS resources used if you create a stack from this template.", - "Resources": { "Queue": { "Type": "AWS::SQS::Queue", - "Properties": { - "QueueName": "my-queue", - "VisibilityTimeout": 60, - } - }, - "S3Bucket": { - "Type": "AWS::S3::Bucket", - "DeletionPolicy": "Retain" + "Properties": {"QueueName": "my-queue", "VisibilityTimeout": 60}, }, + "S3Bucket": {"Type": "AWS::S3::Bucket", "DeletionPolicy": "Retain"}, }, } name_type_template = { "AWSTemplateFormatVersion": "2010-09-09", - "Description": "Create a multi-az, load balanced, Auto Scaled sample web site. The Auto Scaling trigger is based on the CPU utilization of the web servers. The AMI is chosen based on the region in which the stack is run. This example creates a web service running across all availability zones in a region. The instances are load balanced with a simple health check. The web site is available on port 80, however, the instances can be configured to listen on any port (8888 by default). **WARNING** This template creates one or more Amazon EC2 instances. You will be billed for the AWS resources used if you create a stack from this template.", - "Resources": { - "Queue": { - "Type": "AWS::SQS::Queue", - "Properties": { - "VisibilityTimeout": 60, - } - }, + "Queue": {"Type": "AWS::SQS::Queue", "Properties": {"VisibilityTimeout": 60}} }, } output_dict = { "Outputs": { - "Output1": { - "Value": {"Ref": "Queue"}, - "Description": "This is a description." - } + "Output1": {"Value": {"Ref": "Queue"}, "Description": "This is a description."} } } bad_output = { - "Outputs": { - "Output1": { - "Value": {"Fn::GetAtt": ["Queue", "InvalidAttribute"]} - } - } + "Outputs": {"Output1": {"Value": {"Fn::GetAtt": ["Queue", "InvalidAttribute"]}}} } get_attribute_output = { - "Outputs": { - "Output1": { - "Value": {"Fn::GetAtt": ["Queue", "QueueName"]} - } - } + "Outputs": {"Output1": {"Value": {"Fn::GetAtt": ["Queue", "QueueName"]}}} } -get_availability_zones_output = { - "Outputs": { - "Output1": { - "Value": {"Fn::GetAZs": ""} - } - } -} +get_availability_zones_output = {"Outputs": {"Output1": {"Value": {"Fn::GetAZs": ""}}}} parameters = { "Parameters": { - "Param": { - "Type": "String", - }, - "NoEchoParam": { - "Type": "String", - "NoEcho": True - } + "Param": {"Type": "String"}, + "NoEchoParam": {"Type": "String", "NoEcho": True}, } } @@ -101,11 +67,11 @@ split_select_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Select": [ "1", {"Fn::Split": [ "-", "123-myqueue" ] } ] }, + "QueueName": {"Fn::Select": ["1", {"Fn::Split": ["-", "123-myqueue"]}]}, "VisibilityTimeout": 60, - } + }, } - } + }, } sub_template = { @@ -114,18 +80,18 @@ sub_template = { "Queue1": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Sub": '${AWS::StackName}-queue-${!Literal}'}, + "QueueName": {"Fn::Sub": "${AWS::StackName}-queue-${!Literal}"}, "VisibilityTimeout": 60, - } + }, }, "Queue2": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Sub": '${Queue1.QueueName}'}, + "QueueName": {"Fn::Sub": "${Queue1.QueueName}"}, "VisibilityTimeout": 60, - } + }, }, - } + }, } export_value_template = { @@ -134,17 +100,12 @@ export_value_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::Sub": '${AWS::StackName}-queue'}, + "QueueName": {"Fn::Sub": "${AWS::StackName}-queue"}, "VisibilityTimeout": 60, - } + }, } }, - "Outputs": { - "Output1": { - "Value": "value", - "Export": {"Name": 'queue-us-west-1'} - } - } + "Outputs": {"Output1": {"Value": "value", "Export": {"Name": "queue-us-west-1"}}}, } import_value_template = { @@ -153,33 +114,30 @@ import_value_template = { "Queue": { "Type": "AWS::SQS::Queue", "Properties": { - "QueueName": {"Fn::ImportValue": 'queue-us-west-1'}, + "QueueName": {"Fn::ImportValue": "queue-us-west-1"}, "VisibilityTimeout": 60, - } + }, } - } + }, } -outputs_template = dict(list(dummy_template.items()) + - list(output_dict.items())) -bad_outputs_template = dict( - list(dummy_template.items()) + list(bad_output.items())) +outputs_template = dict(list(dummy_template.items()) + list(output_dict.items())) +bad_outputs_template = dict(list(dummy_template.items()) + list(bad_output.items())) get_attribute_outputs_template = dict( - list(dummy_template.items()) + list(get_attribute_output.items())) + list(dummy_template.items()) + list(get_attribute_output.items()) +) get_availability_zones_template = dict( - list(dummy_template.items()) + list(get_availability_zones_output.items())) + list(dummy_template.items()) + list(get_availability_zones_output.items()) +) -parameters_template = dict( - list(dummy_template.items()) + list(parameters.items())) +parameters_template = dict(list(dummy_template.items()) + list(parameters.items())) dummy_template_json = json.dumps(dummy_template) name_type_template_json = json.dumps(name_type_template) output_type_template_json = json.dumps(outputs_template) bad_output_template_json = json.dumps(bad_outputs_template) -get_attribute_outputs_template_json = json.dumps( - get_attribute_outputs_template) -get_availability_zones_template_json = json.dumps( - get_availability_zones_template) +get_attribute_outputs_template_json = json.dumps(get_attribute_outputs_template) +get_availability_zones_template_json = json.dumps(get_availability_zones_template) parameters_template_json = json.dumps(parameters_template) split_select_template_json = json.dumps(split_select_template) sub_template_json = json.dumps(sub_template) @@ -193,15 +151,16 @@ def test_parse_stack_resources(): name="test_stack", template=dummy_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(2) - queue = stack.resource_map['Queue'] + queue = stack.resource_map["Queue"] queue.should.be.a(Queue) queue.name.should.equal("my-queue") - bucket = stack.resource_map['S3Bucket'] + bucket = stack.resource_map["S3Bucket"] bucket.should.be.a(FakeBucket) bucket.physical_resource_id.should.equal(bucket.name) @@ -209,8 +168,7 @@ def test_parse_stack_resources(): @patch("moto.cloudformation.parsing.logger") def test_missing_resource_logs(logger): resource_class_from_type("foobar") - logger.warning.assert_called_with( - 'No Moto CloudFormation support for %s', 'foobar') + logger.warning.assert_called_with("No Moto CloudFormation support for %s", "foobar") def test_parse_stack_with_name_type_resource(): @@ -219,10 +177,11 @@ def test_parse_stack_with_name_type_resource(): name="test_stack", template=name_type_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(1) - list(stack.resource_map.keys())[0].should.equal('Queue') + list(stack.resource_map.keys())[0].should.equal("Queue") queue = list(stack.resource_map.values())[0] queue.should.be.a(Queue) @@ -233,10 +192,11 @@ def test_parse_stack_with_yaml_template(): name="test_stack", template=yaml.dump(name_type_template), parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(1) - list(stack.resource_map.keys())[0].should.equal('Queue') + list(stack.resource_map.keys())[0].should.equal("Queue") queue = list(stack.resource_map.values())[0] queue.should.be.a(Queue) @@ -247,10 +207,11 @@ def test_parse_stack_with_outputs(): name="test_stack", template=output_type_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('Output1') + list(stack.output_map.keys())[0].should.equal("Output1") output = list(stack.output_map.values())[0] output.should.be.a(Output) output.description.should.equal("This is a description.") @@ -262,14 +223,16 @@ def test_parse_stack_with_get_attribute_outputs(): name="test_stack", template=get_attribute_outputs_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('Output1') + list(stack.output_map.keys())[0].should.equal("Output1") output = list(stack.output_map.values())[0] output.should.be.a(Output) output.value.should.equal("my-queue") + def test_parse_stack_with_get_attribute_kms(): from .fixtures.kms_key import template @@ -279,31 +242,35 @@ def test_parse_stack_with_get_attribute_kms(): name="test_stack", template=template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('KeyArn') + list(stack.output_map.keys())[0].should.equal("KeyArn") output = list(stack.output_map.values())[0] output.should.be.a(Output) + def test_parse_stack_with_get_availability_zones(): stack = FakeStack( stack_id="test_id", name="test_stack", template=get_availability_zones_template_json, parameters={}, - region_name='us-east-1') + region_name="us-east-1", + ) stack.output_map.should.have.length_of(1) - list(stack.output_map.keys())[0].should.equal('Output1') + list(stack.output_map.keys())[0].should.equal("Output1") output = list(stack.output_map.values())[0] output.should.be.a(Output) - output.value.should.equal([ "us-east-1a", "us-east-1b", "us-east-1c", "us-east-1d" ]) + output.value.should.equal(["us-east-1a", "us-east-1b", "us-east-1c", "us-east-1d"]) def test_parse_stack_with_bad_get_attribute_outputs(): FakeStack.when.called_with( - "test_id", "test_stack", bad_output_template_json, {}, "us-west-1").should.throw(ValidationError) + "test_id", "test_stack", bad_output_template_json, {}, "us-west-1" + ).should.throw(ValidationError) def test_parse_stack_with_parameters(): @@ -312,7 +279,8 @@ def test_parse_stack_with_parameters(): name="test_stack", template=parameters_template_json, parameters={"Param": "visible value", "NoEchoParam": "hidden value"}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.no_echo_parameter_keys.should.have("NoEchoParam") stack.resource_map.no_echo_parameter_keys.should_not.have("Param") @@ -334,21 +302,13 @@ def test_parse_equals_condition(): def test_parse_not_condition(): parse_condition( - condition={ - "Fn::Not": [{ - "Fn::Equals": [{"Ref": "EnvType"}, "prod"] - }] - }, + condition={"Fn::Not": [{"Fn::Equals": [{"Ref": "EnvType"}, "prod"]}]}, resources_map={"EnvType": "prod"}, condition_map={}, ).should.equal(False) parse_condition( - condition={ - "Fn::Not": [{ - "Fn::Equals": [{"Ref": "EnvType"}, "prod"] - }] - }, + condition={"Fn::Not": [{"Fn::Equals": [{"Ref": "EnvType"}, "prod"]}]}, resources_map={"EnvType": "staging"}, condition_map={}, ).should.equal(True) @@ -416,10 +376,11 @@ def test_parse_split_and_select(): name="test_stack", template=split_select_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) stack.resource_map.should.have.length_of(1) - queue = stack.resource_map['Queue'] + queue = stack.resource_map["Queue"] queue.name.should.equal("myqueue") @@ -429,10 +390,11 @@ def test_sub(): name="test_stack", template=sub_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) - queue1 = stack.resource_map['Queue1'] - queue2 = stack.resource_map['Queue2'] + queue1 = stack.resource_map["Queue1"] + queue2 = stack.resource_map["Queue2"] queue2.name.should.equal(queue1.name) @@ -442,20 +404,21 @@ def test_import(): name="test_stack", template=export_value_template_json, parameters={}, - region_name='us-west-1') + region_name="us-west-1", + ) import_stack = FakeStack( stack_id="test_id", name="test_stack", template=import_value_template_json, parameters={}, - region_name='us-west-1', - cross_stack_resources={export_stack.exports[0].value: export_stack.exports[0]}) + region_name="us-west-1", + cross_stack_resources={export_stack.exports[0].value: export_stack.exports[0]}, + ) - queue = import_stack.resource_map['Queue'] + queue = import_stack.resource_map["Queue"] queue.name.should.equal("value") - def test_short_form_func_in_yaml_teamplate(): template = """--- KeyB64: !Base64 valueToEncode @@ -476,24 +439,24 @@ def test_short_form_func_in_yaml_teamplate(): KeySplit: !Split [A, B] KeySub: !Sub A """ - yaml.add_multi_constructor('', yaml_tag_constructor, Loader=yaml.Loader) + yaml.add_multi_constructor("", yaml_tag_constructor, Loader=yaml.Loader) template_dict = yaml.load(template, Loader=yaml.Loader) key_and_expects = [ - ['KeyRef', {'Ref': 'foo'}], - ['KeyB64', {'Fn::Base64': 'valueToEncode'}], - ['KeyAnd', {'Fn::And': ['A', 'B']}], - ['KeyEquals', {'Fn::Equals': ['A', 'B']}], - ['KeyIf', {'Fn::If': ['A', 'B', 'C']}], - ['KeyNot', {'Fn::Not': ['A']}], - ['KeyOr', {'Fn::Or': ['A', 'B']}], - ['KeyFindInMap', {'Fn::FindInMap': ['A', 'B', 'C']}], - ['KeyGetAtt', {'Fn::GetAtt': ['A', 'B']}], - ['KeyGetAZs', {'Fn::GetAZs': 'A'}], - ['KeyImportValue', {'Fn::ImportValue': 'A'}], - ['KeyJoin', {'Fn::Join': [ ":", [ 'A', 'B', 'C' ] ]}], - ['KeySelect', {'Fn::Select': ['A', 'B']}], - ['KeySplit', {'Fn::Split': ['A', 'B']}], - ['KeySub', {'Fn::Sub': 'A'}], + ["KeyRef", {"Ref": "foo"}], + ["KeyB64", {"Fn::Base64": "valueToEncode"}], + ["KeyAnd", {"Fn::And": ["A", "B"]}], + ["KeyEquals", {"Fn::Equals": ["A", "B"]}], + ["KeyIf", {"Fn::If": ["A", "B", "C"]}], + ["KeyNot", {"Fn::Not": ["A"]}], + ["KeyOr", {"Fn::Or": ["A", "B"]}], + ["KeyFindInMap", {"Fn::FindInMap": ["A", "B", "C"]}], + ["KeyGetAtt", {"Fn::GetAtt": ["A", "B"]}], + ["KeyGetAZs", {"Fn::GetAZs": "A"}], + ["KeyImportValue", {"Fn::ImportValue": "A"}], + ["KeyJoin", {"Fn::Join": [":", ["A", "B", "C"]]}], + ["KeySelect", {"Fn::Select": ["A", "B"]}], + ["KeySplit", {"Fn::Split": ["A", "B"]}], + ["KeySub", {"Fn::Sub": "A"}], ] for k, v in key_and_expects: template_dict.should.have.key(k).which.should.be.equal(v) diff --git a/tests/test_cloudformation/test_validate.py b/tests/test_cloudformation/test_validate.py index e2c3af05d..4dd4d7e08 100644 --- a/tests/test_cloudformation/test_validate.py +++ b/tests/test_cloudformation/test_validate.py @@ -9,7 +9,11 @@ import botocore from moto.cloudformation.exceptions import ValidationError from moto.cloudformation.models import FakeStack -from moto.cloudformation.parsing import resource_class_from_type, parse_condition, Export +from moto.cloudformation.parsing import ( + resource_class_from_type, + parse_condition, + Export, +) from moto.sqs.models import Queue from moto.s3.models import FakeBucket from moto.cloudformation.utils import yaml_tag_constructor @@ -27,25 +31,16 @@ json_template = { "KeyName": "dummy", "InstanceType": "t2.micro", "Tags": [ - { - "Key": "Description", - "Value": "Test tag" - }, - { - "Key": "Name", - "Value": "Name tag for tests" - } - ] - } + {"Key": "Description", "Value": "Test tag"}, + {"Key": "Name", "Value": "Name tag for tests"}, + ], + }, } - } + }, } # One resource is required -json_bad_template = { - "AWSTemplateFormatVersion": "2010-09-09", - "Description": "Stack 1" -} +json_bad_template = {"AWSTemplateFormatVersion": "2010-09-09", "Description": "Stack 1"} dummy_template_json = json.dumps(json_template) dummy_bad_template_json = json.dumps(json_bad_template) @@ -53,25 +48,25 @@ dummy_bad_template_json = json.dumps(json_bad_template) @mock_cloudformation def test_boto3_json_validate_successful(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - response = cf_conn.validate_template( - TemplateBody=dummy_template_json, - ) - assert response['Description'] == "Stack 1" - assert response['Parameters'] == [] - assert response['ResponseMetadata']['HTTPStatusCode'] == 200 + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + response = cf_conn.validate_template(TemplateBody=dummy_template_json) + assert response["Description"] == "Stack 1" + assert response["Parameters"] == [] + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + @mock_cloudformation def test_boto3_json_invalid_missing_resource(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") try: - cf_conn.validate_template( - TemplateBody=dummy_bad_template_json, - ) + cf_conn.validate_template(TemplateBody=dummy_bad_template_json) assert False except botocore.exceptions.ClientError as e: - assert str(e) == 'An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack' \ - ' with id Missing top level item Resources to file module does not exist' + assert ( + str(e) + == "An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack" + " with id Missing top level item Resources to file module does not exist" + ) assert True @@ -91,25 +86,26 @@ yaml_bad_template = """ Description: Simple CloudFormation Test Template """ + @mock_cloudformation def test_boto3_yaml_validate_successful(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') - response = cf_conn.validate_template( - TemplateBody=yaml_template, - ) - assert response['Description'] == "Simple CloudFormation Test Template" - assert response['Parameters'] == [] - assert response['ResponseMetadata']['HTTPStatusCode'] == 200 + cf_conn = boto3.client("cloudformation", region_name="us-east-1") + response = cf_conn.validate_template(TemplateBody=yaml_template) + assert response["Description"] == "Simple CloudFormation Test Template" + assert response["Parameters"] == [] + assert response["ResponseMetadata"]["HTTPStatusCode"] == 200 + @mock_cloudformation def test_boto3_yaml_invalid_missing_resource(): - cf_conn = boto3.client('cloudformation', region_name='us-east-1') + cf_conn = boto3.client("cloudformation", region_name="us-east-1") try: - cf_conn.validate_template( - TemplateBody=yaml_bad_template, - ) + cf_conn.validate_template(TemplateBody=yaml_bad_template) assert False except botocore.exceptions.ClientError as e: - assert str(e) == 'An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack' \ - ' with id Missing top level item Resources to file module does not exist' + assert ( + str(e) + == "An error occurred (ValidationError) when calling the ValidateTemplate operation: Stack" + " with id Missing top level item Resources to file module does not exist" + ) assert True diff --git a/tests/test_cloudwatch/test_cloudwatch.py b/tests/test_cloudwatch/test_cloudwatch.py index a0f3871c0..5d60dd7ee 100644 --- a/tests/test_cloudwatch/test_cloudwatch.py +++ b/tests/test_cloudwatch/test_cloudwatch.py @@ -9,22 +9,22 @@ from moto import mock_cloudwatch_deprecated def alarm_fixture(name="tester", action=None): - action = action or ['arn:alarm'] + action = action or ["arn:alarm"] return MetricAlarm( name=name, namespace="{0}_namespace".format(name), metric="{0}_metric".format(name), - comparison='>=', + comparison=">=", threshold=2.0, period=60, evaluation_periods=5, - statistic='Average', - description='A test', - dimensions={'InstanceId': ['i-0123456,i-0123457']}, + statistic="Average", + description="A test", + dimensions={"InstanceId": ["i-0123456,i-0123457"]}, alarm_actions=action, - ok_actions=['arn:ok'], - insufficient_data_actions=['arn:insufficient'], - unit='Seconds', + ok_actions=["arn:ok"], + insufficient_data_actions=["arn:insufficient"], + unit="Seconds", ) @@ -38,21 +38,20 @@ def test_create_alarm(): alarms = conn.describe_alarms() alarms.should.have.length_of(1) alarm = alarms[0] - alarm.name.should.equal('tester') - alarm.namespace.should.equal('tester_namespace') - alarm.metric.should.equal('tester_metric') - alarm.comparison.should.equal('>=') + alarm.name.should.equal("tester") + alarm.namespace.should.equal("tester_namespace") + alarm.metric.should.equal("tester_metric") + alarm.comparison.should.equal(">=") alarm.threshold.should.equal(2.0) alarm.period.should.equal(60) alarm.evaluation_periods.should.equal(5) - alarm.statistic.should.equal('Average') - alarm.description.should.equal('A test') - dict(alarm.dimensions).should.equal( - {'InstanceId': ['i-0123456,i-0123457']}) - list(alarm.alarm_actions).should.equal(['arn:alarm']) - list(alarm.ok_actions).should.equal(['arn:ok']) - list(alarm.insufficient_data_actions).should.equal(['arn:insufficient']) - alarm.unit.should.equal('Seconds') + alarm.statistic.should.equal("Average") + alarm.description.should.equal("A test") + dict(alarm.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]}) + list(alarm.alarm_actions).should.equal(["arn:alarm"]) + list(alarm.ok_actions).should.equal(["arn:ok"]) + list(alarm.insufficient_data_actions).should.equal(["arn:insufficient"]) + alarm.unit.should.equal("Seconds") @mock_cloudwatch_deprecated @@ -79,19 +78,18 @@ def test_put_metric_data(): conn = boto.connect_cloudwatch() conn.put_metric_data( - namespace='tester', - name='metric', + namespace="tester", + name="metric", value=1.5, - dimensions={'InstanceId': ['i-0123456,i-0123457']}, + dimensions={"InstanceId": ["i-0123456,i-0123457"]}, ) metrics = conn.list_metrics() metrics.should.have.length_of(1) metric = metrics[0] - metric.namespace.should.equal('tester') - metric.name.should.equal('metric') - dict(metric.dimensions).should.equal( - {'InstanceId': ['i-0123456,i-0123457']}) + metric.namespace.should.equal("tester") + metric.name.should.equal("metric") + dict(metric.dimensions).should.equal({"InstanceId": ["i-0123456,i-0123457"]}) @mock_cloudwatch_deprecated @@ -110,8 +108,7 @@ def test_describe_alarms(): alarms.should.have.length_of(4) alarms = conn.describe_alarms(alarm_name_prefix="nfoo") alarms.should.have.length_of(2) - alarms = conn.describe_alarms( - alarm_names=["nfoobar", "nbarfoo", "nbazfoo"]) + alarms = conn.describe_alarms(alarm_names=["nfoobar", "nbarfoo", "nbazfoo"]) alarms.should.have.length_of(3) alarms = conn.describe_alarms(action_prefix="afoo") alarms.should.have.length_of(2) diff --git a/tests/test_cloudwatch/test_cloudwatch_boto3.py b/tests/test_cloudwatch/test_cloudwatch_boto3.py index 40b5eee08..8df6cf6d4 100755 --- a/tests/test_cloudwatch/test_cloudwatch_boto3.py +++ b/tests/test_cloudwatch/test_cloudwatch_boto3.py @@ -4,221 +4,202 @@ import boto3 from botocore.exceptions import ClientError from datetime import datetime, timedelta import pytz -import sure # noqa +import sure # noqa from moto import mock_cloudwatch @mock_cloudwatch def test_put_list_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') + client = boto3.client("cloudwatch", region_name="eu-central-1") widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - client.put_dashboard(DashboardName='test1', DashboardBody=widget) + client.put_dashboard(DashboardName="test1", DashboardBody=widget) resp = client.list_dashboards() - len(resp['DashboardEntries']).should.equal(1) + len(resp["DashboardEntries"]).should.equal(1) @mock_cloudwatch def test_put_list_prefix_nomatch_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') + client = boto3.client("cloudwatch", region_name="eu-central-1") widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - resp = client.list_dashboards(DashboardNamePrefix='nomatch') + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + resp = client.list_dashboards(DashboardNamePrefix="nomatch") - len(resp['DashboardEntries']).should.equal(0) + len(resp["DashboardEntries"]).should.equal(0) @mock_cloudwatch def test_delete_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') + client = boto3.client("cloudwatch", region_name="eu-central-1") widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - client.put_dashboard(DashboardName='test2', DashboardBody=widget) - client.put_dashboard(DashboardName='test3', DashboardBody=widget) - client.delete_dashboards(DashboardNames=['test2', 'test1']) + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + client.put_dashboard(DashboardName="test2", DashboardBody=widget) + client.put_dashboard(DashboardName="test3", DashboardBody=widget) + client.delete_dashboards(DashboardNames=["test2", "test1"]) - resp = client.list_dashboards(DashboardNamePrefix='test3') - len(resp['DashboardEntries']).should.equal(1) + resp = client.list_dashboards(DashboardNamePrefix="test3") + len(resp["DashboardEntries"]).should.equal(1) @mock_cloudwatch def test_delete_dashboard_fail(): - client = boto3.client('cloudwatch', region_name='eu-central-1') + client = boto3.client("cloudwatch", region_name="eu-central-1") widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - client.put_dashboard(DashboardName='test1', DashboardBody=widget) - client.put_dashboard(DashboardName='test2', DashboardBody=widget) - client.put_dashboard(DashboardName='test3', DashboardBody=widget) + client.put_dashboard(DashboardName="test1", DashboardBody=widget) + client.put_dashboard(DashboardName="test2", DashboardBody=widget) + client.put_dashboard(DashboardName="test3", DashboardBody=widget) # Doesnt delete anything if all dashboards to be deleted do not exist try: - client.delete_dashboards(DashboardNames=['test2', 'test1', 'test_no_match']) + client.delete_dashboards(DashboardNames=["test2", "test1", "test_no_match"]) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFound') + err.response["Error"]["Code"].should.equal("ResourceNotFound") else: - raise RuntimeError('Should of raised error') + raise RuntimeError("Should of raised error") resp = client.list_dashboards() - len(resp['DashboardEntries']).should.equal(3) + len(resp["DashboardEntries"]).should.equal(3) @mock_cloudwatch def test_get_dashboard(): - client = boto3.client('cloudwatch', region_name='eu-central-1') + client = boto3.client("cloudwatch", region_name="eu-central-1") widget = '{"widgets": [{"type": "text", "x": 0, "y": 7, "width": 3, "height": 3, "properties": {"markdown": "Hello world"}}]}' - client.put_dashboard(DashboardName='test1', DashboardBody=widget) + client.put_dashboard(DashboardName="test1", DashboardBody=widget) - resp = client.get_dashboard(DashboardName='test1') - resp.should.contain('DashboardArn') - resp.should.contain('DashboardBody') - resp['DashboardName'].should.equal('test1') + resp = client.get_dashboard(DashboardName="test1") + resp.should.contain("DashboardArn") + resp.should.contain("DashboardBody") + resp["DashboardName"].should.equal("test1") @mock_cloudwatch def test_get_dashboard_fail(): - client = boto3.client('cloudwatch', region_name='eu-central-1') + client = boto3.client("cloudwatch", region_name="eu-central-1") try: - client.get_dashboard(DashboardName='test1') + client.get_dashboard(DashboardName="test1") except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFound') + err.response["Error"]["Code"].should.equal("ResourceNotFound") else: - raise RuntimeError('Should of raised error') + raise RuntimeError("Should of raised error") @mock_cloudwatch def test_alarm_state(): - client = boto3.client('cloudwatch', region_name='eu-central-1') + client = boto3.client("cloudwatch", region_name="eu-central-1") client.put_metric_alarm( - AlarmName='testalarm1', - MetricName='cpu', - Namespace='blah', + AlarmName="testalarm1", + MetricName="cpu", + Namespace="blah", Period=10, EvaluationPeriods=5, - Statistic='Average', + Statistic="Average", Threshold=2, - ComparisonOperator='GreaterThanThreshold', + ComparisonOperator="GreaterThanThreshold", ) client.put_metric_alarm( - AlarmName='testalarm2', - MetricName='cpu', - Namespace='blah', + AlarmName="testalarm2", + MetricName="cpu", + Namespace="blah", Period=10, EvaluationPeriods=5, - Statistic='Average', + Statistic="Average", Threshold=2, - ComparisonOperator='GreaterThanThreshold', + ComparisonOperator="GreaterThanThreshold", ) # This is tested implicitly as if it doesnt work the rest will die client.set_alarm_state( - AlarmName='testalarm1', - StateValue='ALARM', - StateReason='testreason', - StateReasonData='{"some": "json_data"}' + AlarmName="testalarm1", + StateValue="ALARM", + StateReason="testreason", + StateReasonData='{"some": "json_data"}', ) - resp = client.describe_alarms( - StateValue='ALARM' - ) - len(resp['MetricAlarms']).should.equal(1) - resp['MetricAlarms'][0]['AlarmName'].should.equal('testalarm1') - resp['MetricAlarms'][0]['StateValue'].should.equal('ALARM') + resp = client.describe_alarms(StateValue="ALARM") + len(resp["MetricAlarms"]).should.equal(1) + resp["MetricAlarms"][0]["AlarmName"].should.equal("testalarm1") + resp["MetricAlarms"][0]["StateValue"].should.equal("ALARM") - resp = client.describe_alarms( - StateValue='OK' - ) - len(resp['MetricAlarms']).should.equal(1) - resp['MetricAlarms'][0]['AlarmName'].should.equal('testalarm2') - resp['MetricAlarms'][0]['StateValue'].should.equal('OK') + resp = client.describe_alarms(StateValue="OK") + len(resp["MetricAlarms"]).should.equal(1) + resp["MetricAlarms"][0]["AlarmName"].should.equal("testalarm2") + resp["MetricAlarms"][0]["StateValue"].should.equal("OK") # Just for sanity resp = client.describe_alarms() - len(resp['MetricAlarms']).should.equal(2) + len(resp["MetricAlarms"]).should.equal(2) @mock_cloudwatch def test_put_metric_data_no_dimensions(): - conn = boto3.client('cloudwatch', region_name='us-east-1') + conn = boto3.client("cloudwatch", region_name="us-east-1") conn.put_metric_data( - Namespace='tester', - MetricData=[ - dict( - MetricName='metric', - Value=1.5, - ) - ] + Namespace="tester", MetricData=[dict(MetricName="metric", Value=1.5)] ) - metrics = conn.list_metrics()['Metrics'] + metrics = conn.list_metrics()["Metrics"] metrics.should.have.length_of(1) metric = metrics[0] - metric['Namespace'].should.equal('tester') - metric['MetricName'].should.equal('metric') - + metric["Namespace"].should.equal("tester") + metric["MetricName"].should.equal("metric") @mock_cloudwatch def test_put_metric_data_with_statistics(): - conn = boto3.client('cloudwatch', region_name='us-east-1') + conn = boto3.client("cloudwatch", region_name="us-east-1") conn.put_metric_data( - Namespace='tester', + Namespace="tester", MetricData=[ dict( - MetricName='statmetric', + MetricName="statmetric", Timestamp=datetime(2015, 1, 1), # no Value to test https://github.com/spulec/moto/issues/1615 StatisticValues=dict( - SampleCount=123.0, - Sum=123.0, - Minimum=123.0, - Maximum=123.0 + SampleCount=123.0, Sum=123.0, Minimum=123.0, Maximum=123.0 ), - Unit='Milliseconds', - StorageResolution=123 + Unit="Milliseconds", + StorageResolution=123, ) - ] + ], ) - metrics = conn.list_metrics()['Metrics'] + metrics = conn.list_metrics()["Metrics"] metrics.should.have.length_of(1) metric = metrics[0] - metric['Namespace'].should.equal('tester') - metric['MetricName'].should.equal('statmetric') + metric["Namespace"].should.equal("tester") + metric["MetricName"].should.equal("statmetric") # TODO: test statistics - https://github.com/spulec/moto/issues/1615 + @mock_cloudwatch def test_get_metric_statistics(): - conn = boto3.client('cloudwatch', region_name='us-east-1') + conn = boto3.client("cloudwatch", region_name="us-east-1") utc_now = datetime.now(tz=pytz.utc) conn.put_metric_data( - Namespace='tester', - MetricData=[ - dict( - MetricName='metric', - Value=1.5, - Timestamp=utc_now - ) - ] + Namespace="tester", + MetricData=[dict(MetricName="metric", Value=1.5, Timestamp=utc_now)], ) stats = conn.get_metric_statistics( - Namespace='tester', - MetricName='metric', + Namespace="tester", + MetricName="metric", StartTime=utc_now - timedelta(seconds=60), EndTime=utc_now + timedelta(seconds=60), Period=60, - Statistics=['SampleCount', 'Sum'] + Statistics=["SampleCount", "Sum"], ) - stats['Datapoints'].should.have.length_of(1) - datapoint = stats['Datapoints'][0] - datapoint['SampleCount'].should.equal(1.0) - datapoint['Sum'].should.equal(1.5) + stats["Datapoints"].should.have.length_of(1) + datapoint = stats["Datapoints"][0] + datapoint["SampleCount"].should.equal(1.0) + datapoint["Sum"].should.equal(1.5) diff --git a/tests/test_cognitoidentity/test_cognitoidentity.py b/tests/test_cognitoidentity/test_cognitoidentity.py index 67679e896..c338891b6 100644 --- a/tests/test_cognitoidentity/test_cognitoidentity.py +++ b/tests/test_cognitoidentity/test_cognitoidentity.py @@ -10,132 +10,136 @@ from moto.cognitoidentity.utils import get_random_identity_id @mock_cognitoidentity def test_create_identity_pool(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") - result = conn.create_identity_pool(IdentityPoolName='TestPool', + result = conn.create_identity_pool( + IdentityPoolName="TestPool", AllowUnauthenticatedIdentities=False, - SupportedLoginProviders={'graph.facebook.com': '123456789012345'}, - DeveloperProviderName='devname', - OpenIdConnectProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'], + SupportedLoginProviders={"graph.facebook.com": "123456789012345"}, + DeveloperProviderName="devname", + OpenIdConnectProviderARNs=["arn:aws:rds:eu-west-2:123456789012:db:mysql-db"], CognitoIdentityProviders=[ { - 'ProviderName': 'testprovider', - 'ClientId': 'CLIENT12345', - 'ServerSideTokenCheck': True - }, + "ProviderName": "testprovider", + "ClientId": "CLIENT12345", + "ServerSideTokenCheck": True, + } ], - SamlProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db']) - assert result['IdentityPoolId'] != '' + SamlProviderARNs=["arn:aws:rds:eu-west-2:123456789012:db:mysql-db"], + ) + assert result["IdentityPoolId"] != "" @mock_cognitoidentity def test_describe_identity_pool(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") - res = conn.create_identity_pool(IdentityPoolName='TestPool', + res = conn.create_identity_pool( + IdentityPoolName="TestPool", AllowUnauthenticatedIdentities=False, - SupportedLoginProviders={'graph.facebook.com': '123456789012345'}, - DeveloperProviderName='devname', - OpenIdConnectProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db'], + SupportedLoginProviders={"graph.facebook.com": "123456789012345"}, + DeveloperProviderName="devname", + OpenIdConnectProviderARNs=["arn:aws:rds:eu-west-2:123456789012:db:mysql-db"], CognitoIdentityProviders=[ { - 'ProviderName': 'testprovider', - 'ClientId': 'CLIENT12345', - 'ServerSideTokenCheck': True - }, + "ProviderName": "testprovider", + "ClientId": "CLIENT12345", + "ServerSideTokenCheck": True, + } ], - SamlProviderARNs=['arn:aws:rds:eu-west-2:123456789012:db:mysql-db']) + SamlProviderARNs=["arn:aws:rds:eu-west-2:123456789012:db:mysql-db"], + ) - result = conn.describe_identity_pool(IdentityPoolId=res['IdentityPoolId']) + result = conn.describe_identity_pool(IdentityPoolId=res["IdentityPoolId"]) - assert result['IdentityPoolId'] == res['IdentityPoolId'] - assert result['AllowUnauthenticatedIdentities'] == res['AllowUnauthenticatedIdentities'] - assert result['SupportedLoginProviders'] == res['SupportedLoginProviders'] - assert result['DeveloperProviderName'] == res['DeveloperProviderName'] - assert result['OpenIdConnectProviderARNs'] == res['OpenIdConnectProviderARNs'] - assert result['CognitoIdentityProviders'] == res['CognitoIdentityProviders'] - assert result['SamlProviderARNs'] == res['SamlProviderARNs'] + assert result["IdentityPoolId"] == res["IdentityPoolId"] + assert ( + result["AllowUnauthenticatedIdentities"] + == res["AllowUnauthenticatedIdentities"] + ) + assert result["SupportedLoginProviders"] == res["SupportedLoginProviders"] + assert result["DeveloperProviderName"] == res["DeveloperProviderName"] + assert result["OpenIdConnectProviderARNs"] == res["OpenIdConnectProviderARNs"] + assert result["CognitoIdentityProviders"] == res["CognitoIdentityProviders"] + assert result["SamlProviderARNs"] == res["SamlProviderARNs"] @mock_cognitoidentity def test_describe_identity_pool_with_invalid_id_raises_error(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") with assert_raises(ClientError) as cm: - conn.describe_identity_pool(IdentityPoolId='us-west-2_non-existent') + conn.describe_identity_pool(IdentityPoolId="us-west-2_non-existent") - cm.exception.operation_name.should.equal('DescribeIdentityPool') - cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.operation_name.should.equal("DescribeIdentityPool") + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) # testing a helper function def test_get_random_identity_id(): - assert len(get_random_identity_id('us-west-2')) > 0 - assert len(get_random_identity_id('us-west-2').split(':')[1]) == 19 + assert len(get_random_identity_id("us-west-2")) > 0 + assert len(get_random_identity_id("us-west-2").split(":")[1]) == 19 @mock_cognitoidentity def test_get_id(): # These two do NOT work in server mode. They just don't return the data from the model. - conn = boto3.client('cognito-identity', 'us-west-2') - result = conn.get_id(AccountId='someaccount', - IdentityPoolId='us-west-2:12345', - Logins={ - 'someurl': '12345' - }) + conn = boto3.client("cognito-identity", "us-west-2") + result = conn.get_id( + AccountId="someaccount", + IdentityPoolId="us-west-2:12345", + Logins={"someurl": "12345"}, + ) print(result) - assert result.get('IdentityId', "").startswith('us-west-2') or result.get('ResponseMetadata').get( - 'HTTPStatusCode') == 200 + assert ( + result.get("IdentityId", "").startswith("us-west-2") + or result.get("ResponseMetadata").get("HTTPStatusCode") == 200 + ) @mock_cognitoidentity def test_get_credentials_for_identity(): # These two do NOT work in server mode. They just don't return the data from the model. - conn = boto3.client('cognito-identity', 'us-west-2') - result = conn.get_credentials_for_identity(IdentityId='12345') + conn = boto3.client("cognito-identity", "us-west-2") + result = conn.get_credentials_for_identity(IdentityId="12345") - assert result.get('Expiration', 0) > 0 or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 - assert result.get('IdentityId') == '12345' or result.get('ResponseMetadata').get('HTTPStatusCode') == 200 + assert ( + result.get("Expiration", 0) > 0 + or result.get("ResponseMetadata").get("HTTPStatusCode") == 200 + ) + assert ( + result.get("IdentityId") == "12345" + or result.get("ResponseMetadata").get("HTTPStatusCode") == 200 + ) @mock_cognitoidentity def test_get_open_id_token_for_developer_identity(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") result = conn.get_open_id_token_for_developer_identity( - IdentityPoolId='us-west-2:12345', - IdentityId='12345', - Logins={ - 'someurl': '12345' - }, - TokenDuration=123 + IdentityPoolId="us-west-2:12345", + IdentityId="12345", + Logins={"someurl": "12345"}, + TokenDuration=123, ) - assert len(result['Token']) > 0 - assert result['IdentityId'] == '12345' + assert len(result["Token"]) > 0 + assert result["IdentityId"] == "12345" @mock_cognitoidentity def test_get_open_id_token_for_developer_identity_when_no_explicit_identity_id(): - conn = boto3.client('cognito-identity', 'us-west-2') + conn = boto3.client("cognito-identity", "us-west-2") result = conn.get_open_id_token_for_developer_identity( - IdentityPoolId='us-west-2:12345', - Logins={ - 'someurl': '12345' - }, - TokenDuration=123 + IdentityPoolId="us-west-2:12345", Logins={"someurl": "12345"}, TokenDuration=123 ) - assert len(result['Token']) > 0 - assert len(result['IdentityId']) > 0 + assert len(result["Token"]) > 0 + assert len(result["IdentityId"]) > 0 @mock_cognitoidentity def test_get_open_id_token(): - conn = boto3.client('cognito-identity', 'us-west-2') - result = conn.get_open_id_token( - IdentityId='12345', - Logins={ - 'someurl': '12345' - } - ) - assert len(result['Token']) > 0 - assert result['IdentityId'] == '12345' + conn = boto3.client("cognito-identity", "us-west-2") + result = conn.get_open_id_token(IdentityId="12345", Logins={"someurl": "12345"}) + assert len(result["Token"]) > 0 + assert result["IdentityId"] == "12345" diff --git a/tests/test_cognitoidentity/test_server.py b/tests/test_cognitoidentity/test_server.py index b63d42bc0..903dae290 100644 --- a/tests/test_cognitoidentity/test_server.py +++ b/tests/test_cognitoidentity/test_server.py @@ -6,9 +6,9 @@ import sure # noqa import moto.server as server from moto import mock_cognitoidentity -''' +""" Test the different server responses -''' +""" @mock_cognitoidentity @@ -17,14 +17,16 @@ def test_create_identity_pool(): backend = server.create_backend_app("cognito-identity") test_client = backend.test_client() - res = test_client.post('/', - data={"IdentityPoolName": "test", "AllowUnauthenticatedIdentities": True}, - headers={ - "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.CreateIdentityPool"}, - ) + res = test_client.post( + "/", + data={"IdentityPoolName": "test", "AllowUnauthenticatedIdentities": True}, + headers={ + "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.CreateIdentityPool" + }, + ) json_data = json.loads(res.data.decode("utf-8")) - assert json_data['IdentityPoolName'] == "test" + assert json_data["IdentityPoolName"] == "test" @mock_cognitoidentity @@ -32,14 +34,20 @@ def test_get_id(): backend = server.create_backend_app("cognito-identity") test_client = backend.test_client() - res = test_client.post('/', - data=json.dumps({'AccountId': 'someaccount', - 'IdentityPoolId': 'us-west-2:12345', - 'Logins': {'someurl': '12345'}}), - headers={ - "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.GetId"}, - ) + res = test_client.post( + "/", + data=json.dumps( + { + "AccountId": "someaccount", + "IdentityPoolId": "us-west-2:12345", + "Logins": {"someurl": "12345"}, + } + ), + headers={ + "X-Amz-Target": "com.amazonaws.cognito.identity.model.AWSCognitoIdentityService.GetId" + }, + ) print(res.data) json_data = json.loads(res.data.decode("utf-8")) - assert ':' in json_data['IdentityId'] + assert ":" in json_data["IdentityId"] diff --git a/tests/test_cognitoidp/test_cognitoidp.py b/tests/test_cognitoidp/test_cognitoidp.py index 5f9d4a153..82e866ff6 100644 --- a/tests/test_cognitoidp/test_cognitoidp.py +++ b/tests/test_cognitoidp/test_cognitoidp.py @@ -6,6 +6,7 @@ import random import uuid import boto3 + # noinspection PyUnresolvedReferences import sure # noqa from botocore.exceptions import ClientError @@ -21,15 +22,10 @@ def test_create_user_pool(): name = str(uuid.uuid4()) value = str(uuid.uuid4()) - result = conn.create_user_pool( - PoolName=name, - LambdaConfig={ - "PreSignUp": value - } - ) + result = conn.create_user_pool(PoolName=name, LambdaConfig={"PreSignUp": value}) result["UserPool"]["Id"].should_not.be.none - result["UserPool"]["Id"].should.match(r'[\w-]+_[0-9a-zA-Z]+') + result["UserPool"]["Id"].should.match(r"[\w-]+_[0-9a-zA-Z]+") result["UserPool"]["Name"].should.equal(name) result["UserPool"]["LambdaConfig"]["PreSignUp"].should.equal(value) @@ -102,10 +98,7 @@ def test_describe_user_pool(): name = str(uuid.uuid4()) value = str(uuid.uuid4()) user_pool_details = conn.create_user_pool( - PoolName=name, - LambdaConfig={ - "PreSignUp": value - } + PoolName=name, LambdaConfig={"PreSignUp": value} ) result = conn.describe_user_pool(UserPoolId=user_pool_details["UserPool"]["Id"]) @@ -139,7 +132,7 @@ def test_create_user_pool_domain_custom_domain_config(): domain = str(uuid.uuid4()) custom_domain_config = { - "CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012", + "CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012" } user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] result = conn.create_user_pool_domain( @@ -184,7 +177,7 @@ def test_update_user_pool_domain(): domain = str(uuid.uuid4()) custom_domain_config = { - "CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012", + "CertificateArn": "arn:aws:acm:us-east-1:123456789012:certificate/123456789012" } user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] conn.create_user_pool_domain(UserPoolId=user_pool_id, Domain=domain) @@ -203,9 +196,7 @@ def test_create_user_pool_client(): value = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] result = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=client_name, - CallbackURLs=[value], + UserPoolId=user_pool_id, ClientName=client_name, CallbackURLs=[value] ) result["UserPoolClient"]["UserPoolId"].should.equal(user_pool_id) @@ -236,11 +227,11 @@ def test_list_user_pool_clients_returns_max_items(): client_count = 10 for i in range(client_count): client_name = str(uuid.uuid4()) - conn.create_user_pool_client(UserPoolId=user_pool_id, - ClientName=client_name) + conn.create_user_pool_client(UserPoolId=user_pool_id, ClientName=client_name) max_results = 5 - result = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["UserPoolClients"].should.have.length_of(max_results) result.should.have.key("NextToken") @@ -254,18 +245,18 @@ def test_list_user_pool_clients_returns_next_tokens(): client_count = 10 for i in range(client_count): client_name = str(uuid.uuid4()) - conn.create_user_pool_client(UserPoolId=user_pool_id, - ClientName=client_name) + conn.create_user_pool_client(UserPoolId=user_pool_id, ClientName=client_name) max_results = 5 - result = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["UserPoolClients"].should.have.length_of(max_results) result.should.have.key("NextToken") next_token = result["NextToken"] - result_2 = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results, - NextToken=next_token) + result_2 = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results, NextToken=next_token + ) result_2["UserPoolClients"].should.have.length_of(max_results) result_2.shouldnt.have.key("NextToken") @@ -279,11 +270,11 @@ def test_list_user_pool_clients_when_max_items_more_than_total_items(): client_count = 10 for i in range(client_count): client_name = str(uuid.uuid4()) - conn.create_user_pool_client(UserPoolId=user_pool_id, - ClientName=client_name) + conn.create_user_pool_client(UserPoolId=user_pool_id, ClientName=client_name) max_results = client_count + 5 - result = conn.list_user_pool_clients(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_user_pool_clients( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["UserPoolClients"].should.have.length_of(client_count) result.shouldnt.have.key("NextToken") @@ -296,14 +287,11 @@ def test_describe_user_pool_client(): value = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_details = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=client_name, - CallbackURLs=[value], + UserPoolId=user_pool_id, ClientName=client_name, CallbackURLs=[value] ) result = conn.describe_user_pool_client( - UserPoolId=user_pool_id, - ClientId=client_details["UserPoolClient"]["ClientId"], + UserPoolId=user_pool_id, ClientId=client_details["UserPoolClient"]["ClientId"] ) result["UserPoolClient"]["ClientName"].should.equal(client_name) @@ -321,9 +309,7 @@ def test_update_user_pool_client(): new_value = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_details = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=old_client_name, - CallbackURLs=[old_value], + UserPoolId=user_pool_id, ClientName=old_client_name, CallbackURLs=[old_value] ) result = conn.update_user_pool_client( @@ -344,13 +330,11 @@ def test_delete_user_pool_client(): user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_details = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=str(uuid.uuid4()), + UserPoolId=user_pool_id, ClientName=str(uuid.uuid4()) ) conn.delete_user_pool_client( - UserPoolId=user_pool_id, - ClientId=client_details["UserPoolClient"]["ClientId"], + UserPoolId=user_pool_id, ClientId=client_details["UserPoolClient"]["ClientId"] ) caught = False @@ -377,9 +361,7 @@ def test_create_identity_provider(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) result["IdentityProvider"]["UserPoolId"].should.equal(user_pool_id) @@ -402,10 +384,7 @@ def test_list_identity_providers(): ProviderDetails={}, ) - result = conn.list_identity_providers( - UserPoolId=user_pool_id, - MaxResults=10, - ) + result = conn.list_identity_providers(UserPoolId=user_pool_id, MaxResults=10) result["Providers"].should.have.length_of(1) result["Providers"][0]["ProviderName"].should.equal(provider_name) @@ -430,8 +409,9 @@ def test_list_identity_providers_returns_max_items(): ) max_results = 5 - result = conn.list_identity_providers(UserPoolId=user_pool_id, - MaxResults=max_results) + result = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["Providers"].should.have.length_of(max_results) result.should.have.key("NextToken") @@ -454,14 +434,16 @@ def test_list_identity_providers_returns_next_tokens(): ) max_results = 5 - result = conn.list_identity_providers(UserPoolId=user_pool_id, MaxResults=max_results) + result = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["Providers"].should.have.length_of(max_results) result.should.have.key("NextToken") next_token = result["NextToken"] - result_2 = conn.list_identity_providers(UserPoolId=user_pool_id, - MaxResults=max_results, - NextToken=next_token) + result_2 = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results, NextToken=next_token + ) result_2["Providers"].should.have.length_of(max_results) result_2.shouldnt.have.key("NextToken") @@ -484,7 +466,9 @@ def test_list_identity_providers_when_max_items_more_than_total_items(): ) max_results = identity_provider_count + 5 - result = conn.list_identity_providers(UserPoolId=user_pool_id, MaxResults=max_results) + result = conn.list_identity_providers( + UserPoolId=user_pool_id, MaxResults=max_results + ) result["Providers"].should.have.length_of(identity_provider_count) result.shouldnt.have.key("NextToken") @@ -501,14 +485,11 @@ def test_describe_identity_providers(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) result = conn.describe_identity_provider( - UserPoolId=user_pool_id, - ProviderName=provider_name, + UserPoolId=user_pool_id, ProviderName=provider_name ) result["IdentityProvider"]["UserPoolId"].should.equal(user_pool_id) @@ -530,17 +511,13 @@ def test_update_identity_provider(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) result = conn.update_identity_provider( UserPoolId=user_pool_id, ProviderName=provider_name, - ProviderDetails={ - "thing": new_value - }, + ProviderDetails={"thing": new_value}, ) result["IdentityProvider"]["UserPoolId"].should.equal(user_pool_id) @@ -557,16 +534,12 @@ def test_update_identity_provider_no_user_pool(): with assert_raises(conn.exceptions.ResourceNotFoundException) as cm: conn.update_identity_provider( - UserPoolId="foo", - ProviderName="bar", - ProviderDetails={ - "thing": new_value - }, + UserPoolId="foo", ProviderName="bar", ProviderDetails={"thing": new_value} ) - cm.exception.operation_name.should.equal('UpdateIdentityProvider') - cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.operation_name.should.equal("UpdateIdentityProvider") + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) @mock_cognitoidp @@ -583,14 +556,12 @@ def test_update_identity_provider_no_identity_provider(): conn.update_identity_provider( UserPoolId=user_pool_id, ProviderName="foo", - ProviderDetails={ - "thing": new_value - }, + ProviderDetails={"thing": new_value}, ) - cm.exception.operation_name.should.equal('UpdateIdentityProvider') - cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.operation_name.should.equal("UpdateIdentityProvider") + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) @mock_cognitoidp @@ -605,9 +576,7 @@ def test_delete_identity_providers(): UserPoolId=user_pool_id, ProviderName=provider_name, ProviderType=provider_type, - ProviderDetails={ - "thing": value - }, + ProviderDetails={"thing": value}, ) conn.delete_identity_provider(UserPoolId=user_pool_id, ProviderName=provider_name) @@ -615,8 +584,7 @@ def test_delete_identity_providers(): caught = False try: conn.describe_identity_provider( - UserPoolId=user_pool_id, - ProviderName=provider_name, + UserPoolId=user_pool_id, ProviderName=provider_name ) except conn.exceptions.ResourceNotFoundException: caught = True @@ -662,9 +630,9 @@ def test_create_group_with_duplicate_name_raises_error(): with assert_raises(ClientError) as cm: conn.create_group(GroupName=group_name, UserPoolId=user_pool_id) - cm.exception.operation_name.should.equal('CreateGroup') - cm.exception.response['Error']['Code'].should.equal('GroupExistsException') - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.operation_name.should.equal("CreateGroup") + cm.exception.response["Error"]["Code"].should.equal("GroupExistsException") + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) @mock_cognitoidp @@ -710,7 +678,7 @@ def test_delete_group(): with assert_raises(ClientError) as cm: conn.get_group(GroupName=group_name, UserPoolId=user_pool_id) - cm.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') + cm.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") @mock_cognitoidp @@ -724,7 +692,9 @@ def test_admin_add_user_to_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - result = conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + result = conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected @@ -739,8 +709,12 @@ def test_admin_add_user_to_group_again_is_noop(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) @mock_cognitoidp @@ -754,7 +728,9 @@ def test_list_users_in_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) result = conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name) @@ -775,8 +751,12 @@ def test_list_users_in_group_ignores_deleted_user(): username2 = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username2) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username2, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username2, GroupName=group_name + ) conn.admin_delete_user(UserPoolId=user_pool_id, Username=username) result = conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name) @@ -796,7 +776,9 @@ def test_admin_list_groups_for_user(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) result = conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id) @@ -817,8 +799,12 @@ def test_admin_list_groups_for_user_ignores_deleted_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name2) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name2 + ) conn.delete_group(GroupName=group_name, UserPoolId=user_pool_id) result = conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id) @@ -838,14 +824,20 @@ def test_admin_remove_user_from_group(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) - result = conn.admin_remove_user_from_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + result = conn.admin_remove_user_from_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected - conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name) \ - ["Users"].should.have.length_of(0) - conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id) \ - ["Groups"].should.have.length_of(0) + conn.list_users_in_group(UserPoolId=user_pool_id, GroupName=group_name)[ + "Users" + ].should.have.length_of(0) + conn.admin_list_groups_for_user(Username=username, UserPoolId=user_pool_id)[ + "Groups" + ].should.have.length_of(0) @mock_cognitoidp @@ -859,8 +851,12 @@ def test_admin_remove_user_from_group_again_is_noop(): username = str(uuid.uuid4()) conn.admin_create_user(UserPoolId=user_pool_id, Username=username) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) - conn.admin_add_user_to_group(UserPoolId=user_pool_id, Username=username, GroupName=group_name) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) + conn.admin_add_user_to_group( + UserPoolId=user_pool_id, Username=username, GroupName=group_name + ) @mock_cognitoidp @@ -873,9 +869,7 @@ def test_admin_create_user(): result = conn.admin_create_user( UserPoolId=user_pool_id, Username=username, - UserAttributes=[ - {"Name": "thing", "Value": value} - ], + UserAttributes=[{"Name": "thing", "Value": value}], ) result["User"]["Username"].should.equal(username) @@ -896,9 +890,7 @@ def test_admin_create_existing_user(): conn.admin_create_user( UserPoolId=user_pool_id, Username=username, - UserAttributes=[ - {"Name": "thing", "Value": value} - ], + UserAttributes=[{"Name": "thing", "Value": value}], ) caught = False @@ -906,9 +898,7 @@ def test_admin_create_existing_user(): conn.admin_create_user( UserPoolId=user_pool_id, Username=username, - UserAttributes=[ - {"Name": "thing", "Value": value} - ], + UserAttributes=[{"Name": "thing", "Value": value}], ) except conn.exceptions.UsernameExistsException: caught = True @@ -926,9 +916,7 @@ def test_admin_get_user(): conn.admin_create_user( UserPoolId=user_pool_id, Username=username, - UserAttributes=[ - {"Name": "thing", "Value": value} - ], + UserAttributes=[{"Name": "thing", "Value": value}], ) result = conn.admin_get_user(UserPoolId=user_pool_id, Username=username) @@ -974,8 +962,7 @@ def test_list_users_returns_limit_items(): # Given 10 users user_count = 10 for i in range(user_count): - conn.admin_create_user(UserPoolId=user_pool_id, - Username=str(uuid.uuid4())) + conn.admin_create_user(UserPoolId=user_pool_id, Username=str(uuid.uuid4())) max_results = 5 result = conn.list_users(UserPoolId=user_pool_id, Limit=max_results) result["Users"].should.have.length_of(max_results) @@ -990,8 +977,7 @@ def test_list_users_returns_pagination_tokens(): # Given 10 users user_count = 10 for i in range(user_count): - conn.admin_create_user(UserPoolId=user_pool_id, - Username=str(uuid.uuid4())) + conn.admin_create_user(UserPoolId=user_pool_id, Username=str(uuid.uuid4())) max_results = 5 result = conn.list_users(UserPoolId=user_pool_id, Limit=max_results) @@ -999,8 +985,9 @@ def test_list_users_returns_pagination_tokens(): result.should.have.key("PaginationToken") next_token = result["PaginationToken"] - result_2 = conn.list_users(UserPoolId=user_pool_id, - Limit=max_results, PaginationToken=next_token) + result_2 = conn.list_users( + UserPoolId=user_pool_id, Limit=max_results, PaginationToken=next_token + ) result_2["Users"].should.have.length_of(max_results) result_2.shouldnt.have.key("PaginationToken") @@ -1013,8 +1000,7 @@ def test_list_users_when_limit_more_than_total_items(): # Given 10 users user_count = 10 for i in range(user_count): - conn.admin_create_user(UserPoolId=user_pool_id, - Username=str(uuid.uuid4())) + conn.admin_create_user(UserPoolId=user_pool_id, Username=str(uuid.uuid4())) max_results = user_count + 5 result = conn.list_users(UserPoolId=user_pool_id, Limit=max_results) @@ -1033,8 +1019,9 @@ def test_admin_disable_user(): result = conn.admin_disable_user(UserPoolId=user_pool_id, Username=username) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected - conn.admin_get_user(UserPoolId=user_pool_id, Username=username) \ - ["Enabled"].should.equal(False) + conn.admin_get_user(UserPoolId=user_pool_id, Username=username)[ + "Enabled" + ].should.equal(False) @mock_cognitoidp @@ -1049,8 +1036,9 @@ def test_admin_enable_user(): result = conn.admin_enable_user(UserPoolId=user_pool_id, Username=username) list(result.keys()).should.equal(["ResponseMetadata"]) # No response expected - conn.admin_get_user(UserPoolId=user_pool_id, Username=username) \ - ["Enabled"].should.equal(True) + conn.admin_get_user(UserPoolId=user_pool_id, Username=username)[ + "Enabled" + ].should.equal(True) @mock_cognitoidp @@ -1080,27 +1068,21 @@ def authentication_flow(conn): client_id = conn.create_user_pool_client( UserPoolId=user_pool_id, ClientName=str(uuid.uuid4()), - ReadAttributes=[user_attribute_name] + ReadAttributes=[user_attribute_name], )["UserPoolClient"]["ClientId"] conn.admin_create_user( UserPoolId=user_pool_id, Username=username, TemporaryPassword=temporary_password, - UserAttributes=[{ - 'Name': user_attribute_name, - 'Value': user_attribute_value - }] + UserAttributes=[{"Name": user_attribute_name, "Value": user_attribute_value}], ) result = conn.admin_initiate_auth( UserPoolId=user_pool_id, ClientId=client_id, AuthFlow="ADMIN_NO_SRP_AUTH", - AuthParameters={ - "USERNAME": username, - "PASSWORD": temporary_password - }, + AuthParameters={"USERNAME": username, "PASSWORD": temporary_password}, ) # A newly created user is forced to set a new password @@ -1113,10 +1095,7 @@ def authentication_flow(conn): Session=result["Session"], ClientId=client_id, ChallengeName="NEW_PASSWORD_REQUIRED", - ChallengeResponses={ - "USERNAME": username, - "NEW_PASSWORD": new_password - } + ChallengeResponses={"USERNAME": username, "NEW_PASSWORD": new_password}, ) result["AuthenticationResult"]["IdToken"].should_not.be.none @@ -1129,9 +1108,7 @@ def authentication_flow(conn): "access_token": result["AuthenticationResult"]["AccessToken"], "username": username, "password": new_password, - "additional_fields": { - user_attribute_name: user_attribute_value - } + "additional_fields": {user_attribute_name: user_attribute_value}, } @@ -1154,7 +1131,9 @@ def test_token_legitimacy(): id_token = outputs["id_token"] access_token = outputs["access_token"] client_id = outputs["client_id"] - issuer = "https://cognito-idp.us-west-2.amazonaws.com/{}".format(outputs["user_pool_id"]) + issuer = "https://cognito-idp.us-west-2.amazonaws.com/{}".format( + outputs["user_pool_id"] + ) id_claims = json.loads(jws.verify(id_token, json_web_key, "RS256")) id_claims["iss"].should.equal(issuer) id_claims["aud"].should.equal(client_id) @@ -1185,10 +1164,7 @@ def test_change_password(): UserPoolId=outputs["user_pool_id"], ClientId=outputs["client_id"], AuthFlow="ADMIN_NO_SRP_AUTH", - AuthParameters={ - "USERNAME": outputs["username"], - "PASSWORD": newer_password, - }, + AuthParameters={"USERNAME": outputs["username"], "PASSWORD": newer_password}, ) result["AuthenticationResult"].should_not.be.none @@ -1198,7 +1174,9 @@ def test_change_password(): def test_forgot_password(): conn = boto3.client("cognito-idp", "us-west-2") - result = conn.forgot_password(ClientId=str(uuid.uuid4()), Username=str(uuid.uuid4())) + result = conn.forgot_password( + ClientId=str(uuid.uuid4()), Username=str(uuid.uuid4()) + ) result["CodeDeliveryDetails"].should_not.be.none @@ -1209,14 +1187,11 @@ def test_confirm_forgot_password(): username = str(uuid.uuid4()) user_pool_id = conn.create_user_pool(PoolName=str(uuid.uuid4()))["UserPool"]["Id"] client_id = conn.create_user_pool_client( - UserPoolId=user_pool_id, - ClientName=str(uuid.uuid4()), + UserPoolId=user_pool_id, ClientName=str(uuid.uuid4()) )["UserPoolClient"]["ClientId"] conn.admin_create_user( - UserPoolId=user_pool_id, - Username=username, - TemporaryPassword=str(uuid.uuid4()), + UserPoolId=user_pool_id, Username=username, TemporaryPassword=str(uuid.uuid4()) ) conn.confirm_forgot_password( @@ -1226,6 +1201,7 @@ def test_confirm_forgot_password(): Password=str(uuid.uuid4()), ) + @mock_cognitoidp def test_admin_update_user_attributes(): conn = boto3.client("cognito-idp", "us-west-2") @@ -1237,41 +1213,26 @@ def test_admin_update_user_attributes(): UserPoolId=user_pool_id, Username=username, UserAttributes=[ - { - 'Name': 'family_name', - 'Value': 'Doe', - }, - { - 'Name': 'given_name', - 'Value': 'John', - } - ] + {"Name": "family_name", "Value": "Doe"}, + {"Name": "given_name", "Value": "John"}, + ], ) conn.admin_update_user_attributes( UserPoolId=user_pool_id, Username=username, UserAttributes=[ - { - 'Name': 'family_name', - 'Value': 'Doe', - }, - { - 'Name': 'given_name', - 'Value': 'Jane', - } - ] + {"Name": "family_name", "Value": "Doe"}, + {"Name": "given_name", "Value": "Jane"}, + ], ) - user = conn.admin_get_user( - UserPoolId=user_pool_id, - Username=username - ) - attributes = user['UserAttributes'] + user = conn.admin_get_user(UserPoolId=user_pool_id, Username=username) + attributes = user["UserAttributes"] attributes.should.be.a(list) for attr in attributes: - val = attr['Value'] - if attr['Name'] == 'family_name': - val.should.equal('Doe') - elif attr['Name'] == 'given_name': - val.should.equal('Jane') + val = attr["Value"] + if attr["Name"] == "family_name": + val.should.equal("Doe") + elif attr["Name"] == "given_name": + val.should.equal("Jane") diff --git a/tests/test_config/test_config.py b/tests/test_config/test_config.py index 8b9a5d877..d3751b123 100644 --- a/tests/test_config/test_config.py +++ b/tests/test_config/test_config.py @@ -11,1006 +11,1248 @@ from moto.config import mock_config @mock_config def test_put_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Try without a name supplied: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={'roleARN': 'somearn'}) - assert ce.exception.response['Error']['Code'] == 'InvalidConfigurationRecorderNameException' - assert 'is not valid, blank string.' in ce.exception.response['Error']['Message'] + client.put_configuration_recorder(ConfigurationRecorder={"roleARN": "somearn"}) + assert ( + ce.exception.response["Error"]["Code"] + == "InvalidConfigurationRecorderNameException" + ) + assert "is not valid, blank string." in ce.exception.response["Error"]["Message"] # Try with a really long name: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={'name': 'a' * 257, 'roleARN': 'somearn'}) - assert ce.exception.response['Error']['Code'] == 'ValidationException' - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] + client.put_configuration_recorder( + ConfigurationRecorder={"name": "a" * 257, "roleARN": "somearn"} + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) # With resource types and flags set to True: bad_groups = [ - {'allSupported': True, 'includeGlobalResourceTypes': True, 'resourceTypes': ['item']}, - {'allSupported': False, 'includeGlobalResourceTypes': True, 'resourceTypes': ['item']}, - {'allSupported': True, 'includeGlobalResourceTypes': False, 'resourceTypes': ['item']}, - {'allSupported': False, 'includeGlobalResourceTypes': False, 'resourceTypes': []}, - {'includeGlobalResourceTypes': False, 'resourceTypes': []}, - {'includeGlobalResourceTypes': True}, - {'resourceTypes': []}, - {} + { + "allSupported": True, + "includeGlobalResourceTypes": True, + "resourceTypes": ["item"], + }, + { + "allSupported": False, + "includeGlobalResourceTypes": True, + "resourceTypes": ["item"], + }, + { + "allSupported": True, + "includeGlobalResourceTypes": False, + "resourceTypes": ["item"], + }, + { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": [], + }, + {"includeGlobalResourceTypes": False, "resourceTypes": []}, + {"includeGlobalResourceTypes": True}, + {"resourceTypes": []}, + {}, ] for bg in bad_groups: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'default', - 'roleARN': 'somearn', - 'recordingGroup': bg - }) - assert ce.exception.response['Error']['Code'] == 'InvalidRecordingGroupException' - assert ce.exception.response['Error']['Message'] == 'The recording group provided is not valid' + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "default", + "roleARN": "somearn", + "recordingGroup": bg, + } + ) + assert ( + ce.exception.response["Error"]["Code"] == "InvalidRecordingGroupException" + ) + assert ( + ce.exception.response["Error"]["Message"] + == "The recording group provided is not valid" + ) # With an invalid Resource Type: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'default', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - # 2 good, and 2 bad: - 'resourceTypes': ['AWS::EC2::Volume', 'LOLNO', 'AWS::EC2::VPC', 'LOLSTILLNO'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "default", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + # 2 good, and 2 bad: + "resourceTypes": [ + "AWS::EC2::Volume", + "LOLNO", + "AWS::EC2::VPC", + "LOLSTILLNO", + ], + }, } - }) - assert ce.exception.response['Error']['Code'] == 'ValidationException' - assert "2 validation error detected: Value '['LOLNO', 'LOLSTILLNO']" in str(ce.exception.response['Error']['Message']) - assert 'AWS::EC2::Instance' in ce.exception.response['Error']['Message'] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" + assert "2 validation error detected: Value '['LOLNO', 'LOLSTILLNO']" in str( + ce.exception.response["Error"]["Message"] + ) + assert "AWS::EC2::Instance" in ce.exception.response["Error"]["Message"] # Create a proper one: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert not result[0]['recordingGroup']['allSupported'] - assert not result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert len(result[0]['recordingGroup']['resourceTypes']) == 2 - assert 'AWS::EC2::Volume' in result[0]['recordingGroup']['resourceTypes'] \ - and 'AWS::EC2::VPC' in result[0]['recordingGroup']['resourceTypes'] + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert not result[0]["recordingGroup"]["allSupported"] + assert not result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert len(result[0]["recordingGroup"]["resourceTypes"]) == 2 + assert ( + "AWS::EC2::Volume" in result[0]["recordingGroup"]["resourceTypes"] + and "AWS::EC2::VPC" in result[0]["recordingGroup"]["resourceTypes"] + ) # Now update the configuration recorder: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': True, - 'includeGlobalResourceTypes': True + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": True, + "includeGlobalResourceTypes": True, + }, } - }) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + ) + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert result[0]['recordingGroup']['allSupported'] - assert result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert len(result[0]['recordingGroup']['resourceTypes']) == 0 + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert result[0]["recordingGroup"]["allSupported"] + assert result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert len(result[0]["recordingGroup"]["resourceTypes"]) == 0 # With a default recording group (i.e. lacking one) - client.put_configuration_recorder(ConfigurationRecorder={'name': 'testrecorder', 'roleARN': 'somearn'}) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + client.put_configuration_recorder( + ConfigurationRecorder={"name": "testrecorder", "roleARN": "somearn"} + ) + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert result[0]['recordingGroup']['allSupported'] - assert not result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert not result[0]['recordingGroup'].get('resourceTypes') + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert result[0]["recordingGroup"]["allSupported"] + assert not result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert not result[0]["recordingGroup"].get("resourceTypes") # Can currently only have exactly 1 Config Recorder in an account/region: with assert_raises(ClientError) as ce: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'someotherrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "someotherrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + }, } - }) - assert ce.exception.response['Error']['Code'] == 'MaxNumberOfConfigurationRecordersExceededException' - assert "maximum number of configuration recorders: 1 is reached." in ce.exception.response['Error']['Message'] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "MaxNumberOfConfigurationRecordersExceededException" + ) + assert ( + "maximum number of configuration recorders: 1 is reached." + in ce.exception.response["Error"]["Message"] + ) @mock_config def test_put_configuration_aggregator(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # With too many aggregation sources: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ] + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], }, { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ] - } - ] + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], + }, + ], ) - assert 'Member must have length less than or equal to 1' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 1" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # With an invalid region config (no regions defined): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AllAwsRegions': False + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AllAwsRegions": False, } - ] + ], ) - assert 'Your request does not specify any regions' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "Your request does not specify any regions" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole' - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole" + }, ) - assert 'Your request does not specify any regions' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "Your request does not specify any regions" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # With both region flags defined: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': True + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": True, } - ] + ], ) - assert 'You must choose one of these options' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "You must choose one of these options" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': True - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": True, + }, ) - assert 'You must choose one of these options' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "You must choose one of these options" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # Name too long: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='a' * 257, + ConfigurationAggregatorName="a" * 257, AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } - ] + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], ) - assert 'configurationAggregatorName' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert "configurationAggregatorName" in ce.exception.response["Error"]["Message"] + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Too many tags (>50): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], + Tags=[ + {"Key": "{}".format(x), "Value": "{}".format(x)} for x in range(0, 51) ], - Tags=[{'Key': '{}'.format(x), 'Value': '{}'.format(x)} for x in range(0, 51)] ) - assert 'Member must have length less than or equal to 50' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 50" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag key is too big (>128 chars): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': 'a' * 129, 'Value': 'a'}] + Tags=[{"Key": "a" * 129, "Value": "a"}], ) - assert 'Member must have length less than or equal to 128' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 128" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag value is too big (>256 chars): with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': 'tag', 'Value': 'a' * 257}] + Tags=[{"Key": "tag", "Value": "a" * 257}], ) - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Duplicate Tags: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': 'a', 'Value': 'a'}, {'Key': 'a', 'Value': 'a'}] + Tags=[{"Key": "a", "Value": "a"}, {"Key": "a", "Value": "a"}], ) - assert 'Duplicate tag keys found.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidInput' + assert "Duplicate tag keys found." in ce.exception.response["Error"]["Message"] + assert ce.exception.response["Error"]["Code"] == "InvalidInput" # Invalid characters in the tag key: with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } + {"AccountIds": ["012345678910"], "AllAwsRegions": True} ], - Tags=[{'Key': '!', 'Value': 'a'}] + Tags=[{"Key": "!", "Value": "a"}], ) - assert 'Member must satisfy regular expression pattern:' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must satisfy regular expression pattern:" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # If it contains both the AccountAggregationSources and the OrganizationAggregationSource with assert_raises(ClientError) as ce: client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': False - } + {"AccountIds": ["012345678910"], "AllAwsRegions": False} ], OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AllAwsRegions': False - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AllAwsRegions": False, + }, ) - assert 'AccountAggregationSource and the OrganizationAggregationSource' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + assert ( + "AccountAggregationSource and the OrganizationAggregationSource" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # If it contains neither: with assert_raises(ClientError) as ce: - client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', - ) - assert 'AccountAggregationSource or the OrganizationAggregationSource' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidParameterValueException' + client.put_configuration_aggregator(ConfigurationAggregatorName="testing") + assert ( + "AccountAggregationSource or the OrganizationAggregationSource" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidParameterValueException" # Just make one: account_aggregation_source = { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': False + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": False, } result = client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[account_aggregation_source], ) - assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testing' - assert result['ConfigurationAggregator']['AccountAggregationSources'] == [account_aggregation_source] - assert 'arn:aws:config:us-west-2:123456789012:config-aggregator/config-aggregator-' in \ - result['ConfigurationAggregator']['ConfigurationAggregatorArn'] - assert result['ConfigurationAggregator']['CreationTime'] == result['ConfigurationAggregator']['LastUpdatedTime'] - - # Update the existing one: - original_arn = result['ConfigurationAggregator']['ConfigurationAggregatorArn'] - account_aggregation_source.pop('AwsRegions') - account_aggregation_source['AllAwsRegions'] = True - result = client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', - AccountAggregationSources=[account_aggregation_source] + assert result["ConfigurationAggregator"]["ConfigurationAggregatorName"] == "testing" + assert result["ConfigurationAggregator"]["AccountAggregationSources"] == [ + account_aggregation_source + ] + assert ( + "arn:aws:config:us-west-2:123456789012:config-aggregator/config-aggregator-" + in result["ConfigurationAggregator"]["ConfigurationAggregatorArn"] + ) + assert ( + result["ConfigurationAggregator"]["CreationTime"] + == result["ConfigurationAggregator"]["LastUpdatedTime"] ) - assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testing' - assert result['ConfigurationAggregator']['AccountAggregationSources'] == [account_aggregation_source] - assert result['ConfigurationAggregator']['ConfigurationAggregatorArn'] == original_arn + # Update the existing one: + original_arn = result["ConfigurationAggregator"]["ConfigurationAggregatorArn"] + account_aggregation_source.pop("AwsRegions") + account_aggregation_source["AllAwsRegions"] = True + result = client.put_configuration_aggregator( + ConfigurationAggregatorName="testing", + AccountAggregationSources=[account_aggregation_source], + ) + + assert result["ConfigurationAggregator"]["ConfigurationAggregatorName"] == "testing" + assert result["ConfigurationAggregator"]["AccountAggregationSources"] == [ + account_aggregation_source + ] + assert ( + result["ConfigurationAggregator"]["ConfigurationAggregatorArn"] == original_arn + ) # Make an org one: result = client.put_configuration_aggregator( - ConfigurationAggregatorName='testingOrg', + ConfigurationAggregatorName="testingOrg", OrganizationAggregationSource={ - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AwsRegions': ['us-east-1', 'us-west-2'] - } + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AwsRegions": ["us-east-1", "us-west-2"], + }, ) - assert result['ConfigurationAggregator']['ConfigurationAggregatorName'] == 'testingOrg' - assert result['ConfigurationAggregator']['OrganizationAggregationSource'] == { - 'RoleArn': 'arn:aws:iam::012345678910:role/SomeRole', - 'AwsRegions': [ - 'us-east-1', - 'us-west-2' - ], - 'AllAwsRegions': False + assert ( + result["ConfigurationAggregator"]["ConfigurationAggregatorName"] == "testingOrg" + ) + assert result["ConfigurationAggregator"]["OrganizationAggregationSource"] == { + "RoleArn": "arn:aws:iam::012345678910:role/SomeRole", + "AwsRegions": ["us-east-1", "us-west-2"], + "AllAwsRegions": False, } @mock_config def test_describe_configuration_aggregators(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without any config aggregators: - assert not client.describe_configuration_aggregators()['ConfigurationAggregators'] + assert not client.describe_configuration_aggregators()["ConfigurationAggregators"] # Make 10 config aggregators: for x in range(0, 10): client.put_configuration_aggregator( - ConfigurationAggregatorName='testing{}'.format(x), + ConfigurationAggregatorName="testing{}".format(x), AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } - ] + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], ) # Describe with an incorrect name: with assert_raises(ClientError) as ce: - client.describe_configuration_aggregators(ConfigurationAggregatorNames=['DoesNotExist']) - assert 'The configuration aggregator does not exist.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException' + client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["DoesNotExist"] + ) + assert ( + "The configuration aggregator does not exist." + in ce.exception.response["Error"]["Message"] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "NoSuchConfigurationAggregatorException" + ) # Error describe with more than 1 item in the list: with assert_raises(ClientError) as ce: - client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing0', 'DoesNotExist']) - assert 'At least one of the configuration aggregators does not exist.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException' + client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["testing0", "DoesNotExist"] + ) + assert ( + "At least one of the configuration aggregators does not exist." + in ce.exception.response["Error"]["Message"] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "NoSuchConfigurationAggregatorException" + ) # Get the normal list: result = client.describe_configuration_aggregators() - assert not result.get('NextToken') - assert len(result['ConfigurationAggregators']) == 10 + assert not result.get("NextToken") + assert len(result["ConfigurationAggregators"]) == 10 # Test filtered list: - agg_names = ['testing0', 'testing1', 'testing2'] - result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=agg_names) - assert not result.get('NextToken') - assert len(result['ConfigurationAggregators']) == 3 - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == agg_names + agg_names = ["testing0", "testing1", "testing2"] + result = client.describe_configuration_aggregators( + ConfigurationAggregatorNames=agg_names + ) + assert not result.get("NextToken") + assert len(result["ConfigurationAggregators"]) == 3 + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == agg_names # Test Pagination: result = client.describe_configuration_aggregators(Limit=4) - assert len(result['ConfigurationAggregators']) == 4 - assert result['NextToken'] == 'testing4' - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \ - ['testing{}'.format(x) for x in range(0, 4)] - result = client.describe_configuration_aggregators(Limit=4, NextToken='testing4') - assert len(result['ConfigurationAggregators']) == 4 - assert result['NextToken'] == 'testing8' - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \ - ['testing{}'.format(x) for x in range(4, 8)] - result = client.describe_configuration_aggregators(Limit=4, NextToken='testing8') - assert len(result['ConfigurationAggregators']) == 2 - assert not result.get('NextToken') - assert [agg['ConfigurationAggregatorName'] for agg in result['ConfigurationAggregators']] == \ - ['testing{}'.format(x) for x in range(8, 10)] + assert len(result["ConfigurationAggregators"]) == 4 + assert result["NextToken"] == "testing4" + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == ["testing{}".format(x) for x in range(0, 4)] + result = client.describe_configuration_aggregators(Limit=4, NextToken="testing4") + assert len(result["ConfigurationAggregators"]) == 4 + assert result["NextToken"] == "testing8" + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == ["testing{}".format(x) for x in range(4, 8)] + result = client.describe_configuration_aggregators(Limit=4, NextToken="testing8") + assert len(result["ConfigurationAggregators"]) == 2 + assert not result.get("NextToken") + assert [ + agg["ConfigurationAggregatorName"] for agg in result["ConfigurationAggregators"] + ] == ["testing{}".format(x) for x in range(8, 10)] # Test Pagination with Filtering: - result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing2', 'testing4'], Limit=1) - assert len(result['ConfigurationAggregators']) == 1 - assert result['NextToken'] == 'testing4' - assert result['ConfigurationAggregators'][0]['ConfigurationAggregatorName'] == 'testing2' - result = client.describe_configuration_aggregators(ConfigurationAggregatorNames=['testing2', 'testing4'], Limit=1, NextToken='testing4') - assert not result.get('NextToken') - assert result['ConfigurationAggregators'][0]['ConfigurationAggregatorName'] == 'testing4' + result = client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["testing2", "testing4"], Limit=1 + ) + assert len(result["ConfigurationAggregators"]) == 1 + assert result["NextToken"] == "testing4" + assert ( + result["ConfigurationAggregators"][0]["ConfigurationAggregatorName"] + == "testing2" + ) + result = client.describe_configuration_aggregators( + ConfigurationAggregatorNames=["testing2", "testing4"], + Limit=1, + NextToken="testing4", + ) + assert not result.get("NextToken") + assert ( + result["ConfigurationAggregators"][0]["ConfigurationAggregatorName"] + == "testing4" + ) # Test with an invalid filter: with assert_raises(ClientError) as ce: - client.describe_configuration_aggregators(NextToken='WRONG') - assert 'The nextToken provided is invalid' == ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidNextTokenException' + client.describe_configuration_aggregators(NextToken="WRONG") + assert ( + "The nextToken provided is invalid" == ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidNextTokenException" @mock_config def test_put_aggregation_authorization(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Too many tags (>50): with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': '{}'.format(x), 'Value': '{}'.format(x)} for x in range(0, 51)] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[ + {"Key": "{}".format(x), "Value": "{}".format(x)} for x in range(0, 51) + ], ) - assert 'Member must have length less than or equal to 50' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 50" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag key is too big (>128 chars): with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': 'a' * 129, 'Value': 'a'}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "a" * 129, "Value": "a"}], ) - assert 'Member must have length less than or equal to 128' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 128" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Tag value is too big (>256 chars): with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': 'tag', 'Value': 'a' * 257}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "tag", "Value": "a" * 257}], ) - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Duplicate Tags: with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': 'a', 'Value': 'a'}, {'Key': 'a', 'Value': 'a'}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "a", "Value": "a"}, {"Key": "a", "Value": "a"}], ) - assert 'Duplicate tag keys found.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidInput' + assert "Duplicate tag keys found." in ce.exception.response["Error"]["Message"] + assert ce.exception.response["Error"]["Code"] == "InvalidInput" # Invalid characters in the tag key: with assert_raises(ClientError) as ce: client.put_aggregation_authorization( - AuthorizedAccountId='012345678910', - AuthorizedAwsRegion='us-west-2', - Tags=[{'Key': '!', 'Value': 'a'}] + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-west-2", + Tags=[{"Key": "!", "Value": "a"}], ) - assert 'Member must satisfy regular expression pattern:' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'ValidationException' + assert ( + "Member must satisfy regular expression pattern:" + in ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "ValidationException" # Put a normal one there: - result = client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-east-1', - Tags=[{'Key': 'tag', 'Value': 'a'}]) + result = client.put_aggregation_authorization( + AuthorizedAccountId="012345678910", + AuthorizedAwsRegion="us-east-1", + Tags=[{"Key": "tag", "Value": "a"}], + ) - assert result['AggregationAuthorization']['AggregationAuthorizationArn'] == 'arn:aws:config:us-west-2:123456789012:' \ - 'aggregation-authorization/012345678910/us-east-1' - assert result['AggregationAuthorization']['AuthorizedAccountId'] == '012345678910' - assert result['AggregationAuthorization']['AuthorizedAwsRegion'] == 'us-east-1' - assert isinstance(result['AggregationAuthorization']['CreationTime'], datetime) + assert ( + result["AggregationAuthorization"]["AggregationAuthorizationArn"] + == "arn:aws:config:us-west-2:123456789012:" + "aggregation-authorization/012345678910/us-east-1" + ) + assert result["AggregationAuthorization"]["AuthorizedAccountId"] == "012345678910" + assert result["AggregationAuthorization"]["AuthorizedAwsRegion"] == "us-east-1" + assert isinstance(result["AggregationAuthorization"]["CreationTime"], datetime) - creation_date = result['AggregationAuthorization']['CreationTime'] + creation_date = result["AggregationAuthorization"]["CreationTime"] # And again: - result = client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-east-1') - assert result['AggregationAuthorization']['AggregationAuthorizationArn'] == 'arn:aws:config:us-west-2:123456789012:' \ - 'aggregation-authorization/012345678910/us-east-1' - assert result['AggregationAuthorization']['AuthorizedAccountId'] == '012345678910' - assert result['AggregationAuthorization']['AuthorizedAwsRegion'] == 'us-east-1' - assert result['AggregationAuthorization']['CreationTime'] == creation_date + result = client.put_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-east-1" + ) + assert ( + result["AggregationAuthorization"]["AggregationAuthorizationArn"] + == "arn:aws:config:us-west-2:123456789012:" + "aggregation-authorization/012345678910/us-east-1" + ) + assert result["AggregationAuthorization"]["AuthorizedAccountId"] == "012345678910" + assert result["AggregationAuthorization"]["AuthorizedAwsRegion"] == "us-east-1" + assert result["AggregationAuthorization"]["CreationTime"] == creation_date @mock_config def test_describe_aggregation_authorizations(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # With no aggregation authorizations: - assert not client.describe_aggregation_authorizations()['AggregationAuthorizations'] + assert not client.describe_aggregation_authorizations()["AggregationAuthorizations"] # Make 10 account authorizations: for i in range(0, 10): - client.put_aggregation_authorization(AuthorizedAccountId='{}'.format(str(i) * 12), AuthorizedAwsRegion='us-west-2') + client.put_aggregation_authorization( + AuthorizedAccountId="{}".format(str(i) * 12), + AuthorizedAwsRegion="us-west-2", + ) result = client.describe_aggregation_authorizations() - assert len(result['AggregationAuthorizations']) == 10 - assert not result.get('NextToken') + assert len(result["AggregationAuthorizations"]) == 10 + assert not result.get("NextToken") for i in range(0, 10): - assert result['AggregationAuthorizations'][i]['AuthorizedAccountId'] == str(i) * 12 + assert ( + result["AggregationAuthorizations"][i]["AuthorizedAccountId"] == str(i) * 12 + ) # Test Pagination: result = client.describe_aggregation_authorizations(Limit=4) - assert len(result['AggregationAuthorizations']) == 4 - assert result['NextToken'] == ('4' * 12) + '/us-west-2' - assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(0, 4)] + assert len(result["AggregationAuthorizations"]) == 4 + assert result["NextToken"] == ("4" * 12) + "/us-west-2" + assert [ + auth["AuthorizedAccountId"] for auth in result["AggregationAuthorizations"] + ] == ["{}".format(str(x) * 12) for x in range(0, 4)] - result = client.describe_aggregation_authorizations(Limit=4, NextToken=('4' * 12) + '/us-west-2') - assert len(result['AggregationAuthorizations']) == 4 - assert result['NextToken'] == ('8' * 12) + '/us-west-2' - assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(4, 8)] + result = client.describe_aggregation_authorizations( + Limit=4, NextToken=("4" * 12) + "/us-west-2" + ) + assert len(result["AggregationAuthorizations"]) == 4 + assert result["NextToken"] == ("8" * 12) + "/us-west-2" + assert [ + auth["AuthorizedAccountId"] for auth in result["AggregationAuthorizations"] + ] == ["{}".format(str(x) * 12) for x in range(4, 8)] - result = client.describe_aggregation_authorizations(Limit=4, NextToken=('8' * 12) + '/us-west-2') - assert len(result['AggregationAuthorizations']) == 2 - assert not result.get('NextToken') - assert [auth['AuthorizedAccountId'] for auth in result['AggregationAuthorizations']] == ['{}'.format(str(x) * 12) for x in range(8, 10)] + result = client.describe_aggregation_authorizations( + Limit=4, NextToken=("8" * 12) + "/us-west-2" + ) + assert len(result["AggregationAuthorizations"]) == 2 + assert not result.get("NextToken") + assert [ + auth["AuthorizedAccountId"] for auth in result["AggregationAuthorizations"] + ] == ["{}".format(str(x) * 12) for x in range(8, 10)] # Test with an invalid filter: with assert_raises(ClientError) as ce: - client.describe_aggregation_authorizations(NextToken='WRONG') - assert 'The nextToken provided is invalid' == ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'InvalidNextTokenException' + client.describe_aggregation_authorizations(NextToken="WRONG") + assert ( + "The nextToken provided is invalid" == ce.exception.response["Error"]["Message"] + ) + assert ce.exception.response["Error"]["Code"] == "InvalidNextTokenException" @mock_config def test_delete_aggregation_authorization(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") - client.put_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2') + client.put_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-west-2" + ) # Delete it: - client.delete_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2') + client.delete_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-west-2" + ) # Verify that none are there: - assert not client.describe_aggregation_authorizations()['AggregationAuthorizations'] + assert not client.describe_aggregation_authorizations()["AggregationAuthorizations"] # Try it again -- nothing should happen: - client.delete_aggregation_authorization(AuthorizedAccountId='012345678910', AuthorizedAwsRegion='us-west-2') + client.delete_aggregation_authorization( + AuthorizedAccountId="012345678910", AuthorizedAwsRegion="us-west-2" + ) @mock_config def test_delete_configuration_aggregator(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', + ConfigurationAggregatorName="testing", AccountAggregationSources=[ - { - 'AccountIds': [ - '012345678910', - ], - 'AllAwsRegions': True - } - ] + {"AccountIds": ["012345678910"], "AllAwsRegions": True} + ], ) - client.delete_configuration_aggregator(ConfigurationAggregatorName='testing') + client.delete_configuration_aggregator(ConfigurationAggregatorName="testing") # And again to confirm that it's deleted: with assert_raises(ClientError) as ce: - client.delete_configuration_aggregator(ConfigurationAggregatorName='testing') - assert 'The configuration aggregator does not exist.' in ce.exception.response['Error']['Message'] - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationAggregatorException' + client.delete_configuration_aggregator(ConfigurationAggregatorName="testing") + assert ( + "The configuration aggregator does not exist." + in ce.exception.response["Error"]["Message"] + ) + assert ( + ce.exception.response["Error"]["Code"] + == "NoSuchConfigurationAggregatorException" + ) @mock_config def test_describe_configurations(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without any configurations: result = client.describe_configuration_recorders() - assert not result['ConfigurationRecorders'] + assert not result["ConfigurationRecorders"] - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) - result = client.describe_configuration_recorders()['ConfigurationRecorders'] + result = client.describe_configuration_recorders()["ConfigurationRecorders"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert result[0]['roleARN'] == 'somearn' - assert not result[0]['recordingGroup']['allSupported'] - assert not result[0]['recordingGroup']['includeGlobalResourceTypes'] - assert len(result[0]['recordingGroup']['resourceTypes']) == 2 - assert 'AWS::EC2::Volume' in result[0]['recordingGroup']['resourceTypes'] \ - and 'AWS::EC2::VPC' in result[0]['recordingGroup']['resourceTypes'] + assert result[0]["name"] == "testrecorder" + assert result[0]["roleARN"] == "somearn" + assert not result[0]["recordingGroup"]["allSupported"] + assert not result[0]["recordingGroup"]["includeGlobalResourceTypes"] + assert len(result[0]["recordingGroup"]["resourceTypes"]) == 2 + assert ( + "AWS::EC2::Volume" in result[0]["recordingGroup"]["resourceTypes"] + and "AWS::EC2::VPC" in result[0]["recordingGroup"]["resourceTypes"] + ) # Specify an incorrect name: with assert_raises(ClientError) as ce: - client.describe_configuration_recorders(ConfigurationRecorderNames=['wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_configuration_recorders(ConfigurationRecorderNames=["wrong"]) + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) + assert "wrong" in ce.exception.response["Error"]["Message"] # And with both a good and wrong name: with assert_raises(ClientError) as ce: - client.describe_configuration_recorders(ConfigurationRecorderNames=['testrecorder', 'wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_configuration_recorders( + ConfigurationRecorderNames=["testrecorder", "wrong"] + ) + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) + assert "wrong" in ce.exception.response["Error"]["Message"] @mock_config def test_delivery_channels(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Try without a config recorder: with assert_raises(ClientError) as ce: client.put_delivery_channel(DeliveryChannel={}) - assert ce.exception.response['Error']['Code'] == 'NoAvailableConfigurationRecorderException' - assert ce.exception.response['Error']['Message'] == 'Configuration recorder is not available to ' \ - 'put delivery channel.' + assert ( + ce.exception.response["Error"]["Code"] + == "NoAvailableConfigurationRecorderException" + ) + assert ( + ce.exception.response["Error"]["Message"] + == "Configuration recorder is not available to " + "put delivery channel." + ) # Create a config recorder to continue testing: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Try without a name supplied: with assert_raises(ClientError) as ce: client.put_delivery_channel(DeliveryChannel={}) - assert ce.exception.response['Error']['Code'] == 'InvalidDeliveryChannelNameException' - assert 'is not valid, blank string.' in ce.exception.response['Error']['Message'] + assert ( + ce.exception.response["Error"]["Code"] == "InvalidDeliveryChannelNameException" + ) + assert "is not valid, blank string." in ce.exception.response["Error"]["Message"] # Try with a really long name: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'a' * 257}) - assert ce.exception.response['Error']['Code'] == 'ValidationException' - assert 'Member must have length less than or equal to 256' in ce.exception.response['Error']['Message'] + client.put_delivery_channel(DeliveryChannel={"name": "a" * 257}) + assert ce.exception.response["Error"]["Code"] == "ValidationException" + assert ( + "Member must have length less than or equal to 256" + in ce.exception.response["Error"]["Message"] + ) # Without specifying a bucket name: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel'}) - assert ce.exception.response['Error']['Code'] == 'NoSuchBucketException' - assert ce.exception.response['Error']['Message'] == 'Cannot find a S3 bucket with an empty bucket name.' + client.put_delivery_channel(DeliveryChannel={"name": "testchannel"}) + assert ce.exception.response["Error"]["Code"] == "NoSuchBucketException" + assert ( + ce.exception.response["Error"]["Message"] + == "Cannot find a S3 bucket with an empty bucket name." + ) with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': ''}) - assert ce.exception.response['Error']['Code'] == 'NoSuchBucketException' - assert ce.exception.response['Error']['Message'] == 'Cannot find a S3 bucket with an empty bucket name.' + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": ""} + ) + assert ce.exception.response["Error"]["Code"] == "NoSuchBucketException" + assert ( + ce.exception.response["Error"]["Message"] + == "Cannot find a S3 bucket with an empty bucket name." + ) # With an empty string for the S3 key prefix: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', 's3BucketName': 'somebucket', 's3KeyPrefix': ''}) - assert ce.exception.response['Error']['Code'] == 'InvalidS3KeyPrefixException' - assert 'empty s3 key prefix.' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "s3KeyPrefix": "", + } + ) + assert ce.exception.response["Error"]["Code"] == "InvalidS3KeyPrefixException" + assert "empty s3 key prefix." in ce.exception.response["Error"]["Message"] # With an empty string for the SNS ARN: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', 's3BucketName': 'somebucket', 'snsTopicARN': ''}) - assert ce.exception.response['Error']['Code'] == 'InvalidSNSTopicARNException' - assert 'The sns topic arn' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "snsTopicARN": "", + } + ) + assert ce.exception.response["Error"]["Code"] == "InvalidSNSTopicARNException" + assert "The sns topic arn" in ce.exception.response["Error"]["Message"] # With an invalid delivery frequency: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', - 's3BucketName': 'somebucket', - 'configSnapshotDeliveryProperties': {'deliveryFrequency': 'WRONG'} - }) - assert ce.exception.response['Error']['Code'] == 'InvalidDeliveryFrequency' - assert 'WRONG' in ce.exception.response['Error']['Message'] - assert 'TwentyFour_Hours' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "configSnapshotDeliveryProperties": {"deliveryFrequency": "WRONG"}, + } + ) + assert ce.exception.response["Error"]["Code"] == "InvalidDeliveryFrequency" + assert "WRONG" in ce.exception.response["Error"]["Message"] + assert "TwentyFour_Hours" in ce.exception.response["Error"]["Message"] # Create a proper one: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 2 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" # Overwrite it with another proper configuration: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', - 's3BucketName': 'somebucket', - 'snsTopicARN': 'sometopicarn', - 'configSnapshotDeliveryProperties': {'deliveryFrequency': 'TwentyFour_Hours'} - }) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "snsTopicARN": "sometopicarn", + "configSnapshotDeliveryProperties": { + "deliveryFrequency": "TwentyFour_Hours" + }, + } + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 4 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' - assert result[0]['snsTopicARN'] == 'sometopicarn' - assert result[0]['configSnapshotDeliveryProperties']['deliveryFrequency'] == 'TwentyFour_Hours' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" + assert result[0]["snsTopicARN"] == "sometopicarn" + assert ( + result[0]["configSnapshotDeliveryProperties"]["deliveryFrequency"] + == "TwentyFour_Hours" + ) # Can only have 1: with assert_raises(ClientError) as ce: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel2', 's3BucketName': 'somebucket'}) - assert ce.exception.response['Error']['Code'] == 'MaxNumberOfDeliveryChannelsExceededException' - assert 'because the maximum number of delivery channels: 1 is reached.' in ce.exception.response['Error']['Message'] + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel2", "s3BucketName": "somebucket"} + ) + assert ( + ce.exception.response["Error"]["Code"] + == "MaxNumberOfDeliveryChannelsExceededException" + ) + assert ( + "because the maximum number of delivery channels: 1 is reached." + in ce.exception.response["Error"]["Message"] + ) @mock_config def test_describe_delivery_channels(): - client = boto3.client('config', region_name='us-west-2') - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client = boto3.client("config", region_name="us-west-2") + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Without any channels: result = client.describe_delivery_channels() - assert not result['DeliveryChannels'] + assert not result["DeliveryChannels"] - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 2 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" # Overwrite it with another proper configuration: - client.put_delivery_channel(DeliveryChannel={ - 'name': 'testchannel', - 's3BucketName': 'somebucket', - 'snsTopicARN': 'sometopicarn', - 'configSnapshotDeliveryProperties': {'deliveryFrequency': 'TwentyFour_Hours'} - }) - result = client.describe_delivery_channels()['DeliveryChannels'] + client.put_delivery_channel( + DeliveryChannel={ + "name": "testchannel", + "s3BucketName": "somebucket", + "snsTopicARN": "sometopicarn", + "configSnapshotDeliveryProperties": { + "deliveryFrequency": "TwentyFour_Hours" + }, + } + ) + result = client.describe_delivery_channels()["DeliveryChannels"] assert len(result) == 1 assert len(result[0].keys()) == 4 - assert result[0]['name'] == 'testchannel' - assert result[0]['s3BucketName'] == 'somebucket' - assert result[0]['snsTopicARN'] == 'sometopicarn' - assert result[0]['configSnapshotDeliveryProperties']['deliveryFrequency'] == 'TwentyFour_Hours' + assert result[0]["name"] == "testchannel" + assert result[0]["s3BucketName"] == "somebucket" + assert result[0]["snsTopicARN"] == "sometopicarn" + assert ( + result[0]["configSnapshotDeliveryProperties"]["deliveryFrequency"] + == "TwentyFour_Hours" + ) # Specify an incorrect name: with assert_raises(ClientError) as ce: - client.describe_delivery_channels(DeliveryChannelNames=['wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchDeliveryChannelException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_delivery_channels(DeliveryChannelNames=["wrong"]) + assert ce.exception.response["Error"]["Code"] == "NoSuchDeliveryChannelException" + assert "wrong" in ce.exception.response["Error"]["Message"] # And with both a good and wrong name: with assert_raises(ClientError) as ce: - client.describe_delivery_channels(DeliveryChannelNames=['testchannel', 'wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchDeliveryChannelException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_delivery_channels(DeliveryChannelNames=["testchannel", "wrong"]) + assert ce.exception.response["Error"]["Code"] == "NoSuchDeliveryChannelException" + assert "wrong" in ce.exception.response["Error"]["Message"] @mock_config def test_start_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without a config recorder: with assert_raises(ClientError) as ce: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Without a delivery channel: with assert_raises(ClientError) as ce: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoAvailableDeliveryChannelException' + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoAvailableDeliveryChannelException" + ) # Make the delivery channel: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) # Start it: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") # Verify it's enabled: - result = client.describe_configuration_recorder_status()['ConfigurationRecordersStatus'] - lower_bound = (datetime.utcnow() - timedelta(minutes=5)) - assert result[0]['recording'] - assert result[0]['lastStatus'] == 'PENDING' - assert lower_bound < result[0]['lastStartTime'].replace(tzinfo=None) <= datetime.utcnow() - assert lower_bound < result[0]['lastStatusChangeTime'].replace(tzinfo=None) <= datetime.utcnow() + result = client.describe_configuration_recorder_status()[ + "ConfigurationRecordersStatus" + ] + lower_bound = datetime.utcnow() - timedelta(minutes=5) + assert result[0]["recording"] + assert result[0]["lastStatus"] == "PENDING" + assert ( + lower_bound + < result[0]["lastStartTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) + assert ( + lower_bound + < result[0]["lastStatusChangeTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) @mock_config def test_stop_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without a config recorder: with assert_raises(ClientError) as ce: - client.stop_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' + client.stop_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Make the delivery channel for creation: - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) # Start it: - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') - client.stop_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") + client.stop_configuration_recorder(ConfigurationRecorderName="testrecorder") # Verify it's disabled: - result = client.describe_configuration_recorder_status()['ConfigurationRecordersStatus'] - lower_bound = (datetime.utcnow() - timedelta(minutes=5)) - assert not result[0]['recording'] - assert result[0]['lastStatus'] == 'PENDING' - assert lower_bound < result[0]['lastStartTime'].replace(tzinfo=None) <= datetime.utcnow() - assert lower_bound < result[0]['lastStopTime'].replace(tzinfo=None) <= datetime.utcnow() - assert lower_bound < result[0]['lastStatusChangeTime'].replace(tzinfo=None) <= datetime.utcnow() + result = client.describe_configuration_recorder_status()[ + "ConfigurationRecordersStatus" + ] + lower_bound = datetime.utcnow() - timedelta(minutes=5) + assert not result[0]["recording"] + assert result[0]["lastStatus"] == "PENDING" + assert ( + lower_bound + < result[0]["lastStartTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) + assert ( + lower_bound + < result[0]["lastStopTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) + assert ( + lower_bound + < result[0]["lastStatusChangeTime"].replace(tzinfo=None) + <= datetime.utcnow() + ) @mock_config def test_describe_configuration_recorder_status(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without any: result = client.describe_configuration_recorder_status() - assert not result['ConfigurationRecordersStatus'] + assert not result["ConfigurationRecordersStatus"] # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Without specifying a config recorder: - result = client.describe_configuration_recorder_status()['ConfigurationRecordersStatus'] + result = client.describe_configuration_recorder_status()[ + "ConfigurationRecordersStatus" + ] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert not result[0]['recording'] + assert result[0]["name"] == "testrecorder" + assert not result[0]["recording"] # With a proper name: result = client.describe_configuration_recorder_status( - ConfigurationRecorderNames=['testrecorder'])['ConfigurationRecordersStatus'] + ConfigurationRecorderNames=["testrecorder"] + )["ConfigurationRecordersStatus"] assert len(result) == 1 - assert result[0]['name'] == 'testrecorder' - assert not result[0]['recording'] + assert result[0]["name"] == "testrecorder" + assert not result[0]["recording"] # Invalid name: with assert_raises(ClientError) as ce: - client.describe_configuration_recorder_status(ConfigurationRecorderNames=['testrecorder', 'wrong']) - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' - assert 'wrong' in ce.exception.response['Error']['Message'] + client.describe_configuration_recorder_status( + ConfigurationRecorderNames=["testrecorder", "wrong"] + ) + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) + assert "wrong" in ce.exception.response["Error"]["Message"] @mock_config def test_delete_configuration_recorder(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Make the config recorder; - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) + ) # Delete it: - client.delete_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.delete_configuration_recorder(ConfigurationRecorderName="testrecorder") # Try again -- it should be deleted: with assert_raises(ClientError) as ce: - client.delete_configuration_recorder(ConfigurationRecorderName='testrecorder') - assert ce.exception.response['Error']['Code'] == 'NoSuchConfigurationRecorderException' + client.delete_configuration_recorder(ConfigurationRecorderName="testrecorder") + assert ( + ce.exception.response["Error"]["Code"] == "NoSuchConfigurationRecorderException" + ) @mock_config def test_delete_delivery_channel(): - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Need a recorder to test the constraint on recording being enabled: - client.put_configuration_recorder(ConfigurationRecorder={ - 'name': 'testrecorder', - 'roleARN': 'somearn', - 'recordingGroup': { - 'allSupported': False, - 'includeGlobalResourceTypes': False, - 'resourceTypes': ['AWS::EC2::Volume', 'AWS::EC2::VPC'] + client.put_configuration_recorder( + ConfigurationRecorder={ + "name": "testrecorder", + "roleARN": "somearn", + "recordingGroup": { + "allSupported": False, + "includeGlobalResourceTypes": False, + "resourceTypes": ["AWS::EC2::Volume", "AWS::EC2::VPC"], + }, } - }) - client.put_delivery_channel(DeliveryChannel={'name': 'testchannel', 's3BucketName': 'somebucket'}) - client.start_configuration_recorder(ConfigurationRecorderName='testrecorder') + ) + client.put_delivery_channel( + DeliveryChannel={"name": "testchannel", "s3BucketName": "somebucket"} + ) + client.start_configuration_recorder(ConfigurationRecorderName="testrecorder") # With the recorder enabled: with assert_raises(ClientError) as ce: - client.delete_delivery_channel(DeliveryChannelName='testchannel') - assert ce.exception.response['Error']['Code'] == 'LastDeliveryChannelDeleteFailedException' - assert 'because there is a running configuration recorder.' in ce.exception.response['Error']['Message'] + client.delete_delivery_channel(DeliveryChannelName="testchannel") + assert ( + ce.exception.response["Error"]["Code"] + == "LastDeliveryChannelDeleteFailedException" + ) + assert ( + "because there is a running configuration recorder." + in ce.exception.response["Error"]["Message"] + ) # Stop recording: - client.stop_configuration_recorder(ConfigurationRecorderName='testrecorder') + client.stop_configuration_recorder(ConfigurationRecorderName="testrecorder") # Try again: - client.delete_delivery_channel(DeliveryChannelName='testchannel') + client.delete_delivery_channel(DeliveryChannelName="testchannel") # Verify: with assert_raises(ClientError) as ce: - client.delete_delivery_channel(DeliveryChannelName='testchannel') - assert ce.exception.response['Error']['Code'] == 'NoSuchDeliveryChannelException' + client.delete_delivery_channel(DeliveryChannelName="testchannel") + assert ce.exception.response["Error"]["Code"] == "NoSuchDeliveryChannelException" @mock_config @@ -1019,72 +1261,104 @@ def test_list_discovered_resource(): """NOTE: We are only really testing the Config part. For each individual service, please add tests for that individual service's "list_config_service_resources" function. """ - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # With nothing created yet: - assert not client.list_discovered_resources(resourceType='AWS::S3::Bucket')['resourceIdentifiers'] + assert not client.list_discovered_resources(resourceType="AWS::S3::Bucket")[ + "resourceIdentifiers" + ] # Create some S3 buckets: - s3_client = boto3.client('s3', region_name='us-west-2') + s3_client = boto3.client("s3", region_name="us-west-2") for x in range(0, 10): - s3_client.create_bucket(Bucket='bucket{}'.format(x), CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) # And with an EU bucket -- this should not show up for the us-west-2 config backend: - eu_client = boto3.client('s3', region_name='eu-west-1') - eu_client.create_bucket(Bucket='eu-bucket', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) + eu_client = boto3.client("s3", region_name="eu-west-1") + eu_client.create_bucket( + Bucket="eu-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) # Now try: - result = client.list_discovered_resources(resourceType='AWS::S3::Bucket') - assert len(result['resourceIdentifiers']) == 10 + result = client.list_discovered_resources(resourceType="AWS::S3::Bucket") + assert len(result["resourceIdentifiers"]) == 10 for x in range(0, 10): - assert result['resourceIdentifiers'][x] == { - 'resourceType': 'AWS::S3::Bucket', - 'resourceId': 'bucket{}'.format(x), - 'resourceName': 'bucket{}'.format(x) + assert result["resourceIdentifiers"][x] == { + "resourceType": "AWS::S3::Bucket", + "resourceId": "bucket{}".format(x), + "resourceName": "bucket{}".format(x), } - assert not result.get('nextToken') + assert not result.get("nextToken") - result = client.list_discovered_resources(resourceType='AWS::S3::Bucket', resourceName='eu-bucket') - assert not result['resourceIdentifiers'] + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", resourceName="eu-bucket" + ) + assert not result["resourceIdentifiers"] # Test that pagination places a proper nextToken in the response and also that the limit works: - result = client.list_discovered_resources(resourceType='AWS::S3::Bucket', limit=1, nextToken='bucket1') - assert len(result['resourceIdentifiers']) == 1 - assert result['nextToken'] == 'bucket2' + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, nextToken="bucket1" + ) + assert len(result["resourceIdentifiers"]) == 1 + assert result["nextToken"] == "bucket2" # Try with a resource name: - result = client.list_discovered_resources(resourceType='AWS::S3::Bucket', limit=1, resourceName='bucket1') - assert len(result['resourceIdentifiers']) == 1 - assert not result.get('nextToken') + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, resourceName="bucket1" + ) + assert len(result["resourceIdentifiers"]) == 1 + assert not result.get("nextToken") # Try with a resource ID: - result = client.list_discovered_resources(resourceType='AWS::S3::Bucket', limit=1, resourceIds=['bucket1']) - assert len(result['resourceIdentifiers']) == 1 - assert not result.get('nextToken') + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, resourceIds=["bucket1"] + ) + assert len(result["resourceIdentifiers"]) == 1 + assert not result.get("nextToken") # Try with duplicated resource IDs: - result = client.list_discovered_resources(resourceType='AWS::S3::Bucket', limit=1, resourceIds=['bucket1', 'bucket1']) - assert len(result['resourceIdentifiers']) == 1 - assert not result.get('nextToken') + result = client.list_discovered_resources( + resourceType="AWS::S3::Bucket", limit=1, resourceIds=["bucket1", "bucket1"] + ) + assert len(result["resourceIdentifiers"]) == 1 + assert not result.get("nextToken") # Test with an invalid resource type: - assert not client.list_discovered_resources(resourceType='LOL::NOT::A::RESOURCE::TYPE')['resourceIdentifiers'] + assert not client.list_discovered_resources( + resourceType="LOL::NOT::A::RESOURCE::TYPE" + )["resourceIdentifiers"] # Test with an invalid page num > 100: with assert_raises(ClientError) as ce: - client.list_discovered_resources(resourceType='AWS::S3::Bucket', limit=101) - assert '101' in ce.exception.response['Error']['Message'] + client.list_discovered_resources(resourceType="AWS::S3::Bucket", limit=101) + assert "101" in ce.exception.response["Error"]["Message"] # Test by supplying both resourceName and also resourceIds: with assert_raises(ClientError) as ce: - client.list_discovered_resources(resourceType='AWS::S3::Bucket', resourceName='whats', resourceIds=['up', 'doc']) - assert 'Both Resource ID and Resource Name cannot be specified in the request' in ce.exception.response['Error']['Message'] + client.list_discovered_resources( + resourceType="AWS::S3::Bucket", + resourceName="whats", + resourceIds=["up", "doc"], + ) + assert ( + "Both Resource ID and Resource Name cannot be specified in the request" + in ce.exception.response["Error"]["Message"] + ) # More than 20 resourceIds: - resource_ids = ['{}'.format(x) for x in range(0, 21)] + resource_ids = ["{}".format(x) for x in range(0, 21)] with assert_raises(ClientError) as ce: - client.list_discovered_resources(resourceType='AWS::S3::Bucket', resourceIds=resource_ids) - assert 'The specified list had more than 20 resource ID\'s.' in ce.exception.response['Error']['Message'] + client.list_discovered_resources( + resourceType="AWS::S3::Bucket", resourceIds=resource_ids + ) + assert ( + "The specified list had more than 20 resource ID's." + in ce.exception.response["Error"]["Message"] + ) @mock_config @@ -1093,105 +1367,142 @@ def test_list_aggregate_discovered_resource(): """NOTE: We are only really testing the Config part. For each individual service, please add tests for that individual service's "list_config_service_resources" function. """ - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # Without an aggregator: with assert_raises(ClientError) as ce: - client.list_aggregate_discovered_resources(ConfigurationAggregatorName='lolno', ResourceType='AWS::S3::Bucket') - assert 'The configuration aggregator does not exist' in ce.exception.response['Error']['Message'] + client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="lolno", ResourceType="AWS::S3::Bucket" + ) + assert ( + "The configuration aggregator does not exist" + in ce.exception.response["Error"]["Message"] + ) # Create the aggregator: account_aggregation_source = { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AllAwsRegions': True + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AllAwsRegions": True, } client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', - AccountAggregationSources=[account_aggregation_source] + ConfigurationAggregatorName="testing", + AccountAggregationSources=[account_aggregation_source], ) # With nothing created yet: - assert not client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', - ResourceType='AWS::S3::Bucket')['ResourceIdentifiers'] + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", ResourceType="AWS::S3::Bucket" + )["ResourceIdentifiers"] # Create some S3 buckets: - s3_client = boto3.client('s3', region_name='us-west-2') + s3_client = boto3.client("s3", region_name="us-west-2") for x in range(0, 10): - s3_client.create_bucket(Bucket='bucket{}'.format(x), CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) - s3_client_eu = boto3.client('s3', region_name='eu-west-1') + s3_client_eu = boto3.client("s3", region_name="eu-west-1") for x in range(10, 12): - s3_client_eu.create_bucket(Bucket='eu-bucket{}'.format(x), CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) + s3_client_eu.create_bucket( + Bucket="eu-bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) # Now try: - result = client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket') - assert len(result['ResourceIdentifiers']) == 12 + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", ResourceType="AWS::S3::Bucket" + ) + assert len(result["ResourceIdentifiers"]) == 12 for x in range(0, 10): - assert result['ResourceIdentifiers'][x] == { - 'SourceAccountId': '123456789012', - 'ResourceType': 'AWS::S3::Bucket', - 'ResourceId': 'bucket{}'.format(x), - 'ResourceName': 'bucket{}'.format(x), - 'SourceRegion': 'us-west-2' + assert result["ResourceIdentifiers"][x] == { + "SourceAccountId": "123456789012", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket{}".format(x), + "ResourceName": "bucket{}".format(x), + "SourceRegion": "us-west-2", } for x in range(11, 12): - assert result['ResourceIdentifiers'][x] == { - 'SourceAccountId': '123456789012', - 'ResourceType': 'AWS::S3::Bucket', - 'ResourceId': 'eu-bucket{}'.format(x), - 'ResourceName': 'eu-bucket{}'.format(x), - 'SourceRegion': 'eu-west-1' + assert result["ResourceIdentifiers"][x] == { + "SourceAccountId": "123456789012", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "eu-bucket{}".format(x), + "ResourceName": "eu-bucket{}".format(x), + "SourceRegion": "eu-west-1", } - assert not result.get('NextToken') + assert not result.get("NextToken") # Test that pagination places a proper nextToken in the response and also that the limit works: - result = client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket', - Limit=1, NextToken='bucket1') - assert len(result['ResourceIdentifiers']) == 1 - assert result['NextToken'] == 'bucket2' + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=1, + NextToken="bucket1", + ) + assert len(result["ResourceIdentifiers"]) == 1 + assert result["NextToken"] == "bucket2" # Try with a resource name: - result = client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket', - Limit=1, NextToken='bucket1', Filters={'ResourceName': 'bucket1'}) - assert len(result['ResourceIdentifiers']) == 1 - assert not result.get('NextToken') + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=1, + NextToken="bucket1", + Filters={"ResourceName": "bucket1"}, + ) + assert len(result["ResourceIdentifiers"]) == 1 + assert not result.get("NextToken") # Try with a resource ID: - result = client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket', - Limit=1, NextToken='bucket1', Filters={'ResourceId': 'bucket1'}) - assert len(result['ResourceIdentifiers']) == 1 - assert not result.get('NextToken') + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=1, + NextToken="bucket1", + Filters={"ResourceId": "bucket1"}, + ) + assert len(result["ResourceIdentifiers"]) == 1 + assert not result.get("NextToken") # Try with a region specified: - result = client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket', - Filters={'Region': 'eu-west-1'}) - assert len(result['ResourceIdentifiers']) == 2 - assert result['ResourceIdentifiers'][0]['SourceRegion'] == 'eu-west-1' - assert not result.get('NextToken') + result = client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Filters={"Region": "eu-west-1"}, + ) + assert len(result["ResourceIdentifiers"]) == 2 + assert result["ResourceIdentifiers"][0]["SourceRegion"] == "eu-west-1" + assert not result.get("NextToken") # Try with both name and id set to the incorrect values: - assert not client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket', - Filters={'ResourceId': 'bucket1', - 'ResourceName': 'bucket2'})['ResourceIdentifiers'] + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Filters={"ResourceId": "bucket1", "ResourceName": "bucket2"}, + )["ResourceIdentifiers"] # Test with an invalid resource type: - assert not client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', - ResourceType='LOL::NOT::A::RESOURCE::TYPE')['ResourceIdentifiers'] + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="LOL::NOT::A::RESOURCE::TYPE", + )["ResourceIdentifiers"] # Try with correct name but incorrect region: - assert not client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket', - Filters={'ResourceId': 'bucket1', - 'Region': 'us-west-1'})['ResourceIdentifiers'] + assert not client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Filters={"ResourceId": "bucket1", "Region": "us-west-1"}, + )["ResourceIdentifiers"] # Test with an invalid page num > 100: with assert_raises(ClientError) as ce: - client.list_aggregate_discovered_resources(ConfigurationAggregatorName='testing', ResourceType='AWS::S3::Bucket', Limit=101) - assert '101' in ce.exception.response['Error']['Message'] + client.list_aggregate_discovered_resources( + ConfigurationAggregatorName="testing", + ResourceType="AWS::S3::Bucket", + Limit=101, + ) + assert "101" in ce.exception.response["Error"]["Message"] @mock_config @@ -1200,37 +1511,57 @@ def test_get_resource_config_history(): """NOTE: We are only really testing the Config part. For each individual service, please add tests for that individual service's "get_config_resource" function. """ - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # With an invalid resource type: with assert_raises(ClientError) as ce: - client.get_resource_config_history(resourceType='NOT::A::RESOURCE', resourceId='notcreatedyet') - assert ce.exception.response['Error'] == {'Message': 'Resource notcreatedyet of resourceType:NOT::A::RESOURCE is unknown or has ' - 'not been discovered', 'Code': 'ResourceNotDiscoveredException'} + client.get_resource_config_history( + resourceType="NOT::A::RESOURCE", resourceId="notcreatedyet" + ) + assert ce.exception.response["Error"] == { + "Message": "Resource notcreatedyet of resourceType:NOT::A::RESOURCE is unknown or has " + "not been discovered", + "Code": "ResourceNotDiscoveredException", + } # With nothing created yet: with assert_raises(ClientError) as ce: - client.get_resource_config_history(resourceType='AWS::S3::Bucket', resourceId='notcreatedyet') - assert ce.exception.response['Error'] == {'Message': 'Resource notcreatedyet of resourceType:AWS::S3::Bucket is unknown or has ' - 'not been discovered', 'Code': 'ResourceNotDiscoveredException'} + client.get_resource_config_history( + resourceType="AWS::S3::Bucket", resourceId="notcreatedyet" + ) + assert ce.exception.response["Error"] == { + "Message": "Resource notcreatedyet of resourceType:AWS::S3::Bucket is unknown or has " + "not been discovered", + "Code": "ResourceNotDiscoveredException", + } # Create an S3 bucket: - s3_client = boto3.client('s3', region_name='us-west-2') + s3_client = boto3.client("s3", region_name="us-west-2") for x in range(0, 10): - s3_client.create_bucket(Bucket='bucket{}'.format(x), CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) # Now try: - result = client.get_resource_config_history(resourceType='AWS::S3::Bucket', resourceId='bucket1')['configurationItems'] + result = client.get_resource_config_history( + resourceType="AWS::S3::Bucket", resourceId="bucket1" + )["configurationItems"] assert len(result) == 1 - assert result[0]['resourceName'] == result[0]['resourceId'] == 'bucket1' - assert result[0]['arn'] == 'arn:aws:s3:::bucket1' + assert result[0]["resourceName"] == result[0]["resourceId"] == "bucket1" + assert result[0]["arn"] == "arn:aws:s3:::bucket1" # Make a bucket in a different region and verify that it does not show up in the config backend: - s3_client = boto3.client('s3', region_name='eu-west-1') - s3_client.create_bucket(Bucket='eu-bucket', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) + s3_client = boto3.client("s3", region_name="eu-west-1") + s3_client.create_bucket( + Bucket="eu-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) with assert_raises(ClientError) as ce: - client.get_resource_config_history(resourceType='AWS::S3::Bucket', resourceId='eu-bucket') - assert ce.exception.response['Error']['Code'] == 'ResourceNotDiscoveredException' + client.get_resource_config_history( + resourceType="AWS::S3::Bucket", resourceId="eu-bucket" + ) + assert ce.exception.response["Error"]["Code"] == "ResourceNotDiscoveredException" @mock_config @@ -1239,42 +1570,62 @@ def test_batch_get_resource_config(): """NOTE: We are only really testing the Config part. For each individual service, please add tests for that individual service's "get_config_resource" function. """ - client = boto3.client('config', region_name='us-west-2') + client = boto3.client("config", region_name="us-west-2") # With more than 100 resourceKeys: with assert_raises(ClientError) as ce: - client.batch_get_resource_config(resourceKeys=[{'resourceType': 'AWS::S3::Bucket', 'resourceId': 'someBucket'}] * 101) - assert 'Member must have length less than or equal to 100' in ce.exception.response['Error']['Message'] + client.batch_get_resource_config( + resourceKeys=[ + {"resourceType": "AWS::S3::Bucket", "resourceId": "someBucket"} + ] + * 101 + ) + assert ( + "Member must have length less than or equal to 100" + in ce.exception.response["Error"]["Message"] + ) # With invalid resource types and resources that don't exist: - result = client.batch_get_resource_config(resourceKeys=[ - {'resourceType': 'NOT::A::RESOURCE', 'resourceId': 'NotAThing'}, {'resourceType': 'AWS::S3::Bucket', 'resourceId': 'NotAThing'}, - ]) + result = client.batch_get_resource_config( + resourceKeys=[ + {"resourceType": "NOT::A::RESOURCE", "resourceId": "NotAThing"}, + {"resourceType": "AWS::S3::Bucket", "resourceId": "NotAThing"}, + ] + ) - assert not result['baseConfigurationItems'] - assert not result['unprocessedResourceKeys'] + assert not result["baseConfigurationItems"] + assert not result["unprocessedResourceKeys"] # Create some S3 buckets: - s3_client = boto3.client('s3', region_name='us-west-2') + s3_client = boto3.client("s3", region_name="us-west-2") for x in range(0, 10): - s3_client.create_bucket(Bucket='bucket{}'.format(x), CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) # Get them all: - keys = [{'resourceType': 'AWS::S3::Bucket', 'resourceId': 'bucket{}'.format(x)} for x in range(0, 10)] + keys = [ + {"resourceType": "AWS::S3::Bucket", "resourceId": "bucket{}".format(x)} + for x in range(0, 10) + ] result = client.batch_get_resource_config(resourceKeys=keys) - assert len(result['baseConfigurationItems']) == 10 - buckets_missing = ['bucket{}'.format(x) for x in range(0, 10)] - for r in result['baseConfigurationItems']: - buckets_missing.remove(r['resourceName']) + assert len(result["baseConfigurationItems"]) == 10 + buckets_missing = ["bucket{}".format(x) for x in range(0, 10)] + for r in result["baseConfigurationItems"]: + buckets_missing.remove(r["resourceName"]) assert not buckets_missing # Make a bucket in a different region and verify that it does not show up in the config backend: - s3_client = boto3.client('s3', region_name='eu-west-1') - s3_client.create_bucket(Bucket='eu-bucket', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - keys = [{'resourceType': 'AWS::S3::Bucket', 'resourceId': 'eu-bucket'}] + s3_client = boto3.client("s3", region_name="eu-west-1") + s3_client.create_bucket( + Bucket="eu-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) + keys = [{"resourceType": "AWS::S3::Bucket", "resourceId": "eu-bucket"}] result = client.batch_get_resource_config(resourceKeys=keys) - assert not result['baseConfigurationItems'] + assert not result["baseConfigurationItems"] @mock_config @@ -1284,88 +1635,167 @@ def test_batch_get_aggregate_resource_config(): for that individual service's "get_config_resource" function. """ from moto.config.models import DEFAULT_ACCOUNT_ID - client = boto3.client('config', region_name='us-west-2') + + client = boto3.client("config", region_name="us-west-2") # Without an aggregator: - bad_ri = {'SourceAccountId': '000000000000', 'SourceRegion': 'not-a-region', 'ResourceType': 'NOT::A::RESOURCE', 'ResourceId': 'nope'} + bad_ri = { + "SourceAccountId": "000000000000", + "SourceRegion": "not-a-region", + "ResourceType": "NOT::A::RESOURCE", + "ResourceId": "nope", + } with assert_raises(ClientError) as ce: - client.batch_get_aggregate_resource_config(ConfigurationAggregatorName='lolno', ResourceIdentifiers=[bad_ri]) - assert 'The configuration aggregator does not exist' in ce.exception.response['Error']['Message'] + client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="lolno", ResourceIdentifiers=[bad_ri] + ) + assert ( + "The configuration aggregator does not exist" + in ce.exception.response["Error"]["Message"] + ) # Create the aggregator: account_aggregation_source = { - 'AccountIds': [ - '012345678910', - '111111111111', - '222222222222' - ], - 'AllAwsRegions': True + "AccountIds": ["012345678910", "111111111111", "222222222222"], + "AllAwsRegions": True, } client.put_configuration_aggregator( - ConfigurationAggregatorName='testing', - AccountAggregationSources=[account_aggregation_source] + ConfigurationAggregatorName="testing", + AccountAggregationSources=[account_aggregation_source], ) # With more than 100 items: with assert_raises(ClientError) as ce: - client.batch_get_aggregate_resource_config(ConfigurationAggregatorName='testing', ResourceIdentifiers=[bad_ri] * 101) - assert 'Member must have length less than or equal to 100' in ce.exception.response['Error']['Message'] + client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=[bad_ri] * 101 + ) + assert ( + "Member must have length less than or equal to 100" + in ce.exception.response["Error"]["Message"] + ) # Create some S3 buckets: - s3_client = boto3.client('s3', region_name='us-west-2') + s3_client = boto3.client("s3", region_name="us-west-2") for x in range(0, 10): - s3_client.create_bucket(Bucket='bucket{}'.format(x), CreateBucketConfiguration={'LocationConstraint': 'us-west-2'}) - s3_client.put_bucket_tagging(Bucket='bucket{}'.format(x), Tagging={'TagSet': [{'Key': 'Some', 'Value': 'Tag'}]}) + s3_client.create_bucket( + Bucket="bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) + s3_client.put_bucket_tagging( + Bucket="bucket{}".format(x), + Tagging={"TagSet": [{"Key": "Some", "Value": "Tag"}]}, + ) - s3_client_eu = boto3.client('s3', region_name='eu-west-1') + s3_client_eu = boto3.client("s3", region_name="eu-west-1") for x in range(10, 12): - s3_client_eu.create_bucket(Bucket='eu-bucket{}'.format(x), CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - s3_client.put_bucket_tagging(Bucket='eu-bucket{}'.format(x), Tagging={'TagSet': [{'Key': 'Some', 'Value': 'Tag'}]}) + s3_client_eu.create_bucket( + Bucket="eu-bucket{}".format(x), + CreateBucketConfiguration={"LocationConstraint": "eu-west-1"}, + ) + s3_client.put_bucket_tagging( + Bucket="eu-bucket{}".format(x), + Tagging={"TagSet": [{"Key": "Some", "Value": "Tag"}]}, + ) # Now try with resources that exist and ones that don't: - identifiers = [{'SourceAccountId': DEFAULT_ACCOUNT_ID, 'SourceRegion': 'us-west-2', 'ResourceType': 'AWS::S3::Bucket', - 'ResourceId': 'bucket{}'.format(x)} for x in range(0, 10)] - identifiers += [{'SourceAccountId': DEFAULT_ACCOUNT_ID, 'SourceRegion': 'eu-west-1', 'ResourceType': 'AWS::S3::Bucket', - 'ResourceId': 'eu-bucket{}'.format(x)} for x in range(10, 12)] + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "us-west-2", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket{}".format(x), + } + for x in range(0, 10) + ] + identifiers += [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "eu-west-1", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "eu-bucket{}".format(x), + } + for x in range(10, 12) + ] identifiers += [bad_ri] - result = client.batch_get_aggregate_resource_config(ConfigurationAggregatorName='testing', ResourceIdentifiers=identifiers) - assert len(result['UnprocessedResourceIdentifiers']) == 1 - assert result['UnprocessedResourceIdentifiers'][0] == bad_ri + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert len(result["UnprocessedResourceIdentifiers"]) == 1 + assert result["UnprocessedResourceIdentifiers"][0] == bad_ri # Verify all the buckets are there: - assert len(result['BaseConfigurationItems']) == 12 - missing_buckets = ['bucket{}'.format(x) for x in range(0, 10)] + ['eu-bucket{}'.format(x) for x in range(10, 12)] + assert len(result["BaseConfigurationItems"]) == 12 + missing_buckets = ["bucket{}".format(x) for x in range(0, 10)] + [ + "eu-bucket{}".format(x) for x in range(10, 12) + ] - for r in result['BaseConfigurationItems']: - missing_buckets.remove(r['resourceName']) + for r in result["BaseConfigurationItems"]: + missing_buckets.remove(r["resourceName"]) assert not missing_buckets # Verify that 'tags' is not in the result set: - for b in result['BaseConfigurationItems']: - assert not b.get('tags') - assert json.loads(b['supplementaryConfiguration']['BucketTaggingConfiguration']) == {'tagSets': [{'tags': {'Some': 'Tag'}}]} + for b in result["BaseConfigurationItems"]: + assert not b.get("tags") + assert json.loads( + b["supplementaryConfiguration"]["BucketTaggingConfiguration"] + ) == {"tagSets": [{"tags": {"Some": "Tag"}}]} # Verify that if the resource name and ID are correct that things are good: - identifiers = [{'SourceAccountId': DEFAULT_ACCOUNT_ID, 'SourceRegion': 'us-west-2', 'ResourceType': 'AWS::S3::Bucket', - 'ResourceId': 'bucket1', 'ResourceName': 'bucket1'}] - result = client.batch_get_aggregate_resource_config(ConfigurationAggregatorName='testing', ResourceIdentifiers=identifiers) - assert not result['UnprocessedResourceIdentifiers'] - assert len(result['BaseConfigurationItems']) == 1 and result['BaseConfigurationItems'][0]['resourceName'] == 'bucket1' + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "us-west-2", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket1", + "ResourceName": "bucket1", + } + ] + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert not result["UnprocessedResourceIdentifiers"] + assert ( + len(result["BaseConfigurationItems"]) == 1 + and result["BaseConfigurationItems"][0]["resourceName"] == "bucket1" + ) # Verify that if the resource name and ID mismatch that we don't get a result: - identifiers = [{'SourceAccountId': DEFAULT_ACCOUNT_ID, 'SourceRegion': 'us-west-2', 'ResourceType': 'AWS::S3::Bucket', - 'ResourceId': 'bucket1', 'ResourceName': 'bucket2'}] - result = client.batch_get_aggregate_resource_config(ConfigurationAggregatorName='testing', ResourceIdentifiers=identifiers) - assert not result['BaseConfigurationItems'] - assert len(result['UnprocessedResourceIdentifiers']) == 1 - assert len(result['UnprocessedResourceIdentifiers']) == 1 and result['UnprocessedResourceIdentifiers'][0]['ResourceName'] == 'bucket2' + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "us-west-2", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket1", + "ResourceName": "bucket2", + } + ] + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert not result["BaseConfigurationItems"] + assert len(result["UnprocessedResourceIdentifiers"]) == 1 + assert ( + len(result["UnprocessedResourceIdentifiers"]) == 1 + and result["UnprocessedResourceIdentifiers"][0]["ResourceName"] == "bucket2" + ) # Verify that if the region is incorrect that we don't get a result: - identifiers = [{'SourceAccountId': DEFAULT_ACCOUNT_ID, 'SourceRegion': 'eu-west-1', 'ResourceType': 'AWS::S3::Bucket', - 'ResourceId': 'bucket1'}] - result = client.batch_get_aggregate_resource_config(ConfigurationAggregatorName='testing', ResourceIdentifiers=identifiers) - assert not result['BaseConfigurationItems'] - assert len(result['UnprocessedResourceIdentifiers']) == 1 - assert len(result['UnprocessedResourceIdentifiers']) == 1 and result['UnprocessedResourceIdentifiers'][0]['SourceRegion'] == 'eu-west-1' + identifiers = [ + { + "SourceAccountId": DEFAULT_ACCOUNT_ID, + "SourceRegion": "eu-west-1", + "ResourceType": "AWS::S3::Bucket", + "ResourceId": "bucket1", + } + ] + result = client.batch_get_aggregate_resource_config( + ConfigurationAggregatorName="testing", ResourceIdentifiers=identifiers + ) + assert not result["BaseConfigurationItems"] + assert len(result["UnprocessedResourceIdentifiers"]) == 1 + assert ( + len(result["UnprocessedResourceIdentifiers"]) == 1 + and result["UnprocessedResourceIdentifiers"][0]["SourceRegion"] == "eu-west-1" + ) diff --git a/tests/test_core/test_auth.py b/tests/test_core/test_auth.py index 00229f808..7dc632188 100644 --- a/tests/test_core/test_auth.py +++ b/tests/test_core/test_auth.py @@ -3,6 +3,7 @@ import json import boto3 import sure # noqa from botocore.exceptions import ClientError + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -13,191 +14,251 @@ from moto.iam.models import ACCOUNT_ID @mock_iam -def create_user_with_access_key(user_name='test-user'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key(user_name="test-user"): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) - return client.create_access_key(UserName=user_name)['AccessKey'] + return client.create_access_key(UserName=user_name)["AccessKey"] @mock_iam -def create_user_with_access_key_and_inline_policy(user_name, policy_document, policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key_and_inline_policy( + user_name, policy_document, policy_name="policy1" +): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) - client.put_user_policy(UserName=user_name, PolicyName=policy_name, PolicyDocument=json.dumps(policy_document)) - return client.create_access_key(UserName=user_name)['AccessKey'] + client.put_user_policy( + UserName=user_name, + PolicyName=policy_name, + PolicyDocument=json.dumps(policy_document), + ) + return client.create_access_key(UserName=user_name)["AccessKey"] @mock_iam -def create_user_with_access_key_and_attached_policy(user_name, policy_document, policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key_and_attached_policy( + user_name, policy_document, policy_name="policy1" +): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) policy_arn = client.create_policy( - PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) - )['Policy']['Arn'] + PolicyName=policy_name, PolicyDocument=json.dumps(policy_document) + )["Policy"]["Arn"] client.attach_user_policy(UserName=user_name, PolicyArn=policy_arn) - return client.create_access_key(UserName=user_name)['AccessKey'] + return client.create_access_key(UserName=user_name)["AccessKey"] @mock_iam -def create_user_with_access_key_and_multiple_policies(user_name, inline_policy_document, - attached_policy_document, inline_policy_name='policy1', attached_policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_user_with_access_key_and_multiple_policies( + user_name, + inline_policy_document, + attached_policy_document, + inline_policy_name="policy1", + attached_policy_name="policy1", +): + client = boto3.client("iam", region_name="us-east-1") client.create_user(UserName=user_name) policy_arn = client.create_policy( PolicyName=attached_policy_name, - PolicyDocument=json.dumps(attached_policy_document) - )['Policy']['Arn'] + PolicyDocument=json.dumps(attached_policy_document), + )["Policy"]["Arn"] client.attach_user_policy(UserName=user_name, PolicyArn=policy_arn) - client.put_user_policy(UserName=user_name, PolicyName=inline_policy_name, PolicyDocument=json.dumps(inline_policy_document)) - return client.create_access_key(UserName=user_name)['AccessKey'] + client.put_user_policy( + UserName=user_name, + PolicyName=inline_policy_name, + PolicyDocument=json.dumps(inline_policy_document), + ) + return client.create_access_key(UserName=user_name)["AccessKey"] -def create_group_with_attached_policy_and_add_user(user_name, policy_document, - group_name='test-group', policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_group_with_attached_policy_and_add_user( + user_name, policy_document, group_name="test-group", policy_name="policy1" +): + client = boto3.client("iam", region_name="us-east-1") client.create_group(GroupName=group_name) policy_arn = client.create_policy( - PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) - )['Policy']['Arn'] + PolicyName=policy_name, PolicyDocument=json.dumps(policy_document) + )["Policy"]["Arn"] client.attach_group_policy(GroupName=group_name, PolicyArn=policy_arn) client.add_user_to_group(GroupName=group_name, UserName=user_name) -def create_group_with_inline_policy_and_add_user(user_name, policy_document, - group_name='test-group', policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_group_with_inline_policy_and_add_user( + user_name, policy_document, group_name="test-group", policy_name="policy1" +): + client = boto3.client("iam", region_name="us-east-1") client.create_group(GroupName=group_name) client.put_group_policy( GroupName=group_name, PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) + PolicyDocument=json.dumps(policy_document), ) client.add_user_to_group(GroupName=group_name, UserName=user_name) -def create_group_with_multiple_policies_and_add_user(user_name, inline_policy_document, - attached_policy_document, group_name='test-group', - inline_policy_name='policy1', attached_policy_name='policy1'): - client = boto3.client('iam', region_name='us-east-1') +def create_group_with_multiple_policies_and_add_user( + user_name, + inline_policy_document, + attached_policy_document, + group_name="test-group", + inline_policy_name="policy1", + attached_policy_name="policy1", +): + client = boto3.client("iam", region_name="us-east-1") client.create_group(GroupName=group_name) client.put_group_policy( GroupName=group_name, PolicyName=inline_policy_name, - PolicyDocument=json.dumps(inline_policy_document) + PolicyDocument=json.dumps(inline_policy_document), ) policy_arn = client.create_policy( PolicyName=attached_policy_name, - PolicyDocument=json.dumps(attached_policy_document) - )['Policy']['Arn'] + PolicyDocument=json.dumps(attached_policy_document), + )["Policy"]["Arn"] client.attach_group_policy(GroupName=group_name, PolicyArn=policy_arn) client.add_user_to_group(GroupName=group_name, UserName=user_name) @mock_iam @mock_sts -def create_role_with_attached_policy_and_assume_it(role_name, trust_policy_document, - policy_document, session_name='session1', policy_name='policy1'): - iam_client = boto3.client('iam', region_name='us-east-1') - sts_client = boto3.client('sts', region_name='us-east-1') +def create_role_with_attached_policy_and_assume_it( + role_name, + trust_policy_document, + policy_document, + session_name="session1", + policy_name="policy1", +): + iam_client = boto3.client("iam", region_name="us-east-1") + sts_client = boto3.client("sts", region_name="us-east-1") role_arn = iam_client.create_role( - RoleName=role_name, - AssumeRolePolicyDocument=json.dumps(trust_policy_document) - )['Role']['Arn'] + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy_document) + )["Role"]["Arn"] policy_arn = iam_client.create_policy( - PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) - )['Policy']['Arn'] + PolicyName=policy_name, PolicyDocument=json.dumps(policy_document) + )["Policy"]["Arn"] iam_client.attach_role_policy(RoleName=role_name, PolicyArn=policy_arn) - return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)['Credentials'] + return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)[ + "Credentials" + ] @mock_iam @mock_sts -def create_role_with_inline_policy_and_assume_it(role_name, trust_policy_document, - policy_document, session_name='session1', policy_name='policy1'): - iam_client = boto3.client('iam', region_name='us-east-1') - sts_client = boto3.client('sts', region_name='us-east-1') +def create_role_with_inline_policy_and_assume_it( + role_name, + trust_policy_document, + policy_document, + session_name="session1", + policy_name="policy1", +): + iam_client = boto3.client("iam", region_name="us-east-1") + sts_client = boto3.client("sts", region_name="us-east-1") role_arn = iam_client.create_role( - RoleName=role_name, - AssumeRolePolicyDocument=json.dumps(trust_policy_document) - )['Role']['Arn'] + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(trust_policy_document) + )["Role"]["Arn"] iam_client.put_role_policy( RoleName=role_name, PolicyName=policy_name, - PolicyDocument=json.dumps(policy_document) + PolicyDocument=json.dumps(policy_document), ) - return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)['Credentials'] + return sts_client.assume_role(RoleArn=role_arn, RoleSessionName=session_name)[ + "Credentials" + ] @set_initial_no_auth_action_count(0) @mock_iam def test_invalid_client_token_id(): - client = boto3.client('iam', region_name='us-east-1', aws_access_key_id='invalid', aws_secret_access_key='invalid') + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id="invalid", + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.get_user() - ex.exception.response['Error']['Code'].should.equal('InvalidClientTokenId') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The security token included in the request is invalid.') + ex.exception.response["Error"]["Code"].should.equal("InvalidClientTokenId") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The security token included in the request is invalid." + ) @set_initial_no_auth_action_count(0) @mock_ec2 def test_auth_failure(): - client = boto3.client('ec2', region_name='us-east-1', aws_access_key_id='invalid', aws_secret_access_key='invalid') + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id="invalid", + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.describe_instances() - ex.exception.response['Error']['Code'].should.equal('AuthFailure') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(401) - ex.exception.response['Error']['Message'].should.equal('AWS was not able to validate the provided access credentials') + ex.exception.response["Error"]["Code"].should.equal("AuthFailure") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(401) + ex.exception.response["Error"]["Message"].should.equal( + "AWS was not able to validate the provided access credentials" + ) @set_initial_no_auth_action_count(2) @mock_iam def test_signature_does_not_match(): access_key = create_user_with_access_key() - client = boto3.client('iam', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key='invalid') + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.get_user() - ex.exception.response['Error']['Code'].should.equal('SignatureDoesNotMatch') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details.') + ex.exception.response["Error"]["Code"].should.equal("SignatureDoesNotMatch") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The request signature we calculated does not match the signature you provided. Check your AWS Secret Access Key and signing method. Consult the service documentation for details." + ) @set_initial_no_auth_action_count(2) @mock_ec2 def test_auth_failure_with_valid_access_key_id(): access_key = create_user_with_access_key() - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key='invalid') + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.describe_instances() - ex.exception.response['Error']['Code'].should.equal('AuthFailure') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(401) - ex.exception.response['Error']['Message'].should.equal('AWS was not able to validate the provided access credentials') + ex.exception.response["Error"]["Code"].should.equal("AuthFailure") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(401) + ex.exception.response["Error"]["Message"].should.equal( + "AWS was not able to validate the provided access credentials" + ) @set_initial_no_auth_action_count(2) @mock_ec2 def test_access_denied_with_no_policy(): - user_name = 'test-user' + user_name = "test-user" access_key = create_user_with_access_key(user_name) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.describe_instances() - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( account_id=ACCOUNT_ID, user_name=user_name, - operation="ec2:DescribeInstances" + operation="ec2:DescribeInstances", ) ) @@ -205,32 +266,29 @@ def test_access_denied_with_no_policy(): @set_initial_no_auth_action_count(3) @mock_ec2 def test_access_denied_with_not_allowing_policy(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": [ - "ec2:Describe*" - ], - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": ["ec2:Describe*"], "Resource": "*"} + ], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.run_instances(MaxCount=1, MinCount=1) - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( - account_id=ACCOUNT_ID, - user_name=user_name, - operation="ec2:RunInstances" + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( + account_id=ACCOUNT_ID, user_name=user_name, operation="ec2:RunInstances" ) ) @@ -238,37 +296,30 @@ def test_access_denied_with_not_allowing_policy(): @set_initial_no_auth_action_count(3) @mock_ec2 def test_access_denied_with_denying_policy(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": [ - "ec2:*", - ], - "Resource": "*" - }, - { - "Effect": "Deny", - "Action": "ec2:CreateVpc", - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": ["ec2:*"], "Resource": "*"}, + {"Effect": "Deny", "Action": "ec2:CreateVpc", "Resource": "*"}, + ], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.create_vpc(CidrBlock="10.0.0.0/16") - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( - account_id=ACCOUNT_ID, - user_name=user_name, - operation="ec2:CreateVpc" + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( + account_id=ACCOUNT_ID, user_name=user_name, operation="ec2:CreateVpc" ) ) @@ -276,203 +327,173 @@ def test_access_denied_with_denying_policy(): @set_initial_no_auth_action_count(3) @mock_sts def test_get_caller_identity_allowed_with_denying_policy(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Deny", - "Action": "sts:GetCallerIdentity", - "Resource": "*" - } - ] + {"Effect": "Deny", "Action": "sts:GetCallerIdentity", "Resource": "*"} + ], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('sts', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "sts", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) client.get_caller_identity().should.be.a(dict) @set_initial_no_auth_action_count(3) @mock_ec2 def test_allowed_with_wildcard_action(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "ec2:Describe*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "ec2:Describe*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) - client.describe_tags()['Tags'].should.be.empty + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) + client.describe_tags()["Tags"].should.be.empty @set_initial_no_auth_action_count(4) @mock_iam def test_allowed_with_explicit_action_in_attached_policy(): - user_name = 'test-user' + user_name = "test-user" attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "iam:ListGroups", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "iam:ListGroups", "Resource": "*"}], } - access_key = create_user_with_access_key_and_attached_policy(user_name, attached_policy_document) - client = boto3.client('iam', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) - client.list_groups()['Groups'].should.be.empty + access_key = create_user_with_access_key_and_attached_policy( + user_name, attached_policy_document + ) + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) + client.list_groups()["Groups"].should.be.empty @set_initial_no_auth_action_count(8) @mock_s3 @mock_iam def test_s3_access_denied_with_denying_attached_group_policy(): - user_name = 'test-user' + user_name = "test-user" attached_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": "s3:ListAllMyBuckets", - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": "s3:ListAllMyBuckets", "Resource": "*"} + ], } group_attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "s3:List*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "s3:List*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_attached_policy(user_name, attached_policy_document) - create_group_with_attached_policy_and_add_user(user_name, group_attached_policy_document) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_attached_policy( + user_name, attached_policy_document + ) + create_group_with_attached_policy_and_add_user( + user_name, group_attached_policy_document + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.list_buckets() - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('Access Denied') + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal("Access Denied") @set_initial_no_auth_action_count(6) @mock_s3 @mock_iam def test_s3_access_denied_with_denying_inline_group_policy(): - user_name = 'test-user' - bucket_name = 'test-bucket' + user_name = "test-user" + bucket_name = "test-bucket" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}], } group_inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "s3:GetObject", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "s3:GetObject", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - create_group_with_inline_policy_and_add_user(user_name, group_inline_policy_document) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + create_group_with_inline_policy_and_add_user( + user_name, group_inline_policy_document + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: - client.get_object(Bucket=bucket_name, Key='sdfsdf') - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('Access Denied') + client.get_object(Bucket=bucket_name, Key="sdfsdf") + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal("Access Denied") @set_initial_no_auth_action_count(10) @mock_iam @mock_ec2 def test_access_denied_with_many_irrelevant_policies(): - user_name = 'test-user' + user_name = "test-user" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "ec2:Describe*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "ec2:Describe*", "Resource": "*"}], } attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "s3:*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "s3:*", "Resource": "*"}], } group_inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "iam:List*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "iam:List*", "Resource": "*"}], } group_attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "Action": "lambda:*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "Action": "lambda:*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_multiple_policies(user_name, inline_policy_document, - attached_policy_document) - create_group_with_multiple_policies_and_add_user(user_name, group_inline_policy_document, - group_attached_policy_document) - client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_multiple_policies( + user_name, inline_policy_document, attached_policy_document + ) + create_group_with_multiple_policies_and_add_user( + user_name, group_inline_policy_document, group_attached_policy_document + ) + client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) with assert_raises(ClientError) as ex: client.create_key_pair(KeyName="TestKey") - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}'.format( - account_id=ACCOUNT_ID, - user_name=user_name, - operation="ec2:CreateKeyPair" + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:iam::{account_id}:user/{user_name} is not authorized to perform: {operation}".format( + account_id=ACCOUNT_ID, user_name=user_name, operation="ec2:CreateKeyPair" ) ) @@ -483,14 +504,16 @@ def test_access_denied_with_many_irrelevant_policies(): @mock_ec2 @mock_elbv2 def test_allowed_with_temporary_credentials(): - role_name = 'test-role' + role_name = "test-role" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } attached_policy_document = { "Version": "2012-10-17", @@ -499,30 +522,35 @@ def test_allowed_with_temporary_credentials(): "Effect": "Allow", "Action": [ "elasticloadbalancing:CreateLoadBalancer", - "ec2:DescribeSubnets" + "ec2:DescribeSubnets", ], - "Resource": "*" + "Resource": "*", } - ] + ], } - credentials = create_role_with_attached_policy_and_assume_it(role_name, trust_policy_document, attached_policy_document) - elbv2_client = boto3.client('elbv2', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token=credentials['SessionToken']) - ec2_client = boto3.client('ec2', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token=credentials['SessionToken']) - subnets = ec2_client.describe_subnets()['Subnets'] + credentials = create_role_with_attached_policy_and_assume_it( + role_name, trust_policy_document, attached_policy_document + ) + elbv2_client = boto3.client( + "elbv2", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + ec2_client = boto3.client( + "ec2", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + subnets = ec2_client.describe_subnets()["Subnets"] len(subnets).should.be.greater_than(1) elbv2_client.create_load_balancer( - Name='test-load-balancer', - Subnets=[ - subnets[0]['SubnetId'], - subnets[1]['SubnetId'] - ] - )['LoadBalancers'].should.have.length_of(1) + Name="test-load-balancer", + Subnets=[subnets[0]["SubnetId"], subnets[1]["SubnetId"]], + )["LoadBalancers"].should.have.length_of(1) @set_initial_no_auth_action_count(3) @@ -530,48 +558,48 @@ def test_allowed_with_temporary_credentials(): @mock_sts @mock_rds2 def test_access_denied_with_temporary_credentials(): - role_name = 'test-role' - session_name = 'test-session' + role_name = "test-role" + session_name = "test-session" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } attached_policy_document = { "Version": "2012-10-17", "Statement": [ - { - "Effect": "Allow", - "Action": [ - 'rds:Describe*' - ], - "Resource": "*" - } - ] + {"Effect": "Allow", "Action": ["rds:Describe*"], "Resource": "*"} + ], } - credentials = create_role_with_inline_policy_and_assume_it(role_name, trust_policy_document, - attached_policy_document, session_name) - client = boto3.client('rds', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token=credentials['SessionToken']) + credentials = create_role_with_inline_policy_and_assume_it( + role_name, trust_policy_document, attached_policy_document, session_name + ) + client = boto3.client( + "rds", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) with assert_raises(ClientError) as ex: client.create_db_instance( - DBInstanceIdentifier='test-db-instance', - DBInstanceClass='db.t3', - Engine='aurora-postgresql' + DBInstanceIdentifier="test-db-instance", + DBInstanceClass="db.t3", + Engine="aurora-postgresql", ) - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal( - 'User: arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name} is not authorized to perform: {operation}'.format( + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "User: arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name} is not authorized to perform: {operation}".format( account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name, - operation="rds:CreateDBInstance" + operation="rds:CreateDBInstance", ) ) @@ -579,89 +607,95 @@ def test_access_denied_with_temporary_credentials(): @set_initial_no_auth_action_count(3) @mock_iam def test_get_user_from_credentials(): - user_name = 'new-test-user' + user_name = "new-test-user" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "iam:*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "iam:*", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - client = boto3.client('iam', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) - client.get_user()['User']['UserName'].should.equal(user_name) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + client = boto3.client( + "iam", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) + client.get_user()["User"]["UserName"].should.equal(user_name) @set_initial_no_auth_action_count(0) @mock_s3 def test_s3_invalid_access_key_id(): - client = boto3.client('s3', region_name='us-east-1', aws_access_key_id='invalid', aws_secret_access_key='invalid') + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id="invalid", + aws_secret_access_key="invalid", + ) with assert_raises(ClientError) as ex: client.list_buckets() - ex.exception.response['Error']['Code'].should.equal('InvalidAccessKeyId') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The AWS Access Key Id you provided does not exist in our records.') + ex.exception.response["Error"]["Code"].should.equal("InvalidAccessKeyId") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The AWS Access Key Id you provided does not exist in our records." + ) @set_initial_no_auth_action_count(3) @mock_s3 @mock_iam def test_s3_signature_does_not_match(): - bucket_name = 'test-bucket' + bucket_name = "test-bucket" access_key = create_user_with_access_key() - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key='invalid') + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key="invalid", + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: client.put_object(Bucket=bucket_name, Key="abc") - ex.exception.response['Error']['Code'].should.equal('SignatureDoesNotMatch') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('The request signature we calculated does not match the signature you provided. Check your key and signing method.') + ex.exception.response["Error"]["Code"].should.equal("SignatureDoesNotMatch") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal( + "The request signature we calculated does not match the signature you provided. Check your key and signing method." + ) @set_initial_no_auth_action_count(7) @mock_s3 @mock_iam def test_s3_access_denied_not_action(): - user_name = 'test-user' - bucket_name = 'test-bucket' + user_name = "test-user" + bucket_name = "test-bucket" inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": "*", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": "*", "Resource": "*"}], } group_inline_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Deny", - "NotAction": "iam:GetUser", - "Resource": "*" - } - ] + "Statement": [{"Effect": "Deny", "NotAction": "iam:GetUser", "Resource": "*"}], } - access_key = create_user_with_access_key_and_inline_policy(user_name, inline_policy_document) - create_group_with_inline_policy_and_add_user(user_name, group_inline_policy_document) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']) + access_key = create_user_with_access_key_and_inline_policy( + user_name, inline_policy_document + ) + create_group_with_inline_policy_and_add_user( + user_name, group_inline_policy_document + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: - client.delete_object(Bucket=bucket_name, Key='sdfsdf') - ex.exception.response['Error']['Code'].should.equal('AccessDenied') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(403) - ex.exception.response['Error']['Message'].should.equal('Access Denied') + client.delete_object(Bucket=bucket_name, Key="sdfsdf") + ex.exception.response["Error"]["Code"].should.equal("AccessDenied") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(403) + ex.exception.response["Error"]["Message"].should.equal("Access Denied") @set_initial_no_auth_action_count(4) @@ -669,38 +703,38 @@ def test_s3_access_denied_not_action(): @mock_sts @mock_s3 def test_s3_invalid_token_with_temporary_credentials(): - role_name = 'test-role' - session_name = 'test-session' - bucket_name = 'test-bucket-888' + role_name = "test-role" + session_name = "test-session" + bucket_name = "test-bucket-888" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } attached_policy_document = { "Version": "2012-10-17", - "Statement": [ - { - "Effect": "Allow", - "Action": [ - '*' - ], - "Resource": "*" - } - ] + "Statement": [{"Effect": "Allow", "Action": ["*"], "Resource": "*"}], } - credentials = create_role_with_inline_policy_and_assume_it(role_name, trust_policy_document, - attached_policy_document, session_name) - client = boto3.client('s3', region_name='us-east-1', - aws_access_key_id=credentials['AccessKeyId'], - aws_secret_access_key=credentials['SecretAccessKey'], - aws_session_token='invalid') + credentials = create_role_with_inline_policy_and_assume_it( + role_name, trust_policy_document, attached_policy_document, session_name + ) + client = boto3.client( + "s3", + region_name="us-east-1", + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token="invalid", + ) client.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as ex: client.list_bucket_metrics_configurations(Bucket=bucket_name) - ex.exception.response['Error']['Code'].should.equal('InvalidToken') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal('The provided token is malformed or otherwise invalid.') + ex.exception.response["Error"]["Code"].should.equal("InvalidToken") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "The provided token is malformed or otherwise invalid." + ) diff --git a/tests/test_core/test_context_manager.py b/tests/test_core/test_context_manager.py index 4824e021f..d20c2187f 100644 --- a/tests/test_core/test_context_manager.py +++ b/tests/test_core/test_context_manager.py @@ -5,8 +5,8 @@ from moto import mock_sqs, settings def test_context_manager_returns_mock(): with mock_sqs() as sqs_mock: - conn = boto3.client("sqs", region_name='us-west-1') + conn = boto3.client("sqs", region_name="us-west-1") conn.create_queue(QueueName="queue1") if not settings.TEST_SERVER_MODE: - list(sqs_mock.backends['us-west-1'].queues.keys()).should.equal(['queue1']) + list(sqs_mock.backends["us-west-1"].queues.keys()).should.equal(["queue1"]) diff --git a/tests/test_core/test_decorator_calls.py b/tests/test_core/test_decorator_calls.py index 5d2f6a4ef..408ca6819 100644 --- a/tests/test_core/test_decorator_calls.py +++ b/tests/test_core/test_decorator_calls.py @@ -9,9 +9,9 @@ from nose.tools import assert_raises from moto import mock_ec2_deprecated, mock_s3_deprecated -''' +""" Test the different ways that the decorator can be used -''' +""" @mock_ec2_deprecated @@ -21,32 +21,32 @@ def test_basic_connect(): @mock_ec2_deprecated def test_basic_decorator(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") list(conn.get_all_instances()).should.equal([]) def test_context_manager(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError): conn.get_all_instances() with mock_ec2_deprecated(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") list(conn.get_all_instances()).should.equal([]) with assert_raises(EC2ResponseError): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") conn.get_all_instances() def test_decorator_start_and_stop(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError): conn.get_all_instances() mock = mock_ec2_deprecated() mock.start() - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") list(conn.get_all_instances()).should.equal([]) mock.stop() @@ -60,12 +60,12 @@ def test_decorater_wrapped_gets_set(): Moto decorator's __wrapped__ should get set to the tests function """ test_decorater_wrapped_gets_set.__wrapped__.__name__.should.equal( - 'test_decorater_wrapped_gets_set') + "test_decorater_wrapped_gets_set" + ) @mock_ec2_deprecated class Tester(object): - def test_the_class(self): conn = boto.connect_ec2() list(conn.get_all_instances()).should.have.length_of(0) @@ -77,19 +77,17 @@ class Tester(object): @mock_s3_deprecated class TesterWithSetup(unittest.TestCase): - def setUp(self): self.conn = boto.connect_s3() - self.conn.create_bucket('mybucket') + self.conn.create_bucket("mybucket") def test_still_the_same(self): - bucket = self.conn.get_bucket('mybucket') + bucket = self.conn.get_bucket("mybucket") bucket.name.should.equal("mybucket") @mock_s3_deprecated class TesterWithStaticmethod(object): - @staticmethod def static(*args): assert not args or not isinstance(args[0], TesterWithStaticmethod) diff --git a/tests/test_core/test_instance_metadata.py b/tests/test_core/test_instance_metadata.py index f8bf24814..d30138d5d 100644 --- a/tests/test_core/test_instance_metadata.py +++ b/tests/test_core/test_instance_metadata.py @@ -6,9 +6,9 @@ import requests from moto import mock_ec2, settings if settings.TEST_SERVER_MODE: - BASE_URL = 'http://localhost:5000' + BASE_URL = "http://localhost:5000" else: - BASE_URL = 'http://169.254.169.254' + BASE_URL = "http://169.254.169.254" @mock_ec2 @@ -21,26 +21,28 @@ def test_latest_meta_data(): def test_meta_data_iam(): res = requests.get("{0}/latest/meta-data/iam".format(BASE_URL)) json_response = res.json() - default_role = json_response['security-credentials']['default-role'] - default_role.should.contain('AccessKeyId') - default_role.should.contain('SecretAccessKey') - default_role.should.contain('Token') - default_role.should.contain('Expiration') + default_role = json_response["security-credentials"]["default-role"] + default_role.should.contain("AccessKeyId") + default_role.should.contain("SecretAccessKey") + default_role.should.contain("Token") + default_role.should.contain("Expiration") @mock_ec2 def test_meta_data_security_credentials(): res = requests.get( - "{0}/latest/meta-data/iam/security-credentials/".format(BASE_URL)) + "{0}/latest/meta-data/iam/security-credentials/".format(BASE_URL) + ) res.content.should.equal(b"default-role") @mock_ec2 def test_meta_data_default_role(): res = requests.get( - "{0}/latest/meta-data/iam/security-credentials/default-role".format(BASE_URL)) + "{0}/latest/meta-data/iam/security-credentials/default-role".format(BASE_URL) + ) json_response = res.json() - json_response.should.contain('AccessKeyId') - json_response.should.contain('SecretAccessKey') - json_response.should.contain('Token') - json_response.should.contain('Expiration') + json_response.should.contain("AccessKeyId") + json_response.should.contain("SecretAccessKey") + json_response.should.contain("Token") + json_response.should.contain("Expiration") diff --git a/tests/test_core/test_moto_api.py b/tests/test_core/test_moto_api.py index cb0ca8939..6482d903e 100644 --- a/tests/test_core/test_moto_api.py +++ b/tests/test_core/test_moto_api.py @@ -6,28 +6,32 @@ import requests import boto3 from moto import mock_sqs, settings -base_url = "http://localhost:5000" if settings.TEST_SERVER_MODE else "http://motoapi.amazonaws.com" +base_url = ( + "http://localhost:5000" + if settings.TEST_SERVER_MODE + else "http://motoapi.amazonaws.com" +) @mock_sqs def test_reset_api(): - conn = boto3.client("sqs", region_name='us-west-1') + conn = boto3.client("sqs", region_name="us-west-1") conn.create_queue(QueueName="queue1") - conn.list_queues()['QueueUrls'].should.have.length_of(1) + conn.list_queues()["QueueUrls"].should.have.length_of(1) res = requests.post("{base_url}/moto-api/reset".format(base_url=base_url)) res.content.should.equal(b'{"status": "ok"}') - conn.list_queues().shouldnt.contain('QueueUrls') # No more queues + conn.list_queues().shouldnt.contain("QueueUrls") # No more queues @mock_sqs def test_data_api(): - conn = boto3.client("sqs", region_name='us-west-1') + conn = boto3.client("sqs", region_name="us-west-1") conn.create_queue(QueueName="queue1") res = requests.post("{base_url}/moto-api/data.json".format(base_url=base_url)) - queues = res.json()['sqs']['Queue'] + queues = res.json()["sqs"]["Queue"] len(queues).should.equal(1) queue = queues[0] - queue['name'].should.equal("queue1") + queue["name"].should.equal("queue1") diff --git a/tests/test_core/test_nested.py b/tests/test_core/test_nested.py index 7c0b8f687..04b04257c 100644 --- a/tests/test_core/test_nested.py +++ b/tests/test_core/test_nested.py @@ -9,14 +9,13 @@ from moto import mock_sqs_deprecated, mock_ec2_deprecated class TestNestedDecorators(unittest.TestCase): - @mock_sqs_deprecated def setup_sqs_queue(self): conn = SQSConnection() - q = conn.create_queue('some-queue') + q = conn.create_queue("some-queue") m = Message() - m.set_body('This is my first message.') + m.set_body("This is my first message.") q.write(m) self.assertEqual(q.count(), 1) @@ -26,4 +25,4 @@ class TestNestedDecorators(unittest.TestCase): self.setup_sqs_queue() conn = EC2Connection() - conn.run_instances('ami-123456') + conn.run_instances("ami-123456") diff --git a/tests/test_core/test_request_mocking.py b/tests/test_core/test_request_mocking.py index ee3ec5f88..2c44d52ce 100644 --- a/tests/test_core/test_request_mocking.py +++ b/tests/test_core/test_request_mocking.py @@ -7,7 +7,7 @@ from moto import mock_sqs, settings @mock_sqs def test_passthrough_requests(): - conn = boto3.client("sqs", region_name='us-west-1') + conn = boto3.client("sqs", region_name="us-west-1") conn.create_queue(QueueName="queue1") res = requests.get("https://httpbin.org/ip") @@ -15,6 +15,7 @@ def test_passthrough_requests(): if not settings.TEST_SERVER_MODE: + @mock_sqs def test_requests_to_amazon_subdomains_dont_work(): res = requests.get("https://fakeservice.amazonaws.com/foo/bar") diff --git a/tests/test_core/test_responses.py b/tests/test_core/test_responses.py index d0f672ab8..587e3584b 100644 --- a/tests/test_core/test_responses.py +++ b/tests/test_core/test_responses.py @@ -9,81 +9,86 @@ from moto.core.responses import flatten_json_request_body def test_flatten_json_request_body(): - spec = AWSServiceSpec( - 'data/emr/2009-03-31/service-2.json').input_spec('RunJobFlow') + spec = AWSServiceSpec("data/emr/2009-03-31/service-2.json").input_spec("RunJobFlow") body = { - 'Name': 'cluster', - 'Instances': { - 'Ec2KeyName': 'ec2key', - 'InstanceGroups': [ - {'InstanceRole': 'MASTER', - 'InstanceType': 'm1.small'}, - {'InstanceRole': 'CORE', - 'InstanceType': 'm1.medium'}, + "Name": "cluster", + "Instances": { + "Ec2KeyName": "ec2key", + "InstanceGroups": [ + {"InstanceRole": "MASTER", "InstanceType": "m1.small"}, + {"InstanceRole": "CORE", "InstanceType": "m1.medium"}, ], - 'Placement': {'AvailabilityZone': 'us-east-1'}, + "Placement": {"AvailabilityZone": "us-east-1"}, }, - 'Steps': [ - {'HadoopJarStep': { - 'Properties': [ - {'Key': 'k1', 'Value': 'v1'}, - {'Key': 'k2', 'Value': 'v2'} - ], - 'Args': ['arg1', 'arg2']}}, + "Steps": [ + { + "HadoopJarStep": { + "Properties": [ + {"Key": "k1", "Value": "v1"}, + {"Key": "k2", "Value": "v2"}, + ], + "Args": ["arg1", "arg2"], + } + } + ], + "Configurations": [ + { + "Classification": "class", + "Properties": {"propkey1": "propkey1", "propkey2": "propkey2"}, + }, + {"Classification": "anotherclass", "Properties": {"propkey3": "propkey3"}}, ], - 'Configurations': [ - {'Classification': 'class', - 'Properties': {'propkey1': 'propkey1', - 'propkey2': 'propkey2'}}, - {'Classification': 'anotherclass', - 'Properties': {'propkey3': 'propkey3'}}, - ] } - flat = flatten_json_request_body('', body, spec) - flat['Name'].should.equal(body['Name']) - flat['Instances.Ec2KeyName'].should.equal(body['Instances']['Ec2KeyName']) + flat = flatten_json_request_body("", body, spec) + flat["Name"].should.equal(body["Name"]) + flat["Instances.Ec2KeyName"].should.equal(body["Instances"]["Ec2KeyName"]) for idx in range(2): - flat['Instances.InstanceGroups.member.' + str(idx + 1) + '.InstanceRole'].should.equal( - body['Instances']['InstanceGroups'][idx]['InstanceRole']) - flat['Instances.InstanceGroups.member.' + str(idx + 1) + '.InstanceType'].should.equal( - body['Instances']['InstanceGroups'][idx]['InstanceType']) - flat['Instances.Placement.AvailabilityZone'].should.equal( - body['Instances']['Placement']['AvailabilityZone']) + flat[ + "Instances.InstanceGroups.member." + str(idx + 1) + ".InstanceRole" + ].should.equal(body["Instances"]["InstanceGroups"][idx]["InstanceRole"]) + flat[ + "Instances.InstanceGroups.member." + str(idx + 1) + ".InstanceType" + ].should.equal(body["Instances"]["InstanceGroups"][idx]["InstanceType"]) + flat["Instances.Placement.AvailabilityZone"].should.equal( + body["Instances"]["Placement"]["AvailabilityZone"] + ) for idx in range(1): - prefix = 'Steps.member.' + str(idx + 1) + '.HadoopJarStep' - step = body['Steps'][idx]['HadoopJarStep'] + prefix = "Steps.member." + str(idx + 1) + ".HadoopJarStep" + step = body["Steps"][idx]["HadoopJarStep"] i = 0 - while prefix + '.Properties.member.' + str(i + 1) + '.Key' in flat: - flat[prefix + '.Properties.member.' + - str(i + 1) + '.Key'].should.equal(step['Properties'][i]['Key']) - flat[prefix + '.Properties.member.' + - str(i + 1) + '.Value'].should.equal(step['Properties'][i]['Value']) + while prefix + ".Properties.member." + str(i + 1) + ".Key" in flat: + flat[prefix + ".Properties.member." + str(i + 1) + ".Key"].should.equal( + step["Properties"][i]["Key"] + ) + flat[prefix + ".Properties.member." + str(i + 1) + ".Value"].should.equal( + step["Properties"][i]["Value"] + ) i += 1 i = 0 - while prefix + '.Args.member.' + str(i + 1) in flat: - flat[prefix + '.Args.member.' + - str(i + 1)].should.equal(step['Args'][i]) + while prefix + ".Args.member." + str(i + 1) in flat: + flat[prefix + ".Args.member." + str(i + 1)].should.equal(step["Args"][i]) i += 1 for idx in range(2): - flat['Configurations.member.' + str(idx + 1) + '.Classification'].should.equal( - body['Configurations'][idx]['Classification']) + flat["Configurations.member." + str(idx + 1) + ".Classification"].should.equal( + body["Configurations"][idx]["Classification"] + ) props = {} i = 1 - keyfmt = 'Configurations.member.{0}.Properties.entry.{1}' + keyfmt = "Configurations.member.{0}.Properties.entry.{1}" key = keyfmt.format(idx + 1, i) - while key + '.key' in flat: - props[flat[key + '.key']] = flat[key + '.value'] + while key + ".key" in flat: + props[flat[key + ".key"]] = flat[key + ".value"] i += 1 key = keyfmt.format(idx + 1, i) - props.should.equal(body['Configurations'][idx]['Properties']) + props.should.equal(body["Configurations"][idx]["Properties"]) def test_parse_qs_unicode_decode_error(): body = b'{"key": "%D0"}, "C": "#0 = :0"}' - request = AWSPreparedRequest('GET', 'http://request', {'foo': 'bar'}, body, False) + request = AWSPreparedRequest("GET", "http://request", {"foo": "bar"}, body, False) BaseResponse().setup_class(request, request.url, request.headers) diff --git a/tests/test_core/test_server.py b/tests/test_core/test_server.py index bd00b17c3..5514223af 100644 --- a/tests/test_core/test_server.py +++ b/tests/test_core/test_server.py @@ -8,13 +8,15 @@ from moto.server import main, create_backend_app, DomainDispatcherApplication def test_wrong_arguments(): try: main(["name", "test1", "test2", "test3"]) - assert False, ("main() when called with the incorrect number of args" - " should raise a system exit") + assert False, ( + "main() when called with the incorrect number of args" + " should raise a system exit" + ) except SystemExit: pass -@patch('moto.server.run_simple') +@patch("moto.server.run_simple") def test_right_arguments(run_simple): main(["s3"]) func_call = run_simple.call_args[0] @@ -22,7 +24,7 @@ def test_right_arguments(run_simple): func_call[1].should.equal(5000) -@patch('moto.server.run_simple') +@patch("moto.server.run_simple") def test_port_argument(run_simple): main(["s3", "--port", "8080"]) func_call = run_simple.call_args[0] @@ -33,15 +35,15 @@ def test_port_argument(run_simple): def test_domain_dispatched(): dispatcher = DomainDispatcherApplication(create_backend_app) backend_app = dispatcher.get_application( - {"HTTP_HOST": "email.us-east1.amazonaws.com"}) + {"HTTP_HOST": "email.us-east1.amazonaws.com"} + ) keys = list(backend_app.view_functions.keys()) - keys[0].should.equal('EmailResponse.dispatch') + keys[0].should.equal("EmailResponse.dispatch") def test_domain_dispatched_with_service(): # If we pass a particular service, always return that. dispatcher = DomainDispatcherApplication(create_backend_app, service="s3") - backend_app = dispatcher.get_application( - {"HTTP_HOST": "s3.us-east1.amazonaws.com"}) + backend_app = dispatcher.get_application({"HTTP_HOST": "s3.us-east1.amazonaws.com"}) keys = set(backend_app.view_functions.keys()) - keys.should.contain('ResponseObject.key_response') + keys.should.contain("ResponseObject.key_response") diff --git a/tests/test_core/test_socket.py b/tests/test_core/test_socket.py index 2e73d7b5f..5e446ca1a 100644 --- a/tests/test_core/test_socket.py +++ b/tests/test_core/test_socket.py @@ -6,16 +6,16 @@ from six import PY3 class TestSocketPair(unittest.TestCase): - @mock_dynamodb2_deprecated def test_asyncio_deprecated(self): if PY3: self.assertIn( - 'moto.packages.httpretty.core.fakesock.socket', + "moto.packages.httpretty.core.fakesock.socket", str(socket.socket), - 'Our mock should be present' + "Our mock should be present", ) import asyncio + self.assertIsNotNone(asyncio.get_event_loop()) @mock_dynamodb2_deprecated @@ -24,9 +24,9 @@ class TestSocketPair(unittest.TestCase): # In Python2, the fakesocket is not set, for some reason. if PY3: self.assertIn( - 'moto.packages.httpretty.core.fakesock.socket', + "moto.packages.httpretty.core.fakesock.socket", str(socket.socket), - 'Our mock should be present' + "Our mock should be present", ) a, b = socket.socketpair() self.assertIsNotNone(a) @@ -36,7 +36,6 @@ class TestSocketPair(unittest.TestCase): if b: b.close() - @mock_dynamodb2 def test_socket_pair(self): a, b = socket.socketpair() diff --git a/tests/test_core/test_url_mapping.py b/tests/test_core/test_url_mapping.py index 8f7921a5a..4dccc4f21 100644 --- a/tests/test_core/test_url_mapping.py +++ b/tests/test_core/test_url_mapping.py @@ -14,8 +14,9 @@ def test_flask_path_converting_simple(): def test_flask_path_converting_regex(): - convert_regex_to_flask_path( - "/(?P[a-zA-Z0-9\-_]+)").should.equal('/') + convert_regex_to_flask_path("/(?P[a-zA-Z0-9\-_]+)").should.equal( + '/' + ) convert_regex_to_flask_path("(?P\d+)/(?P.*)$").should.equal( '/' diff --git a/tests/test_core/test_utils.py b/tests/test_core/test_utils.py index 8dbf21716..7c72aaccd 100644 --- a/tests/test_core/test_utils.py +++ b/tests/test_core/test_utils.py @@ -3,7 +3,11 @@ from __future__ import unicode_literals import sure # noqa from freezegun import freeze_time -from moto.core.utils import camelcase_to_underscores, underscores_to_camelcase, unix_time +from moto.core.utils import ( + camelcase_to_underscores, + underscores_to_camelcase, + unix_time, +) def test_camelcase_to_underscores(): @@ -18,9 +22,7 @@ def test_camelcase_to_underscores(): def test_underscores_to_camelcase(): - cases = { - "the_new_attribute": "theNewAttribute", - } + cases = {"the_new_attribute": "theNewAttribute"} for arg, expected in cases.items(): underscores_to_camelcase(arg).should.equal(expected) diff --git a/tests/test_datapipeline/test_datapipeline.py b/tests/test_datapipeline/test_datapipeline.py index ce190c7e4..42063b506 100644 --- a/tests/test_datapipeline/test_datapipeline.py +++ b/tests/test_datapipeline/test_datapipeline.py @@ -9,8 +9,8 @@ from moto.datapipeline.utils import remove_capitalization_of_dict_keys def get_value_from_fields(key, fields): for field in fields: - if field['key'] == key: - return field['stringValue'] + if field["key"] == key: + return field["stringValue"] @mock_datapipeline_deprecated @@ -20,62 +20,46 @@ def test_create_pipeline(): res = conn.create_pipeline("mypipeline", "some-unique-id") pipeline_id = res["pipelineId"] - pipeline_descriptions = conn.describe_pipelines( - [pipeline_id])["pipelineDescriptionList"] + pipeline_descriptions = conn.describe_pipelines([pipeline_id])[ + "pipelineDescriptionList" + ] pipeline_descriptions.should.have.length_of(1) pipeline_description = pipeline_descriptions[0] - pipeline_description['name'].should.equal("mypipeline") + pipeline_description["name"].should.equal("mypipeline") pipeline_description["pipelineId"].should.equal(pipeline_id) - fields = pipeline_description['fields'] + fields = pipeline_description["fields"] - get_value_from_fields('@pipelineState', fields).should.equal("PENDING") - get_value_from_fields('uniqueId', fields).should.equal("some-unique-id") + get_value_from_fields("@pipelineState", fields).should.equal("PENDING") + get_value_from_fields("uniqueId", fields).should.equal("some-unique-id") PIPELINE_OBJECTS = [ { "id": "Default", "name": "Default", - "fields": [{ - "key": "workerGroup", - "stringValue": "workerGroup" - }] + "fields": [{"key": "workerGroup", "stringValue": "workerGroup"}], }, { "id": "Schedule", "name": "Schedule", - "fields": [{ - "key": "startDateTime", - "stringValue": "2012-12-12T00:00:00" - }, { - "key": "type", - "stringValue": "Schedule" - }, { - "key": "period", - "stringValue": "1 hour" - }, { - "key": "endDateTime", - "stringValue": "2012-12-21T18:00:00" - }] + "fields": [ + {"key": "startDateTime", "stringValue": "2012-12-12T00:00:00"}, + {"key": "type", "stringValue": "Schedule"}, + {"key": "period", "stringValue": "1 hour"}, + {"key": "endDateTime", "stringValue": "2012-12-21T18:00:00"}, + ], }, { "id": "SayHello", "name": "SayHello", - "fields": [{ - "key": "type", - "stringValue": "ShellCommandActivity" - }, { - "key": "command", - "stringValue": "echo hello" - }, { - "key": "parent", - "refValue": "Default" - }, { - "key": "schedule", - "refValue": "Schedule" - }] - } + "fields": [ + {"key": "type", "stringValue": "ShellCommandActivity"}, + {"key": "command", "stringValue": "echo hello"}, + {"key": "parent", "refValue": "Default"}, + {"key": "schedule", "refValue": "Schedule"}, + ], + }, ] @@ -88,14 +72,13 @@ def test_creating_pipeline_definition(): conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id) pipeline_definition = conn.get_pipeline_definition(pipeline_id) - pipeline_definition['pipelineObjects'].should.have.length_of(3) - default_object = pipeline_definition['pipelineObjects'][0] - default_object['name'].should.equal("Default") - default_object['id'].should.equal("Default") - default_object['fields'].should.equal([{ - "key": "workerGroup", - "stringValue": "workerGroup" - }]) + pipeline_definition["pipelineObjects"].should.have.length_of(3) + default_object = pipeline_definition["pipelineObjects"][0] + default_object["name"].should.equal("Default") + default_object["id"].should.equal("Default") + default_object["fields"].should.equal( + [{"key": "workerGroup", "stringValue": "workerGroup"}] + ) @mock_datapipeline_deprecated @@ -107,15 +90,15 @@ def test_describing_pipeline_objects(): conn.put_pipeline_definition(PIPELINE_OBJECTS, pipeline_id) objects = conn.describe_objects(["Schedule", "Default"], pipeline_id)[ - 'pipelineObjects'] + "pipelineObjects" + ] objects.should.have.length_of(2) - default_object = [x for x in objects if x['id'] == 'Default'][0] - default_object['name'].should.equal("Default") - default_object['fields'].should.equal([{ - "key": "workerGroup", - "stringValue": "workerGroup" - }]) + default_object = [x for x in objects if x["id"] == "Default"][0] + default_object["name"].should.equal("Default") + default_object["fields"].should.equal( + [{"key": "workerGroup", "stringValue": "workerGroup"}] + ) @mock_datapipeline_deprecated @@ -127,13 +110,14 @@ def test_activate_pipeline(): pipeline_id = res["pipelineId"] conn.activate_pipeline(pipeline_id) - pipeline_descriptions = conn.describe_pipelines( - [pipeline_id])["pipelineDescriptionList"] + pipeline_descriptions = conn.describe_pipelines([pipeline_id])[ + "pipelineDescriptionList" + ] pipeline_descriptions.should.have.length_of(1) pipeline_description = pipeline_descriptions[0] - fields = pipeline_description['fields'] + fields = pipeline_description["fields"] - get_value_from_fields('@pipelineState', fields).should.equal("SCHEDULED") + get_value_from_fields("@pipelineState", fields).should.equal("SCHEDULED") @mock_datapipeline_deprecated @@ -160,14 +144,12 @@ def test_listing_pipelines(): response["hasMoreResults"].should.be(False) response["marker"].should.be.none response["pipelineIdList"].should.have.length_of(2) - response["pipelineIdList"].should.contain({ - "id": res1["pipelineId"], - "name": "mypipeline1", - }) - response["pipelineIdList"].should.contain({ - "id": res2["pipelineId"], - "name": "mypipeline2" - }) + response["pipelineIdList"].should.contain( + {"id": res1["pipelineId"], "name": "mypipeline1"} + ) + response["pipelineIdList"].should.contain( + {"id": res2["pipelineId"], "name": "mypipeline2"} + ) @mock_datapipeline_deprecated @@ -179,7 +161,7 @@ def test_listing_paginated_pipelines(): response = conn.list_pipelines() response["hasMoreResults"].should.be(True) - response["marker"].should.equal(response["pipelineIdList"][-1]['id']) + response["marker"].should.equal(response["pipelineIdList"][-1]["id"]) response["pipelineIdList"].should.have.length_of(50) @@ -188,17 +170,13 @@ def test_remove_capitalization_of_dict_keys(): result = remove_capitalization_of_dict_keys( { "Id": "IdValue", - "Fields": [{ - "Key": "KeyValue", - "StringValue": "StringValueValue" - }] + "Fields": [{"Key": "KeyValue", "StringValue": "StringValueValue"}], } ) - result.should.equal({ - "id": "IdValue", - "fields": [{ - "key": "KeyValue", - "stringValue": "StringValueValue" - }], - }) + result.should.equal( + { + "id": "IdValue", + "fields": [{"key": "KeyValue", "stringValue": "StringValueValue"}], + } + ) diff --git a/tests/test_datapipeline/test_server.py b/tests/test_datapipeline/test_server.py index 03c77b034..49b8c39ce 100644 --- a/tests/test_datapipeline/test_server.py +++ b/tests/test_datapipeline/test_server.py @@ -6,9 +6,9 @@ import sure # noqa import moto.server as server from moto import mock_datapipeline -''' +""" Test the different server responses -''' +""" @mock_datapipeline @@ -16,13 +16,11 @@ def test_list_streams(): backend = server.create_backend_app("datapipeline") test_client = backend.test_client() - res = test_client.post('/', - data={"pipelineIds": ["ASdf"]}, - headers={ - "X-Amz-Target": "DataPipeline.DescribePipelines"}, - ) + res = test_client.post( + "/", + data={"pipelineIds": ["ASdf"]}, + headers={"X-Amz-Target": "DataPipeline.DescribePipelines"}, + ) json_data = json.loads(res.data.decode("utf-8")) - json_data.should.equal({ - 'pipelineDescriptionList': [] - }) + json_data.should.equal({"pipelineDescriptionList": []}) diff --git a/tests/test_dynamodb/test_dynamodb.py b/tests/test_dynamodb/test_dynamodb.py index d48519755..931e57e06 100644 --- a/tests/test_dynamodb/test_dynamodb.py +++ b/tests/test_dynamodb/test_dynamodb.py @@ -15,20 +15,17 @@ from boto.exception import DynamoDBResponseError @mock_dynamodb_deprecated def test_list_tables(): - name = 'TestTable' - dynamodb_backend.create_table( - name, hash_key_attr="name", hash_key_type="S") - conn = boto.connect_dynamodb('the_key', 'the_secret') - assert conn.list_tables() == ['TestTable'] + name = "TestTable" + dynamodb_backend.create_table(name, hash_key_attr="name", hash_key_type="S") + conn = boto.connect_dynamodb("the_key", "the_secret") + assert conn.list_tables() == ["TestTable"] @mock_dynamodb_deprecated def test_list_tables_layer_1(): - dynamodb_backend.create_table( - "test_1", hash_key_attr="name", hash_key_type="S") - dynamodb_backend.create_table( - "test_2", hash_key_attr="name", hash_key_type="S") - conn = boto.connect_dynamodb('the_key', 'the_secret') + dynamodb_backend.create_table("test_1", hash_key_attr="name", hash_key_type="S") + dynamodb_backend.create_table("test_2", hash_key_attr="name", hash_key_type="S") + conn = boto.connect_dynamodb("the_key", "the_secret") res = conn.layer1.list_tables(limit=1) expected = {"TableNames": ["test_1"], "LastEvaluatedTableName": "test_1"} res.should.equal(expected) @@ -40,15 +37,15 @@ def test_list_tables_layer_1(): @mock_dynamodb_deprecated def test_describe_missing_table(): - conn = boto.connect_dynamodb('the_key', 'the_secret') + conn = boto.connect_dynamodb("the_key", "the_secret") with assert_raises(DynamoDBResponseError): - conn.describe_table('messages') + conn.describe_table("messages") @mock_dynamodb_deprecated def test_dynamodb_with_connect_to_region(): # this will work if connected with boto.connect_dynamodb() - dynamodb = boto.dynamodb.connect_to_region('us-west-2') + dynamodb = boto.dynamodb.connect_to_region("us-west-2") - schema = dynamodb.create_schema('column1', str(), 'column2', int()) - dynamodb.create_table('table1', schema, 200, 200) + schema = dynamodb.create_schema("column1", str(), "column2", int()) + dynamodb.create_table("table1", schema, 200, 200) diff --git a/tests/test_dynamodb/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb/test_dynamodb_table_with_range_key.py index 2a482b31e..40301025f 100644 --- a/tests/test_dynamodb/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb/test_dynamodb_table_with_range_key.py @@ -13,17 +13,14 @@ from boto.exception import DynamoDBResponseError def create_table(conn): message_table_schema = conn.create_schema( - hash_key_name='forum_name', + hash_key_name="forum_name", hash_key_proto_value=str, - range_key_name='subject', - range_key_proto_value=str + range_key_name="subject", + range_key_proto_value=str, ) table = conn.create_table( - name='messages', - schema=message_table_schema, - read_units=10, - write_units=10 + name="messages", schema=message_table_schema, read_units=10, write_units=10 ) return table @@ -35,29 +32,23 @@ def test_create_table(): create_table(conn) expected = { - 'Table': { - 'CreationDateTime': 1326499200.0, - 'ItemCount': 0, - 'KeySchema': { - 'HashKeyElement': { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - 'RangeKeyElement': { - 'AttributeName': 'subject', - 'AttributeType': 'S' - } + "Table": { + "CreationDateTime": 1326499200.0, + "ItemCount": 0, + "KeySchema": { + "HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"}, + "RangeKeyElement": {"AttributeName": "subject", "AttributeType": "S"}, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 10, }, - 'TableName': 'messages', - 'TableSizeBytes': 0, - 'TableStatus': 'ACTIVE' + "TableName": "messages", + "TableSizeBytes": 0, + "TableStatus": "ACTIVE", } } - conn.describe_table('messages').should.equal(expected) + conn.describe_table("messages").should.equal(expected) @mock_dynamodb_deprecated @@ -66,11 +57,12 @@ def test_delete_table(): create_table(conn) conn.list_tables().should.have.length_of(1) - conn.layer1.delete_table('messages') + conn.layer1.delete_table("messages") conn.list_tables().should.have.length_of(0) - conn.layer1.delete_table.when.called_with( - 'messages').should.throw(DynamoDBResponseError) + conn.layer1.delete_table.when.called_with("messages").should.throw( + DynamoDBResponseError + ) @mock_dynamodb_deprecated @@ -93,45 +85,47 @@ def test_item_add_and_describe_and_update(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = table.new_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attrs=item_data, + hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data ) item.put() table.has_item("LOLCat Forum", "Check this out!").should.equal(True) returned_item = table.get_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", + range_key="Check this out!", + attributes_to_get=["Body", "SentBy"], + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - }) - item['SentBy'] = 'User B' + item["SentBy"] = "User B" item.put() returned_item = table.get_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", + range_key="Check this out!", + attributes_to_get=["Body", "SentBy"], + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @mock_dynamodb_deprecated @@ -139,11 +133,8 @@ def test_item_put_without_table(): conn = boto.connect_dynamodb() conn.layer1.put_item.when.called_with( - table_name='undeclared-table', - item=dict( - hash_key='LOLCat Forum', - range_key='Check this out!', - ), + table_name="undeclared-table", + item=dict(hash_key="LOLCat Forum", range_key="Check this out!"), ).should.throw(DynamoDBResponseError) @@ -152,10 +143,9 @@ def test_get_missing_item(): conn = boto.connect_dynamodb() table = create_table(conn) - table.get_item.when.called_with( - hash_key='tester', - range_key='other', - ).should.throw(DynamoDBKeyNotFoundError) + table.get_item.when.called_with(hash_key="tester", range_key="other").should.throw( + DynamoDBKeyNotFoundError + ) table.has_item("foobar", "more").should.equal(False) @@ -164,11 +154,8 @@ def test_get_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.get_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - 'RangeKeyElement': {'S': 'test-range'}, - }, + table_name="undeclared-table", + key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}}, ).should.throw(DynamoDBKeyNotFoundError) @@ -182,10 +169,7 @@ def test_get_item_without_range_key(): range_key_proto_value=int, ) table = conn.create_table( - name='messages', - schema=message_table_schema, - read_units=10, - write_units=10 + name="messages", schema=message_table_schema, read_units=10, write_units=10 ) hash_key = 3241526475 @@ -193,8 +177,9 @@ def test_get_item_without_range_key(): new_item = table.new_item(hash_key=hash_key, range_key=range_key) new_item.put() - table.get_item.when.called_with( - hash_key=hash_key).should.throw(DynamoDBValidationError) + table.get_item.when.called_with(hash_key=hash_key).should.throw( + DynamoDBValidationError + ) @mock_dynamodb_deprecated @@ -203,14 +188,12 @@ def test_delete_item(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = table.new_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attrs=item_data, + hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data ) item.put() @@ -218,7 +201,7 @@ def test_delete_item(): table.item_count.should.equal(1) response = item.delete() - response.should.equal({u'Attributes': [], u'ConsumedCapacityUnits': 0.5}) + response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5}) table.refresh() table.item_count.should.equal(0) @@ -231,31 +214,31 @@ def test_delete_item_with_attribute_response(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = table.new_item( - hash_key='LOLCat Forum', - range_key='Check this out!', - attrs=item_data, + hash_key="LOLCat Forum", range_key="Check this out!", attrs=item_data ) item.put() table.refresh() table.item_count.should.equal(1) - response = item.delete(return_values='ALL_OLD') - response.should.equal({ - 'Attributes': { - 'Body': 'http://url_to_lolcat.gif', - 'forum_name': 'LOLCat Forum', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'SentBy': 'User A', - 'subject': 'Check this out!' - }, - 'ConsumedCapacityUnits': 0.5 - }) + response = item.delete(return_values="ALL_OLD") + response.should.equal( + { + "Attributes": { + "Body": "http://url_to_lolcat.gif", + "forum_name": "LOLCat Forum", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "SentBy": "User A", + "subject": "Check this out!", + }, + "ConsumedCapacityUnits": 0.5, + } + ) table.refresh() table.item_count.should.equal(0) @@ -267,11 +250,8 @@ def test_delete_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.delete_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - 'RangeKeyElement': {'S': 'test-range'}, - }, + table_name="undeclared-table", + key={"HashKeyElement": {"S": "tester"}, "RangeKeyElement": {"S": "test-range"}}, ).should.throw(DynamoDBResponseError) @@ -281,54 +261,42 @@ def test_query(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - range_key='456', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='123', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='789', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data) item.put() - results = table.query(hash_key='the-key', - range_key_condition=condition.GT('1')) - results.response['Items'].should.have.length_of(3) + results = table.query(hash_key="the-key", range_key_condition=condition.GT("1")) + results.response["Items"].should.have.length_of(3) - results = table.query(hash_key='the-key', - range_key_condition=condition.GT('234')) - results.response['Items'].should.have.length_of(2) + results = table.query(hash_key="the-key", range_key_condition=condition.GT("234")) + results.response["Items"].should.have.length_of(2) - results = table.query(hash_key='the-key', - range_key_condition=condition.GT('9999')) - results.response['Items'].should.have.length_of(0) + results = table.query(hash_key="the-key", range_key_condition=condition.GT("9999")) + results.response["Items"].should.have.length_of(0) - results = table.query(hash_key='the-key', - range_key_condition=condition.CONTAINS('12')) - results.response['Items'].should.have.length_of(1) + results = table.query( + hash_key="the-key", range_key_condition=condition.CONTAINS("12") + ) + results.response["Items"].should.have.length_of(1) - results = table.query(hash_key='the-key', - range_key_condition=condition.BEGINS_WITH('7')) - results.response['Items'].should.have.length_of(1) + results = table.query( + hash_key="the-key", range_key_condition=condition.BEGINS_WITH("7") + ) + results.response["Items"].should.have.length_of(1) - results = table.query(hash_key='the-key', - range_key_condition=condition.BETWEEN('567', '890')) - results.response['Items'].should.have.length_of(1) + results = table.query( + hash_key="the-key", range_key_condition=condition.BETWEEN("567", "890") + ) + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -336,12 +304,10 @@ def test_query_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.query.when.called_with( - table_name='undeclared-table', - hash_key_value={'S': 'the-key'}, + table_name="undeclared-table", + hash_key_value={"S": "the-key"}, range_key_conditions={ - "AttributeValueList": [{ - "S": "User B" - }], + "AttributeValueList": [{"S": "User B"}], "ComparisonOperator": "EQ", }, ).should.throw(DynamoDBResponseError) @@ -353,61 +319,49 @@ def test_scan(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - range_key='456', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='123', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='the-key', - range_key='789', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="789", attrs=item_data) item.put() results = table.scan() - results.response['Items'].should.have.length_of(3) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'SentBy': condition.EQ('User B')}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"SentBy": condition.EQ("User B")}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Body': condition.BEGINS_WITH('http')}) - results.response['Items'].should.have.length_of(3) + results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")}) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'Ids': condition.CONTAINS(2)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NOT_NULL()}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.NOT_NULL()}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NULL()}) - results.response['Items'].should.have.length_of(2) + results = table.scan(scan_filter={"Ids": condition.NULL()}) + results.response["Items"].should.have.length_of(2) - results = table.scan(scan_filter={'PK': condition.BETWEEN(8, 9)}) - results.response['Items'].should.have.length_of(0) + results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)}) + results.response["Items"].should.have.length_of(0) - results = table.scan(scan_filter={'PK': condition.BETWEEN(5, 8)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)}) + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -415,13 +369,11 @@ def test_scan_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(DynamoDBResponseError) @@ -433,7 +385,7 @@ def test_scan_after_has_item(): table = create_table(conn) list(table.scan()).should.equal([]) - table.has_item(hash_key='the-key', range_key='123') + table.has_item(hash_key="the-key", range_key="123") list(table.scan()).should.equal([]) @@ -446,27 +398,31 @@ def test_write_batch(): batch_list = conn.new_batch_write_list() items = [] - items.append(table.new_item( - hash_key='the-key', - range_key='123', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }, - )) + items.append( + table.new_item( + hash_key="the-key", + range_key="123", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + }, + ) + ) - items.append(table.new_item( - hash_key='the-key', - range_key='789', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, - }, - )) + items.append( + table.new_item( + hash_key="the-key", + range_key="789", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, + }, + ) + ) batch_list.add_batch(table, puts=items) conn.batch_write_item(batch_list) @@ -475,7 +431,7 @@ def test_write_batch(): table.item_count.should.equal(2) batch_list = conn.new_batch_write_list() - batch_list.add_batch(table, deletes=[('the-key', '789')]) + batch_list.add_batch(table, deletes=[("the-key", "789")]) conn.batch_write_item(batch_list) table.refresh() @@ -488,39 +444,27 @@ def test_batch_read(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - range_key='456', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="456", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key', - range_key='123', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", range_key="123", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='another-key', - range_key='789', - attrs=item_data, - ) + item = table.new_item(hash_key="another-key", range_key="789", attrs=item_data) item.put() - items = table.batch_get_item([('the-key', '123'), ('another-key', '789')]) + items = table.batch_get_item([("the-key", "123"), ("another-key", "789")]) # Iterate through so that batch_item gets called count = len([x for x in items]) count.should.equal(2) diff --git a/tests/test_dynamodb/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb/test_dynamodb_table_without_range_key.py index ebd0c2051..e5a268c97 100644 --- a/tests/test_dynamodb/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb/test_dynamodb_table_without_range_key.py @@ -13,15 +13,11 @@ from boto.exception import DynamoDBResponseError def create_table(conn): message_table_schema = conn.create_schema( - hash_key_name='forum_name', - hash_key_proto_value=str, + hash_key_name="forum_name", hash_key_proto_value=str ) table = conn.create_table( - name='messages', - schema=message_table_schema, - read_units=10, - write_units=10 + name="messages", schema=message_table_schema, read_units=10, write_units=10 ) return table @@ -33,25 +29,22 @@ def test_create_table(): create_table(conn) expected = { - 'Table': { - 'CreationDateTime': 1326499200.0, - 'ItemCount': 0, - 'KeySchema': { - 'HashKeyElement': { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, + "Table": { + "CreationDateTime": 1326499200.0, + "ItemCount": 0, + "KeySchema": { + "HashKeyElement": {"AttributeName": "forum_name", "AttributeType": "S"} }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 10 + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 10, }, - 'TableName': 'messages', - 'TableSizeBytes': 0, - 'TableStatus': 'ACTIVE', + "TableName": "messages", + "TableSizeBytes": 0, + "TableStatus": "ACTIVE", } } - conn.describe_table('messages').should.equal(expected) + conn.describe_table("messages").should.equal(expected) @mock_dynamodb_deprecated @@ -60,11 +53,12 @@ def test_delete_table(): create_table(conn) conn.list_tables().should.have.length_of(1) - conn.layer1.delete_table('messages') + conn.layer1.delete_table("messages") conn.list_tables().should.have.length_of(0) - conn.layer1.delete_table.when.called_with( - 'messages').should.throw(DynamoDBResponseError) + conn.layer1.delete_table.when.called_with("messages").should.throw( + DynamoDBResponseError + ) @mock_dynamodb_deprecated @@ -87,38 +81,37 @@ def test_item_add_and_describe_and_update(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='LOLCat Forum', - attrs=item_data, - ) + item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item.put() returned_item = table.get_item( - hash_key='LOLCat Forum', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"] + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - }) - item['SentBy'] = 'User B' + item["SentBy"] = "User B" item.put() returned_item = table.get_item( - hash_key='LOLCat Forum', - attributes_to_get=['Body', 'SentBy'] + hash_key="LOLCat Forum", attributes_to_get=["Body", "SentBy"] + ) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @mock_dynamodb_deprecated @@ -126,10 +119,7 @@ def test_item_put_without_table(): conn = boto.connect_dynamodb() conn.layer1.put_item.when.called_with( - table_name='undeclared-table', - item=dict( - hash_key='LOLCat Forum', - ), + table_name="undeclared-table", item=dict(hash_key="LOLCat Forum") ).should.throw(DynamoDBResponseError) @@ -138,9 +128,9 @@ def test_get_missing_item(): conn = boto.connect_dynamodb() table = create_table(conn) - table.get_item.when.called_with( - hash_key='tester', - ).should.throw(DynamoDBKeyNotFoundError) + table.get_item.when.called_with(hash_key="tester").should.throw( + DynamoDBKeyNotFoundError + ) @mock_dynamodb_deprecated @@ -148,10 +138,7 @@ def test_get_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.get_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - }, + table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}} ).should.throw(DynamoDBKeyNotFoundError) @@ -161,21 +148,18 @@ def test_delete_item(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='LOLCat Forum', - attrs=item_data, - ) + item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item.put() table.refresh() table.item_count.should.equal(1) response = item.delete() - response.should.equal({u'Attributes': [], u'ConsumedCapacityUnits': 0.5}) + response.should.equal({"Attributes": [], "ConsumedCapacityUnits": 0.5}) table.refresh() table.item_count.should.equal(0) @@ -188,29 +172,28 @@ def test_delete_item_with_attribute_response(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='LOLCat Forum', - attrs=item_data, - ) + item = table.new_item(hash_key="LOLCat Forum", attrs=item_data) item.put() table.refresh() table.item_count.should.equal(1) - response = item.delete(return_values='ALL_OLD') - response.should.equal({ - u'Attributes': { - u'Body': u'http://url_to_lolcat.gif', - u'forum_name': u'LOLCat Forum', - u'ReceivedTime': u'12/9/2011 11:36:03 PM', - u'SentBy': u'User A', - }, - u'ConsumedCapacityUnits': 0.5 - }) + response = item.delete(return_values="ALL_OLD") + response.should.equal( + { + "Attributes": { + "Body": "http://url_to_lolcat.gif", + "forum_name": "LOLCat Forum", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "SentBy": "User A", + }, + "ConsumedCapacityUnits": 0.5, + } + ) table.refresh() table.item_count.should.equal(0) @@ -222,10 +205,7 @@ def test_delete_item_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.delete_item.when.called_with( - table_name='undeclared-table', - key={ - 'HashKeyElement': {'S': 'tester'}, - }, + table_name="undeclared-table", key={"HashKeyElement": {"S": "tester"}} ).should.throw(DynamoDBResponseError) @@ -235,18 +215,15 @@ def test_query(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", attrs=item_data) item.put() - results = table.query(hash_key='the-key') - results.response['Items'].should.have.length_of(1) + results = table.query(hash_key="the-key") + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -254,8 +231,7 @@ def test_query_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.query.when.called_with( - table_name='undeclared-table', - hash_key_value={'S': 'the-key'}, + table_name="undeclared-table", hash_key_value={"S": "the-key"} ).should.throw(DynamoDBResponseError) @@ -265,58 +241,49 @@ def test_scan(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key2', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key2", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='the-key3', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key3", attrs=item_data) item.put() results = table.scan() - results.response['Items'].should.have.length_of(3) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'SentBy': condition.EQ('User B')}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"SentBy": condition.EQ("User B")}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Body': condition.BEGINS_WITH('http')}) - results.response['Items'].should.have.length_of(3) + results = table.scan(scan_filter={"Body": condition.BEGINS_WITH("http")}) + results.response["Items"].should.have.length_of(3) - results = table.scan(scan_filter={'Ids': condition.CONTAINS(2)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.CONTAINS(2)}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NOT_NULL()}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"Ids": condition.NOT_NULL()}) + results.response["Items"].should.have.length_of(1) - results = table.scan(scan_filter={'Ids': condition.NULL()}) - results.response['Items'].should.have.length_of(2) + results = table.scan(scan_filter={"Ids": condition.NULL()}) + results.response["Items"].should.have.length_of(2) - results = table.scan(scan_filter={'PK': condition.BETWEEN(8, 9)}) - results.response['Items'].should.have.length_of(0) + results = table.scan(scan_filter={"PK": condition.BETWEEN(8, 9)}) + results.response["Items"].should.have.length_of(0) - results = table.scan(scan_filter={'PK': condition.BETWEEN(5, 8)}) - results.response['Items'].should.have.length_of(1) + results = table.scan(scan_filter={"PK": condition.BETWEEN(5, 8)}) + results.response["Items"].should.have.length_of(1) @mock_dynamodb_deprecated @@ -324,13 +291,11 @@ def test_scan_with_undeclared_table(): conn = boto.connect_dynamodb() conn.layer1.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(DynamoDBResponseError) @@ -342,7 +307,7 @@ def test_scan_after_has_item(): table = create_table(conn) list(table.scan()).should.equal([]) - table.has_item('the-key') + table.has_item("the-key") list(table.scan()).should.equal([]) @@ -355,25 +320,29 @@ def test_write_batch(): batch_list = conn.new_batch_write_list() items = [] - items.append(table.new_item( - hash_key='the-key', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }, - )) + items.append( + table.new_item( + hash_key="the-key", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + }, + ) + ) - items.append(table.new_item( - hash_key='the-key2', - attrs={ - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, - }, - )) + items.append( + table.new_item( + hash_key="the-key2", + attrs={ + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, + }, + ) + ) batch_list.add_batch(table, puts=items) conn.batch_write_item(batch_list) @@ -382,7 +351,7 @@ def test_write_batch(): table.item_count.should.equal(2) batch_list = conn.new_batch_write_list() - batch_list.add_batch(table, deletes=[('the-key')]) + batch_list.add_batch(table, deletes=[("the-key")]) conn.batch_write_item(batch_list) table.refresh() @@ -395,36 +364,27 @@ def test_batch_read(): table = create_table(conn) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item = table.new_item( - hash_key='the-key1', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key1", attrs=item_data) item.put() - item = table.new_item( - hash_key='the-key2', - attrs=item_data, - ) + item = table.new_item(hash_key="the-key2", attrs=item_data) item.put() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item = table.new_item( - hash_key='another-key', - attrs=item_data, - ) + item = table.new_item(hash_key="another-key", attrs=item_data) item.put() - items = table.batch_get_item([('the-key1'), ('another-key')]) + items = table.batch_get_item([("the-key1"), ("another-key")]) # Iterate through so that batch_item gets called count = len([x for x in items]) count.should.have.equal(2) diff --git a/tests/test_dynamodb/test_server.py b/tests/test_dynamodb/test_server.py index 66004bbe1..310643628 100644 --- a/tests/test_dynamodb/test_server.py +++ b/tests/test_dynamodb/test_server.py @@ -3,18 +3,18 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_table_list(): backend = server.create_backend_app("dynamodb") test_client = backend.test_client() - res = test_client.get('/') + res = test_client.get("/") res.status_code.should.equal(404) - headers = {'X-Amz-Target': 'TestTable.ListTables'} - res = test_client.get('/', headers=headers) - res.data.should.contain(b'TableNames') + headers = {"X-Amz-Target": "TestTable.ListTables"} + res = test_client.get("/", headers=headers) + res.data.should.contain(b"TableNames") diff --git a/tests/test_dynamodb2/test_dynamodb.py b/tests/test_dynamodb2/test_dynamodb.py index ca8a7c935..f6ed4f13d 100644 --- a/tests/test_dynamodb2/test_dynamodb.py +++ b/tests/test_dynamodb2/test_dynamodb.py @@ -19,6 +19,7 @@ import moto.dynamodb2.comparisons import moto.dynamodb2.models from nose.tools import assert_raises + try: import boto.dynamodb2 except ImportError: @@ -28,16 +29,18 @@ except ImportError: @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_list_tables(): - name = 'TestTable' + name = "TestTable" # Should make tables properly with boto - dynamodb_backend2.create_table(name, schema=[ - {u'KeyType': u'HASH', u'AttributeName': u'forum_name'}, - {u'KeyType': u'RANGE', u'AttributeName': u'subject'} - ]) + dynamodb_backend2.create_table( + name, + schema=[ + {"KeyType": "HASH", "AttributeName": "forum_name"}, + {"KeyType": "RANGE", "AttributeName": "subject"}, + ], + ) conn = boto.dynamodb2.connect_to_region( - 'us-east-1', - aws_access_key_id="ak", - aws_secret_access_key="sk") + "us-east-1", aws_access_key_id="ak", aws_secret_access_key="sk" + ) assert conn.list_tables()["TableNames"] == [name] @@ -45,16 +48,15 @@ def test_list_tables(): @mock_dynamodb2_deprecated def test_list_tables_layer_1(): # Should make tables properly with boto - dynamodb_backend2.create_table("test_1", schema=[ - {u'KeyType': u'HASH', u'AttributeName': u'name'} - ]) - dynamodb_backend2.create_table("test_2", schema=[ - {u'KeyType': u'HASH', u'AttributeName': u'name'} - ]) + dynamodb_backend2.create_table( + "test_1", schema=[{"KeyType": "HASH", "AttributeName": "name"}] + ) + dynamodb_backend2.create_table( + "test_2", schema=[{"KeyType": "HASH", "AttributeName": "name"}] + ) conn = boto.dynamodb2.connect_to_region( - 'us-east-1', - aws_access_key_id="ak", - aws_secret_access_key="sk") + "us-east-1", aws_access_key_id="ak", aws_secret_access_key="sk" + ) res = conn.list_tables(limit=1) expected = {"TableNames": ["test_1"], "LastEvaluatedTableName": "test_1"} @@ -69,30 +71,36 @@ def test_list_tables_layer_1(): @mock_dynamodb2_deprecated def test_describe_missing_table(): conn = boto.dynamodb2.connect_to_region( - 'us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") + "us-west-2", aws_access_key_id="ak", aws_secret_access_key="sk" + ) with assert_raises(JSONResponseError): - conn.describe_table('messages') + conn.describe_table("messages") @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_table_tags(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'id','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'id','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) table_description = conn.describe_table(TableName=name) - arn = table_description['Table']['TableArn'] + arn = table_description["Table"]["TableArn"] # Tag table - tags = [{'Key': 'TestTag', 'Value': 'TestValue'}, {'Key': 'TestTag2', 'Value': 'TestValue2'}] + tags = [ + {"Key": "TestTag", "Value": "TestValue"}, + {"Key": "TestTag2", "Value": "TestValue2"}, + ] conn.tag_resource(ResourceArn=arn, Tags=tags) # Check tags @@ -100,28 +108,32 @@ def test_list_table_tags(): assert resp["Tags"] == tags # Remove 1 tag - conn.untag_resource(ResourceArn=arn, TagKeys=['TestTag']) + conn.untag_resource(ResourceArn=arn, TagKeys=["TestTag"]) # Check tags resp = conn.list_tags_of_resource(ResourceArn=arn) - assert resp["Tags"] == [{'Key': 'TestTag2', 'Value': 'TestValue2'}] + assert resp["Tags"] == [{"Key": "TestTag2", "Value": "TestValue2"}] @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_table_tags_empty(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'id','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'id','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) table_description = conn.describe_table(TableName=name) - arn = table_description['Table']['TableArn'] - tags = [{'Key':'TestTag', 'Value': 'TestValue'}] + arn = table_description["Table"]["TableArn"] + tags = [{"Key": "TestTag", "Value": "TestValue"}] # conn.tag_resource(ResourceArn=arn, # Tags=tags) resp = conn.list_tags_of_resource(ResourceArn=arn) @@ -131,914 +143,786 @@ def test_list_table_tags_empty(): @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_table_tags_paginated(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'id','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'id','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) table_description = conn.describe_table(TableName=name) - arn = table_description['Table']['TableArn'] + arn = table_description["Table"]["TableArn"] for i in range(11): - tags = [{'Key':'TestTag%d' % i, 'Value': 'TestValue'}] - conn.tag_resource(ResourceArn=arn, - Tags=tags) + tags = [{"Key": "TestTag%d" % i, "Value": "TestValue"}] + conn.tag_resource(ResourceArn=arn, Tags=tags) resp = conn.list_tags_of_resource(ResourceArn=arn) assert len(resp["Tags"]) == 10 - assert 'NextToken' in resp.keys() - resp2 = conn.list_tags_of_resource(ResourceArn=arn, - NextToken=resp['NextToken']) + assert "NextToken" in resp.keys() + resp2 = conn.list_tags_of_resource(ResourceArn=arn, NextToken=resp["NextToken"]) assert len(resp2["Tags"]) == 1 - assert 'NextToken' not in resp2.keys() + assert "NextToken" not in resp2.keys() @requires_boto_gte("2.9") @mock_dynamodb2 def test_list_not_found_table_tags(): - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - arn = 'DymmyArn' + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + arn = "DymmyArn" try: conn.list_tags_of_resource(ResourceArn=arn) except ClientError as exception: - assert exception.response['Error']['Code'] == "ResourceNotFoundException" + assert exception.response["Error"]["Code"] == "ResourceNotFoundException" @requires_boto_gte("2.9") @mock_dynamodb2 def test_item_add_empty_string_exception(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) with assert_raises(ClientError) as ex: conn.put_item( TableName=name, Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - } + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": ""}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + }, ) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'One or more parameter values were invalid: An AttributeValue may not contain an empty string' + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "One or more parameter values were invalid: An AttributeValue may not contain an empty string" ) @requires_boto_gte("2.9") @mock_dynamodb2 def test_update_item_with_empty_string_exception(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) conn.put_item( TableName=name, Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "test" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - } + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": "test"}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + }, ) with assert_raises(ClientError) as ex: conn.update_item( TableName=name, - Key={ - 'forum_name': { 'S': 'LOLCat Forum'}, - }, - UpdateExpression='set Body=:Body', - ExpressionAttributeValues={ - ':Body': {'S': ''} - }) + Key={"forum_name": {"S": "LOLCat Forum"}}, + UpdateExpression="set Body=:Body", + ExpressionAttributeValues={":Body": {"S": ""}}, + ) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'One or more parameter values were invalid: An AttributeValue may not contain an empty string' + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "One or more parameter values were invalid: An AttributeValue may not contain an empty string" ) @requires_boto_gte("2.9") @mock_dynamodb2 def test_query_invalid_table(): - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) try: - conn.query(TableName='invalid_table', KeyConditionExpression='index1 = :partitionkeyval', ExpressionAttributeValues={':partitionkeyval': {'S':'test'}}) + conn.query( + TableName="invalid_table", + KeyConditionExpression="index1 = :partitionkeyval", + ExpressionAttributeValues={":partitionkeyval": {"S": "test"}}, + ) except ClientError as exception: - assert exception.response['Error']['Code'] == "ResourceNotFoundException" + assert exception.response["Error"]["Code"] == "ResourceNotFoundException" @requires_boto_gte("2.9") @mock_dynamodb2 def test_scan_returns_consumed_capacity(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") - - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) - - conn.put_item( - TableName=name, - Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "test" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - } - ) - - response = conn.scan( - TableName=name, + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", ) - assert 'ConsumedCapacity' in response - assert 'CapacityUnits' in response['ConsumedCapacity'] - assert response['ConsumedCapacity']['TableName'] == name + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + + conn.put_item( + TableName=name, + Item={ + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": "test"}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + }, + ) + + response = conn.scan(TableName=name) + + assert "ConsumedCapacity" in response + assert "CapacityUnits" in response["ConsumedCapacity"] + assert response["ConsumedCapacity"]["TableName"] == name @requires_boto_gte("2.9") @mock_dynamodb2 def test_put_item_with_special_chars(): - name = 'TestTable' - conn = boto3.client('dynamodb', - region_name='us-west-2', - aws_access_key_id="ak", - aws_secret_access_key="sk") + name = "TestTable" + conn = boto3.client( + "dynamodb", + region_name="us-west-2", + aws_access_key_id="ak", + aws_secret_access_key="sk", + ) - conn.create_table(TableName=name, - KeySchema=[{'AttributeName':'forum_name','KeyType':'HASH'}], - AttributeDefinitions=[{'AttributeName':'forum_name','AttributeType':'S'}], - ProvisionedThroughput={'ReadCapacityUnits':5,'WriteCapacityUnits':5}) + conn.create_table( + TableName=name, + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) conn.put_item( - TableName=name, - Item={ - 'forum_name': { 'S': 'LOLCat Forum' }, - 'subject': { 'S': 'Check this out!' }, - 'Body': { 'S': 'http://url_to_lolcat.gif'}, - 'SentBy': { 'S': "test" }, - 'ReceivedTime': { 'S': '12/9/2011 11:36:03 PM'}, - '"': {"S": "foo"}, - } - ) + TableName=name, + Item={ + "forum_name": {"S": "LOLCat Forum"}, + "subject": {"S": "Check this out!"}, + "Body": {"S": "http://url_to_lolcat.gif"}, + "SentBy": {"S": "test"}, + "ReceivedTime": {"S": "12/9/2011 11:36:03 PM"}, + '"': {"S": "foo"}, + }, + ) @requires_boto_gte("2.9") @mock_dynamodb2 def test_query_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) - - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} ) - assert 'ConsumedCapacity' in results - assert 'CapacityUnits' in results['ConsumedCapacity'] - assert results['ConsumedCapacity']['CapacityUnits'] == 1 + results = table.query(KeyConditionExpression=Key("forum_name").eq("the-key")) + + assert "ConsumedCapacity" in results + assert "CapacityUnits" in results["ConsumedCapacity"] + assert results["ConsumedCapacity"]["CapacityUnits"] == 1 @mock_dynamodb2 def test_basic_projection_expression_using_get_item(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", } ) - table = dynamodb.Table('users') - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message' - }) result = table.get_item( - Key = { - 'forum_name': 'the-key', - 'subject': '123' - }, - ProjectionExpression='body, subject' + Key={"forum_name": "the-key", "subject": "123"}, + ProjectionExpression="body, subject", ) - result['Item'].should.be.equal({ - 'subject': '123', - 'body': 'some test message' - }) + result["Item"].should.be.equal({"subject": "123", "body": "some test message"}) # The projection expression should not remove data from storage - result = table.get_item( - Key = { - 'forum_name': 'the-key', - 'subject': '123' - } - ) + result = table.get_item(Key={"forum_name": "the-key", "subject": "123"}) - result['Item'].should.be.equal({ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) + result["Item"].should.be.equal( + {"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) @mock_dynamodb2 def test_basic_projection_expressions_using_query(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", } ) - table = dynamodb.Table('users') - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message' - }) # Test a query returning all items results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body, subject' + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body, subject", ) - assert 'body' in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'subject' in results['Items'][0] + assert "body" in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "subject" in results["Items"][0] - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '1234', - 'body': 'yet another test message' - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "1234", + "body": "yet another test message", + } + ) results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body' + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body", ) - assert 'body' in results['Items'][0] - assert 'subject' not in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'body' in results['Items'][1] - assert 'subject' not in results['Items'][1] - assert results['Items'][1]['body'] == 'yet another test message' + assert "body" in results["Items"][0] + assert "subject" not in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "body" in results["Items"][1] + assert "subject" not in results["Items"][1] + assert results["Items"][1]["body"] == "yet another test message" # The projection expression should not remove data from storage - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ) - assert 'subject' in results['Items'][0] - assert 'body' in results['Items'][1] - assert 'forum_name' in results['Items'][1] + results = table.query(KeyConditionExpression=Key("forum_name").eq("the-key")) + assert "subject" in results["Items"][0] + assert "body" in results["Items"][1] + assert "forum_name" in results["Items"][1] @mock_dynamodb2 def test_basic_projection_expressions_using_scan(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) + + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", } ) - table = dynamodb.Table('users') - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message' - }) # Test a scan returning all items results = table.scan( - FilterExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body, subject' + FilterExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="body, subject", ) - assert 'body' in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'subject' in results['Items'][0] + assert "body" in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "subject" in results["Items"][0] - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '1234', - 'body': 'yet another test message' - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "1234", + "body": "yet another test message", + } + ) results = table.scan( - FilterExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='body' + FilterExpression=Key("forum_name").eq("the-key"), ProjectionExpression="body" ) - assert 'body' in results['Items'][0] - assert 'subject' not in results['Items'][0] - assert 'forum_name' not in results['Items'][0] - assert 'body' in results['Items'][1] - assert 'subject' not in results['Items'][1] - assert 'forum_name' not in results['Items'][1] + assert "body" in results["Items"][0] + assert "subject" not in results["Items"][0] + assert "forum_name" not in results["Items"][0] + assert "body" in results["Items"][1] + assert "subject" not in results["Items"][1] + assert "forum_name" not in results["Items"][1] # The projection expression should not remove data from storage - results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ) - assert 'subject' in results['Items'][0] - assert 'body' in results['Items'][1] - assert 'forum_name' in results['Items'][1] + results = table.query(KeyConditionExpression=Key("forum_name").eq("the-key")) + assert "subject" in results["Items"][0] + assert "body" in results["Items"][1] + assert "forum_name" in results["Items"][1] @mock_dynamodb2 def test_basic_projection_expression_using_get_item_with_attr_expression_names(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": "some test message", + "attachment": "something", } ) - table = dynamodb.Table('users') - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - 'attachment': 'something' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message', - 'attachment': 'something' - }) + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + "attachment": "something", + } + ) result = table.get_item( - Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, - ProjectionExpression='#rl, #rt, subject', - ExpressionAttributeNames={ - '#rl': 'body', - '#rt': 'attachment' - }, + Key={"forum_name": "the-key", "subject": "123"}, + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, ) - result['Item'].should.be.equal({ - 'subject': '123', - 'body': 'some test message', - 'attachment': 'something' - }) + result["Item"].should.be.equal( + {"subject": "123", "body": "some test message", "attachment": "something"} + ) @mock_dynamodb2 def test_basic_projection_expressions_using_query_with_attr_expression_names(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": "some test message", + "attachment": "something", } ) - table = dynamodb.Table('users') - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - 'attachment': 'something' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message', - 'attachment': 'something' - }) + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + "attachment": "something", + } + ) # Test a query returning all items results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='#rl, #rt, subject', - ExpressionAttributeNames={ - '#rl': 'body', - '#rt': 'attachment' - }, + KeyConditionExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, ) - assert 'body' in results['Items'][0] - assert results['Items'][0]['body'] == 'some test message' - assert 'subject' in results['Items'][0] - assert results['Items'][0]['subject'] == '123' - assert 'attachment' in results['Items'][0] - assert results['Items'][0]['attachment'] == 'something' + assert "body" in results["Items"][0] + assert results["Items"][0]["body"] == "some test message" + assert "subject" in results["Items"][0] + assert results["Items"][0]["subject"] == "123" + assert "attachment" in results["Items"][0] + assert results["Items"][0]["attachment"] == "something" @mock_dynamodb2 def test_basic_projection_expressions_using_scan_with_attr_expression_names(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": "some test message", + "attachment": "something", } ) - table = dynamodb.Table('users') - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - 'attachment': 'something' - }) - - table.put_item(Item={ - 'forum_name': 'not-the-key', - 'subject': '123', - 'body': 'some other test message', - 'attachment': 'something' - }) + table.put_item( + Item={ + "forum_name": "not-the-key", + "subject": "123", + "body": "some other test message", + "attachment": "something", + } + ) # Test a scan returning all items results = table.scan( - FilterExpression=Key('forum_name').eq( - 'the-key'), - ProjectionExpression='#rl, #rt, subject', - ExpressionAttributeNames={ - '#rl': 'body', - '#rt': 'attachment' - }, + FilterExpression=Key("forum_name").eq("the-key"), + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, ) - assert 'body' in results['Items'][0] - assert 'attachment' in results['Items'][0] - assert 'subject' in results['Items'][0] - assert 'form_name' not in results['Items'][0] + assert "body" in results["Items"][0] + assert "attachment" in results["Items"][0] + assert "subject" in results["Items"][0] + assert "form_name" not in results["Items"][0] # Test without a FilterExpression results = table.scan( - ProjectionExpression='#rl, #rt, subject', - ExpressionAttributeNames={ - '#rl': 'body', - '#rt': 'attachment' - }, + ProjectionExpression="#rl, #rt, subject", + ExpressionAttributeNames={"#rl": "body", "#rt": "attachment"}, ) - assert 'body' in results['Items'][0] - assert 'attachment' in results['Items'][0] - assert 'subject' in results['Items'][0] - assert 'form_name' not in results['Items'][0] + assert "body" in results["Items"][0] + assert "attachment" in results["Items"][0] + assert "subject" in results["Items"][0] + assert "form_name" not in results["Items"][0] @mock_dynamodb2 def test_put_item_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - response = table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - }) + response = table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) - assert 'ConsumedCapacity' in response + assert "ConsumedCapacity" in response @mock_dynamodb2 def test_update_item_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - }) + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) - response = table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, - UpdateExpression='set body=:tb', - ExpressionAttributeValues={ - ':tb': 'a new message' - }) + response = table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, + UpdateExpression="set body=:tb", + ExpressionAttributeValues={":tb": "a new message"}, + ) - assert 'ConsumedCapacity' in response - assert 'CapacityUnits' in response['ConsumedCapacity'] - assert 'TableName' in response['ConsumedCapacity'] + assert "ConsumedCapacity" in response + assert "CapacityUnits" in response["ConsumedCapacity"] + assert "TableName" in response["ConsumedCapacity"] @mock_dynamodb2 def test_get_item_returns_consumed_capacity(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': 'some test message', - }) + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "body": "some test message"} + ) - response = table.get_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }) + response = table.get_item(Key={"forum_name": "the-key", "subject": "123"}) - assert 'ConsumedCapacity' in response - assert 'CapacityUnits' in response['ConsumedCapacity'] - assert 'TableName' in response['ConsumedCapacity'] + assert "ConsumedCapacity" in response + assert "CapacityUnits" in response["ConsumedCapacity"] + assert "TableName" in response["ConsumedCapacity"] def test_filter_expression(): - row1 = moto.dynamodb2.models.Item(None, None, None, None, {'Id': {'N': '8'}, 'Subs': {'N': '5'}, 'Desc': {'S': 'Some description'}, 'KV': {'SS': ['test1', 'test2']}}) - row2 = moto.dynamodb2.models.Item(None, None, None, None, {'Id': {'N': '8'}, 'Subs': {'N': '10'}, 'Desc': {'S': 'A description'}, 'KV': {'SS': ['test3', 'test4']}}) + row1 = moto.dynamodb2.models.Item( + None, + None, + None, + None, + { + "Id": {"N": "8"}, + "Subs": {"N": "5"}, + "Desc": {"S": "Some description"}, + "KV": {"SS": ["test1", "test2"]}, + }, + ) + row2 = moto.dynamodb2.models.Item( + None, + None, + None, + None, + { + "Id": {"N": "8"}, + "Subs": {"N": "10"}, + "Desc": {"S": "A description"}, + "KV": {"SS": ["test3", "test4"]}, + }, + ) # NOT test 1 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT attribute_not_exists(Id)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "NOT attribute_not_exists(Id)", {}, {} + ) filter_expr.expr(row1).should.be(True) # NOT test 2 - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('NOT (Id = :v0)', {}, {':v0': {'N': '8'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "NOT (Id = :v0)", {}, {":v0": {"N": "8"}} + ) filter_expr.expr(row1).should.be(False) # Id = 8 so should be false # AND test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id > :v0 AND Subs < :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '7'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id > :v0 AND Subs < :v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "7"}} + ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # OR test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 OR Id=:v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '8'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id = :v0 OR Id=:v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "8"}} + ) filter_expr.expr(row1).should.be(True) # BETWEEN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id BETWEEN :v0 AND :v1', {}, {':v0': {'N': '5'}, ':v1': {'N': '10'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id BETWEEN :v0 AND :v1", {}, {":v0": {"N": "5"}, ":v1": {"N": "10"}} + ) filter_expr.expr(row1).should.be(True) # PAREN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id = :v0 AND (Subs = :v0 OR Subs = :v1)', {}, {':v0': {'N': '8'}, ':v1': {'N': '5'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id = :v0 AND (Subs = :v0 OR Subs = :v1)", + {}, + {":v0": {"N": "8"}, ":v1": {"N": "5"}}, + ) filter_expr.expr(row1).should.be(True) # IN test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('Id IN (:v0, :v1, :v2)', {}, { - ':v0': {'N': '7'}, - ':v1': {'N': '8'}, - ':v2': {'N': '9'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "Id IN (:v0, :v1, :v2)", + {}, + {":v0": {"N": "7"}, ":v1": {"N": "8"}, ":v2": {"N": "9"}}, + ) filter_expr.expr(row1).should.be(True) # attribute function tests (with extra spaces) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_exists(Id) AND attribute_not_exists (User)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "attribute_exists(Id) AND attribute_not_exists (User)", {}, {} + ) filter_expr.expr(row1).should.be(True) - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('attribute_type(Id, :v0)', {}, {':v0': {'S': 'N'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "attribute_type(Id, :v0)", {}, {":v0": {"S": "N"}} + ) filter_expr.expr(row1).should.be(True) # beginswith function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('begins_with(Desc, :v0)', {}, {':v0': {'S': 'Some'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "begins_with(Desc, :v0)", {}, {":v0": {"S": "Some"}} + ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # contains function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('contains(KV, :v0)', {}, {':v0': {'S': 'test1'}}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "contains(KV, :v0)", {}, {":v0": {"S": "test1"}} + ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) # size function test - filter_expr = moto.dynamodb2.comparisons.get_filter_expression('size(Desc) > size(KV)', {}, {}) + filter_expr = moto.dynamodb2.comparisons.get_filter_expression( + "size(Desc) > size(KV)", {}, {} + ) filter_expr.expr(row1).should.be(True) # Expression from @batkuip filter_expr = moto.dynamodb2.comparisons.get_filter_expression( - '(#n0 < :v0 AND attribute_not_exists(#n1))', - {'#n0': 'Subs', '#n1': 'fanout_ts'}, - {':v0': {'N': '7'}} + "(#n0 < :v0 AND attribute_not_exists(#n1))", + {"#n0": "Subs", "#n1": "fanout_ts"}, + {":v0": {"N": "7"}}, ) filter_expr.expr(row1).should.be(True) # Expression from to check contains on string value filter_expr = moto.dynamodb2.comparisons.get_filter_expression( - 'contains(#n0, :v0)', - {'#n0': 'Desc'}, - {':v0': {'S': 'Some'}} + "contains(#n0, :v0)", {"#n0": "Desc"}, {":v0": {"S": "Some"}} ) filter_expr.expr(row1).should.be(True) filter_expr.expr(row2).should.be(False) @@ -1046,1143 +930,1074 @@ def test_filter_expression(): @mock_dynamodb2 def test_query_filter(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'nested': {'M': { - 'version': {'S': 'version1'}, - 'contents': {'L': [ - {'S': 'value1'}, {'S': 'value2'}, - ]}, - }}, - } + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "nested": { + "M": { + "version": {"S": "version1"}, + "contents": {"L": [{"S": "value1"}, {"S": "value2"}]}, + } + }, + }, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app2'}, - 'nested': {'M': { - 'version': {'S': 'version2'}, - 'contents': {'L': [ - {'S': 'value1'}, {'S': 'value2'}, - ]}, - }}, - } + "client": {"S": "client1"}, + "app": {"S": "app2"}, + "nested": { + "M": { + "version": {"S": "version2"}, + "contents": {"L": [{"S": "value1"}, {"S": "value2"}]}, + } + }, + }, ) - table = dynamodb.Table('test1') - response = table.query( - KeyConditionExpression=Key('client').eq('client1') - ) - assert response['Count'] == 2 + table = dynamodb.Table("test1") + response = table.query(KeyConditionExpression=Key("client").eq("client1")) + assert response["Count"] == 2 response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('app').eq('app2') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("app").eq("app2"), ) - assert response['Count'] == 1 - assert response['Items'][0]['app'] == 'app2' + assert response["Count"] == 1 + assert response["Items"][0]["app"] == "app2" response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('app').contains('app') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("app").contains("app"), ) - assert response['Count'] == 2 + assert response["Count"] == 2 response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('nested.version').contains('version') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("nested.version").contains("version"), ) - assert response['Count'] == 2 + assert response["Count"] == 2 response = table.query( - KeyConditionExpression=Key('client').eq('client1'), - FilterExpression=Attr('nested.contents[0]').eq('value1') + KeyConditionExpression=Key("client").eq("client1"), + FilterExpression=Attr("nested.contents[0]").eq("value1"), ) - assert response['Count'] == 2 + assert response["Count"] == 2 @mock_dynamodb2 def test_query_filter_overlapping_expression_prefixes(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'nested': {'M': { - 'version': {'S': 'version1'}, - 'contents': {'L': [ - {'S': 'value1'}, {'S': 'value2'}, - ]}, - }}, - }) - - table = dynamodb.Table('test1') - response = table.query( - KeyConditionExpression=Key('client').eq('client1') & Key('app').eq('app1'), - ProjectionExpression='#1, #10, nested', - ExpressionAttributeNames={ - '#1': 'client', - '#10': 'app', - } + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "nested": { + "M": { + "version": {"S": "version1"}, + "contents": {"L": [{"S": "value1"}, {"S": "value2"}]}, + } + }, + }, ) - assert response['Count'] == 1 - assert response['Items'][0] == { - 'client': 'client1', - 'app': 'app1', - 'nested': { - 'version': 'version1', - 'contents': ['value1', 'value2'] - } + table = dynamodb.Table("test1") + response = table.query( + KeyConditionExpression=Key("client").eq("client1") & Key("app").eq("app1"), + ProjectionExpression="#1, #10, nested", + ExpressionAttributeNames={"#1": "client", "#10": "app"}, + ) + + assert response["Count"] == 1 + assert response["Items"][0] == { + "client": "client1", + "app": "app1", + "nested": {"version": "version1", "contents": ["value1", "value2"]}, } @mock_dynamodb2 def test_scan_filter(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"S": "app1"}} ) - table = dynamodb.Table('test1') - response = table.scan( - FilterExpression=Attr('app').eq('app2') - ) - assert response['Count'] == 0 + table = dynamodb.Table("test1") + response = table.scan(FilterExpression=Attr("app").eq("app2")) + assert response["Count"] == 0 - response = table.scan( - FilterExpression=Attr('app').eq('app1') - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("app").eq("app1")) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('app').ne('app2') - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("app").ne("app2")) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('app').ne('app1') - ) - assert response['Count'] == 0 + response = table.scan(FilterExpression=Attr("app").ne("app1")) + assert response["Count"] == 0 @mock_dynamodb2 def test_scan_filter2(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'N'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "N"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'N': '1'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"N": "1"}} ) response = client.scan( - TableName='test1', - Select='ALL_ATTRIBUTES', - FilterExpression='#tb >= :dt', + TableName="test1", + Select="ALL_ATTRIBUTES", + FilterExpression="#tb >= :dt", ExpressionAttributeNames={"#tb": "app"}, - ExpressionAttributeValues={":dt": {"N": str(1)}} + ExpressionAttributeValues={":dt": {"N": str(1)}}, ) - assert response['Count'] == 1 + assert response["Count"] == 1 @mock_dynamodb2 def test_scan_filter3(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'N'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "N"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'N': '1'}, - 'active': {'BOOL': True} - } + TableName="test1", + Item={"client": {"S": "client1"}, "app": {"N": "1"}, "active": {"BOOL": True}}, ) - table = dynamodb.Table('test1') - response = table.scan( - FilterExpression=Attr('active').eq(True) - ) - assert response['Count'] == 1 + table = dynamodb.Table("test1") + response = table.scan(FilterExpression=Attr("active").eq(True)) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('active').ne(True) - ) - assert response['Count'] == 0 + response = table.scan(FilterExpression=Attr("active").ne(True)) + assert response["Count"] == 0 - response = table.scan( - FilterExpression=Attr('active').ne(False) - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("active").ne(False)) + assert response["Count"] == 1 - response = table.scan( - FilterExpression=Attr('app').ne(1) - ) - assert response['Count'] == 0 + response = table.scan(FilterExpression=Attr("app").ne(1)) + assert response["Count"] == 0 - response = table.scan( - FilterExpression=Attr('app').ne(2) - ) - assert response['Count'] == 1 + response = table.scan(FilterExpression=Attr("app").ne(2)) + assert response["Count"] == 1 @mock_dynamodb2 def test_scan_filter4(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'N'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "N"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) - table = dynamodb.Table('test1') + table = dynamodb.Table("test1") response = table.scan( - FilterExpression=Attr('epoch_ts').lt(7) & Attr('fanout_ts').not_exists() + FilterExpression=Attr("epoch_ts").lt(7) & Attr("fanout_ts").not_exists() ) # Just testing - assert response['Count'] == 0 + assert response["Count"] == 0 @mock_dynamodb2 def test_bad_scan_filter(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) - table = dynamodb.Table('test1') + table = dynamodb.Table("test1") # Bad expression try: - table.scan( - FilterExpression='client test' - ) + table.scan(FilterExpression="client test") except ClientError as err: - err.response['Error']['Code'].should.equal('ValidationError') + err.response["Error"]["Code"].should.equal("ValidationError") else: - raise RuntimeError('Should have raised ResourceInUseException') + raise RuntimeError("Should have raised ResourceInUseException") @mock_dynamodb2 def test_create_table_pay_per_request(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - BillingMode="PAY_PER_REQUEST" + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + BillingMode="PAY_PER_REQUEST", ) @mock_dynamodb2 def test_create_table_error_pay_per_request_with_provisioned_param(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") try: client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123}, - BillingMode="PAY_PER_REQUEST" + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, + BillingMode="PAY_PER_REQUEST", ) except ClientError as err: - err.response['Error']['Code'].should.equal('ValidationException') + err.response["Error"]["Code"].should.equal("ValidationException") @mock_dynamodb2 def test_duplicate_create(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) try: client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceInUseException') + err.response["Error"]["Code"].should.equal("ResourceInUseException") else: - raise RuntimeError('Should have raised ResourceInUseException') + raise RuntimeError("Should have raised ResourceInUseException") @mock_dynamodb2 def test_delete_table(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) - client.delete_table(TableName='test1') + client.delete_table(TableName="test1") resp = client.list_tables() - len(resp['TableNames']).should.equal(0) + len(resp["TableNames"]).should.equal(0) try: - client.delete_table(TableName='test1') + client.delete_table(TableName="test1") except ClientError as err: - err.response['Error']['Code'].should.equal('ResourceNotFoundException') + err.response["Error"]["Code"].should.equal("ResourceNotFoundException") else: - raise RuntimeError('Should have raised ResourceNotFoundException') + raise RuntimeError("Should have raised ResourceNotFoundException") @mock_dynamodb2 def test_delete_item(): - client = boto3.client('dynamodb', region_name='us-east-1') - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"S": "app1"}} ) client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app2'} - } + TableName="test1", Item={"client": {"S": "client1"}, "app": {"S": "app2"}} ) - table = dynamodb.Table('test1') + table = dynamodb.Table("test1") response = table.scan() - assert response['Count'] == 2 + assert response["Count"] == 2 # Test ReturnValues validation with assert_raises(ClientError) as ex: - table.delete_item(Key={'client': 'client1', 'app': 'app1'}, - ReturnValues='ALL_NEW') + table.delete_item( + Key={"client": "client1", "app": "app1"}, ReturnValues="ALL_NEW" + ) # Test deletion and returning old value - response = table.delete_item(Key={'client': 'client1', 'app': 'app1'}, ReturnValues='ALL_OLD') - response['Attributes'].should.contain('client') - response['Attributes'].should.contain('app') + response = table.delete_item( + Key={"client": "client1", "app": "app1"}, ReturnValues="ALL_OLD" + ) + response["Attributes"].should.contain("client") + response["Attributes"].should.contain("app") response = table.scan() - assert response['Count'] == 1 + assert response["Count"] == 1 # Test deletion returning nothing - response = table.delete_item(Key={'client': 'client1', 'app': 'app2'}) - len(response['Attributes']).should.equal(0) + response = table.delete_item(Key={"client": "client1", "app": "app2"}) + len(response["Attributes"]).should.equal(0) response = table.scan() - assert response['Count'] == 0 + assert response["Count"] == 0 @mock_dynamodb2 def test_describe_limits(): - client = boto3.client('dynamodb', region_name='eu-central-1') + client = boto3.client("dynamodb", region_name="eu-central-1") resp = client.describe_limits() - resp['AccountMaxReadCapacityUnits'].should.equal(20000) - resp['AccountMaxWriteCapacityUnits'].should.equal(20000) - resp['TableMaxWriteCapacityUnits'].should.equal(10000) - resp['TableMaxReadCapacityUnits'].should.equal(10000) + resp["AccountMaxReadCapacityUnits"].should.equal(20000) + resp["AccountMaxWriteCapacityUnits"].should.equal(20000) + resp["TableMaxWriteCapacityUnits"].should.equal(10000) + resp["TableMaxReadCapacityUnits"].should.equal(10000) @mock_dynamodb2 def test_set_ttl(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.update_time_to_live( - TableName='test1', - TimeToLiveSpecification={ - 'Enabled': True, - 'AttributeName': 'expire' - } + TableName="test1", + TimeToLiveSpecification={"Enabled": True, "AttributeName": "expire"}, ) - resp = client.describe_time_to_live(TableName='test1') - resp['TimeToLiveDescription']['TimeToLiveStatus'].should.equal('ENABLED') - resp['TimeToLiveDescription']['AttributeName'].should.equal('expire') + resp = client.describe_time_to_live(TableName="test1") + resp["TimeToLiveDescription"]["TimeToLiveStatus"].should.equal("ENABLED") + resp["TimeToLiveDescription"]["AttributeName"].should.equal("expire") client.update_time_to_live( - TableName='test1', - TimeToLiveSpecification={ - 'Enabled': False, - 'AttributeName': 'expire' - } + TableName="test1", + TimeToLiveSpecification={"Enabled": False, "AttributeName": "expire"}, ) - resp = client.describe_time_to_live(TableName='test1') - resp['TimeToLiveDescription']['TimeToLiveStatus'].should.equal('DISABLED') + resp = client.describe_time_to_live(TableName="test1") + resp["TimeToLiveDescription"]["TimeToLiveStatus"].should.equal("DISABLED") # https://github.com/spulec/moto/issues/1043 @mock_dynamodb2 def test_query_missing_expr_names(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, + ) + client.put_item( + TableName="test1", Item={"client": {"S": "test1"}, "app": {"S": "test1"}} + ) + client.put_item( + TableName="test1", Item={"client": {"S": "test2"}, "app": {"S": "test2"}} ) - client.put_item(TableName='test1', Item={'client': {'S': 'test1'}, 'app': {'S': 'test1'}}) - client.put_item(TableName='test1', Item={'client': {'S': 'test2'}, 'app': {'S': 'test2'}}) - resp = client.query(TableName='test1', KeyConditionExpression='client=:client', - ExpressionAttributeValues={':client': {'S': 'test1'}}) + resp = client.query( + TableName="test1", + KeyConditionExpression="client=:client", + ExpressionAttributeValues={":client": {"S": "test1"}}, + ) - resp['Count'].should.equal(1) - resp['Items'][0]['client']['S'].should.equal('test1') + resp["Count"].should.equal(1) + resp["Items"][0]["client"]["S"].should.equal("test1") - resp = client.query(TableName='test1', KeyConditionExpression=':name=test2', - ExpressionAttributeNames={':name': 'client'}) + resp = client.query( + TableName="test1", + KeyConditionExpression=":name=test2", + ExpressionAttributeNames={":name": "client"}, + ) - resp['Count'].should.equal(1) - resp['Items'][0]['client']['S'].should.equal('test2') + resp["Count"].should.equal(1) + resp["Items"][0]["client"]["S"].should.equal("test2") # https://github.com/spulec/moto/issues/2328 @mock_dynamodb2 def test_update_item_with_list(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='Table', - KeySchema=[ - { - 'AttributeName': 'key', - 'KeyType': 'HASH' - } - ], - AttributeDefinitions=[ - { - 'AttributeName': 'key', - 'AttributeType': 'S' - }, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } + TableName="Table", + KeySchema=[{"AttributeName": "key", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "key", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - table = dynamodb.Table('Table') + table = dynamodb.Table("Table") table.update_item( - Key={'key': 'the-key'}, - AttributeUpdates={ - 'list': {'Value': [1, 2], 'Action': 'PUT'} - } + Key={"key": "the-key"}, + AttributeUpdates={"list": {"Value": [1, 2], "Action": "PUT"}}, ) - resp = table.get_item(Key={'key': 'the-key'}) - resp['Item'].should.equal({ - 'key': 'the-key', - 'list': [1, 2] - }) + resp = table.get_item(Key={"key": "the-key"}) + resp["Item"].should.equal({"key": "the-key", "list": [1, 2]}) # https://github.com/spulec/moto/issues/1342 @mock_dynamodb2 def test_update_item_on_map(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') - client = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + table = dynamodb.Table("users") + + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "body": {"nested": {"data": "test"}}, } ) - table = dynamodb.Table('users') - - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'body': {'nested': {'data': 'test'}}, - }) resp = table.scan() - resp['Items'][0]['body'].should.equal({'nested': {'data': 'test'}}) + resp["Items"][0]["body"].should.equal({"nested": {"data": "test"}}) # Nonexistent nested attributes are supported for existing top-level attributes. - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, - UpdateExpression='SET body.#nested.#data = :tb, body.nested.#nonexistentnested.#data = :tb2', + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, + UpdateExpression="SET body.#nested.#data = :tb, body.nested.#nonexistentnested.#data = :tb2", ExpressionAttributeNames={ - '#nested': 'nested', - '#nonexistentnested': 'nonexistentnested', - '#data': 'data' + "#nested": "nested", + "#nonexistentnested": "nonexistentnested", + "#data": "data", }, - ExpressionAttributeValues={ - ':tb': 'new_value', - ':tb2': 'other_value' - }) + ExpressionAttributeValues={":tb": "new_value", ":tb2": "other_value"}, + ) resp = table.scan() - resp['Items'][0]['body'].should.equal({ - 'nested': { - 'data': 'new_value', - 'nonexistentnested': {'data': 'other_value'} - } - }) + resp["Items"][0]["body"].should.equal( + {"nested": {"data": "new_value", "nonexistentnested": {"data": "other_value"}}} + ) # Test nested value for a nonexistent attribute. with assert_raises(client.exceptions.ConditionalCheckFailedException): - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, - UpdateExpression='SET nonexistent.#nested = :tb', - ExpressionAttributeNames={ - '#nested': 'nested' - }, - ExpressionAttributeValues={ - ':tb': 'new_value' - }) - + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, + UpdateExpression="SET nonexistent.#nested = :tb", + ExpressionAttributeNames={"#nested": "nested"}, + ExpressionAttributeValues={":tb": "new_value"}, + ) # https://github.com/spulec/moto/issues/1358 @mock_dynamodb2 def test_update_if_not_exists(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123"}) - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, # if_not_exists without space - UpdateExpression='SET created_at=if_not_exists(created_at,:created_at)', - ExpressionAttributeValues={ - ':created_at': 123 - } + UpdateExpression="SET created_at=if_not_exists(created_at,:created_at)", + ExpressionAttributeValues={":created_at": 123}, ) resp = table.scan() - assert resp['Items'][0]['created_at'] == 123 + assert resp["Items"][0]["created_at"] == 123 - table.update_item(Key={ - 'forum_name': 'the-key', - 'subject': '123' - }, + table.update_item( + Key={"forum_name": "the-key", "subject": "123"}, # if_not_exists with space - UpdateExpression='SET created_at = if_not_exists (created_at, :created_at)', - ExpressionAttributeValues={ - ':created_at': 456 - } + UpdateExpression="SET created_at = if_not_exists (created_at, :created_at)", + ExpressionAttributeValues={":created_at": 456}, ) resp = table.scan() # Still the original value - assert resp['Items'][0]['created_at'] == 123 + assert resp["Items"][0]["created_at"] == 123 # https://github.com/spulec/moto/issues/1937 @mock_dynamodb2 def test_update_return_attributes(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='moto-test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1} + TableName="moto-test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) def update(col, to, rv): return dynamodb.update_item( - TableName='moto-test', - Key={'id': {'S': 'foo'}}, - AttributeUpdates={col: {'Value': {'S': to}, 'Action': 'PUT'}}, - ReturnValues=rv + TableName="moto-test", + Key={"id": {"S": "foo"}}, + AttributeUpdates={col: {"Value": {"S": to}, "Action": "PUT"}}, + ReturnValues=rv, ) - r = update('col1', 'val1', 'ALL_NEW') - assert r['Attributes'] == {'id': {'S': 'foo'}, 'col1': {'S': 'val1'}} + r = update("col1", "val1", "ALL_NEW") + assert r["Attributes"] == {"id": {"S": "foo"}, "col1": {"S": "val1"}} - r = update('col1', 'val2', 'ALL_OLD') - assert r['Attributes'] == {'id': {'S': 'foo'}, 'col1': {'S': 'val1'}} + r = update("col1", "val2", "ALL_OLD") + assert r["Attributes"] == {"id": {"S": "foo"}, "col1": {"S": "val1"}} - r = update('col2', 'val3', 'UPDATED_NEW') - assert r['Attributes'] == {'col2': {'S': 'val3'}} + r = update("col2", "val3", "UPDATED_NEW") + assert r["Attributes"] == {"col2": {"S": "val3"}} - r = update('col2', 'val4', 'UPDATED_OLD') - assert r['Attributes'] == {'col2': {'S': 'val3'}} + r = update("col2", "val4", "UPDATED_OLD") + assert r["Attributes"] == {"col2": {"S": "val3"}} - r = update('col1', 'val5', 'NONE') - assert r['Attributes'] == {} + r = update("col1", "val5", "NONE") + assert r["Attributes"] == {} with assert_raises(ClientError) as ex: - r = update('col1', 'val6', 'WRONG') + r = update("col1", "val6", "WRONG") @mock_dynamodb2 def test_put_return_attributes(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='moto-test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1} + TableName="moto-test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) r = dynamodb.put_item( - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'col1': {'S': 'val1'}}, - ReturnValues='NONE' + TableName="moto-test", + Item={"id": {"S": "foo"}, "col1": {"S": "val1"}}, + ReturnValues="NONE", ) - assert 'Attributes' not in r + assert "Attributes" not in r r = dynamodb.put_item( - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'col1': {'S': 'val2'}}, - ReturnValues='ALL_OLD' + TableName="moto-test", + Item={"id": {"S": "foo"}, "col1": {"S": "val2"}}, + ReturnValues="ALL_OLD", ) - assert r['Attributes'] == {'id': {'S': 'foo'}, 'col1': {'S': 'val1'}} + assert r["Attributes"] == {"id": {"S": "foo"}, "col1": {"S": "val1"}} with assert_raises(ClientError) as ex: dynamodb.put_item( - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'col1': {'S': 'val3'}}, - ReturnValues='ALL_NEW' + TableName="moto-test", + Item={"id": {"S": "foo"}, "col1": {"S": "val3"}}, + ReturnValues="ALL_NEW", ) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal('Return values set to invalid value') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "Return values set to invalid value" + ) @mock_dynamodb2 def test_query_global_secondary_index_when_created_via_update_table_resource(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. dynamodb.create_table( - TableName='users', - KeySchema=[ - { - 'AttributeName': 'user_id', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'user_id', - 'AttributeType': 'N', - } - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - }, + TableName="users", + KeySchema=[{"AttributeName": "user_id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "user_id", "AttributeType": "N"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.update( - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - ], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], GlobalSecondaryIndexUpdates=[ - {'Create': - { - 'IndexName': 'forum_name_index', - 'KeySchema': [ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH', - }, - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 + { + "Create": { + "IndexName": "forum_name_index", + "KeySchema": [{"AttributeName": "forum_name", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, }, } } - ] + ], ) next_user_id = 1 - for my_forum_name in ['cats', 'dogs']: - for my_subject in ['my pet is the cutest', 'wow look at what my pet did', "don't you love my pet?"]: - table.put_item(Item={'user_id': next_user_id, 'forum_name': my_forum_name, 'subject': my_subject}) + for my_forum_name in ["cats", "dogs"]: + for my_subject in [ + "my pet is the cutest", + "wow look at what my pet did", + "don't you love my pet?", + ]: + table.put_item( + Item={ + "user_id": next_user_id, + "forum_name": my_forum_name, + "subject": my_subject, + } + ) next_user_id += 1 # get all the cat users forum_only_query_response = table.query( - IndexName='forum_name_index', - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('forum_name').eq('cats'), + IndexName="forum_name_index", + Select="ALL_ATTRIBUTES", + KeyConditionExpression=Key("forum_name").eq("cats"), ) - forum_only_items = forum_only_query_response['Items'] + forum_only_items = forum_only_query_response["Items"] assert len(forum_only_items) == 3 for item in forum_only_items: - assert item['forum_name'] == 'cats' + assert item["forum_name"] == "cats" # query all cat users with a particular subject forum_and_subject_query_results = table.query( - IndexName='forum_name_index', - Select='ALL_ATTRIBUTES', - KeyConditionExpression=Key('forum_name').eq('cats'), - FilterExpression=Attr('subject').eq('my pet is the cutest'), + IndexName="forum_name_index", + Select="ALL_ATTRIBUTES", + KeyConditionExpression=Key("forum_name").eq("cats"), + FilterExpression=Attr("subject").eq("my pet is the cutest"), ) - forum_and_subject_items = forum_and_subject_query_results['Items'] + forum_and_subject_items = forum_and_subject_query_results["Items"] assert len(forum_and_subject_items) == 1 - assert forum_and_subject_items[0] == {'user_id': Decimal('1'), 'forum_name': 'cats', - 'subject': 'my pet is the cutest'} + assert forum_and_subject_items[0] == { + "user_id": Decimal("1"), + "forum_name": "cats", + "subject": "my pet is the cutest", + } @mock_dynamodb2 def test_dynamodb_streams_1(): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test-streams", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, StreamSpecification={ - 'StreamEnabled': True, - 'StreamViewType': 'NEW_AND_OLD_IMAGES' - } + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", + }, ) - assert 'StreamSpecification' in resp['TableDescription'] - assert resp['TableDescription']['StreamSpecification'] == { - 'StreamEnabled': True, - 'StreamViewType': 'NEW_AND_OLD_IMAGES' + assert "StreamSpecification" in resp["TableDescription"] + assert resp["TableDescription"]["StreamSpecification"] == { + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", } - assert 'LatestStreamLabel' in resp['TableDescription'] - assert 'LatestStreamArn' in resp['TableDescription'] + assert "LatestStreamLabel" in resp["TableDescription"] + assert "LatestStreamArn" in resp["TableDescription"] - resp = conn.delete_table(TableName='test-streams') + resp = conn.delete_table(TableName="test-streams") - assert 'StreamSpecification' in resp['TableDescription'] + assert "StreamSpecification" in resp["TableDescription"] @mock_dynamodb2 def test_dynamodb_streams_2(): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-stream-update', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test-stream-update", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - assert 'StreamSpecification' not in resp['TableDescription'] + assert "StreamSpecification" not in resp["TableDescription"] resp = conn.update_table( - TableName='test-stream-update', - StreamSpecification={ - 'StreamEnabled': True, - 'StreamViewType': 'NEW_IMAGE' - } + TableName="test-stream-update", + StreamSpecification={"StreamEnabled": True, "StreamViewType": "NEW_IMAGE"}, ) - assert 'StreamSpecification' in resp['TableDescription'] - assert resp['TableDescription']['StreamSpecification'] == { - 'StreamEnabled': True, - 'StreamViewType': 'NEW_IMAGE' + assert "StreamSpecification" in resp["TableDescription"] + assert resp["TableDescription"]["StreamSpecification"] == { + "StreamEnabled": True, + "StreamViewType": "NEW_IMAGE", } - assert 'LatestStreamLabel' in resp['TableDescription'] - assert 'LatestStreamArn' in resp['TableDescription'] + assert "LatestStreamLabel" in resp["TableDescription"] + assert "LatestStreamArn" in resp["TableDescription"] @mock_dynamodb2 def test_condition_expressions(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") # Create the DynamoDB table. client.create_table( - TableName='test1', - AttributeDefinitions=[{'AttributeName': 'client', 'AttributeType': 'S'}, {'AttributeName': 'app', 'AttributeType': 'S'}], - KeySchema=[{'AttributeName': 'client', 'KeyType': 'HASH'}, {'AttributeName': 'app', 'KeyType': 'RANGE'}], - ProvisionedThroughput={'ReadCapacityUnits': 123, 'WriteCapacityUnits': 123} + TableName="test1", + AttributeDefinitions=[ + {"AttributeName": "client", "AttributeType": "S"}, + {"AttributeName": "app", "AttributeType": "S"}, + ], + KeySchema=[ + {"AttributeName": "client", "KeyType": "HASH"}, + {"AttributeName": "app", "KeyType": "RANGE"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 123, "WriteCapacityUnits": 123}, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - } + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, ) client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match', + ConditionExpression="attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match", ExpressionAttributeNames={ - '#existing': 'existing', - '#nonexistent': 'nope', - '#match': 'match', + "#existing": "existing", + "#nonexistent": "nope", + "#match": "match", }, + ExpressionAttributeValues={":match": {"S": "match"}}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="NOT(attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2))", + ExpressionAttributeNames={"#nonexistent1": "nope", "#nonexistent2": "nope2"}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="attribute_exists(#nonexistent) OR attribute_exists(#existing)", + ExpressionAttributeNames={"#nonexistent": "nope", "#existing": "existing"}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="#client BETWEEN :a AND :z", + ExpressionAttributeNames={"#client": "client"}, + ExpressionAttributeValues={":a": {"S": "a"}, ":z": {"S": "z"}}, + ) + + client.put_item( + TableName="test1", + Item={ + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, + }, + ConditionExpression="#client IN (:client1, :client2)", + ExpressionAttributeNames={"#client": "client"}, ExpressionAttributeValues={ - ':match': {'S': 'match'} - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + ":client1": {"S": "client1"}, + ":client2": {"S": "client2"}, }, - ConditionExpression='NOT(attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2))', - ExpressionAttributeNames={ - '#nonexistent1': 'nope', - '#nonexistent2': 'nope2' - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - }, - ConditionExpression='attribute_exists(#nonexistent) OR attribute_exists(#existing)', - ExpressionAttributeNames={ - '#nonexistent': 'nope', - '#existing': 'existing' - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - }, - ConditionExpression='#client BETWEEN :a AND :z', - ExpressionAttributeNames={ - '#client': 'client', - }, - ExpressionAttributeValues={ - ':a': {'S': 'a'}, - ':z': {'S': 'z'}, - } - ) - - client.put_item( - TableName='test1', - Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, - }, - ConditionExpression='#client IN (:client1, :client2)', - ExpressionAttributeNames={ - '#client': 'client', - }, - ExpressionAttributeValues={ - ':client1': {'S': 'client1'}, - ':client2': {'S': 'client2'}, - } ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2)', + ConditionExpression="attribute_exists(#nonexistent1) AND attribute_exists(#nonexistent2)", ExpressionAttributeNames={ - '#nonexistent1': 'nope', - '#nonexistent2': 'nope2' - } + "#nonexistent1": "nope", + "#nonexistent2": "nope2", + }, ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='NOT(attribute_not_exists(#nonexistent1) AND attribute_not_exists(#nonexistent2))', + ConditionExpression="NOT(attribute_not_exists(#nonexistent1) AND attribute_not_exists(#nonexistent2))", ExpressionAttributeNames={ - '#nonexistent1': 'nope', - '#nonexistent2': 'nope2' - } + "#nonexistent1": "nope", + "#nonexistent2": "nope2", + }, ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.put_item( - TableName='test1', + TableName="test1", Item={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - 'match': {'S': 'match'}, - 'existing': {'S': 'existing'}, + "client": {"S": "client1"}, + "app": {"S": "app1"}, + "match": {"S": "match"}, + "existing": {"S": "existing"}, }, - ConditionExpression='attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match', + ConditionExpression="attribute_exists(#existing) AND attribute_not_exists(#nonexistent) AND #match = :match", ExpressionAttributeNames={ - '#existing': 'existing', - '#nonexistent': 'nope', - '#match': 'match', + "#existing": "existing", + "#nonexistent": "nope", + "#match": "match", }, - ExpressionAttributeValues={ - ':match': {'S': 'match2'} - } + ExpressionAttributeValues={":match": {"S": "match2"}}, ) # Make sure update_item honors ConditionExpression as well client.update_item( - TableName='test1', - Key={ - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - }, - UpdateExpression='set #match=:match', - ConditionExpression='attribute_exists(#existing)', - ExpressionAttributeNames={ - '#existing': 'existing', - '#match': 'match', - }, - ExpressionAttributeValues={ - ':match': {'S': 'match'} - } + TableName="test1", + Key={"client": {"S": "client1"}, "app": {"S": "app1"}}, + UpdateExpression="set #match=:match", + ConditionExpression="attribute_exists(#existing)", + ExpressionAttributeNames={"#existing": "existing", "#match": "match"}, + ExpressionAttributeValues={":match": {"S": "match"}}, ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.update_item( - TableName='test1', - Key={ - 'client': { 'S': 'client1'}, - 'app': { 'S': 'app1'}, - }, - UpdateExpression='set #match=:match', - ConditionExpression='attribute_not_exists(#existing)', - ExpressionAttributeValues={ - ':match': {'S': 'match'} - }, - ExpressionAttributeNames={ - '#existing': 'existing', - '#match': 'match', - }, + TableName="test1", + Key={"client": {"S": "client1"}, "app": {"S": "app1"}}, + UpdateExpression="set #match=:match", + ConditionExpression="attribute_not_exists(#existing)", + ExpressionAttributeValues={":match": {"S": "match"}}, + ExpressionAttributeNames={"#existing": "existing", "#match": "match"}, ) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.delete_item( - TableName = 'test1', - Key = { - 'client': {'S': 'client1'}, - 'app': {'S': 'app1'}, - }, - ConditionExpression = 'attribute_not_exists(#existing)', - ExpressionAttributeValues = { - ':match': {'S': 'match'} - }, - ExpressionAttributeNames = { - '#existing': 'existing', - '#match': 'match', - }, + TableName="test1", + Key={"client": {"S": "client1"}, "app": {"S": "app1"}}, + ConditionExpression="attribute_not_exists(#existing)", + ExpressionAttributeValues={":match": {"S": "match"}}, + ExpressionAttributeNames={"#existing": "existing", "#match": "match"}, ) @mock_dynamodb2 def test_condition_expression__attr_doesnt_exist(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'forum_name', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}) + TableName="test", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) - client.put_item(TableName='test', - Item={'forum_name': {'S': 'foo'}, 'ttl': {'N': 'bar'}}) + client.put_item( + TableName="test", Item={"forum_name": {"S": "foo"}, "ttl": {"N": "bar"}} + ) def update_if_attr_doesnt_exist(): # Test nonexistent top-level attribute. client.update_item( - TableName='test', - Key={ - 'forum_name': {'S': 'the-key'}, - 'subject': {'S': 'the-subject'}, - }, - UpdateExpression='set #new_state=:new_state, #ttl=:ttl', - ConditionExpression='attribute_not_exists(#new_state)', - ExpressionAttributeNames={'#new_state': 'foobar', '#ttl': 'ttl'}, + TableName="test", + Key={"forum_name": {"S": "the-key"}, "subject": {"S": "the-subject"}}, + UpdateExpression="set #new_state=:new_state, #ttl=:ttl", + ConditionExpression="attribute_not_exists(#new_state)", + ExpressionAttributeNames={"#new_state": "foobar", "#ttl": "ttl"}, ExpressionAttributeValues={ - ':new_state': {'S': 'some-value'}, - ':ttl': {'N': '12345.67'}, + ":new_state": {"S": "some-value"}, + ":ttl": {"N": "12345.67"}, }, - ReturnValues='ALL_NEW', + ReturnValues="ALL_NEW", ) update_if_attr_doesnt_exist() @@ -2194,657 +2009,832 @@ def test_condition_expression__attr_doesnt_exist(): @mock_dynamodb2 def test_condition_expression__or_order(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], - AttributeDefinitions=[ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) # ensure that the RHS of the OR expression is not evaluated if the LHS # returns true (as it would result an error) client.update_item( - TableName='test', - Key={ - 'forum_name': {'S': 'the-key'}, - }, - UpdateExpression='set #ttl=:ttl', - ConditionExpression='attribute_not_exists(#ttl) OR #ttl <= :old_ttl', - ExpressionAttributeNames={'#ttl': 'ttl'}, - ExpressionAttributeValues={ - ':ttl': {'N': '6'}, - ':old_ttl': {'N': '5'}, - } + TableName="test", + Key={"forum_name": {"S": "the-key"}}, + UpdateExpression="set #ttl=:ttl", + ConditionExpression="attribute_not_exists(#ttl) OR #ttl <= :old_ttl", + ExpressionAttributeNames={"#ttl": "ttl"}, + ExpressionAttributeValues={":ttl": {"N": "6"}, ":old_ttl": {"N": "5"}}, ) @mock_dynamodb2 def test_condition_expression__and_order(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'forum_name', 'KeyType': 'HASH'}], - AttributeDefinitions=[ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + TableName="test", + KeySchema=[{"AttributeName": "forum_name", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "forum_name", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - + # ensure that the RHS of the AND expression is not evaluated if the LHS # returns true (as it would result an error) with assert_raises(client.exceptions.ConditionalCheckFailedException): client.update_item( - TableName='test', - Key={ - 'forum_name': {'S': 'the-key'}, - }, - UpdateExpression='set #ttl=:ttl', - ConditionExpression='attribute_exists(#ttl) AND #ttl <= :old_ttl', - ExpressionAttributeNames={'#ttl': 'ttl'}, - ExpressionAttributeValues={ - ':ttl': {'N': '6'}, - ':old_ttl': {'N': '5'}, - } + TableName="test", + Key={"forum_name": {"S": "the-key"}}, + UpdateExpression="set #ttl=:ttl", + ConditionExpression="attribute_exists(#ttl) AND #ttl <= :old_ttl", + ExpressionAttributeNames={"#ttl": "ttl"}, + ExpressionAttributeValues={":ttl": {"N": "6"}, ":old_ttl": {"N": "5"}}, ) @mock_dynamodb2 def test_query_gsi_with_range_key(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_hash_key', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_range_key', 'AttributeType': 'S'} + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_hash_key", "AttributeType": "S"}, + {"AttributeName": "gsi_range_key", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - { - 'AttributeName': 'gsi_hash_key', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'gsi_range_key', - 'KeyType': 'RANGE' - }, + "IndexName": "test_gsi", + "KeySchema": [ + {"AttributeName": "gsi_hash_key", "KeyType": "HASH"}, + {"AttributeName": "gsi_range_key", "KeyType": "RANGE"}, ], - 'Projection': { - 'ProjectionType': 'ALL', + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, - ] + } + ], ) dynamodb.put_item( - TableName='test', + TableName="test", Item={ - 'id': {'S': 'test1'}, - 'gsi_hash_key': {'S': 'key1'}, - 'gsi_range_key': {'S': 'range1'}, - } + "id": {"S": "test1"}, + "gsi_hash_key": {"S": "key1"}, + "gsi_range_key": {"S": "range1"}, + }, ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': 'test2'}, - 'gsi_hash_key': {'S': 'key1'}, - } + TableName="test", Item={"id": {"S": "test2"}, "gsi_hash_key": {"S": "key1"}} ) - res = dynamodb.query(TableName='test', IndexName='test_gsi', - KeyConditionExpression='gsi_hash_key = :gsi_hash_key AND gsi_range_key = :gsi_range_key', - ExpressionAttributeValues={ - ':gsi_hash_key': {'S': 'key1'}, - ':gsi_range_key': {'S': 'range1'} - }) + res = dynamodb.query( + TableName="test", + IndexName="test_gsi", + KeyConditionExpression="gsi_hash_key = :gsi_hash_key AND gsi_range_key = :gsi_range_key", + ExpressionAttributeValues={ + ":gsi_hash_key": {"S": "key1"}, + ":gsi_range_key": {"S": "range1"}, + }, + ) res.should.have.key("Count").equal(1) res.should.have.key("Items") - res['Items'][0].should.equal({ - 'id': {'S': 'test1'}, - 'gsi_hash_key': {'S': 'key1'}, - 'gsi_range_key': {'S': 'range1'}, - }) + res["Items"][0].should.equal( + { + "id": {"S": "test1"}, + "gsi_hash_key": {"S": "key1"}, + "gsi_range_key": {"S": "range1"}, + } + ) @mock_dynamodb2 def test_scan_by_non_exists_index(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_col', 'AttributeType': 'S'} + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_col", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - { - 'AttributeName': 'gsi_col', - 'KeyType': 'HASH' - }, - ], - 'Projection': { - 'ProjectionType': 'ALL', + "IndexName": "test_gsi", + "KeySchema": [{"AttributeName": "gsi_col", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, - ] + } + ], ) with assert_raises(ClientError) as ex: - dynamodb.scan(TableName='test', IndexName='non_exists_index') + dynamodb.scan(TableName="test", IndexName="non_exists_index") - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'The table does not have the specified index: non_exists_index' + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "The table does not have the specified index: non_exists_index" ) @mock_dynamodb2 def test_batch_items_returns_all(): dynamodb = _create_user_table() - returned_items = dynamodb.batch_get_item(RequestItems={ - 'users': { - 'Keys': [{ - 'username': {'S': 'user0'} - }, { - 'username': {'S': 'user1'} - }, { - 'username': {'S': 'user2'} - }, { - 'username': {'S': 'user3'} - }], - 'ConsistentRead': True + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + } } - })['Responses']['users'] + )["Responses"]["users"] assert len(returned_items) == 3 - assert [item['username']['S'] for item in returned_items] == ['user1', 'user2', 'user3'] + assert [item["username"]["S"] for item in returned_items] == [ + "user1", + "user2", + "user3", + ] @mock_dynamodb2 def test_batch_items_with_basic_projection_expression(): dynamodb = _create_user_table() - returned_items = dynamodb.batch_get_item(RequestItems={ - 'users': { - 'Keys': [{ - 'username': {'S': 'user0'} - }, { - 'username': {'S': 'user1'} - }, { - 'username': {'S': 'user2'} - }, { - 'username': {'S': 'user3'} - }], - 'ConsistentRead': True, - 'ProjectionExpression': 'username' + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + "ProjectionExpression": "username", + } } - })['Responses']['users'] + )["Responses"]["users"] returned_items.should.have.length_of(3) - [item['username']['S'] for item in returned_items].should.be.equal(['user1', 'user2', 'user3']) - [item.get('foo') for item in returned_items].should.be.equal([None, None, None]) + [item["username"]["S"] for item in returned_items].should.be.equal( + ["user1", "user2", "user3"] + ) + [item.get("foo") for item in returned_items].should.be.equal([None, None, None]) # The projection expression should not remove data from storage - returned_items = dynamodb.batch_get_item(RequestItems = { - 'users': { - 'Keys': [{ - 'username': {'S': 'user0'} - }, { - 'username': {'S': 'user1'} - }, { - 'username': {'S': 'user2'} - }, { - 'username': {'S': 'user3'} - }], - 'ConsistentRead': True + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + } } - })['Responses']['users'] + )["Responses"]["users"] - [item['username']['S'] for item in returned_items].should.be.equal(['user1', 'user2', 'user3']) - [item['foo']['S'] for item in returned_items].should.be.equal(['bar', 'bar', 'bar']) + [item["username"]["S"] for item in returned_items].should.be.equal( + ["user1", "user2", "user3"] + ) + [item["foo"]["S"] for item in returned_items].should.be.equal(["bar", "bar", "bar"]) @mock_dynamodb2 def test_batch_items_with_basic_projection_expression_and_attr_expression_names(): dynamodb = _create_user_table() - returned_items = dynamodb.batch_get_item(RequestItems={ - 'users': { - 'Keys': [{ - 'username': {'S': 'user0'} - }, { - 'username': {'S': 'user1'} - }, { - 'username': {'S': 'user2'} - }, { - 'username': {'S': 'user3'} - }], - 'ConsistentRead': True, - 'ProjectionExpression': '#rl', - 'ExpressionAttributeNames': { - '#rl': 'username' - }, + returned_items = dynamodb.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user1"}}, + {"username": {"S": "user2"}}, + {"username": {"S": "user3"}}, + ], + "ConsistentRead": True, + "ProjectionExpression": "#rl", + "ExpressionAttributeNames": {"#rl": "username"}, + } } - })['Responses']['users'] + )["Responses"]["users"] returned_items.should.have.length_of(3) - [item['username']['S'] for item in returned_items].should.be.equal(['user1', 'user2', 'user3']) - [item.get('foo') for item in returned_items].should.be.equal([None, None, None]) + [item["username"]["S"] for item in returned_items].should.be.equal( + ["user1", "user2", "user3"] + ) + [item.get("foo") for item in returned_items].should.be.equal([None, None, None]) @mock_dynamodb2 def test_batch_items_should_throw_exception_for_duplicate_request(): client = _create_user_table() with assert_raises(ClientError) as ex: - client.batch_get_item(RequestItems={ - 'users': { - 'Keys': [{ - 'username': {'S': 'user0'} - }, { - 'username': {'S': 'user0'} - }], - 'ConsistentRead': True - }}) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal('Provided list of item keys contains duplicates') + client.batch_get_item( + RequestItems={ + "users": { + "Keys": [ + {"username": {"S": "user0"}}, + {"username": {"S": "user0"}}, + ], + "ConsistentRead": True, + } + } + ) + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Provided list of item keys contains duplicates" + ) @mock_dynamodb2 def test_index_with_unknown_attributes_should_fail(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") - expected_exception = 'Some index key attributes are not defined in AttributeDefinitions.' + expected_exception = ( + "Some index key attributes are not defined in AttributeDefinitions." + ) with assert_raises(ClientError) as ex: dynamodb.create_table( AttributeDefinitions=[ - {'AttributeName': 'customer_nr', 'AttributeType': 'S'}, - {'AttributeName': 'last_name', 'AttributeType': 'S'}], - TableName='table_with_missing_attribute_definitions', + {"AttributeName": "customer_nr", "AttributeType": "S"}, + {"AttributeName": "last_name", "AttributeType": "S"}, + ], + TableName="table_with_missing_attribute_definitions", KeySchema=[ - {'AttributeName': 'customer_nr', 'KeyType': 'HASH'}, - {'AttributeName': 'last_name', 'KeyType': 'RANGE'}], - LocalSecondaryIndexes=[{ - 'IndexName': 'indexthataddsanadditionalattribute', - 'KeySchema': [ - {'AttributeName': 'customer_nr', 'KeyType': 'HASH'}, - {'AttributeName': 'postcode', 'KeyType': 'RANGE'}], - 'Projection': { 'ProjectionType': 'ALL' } - }], - BillingMode='PAY_PER_REQUEST') + {"AttributeName": "customer_nr", "KeyType": "HASH"}, + {"AttributeName": "last_name", "KeyType": "RANGE"}, + ], + LocalSecondaryIndexes=[ + { + "IndexName": "indexthataddsanadditionalattribute", + "KeySchema": [ + {"AttributeName": "customer_nr", "KeyType": "HASH"}, + {"AttributeName": "postcode", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + } + ], + BillingMode="PAY_PER_REQUEST", + ) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.contain(expected_exception) + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.contain(expected_exception) @mock_dynamodb2 def test_update_list_index__set_existing_index(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, - UpdateExpression='set itemlist[1]=:Item', - ExpressionAttributeValues={':Item': {'S': 'bar2_update'}}) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="set itemlist[1]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar2_update"}}, + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] - result['id'].should.equal({'S': 'foo'}) - result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]}) + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + result["id"].should.equal({"S": "foo"}) + result["itemlist"].should.equal( + {"L": [{"S": "bar1"}, {"S": "bar2_update"}, {"S": "bar3"}]} + ) @mock_dynamodb2 def test_update_list_index__set_existing_nested_index(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo2'}, 'itemmap': {'M': {'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, - UpdateExpression='set itemmap.itemlist[1]=:Item', - ExpressionAttributeValues={':Item': {'S': 'bar2_update'}}) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": {"itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}} + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemmap.itemlist[1]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar2_update"}}, + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] - result['id'].should.equal({'S': 'foo2'}) - result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}, {'S': 'bar2_update'}, {'S': 'bar3'}]) + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + result["id"].should.equal({"S": "foo2"}) + result["itemmap"]["M"]["itemlist"]["L"].should.equal( + [{"S": "bar1"}, {"S": "bar2_update"}, {"S": "bar3"}] + ) @mock_dynamodb2 def test_update_list_index__set_index_out_of_range(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, - UpdateExpression='set itemlist[10]=:Item', - ExpressionAttributeValues={':Item': {'S': 'bar10'}}) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="set itemlist[10]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar10"}}, + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] - assert result['id'] == {'S': 'foo'} - assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}, {'S': 'bar10'}]} + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + assert result["id"] == {"S": "foo"} + assert result["itemlist"] == { + "L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}, {"S": "bar10"}] + } @mock_dynamodb2 def test_update_list_index__set_nested_index_out_of_range(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo2'}, 'itemmap': {'M': {'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, - UpdateExpression='set itemmap.itemlist[10]=:Item', - ExpressionAttributeValues={':Item': {'S': 'bar10'}}) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": {"itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}} + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemmap.itemlist[10]=:Item", + ExpressionAttributeValues={":Item": {"S": "bar10"}}, + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] - assert result['id'] == {'S': 'foo2'} - assert result['itemmap']['M']['itemlist']['L'] == [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}, {'S': 'bar10'}] + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + assert result["id"] == {"S": "foo2"} + assert result["itemmap"]["M"]["itemlist"]["L"] == [ + {"S": "bar1"}, + {"S": "bar2"}, + {"S": "bar3"}, + {"S": "bar10"}, + ] @mock_dynamodb2 def test_update_list_index__set_double_nested_index(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo2'}, - 'itemmap': {'M': {'itemlist': {'L': [{'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}}, - {'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar21'}}}]}}}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, - UpdateExpression='set itemmap.itemlist[1].foos=:Item', - ExpressionAttributeValues={':Item': {'S': 'bar22'}}) + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": { + "itemlist": { + "L": [ + {"M": {"foo": {"S": "bar11"}, "foos": {"S": "bar12"}}}, + {"M": {"foo": {"S": "bar21"}, "foos": {"S": "bar21"}}}, + ] + } + } + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemmap.itemlist[1].foos=:Item", + ExpressionAttributeValues={":Item": {"S": "bar22"}}, + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] - assert result['id'] == {'S': 'foo2'} - len(result['itemmap']['M']['itemlist']['L']).should.equal(2) - result['itemmap']['M']['itemlist']['L'][0].should.equal({'M': {'foo': {'S': 'bar11'}, 'foos': {'S': 'bar12'}}}) # unchanged - result['itemmap']['M']['itemlist']['L'][1].should.equal({'M': {'foo': {'S': 'bar21'}, 'foos': {'S': 'bar22'}}}) # updated + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + assert result["id"] == {"S": "foo2"} + len(result["itemmap"]["M"]["itemlist"]["L"]).should.equal(2) + result["itemmap"]["M"]["itemlist"]["L"][0].should.equal( + {"M": {"foo": {"S": "bar11"}, "foos": {"S": "bar12"}}} + ) # unchanged + result["itemmap"]["M"]["itemlist"]["L"][1].should.equal( + {"M": {"foo": {"S": "bar21"}, "foos": {"S": "bar22"}}} + ) # updated @mock_dynamodb2 def test_update_list_index__set_index_of_a_string(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, Item={'id': {'S': 'foo2'}, 'itemstr': {'S': 'somestring'}}) + client.put_item( + TableName=table_name, Item={"id": {"S": "foo2"}, "itemstr": {"S": "somestring"}} + ) with assert_raises(ClientError) as ex: - client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, - UpdateExpression='set itemstr[1]=:Item', - ExpressionAttributeValues={':Item': {'S': 'string_update'}}) - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="set itemstr[1]=:Item", + ExpressionAttributeValues={":Item": {"S": "string_update"}}, + ) + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})[ + "Item" + ] - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal( - 'The document path provided in the update expression is invalid for update') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "The document path provided in the update expression is invalid for update" + ) @mock_dynamodb2 def test_remove_top_level_attribute(): - table_name = 'test_remove' + table_name = "test_remove" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo'}, 'item': {'S': 'bar'}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE item') + client.put_item( + TableName=table_name, Item={"id": {"S": "foo"}, "item": {"S": "bar"}} + ) + client.update_item( + TableName=table_name, Key={"id": {"S": "foo"}}, UpdateExpression="REMOVE item" + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] - result.should.equal({'id': {'S': 'foo'}}) + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + result.should.equal({"id": {"S": "foo"}}) @mock_dynamodb2 def test_remove_list_index__remove_existing_index(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[1]') + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="REMOVE itemlist[1]", + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] - result['id'].should.equal({'S': 'foo'}) - result['itemlist'].should.equal({'L': [{'S': 'bar1'}, {'S': 'bar3'}]}) + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + result["id"].should.equal({"S": "foo"}) + result["itemlist"].should.equal({"L": [{"S": "bar1"}, {"S": "bar3"}]}) @mock_dynamodb2 def test_remove_list_index__remove_existing_nested_index(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo2'}, 'itemmap': {'M': {'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}]}}}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, UpdateExpression='REMOVE itemmap.itemlist[1]') + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": {"M": {"itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}]}}}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="REMOVE itemmap.itemlist[1]", + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] - result['id'].should.equal({'S': 'foo2'}) - result['itemmap']['M']['itemlist']['L'].should.equal([{'S': 'bar1'}]) + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + result["id"].should.equal({"S": "foo2"}) + result["itemmap"]["M"]["itemlist"]["L"].should.equal([{"S": "bar1"}]) @mock_dynamodb2 def test_remove_list_index__remove_existing_double_nested_index(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo2'}, - 'itemmap': {'M': {'itemlist': {'L': [{'M': {'foo00': {'S': 'bar1'}, - 'foo01': {'S': 'bar2'}}}, - {'M': {'foo10': {'S': 'bar1'}, - 'foo11': {'S': 'bar2'}}}]}}}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo2'}}, - UpdateExpression='REMOVE itemmap.itemlist[1].foo10') + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo2"}, + "itemmap": { + "M": { + "itemlist": { + "L": [ + {"M": {"foo00": {"S": "bar1"}, "foo01": {"S": "bar2"}}}, + {"M": {"foo10": {"S": "bar1"}, "foo11": {"S": "bar2"}}}, + ] + } + } + }, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo2"}}, + UpdateExpression="REMOVE itemmap.itemlist[1].foo10", + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo2'}})['Item'] - assert result['id'] == {'S': 'foo2'} - assert result['itemmap']['M']['itemlist']['L'][0]['M'].should.equal({'foo00': {'S': 'bar1'}, - 'foo01': {'S': 'bar2'}}) # untouched - assert result['itemmap']['M']['itemlist']['L'][1]['M'].should.equal({'foo11': {'S': 'bar2'}}) # changed + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo2"}})["Item"] + assert result["id"] == {"S": "foo2"} + assert result["itemmap"]["M"]["itemlist"]["L"][0]["M"].should.equal( + {"foo00": {"S": "bar1"}, "foo01": {"S": "bar2"}} + ) # untouched + assert result["itemmap"]["M"]["itemlist"]["L"][1]["M"].should.equal( + {"foo11": {"S": "bar2"}} + ) # changed @mock_dynamodb2 def test_remove_list_index__remove_index_out_of_range(): - table_name = 'test_list_index_access' + table_name = "test_list_index_access" client = create_table_with_list(table_name) - client.put_item(TableName=table_name, - Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]}}) - client.update_item(TableName=table_name, Key={'id': {'S': 'foo'}}, UpdateExpression='REMOVE itemlist[10]') + client.put_item( + TableName=table_name, + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]}, + }, + ) + client.update_item( + TableName=table_name, + Key={"id": {"S": "foo"}}, + UpdateExpression="REMOVE itemlist[10]", + ) # - result = client.get_item(TableName=table_name, Key={'id': {'S': 'foo'}})['Item'] - assert result['id'] == {'S': 'foo'} - assert result['itemlist'] == {'L': [{'S': 'bar1'}, {'S': 'bar2'}, {'S': 'bar3'}]} + result = client.get_item(TableName=table_name, Key={"id": {"S": "foo"}})["Item"] + assert result["id"] == {"S": "foo"} + assert result["itemlist"] == {"L": [{"S": "bar1"}, {"S": "bar2"}, {"S": "bar3"}]} def create_table_with_list(table_name): - client = boto3.client('dynamodb', region_name='us-east-1') - client.create_table(TableName=table_name, - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - BillingMode='PAY_PER_REQUEST') + client = boto3.client("dynamodb", region_name="us-east-1") + client.create_table( + TableName=table_name, + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + BillingMode="PAY_PER_REQUEST", + ) return client @mock_dynamodb2 def test_sorted_query_with_numerical_sort_key(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') - dynamodb.create_table(TableName="CarCollection", - KeySchema=[{ 'AttributeName': "CarModel", 'KeyType': 'HASH'}, - {'AttributeName': "CarPrice", 'KeyType': 'RANGE'}], - AttributeDefinitions=[{'AttributeName': "CarModel", 'AttributeType': "S"}, - {'AttributeName': "CarPrice", 'AttributeType': "N"}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}) + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") + dynamodb.create_table( + TableName="CarCollection", + KeySchema=[ + {"AttributeName": "CarModel", "KeyType": "HASH"}, + {"AttributeName": "CarPrice", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "CarModel", "AttributeType": "S"}, + {"AttributeName": "CarPrice", "AttributeType": "N"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) def create_item(price): return {"CarModel": "M", "CarPrice": price} - table = dynamodb.Table('CarCollection') + table = dynamodb.Table("CarCollection") items = list(map(create_item, [2, 1, 10, 3])) for item in items: table.put_item(Item=item) - response = table.query(KeyConditionExpression=Key('CarModel').eq("M")) + response = table.query(KeyConditionExpression=Key("CarModel").eq("M")) - response_items = response['Items'] + response_items = response["Items"] assert len(items) == len(response_items) assert all(isinstance(item["CarPrice"], Decimal) for item in response_items) response_prices = [item["CarPrice"] for item in response_items] expected_prices = [Decimal(item["CarPrice"]) for item in items] expected_prices.sort() - assert expected_prices == response_prices, "result items are not sorted by numerical value" + assert ( + expected_prices == response_prices + ), "result items are not sorted by numerical value" # https://github.com/spulec/moto/issues/1874 @mock_dynamodb2 def test_item_size_is_under_400KB(): - dynamodb = boto3.resource('dynamodb') - client = boto3.client('dynamodb') + dynamodb = boto3.resource("dynamodb") + client = boto3.client("dynamodb") dynamodb.create_table( - TableName='moto-test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1} + TableName="moto-test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - table = dynamodb.Table('moto-test') + table = dynamodb.Table("moto-test") - large_item = 'x' * 410 * 1000 - assert_failure_due_to_item_size(func=client.put_item, - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'item': {'S': large_item}}) - assert_failure_due_to_item_size(func=table.put_item, Item = {'id': 'bar', 'item': large_item}) - assert_failure_due_to_item_size(func=client.update_item, - TableName='moto-test', - Key={'id': {'S': 'foo2'}}, - UpdateExpression='set item=:Item', - ExpressionAttributeValues={':Item': {'S': large_item}}) + large_item = "x" * 410 * 1000 + assert_failure_due_to_item_size( + func=client.put_item, + TableName="moto-test", + Item={"id": {"S": "foo"}, "item": {"S": large_item}}, + ) + assert_failure_due_to_item_size( + func=table.put_item, Item={"id": "bar", "item": large_item} + ) + assert_failure_due_to_item_size( + func=client.update_item, + TableName="moto-test", + Key={"id": {"S": "foo2"}}, + UpdateExpression="set item=:Item", + ExpressionAttributeValues={":Item": {"S": large_item}}, + ) # Assert op fails when updating a nested item - assert_failure_due_to_item_size(func=table.put_item, - Item={'id': 'bar', 'itemlist': [{'item': large_item}]}) - assert_failure_due_to_item_size(func=client.put_item, - TableName='moto-test', - Item={'id': {'S': 'foo'}, 'itemlist': {'L': [{'M': {'item1': {'S': large_item}}}]}}) + assert_failure_due_to_item_size( + func=table.put_item, Item={"id": "bar", "itemlist": [{"item": large_item}]} + ) + assert_failure_due_to_item_size( + func=client.put_item, + TableName="moto-test", + Item={ + "id": {"S": "foo"}, + "itemlist": {"L": [{"M": {"item1": {"S": large_item}}}]}, + }, + ) def assert_failure_due_to_item_size(func, **kwargs): with assert_raises(ClientError) as ex: func(**kwargs) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal('Item size has exceeded the maximum allowed size') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Item size has exceeded the maximum allowed size" + ) @mock_dynamodb2 # https://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_Query.html#DDB-Query-request-KeyConditionExpression def test_hash_key_cannot_use_begins_with_operations(): - dynamodb = boto3.resource('dynamodb') + dynamodb = boto3.resource("dynamodb") table = dynamodb.create_table( - TableName='test-table', - KeySchema=[{'AttributeName': 'key', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'key', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}) + TableName="test-table", + KeySchema=[{"AttributeName": "key", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "key", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + ) - items = [{'key': 'prefix-$LATEST', 'value': '$LATEST'}, - {'key': 'prefix-DEV', 'value': 'DEV'}, - {'key': 'prefix-PROD', 'value': 'PROD'}] + items = [ + {"key": "prefix-$LATEST", "value": "$LATEST"}, + {"key": "prefix-DEV", "value": "DEV"}, + {"key": "prefix-PROD", "value": "PROD"}, + ] with table.batch_writer() as batch: for item in items: batch.put_item(Item=item) - table = dynamodb.Table('test-table') + table = dynamodb.Table("test-table") with assert_raises(ClientError) as ex: - table.query(KeyConditionExpression=Key('key').begins_with('prefix-')) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal('Query key condition not supported') + table.query(KeyConditionExpression=Key("key").begins_with("prefix-")) + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Query key condition not supported" + ) @mock_dynamodb2 def test_update_supports_complex_expression_attribute_values(): - client = boto3.client('dynamodb') + client = boto3.client("dynamodb") - client.create_table(AttributeDefinitions=[{'AttributeName': 'SHA256', 'AttributeType': 'S'}], - TableName='TestTable', - KeySchema=[{'AttributeName': 'SHA256', 'KeyType': 'HASH'}], - ProvisionedThroughput={'ReadCapacityUnits': 5, 'WriteCapacityUnits': 5}) + client.create_table( + AttributeDefinitions=[{"AttributeName": "SHA256", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "SHA256", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) - client.update_item(TableName='TestTable', - Key={'SHA256': {'S': 'sha-of-file'}}, - UpdateExpression=('SET MD5 = :md5,' - 'MyStringSet = :string_set,' - 'MyMap = :map' ), - ExpressionAttributeValues={':md5': {'S': 'md5-of-file'}, - ':string_set': {'SS': ['string1', 'string2']}, - ':map': {'M': {'EntryKey': {'SS': ['thing1', 'thing2']}}}}) - result = client.get_item(TableName='TestTable', Key={'SHA256': {'S': 'sha-of-file'}})['Item'] - result.should.equal({u'MyStringSet': {u'SS': [u'string1', u'string2']}, - 'MyMap': {u'M': {u'EntryKey': {u'SS': [u'thing1', u'thing2']}}}, - 'SHA256': {u'S': u'sha-of-file'}, - 'MD5': {u'S': u'md5-of-file'}}) + client.update_item( + TableName="TestTable", + Key={"SHA256": {"S": "sha-of-file"}}, + UpdateExpression=( + "SET MD5 = :md5," "MyStringSet = :string_set," "MyMap = :map" + ), + ExpressionAttributeValues={ + ":md5": {"S": "md5-of-file"}, + ":string_set": {"SS": ["string1", "string2"]}, + ":map": {"M": {"EntryKey": {"SS": ["thing1", "thing2"]}}}, + }, + ) + result = client.get_item( + TableName="TestTable", Key={"SHA256": {"S": "sha-of-file"}} + )["Item"] + result.should.equal( + { + "MyStringSet": {"SS": ["string1", "string2"]}, + "MyMap": {"M": {"EntryKey": {"SS": ["thing1", "thing2"]}}}, + "SHA256": {"S": "sha-of-file"}, + "MD5": {"S": "md5-of-file"}, + } + ) @mock_dynamodb2 def test_update_supports_list_append(): - client = boto3.client('dynamodb') + client = boto3.client("dynamodb") - client.create_table(AttributeDefinitions=[{'AttributeName': 'SHA256', 'AttributeType': 'S'}], - TableName='TestTable', - KeySchema=[{'AttributeName': 'SHA256', 'KeyType': 'HASH'}], - ProvisionedThroughput={'ReadCapacityUnits': 5, 'WriteCapacityUnits': 5}) - client.put_item(TableName='TestTable', - Item={'SHA256': {'S': 'sha-of-file'}, 'crontab': {'L': [{'S': 'bar1'}]}}) + client.create_table( + AttributeDefinitions=[{"AttributeName": "SHA256", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "SHA256", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="TestTable", + Item={"SHA256": {"S": "sha-of-file"}, "crontab": {"L": [{"S": "bar1"}]}}, + ) # Update item using list_append expression - client.update_item(TableName='TestTable', - Key={'SHA256': {'S': 'sha-of-file'}}, - UpdateExpression="SET crontab = list_append(crontab, :i)", - ExpressionAttributeValues={':i': {'L': [{'S': 'bar2'}]}}) + client.update_item( + TableName="TestTable", + Key={"SHA256": {"S": "sha-of-file"}}, + UpdateExpression="SET crontab = list_append(crontab, :i)", + ExpressionAttributeValues={":i": {"L": [{"S": "bar2"}]}}, + ) # Verify item is appended to the existing list - result = client.get_item(TableName='TestTable', Key={'SHA256': {'S': 'sha-of-file'}})['Item'] - result.should.equal({'SHA256': {'S': 'sha-of-file'}, - 'crontab': {'L': [{'S': 'bar1'}, {'S': 'bar2'}]}}) + result = client.get_item( + TableName="TestTable", Key={"SHA256": {"S": "sha-of-file"}} + )["Item"] + result.should.equal( + { + "SHA256": {"S": "sha-of-file"}, + "crontab": {"L": [{"S": "bar1"}, {"S": "bar2"}]}, + } + ) @mock_dynamodb2 def test_update_catches_invalid_list_append_operation(): - client = boto3.client('dynamodb') + client = boto3.client("dynamodb") - client.create_table(AttributeDefinitions=[{'AttributeName': 'SHA256', 'AttributeType': 'S'}], - TableName='TestTable', - KeySchema=[{'AttributeName': 'SHA256', 'KeyType': 'HASH'}], - ProvisionedThroughput={'ReadCapacityUnits': 5, 'WriteCapacityUnits': 5}) - client.put_item(TableName='TestTable', - Item={'SHA256': {'S': 'sha-of-file'}, 'crontab': {'L': [{'S': 'bar1'}]}}) + client.create_table( + AttributeDefinitions=[{"AttributeName": "SHA256", "AttributeType": "S"}], + TableName="TestTable", + KeySchema=[{"AttributeName": "SHA256", "KeyType": "HASH"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="TestTable", + Item={"SHA256": {"S": "sha-of-file"}, "crontab": {"L": [{"S": "bar1"}]}}, + ) # Update item using invalid list_append expression with assert_raises(ParamValidationError) as ex: - client.update_item(TableName='TestTable', - Key={'SHA256': {'S': 'sha-of-file'}}, - UpdateExpression="SET crontab = list_append(crontab, :i)", - ExpressionAttributeValues={':i': [{'S': 'bar2'}]}) + client.update_item( + TableName="TestTable", + Key={"SHA256": {"S": "sha-of-file"}}, + UpdateExpression="SET crontab = list_append(crontab, :i)", + ExpressionAttributeValues={":i": [{"S": "bar2"}]}, + ) # Verify correct error is returned str(ex.exception).should.match("Parameter validation failed:") - str(ex.exception).should.match("Invalid type for parameter ExpressionAttributeValues.") + str(ex.exception).should.match( + "Invalid type for parameter ExpressionAttributeValues." + ) def _create_user_table(): - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") client.create_table( - TableName='users', - KeySchema=[{'AttributeName': 'username', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'username', 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 5, 'WriteCapacityUnits': 5} + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, + ) + client.put_item( + TableName="users", Item={"username": {"S": "user1"}, "foo": {"S": "bar"}} + ) + client.put_item( + TableName="users", Item={"username": {"S": "user2"}, "foo": {"S": "bar"}} + ) + client.put_item( + TableName="users", Item={"username": {"S": "user3"}, "foo": {"S": "bar"}} ) - client.put_item(TableName='users', Item={'username': {'S': 'user1'}, 'foo': {'S': 'bar'}}) - client.put_item(TableName='users', Item={'username': {'S': 'user2'}, 'foo': {'S': 'bar'}}) - client.put_item(TableName='users', Item={'username': {'S': 'user3'}, 'foo': {'S': 'bar'}}) return client diff --git a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py index 7d1975eda..b12b41ac0 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_with_range_key.py @@ -11,6 +11,7 @@ from freezegun import freeze_time from moto import mock_dynamodb2, mock_dynamodb2_deprecated from boto.exception import JSONResponseError from tests.helpers import requires_boto_gte + try: from boto.dynamodb2.fields import GlobalAllIndex, HashKey, RangeKey, AllIndex from boto.dynamodb2.table import Item, Table @@ -22,36 +23,28 @@ except ImportError: def create_table(): - table = Table.create('messages', schema=[ - HashKey('forum_name'), - RangeKey('subject'), - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", + schema=[HashKey("forum_name"), RangeKey("subject")], + throughput={"read": 10, "write": 10}, + ) return table def create_table_with_local_indexes(): table = Table.create( - 'messages', - schema=[ - HashKey('forum_name'), - RangeKey('subject'), - ], - throughput={ - 'read': 10, - 'write': 10, - }, + "messages", + schema=[HashKey("forum_name"), RangeKey("subject")], + throughput={"read": 10, "write": 10}, indexes=[ AllIndex( - 'threads_index', + "threads_index", parts=[ - HashKey('forum_name', data_type=STRING), - RangeKey('threads', data_type=NUMBER), - ] + HashKey("forum_name", data_type=STRING), + RangeKey("threads", data_type=NUMBER), + ], ) - ] + ], ) return table @@ -67,25 +60,28 @@ def iterate_results(res): def test_create_table(): table = create_table() expected = { - 'Table': { - 'AttributeDefinitions': [ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - {'AttributeName': 'subject', 'AttributeType': 'S'} + "Table": { + "AttributeDefinitions": [ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - 'ProvisionedThroughput': { - 'NumberOfDecreasesToday': 0, 'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10 + "ProvisionedThroughput": { + "NumberOfDecreasesToday": 0, + "WriteCapacityUnits": 10, + "ReadCapacityUnits": 10, }, - 'TableSizeBytes': 0, - 'TableName': 'messages', - 'TableStatus': 'ACTIVE', - 'TableArn': 'arn:aws:dynamodb:us-east-1:123456789011:table/messages', - 'KeySchema': [ - {'KeyType': 'HASH', 'AttributeName': 'forum_name'}, - {'KeyType': 'RANGE', 'AttributeName': 'subject'} + "TableSizeBytes": 0, + "TableName": "messages", + "TableStatus": "ACTIVE", + "TableArn": "arn:aws:dynamodb:us-east-1:123456789011:table/messages", + "KeySchema": [ + {"KeyType": "HASH", "AttributeName": "forum_name"}, + {"KeyType": "RANGE", "AttributeName": "subject"}, ], - 'LocalSecondaryIndexes': [], - 'ItemCount': 0, 'CreationDateTime': 1326499200.0, - 'GlobalSecondaryIndexes': [] + "LocalSecondaryIndexes": [], + "ItemCount": 0, + "CreationDateTime": 1326499200.0, + "GlobalSecondaryIndexes": [], } } table.describe().should.equal(expected) @@ -97,38 +93,38 @@ def test_create_table(): def test_create_table_with_local_index(): table = create_table_with_local_indexes() expected = { - 'Table': { - 'AttributeDefinitions': [ - {'AttributeName': 'forum_name', 'AttributeType': 'S'}, - {'AttributeName': 'subject', 'AttributeType': 'S'}, - {'AttributeName': 'threads', 'AttributeType': 'N'} + "Table": { + "AttributeDefinitions": [ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "threads", "AttributeType": "N"}, ], - 'ProvisionedThroughput': { - 'NumberOfDecreasesToday': 0, - 'WriteCapacityUnits': 10, - 'ReadCapacityUnits': 10, + "ProvisionedThroughput": { + "NumberOfDecreasesToday": 0, + "WriteCapacityUnits": 10, + "ReadCapacityUnits": 10, }, - 'TableSizeBytes': 0, - 'TableName': 'messages', - 'TableStatus': 'ACTIVE', - 'TableArn': 'arn:aws:dynamodb:us-east-1:123456789011:table/messages', - 'KeySchema': [ - {'KeyType': 'HASH', 'AttributeName': 'forum_name'}, - {'KeyType': 'RANGE', 'AttributeName': 'subject'} + "TableSizeBytes": 0, + "TableName": "messages", + "TableStatus": "ACTIVE", + "TableArn": "arn:aws:dynamodb:us-east-1:123456789011:table/messages", + "KeySchema": [ + {"KeyType": "HASH", "AttributeName": "forum_name"}, + {"KeyType": "RANGE", "AttributeName": "subject"}, ], - 'LocalSecondaryIndexes': [ + "LocalSecondaryIndexes": [ { - 'IndexName': 'threads_index', - 'KeySchema': [ - {'AttributeName': 'forum_name', 'KeyType': 'HASH'}, - {'AttributeName': 'threads', 'KeyType': 'RANGE'} + "IndexName": "threads_index", + "KeySchema": [ + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "threads", "KeyType": "RANGE"}, ], - 'Projection': {'ProjectionType': 'ALL'} + "Projection": {"ProjectionType": "ALL"}, } ], - 'ItemCount': 0, - 'CreationDateTime': 1326499200.0, - 'GlobalSecondaryIndexes': [] + "ItemCount": 0, + "CreationDateTime": 1326499200.0, + "GlobalSecondaryIndexes": [], } } table.describe().should.equal(expected) @@ -143,8 +139,7 @@ def test_delete_table(): table.delete() conn.list_tables()["TableNames"].should.have.length_of(0) - conn.delete_table.when.called_with( - 'messages').should.throw(JSONResponseError) + conn.delete_table.when.called_with("messages").should.throw(JSONResponseError) @requires_boto_gte("2.9") @@ -153,18 +148,12 @@ def test_update_table_throughput(): table = create_table() table.throughput["read"].should.equal(10) table.throughput["write"].should.equal(10) - table.update(throughput={ - 'read': 5, - 'write': 15, - }) + table.update(throughput={"read": 5, "write": 15}) table.throughput["read"].should.equal(5) table.throughput["write"].should.equal(15) - table.update(throughput={ - 'read': 5, - 'write': 6, - }) + table.update(throughput={"read": 5, "write": 6}) table.describe() @@ -176,44 +165,45 @@ def test_update_table_throughput(): @mock_dynamodb2_deprecated def test_item_add_and_describe_and_update(): table = create_table() - ok = table.put_item(data={ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) + ok = table.put_item( + data={ + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) ok.should.equal(True) - table.get_item(forum_name="LOLCat Forum", - subject='Check this out!').should_not.be.none + table.get_item( + forum_name="LOLCat Forum", subject="Check this out!" + ).should_not.be.none - returned_item = table.get_item( - forum_name='LOLCat Forum', - subject='Check this out!' + returned_item = table.get_item(forum_name="LOLCat Forum", subject="Check this out!") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.save(overwrite=True) - returned_item = table.get_item( - forum_name='LOLCat Forum', - subject='Check this out!' + returned_item = table.get_item(forum_name="LOLCat Forum", subject="Check this out!") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "Check this out!", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'Check this out!', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) @requires_boto_gte("2.9") @@ -222,40 +212,38 @@ def test_item_partial_save(): table = create_table() data = { - 'forum_name': 'LOLCat Forum', - 'subject': 'The LOLz', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', + "forum_name": "LOLCat Forum", + "subject": "The LOLz", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", } table.put_item(data=data) - returned_item = table.get_item( - forum_name="LOLCat Forum", subject='The LOLz') + returned_item = table.get_item(forum_name="LOLCat Forum", subject="The LOLz") - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.partial_save() - returned_item = table.get_item( - forum_name='LOLCat Forum', - subject='The LOLz' + returned_item = table.get_item(forum_name="LOLCat Forum", subject="The LOLz") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "subject": "The LOLz", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'subject': 'The LOLz', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_item_put_without_table(): - table = Table('undeclared-table') + table = Table("undeclared-table") item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.save.when.called_with().should.throw(JSONResponseError) @@ -266,36 +254,35 @@ def test_item_put_without_table(): def test_get_missing_item(): table = create_table() - table.get_item.when.called_with( - hash_key='tester', - range_key='other', - ).should.throw(ValidationException) + table.get_item.when.called_with(hash_key="tester", range_key="other").should.throw( + ValidationException + ) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_item_with_undeclared_table(): - table = Table('undeclared-table') - table.get_item.when.called_with( - test_hash=3241526475).should.throw(JSONResponseError) + table = Table("undeclared-table") + table.get_item.when.called_with(test_hash=3241526475).should.throw( + JSONResponseError + ) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_item_without_range_key(): - table = Table.create('messages', schema=[ - HashKey('test_hash'), - RangeKey('test_range'), - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", + schema=[HashKey("test_hash"), RangeKey("test_range")], + throughput={"read": 10, "write": 10}, + ) hash_key = 3241526475 range_key = 1234567890987 - table.put_item(data={'test_hash': hash_key, 'test_range': range_key}) - table.get_item.when.called_with( - test_hash=hash_key).should.throw(ValidationException) + table.put_item(data={"test_hash": hash_key, "test_range": range_key}) + table.get_item.when.called_with(test_hash=hash_key).should.throw( + ValidationException + ) @requires_boto_gte("2.30.0") @@ -303,13 +290,13 @@ def test_get_item_without_range_key(): def test_delete_item(): table = create_table() item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) - item['subject'] = 'Check this out!' + item["subject"] = "Check this out!" item.save() table.count().should.equal(1) @@ -326,10 +313,10 @@ def test_delete_item(): def test_delete_item_with_undeclared_table(): table = Table("undeclared-table") item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.delete.when.called_with().should.throw(JSONResponseError) @@ -341,70 +328,65 @@ def test_query(): table = create_table() item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'subject': 'Check this out!' + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "subject": "Check this out!", } item = Item(table, item_data) item.save(overwrite=True) - item['forum_name'] = 'the-key' - item['subject'] = '456' + item["forum_name"] = "the-key" + item["subject"] = "456" item.save(overwrite=True) - item['forum_name'] = 'the-key' - item['subject'] = '123' + item["forum_name"] = "the-key" + item["subject"] = "123" item.save(overwrite=True) - item['forum_name'] = 'the-key' - item['subject'] = '789' + item["forum_name"] = "the-key" + item["subject"] = "789" item.save(overwrite=True) table.count().should.equal(4) - results = table.query_2(forum_name__eq='the-key', - subject__gt='1', consistent=True) + results = table.query_2(forum_name__eq="the-key", subject__gt="1", consistent=True) expected = ["123", "456", "789"] for index, item in enumerate(results): item["subject"].should.equal(expected[index]) - results = table.query_2(forum_name__eq="the-key", - subject__gt='1', reverse=True) + results = table.query_2(forum_name__eq="the-key", subject__gt="1", reverse=True) for index, item in enumerate(results): item["subject"].should.equal(expected[len(expected) - 1 - index]) - results = table.query_2(forum_name__eq='the-key', - subject__gt='1', consistent=True) + results = table.query_2(forum_name__eq="the-key", subject__gt="1", consistent=True) sum(1 for _ in results).should.equal(3) - results = table.query_2(forum_name__eq='the-key', - subject__gt='234', consistent=True) + results = table.query_2( + forum_name__eq="the-key", subject__gt="234", consistent=True + ) sum(1 for _ in results).should.equal(2) - results = table.query_2(forum_name__eq='the-key', subject__gt='9999') + results = table.query_2(forum_name__eq="the-key", subject__gt="9999") sum(1 for _ in results).should.equal(0) - results = table.query_2(forum_name__eq='the-key', subject__beginswith='12') + results = table.query_2(forum_name__eq="the-key", subject__beginswith="12") sum(1 for _ in results).should.equal(1) - results = table.query_2(forum_name__eq='the-key', subject__beginswith='7') + results = table.query_2(forum_name__eq="the-key", subject__beginswith="7") sum(1 for _ in results).should.equal(1) - results = table.query_2(forum_name__eq='the-key', - subject__between=['567', '890']) + results = table.query_2(forum_name__eq="the-key", subject__between=["567", "890"]) sum(1 for _ in results).should.equal(1) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_query_with_undeclared_table(): - table = Table('undeclared') + table = Table("undeclared") results = table.query( - forum_name__eq='Amazon DynamoDB', - subject__beginswith='DynamoDB', - limit=1 + forum_name__eq="Amazon DynamoDB", subject__beginswith="DynamoDB", limit=1 ) iterate_results.when.called_with(results).should.throw(JSONResponseError) @@ -414,30 +396,30 @@ def test_query_with_undeclared_table(): def test_scan(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key' - item_data['subject'] = '456' + item_data["forum_name"] = "the-key" + item_data["subject"] = "456" item = Item(table, item_data) item.save() - item['forum_name'] = 'the-key' - item['subject'] = '123' + item["forum_name"] = "the-key" + item["subject"] = "123" item.save() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:09 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:09 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item_data['forum_name'] = 'the-key' - item_data['subject'] = '789' + item_data["forum_name"] = "the-key" + item_data["subject"] = "789" item = Item(table, item_data) item.save() @@ -445,10 +427,10 @@ def test_scan(): results = table.scan() sum(1 for _ in results).should.equal(3) - results = table.scan(SentBy__eq='User B') + results = table.scan(SentBy__eq="User B") sum(1 for _ in results).should.equal(1) - results = table.scan(Body__beginswith='http') + results = table.scan(Body__beginswith="http") sum(1 for _ in results).should.equal(3) results = table.scan(Ids__null=False) @@ -469,13 +451,11 @@ def test_scan(): def test_scan_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(JSONResponseError) @@ -486,27 +466,28 @@ def test_scan_with_undeclared_table(): def test_write_batch(): table = create_table() with table.batch_write() as batch: - batch.put_item(data={ - 'forum_name': 'the-key', - 'subject': '123', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) - batch.put_item(data={ - 'forum_name': 'the-key', - 'subject': '789', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) + batch.put_item( + data={ + "forum_name": "the-key", + "subject": "123", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) + batch.put_item( + data={ + "forum_name": "the-key", + "subject": "789", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) table.count().should.equal(2) with table.batch_write() as batch: - batch.delete_item( - forum_name='the-key', - subject='789' - ) + batch.delete_item(forum_name="the-key", subject="789") table.count().should.equal(1) @@ -516,37 +497,37 @@ def test_write_batch(): def test_batch_read(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key' - item_data['subject'] = '456' + item_data["forum_name"] = "the-key" + item_data["subject"] = "456" item = Item(table, item_data) item.save() item = Item(table, item_data) - item_data['forum_name'] = 'the-key' - item_data['subject'] = '123' + item_data["forum_name"] = "the-key" + item_data["subject"] = "123" item.save() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } item = Item(table, item_data) - item_data['forum_name'] = 'another-key' - item_data['subject'] = '789' + item_data["forum_name"] = "another-key" + item_data["subject"] = "789" item.save() results = table.batch_get( keys=[ - {'forum_name': 'the-key', 'subject': '123'}, - {'forum_name': 'another-key', 'subject': '789'}, + {"forum_name": "the-key", "subject": "123"}, + {"forum_name": "another-key", "subject": "789"}, ] ) @@ -560,95 +541,76 @@ def test_batch_read(): def test_get_key_fields(): table = create_table() kf = table.get_key_fields() - kf.should.equal(['forum_name', 'subject']) + kf.should.equal(["forum_name", "subject"]) @mock_dynamodb2_deprecated def test_create_with_global_indexes(): conn = boto.dynamodb2.layer1.DynamoDBConnection() - Table.create('messages', schema=[ - HashKey('subject'), - RangeKey('version'), - ], global_indexes=[ - GlobalAllIndex('topic-created_at-index', - parts=[ - HashKey('topic'), - RangeKey('created_at', data_type='N') - ], - throughput={ - 'read': 6, - 'write': 1 - } - ), - ]) + Table.create( + "messages", + schema=[HashKey("subject"), RangeKey("version")], + global_indexes=[ + GlobalAllIndex( + "topic-created_at-index", + parts=[HashKey("topic"), RangeKey("created_at", data_type="N")], + throughput={"read": 6, "write": 1}, + ) + ], + ) table_description = conn.describe_table("messages") - table_description['Table']["GlobalSecondaryIndexes"].should.equal([ - { - "IndexName": "topic-created_at-index", - "KeySchema": [ - { - "AttributeName": "topic", - "KeyType": "HASH" + table_description["Table"]["GlobalSecondaryIndexes"].should.equal( + [ + { + "IndexName": "topic-created_at-index", + "KeySchema": [ + {"AttributeName": "topic", "KeyType": "HASH"}, + {"AttributeName": "created_at", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 6, + "WriteCapacityUnits": 1, }, - { - "AttributeName": "created_at", - "KeyType": "RANGE" - }, - ], - "Projection": { - "ProjectionType": "ALL" - }, - "ProvisionedThroughput": { - "ReadCapacityUnits": 6, - "WriteCapacityUnits": 1, } - } - ]) + ] + ) @mock_dynamodb2_deprecated def test_query_with_global_indexes(): - table = Table.create('messages', schema=[ - HashKey('subject'), - RangeKey('version'), - ], global_indexes=[ - GlobalAllIndex('topic-created_at-index', - parts=[ - HashKey('topic'), - RangeKey('created_at', data_type='N') - ], - throughput={ - 'read': 6, - 'write': 1 - } - ), - GlobalAllIndex('status-created_at-index', - parts=[ - HashKey('status'), - RangeKey('created_at', data_type='N') - ], - throughput={ - 'read': 2, - 'write': 1 - } - ) - ]) + table = Table.create( + "messages", + schema=[HashKey("subject"), RangeKey("version")], + global_indexes=[ + GlobalAllIndex( + "topic-created_at-index", + parts=[HashKey("topic"), RangeKey("created_at", data_type="N")], + throughput={"read": 6, "write": 1}, + ), + GlobalAllIndex( + "status-created_at-index", + parts=[HashKey("status"), RangeKey("created_at", data_type="N")], + throughput={"read": 2, "write": 1}, + ), + ], + ) item_data = { - 'subject': 'Check this out!', - 'version': '1', - 'created_at': 0, - 'status': 'inactive' + "subject": "Check this out!", + "version": "1", + "created_at": 0, + "status": "inactive", } item = Item(table, item_data) item.save(overwrite=True) - item['version'] = '2' + item["version"] = "2" item.save(overwrite=True) - results = table.query(status__eq='active') + results = table.query(status__eq="active") list(results).should.have.length_of(0) @@ -656,19 +618,20 @@ def test_query_with_global_indexes(): def test_query_with_local_indexes(): table = create_table_with_local_indexes() item_data = { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, - 'status': 'inactive' + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, + "status": "inactive", } item = Item(table, item_data) item.save(overwrite=True) - item['version'] = '2' + item["version"] = "2" item.save(overwrite=True) - results = table.query(forum_name__eq='Cool Forum', - index='threads_index', threads__eq=1) + results = table.query( + forum_name__eq="Cool Forum", index="threads_index", threads__eq=1 + ) list(results).should.have.length_of(1) @@ -678,29 +641,29 @@ def test_query_filter_eq(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query_2( - forum_name__eq='Cool Forum', index='threads_index', threads__eq=5 + forum_name__eq="Cool Forum", index="threads_index", threads__eq=5 ) list(results).should.have.length_of(1) @@ -711,30 +674,30 @@ def test_query_filter_lt(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__lt=5 + forum_name__eq="Cool Forum", index="threads_index", threads__lt=5 ) results = list(results) results.should.have.length_of(2) @@ -746,30 +709,30 @@ def test_query_filter_gt(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__gt=1 + forum_name__eq="Cool Forum", index="threads_index", threads__gt=1 ) list(results).should.have.length_of(1) @@ -780,30 +743,30 @@ def test_query_filter_lte(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__lte=5 + forum_name__eq="Cool Forum", index="threads_index", threads__lte=5 ) list(results).should.have.length_of(3) @@ -814,30 +777,30 @@ def test_query_filter_gte(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '1', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "1", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '1', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "1", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) results = table.query( - forum_name__eq='Cool Forum', index='threads_index', threads__gte=1 + forum_name__eq="Cool Forum", index="threads_index", threads__gte=1 ) list(results).should.have.length_of(2) @@ -848,37 +811,33 @@ def test_query_non_hash_range_key(): table = create_table_with_local_indexes() item_data = [ { - 'forum_name': 'Cool Forum', - 'subject': 'Check this out!', - 'version': '1', - 'threads': 1, + "forum_name": "Cool Forum", + "subject": "Check this out!", + "version": "1", + "threads": 1, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Read this now!', - 'version': '3', - 'threads': 5, + "forum_name": "Cool Forum", + "subject": "Read this now!", + "version": "3", + "threads": 5, }, { - 'forum_name': 'Cool Forum', - 'subject': 'Please read this... please', - 'version': '2', - 'threads': 0, - } + "forum_name": "Cool Forum", + "subject": "Please read this... please", + "version": "2", + "threads": 0, + }, ] for data in item_data: item = Item(table, data) item.save(overwrite=True) - results = table.query( - forum_name__eq='Cool Forum', version__gt="2" - ) + results = table.query(forum_name__eq="Cool Forum", version__gt="2") results = list(results) results.should.have.length_of(1) - results = table.query( - forum_name__eq='Cool Forum', version__lt="3" - ) + results = table.query(forum_name__eq="Cool Forum", version__lt="3") results = list(results) results.should.have.length_of(2) @@ -887,94 +846,83 @@ def test_query_non_hash_range_key(): def test_reverse_query(): conn = boto.dynamodb2.layer1.DynamoDBConnection() - table = Table.create('messages', schema=[ - HashKey('subject'), - RangeKey('created_at', data_type='N') - ]) + table = Table.create( + "messages", schema=[HashKey("subject"), RangeKey("created_at", data_type="N")] + ) for i in range(10): - table.put_item({ - 'subject': "Hi", - 'created_at': i - }) + table.put_item({"subject": "Hi", "created_at": i}) - results = table.query_2(subject__eq="Hi", - created_at__lt=6, - limit=4, - reverse=True) + results = table.query_2(subject__eq="Hi", created_at__lt=6, limit=4, reverse=True) expected = [Decimal(5), Decimal(4), Decimal(3), Decimal(2)] - [r['created_at'] for r in results].should.equal(expected) + [r["created_at"] for r in results].should.equal(expected) @mock_dynamodb2_deprecated def test_lookup(): from decimal import Decimal - table = Table.create('messages', schema=[ - HashKey('test_hash'), - RangeKey('test_range'), - ], throughput={ - 'read': 10, - 'write': 10, - }) + + table = Table.create( + "messages", + schema=[HashKey("test_hash"), RangeKey("test_range")], + throughput={"read": 10, "write": 10}, + ) hash_key = 3241526475 range_key = 1234567890987 - data = {'test_hash': hash_key, 'test_range': range_key} + data = {"test_hash": hash_key, "test_range": range_key} table.put_item(data=data) message = table.lookup(hash_key, range_key) - message.get('test_hash').should.equal(Decimal(hash_key)) - message.get('test_range').should.equal(Decimal(range_key)) + message.get("test_hash").should.equal(Decimal(hash_key)) + message.get("test_range").should.equal(Decimal(range_key)) @mock_dynamodb2_deprecated def test_failed_overwrite(): - table = Table.create('messages', schema=[ - HashKey('id'), - RangeKey('range'), - ], throughput={ - 'read': 7, - 'write': 3, - }) + table = Table.create( + "messages", + schema=[HashKey("id"), RangeKey("range")], + throughput={"read": 7, "write": 3}, + ) - data1 = {'id': '123', 'range': 'abc', 'data': '678'} + data1 = {"id": "123", "range": "abc", "data": "678"} table.put_item(data=data1) - data2 = {'id': '123', 'range': 'abc', 'data': '345'} + data2 = {"id": "123", "range": "abc", "data": "345"} table.put_item(data=data2, overwrite=True) - data3 = {'id': '123', 'range': 'abc', 'data': '812'} + data3 = {"id": "123", "range": "abc", "data": "812"} table.put_item.when.called_with(data=data3).should.throw( - ConditionalCheckFailedException) + ConditionalCheckFailedException + ) - returned_item = table.lookup('123', 'abc') + returned_item = table.lookup("123", "abc") dict(returned_item).should.equal(data2) - data4 = {'id': '123', 'range': 'ghi', 'data': 812} + data4 = {"id": "123", "range": "ghi", "data": 812} table.put_item(data=data4) - returned_item = table.lookup('123', 'ghi') + returned_item = table.lookup("123", "ghi") dict(returned_item).should.equal(data4) @mock_dynamodb2_deprecated def test_conflicting_writes(): - table = Table.create('messages', schema=[ - HashKey('id'), - RangeKey('range'), - ]) + table = Table.create("messages", schema=[HashKey("id"), RangeKey("range")]) - item_data = {'id': '123', 'range': 'abc', 'data': '678'} + item_data = {"id": "123", "range": "abc", "data": "678"} item1 = Item(table, item_data) item2 = Item(table, item_data) item1.save() - item1['data'] = '579' - item2['data'] = '912' + item1["data"] = "579" + item2["data"] = "912" item1.save() item2.save.when.called_with().should.throw(ConditionalCheckFailedException) + """ boto3 """ @@ -982,464 +930,351 @@ boto3 @mock_dynamodb2 def test_boto3_conditions(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123' - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '456' - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '789' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123"}) + table.put_item(Item={"forum_name": "the-key", "subject": "456"}) + table.put_item(Item={"forum_name": "the-key", "subject": "789"}) # Test a query returning all items results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('1'), + KeyConditionExpression=Key("forum_name").eq("the-key") & Key("subject").gt("1"), ScanIndexForward=True, ) expected = ["123", "456", "789"] - for index, item in enumerate(results['Items']): + for index, item in enumerate(results["Items"]): item["subject"].should.equal(expected[index]) # Return all items again, but in reverse results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('1'), + KeyConditionExpression=Key("forum_name").eq("the-key") & Key("subject").gt("1"), ScanIndexForward=False, ) - for index, item in enumerate(reversed(results['Items'])): + for index, item in enumerate(reversed(results["Items"])): item["subject"].should.equal(expected[index]) # Filter the subjects to only return some of the results results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('234'), + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").gt("234"), ConsistentRead=True, ) - results['Count'].should.equal(2) + results["Count"].should.equal(2) # Filter to return no results results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").gt('9999') + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").gt("9999") ) - results['Count'].should.equal(0) + results["Count"].should.equal(0) results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").begins_with('12') + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").begins_with("12") ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) results = table.query( - KeyConditionExpression=Key("subject").begins_with( - '7') & Key('forum_name').eq('the-key') + KeyConditionExpression=Key("subject").begins_with("7") + & Key("forum_name").eq("the-key") ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) results = table.query( - KeyConditionExpression=Key('forum_name').eq( - 'the-key') & Key("subject").between('567', '890') + KeyConditionExpression=Key("forum_name").eq("the-key") + & Key("subject").between("567", "890") ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) @mock_dynamodb2 def test_boto3_put_item_with_conditions(): import botocore - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123"}) table.put_item( - Item={ - 'forum_name': 'the-key-2', - 'subject': '1234', - }, - ConditionExpression='attribute_not_exists(forum_name) AND attribute_not_exists(subject)' + Item={"forum_name": "the-key-2", "subject": "1234"}, + ConditionExpression="attribute_not_exists(forum_name) AND attribute_not_exists(subject)", ) table.put_item.when.called_with( - Item={ - 'forum_name': 'the-key', - 'subject': '123' - }, - ConditionExpression='attribute_not_exists(forum_name) AND attribute_not_exists(subject)' + Item={"forum_name": "the-key", "subject": "123"}, + ConditionExpression="attribute_not_exists(forum_name) AND attribute_not_exists(subject)", ).should.throw(botocore.exceptions.ClientError) table.put_item.when.called_with( - Item={ - 'forum_name': 'bogus-key', - 'subject': 'bogus', - 'test': '123' - }, - ConditionExpression='attribute_exists(forum_name) AND attribute_exists(subject)' + Item={"forum_name": "bogus-key", "subject": "bogus", "test": "123"}, + ConditionExpression="attribute_exists(forum_name) AND attribute_exists(subject)", ).should.throw(botocore.exceptions.ClientError) def _create_table_with_range_key(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], - GlobalSecondaryIndexes=[{ - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', + GlobalSecondaryIndexes=[ + { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5, }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', - } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } - }], - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'created', - 'AttributeType': 'N' } ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "username", "AttributeType": "S"}, + {"AttributeName": "created", "AttributeType": "N"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - return dynamodb.Table('users') + return dynamodb.Table("users") @mock_dynamodb2 def test_update_item_range_key_set(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'username': 'johndoe', - 'created': Decimal('3'), - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "username": "johndoe", + "created": Decimal("3"), + } + ) - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} table.update_item( Key=item_key, AttributeUpdates={ - 'username': { - 'Action': u'PUT', - 'Value': 'johndoe2' - }, - 'created': { - 'Action': u'PUT', - 'Value': Decimal('4'), - }, - 'mapfield': { - 'Action': u'PUT', - 'Value': {'key': 'value'}, - } + "username": {"Action": "PUT", "Value": "johndoe2"}, + "created": {"Action": "PUT", "Value": Decimal("4")}, + "mapfield": {"Action": "PUT", "Value": {"key": "value"}}, }, ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'username': "johndoe2", - 'forum_name': 'the-key', - 'subject': '123', - 'created': '4', - 'mapfield': {'key': 'value'}, - }) + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + { + "username": "johndoe2", + "forum_name": "the-key", + "subject": "123", + "created": "4", + "mapfield": {"key": "value"}, + } + ) @mock_dynamodb2 def test_update_item_does_not_exist_is_created(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} result = table.update_item( Key=item_key, AttributeUpdates={ - 'username': { - 'Action': u'PUT', - 'Value': 'johndoe2' - }, - 'created': { - 'Action': u'PUT', - 'Value': Decimal('4'), - }, - 'mapfield': { - 'Action': u'PUT', - 'Value': {'key': 'value'}, - } + "username": {"Action": "PUT", "Value": "johndoe2"}, + "created": {"Action": "PUT", "Value": Decimal("4")}, + "mapfield": {"Action": "PUT", "Value": {"key": "value"}}, }, - ReturnValues='ALL_OLD', + ReturnValues="ALL_OLD", ) - assert not result.get('Attributes') + assert not result.get("Attributes") - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'username': "johndoe2", - 'forum_name': 'the-key', - 'subject': '123', - 'created': '4', - 'mapfield': {'key': 'value'}, - }) + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + { + "username": "johndoe2", + "forum_name": "the-key", + "subject": "123", + "created": "4", + "mapfield": {"key": "value"}, + } + ) @mock_dynamodb2 def test_update_item_add_value(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'numeric_field': Decimal('-1'), - }) - - item_key = {'forum_name': 'the-key', 'subject': '123'} - table.update_item( - Key=item_key, - AttributeUpdates={ - 'numeric_field': { - 'Action': u'ADD', - 'Value': Decimal('2'), - }, - }, + table.put_item( + Item={"forum_name": "the-key", "subject": "123", "numeric_field": Decimal("-1")} ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'numeric_field': '1', - 'forum_name': 'the-key', - 'subject': '123', - }) + item_key = {"forum_name": "the-key", "subject": "123"} + table.update_item( + Key=item_key, + AttributeUpdates={"numeric_field": {"Action": "ADD", "Value": Decimal("2")}}, + ) + + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + {"numeric_field": "1", "forum_name": "the-key", "subject": "123"} + ) @mock_dynamodb2 def test_update_item_add_value_string_set(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'string_set': set(['str1', 'str2']), - }) - - item_key = {'forum_name': 'the-key', 'subject': '123'} - table.update_item( - Key=item_key, - AttributeUpdates={ - 'string_set': { - 'Action': u'ADD', - 'Value': set(['str3']), - }, - }, + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "string_set": set(["str1", "str2"]), + } + ) + + item_key = {"forum_name": "the-key", "subject": "123"} + table.update_item( + Key=item_key, + AttributeUpdates={"string_set": {"Action": "ADD", "Value": set(["str3"])}}, + ) + + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + { + "string_set": set(["str1", "str2", "str3"]), + "forum_name": "the-key", + "subject": "123", + } ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'string_set': set(['str1', 'str2', 'str3']), - 'forum_name': 'the-key', - 'subject': '123', - }) @mock_dynamodb2 def test_update_item_delete_value_string_set(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'string_set': set(['str1', 'str2']), - }) - - item_key = {'forum_name': 'the-key', 'subject': '123'} - table.update_item( - Key=item_key, - AttributeUpdates={ - 'string_set': { - 'Action': u'DELETE', - 'Value': set(['str2']), - }, - }, + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "string_set": set(["str1", "str2"]), + } + ) + + item_key = {"forum_name": "the-key", "subject": "123"} + table.update_item( + Key=item_key, + AttributeUpdates={"string_set": {"Action": "DELETE", "Value": set(["str2"])}}, + ) + + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + {"string_set": set(["str1"]), "forum_name": "the-key", "subject": "123"} ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'string_set': set(['str1']), - 'forum_name': 'the-key', - 'subject': '123', - }) @mock_dynamodb2 def test_update_item_add_value_does_not_exist_is_created(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} table.update_item( Key=item_key, - AttributeUpdates={ - 'numeric_field': { - 'Action': u'ADD', - 'Value': Decimal('2'), - }, - }, + AttributeUpdates={"numeric_field": {"Action": "ADD", "Value": Decimal("2")}}, ) - returned_item = dict((k, str(v) if isinstance(v, Decimal) else v) - for k, v in table.get_item(Key=item_key)['Item'].items()) - dict(returned_item).should.equal({ - 'numeric_field': '2', - 'forum_name': 'the-key', - 'subject': '123', - }) + returned_item = dict( + (k, str(v) if isinstance(v, Decimal) else v) + for k, v in table.get_item(Key=item_key)["Item"].items() + ) + dict(returned_item).should.equal( + {"numeric_field": "2", "forum_name": "the-key", "subject": "123"} + ) @mock_dynamodb2 def test_update_item_with_expression(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'field': '1' - }) + table.put_item(Item={"forum_name": "the-key", "subject": "123", "field": "1"}) - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} - table.update_item( - Key=item_key, - UpdateExpression='SET field=2', + table.update_item(Key=item_key, UpdateExpression="SET field=2") + dict(table.get_item(Key=item_key)["Item"]).should.equal( + {"field": "2", "forum_name": "the-key", "subject": "123"} ) - dict(table.get_item(Key=item_key)['Item']).should.equal({ - 'field': '2', - 'forum_name': 'the-key', - 'subject': '123', - }) - table.update_item( - Key=item_key, - UpdateExpression='SET field = 3', + table.update_item(Key=item_key, UpdateExpression="SET field = 3") + dict(table.get_item(Key=item_key)["Item"]).should.equal( + {"field": "3", "forum_name": "the-key", "subject": "123"} ) - dict(table.get_item(Key=item_key)['Item']).should.equal({ - 'field': '3', - 'forum_name': 'the-key', - 'subject': '123', - }) + @mock_dynamodb2 def test_update_item_add_with_expression(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} current_item = { - 'forum_name': 'the-key', - 'subject': '123', - 'str_set': {'item1', 'item2', 'item3'}, - 'num_set': {1, 2, 3}, - 'num_val': 6 + "forum_name": "the-key", + "subject": "123", + "str_set": {"item1", "item2", "item3"}, + "num_set": {1, 2, 3}, + "num_val": 6, } # Put an entry in the DB to play with @@ -1448,69 +1283,56 @@ def test_update_item_add_with_expression(): # Update item to add a string value to a string set table.update_item( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': {'item4'} - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": {"item4"}}, ) - current_item['str_set'] = current_item['str_set'].union({'item4'}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["str_set"] = current_item["str_set"].union({"item4"}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Update item to add a num value to a num set table.update_item( Key=item_key, - UpdateExpression='ADD num_set :v', - ExpressionAttributeValues={ - ':v': {6} - } + UpdateExpression="ADD num_set :v", + ExpressionAttributeValues={":v": {6}}, ) - current_item['num_set'] = current_item['num_set'].union({6}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["num_set"] = current_item["num_set"].union({6}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Update item to add a value to a number value table.update_item( Key=item_key, - UpdateExpression='ADD num_val :v', - ExpressionAttributeValues={ - ':v': 20 - } + UpdateExpression="ADD num_val :v", + ExpressionAttributeValues={":v": 20}, ) - current_item['num_val'] = current_item['num_val'] + 20 - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["num_val"] = current_item["num_val"] + 20 + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to add a number value to a string set, should raise Client Error table.update_item.when.called_with( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': 20 - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": 20}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to add a number set to the string set, should raise a ClientError table.update_item.when.called_with( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': { 20 } - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": {20}}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to update with a bad expression table.update_item.when.called_with( - Key=item_key, - UpdateExpression='ADD str_set bad_value' + Key=item_key, UpdateExpression="ADD str_set bad_value" ).should.have.raised(ClientError) # Attempt to add a string value instead of a string set table.update_item.when.called_with( Key=item_key, - UpdateExpression='ADD str_set :v', - ExpressionAttributeValues={ - ':v': 'new_string' - } + UpdateExpression="ADD str_set :v", + ExpressionAttributeValues={":v": "new_string"}, ).should.have.raised(ClientError) @@ -1518,13 +1340,13 @@ def test_update_item_add_with_expression(): def test_update_item_delete_with_expression(): table = _create_table_with_range_key() - item_key = {'forum_name': 'the-key', 'subject': '123'} + item_key = {"forum_name": "the-key", "subject": "123"} current_item = { - 'forum_name': 'the-key', - 'subject': '123', - 'str_set': {'item1', 'item2', 'item3'}, - 'num_set': {1, 2, 3}, - 'num_val': 6 + "forum_name": "the-key", + "subject": "123", + "str_set": {"item1", "item2", "item3"}, + "num_set": {1, 2, 3}, + "num_val": 6, } # Put an entry in the DB to play with @@ -1533,49 +1355,40 @@ def test_update_item_delete_with_expression(): # Update item to delete a string value from a string set table.update_item( Key=item_key, - UpdateExpression='DELETE str_set :v', - ExpressionAttributeValues={ - ':v': {'item2'} - } + UpdateExpression="DELETE str_set :v", + ExpressionAttributeValues={":v": {"item2"}}, ) - current_item['str_set'] = current_item['str_set'].difference({'item2'}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["str_set"] = current_item["str_set"].difference({"item2"}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Update item to delete a num value from a num set table.update_item( Key=item_key, - UpdateExpression='DELETE num_set :v', - ExpressionAttributeValues={ - ':v': {2} - } + UpdateExpression="DELETE num_set :v", + ExpressionAttributeValues={":v": {2}}, ) - current_item['num_set'] = current_item['num_set'].difference({2}) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + current_item["num_set"] = current_item["num_set"].difference({2}) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Try to delete on a number, this should fail table.update_item.when.called_with( Key=item_key, - UpdateExpression='DELETE num_val :v', - ExpressionAttributeValues={ - ':v': 20 - } + UpdateExpression="DELETE num_val :v", + ExpressionAttributeValues={":v": 20}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Try to delete a string set from a number set table.update_item.when.called_with( Key=item_key, - UpdateExpression='DELETE num_set :v', - ExpressionAttributeValues={ - ':v': {'del_str'} - } + UpdateExpression="DELETE num_set :v", + ExpressionAttributeValues={":v": {"del_str"}}, ).should.have.raised(ClientError) - dict(table.get_item(Key=item_key)['Item']).should.equal(current_item) + dict(table.get_item(Key=item_key)["Item"]).should.equal(current_item) # Attempt to update with a bad expression table.update_item.when.called_with( - Key=item_key, - UpdateExpression='DELETE num_val badvalue' + Key=item_key, UpdateExpression="DELETE num_val badvalue" ).should.have.raised(ClientError) @@ -1583,394 +1396,309 @@ def test_update_item_delete_with_expression(): def test_boto3_query_gsi_range_comparison(): table = _create_table_with_range_key() - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '123', - 'username': 'johndoe', - 'created': 3, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '456', - 'username': 'johndoe', - 'created': 1, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '789', - 'username': 'johndoe', - 'created': 2, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '159', - 'username': 'janedoe', - 'created': 2, - }) - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '601', - 'username': 'janedoe', - 'created': 5, - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "123", + "username": "johndoe", + "created": 3, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "456", + "username": "johndoe", + "created": 1, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "789", + "username": "johndoe", + "created": 2, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "159", + "username": "janedoe", + "created": 2, + } + ) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "601", + "username": "janedoe", + "created": 5, + } + ) # Test a query returning all johndoe items results = table.query( - KeyConditionExpression=Key('username').eq( - 'johndoe') & Key("created").gt(0), + KeyConditionExpression=Key("username").eq("johndoe") & Key("created").gt(0), ScanIndexForward=True, - IndexName='TestGSI', + IndexName="TestGSI", ) expected = ["456", "789", "123"] - for index, item in enumerate(results['Items']): + for index, item in enumerate(results["Items"]): item["subject"].should.equal(expected[index]) # Return all johndoe items again, but in reverse results = table.query( - KeyConditionExpression=Key('username').eq( - 'johndoe') & Key("created").gt(0), + KeyConditionExpression=Key("username").eq("johndoe") & Key("created").gt(0), ScanIndexForward=False, - IndexName='TestGSI', + IndexName="TestGSI", ) - for index, item in enumerate(reversed(results['Items'])): + for index, item in enumerate(reversed(results["Items"])): item["subject"].should.equal(expected[index]) # Filter the creation to only return some of the results # And reverse order of hash + range key results = table.query( - KeyConditionExpression=Key("created").gt( - 1) & Key('username').eq('johndoe'), + KeyConditionExpression=Key("created").gt(1) & Key("username").eq("johndoe"), ConsistentRead=True, - IndexName='TestGSI', + IndexName="TestGSI", ) - results['Count'].should.equal(2) + results["Count"].should.equal(2) # Filter to return no results results = table.query( - KeyConditionExpression=Key('username').eq( - 'janedoe') & Key("created").gt(9), - IndexName='TestGSI', + KeyConditionExpression=Key("username").eq("janedoe") & Key("created").gt(9), + IndexName="TestGSI", ) - results['Count'].should.equal(0) + results["Count"].should.equal(0) results = table.query( - KeyConditionExpression=Key('username').eq( - 'janedoe') & Key("created").eq(5), - IndexName='TestGSI', + KeyConditionExpression=Key("username").eq("janedoe") & Key("created").eq(5), + IndexName="TestGSI", ) - results['Count'].should.equal(1) + results["Count"].should.equal(1) # Test range key sorting results = table.query( - KeyConditionExpression=Key('username').eq( - 'johndoe') & Key("created").gt(0), - IndexName='TestGSI', + KeyConditionExpression=Key("username").eq("johndoe") & Key("created").gt(0), + IndexName="TestGSI", ) - expected = [Decimal('1'), Decimal('2'), Decimal('3')] - for index, item in enumerate(results['Items']): + expected = [Decimal("1"), Decimal("2"), Decimal("3")] + for index, item in enumerate(results["Items"]): item["created"].should.equal(expected[index]) @mock_dynamodb2 def test_boto3_update_table_throughput(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.provisioned_throughput['ReadCapacityUnits'].should.equal(5) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(6) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(5) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(6) - table.update(ProvisionedThroughput={ - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 11, - }) + table.update( + ProvisionedThroughput={"ReadCapacityUnits": 10, "WriteCapacityUnits": 11} + ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - table.provisioned_throughput['ReadCapacityUnits'].should.equal(10) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(11) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(10) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(11) @mock_dynamodb2 def test_boto3_update_table_gsi_throughput(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], - GlobalSecondaryIndexes=[{ - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', + GlobalSecondaryIndexes=[ + { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 4, }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', - } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 3, - 'WriteCapacityUnits': 4 - } - }], - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'created', - 'AttributeType': 'S' } ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "username", "AttributeType": "S"}, + {"AttributeName": "created", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - gsi_throughput['ReadCapacityUnits'].should.equal(3) - gsi_throughput['WriteCapacityUnits'].should.equal(4) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + gsi_throughput["ReadCapacityUnits"].should.equal(3) + gsi_throughput["WriteCapacityUnits"].should.equal(4) - table.provisioned_throughput['ReadCapacityUnits'].should.equal(5) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(6) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(5) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(6) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Update': { - 'IndexName': 'TestGSI', - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 11, + table.update( + GlobalSecondaryIndexUpdates=[ + { + "Update": { + "IndexName": "TestGSI", + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 11, + }, + } } - }, - }]) + ] + ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") # Primary throughput has not changed - table.provisioned_throughput['ReadCapacityUnits'].should.equal(5) - table.provisioned_throughput['WriteCapacityUnits'].should.equal(6) + table.provisioned_throughput["ReadCapacityUnits"].should.equal(5) + table.provisioned_throughput["WriteCapacityUnits"].should.equal(6) - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - gsi_throughput['ReadCapacityUnits'].should.equal(10) - gsi_throughput['WriteCapacityUnits'].should.equal(11) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + gsi_throughput["ReadCapacityUnits"].should.equal(10) + gsi_throughput["WriteCapacityUnits"].should.equal(11) @mock_dynamodb2 def test_update_table_gsi_create(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(0) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Create': { - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', - }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', + table.update( + GlobalSecondaryIndexUpdates=[ + { + "Create": { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 4, + }, } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 3, - 'WriteCapacityUnits': 4 } - }, - }]) + ] + ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(1) - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - assert gsi_throughput['ReadCapacityUnits'].should.equal(3) - assert gsi_throughput['WriteCapacityUnits'].should.equal(4) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + assert gsi_throughput["ReadCapacityUnits"].should.equal(3) + assert gsi_throughput["WriteCapacityUnits"].should.equal(4) # Check update works - table.update(GlobalSecondaryIndexUpdates=[{ - 'Update': { - 'IndexName': 'TestGSI', - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 10, - 'WriteCapacityUnits': 11, + table.update( + GlobalSecondaryIndexUpdates=[ + { + "Update": { + "IndexName": "TestGSI", + "ProvisionedThroughput": { + "ReadCapacityUnits": 10, + "WriteCapacityUnits": 11, + }, + } } - }, - }]) - table = dynamodb.Table('users') + ] + ) + table = dynamodb.Table("users") - gsi_throughput = table.global_secondary_indexes[0]['ProvisionedThroughput'] - assert gsi_throughput['ReadCapacityUnits'].should.equal(10) - assert gsi_throughput['WriteCapacityUnits'].should.equal(11) + gsi_throughput = table.global_secondary_indexes[0]["ProvisionedThroughput"] + assert gsi_throughput["ReadCapacityUnits"].should.equal(10) + assert gsi_throughput["WriteCapacityUnits"].should.equal(11) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Delete': { - 'IndexName': 'TestGSI', - }, - }]) + table.update(GlobalSecondaryIndexUpdates=[{"Delete": {"IndexName": "TestGSI"}}]) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(0) @mock_dynamodb2 def test_update_table_gsi_throughput(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. table = dynamodb.create_table( - TableName='users', + TableName="users", KeySchema=[ - { - 'AttributeName': 'forum_name', - 'KeyType': 'HASH' - }, - { - 'AttributeName': 'subject', - 'KeyType': 'RANGE' - }, + {"AttributeName": "forum_name", "KeyType": "HASH"}, + {"AttributeName": "subject", "KeyType": "RANGE"}, ], - GlobalSecondaryIndexes=[{ - 'IndexName': 'TestGSI', - 'KeySchema': [ - { - 'AttributeName': 'username', - 'KeyType': 'HASH', + GlobalSecondaryIndexes=[ + { + "IndexName": "TestGSI", + "KeySchema": [ + {"AttributeName": "username", "KeyType": "HASH"}, + {"AttributeName": "created", "KeyType": "RANGE"}, + ], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 3, + "WriteCapacityUnits": 4, }, - { - 'AttributeName': 'created', - 'KeyType': 'RANGE', - } - ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 3, - 'WriteCapacityUnits': 4 - } - }], - AttributeDefinitions=[ - { - 'AttributeName': 'forum_name', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'subject', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - { - 'AttributeName': 'created', - 'AttributeType': 'S' } ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 6 - } + AttributeDefinitions=[ + {"AttributeName": "forum_name", "AttributeType": "S"}, + {"AttributeName": "subject", "AttributeType": "S"}, + {"AttributeName": "username", "AttributeType": "S"}, + {"AttributeName": "created", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 6}, ) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(1) - table.update(GlobalSecondaryIndexUpdates=[{ - 'Delete': { - 'IndexName': 'TestGSI', - }, - }]) + table.update(GlobalSecondaryIndexUpdates=[{"Delete": {"IndexName": "TestGSI"}}]) - table = dynamodb.Table('users') + table = dynamodb.Table("users") table.global_secondary_indexes.should.have.length_of(0) @@ -1978,140 +1706,131 @@ def test_update_table_gsi_throughput(): def test_query_pagination(): table = _create_table_with_range_key() for i in range(10): - table.put_item(Item={ - 'forum_name': 'the-key', - 'subject': '{0}'.format(i), - 'username': 'johndoe', - 'created': Decimal('3'), - }) + table.put_item( + Item={ + "forum_name": "the-key", + "subject": "{0}".format(i), + "username": "johndoe", + "created": Decimal("3"), + } + ) - page1 = table.query( - KeyConditionExpression=Key('forum_name').eq('the-key'), - Limit=6 - ) - page1['Count'].should.equal(6) - page1['Items'].should.have.length_of(6) - page1.should.have.key('LastEvaluatedKey') + page1 = table.query(KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6) + page1["Count"].should.equal(6) + page1["Items"].should.have.length_of(6) + page1.should.have.key("LastEvaluatedKey") page2 = table.query( - KeyConditionExpression=Key('forum_name').eq('the-key'), + KeyConditionExpression=Key("forum_name").eq("the-key"), Limit=6, - ExclusiveStartKey=page1['LastEvaluatedKey'] + ExclusiveStartKey=page1["LastEvaluatedKey"], ) - page2['Count'].should.equal(4) - page2['Items'].should.have.length_of(4) - page2.should_not.have.key('LastEvaluatedKey') + page2["Count"].should.equal(4) + page2["Items"].should.have.length_of(4) + page2.should_not.have.key("LastEvaluatedKey") - results = page1['Items'] + page2['Items'] - subjects = set([int(r['subject']) for r in results]) + results = page1["Items"] + page2["Items"] + subjects = set([int(r["subject"]) for r in results]) subjects.should.equal(set(range(10))) @mock_dynamodb2 def test_scan_by_index(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', + TableName="test", KeySchema=[ - {'AttributeName': 'id', 'KeyType': 'HASH'}, - {'AttributeName': 'range_key', 'KeyType': 'RANGE'}, + {"AttributeName": "id", "KeyType": "HASH"}, + {"AttributeName": "range_key", "KeyType": "RANGE"}, ], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'range_key', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_col', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_range_key', 'AttributeType': 'S'}, - {'AttributeName': 'lsi_range_key', 'AttributeType': 'S'}, + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "range_key", "AttributeType": "S"}, + {"AttributeName": "gsi_col", "AttributeType": "S"}, + {"AttributeName": "gsi_range_key", "AttributeType": "S"}, + {"AttributeName": "lsi_range_key", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - {'AttributeName': 'gsi_col', 'KeyType': 'HASH'}, - {'AttributeName': 'gsi_range_key', 'KeyType': 'RANGE'}, + "IndexName": "test_gsi", + "KeySchema": [ + {"AttributeName": "gsi_col", "KeyType": "HASH"}, + {"AttributeName": "gsi_range_key", "KeyType": "RANGE"}, ], - 'Projection': { - 'ProjectionType': 'ALL', + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, + } ], LocalSecondaryIndexes=[ { - 'IndexName': 'test_lsi', - 'KeySchema': [ - {'AttributeName': 'id', 'KeyType': 'HASH'}, - {'AttributeName': 'lsi_range_key', 'KeyType': 'RANGE'}, + "IndexName": "test_lsi", + "KeySchema": [ + {"AttributeName": "id", "KeyType": "HASH"}, + {"AttributeName": "lsi_range_key", "KeyType": "RANGE"}, ], - 'Projection': { - 'ProjectionType': 'ALL', - }, - }, - ] + "Projection": {"ProjectionType": "ALL"}, + } + ], ) dynamodb.put_item( - TableName='test', + TableName="test", Item={ - 'id': {'S': '1'}, - 'range_key': {'S': '1'}, - 'col1': {'S': 'val1'}, - 'gsi_col': {'S': '1'}, - 'gsi_range_key': {'S': '1'}, - 'lsi_range_key': {'S': '1'}, - } + "id": {"S": "1"}, + "range_key": {"S": "1"}, + "col1": {"S": "val1"}, + "gsi_col": {"S": "1"}, + "gsi_range_key": {"S": "1"}, + "lsi_range_key": {"S": "1"}, + }, ) dynamodb.put_item( - TableName='test', + TableName="test", Item={ - 'id': {'S': '1'}, - 'range_key': {'S': '2'}, - 'col1': {'S': 'val2'}, - 'gsi_col': {'S': '1'}, - 'gsi_range_key': {'S': '2'}, - 'lsi_range_key': {'S': '2'}, - } + "id": {"S": "1"}, + "range_key": {"S": "2"}, + "col1": {"S": "val2"}, + "gsi_col": {"S": "1"}, + "gsi_range_key": {"S": "2"}, + "lsi_range_key": {"S": "2"}, + }, ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '3'}, - 'range_key': {'S': '1'}, - 'col1': {'S': 'val3'}, - } + TableName="test", + Item={"id": {"S": "3"}, "range_key": {"S": "1"}, "col1": {"S": "val3"}}, ) - res = dynamodb.scan(TableName='test') - assert res['Count'] == 3 - assert len(res['Items']) == 3 + res = dynamodb.scan(TableName="test") + assert res["Count"] == 3 + assert len(res["Items"]) == 3 - res = dynamodb.scan(TableName='test', IndexName='test_gsi') - assert res['Count'] == 2 - assert len(res['Items']) == 2 + res = dynamodb.scan(TableName="test", IndexName="test_gsi") + assert res["Count"] == 2 + assert len(res["Items"]) == 2 - res = dynamodb.scan(TableName='test', IndexName='test_gsi', Limit=1) - assert res['Count'] == 1 - assert len(res['Items']) == 1 - last_eval_key = res['LastEvaluatedKey'] - assert last_eval_key['id']['S'] == '1' - assert last_eval_key['gsi_col']['S'] == '1' - assert last_eval_key['gsi_range_key']['S'] == '1' + res = dynamodb.scan(TableName="test", IndexName="test_gsi", Limit=1) + assert res["Count"] == 1 + assert len(res["Items"]) == 1 + last_eval_key = res["LastEvaluatedKey"] + assert last_eval_key["id"]["S"] == "1" + assert last_eval_key["gsi_col"]["S"] == "1" + assert last_eval_key["gsi_range_key"]["S"] == "1" - res = dynamodb.scan(TableName='test', IndexName='test_lsi') - assert res['Count'] == 2 - assert len(res['Items']) == 2 + res = dynamodb.scan(TableName="test", IndexName="test_lsi") + assert res["Count"] == 2 + assert len(res["Items"]) == 2 - res = dynamodb.scan(TableName='test', IndexName='test_lsi', Limit=1) - assert res['Count'] == 1 - assert len(res['Items']) == 1 - last_eval_key = res['LastEvaluatedKey'] - assert last_eval_key['id']['S'] == '1' - assert last_eval_key['range_key']['S'] == '1' - assert last_eval_key['lsi_range_key']['S'] == '1' + res = dynamodb.scan(TableName="test", IndexName="test_lsi", Limit=1) + assert res["Count"] == 1 + assert len(res["Items"]) == 1 + last_eval_key = res["LastEvaluatedKey"] + assert last_eval_key["id"]["S"] == "1" + assert last_eval_key["range_key"]["S"] == "1" + assert last_eval_key["lsi_range_key"]["S"] == "1" diff --git a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py index b2209d990..08d7724f8 100644 --- a/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py +++ b/tests/test_dynamodb2/test_dynamodb_table_without_range_key.py @@ -9,6 +9,7 @@ from boto.exception import JSONResponseError from moto import mock_dynamodb2, mock_dynamodb2_deprecated from tests.helpers import requires_boto_gte import botocore + try: from boto.dynamodb2.fields import HashKey from boto.dynamodb2.table import Table @@ -19,12 +20,9 @@ except ImportError: def create_table(): - table = Table.create('messages', schema=[ - HashKey('forum_name') - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", schema=[HashKey("forum_name")], throughput={"read": 10, "write": 10} + ) return table @@ -34,32 +32,31 @@ def create_table(): def test_create_table(): create_table() expected = { - 'Table': { - 'AttributeDefinitions': [ - {'AttributeName': 'forum_name', 'AttributeType': 'S'} + "Table": { + "AttributeDefinitions": [ + {"AttributeName": "forum_name", "AttributeType": "S"} ], - 'ProvisionedThroughput': { - 'NumberOfDecreasesToday': 0, 'WriteCapacityUnits': 10, 'ReadCapacityUnits': 10 + "ProvisionedThroughput": { + "NumberOfDecreasesToday": 0, + "WriteCapacityUnits": 10, + "ReadCapacityUnits": 10, }, - 'TableSizeBytes': 0, - 'TableName': 'messages', - 'TableStatus': 'ACTIVE', - 'TableArn': 'arn:aws:dynamodb:us-east-1:123456789011:table/messages', - 'KeySchema': [ - {'KeyType': 'HASH', 'AttributeName': 'forum_name'} - ], - 'ItemCount': 0, 'CreationDateTime': 1326499200.0, - 'GlobalSecondaryIndexes': [], - 'LocalSecondaryIndexes': [] + "TableSizeBytes": 0, + "TableName": "messages", + "TableStatus": "ACTIVE", + "TableArn": "arn:aws:dynamodb:us-east-1:123456789011:table/messages", + "KeySchema": [{"KeyType": "HASH", "AttributeName": "forum_name"}], + "ItemCount": 0, + "CreationDateTime": 1326499200.0, + "GlobalSecondaryIndexes": [], + "LocalSecondaryIndexes": [], } } conn = boto.dynamodb2.connect_to_region( - 'us-east-1', - aws_access_key_id="ak", - aws_secret_access_key="sk" + "us-east-1", aws_access_key_id="ak", aws_secret_access_key="sk" ) - conn.describe_table('messages').should.equal(expected) + conn.describe_table("messages").should.equal(expected) @requires_boto_gte("2.9") @@ -69,11 +66,10 @@ def test_delete_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.list_tables()["TableNames"].should.have.length_of(1) - conn.delete_table('messages') + conn.delete_table("messages") conn.list_tables()["TableNames"].should.have.length_of(0) - conn.delete_table.when.called_with( - 'messages').should.throw(JSONResponseError) + conn.delete_table.when.called_with("messages").should.throw(JSONResponseError) @requires_boto_gte("2.9") @@ -83,10 +79,7 @@ def test_update_table_throughput(): table.throughput["read"].should.equal(10) table.throughput["write"].should.equal(10) - table.update(throughput={ - 'read': 5, - 'write': 6, - }) + table.update(throughput={"read": 5, "write": 6}) table.throughput["read"].should.equal(5) table.throughput["write"].should.equal(6) @@ -98,32 +91,34 @@ def test_item_add_and_describe_and_update(): table = create_table() data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", } table.put_item(data=data) returned_item = table.get_item(forum_name="LOLCat Forum") returned_item.should_not.be.none - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - }) + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + } + ) - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.save(overwrite=True) - returned_item = table.get_item( - forum_name='LOLCat Forum' + returned_item = table.get_item(forum_name="LOLCat Forum") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @requires_boto_gte("2.9") @@ -132,25 +127,25 @@ def test_item_partial_save(): table = create_table() data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", } table.put_item(data=data) returned_item = table.get_item(forum_name="LOLCat Forum") - returned_item['SentBy'] = 'User B' + returned_item["SentBy"] = "User B" returned_item.partial_save() - returned_item = table.get_item( - forum_name='LOLCat Forum' + returned_item = table.get_item(forum_name="LOLCat Forum") + dict(returned_item).should.equal( + { + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + } ) - dict(returned_item).should.equal({ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - }) @requires_boto_gte("2.9") @@ -159,12 +154,12 @@ def test_item_put_without_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.put_item.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", item={ - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - } + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + }, ).should.throw(JSONResponseError) @@ -174,8 +169,7 @@ def test_get_item_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.get_item.when.called_with( - table_name='undeclared-table', - key={"forum_name": {"S": "LOLCat Forum"}}, + table_name="undeclared-table", key={"forum_name": {"S": "LOLCat Forum"}} ).should.throw(JSONResponseError) @@ -185,10 +179,10 @@ def test_delete_item(): table = create_table() item_data = { - 'forum_name': 'LOLCat Forum', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "LOLCat Forum", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.save() @@ -210,8 +204,7 @@ def test_delete_item_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.delete_item.when.called_with( - table_name='undeclared-table', - key={"forum_name": {"S": "LOLCat Forum"}}, + table_name="undeclared-table", key={"forum_name": {"S": "LOLCat Forum"}} ).should.throw(JSONResponseError) @@ -221,17 +214,17 @@ def test_query(): table = create_table() item_data = { - 'forum_name': 'the-key', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "forum_name": "the-key", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } item = Item(table, item_data) item.save(overwrite=True) table.count().should.equal(1) table = Table("messages") - results = table.query(forum_name__eq='the-key') + results = table.query(forum_name__eq="the-key") sum(1 for _ in results).should.equal(1) @@ -241,9 +234,13 @@ def test_query_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.query.when.called_with( - table_name='undeclared-table', - key_conditions={"forum_name": { - "ComparisonOperator": "EQ", "AttributeValueList": [{"S": "the-key"}]}} + table_name="undeclared-table", + key_conditions={ + "forum_name": { + "ComparisonOperator": "EQ", + "AttributeValueList": [{"S": "the-key"}], + } + }, ).should.throw(JSONResponseError) @@ -253,36 +250,36 @@ def test_scan(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key' + item_data["forum_name"] = "the-key" item = Item(table, item_data) item.save() - item['forum_name'] = 'the-key2' + item["forum_name"] = "the-key2" item.save(overwrite=True) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } - item_data['forum_name'] = 'the-key3' + item_data["forum_name"] = "the-key3" item = Item(table, item_data) item.save() results = table.scan() sum(1 for _ in results).should.equal(3) - results = table.scan(SentBy__eq='User B') + results = table.scan(SentBy__eq="User B") sum(1 for _ in results).should.equal(1) - results = table.scan(Body__beginswith='http') + results = table.scan(Body__beginswith="http") sum(1 for _ in results).should.equal(3) results = table.scan(Ids__null=False) @@ -304,13 +301,11 @@ def test_scan_with_undeclared_table(): conn = boto.dynamodb2.layer1.DynamoDBConnection() conn.scan.when.called_with( - table_name='undeclared-table', + table_name="undeclared-table", scan_filter={ "SentBy": { - "AttributeValueList": [{ - "S": "User B"} - ], - "ComparisonOperator": "EQ" + "AttributeValueList": [{"S": "User B"}], + "ComparisonOperator": "EQ", } }, ).should.throw(JSONResponseError) @@ -322,27 +317,28 @@ def test_write_batch(): table = create_table() with table.batch_write() as batch: - batch.put_item(data={ - 'forum_name': 'the-key', - 'subject': '123', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) - batch.put_item(data={ - 'forum_name': 'the-key2', - 'subject': '789', - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - }) + batch.put_item( + data={ + "forum_name": "the-key", + "subject": "123", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) + batch.put_item( + data={ + "forum_name": "the-key2", + "subject": "789", + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + } + ) table.count().should.equal(2) with table.batch_write() as batch: - batch.delete_item( - forum_name='the-key', - subject='789' - ) + batch.delete_item(forum_name="the-key", subject="789") table.count().should.equal(1) @@ -353,34 +349,31 @@ def test_batch_read(): table = create_table() item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User A', - 'ReceivedTime': '12/9/2011 11:36:03 PM', + "Body": "http://url_to_lolcat.gif", + "SentBy": "User A", + "ReceivedTime": "12/9/2011 11:36:03 PM", } - item_data['forum_name'] = 'the-key1' + item_data["forum_name"] = "the-key1" item = Item(table, item_data) item.save() item = Item(table, item_data) - item_data['forum_name'] = 'the-key2' + item_data["forum_name"] = "the-key2" item.save(overwrite=True) item_data = { - 'Body': 'http://url_to_lolcat.gif', - 'SentBy': 'User B', - 'ReceivedTime': '12/9/2011 11:36:03 PM', - 'Ids': set([1, 2, 3]), - 'PK': 7, + "Body": "http://url_to_lolcat.gif", + "SentBy": "User B", + "ReceivedTime": "12/9/2011 11:36:03 PM", + "Ids": set([1, 2, 3]), + "PK": 7, } item = Item(table, item_data) - item_data['forum_name'] = 'another-key' + item_data["forum_name"] = "another-key" item.save(overwrite=True) results = table.batch_get( - keys=[ - {'forum_name': 'the-key1'}, - {'forum_name': 'another-key'}, - ] + keys=[{"forum_name": "the-key1"}, {"forum_name": "another-key"}] ) # Iterate through so that batch_item gets called @@ -393,196 +386,136 @@ def test_batch_read(): def test_get_key_fields(): table = create_table() kf = table.get_key_fields() - kf[0].should.equal('forum_name') + kf[0].should.equal("forum_name") @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_missing_item(): table = create_table() - table.get_item.when.called_with( - forum_name='missing').should.throw(ItemNotFound) + table.get_item.when.called_with(forum_name="missing").should.throw(ItemNotFound) @requires_boto_gte("2.9") @mock_dynamodb2_deprecated def test_get_special_item(): - table = Table.create('messages', schema=[ - HashKey('date-joined') - ], throughput={ - 'read': 10, - 'write': 10, - }) + table = Table.create( + "messages", + schema=[HashKey("date-joined")], + throughput={"read": 10, "write": 10}, + ) - data = { - 'date-joined': 127549192, - 'SentBy': 'User A', - } + data = {"date-joined": 127549192, "SentBy": "User A"} table.put_item(data=data) - returned_item = table.get_item(**{'date-joined': 127549192}) + returned_item = table.get_item(**{"date-joined": 127549192}) dict(returned_item).should.equal(data) @mock_dynamodb2_deprecated def test_update_item_remove(): conn = boto.dynamodb2.connect_to_region("us-east-1") - table = Table.create('messages', schema=[ - HashKey('username') - ]) + table = Table.create("messages", schema=[HashKey("username")]) - data = { - 'username': "steve", - 'SentBy': 'User A', - 'SentTo': 'User B', - } + data = {"username": "steve", "SentBy": "User A", "SentTo": "User B"} table.put_item(data=data) - key_map = { - 'username': {"S": "steve"} - } + key_map = {"username": {"S": "steve"}} # Then remove the SentBy field - conn.update_item("messages", key_map, - update_expression="REMOVE SentBy, SentTo") + conn.update_item("messages", key_map, update_expression="REMOVE SentBy, SentTo") returned_item = table.get_item(username="steve") - dict(returned_item).should.equal({ - 'username': "steve", - }) + dict(returned_item).should.equal({"username": "steve"}) @mock_dynamodb2_deprecated def test_update_item_nested_remove(): conn = boto.dynamodb2.connect_to_region("us-east-1") - table = Table.create('messages', schema=[ - HashKey('username') - ]) + table = Table.create("messages", schema=[HashKey("username")]) - data = { - 'username': "steve", - 'Meta': { - 'FullName': 'Steve Urkel' - } - } + data = {"username": "steve", "Meta": {"FullName": "Steve Urkel"}} table.put_item(data=data) - key_map = { - 'username': {"S": "steve"} - } + key_map = {"username": {"S": "steve"}} # Then remove the Meta.FullName field - conn.update_item("messages", key_map, - update_expression="REMOVE Meta.FullName") + conn.update_item("messages", key_map, update_expression="REMOVE Meta.FullName") returned_item = table.get_item(username="steve") - dict(returned_item).should.equal({ - 'username': "steve", - 'Meta': {} - }) + dict(returned_item).should.equal({"username": "steve", "Meta": {}}) @mock_dynamodb2_deprecated def test_update_item_double_nested_remove(): conn = boto.dynamodb2.connect_to_region("us-east-1") - table = Table.create('messages', schema=[ - HashKey('username') - ]) + table = Table.create("messages", schema=[HashKey("username")]) - data = { - 'username': "steve", - 'Meta': { - 'Name': { - 'First': 'Steve', - 'Last': 'Urkel' - } - } - } + data = {"username": "steve", "Meta": {"Name": {"First": "Steve", "Last": "Urkel"}}} table.put_item(data=data) - key_map = { - 'username': {"S": "steve"} - } + key_map = {"username": {"S": "steve"}} # Then remove the Meta.FullName field - conn.update_item("messages", key_map, - update_expression="REMOVE Meta.Name.First") + conn.update_item("messages", key_map, update_expression="REMOVE Meta.Name.First") returned_item = table.get_item(username="steve") - dict(returned_item).should.equal({ - 'username': "steve", - 'Meta': { - 'Name': { - 'Last': 'Urkel' - } - } - }) + dict(returned_item).should.equal( + {"username": "steve", "Meta": {"Name": {"Last": "Urkel"}}} + ) + @mock_dynamodb2_deprecated def test_update_item_set(): conn = boto.dynamodb2.connect_to_region("us-east-1") - table = Table.create('messages', schema=[ - HashKey('username') - ]) + table = Table.create("messages", schema=[HashKey("username")]) - data = { - 'username': "steve", - 'SentBy': 'User A', - } + data = {"username": "steve", "SentBy": "User A"} table.put_item(data=data) - key_map = { - 'username': {"S": "steve"} - } + key_map = {"username": {"S": "steve"}} - conn.update_item("messages", key_map, - update_expression="SET foo=bar, blah=baz REMOVE SentBy") + conn.update_item( + "messages", key_map, update_expression="SET foo=bar, blah=baz REMOVE SentBy" + ) returned_item = table.get_item(username="steve") - dict(returned_item).should.equal({ - 'username': "steve", - 'foo': 'bar', - 'blah': 'baz', - }) + dict(returned_item).should.equal({"username": "steve", "foo": "bar", "blah": "baz"}) @mock_dynamodb2_deprecated def test_failed_overwrite(): - table = Table.create('messages', schema=[ - HashKey('id'), - ], throughput={ - 'read': 7, - 'write': 3, - }) + table = Table.create( + "messages", schema=[HashKey("id")], throughput={"read": 7, "write": 3} + ) - data1 = {'id': '123', 'data': '678'} + data1 = {"id": "123", "data": "678"} table.put_item(data=data1) - data2 = {'id': '123', 'data': '345'} + data2 = {"id": "123", "data": "345"} table.put_item(data=data2, overwrite=True) - data3 = {'id': '123', 'data': '812'} + data3 = {"id": "123", "data": "812"} table.put_item.when.called_with(data=data3).should.throw( - ConditionalCheckFailedException) + ConditionalCheckFailedException + ) - returned_item = table.lookup('123') + returned_item = table.lookup("123") dict(returned_item).should.equal(data2) - data4 = {'id': '124', 'data': 812} + data4 = {"id": "124", "data": 812} table.put_item(data=data4) - returned_item = table.lookup('124') + returned_item = table.lookup("124") dict(returned_item).should.equal(data4) @mock_dynamodb2_deprecated def test_conflicting_writes(): - table = Table.create('messages', schema=[ - HashKey('id'), - ]) + table = Table.create("messages", schema=[HashKey("id")]) - item_data = {'id': '123', 'data': '678'} + item_data = {"id": "123", "data": "678"} item1 = Item(table, item_data) item2 = Item(table, item_data) item1.save() - item1['data'] = '579' - item2['data'] = '912' + item1["data"] = "579" + item2["data"] = "912" item1.save() item2.save.when.called_with().should.throw(ConditionalCheckFailedException) @@ -595,230 +528,178 @@ boto3 @mock_dynamodb2 def test_boto3_create_table(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") table = dynamodb.create_table( - TableName='users', - KeySchema=[ - { - 'AttributeName': 'username', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - table.name.should.equal('users') + table.name.should.equal("users") def _create_user_table(): - dynamodb = boto3.resource('dynamodb', region_name='us-east-1') + dynamodb = boto3.resource("dynamodb", region_name="us-east-1") table = dynamodb.create_table( - TableName='users', - KeySchema=[ - { - 'AttributeName': 'username', - 'KeyType': 'HASH' - }, - ], - AttributeDefinitions=[ - { - 'AttributeName': 'username', - 'AttributeType': 'S' - }, - ], - ProvisionedThroughput={ - 'ReadCapacityUnits': 5, - 'WriteCapacityUnits': 5 - } + TableName="users", + KeySchema=[{"AttributeName": "username", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "username", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) - return dynamodb.Table('users') + return dynamodb.Table("users") @mock_dynamodb2 def test_boto3_conditions(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe'}) - table.put_item(Item={'username': 'janedoe'}) + table.put_item(Item={"username": "johndoe"}) + table.put_item(Item={"username": "janedoe"}) - response = table.query( - KeyConditionExpression=Key('username').eq('johndoe') - ) - response['Count'].should.equal(1) - response['Items'].should.have.length_of(1) - response['Items'][0].should.equal({"username": "johndoe"}) + response = table.query(KeyConditionExpression=Key("username").eq("johndoe")) + response["Count"].should.equal(1) + response["Items"].should.have.length_of(1) + response["Items"][0].should.equal({"username": "johndoe"}) @mock_dynamodb2 def test_boto3_put_item_conditions_pass(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'EQ', - 'AttributeValueList': ['bar'] - } - }) - final_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(final_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "EQ", "AttributeValueList": ["bar"]}}, + ) + final_item = table.get_item(Key={"username": "johndoe"}) + assert dict(final_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_put_item_conditions_pass_because_expect_not_exists_by_compare_to_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'whatever': { - 'ComparisonOperator': 'NULL', - } - }) - final_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(final_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"whatever": {"ComparisonOperator": "NULL"}}, + ) + final_item = table.get_item(Key={"username": "johndoe"}) + assert dict(final_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_put_item_conditions_pass_because_expect_exists_by_compare_to_not_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'NOT_NULL', - } - }) - final_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(final_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "NOT_NULL"}}, + ) + final_item = table.get_item(Key={"username": "johndoe"}) + assert dict(final_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_put_item_conditions_fail(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item.when.called_with( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'NE', - 'AttributeValueList': ['bar'] - } - }).should.throw(botocore.client.ClientError) + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "NE", "AttributeValueList": ["bar"]}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_fail(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'baz'}) + table.put_item(Item={"username": "johndoe", "foo": "baz"}) table.update_item.when.called_with( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=bar', - Expected={ - 'foo': { - 'Value': 'bar', - } - }).should.throw(botocore.client.ClientError) + Key={"username": "johndoe"}, + UpdateExpression="SET foo=bar", + Expected={"foo": {"Value": "bar"}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_fail_because_expect_not_exists(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'baz'}) + table.put_item(Item={"username": "johndoe", "foo": "baz"}) table.update_item.when.called_with( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=bar', - Expected={ - 'foo': { - 'Exists': False - } - }).should.throw(botocore.client.ClientError) + Key={"username": "johndoe"}, + UpdateExpression="SET foo=bar", + Expected={"foo": {"Exists": False}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_fail_because_expect_not_exists_by_compare_to_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'baz'}) + table.put_item(Item={"username": "johndoe", "foo": "baz"}) table.update_item.when.called_with( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=bar', - Expected={ - 'foo': { - 'ComparisonOperator': 'NULL', - } - }).should.throw(botocore.client.ClientError) + Key={"username": "johndoe"}, + UpdateExpression="SET foo=bar", + Expected={"foo": {"ComparisonOperator": "NULL"}}, + ).should.throw(botocore.client.ClientError) + @mock_dynamodb2 def test_boto3_update_item_conditions_pass(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'foo': { - 'Value': 'bar', - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"foo": {"Value": "bar"}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_update_item_conditions_pass_because_expect_not_exists(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'whatever': { - 'Exists': False, - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"whatever": {"Exists": False}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_update_item_conditions_pass_because_expect_not_exists_by_compare_to_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'whatever': { - 'ComparisonOperator': 'NULL', - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"whatever": {"ComparisonOperator": "NULL"}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") + @mock_dynamodb2 def test_boto3_update_item_conditions_pass_because_expect_exists_by_compare_to_not_null(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=baz', - Expected={ - 'foo': { - 'ComparisonOperator': 'NOT_NULL', - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Key={"username": "johndoe"}, + UpdateExpression="SET foo=baz", + Expected={"foo": {"ComparisonOperator": "NOT_NULL"}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") @mock_dynamodb2 def test_boto3_update_settype_item_with_conditions(): class OrderedSet(set): """A set with predictable iteration order""" + def __init__(self, values): super(OrderedSet, self).__init__(values) self.__ordered_values = values @@ -827,143 +708,113 @@ def test_boto3_update_settype_item_with_conditions(): return iter(self.__ordered_values) table = _create_user_table() - table.put_item(Item={'username': 'johndoe'}) + table.put_item(Item={"username": "johndoe"}) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=:new_value', - ExpressionAttributeValues={ - ':new_value': OrderedSet(['hello', 'world']), - }, + Key={"username": "johndoe"}, + UpdateExpression="SET foo=:new_value", + ExpressionAttributeValues={":new_value": OrderedSet(["hello", "world"])}, ) table.update_item( - Key={'username': 'johndoe'}, - UpdateExpression='SET foo=:new_value', - ExpressionAttributeValues={ - ':new_value': set(['baz']), - }, + Key={"username": "johndoe"}, + UpdateExpression="SET foo=:new_value", + ExpressionAttributeValues={":new_value": set(["baz"])}, Expected={ - 'foo': { - 'ComparisonOperator': 'EQ', - 'AttributeValueList': [ - OrderedSet(['world', 'hello']), # Opposite order to original + "foo": { + "ComparisonOperator": "EQ", + "AttributeValueList": [ + OrderedSet(["world", "hello"]) # Opposite order to original ], } }, ) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal(set(['baz'])) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal(set(["baz"])) @mock_dynamodb2 def test_boto3_put_item_conditions_pass(): table = _create_user_table() - table.put_item(Item={'username': 'johndoe', 'foo': 'bar'}) + table.put_item(Item={"username": "johndoe", "foo": "bar"}) table.put_item( - Item={'username': 'johndoe', 'foo': 'baz'}, - Expected={ - 'foo': { - 'ComparisonOperator': 'EQ', - 'AttributeValueList': ['bar'] - } - }) - returned_item = table.get_item(Key={'username': 'johndoe'}) - assert dict(returned_item)['Item']['foo'].should.equal("baz") + Item={"username": "johndoe", "foo": "baz"}, + Expected={"foo": {"ComparisonOperator": "EQ", "AttributeValueList": ["bar"]}}, + ) + returned_item = table.get_item(Key={"username": "johndoe"}) + assert dict(returned_item)["Item"]["foo"].should.equal("baz") @mock_dynamodb2 def test_scan_pagination(): table = _create_user_table() - expected_usernames = ['user{0}'.format(i) for i in range(10)] + expected_usernames = ["user{0}".format(i) for i in range(10)] for u in expected_usernames: - table.put_item(Item={'username': u}) + table.put_item(Item={"username": u}) page1 = table.scan(Limit=6) - page1['Count'].should.equal(6) - page1['Items'].should.have.length_of(6) - page1.should.have.key('LastEvaluatedKey') + page1["Count"].should.equal(6) + page1["Items"].should.have.length_of(6) + page1.should.have.key("LastEvaluatedKey") - page2 = table.scan(Limit=6, - ExclusiveStartKey=page1['LastEvaluatedKey']) - page2['Count'].should.equal(4) - page2['Items'].should.have.length_of(4) - page2.should_not.have.key('LastEvaluatedKey') + page2 = table.scan(Limit=6, ExclusiveStartKey=page1["LastEvaluatedKey"]) + page2["Count"].should.equal(4) + page2["Items"].should.have.length_of(4) + page2.should_not.have.key("LastEvaluatedKey") - results = page1['Items'] + page2['Items'] - usernames = set([r['username'] for r in results]) + results = page1["Items"] + page2["Items"] + usernames = set([r["username"] for r in results]) usernames.should.equal(set(expected_usernames)) @mock_dynamodb2 def test_scan_by_index(): - dynamodb = boto3.client('dynamodb', region_name='us-east-1') + dynamodb = boto3.client("dynamodb", region_name="us-east-1") dynamodb.create_table( - TableName='test', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], + TableName="test", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], AttributeDefinitions=[ - {'AttributeName': 'id', 'AttributeType': 'S'}, - {'AttributeName': 'gsi_col', 'AttributeType': 'S'} + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "gsi_col", "AttributeType": "S"}, ], - ProvisionedThroughput={'ReadCapacityUnits': 1, 'WriteCapacityUnits': 1}, + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, GlobalSecondaryIndexes=[ { - 'IndexName': 'test_gsi', - 'KeySchema': [ - { - 'AttributeName': 'gsi_col', - 'KeyType': 'HASH' - }, - ], - 'Projection': { - 'ProjectionType': 'ALL', + "IndexName": "test_gsi", + "KeySchema": [{"AttributeName": "gsi_col", "KeyType": "HASH"}], + "Projection": {"ProjectionType": "ALL"}, + "ProvisionedThroughput": { + "ReadCapacityUnits": 1, + "WriteCapacityUnits": 1, }, - 'ProvisionedThroughput': { - 'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1 - } - }, - ] + } + ], ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '1'}, - 'col1': {'S': 'val1'}, - 'gsi_col': {'S': 'gsi_val1'}, - } + TableName="test", + Item={"id": {"S": "1"}, "col1": {"S": "val1"}, "gsi_col": {"S": "gsi_val1"}}, ) dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '2'}, - 'col1': {'S': 'val2'}, - 'gsi_col': {'S': 'gsi_val2'}, - } + TableName="test", + Item={"id": {"S": "2"}, "col1": {"S": "val2"}, "gsi_col": {"S": "gsi_val2"}}, ) - dynamodb.put_item( - TableName='test', - Item={ - 'id': {'S': '3'}, - 'col1': {'S': 'val3'}, - } - ) + dynamodb.put_item(TableName="test", Item={"id": {"S": "3"}, "col1": {"S": "val3"}}) - res = dynamodb.scan(TableName='test') - assert res['Count'] == 3 - assert len(res['Items']) == 3 + res = dynamodb.scan(TableName="test") + assert res["Count"] == 3 + assert len(res["Items"]) == 3 - res = dynamodb.scan(TableName='test', IndexName='test_gsi') - assert res['Count'] == 2 - assert len(res['Items']) == 2 + res = dynamodb.scan(TableName="test", IndexName="test_gsi") + assert res["Count"] == 2 + assert len(res["Items"]) == 2 - res = dynamodb.scan(TableName='test', IndexName='test_gsi', Limit=1) - assert res['Count'] == 1 - assert len(res['Items']) == 1 - last_eval_key = res['LastEvaluatedKey'] - assert last_eval_key['id']['S'] == '1' - assert last_eval_key['gsi_col']['S'] == 'gsi_val1' + res = dynamodb.scan(TableName="test", IndexName="test_gsi", Limit=1) + assert res["Count"] == 1 + assert len(res["Items"]) == 1 + last_eval_key = res["LastEvaluatedKey"] + assert last_eval_key["id"]["S"] == "1" + assert last_eval_key["gsi_col"]["S"] == "gsi_val1" diff --git a/tests/test_dynamodb2/test_server.py b/tests/test_dynamodb2/test_server.py index af820beaf..880909fac 100644 --- a/tests/test_dynamodb2/test_server.py +++ b/tests/test_dynamodb2/test_server.py @@ -3,17 +3,17 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_table_list(): backend = server.create_backend_app("dynamodb2") test_client = backend.test_client() - res = test_client.get('/') + res = test_client.get("/") res.status_code.should.equal(404) - headers = {'X-Amz-Target': 'TestTable.ListTables'} - res = test_client.get('/', headers=headers) - res.data.should.contain(b'TableNames') + headers = {"X-Amz-Target": "TestTable.ListTables"} + res = test_client.get("/", headers=headers) + res.data.should.contain(b"TableNames") diff --git a/tests/test_dynamodbstreams/test_dynamodbstreams.py b/tests/test_dynamodbstreams/test_dynamodbstreams.py index deb9f9283..01cf915af 100644 --- a/tests/test_dynamodbstreams/test_dynamodbstreams.py +++ b/tests/test_dynamodbstreams/test_dynamodbstreams.py @@ -6,198 +6,190 @@ import boto3 from moto import mock_dynamodb2, mock_dynamodbstreams -class TestCore(): +class TestCore: stream_arn = None mocks = [] - + def setup(self): self.mocks = [mock_dynamodb2(), mock_dynamodbstreams()] for m in self.mocks: m.start() - + # create a table with a stream - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', - 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1}, + TableName="test-streams", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, StreamSpecification={ - 'StreamEnabled': True, - 'StreamViewType': 'NEW_AND_OLD_IMAGES' - } + "StreamEnabled": True, + "StreamViewType": "NEW_AND_OLD_IMAGES", + }, ) - self.stream_arn = resp['TableDescription']['LatestStreamArn'] + self.stream_arn = resp["TableDescription"]["LatestStreamArn"] def teardown(self): - conn = boto3.client('dynamodb', region_name='us-east-1') - conn.delete_table(TableName='test-streams') + conn = boto3.client("dynamodb", region_name="us-east-1") + conn.delete_table(TableName="test-streams") self.stream_arn = None for m in self.mocks: m.stop() - def test_verify_stream(self): - conn = boto3.client('dynamodb', region_name='us-east-1') - resp = conn.describe_table(TableName='test-streams') - assert 'LatestStreamArn' in resp['Table'] + conn = boto3.client("dynamodb", region_name="us-east-1") + resp = conn.describe_table(TableName="test-streams") + assert "LatestStreamArn" in resp["Table"] def test_describe_stream(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - assert 'StreamDescription' in resp - desc = resp['StreamDescription'] - assert desc['StreamArn'] == self.stream_arn - assert desc['TableName'] == 'test-streams' + assert "StreamDescription" in resp + desc = resp["StreamDescription"] + assert desc["StreamArn"] == self.stream_arn + assert desc["TableName"] == "test-streams" def test_list_streams(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.list_streams() - assert resp['Streams'][0]['StreamArn'] == self.stream_arn + assert resp["Streams"][0]["StreamArn"] == self.stream_arn - resp = conn.list_streams(TableName='no-stream') - assert not resp['Streams'] + resp = conn.list_streams(TableName="no-stream") + assert not resp["Streams"] def test_get_shard_iterator(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='TRIM_HORIZON' + ShardIteratorType="TRIM_HORIZON", ) - assert 'ShardIterator' in resp + assert "ShardIterator" in resp def test_get_shard_iterator_at_sequence_number(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='AT_SEQUENCE_NUMBER', - SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber'] + ShardIteratorType="AT_SEQUENCE_NUMBER", + SequenceNumber=resp["StreamDescription"]["Shards"][0][ + "SequenceNumberRange" + ]["StartingSequenceNumber"], ) - assert 'ShardIterator' in resp + assert "ShardIterator" in resp def test_get_shard_iterator_after_sequence_number(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='AFTER_SEQUENCE_NUMBER', - SequenceNumber=resp['StreamDescription']['Shards'][0]['SequenceNumberRange']['StartingSequenceNumber'] + ShardIteratorType="AFTER_SEQUENCE_NUMBER", + SequenceNumber=resp["StreamDescription"]["Shards"][0][ + "SequenceNumberRange" + ]["StartingSequenceNumber"], ) - assert 'ShardIterator' in resp - + assert "ShardIterator" in resp + def test_get_records_empty(self): - conn = boto3.client('dynamodbstreams', region_name='us-east-1') + conn = boto3.client("dynamodbstreams", region_name="us-east-1") resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( - StreamArn=self.stream_arn, - ShardId=shard_id, - ShardIteratorType='LATEST' + StreamArn=self.stream_arn, ShardId=shard_id, ShardIteratorType="LATEST" ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] resp = conn.get_records(ShardIterator=iterator_id) - assert 'Records' in resp - assert len(resp['Records']) == 0 + assert "Records" in resp + assert len(resp["Records"]) == 0 def test_get_records_seq(self): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") conn.put_item( - TableName='test-streams', - Item={ - 'id': {'S': 'entry1'}, - 'first_col': {'S': 'foo'} - } + TableName="test-streams", + Item={"id": {"S": "entry1"}, "first_col": {"S": "foo"}}, ) conn.put_item( - TableName='test-streams', + TableName="test-streams", Item={ - 'id': {'S': 'entry1'}, - 'first_col': {'S': 'bar'}, - 'second_col': {'S': 'baz'} - } + "id": {"S": "entry1"}, + "first_col": {"S": "bar"}, + "second_col": {"S": "baz"}, + }, ) - conn.delete_item( - TableName='test-streams', - Key={'id': {'S': 'entry1'}} - ) - - conn = boto3.client('dynamodbstreams', region_name='us-east-1') - + conn.delete_item(TableName="test-streams", Key={"id": {"S": "entry1"}}) + + conn = boto3.client("dynamodbstreams", region_name="us-east-1") + resp = conn.describe_stream(StreamArn=self.stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] - + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] + resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='TRIM_HORIZON' + ShardIteratorType="TRIM_HORIZON", ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] resp = conn.get_records(ShardIterator=iterator_id) - assert len(resp['Records']) == 3 - assert resp['Records'][0]['eventName'] == 'INSERT' - assert resp['Records'][1]['eventName'] == 'MODIFY' - assert resp['Records'][2]['eventName'] == 'DELETE' + assert len(resp["Records"]) == 3 + assert resp["Records"][0]["eventName"] == "INSERT" + assert resp["Records"][1]["eventName"] == "MODIFY" + assert resp["Records"][2]["eventName"] == "DELETE" - sequence_number_modify = resp['Records'][1]['dynamodb']['SequenceNumber'] + sequence_number_modify = resp["Records"][1]["dynamodb"]["SequenceNumber"] # now try fetching from the next shard iterator, it should be # empty - resp = conn.get_records(ShardIterator=resp['NextShardIterator']) - assert len(resp['Records']) == 0 + resp = conn.get_records(ShardIterator=resp["NextShardIterator"]) + assert len(resp["Records"]) == 0 # check that if we get the shard iterator AT_SEQUENCE_NUMBER will get the MODIFY event resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='AT_SEQUENCE_NUMBER', - SequenceNumber=sequence_number_modify + ShardIteratorType="AT_SEQUENCE_NUMBER", + SequenceNumber=sequence_number_modify, ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] resp = conn.get_records(ShardIterator=iterator_id) - assert len(resp['Records']) == 2 - assert resp['Records'][0]['eventName'] == 'MODIFY' - assert resp['Records'][1]['eventName'] == 'DELETE' + assert len(resp["Records"]) == 2 + assert resp["Records"][0]["eventName"] == "MODIFY" + assert resp["Records"][1]["eventName"] == "DELETE" # check that if we get the shard iterator AFTER_SEQUENCE_NUMBER will get the DELETE event resp = conn.get_shard_iterator( StreamArn=self.stream_arn, ShardId=shard_id, - ShardIteratorType='AFTER_SEQUENCE_NUMBER', - SequenceNumber=sequence_number_modify + ShardIteratorType="AFTER_SEQUENCE_NUMBER", + SequenceNumber=sequence_number_modify, ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] resp = conn.get_records(ShardIterator=iterator_id) - assert len(resp['Records']) == 1 - assert resp['Records'][0]['eventName'] == 'DELETE' + assert len(resp["Records"]) == 1 + assert resp["Records"][0]["eventName"] == "DELETE" - -class TestEdges(): +class TestEdges: mocks = [] def setup(self): @@ -209,82 +201,70 @@ class TestEdges(): for m in self.mocks: m.stop() - def test_enable_stream_on_table(self): - conn = boto3.client('dynamodb', region_name='us-east-1') + conn = boto3.client("dynamodb", region_name="us-east-1") resp = conn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}], - AttributeDefinitions=[{'AttributeName': 'id', - 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1} + TableName="test-streams", + KeySchema=[{"AttributeName": "id", "KeyType": "HASH"}], + AttributeDefinitions=[{"AttributeName": "id", "AttributeType": "S"}], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, ) - assert 'StreamSpecification' not in resp['TableDescription'] - + assert "StreamSpecification" not in resp["TableDescription"] + resp = conn.update_table( - TableName='test-streams', - StreamSpecification={ - 'StreamViewType': 'KEYS_ONLY' - } + TableName="test-streams", + StreamSpecification={"StreamViewType": "KEYS_ONLY"}, ) - assert 'StreamSpecification' in resp['TableDescription'] - assert resp['TableDescription']['StreamSpecification'] == { - 'StreamEnabled': True, - 'StreamViewType': 'KEYS_ONLY' + assert "StreamSpecification" in resp["TableDescription"] + assert resp["TableDescription"]["StreamSpecification"] == { + "StreamEnabled": True, + "StreamViewType": "KEYS_ONLY", } - assert 'LatestStreamLabel' in resp['TableDescription'] + assert "LatestStreamLabel" in resp["TableDescription"] # now try to enable it again with assert_raises(conn.exceptions.ResourceInUseException): resp = conn.update_table( - TableName='test-streams', - StreamSpecification={ - 'StreamViewType': 'OLD_IMAGES' - } + TableName="test-streams", + StreamSpecification={"StreamViewType": "OLD_IMAGES"}, ) - + def test_stream_with_range_key(self): - dyn = boto3.client('dynamodb', region_name='us-east-1') + dyn = boto3.client("dynamodb", region_name="us-east-1") resp = dyn.create_table( - TableName='test-streams', - KeySchema=[{'AttributeName': 'id', 'KeyType': 'HASH'}, - {'AttributeName': 'color', 'KeyType': 'RANGE'}], - AttributeDefinitions=[{'AttributeName': 'id', - 'AttributeType': 'S'}, - {'AttributeName': 'color', - 'AttributeType': 'S'}], - ProvisionedThroughput={'ReadCapacityUnits': 1, - 'WriteCapacityUnits': 1}, - StreamSpecification={ - 'StreamViewType': 'NEW_IMAGES' - } + TableName="test-streams", + KeySchema=[ + {"AttributeName": "id", "KeyType": "HASH"}, + {"AttributeName": "color", "KeyType": "RANGE"}, + ], + AttributeDefinitions=[ + {"AttributeName": "id", "AttributeType": "S"}, + {"AttributeName": "color", "AttributeType": "S"}, + ], + ProvisionedThroughput={"ReadCapacityUnits": 1, "WriteCapacityUnits": 1}, + StreamSpecification={"StreamViewType": "NEW_IMAGES"}, ) - stream_arn = resp['TableDescription']['LatestStreamArn'] + stream_arn = resp["TableDescription"]["LatestStreamArn"] - streams = boto3.client('dynamodbstreams', region_name='us-east-1') + streams = boto3.client("dynamodbstreams", region_name="us-east-1") resp = streams.describe_stream(StreamArn=stream_arn) - shard_id = resp['StreamDescription']['Shards'][0]['ShardId'] + shard_id = resp["StreamDescription"]["Shards"][0]["ShardId"] resp = streams.get_shard_iterator( - StreamArn=stream_arn, - ShardId=shard_id, - ShardIteratorType='LATEST' + StreamArn=stream_arn, ShardId=shard_id, ShardIteratorType="LATEST" ) - iterator_id = resp['ShardIterator'] + iterator_id = resp["ShardIterator"] dyn.put_item( - TableName='test-streams', - Item={'id': {'S': 'row1'}, 'color': {'S': 'blue'}} + TableName="test-streams", Item={"id": {"S": "row1"}, "color": {"S": "blue"}} ) dyn.put_item( - TableName='test-streams', - Item={'id': {'S': 'row2'}, 'color': {'S': 'green'}} + TableName="test-streams", + Item={"id": {"S": "row2"}, "color": {"S": "green"}}, ) resp = streams.get_records(ShardIterator=iterator_id) - assert len(resp['Records']) == 2 - assert resp['Records'][0]['eventName'] == 'INSERT' - assert resp['Records'][1]['eventName'] == 'INSERT' - + assert len(resp["Records"]) == 2 + assert resp["Records"][0]["eventName"] == "INSERT" + assert resp["Records"][1]["eventName"] == "INSERT" diff --git a/tests/test_ec2/helpers.py b/tests/test_ec2/helpers.py index 94c9c10cb..6dd281874 100644 --- a/tests/test_ec2/helpers.py +++ b/tests/test_ec2/helpers.py @@ -9,7 +9,8 @@ def rsa_check_private_key(private_key_material): assert isinstance(private_key_material, six.string_types) private_key = serialization.load_pem_private_key( - data=private_key_material.encode('ascii'), + data=private_key_material.encode("ascii"), backend=default_backend(), - password=None) + password=None, + ) assert isinstance(private_key, rsa.RSAPrivateKey) diff --git a/tests/test_ec2/test_account_attributes.py b/tests/test_ec2/test_account_attributes.py index 30309bec8..a3135f22e 100644 --- a/tests/test_ec2/test_account_attributes.py +++ b/tests/test_ec2/test_account_attributes.py @@ -6,39 +6,32 @@ import sure # noqa @mock_ec2 def test_describe_account_attributes(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") response = conn.describe_account_attributes() - expected_attribute_values = [{ - 'AttributeValues': [{ - 'AttributeValue': '5' - }], - 'AttributeName': 'vpc-max-security-groups-per-interface' - }, { - 'AttributeValues': [{ - 'AttributeValue': '20' - }], - 'AttributeName': 'max-instances' - }, { - 'AttributeValues': [{ - 'AttributeValue': 'EC2' - }, { - 'AttributeValue': 'VPC' - }], - 'AttributeName': 'supported-platforms' - }, { - 'AttributeValues': [{ - 'AttributeValue': 'none' - }], - 'AttributeName': 'default-vpc' - }, { - 'AttributeValues': [{ - 'AttributeValue': '5' - }], - 'AttributeName': 'max-elastic-ips' - }, { - 'AttributeValues': [{ - 'AttributeValue': '5' - }], - 'AttributeName': 'vpc-max-elastic-ips' - }] - response['AccountAttributes'].should.equal(expected_attribute_values) + expected_attribute_values = [ + { + "AttributeValues": [{"AttributeValue": "5"}], + "AttributeName": "vpc-max-security-groups-per-interface", + }, + { + "AttributeValues": [{"AttributeValue": "20"}], + "AttributeName": "max-instances", + }, + { + "AttributeValues": [{"AttributeValue": "EC2"}, {"AttributeValue": "VPC"}], + "AttributeName": "supported-platforms", + }, + { + "AttributeValues": [{"AttributeValue": "none"}], + "AttributeName": "default-vpc", + }, + { + "AttributeValues": [{"AttributeValue": "5"}], + "AttributeName": "max-elastic-ips", + }, + { + "AttributeValues": [{"AttributeValue": "5"}], + "AttributeName": "vpc-max-elastic-ips", + }, + ] + response["AccountAttributes"].should.equal(expected_attribute_values) diff --git a/tests/test_ec2/test_amis.py b/tests/test_ec2/test_amis.py index feff4a16c..f65352c7c 100644 --- a/tests/test_ec2/test_amis.py +++ b/tests/test_ec2/test_amis.py @@ -5,6 +5,7 @@ import boto.ec2 import boto3 from boto.exception import EC2ResponseError from botocore.exceptions import ClientError + # Ensure 'assert_raises' context manager support for Python 2.6 from nose.tools import assert_raises import sure # noqa @@ -16,22 +17,24 @@ from tests.helpers import requires_boto_gte @mock_ec2_deprecated def test_ami_create_and_delete(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") initial_ami_count = len(AMIS) conn.get_all_volumes().should.have.length_of(0) conn.get_all_snapshots().should.have.length_of(initial_ami_count) - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: image_id = conn.create_image( - instance.id, "test-ami", "this is a test ami", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + instance.id, "test-ami", "this is a test ami", dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateImage operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateImage operation: Request would have succeeded, but DryRun flag is set" + ) image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") @@ -56,30 +59,36 @@ def test_ami_create_and_delete(): snapshots = conn.get_all_snapshots() snapshots.should.have.length_of(initial_ami_count + 1) - retrieved_image_snapshot_id = retrieved_image.block_device_mapping.current_value.snapshot_id + retrieved_image_snapshot_id = ( + retrieved_image.block_device_mapping.current_value.snapshot_id + ) [s.id for s in snapshots].should.contain(retrieved_image_snapshot_id) snapshot = [s for s in snapshots if s.id == retrieved_image_snapshot_id][0] snapshot.description.should.equal( - "Auto-created snapshot for AMI {0}".format(retrieved_image.id)) + "Auto-created snapshot for AMI {0}".format(retrieved_image.id) + ) # root device should be in AMI's block device mappings - root_mapping = retrieved_image.block_device_mapping.get(retrieved_image.root_device_name) + root_mapping = retrieved_image.block_device_mapping.get( + retrieved_image.root_device_name + ) root_mapping.should_not.be.none # Deregister with assert_raises(EC2ResponseError) as ex: success = conn.deregister_image(image_id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeregisterImage operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeregisterImage operation: Request would have succeeded, but DryRun flag is set" + ) success = conn.deregister_image(image_id) success.should.be.true with assert_raises(EC2ResponseError) as cm: conn.deregister_image(image_id) - cm.exception.code.should.equal('InvalidAMIID.NotFound') + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -93,11 +102,10 @@ def test_ami_copy(): conn.get_all_volumes().should.have.length_of(0) conn.get_all_snapshots().should.have.length_of(initial_ami_count) - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] - source_image_id = conn.create_image( - instance.id, "test-ami", "this is a test ami") + source_image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") instance.terminate() source_image = conn.get_all_images(image_ids=[source_image_id])[0] @@ -105,21 +113,29 @@ def test_ami_copy(): # the image_id to fetch the full info. with assert_raises(EC2ResponseError) as ex: copy_image_ref = conn.copy_image( - source_image.region.name, source_image.id, "test-copy-ami", "this is a test copy ami", - dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + source_image.region.name, + source_image.id, + "test-copy-ami", + "this is a test copy ami", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CopyImage operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CopyImage operation: Request would have succeeded, but DryRun flag is set" + ) copy_image_ref = conn.copy_image( - source_image.region.name, source_image.id, "test-copy-ami", "this is a test copy ami") + source_image.region.name, + source_image.id, + "test-copy-ami", + "this is a test copy ami", + ) copy_image_id = copy_image_ref.image_id copy_image = conn.get_all_images(image_ids=[copy_image_id])[0] copy_image.id.should.equal(copy_image_id) - copy_image.virtualization_type.should.equal( - source_image.virtualization_type) + copy_image.virtualization_type.should.equal(source_image.virtualization_type) copy_image.architecture.should.equal(source_image.architecture) copy_image.kernel_id.should.equal(source_image.kernel_id) copy_image.platform.should.equal(source_image.platform) @@ -131,30 +147,37 @@ def test_ami_copy(): conn.get_all_snapshots().should.have.length_of(initial_ami_count + 2) copy_image.block_device_mapping.current_value.snapshot_id.should_not.equal( - source_image.block_device_mapping.current_value.snapshot_id) + source_image.block_device_mapping.current_value.snapshot_id + ) # Copy from non-existent source ID. with assert_raises(EC2ResponseError) as cm: - conn.copy_image(source_image.region.name, 'ami-abcd1234', - "test-copy-ami", "this is a test copy ami") - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.copy_image( + source_image.region.name, + "ami-abcd1234", + "test-copy-ami", + "this is a test copy ami", + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Copy from non-existent source region. with assert_raises(EC2ResponseError) as cm: - invalid_region = 'us-east-1' if (source_image.region.name != - 'us-east-1') else 'us-west-1' - conn.copy_image(invalid_region, source_image.id, - "test-copy-ami", "this is a test copy ami") - cm.exception.code.should.equal('InvalidAMIID.NotFound') + invalid_region = ( + "us-east-1" if (source_image.region.name != "us-east-1") else "us-west-1" + ) + conn.copy_image( + invalid_region, source_image.id, "test-copy-ami", "this is a test copy ami" + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_copy_image_changes_owner_id(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") # this source AMI ID is from moto/ec2/resources/amis.json source_ami_id = "ami-03cf127a" @@ -168,7 +191,8 @@ def test_copy_image_changes_owner_id(): SourceImageId=source_ami_id, Name="new-image", Description="a copy of an image", - SourceRegion="us-east-1") + SourceRegion="us-east-1", + ) describe_resp = conn.describe_images(Owners=["self"]) describe_resp["Images"][0]["OwnerId"].should.equal(OWNER_ID) @@ -177,18 +201,19 @@ def test_copy_image_changes_owner_id(): @mock_ec2_deprecated def test_ami_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_vpc("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_all_images()[0] with assert_raises(EC2ResponseError) as ex: image.add_tag("a key", "some value", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) image.add_tag("a key", "some value") @@ -204,368 +229,374 @@ def test_ami_tagging(): @mock_ec2_deprecated def test_ami_create_from_missing_instance(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") args = ["i-abcdefg", "test-ami", "this is a test ami"] with assert_raises(EC2ResponseError) as cm: conn.create_image(*args) - cm.exception.code.should.equal('InvalidInstanceID.NotFound') + cm.exception.code.should.equal("InvalidInstanceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_ami_pulls_attributes_from_instance(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.modify_attribute("kernel", "test-kernel") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) - image.kernel_id.should.equal('test-kernel') + image.kernel_id.should.equal("test-kernel") @mock_ec2_deprecated def test_ami_filters(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - reservationA = conn.run_instances('ami-1234abcd') + reservationA = conn.run_instances("ami-1234abcd") instanceA = reservationA.instances[0] instanceA.modify_attribute("architecture", "i386") instanceA.modify_attribute("kernel", "k-1234abcd") instanceA.modify_attribute("platform", "windows") instanceA.modify_attribute("virtualization_type", "hvm") - imageA_id = conn.create_image( - instanceA.id, "test-ami-A", "this is a test ami") + imageA_id = conn.create_image(instanceA.id, "test-ami-A", "this is a test ami") imageA = conn.get_image(imageA_id) - reservationB = conn.run_instances('ami-abcd1234') + reservationB = conn.run_instances("ami-abcd1234") instanceB = reservationB.instances[0] instanceB.modify_attribute("architecture", "x86_64") instanceB.modify_attribute("kernel", "k-abcd1234") instanceB.modify_attribute("platform", "linux") instanceB.modify_attribute("virtualization_type", "paravirtual") - imageB_id = conn.create_image( - instanceB.id, "test-ami-B", "this is a test ami") + imageB_id = conn.create_image(instanceB.id, "test-ami-B", "this is a test ami") imageB = conn.get_image(imageB_id) imageB.set_launch_permissions(group_names=("all")) - amis_by_architecture = conn.get_all_images( - filters={'architecture': 'x86_64'}) + amis_by_architecture = conn.get_all_images(filters={"architecture": "x86_64"}) set([ami.id for ami in amis_by_architecture]).should.contain(imageB.id) len(amis_by_architecture).should.equal(35) - amis_by_kernel = conn.get_all_images(filters={'kernel-id': 'k-abcd1234'}) + amis_by_kernel = conn.get_all_images(filters={"kernel-id": "k-abcd1234"}) set([ami.id for ami in amis_by_kernel]).should.equal(set([imageB.id])) amis_by_virtualization = conn.get_all_images( - filters={'virtualization-type': 'paravirtual'}) - set([ami.id for ami in amis_by_virtualization] - ).should.contain(imageB.id) + filters={"virtualization-type": "paravirtual"} + ) + set([ami.id for ami in amis_by_virtualization]).should.contain(imageB.id) len(amis_by_virtualization).should.equal(3) - amis_by_platform = conn.get_all_images(filters={'platform': 'windows'}) + amis_by_platform = conn.get_all_images(filters={"platform": "windows"}) set([ami.id for ami in amis_by_platform]).should.contain(imageA.id) len(amis_by_platform).should.equal(24) - amis_by_id = conn.get_all_images(filters={'image-id': imageA.id}) + amis_by_id = conn.get_all_images(filters={"image-id": imageA.id}) set([ami.id for ami in amis_by_id]).should.equal(set([imageA.id])) - amis_by_state = conn.get_all_images(filters={'state': 'available'}) + amis_by_state = conn.get_all_images(filters={"state": "available"}) ami_ids_by_state = [ami.id for ami in amis_by_state] ami_ids_by_state.should.contain(imageA.id) ami_ids_by_state.should.contain(imageB.id) len(amis_by_state).should.equal(36) - amis_by_name = conn.get_all_images(filters={'name': imageA.name}) + amis_by_name = conn.get_all_images(filters={"name": imageA.name}) set([ami.id for ami in amis_by_name]).should.equal(set([imageA.id])) - amis_by_public = conn.get_all_images(filters={'is-public': 'true'}) + amis_by_public = conn.get_all_images(filters={"is-public": "true"}) set([ami.id for ami in amis_by_public]).should.contain(imageB.id) len(amis_by_public).should.equal(35) - amis_by_nonpublic = conn.get_all_images(filters={'is-public': 'false'}) + amis_by_nonpublic = conn.get_all_images(filters={"is-public": "false"}) set([ami.id for ami in amis_by_nonpublic]).should.contain(imageA.id) len(amis_by_nonpublic).should.equal(1) @mock_ec2_deprecated def test_ami_filtering_via_tag(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - reservationA = conn.run_instances('ami-1234abcd') + reservationA = conn.run_instances("ami-1234abcd") instanceA = reservationA.instances[0] - imageA_id = conn.create_image( - instanceA.id, "test-ami-A", "this is a test ami") + imageA_id = conn.create_image(instanceA.id, "test-ami-A", "this is a test ami") imageA = conn.get_image(imageA_id) imageA.add_tag("a key", "some value") - reservationB = conn.run_instances('ami-abcd1234') + reservationB = conn.run_instances("ami-abcd1234") instanceB = reservationB.instances[0] - imageB_id = conn.create_image( - instanceB.id, "test-ami-B", "this is a test ami") + imageB_id = conn.create_image(instanceB.id, "test-ami-B", "this is a test ami") imageB = conn.get_image(imageB_id) imageB.add_tag("another key", "some other value") - amis_by_tagA = conn.get_all_images(filters={'tag:a key': 'some value'}) + amis_by_tagA = conn.get_all_images(filters={"tag:a key": "some value"}) set([ami.id for ami in amis_by_tagA]).should.equal(set([imageA.id])) - amis_by_tagB = conn.get_all_images( - filters={'tag:another key': 'some other value'}) + amis_by_tagB = conn.get_all_images(filters={"tag:another key": "some other value"}) set([ami.id for ami in amis_by_tagB]).should.equal(set([imageB.id])) @mock_ec2_deprecated def test_getting_missing_ami(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.get_image('ami-missing') - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.get_image("ami-missing") + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_getting_malformed_ami(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.get_image('foo-missing') - cm.exception.code.should.equal('InvalidAMIID.Malformed') + conn.get_image("foo-missing") + cm.exception.code.should.equal("InvalidAMIID.Malformed") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_ami_attribute_group_permissions(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Baseline - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.name.should.equal('launch_permission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.name.should.equal("launch_permission") attributes.attrs.should.have.length_of(0) - ADD_GROUP_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'add', - 'groups': 'all'} + ADD_GROUP_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "add", + "groups": "all", + } - REMOVE_GROUP_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'groups': 'all'} + REMOVE_GROUP_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "groups": "all", + } # Add 'all' group and confirm with assert_raises(EC2ResponseError) as ex: - conn.modify_image_attribute( - **dict(ADD_GROUP_ARGS, **{'dry_run': True})) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_image_attribute(**dict(ADD_GROUP_ARGS, **{"dry_run": True})) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyImageAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyImageAttribute operation: Request would have succeeded, but DryRun flag is set" + ) conn.modify_image_attribute(**ADD_GROUP_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['groups'].should.have.length_of(1) - attributes.attrs['groups'].should.equal(['all']) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["groups"].should.have.length_of(1) + attributes.attrs["groups"].should.equal(["all"]) image = conn.get_image(image_id) image.is_public.should.equal(True) # Add is idempotent - conn.modify_image_attribute.when.called_with( - **ADD_GROUP_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**ADD_GROUP_ARGS).should_not.throw( + EC2ResponseError + ) # Remove 'all' group and confirm conn.modify_image_attribute(**REMOVE_GROUP_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) image = conn.get_image(image_id) image.is_public.should.equal(False) # Remove is idempotent - conn.modify_image_attribute.when.called_with( - **REMOVE_GROUP_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**REMOVE_GROUP_ARGS).should_not.throw( + EC2ResponseError + ) @mock_ec2_deprecated def test_ami_attribute_user_permissions(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Baseline - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.name.should.equal('launch_permission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.name.should.equal("launch_permission") attributes.attrs.should.have.length_of(0) # Both str and int values should work. - USER1 = '123456789011' + USER1 = "123456789011" USER2 = 123456789022 - ADD_USERS_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'add', - 'user_ids': [USER1, USER2]} + ADD_USERS_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "add", + "user_ids": [USER1, USER2], + } - REMOVE_USERS_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'user_ids': [USER1, USER2]} + REMOVE_USERS_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "user_ids": [USER1, USER2], + } - REMOVE_SINGLE_USER_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'user_ids': [USER1]} + REMOVE_SINGLE_USER_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "user_ids": [USER1], + } # Add multiple users and confirm conn.modify_image_attribute(**ADD_USERS_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['user_ids'].should.have.length_of(2) - set(attributes.attrs['user_ids']).should.equal( - set([str(USER1), str(USER2)])) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["user_ids"].should.have.length_of(2) + set(attributes.attrs["user_ids"]).should.equal(set([str(USER1), str(USER2)])) image = conn.get_image(image_id) image.is_public.should.equal(False) # Add is idempotent - conn.modify_image_attribute.when.called_with( - **ADD_USERS_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**ADD_USERS_ARGS).should_not.throw( + EC2ResponseError + ) # Remove single user and confirm conn.modify_image_attribute(**REMOVE_SINGLE_USER_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['user_ids'].should.have.length_of(1) - set(attributes.attrs['user_ids']).should.equal(set([str(USER2)])) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["user_ids"].should.have.length_of(1) + set(attributes.attrs["user_ids"]).should.equal(set([str(USER2)])) image = conn.get_image(image_id) image.is_public.should.equal(False) # Remove multiple users and confirm conn.modify_image_attribute(**REMOVE_USERS_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) image = conn.get_image(image_id) image.is_public.should.equal(False) # Remove is idempotent - conn.modify_image_attribute.when.called_with( - **REMOVE_USERS_ARGS).should_not.throw(EC2ResponseError) + conn.modify_image_attribute.when.called_with(**REMOVE_USERS_ARGS).should_not.throw( + EC2ResponseError + ) @mock_ec2 def test_ami_describe_executable_users(): - conn = boto3.client('ec2', region_name='us-east-1') - ec2 = boto3.resource('ec2', 'us-east-1') - ec2.create_instances(ImageId='', - MinCount=1, - MaxCount=1) - response = conn.describe_instances(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]) - instance_id = response['Reservations'][0]['Instances'][0]['InstanceId'] - image_id = conn.create_image(InstanceId=instance_id, - Name='TestImage', )['ImageId'] + conn = boto3.client("ec2", region_name="us-east-1") + ec2 = boto3.resource("ec2", "us-east-1") + ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + response = conn.describe_instances( + Filters=[{"Name": "instance-state-name", "Values": ["running"]}] + ) + instance_id = response["Reservations"][0]["Instances"][0]["InstanceId"] + image_id = conn.create_image(InstanceId=instance_id, Name="TestImage")["ImageId"] - USER1 = '123456789011' + USER1 = "123456789011" - ADD_USER_ARGS = {'ImageId': image_id, - 'Attribute': 'launchPermission', - 'OperationType': 'add', - 'UserIds': [USER1]} + ADD_USER_ARGS = { + "ImageId": image_id, + "Attribute": "launchPermission", + "OperationType": "add", + "UserIds": [USER1], + } # Add users and get no images conn.modify_image_attribute(**ADD_USER_ARGS) - attributes = conn.describe_image_attribute(ImageId=image_id, - Attribute='LaunchPermissions', - DryRun=False) - attributes['LaunchPermissions'].should.have.length_of(1) - attributes['LaunchPermissions'][0]['UserId'].should.equal(USER1) - images = conn.describe_images(ExecutableUsers=[USER1])['Images'] + attributes = conn.describe_image_attribute( + ImageId=image_id, Attribute="LaunchPermissions", DryRun=False + ) + attributes["LaunchPermissions"].should.have.length_of(1) + attributes["LaunchPermissions"][0]["UserId"].should.equal(USER1) + images = conn.describe_images(ExecutableUsers=[USER1])["Images"] images.should.have.length_of(1) - images[0]['ImageId'].should.equal(image_id) + images[0]["ImageId"].should.equal(image_id) @mock_ec2 def test_ami_describe_executable_users_negative(): - conn = boto3.client('ec2', region_name='us-east-1') - ec2 = boto3.resource('ec2', 'us-east-1') - ec2.create_instances(ImageId='', - MinCount=1, - MaxCount=1) - response = conn.describe_instances(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]) - instance_id = response['Reservations'][0]['Instances'][0]['InstanceId'] - image_id = conn.create_image(InstanceId=instance_id, - Name='TestImage')['ImageId'] + conn = boto3.client("ec2", region_name="us-east-1") + ec2 = boto3.resource("ec2", "us-east-1") + ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + response = conn.describe_instances( + Filters=[{"Name": "instance-state-name", "Values": ["running"]}] + ) + instance_id = response["Reservations"][0]["Instances"][0]["InstanceId"] + image_id = conn.create_image(InstanceId=instance_id, Name="TestImage")["ImageId"] - USER1 = '123456789011' - USER2 = '113355789012' + USER1 = "123456789011" + USER2 = "113355789012" - ADD_USER_ARGS = {'ImageId': image_id, - 'Attribute': 'launchPermission', - 'OperationType': 'add', - 'UserIds': [USER1]} + ADD_USER_ARGS = { + "ImageId": image_id, + "Attribute": "launchPermission", + "OperationType": "add", + "UserIds": [USER1], + } # Add users and get no images # Add users and get no images conn.modify_image_attribute(**ADD_USER_ARGS) - attributes = conn.describe_image_attribute(ImageId=image_id, - Attribute='LaunchPermissions', - DryRun=False) - attributes['LaunchPermissions'].should.have.length_of(1) - attributes['LaunchPermissions'][0]['UserId'].should.equal(USER1) - images = conn.describe_images(ExecutableUsers=[USER2])['Images'] + attributes = conn.describe_image_attribute( + ImageId=image_id, Attribute="LaunchPermissions", DryRun=False + ) + attributes["LaunchPermissions"].should.have.length_of(1) + attributes["LaunchPermissions"][0]["UserId"].should.equal(USER1) + images = conn.describe_images(ExecutableUsers=[USER2])["Images"] images.should.have.length_of(0) @mock_ec2 def test_ami_describe_executable_users_and_filter(): - conn = boto3.client('ec2', region_name='us-east-1') - ec2 = boto3.resource('ec2', 'us-east-1') - ec2.create_instances(ImageId='', - MinCount=1, - MaxCount=1) - response = conn.describe_instances(Filters=[{'Name': 'instance-state-name', 'Values': ['running']}]) - instance_id = response['Reservations'][0]['Instances'][0]['InstanceId'] - image_id = conn.create_image(InstanceId=instance_id, - Name='ImageToDelete', )['ImageId'] + conn = boto3.client("ec2", region_name="us-east-1") + ec2 = boto3.resource("ec2", "us-east-1") + ec2.create_instances(ImageId="", MinCount=1, MaxCount=1) + response = conn.describe_instances( + Filters=[{"Name": "instance-state-name", "Values": ["running"]}] + ) + instance_id = response["Reservations"][0]["Instances"][0]["InstanceId"] + image_id = conn.create_image(InstanceId=instance_id, Name="ImageToDelete")[ + "ImageId" + ] - USER1 = '123456789011' + USER1 = "123456789011" - ADD_USER_ARGS = {'ImageId': image_id, - 'Attribute': 'launchPermission', - 'OperationType': 'add', - 'UserIds': [USER1]} + ADD_USER_ARGS = { + "ImageId": image_id, + "Attribute": "launchPermission", + "OperationType": "add", + "UserIds": [USER1], + } # Add users and get no images conn.modify_image_attribute(**ADD_USER_ARGS) - attributes = conn.describe_image_attribute(ImageId=image_id, - Attribute='LaunchPermissions', - DryRun=False) - attributes['LaunchPermissions'].should.have.length_of(1) - attributes['LaunchPermissions'][0]['UserId'].should.equal(USER1) - images = conn.describe_images(ExecutableUsers=[USER1], - Filters=[{'Name': 'state', 'Values': ['available']}])['Images'] + attributes = conn.describe_image_attribute( + ImageId=image_id, Attribute="LaunchPermissions", DryRun=False + ) + attributes["LaunchPermissions"].should.have.length_of(1) + attributes["LaunchPermissions"][0]["UserId"].should.equal(USER1) + images = conn.describe_images( + ExecutableUsers=[USER1], Filters=[{"Name": "state", "Values": ["available"]}] + )["Images"] images.should.have.length_of(1) - images[0]['ImageId'].should.equal(image_id) + images[0]["ImageId"].should.equal(image_id) @mock_ec2_deprecated @@ -575,49 +606,50 @@ def test_ami_attribute_user_and_group_permissions(): Just spot-check this -- input variations, idempotency, etc are validated via user-specific and group-specific tests above. """ - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Baseline - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.name.should.equal('launch_permission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.name.should.equal("launch_permission") attributes.attrs.should.have.length_of(0) - USER1 = '123456789011' - USER2 = '123456789022' + USER1 = "123456789011" + USER2 = "123456789022" - ADD_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'add', - 'groups': ['all'], - 'user_ids': [USER1, USER2]} + ADD_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "add", + "groups": ["all"], + "user_ids": [USER1, USER2], + } - REMOVE_ARGS = {'image_id': image.id, - 'attribute': 'launchPermission', - 'operation': 'remove', - 'groups': ['all'], - 'user_ids': [USER1, USER2]} + REMOVE_ARGS = { + "image_id": image.id, + "attribute": "launchPermission", + "operation": "remove", + "groups": ["all"], + "user_ids": [USER1, USER2], + } # Add and confirm conn.modify_image_attribute(**ADD_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') - attributes.attrs['user_ids'].should.have.length_of(2) - set(attributes.attrs['user_ids']).should.equal(set([USER1, USER2])) - set(attributes.attrs['groups']).should.equal(set(['all'])) + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") + attributes.attrs["user_ids"].should.have.length_of(2) + set(attributes.attrs["user_ids"]).should.equal(set([USER1, USER2])) + set(attributes.attrs["groups"]).should.equal(set(["all"])) image = conn.get_image(image_id) image.is_public.should.equal(True) # Remove and confirm conn.modify_image_attribute(**REMOVE_ARGS) - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) image = conn.get_image(image_id) image.is_public.should.equal(False) @@ -625,130 +657,138 @@ def test_ami_attribute_user_and_group_permissions(): @mock_ec2_deprecated def test_ami_attribute_error_cases(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) # Error: Add with group != 'all' with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - groups='everyone') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, attribute="launchPermission", operation="add", groups="everyone" + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with user ID that isn't an integer. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids='12345678901A') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids="12345678901A", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with user ID that is > length 12. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids='1234567890123') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids="1234567890123", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with user ID that is < length 12. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids='12345678901') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids="12345678901", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with one invalid user ID among other valid IDs, ensure no # partial changes. with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute(image.id, - attribute='launchPermission', - operation='add', - user_ids=['123456789011', 'foo', '123456789022']) - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_image_attribute( + image.id, + attribute="launchPermission", + operation="add", + user_ids=["123456789011", "foo", "123456789022"], + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none - attributes = conn.get_image_attribute( - image.id, attribute='launchPermission') + attributes = conn.get_image_attribute(image.id, attribute="launchPermission") attributes.attrs.should.have.length_of(0) # Error: Add with invalid image ID with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute("ami-abcd1234", - attribute='launchPermission', - operation='add', - groups='all') - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.modify_image_attribute( + "ami-abcd1234", attribute="launchPermission", operation="add", groups="all" + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Remove with invalid image ID with assert_raises(EC2ResponseError) as cm: - conn.modify_image_attribute("ami-abcd1234", - attribute='launchPermission', - operation='remove', - groups='all') - cm.exception.code.should.equal('InvalidAMIID.NotFound') + conn.modify_image_attribute( + "ami-abcd1234", + attribute="launchPermission", + operation="remove", + groups="all", + ) + cm.exception.code.should.equal("InvalidAMIID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_ami_describe_non_existent(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Valid pattern but non-existent id - img = ec2.Image('ami-abcd1234') + img = ec2.Image("ami-abcd1234") with assert_raises(ClientError): img.load() # Invalid ami pattern - img = ec2.Image('not_an_ami_id') + img = ec2.Image("not_an_ami_id") with assert_raises(ClientError): img.load() @mock_ec2 def test_ami_filter_wildcard(): - ec2_resource = boto3.resource('ec2', region_name='us-west-1') - ec2_client = boto3.client('ec2', region_name='us-west-1') + ec2_resource = boto3.resource("ec2", region_name="us-west-1") + ec2_client = boto3.client("ec2", region_name="us-west-1") - instance = ec2_resource.create_instances(ImageId='ami-1234abcd', MinCount=1, MaxCount=1)[0] - instance.create_image(Name='test-image') + instance = ec2_resource.create_instances( + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 + )[0] + instance.create_image(Name="test-image") # create an image with the same owner but will not match the filter - instance.create_image(Name='not-matching-image') + instance.create_image(Name="not-matching-image") my_images = ec2_client.describe_images( - Owners=['111122223333'], - Filters=[{'Name': 'name', 'Values': ['test*']}] - )['Images'] + Owners=["111122223333"], Filters=[{"Name": "name", "Values": ["test*"]}] + )["Images"] my_images.should.have.length_of(1) @mock_ec2 def test_ami_filter_by_owner_id(): - client = boto3.client('ec2', region_name='us-east-1') + client = boto3.client("ec2", region_name="us-east-1") - ubuntu_id = '099720109477' + ubuntu_id = "099720109477" ubuntu_images = client.describe_images(Owners=[ubuntu_id]) all_images = client.describe_images() - ubuntu_ids = [ami['OwnerId'] for ami in ubuntu_images['Images']] - all_ids = [ami['OwnerId'] for ami in all_images['Images']] + ubuntu_ids = [ami["OwnerId"] for ami in ubuntu_images["Images"]] + all_ids = [ami["OwnerId"] for ami in all_images["Images"]] # Assert all ubuntu_ids are the same and one equals ubuntu_id assert all(ubuntu_ids) and ubuntu_ids[0] == ubuntu_id @@ -758,42 +798,42 @@ def test_ami_filter_by_owner_id(): @mock_ec2 def test_ami_filter_by_self(): - ec2_resource = boto3.resource('ec2', region_name='us-west-1') - ec2_client = boto3.client('ec2', region_name='us-west-1') + ec2_resource = boto3.resource("ec2", region_name="us-west-1") + ec2_client = boto3.client("ec2", region_name="us-west-1") - my_images = ec2_client.describe_images(Owners=['self'])['Images'] + my_images = ec2_client.describe_images(Owners=["self"])["Images"] my_images.should.have.length_of(0) # Create a new image - instance = ec2_resource.create_instances(ImageId='ami-1234abcd', MinCount=1, MaxCount=1)[0] - instance.create_image(Name='test-image') + instance = ec2_resource.create_instances( + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 + )[0] + instance.create_image(Name="test-image") - my_images = ec2_client.describe_images(Owners=['self'])['Images'] + my_images = ec2_client.describe_images(Owners=["self"])["Images"] my_images.should.have.length_of(1) @mock_ec2 def test_ami_snapshots_have_correct_owner(): - ec2_client = boto3.client('ec2', region_name='us-west-1') + ec2_client = boto3.client("ec2", region_name="us-west-1") images_response = ec2_client.describe_images() owner_id_to_snapshot_ids = {} - for image in images_response['Images']: - owner_id = image['OwnerId'] + for image in images_response["Images"]: + owner_id = image["OwnerId"] snapshot_ids = [ - block_device_mapping['Ebs']['SnapshotId'] - for block_device_mapping in image['BlockDeviceMappings'] + block_device_mapping["Ebs"]["SnapshotId"] + for block_device_mapping in image["BlockDeviceMappings"] ] existing_snapshot_ids = owner_id_to_snapshot_ids.get(owner_id, []) - owner_id_to_snapshot_ids[owner_id] = ( - existing_snapshot_ids + snapshot_ids - ) + owner_id_to_snapshot_ids[owner_id] = existing_snapshot_ids + snapshot_ids for owner_id in owner_id_to_snapshot_ids: snapshots_rseponse = ec2_client.describe_snapshots( SnapshotIds=owner_id_to_snapshot_ids[owner_id] ) - for snapshot in snapshots_rseponse['Snapshots']: - assert owner_id == snapshot['OwnerId'] + for snapshot in snapshots_rseponse["Snapshots"]: + assert owner_id == snapshot["OwnerId"] diff --git a/tests/test_ec2/test_availability_zones_and_regions.py b/tests/test_ec2/test_availability_zones_and_regions.py index c64f075ca..349be7936 100644 --- a/tests/test_ec2/test_availability_zones_and_regions.py +++ b/tests/test_ec2/test_availability_zones_and_regions.py @@ -9,7 +9,7 @@ from moto import mock_ec2, mock_ec2_deprecated @mock_ec2_deprecated def test_describe_regions(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") regions = conn.get_all_regions() regions.should.have.length_of(16) for region in regions: @@ -18,7 +18,7 @@ def test_describe_regions(): @mock_ec2_deprecated def test_availability_zones(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") regions = conn.get_all_regions() for region in regions: conn = boto.ec2.connect_to_region(region.name) @@ -30,25 +30,25 @@ def test_availability_zones(): @mock_ec2 def test_boto3_describe_regions(): - ec2 = boto3.client('ec2', 'us-east-1') + ec2 = boto3.client("ec2", "us-east-1") resp = ec2.describe_regions() - resp['Regions'].should.have.length_of(16) - for rec in resp['Regions']: - rec['Endpoint'].should.contain(rec['RegionName']) + resp["Regions"].should.have.length_of(16) + for rec in resp["Regions"]: + rec["Endpoint"].should.contain(rec["RegionName"]) - test_region = 'us-east-1' + test_region = "us-east-1" resp = ec2.describe_regions(RegionNames=[test_region]) - resp['Regions'].should.have.length_of(1) - resp['Regions'][0].should.have.key('RegionName').which.should.equal(test_region) + resp["Regions"].should.have.length_of(1) + resp["Regions"][0].should.have.key("RegionName").which.should.equal(test_region) @mock_ec2 def test_boto3_availability_zones(): - ec2 = boto3.client('ec2', 'us-east-1') + ec2 = boto3.client("ec2", "us-east-1") resp = ec2.describe_regions() - regions = [r['RegionName'] for r in resp['Regions']] + regions = [r["RegionName"] for r in resp["Regions"]] for region in regions: - conn = boto3.client('ec2', region) + conn = boto3.client("ec2", region) resp = conn.describe_availability_zones() - for rec in resp['AvailabilityZones']: - rec['ZoneName'].should.contain(region) + for rec in resp["AvailabilityZones"]: + rec["ZoneName"].should.contain(region) diff --git a/tests/test_ec2/test_customer_gateways.py b/tests/test_ec2/test_customer_gateways.py index 589f887f6..a676a2b5d 100644 --- a/tests/test_ec2/test_customer_gateways.py +++ b/tests/test_ec2/test_customer_gateways.py @@ -10,22 +10,20 @@ from moto import mock_ec2_deprecated @mock_ec2_deprecated def test_create_customer_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - customer_gateway = conn.create_customer_gateway( - 'ipsec.1', '205.251.242.54', 65534) + customer_gateway = conn.create_customer_gateway("ipsec.1", "205.251.242.54", 65534) customer_gateway.should_not.be.none - customer_gateway.id.should.match(r'cgw-\w+') - customer_gateway.type.should.equal('ipsec.1') + customer_gateway.id.should.match(r"cgw-\w+") + customer_gateway.type.should.equal("ipsec.1") customer_gateway.bgp_asn.should.equal(65534) - customer_gateway.ip_address.should.equal('205.251.242.54') + customer_gateway.ip_address.should.equal("205.251.242.54") @mock_ec2_deprecated def test_describe_customer_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') - customer_gateway = conn.create_customer_gateway( - 'ipsec.1', '205.251.242.54', 65534) + conn = boto.connect_vpc("the_key", "the_secret") + customer_gateway = conn.create_customer_gateway("ipsec.1", "205.251.242.54", 65534) cgws = conn.get_all_customer_gateways() cgws.should.have.length_of(1) cgws[0].id.should.match(customer_gateway.id) @@ -33,10 +31,9 @@ def test_describe_customer_gateways(): @mock_ec2_deprecated def test_delete_customer_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - customer_gateway = conn.create_customer_gateway( - 'ipsec.1', '205.251.242.54', 65534) + customer_gateway = conn.create_customer_gateway("ipsec.1", "205.251.242.54", 65534) customer_gateway.should_not.be.none cgws = conn.get_all_customer_gateways() cgws[0].id.should.match(customer_gateway.id) @@ -47,6 +44,6 @@ def test_delete_customer_gateways(): @mock_ec2_deprecated def test_delete_customer_gateways_bad_id(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.delete_customer_gateway('cgw-0123abcd') + conn.delete_customer_gateway("cgw-0123abcd") diff --git a/tests/test_ec2/test_dhcp_options.py b/tests/test_ec2/test_dhcp_options.py index 4e2520241..4aaceaa07 100644 --- a/tests/test_ec2/test_dhcp_options.py +++ b/tests/test_ec2/test_dhcp_options.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -11,16 +12,15 @@ import sure # noqa from moto import mock_ec2, mock_ec2_deprecated -SAMPLE_DOMAIN_NAME = u'example.com' -SAMPLE_NAME_SERVERS = [u'10.0.0.6', u'10.0.0.7'] +SAMPLE_DOMAIN_NAME = "example.com" +SAMPLE_NAME_SERVERS = ["10.0.0.6", "10.0.0.7"] @mock_ec2_deprecated def test_dhcp_options_associate(): """ associate dhcp option """ - conn = boto.connect_vpc('the_key', 'the_secret') - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + conn = boto.connect_vpc("the_key", "the_secret") + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) vpc = conn.create_vpc("10.0.0.0/16") rval = conn.associate_dhcp_options(dhcp_options.id, vpc.id) @@ -30,12 +30,12 @@ def test_dhcp_options_associate(): @mock_ec2_deprecated def test_dhcp_options_associate_invalid_dhcp_id(): """ associate dhcp option bad dhcp options id """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") with assert_raises(EC2ResponseError) as cm: conn.associate_dhcp_options("foo", vpc.id) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -43,13 +43,12 @@ def test_dhcp_options_associate_invalid_dhcp_id(): @mock_ec2_deprecated def test_dhcp_options_associate_invalid_vpc_id(): """ associate dhcp option invalid vpc id """ - conn = boto.connect_vpc('the_key', 'the_secret') - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + conn = boto.connect_vpc("the_key", "the_secret") + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) with assert_raises(EC2ResponseError) as cm: conn.associate_dhcp_options(dhcp_options.id, "foo") - cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -57,9 +56,8 @@ def test_dhcp_options_associate_invalid_vpc_id(): @mock_ec2_deprecated def test_dhcp_options_delete_with_vpc(): """Test deletion of dhcp options with vpc""" - conn = boto.connect_vpc('the_key', 'the_secret') - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + conn = boto.connect_vpc("the_key", "the_secret") + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) dhcp_options_id = dhcp_options.id vpc = conn.create_vpc("10.0.0.0/16") @@ -68,7 +66,7 @@ def test_dhcp_options_delete_with_vpc(): with assert_raises(EC2ResponseError) as cm: conn.delete_dhcp_options(dhcp_options_id) - cm.exception.code.should.equal('DependencyViolation') + cm.exception.code.should.equal("DependencyViolation") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -76,7 +74,7 @@ def test_dhcp_options_delete_with_vpc(): with assert_raises(EC2ResponseError) as cm: conn.get_all_dhcp_options([dhcp_options_id]) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -84,32 +82,33 @@ def test_dhcp_options_delete_with_vpc(): @mock_ec2_deprecated def test_create_dhcp_options(): """Create most basic dhcp option""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - dhcp_option = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) - dhcp_option.options[u'domain-name'][0].should.be.equal(SAMPLE_DOMAIN_NAME) - dhcp_option.options[ - u'domain-name-servers'][0].should.be.equal(SAMPLE_NAME_SERVERS[0]) - dhcp_option.options[ - u'domain-name-servers'][1].should.be.equal(SAMPLE_NAME_SERVERS[1]) + dhcp_option = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + dhcp_option.options["domain-name"][0].should.be.equal(SAMPLE_DOMAIN_NAME) + dhcp_option.options["domain-name-servers"][0].should.be.equal( + SAMPLE_NAME_SERVERS[0] + ) + dhcp_option.options["domain-name-servers"][1].should.be.equal( + SAMPLE_NAME_SERVERS[1] + ) @mock_ec2_deprecated def test_create_dhcp_options_invalid_options(): """Create invalid dhcp options""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") servers = ["f", "f", "f", "f", "f"] with assert_raises(EC2ResponseError) as cm: conn.create_dhcp_options(ntp_servers=servers) - cm.exception.code.should.equal('InvalidParameterValue') + cm.exception.code.should.equal("InvalidParameterValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as cm: conn.create_dhcp_options(netbios_node_type="0") - cm.exception.code.should.equal('InvalidParameterValue') + cm.exception.code.should.equal("InvalidParameterValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -117,7 +116,7 @@ def test_create_dhcp_options_invalid_options(): @mock_ec2_deprecated def test_describe_dhcp_options(): """Test dhcp options lookup by id""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") dhcp_option = conn.create_dhcp_options() dhcp_options = conn.get_all_dhcp_options([dhcp_option.id]) @@ -130,11 +129,11 @@ def test_describe_dhcp_options(): @mock_ec2_deprecated def test_describe_dhcp_options_invalid_id(): """get error on invalid dhcp_option_id lookup""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.get_all_dhcp_options(["1"]) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -142,7 +141,7 @@ def test_describe_dhcp_options_invalid_id(): @mock_ec2_deprecated def test_delete_dhcp_options(): """delete dhcp option""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") dhcp_option = conn.create_dhcp_options() dhcp_options = conn.get_all_dhcp_options([dhcp_option.id]) @@ -152,40 +151,40 @@ def test_delete_dhcp_options(): with assert_raises(EC2ResponseError) as cm: conn.get_all_dhcp_options([dhcp_option.id]) - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_delete_dhcp_options_invalid_id(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") conn.create_dhcp_options() with assert_raises(EC2ResponseError) as cm: conn.delete_dhcp_options("dopt-abcd1234") - cm.exception.code.should.equal('InvalidDhcpOptionID.NotFound') + cm.exception.code.should.equal("InvalidDhcpOptionID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_delete_dhcp_options_malformed_id(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") conn.create_dhcp_options() with assert_raises(EC2ResponseError) as cm: conn.delete_dhcp_options("foo-abcd1234") - cm.exception.code.should.equal('InvalidDhcpOptionsId.Malformed') + cm.exception.code.should.equal("InvalidDhcpOptionsId.Malformed") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_dhcp_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") dhcp_option = conn.create_dhcp_options() dhcp_option.add_tag("a key", "some value") @@ -202,39 +201,35 @@ def test_dhcp_tagging(): @mock_ec2_deprecated def test_dhcp_options_get_by_tag(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - dhcp1 = conn.create_dhcp_options('example.com', ['10.0.10.2']) - dhcp1.add_tag('Name', 'TestDhcpOptions1') - dhcp1.add_tag('test-tag', 'test-value') + dhcp1 = conn.create_dhcp_options("example.com", ["10.0.10.2"]) + dhcp1.add_tag("Name", "TestDhcpOptions1") + dhcp1.add_tag("test-tag", "test-value") - dhcp2 = conn.create_dhcp_options('example.com', ['10.0.20.2']) - dhcp2.add_tag('Name', 'TestDhcpOptions2') - dhcp2.add_tag('test-tag', 'test-value') + dhcp2 = conn.create_dhcp_options("example.com", ["10.0.20.2"]) + dhcp2.add_tag("Name", "TestDhcpOptions2") + dhcp2.add_tag("test-tag", "test-value") - filters = {'tag:Name': 'TestDhcpOptions1', 'tag:test-tag': 'test-value'} + filters = {"tag:Name": "TestDhcpOptions1", "tag:test-tag": "test-value"} dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options[ - 'domain-name'][0].should.be.equal('example.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.10.2') - dhcp_options_sets[0].tags['Name'].should.equal('TestDhcpOptions1') - dhcp_options_sets[0].tags['test-tag'].should.equal('test-value') + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("example.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.10.2") + dhcp_options_sets[0].tags["Name"].should.equal("TestDhcpOptions1") + dhcp_options_sets[0].tags["test-tag"].should.equal("test-value") - filters = {'tag:Name': 'TestDhcpOptions2', 'tag:test-tag': 'test-value'} + filters = {"tag:Name": "TestDhcpOptions2", "tag:test-tag": "test-value"} dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options[ - 'domain-name'][0].should.be.equal('example.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.20.2') - dhcp_options_sets[0].tags['Name'].should.equal('TestDhcpOptions2') - dhcp_options_sets[0].tags['test-tag'].should.equal('test-value') + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("example.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.20.2") + dhcp_options_sets[0].tags["Name"].should.equal("TestDhcpOptions2") + dhcp_options_sets[0].tags["test-tag"].should.equal("test-value") - filters = {'tag:test-tag': 'test-value'} + filters = {"tag:test-tag": "test-value"} dhcp_options_sets = conn.get_all_dhcp_options(filters=filters) dhcp_options_sets.should.have.length_of(2) @@ -242,92 +237,101 @@ def test_dhcp_options_get_by_tag(): @mock_ec2_deprecated def test_dhcp_options_get_by_id(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - dhcp1 = conn.create_dhcp_options('test1.com', ['10.0.10.2']) - dhcp1.add_tag('Name', 'TestDhcpOptions1') - dhcp1.add_tag('test-tag', 'test-value') + dhcp1 = conn.create_dhcp_options("test1.com", ["10.0.10.2"]) + dhcp1.add_tag("Name", "TestDhcpOptions1") + dhcp1.add_tag("test-tag", "test-value") dhcp1_id = dhcp1.id - dhcp2 = conn.create_dhcp_options('test2.com', ['10.0.20.2']) - dhcp2.add_tag('Name', 'TestDhcpOptions2') - dhcp2.add_tag('test-tag', 'test-value') + dhcp2 = conn.create_dhcp_options("test2.com", ["10.0.20.2"]) + dhcp2.add_tag("Name", "TestDhcpOptions2") + dhcp2.add_tag("test-tag", "test-value") dhcp2_id = dhcp2.id dhcp_options_sets = conn.get_all_dhcp_options() dhcp_options_sets.should.have.length_of(2) - dhcp_options_sets = conn.get_all_dhcp_options( - filters={'dhcp-options-id': dhcp1_id}) + dhcp_options_sets = conn.get_all_dhcp_options(filters={"dhcp-options-id": dhcp1_id}) dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options['domain-name'][0].should.be.equal('test1.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.10.2') + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("test1.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.10.2") - dhcp_options_sets = conn.get_all_dhcp_options( - filters={'dhcp-options-id': dhcp2_id}) + dhcp_options_sets = conn.get_all_dhcp_options(filters={"dhcp-options-id": dhcp2_id}) dhcp_options_sets.should.have.length_of(1) - dhcp_options_sets[0].options['domain-name'][0].should.be.equal('test2.com') - dhcp_options_sets[0].options[ - 'domain-name-servers'][0].should.be.equal('10.0.20.2') + dhcp_options_sets[0].options["domain-name"][0].should.be.equal("test2.com") + dhcp_options_sets[0].options["domain-name-servers"][0].should.be.equal("10.0.20.2") @mock_ec2 def test_dhcp_options_get_by_value_filter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.10.2']} - ]) + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.10.2"]}, + ] + ) - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.20.2']} - ]) + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.20.2"]}, + ] + ) - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.30.2']} - ]) + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.30.2"]}, + ] + ) - filters = [{'Name': 'value', 'Values': ['10.0.10.2']}] + filters = [{"Name": "value", "Values": ["10.0.10.2"]}] dhcp_options_sets = list(ec2.dhcp_options_sets.filter(Filters=filters)) dhcp_options_sets.should.have.length_of(1) @mock_ec2 def test_dhcp_options_get_by_key_filter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.10.2']} - ]) + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.10.2"]}, + ] + ) - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.20.2']} - ]) + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.20.2"]}, + ] + ) - ec2.create_dhcp_options(DhcpConfigurations=[ - {'Key': 'domain-name', 'Values': ['example.com']}, - {'Key': 'domain-name-servers', 'Values': ['10.0.30.2']} - ]) + ec2.create_dhcp_options( + DhcpConfigurations=[ + {"Key": "domain-name", "Values": ["example.com"]}, + {"Key": "domain-name-servers", "Values": ["10.0.30.2"]}, + ] + ) - filters = [{'Name': 'key', 'Values': ['domain-name']}] + filters = [{"Name": "key", "Values": ["domain-name"]}] dhcp_options_sets = list(ec2.dhcp_options_sets.filter(Filters=filters)) dhcp_options_sets.should.have.length_of(3) @mock_ec2_deprecated def test_dhcp_options_get_by_invalid_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) - filters = {'invalid-filter': 'invalid-value'} + filters = {"invalid-filter": "invalid-value"} - conn.get_all_dhcp_options.when.called_with( - filters=filters).should.throw(NotImplementedError) + conn.get_all_dhcp_options.when.called_with(filters=filters).should.throw( + NotImplementedError + ) diff --git a/tests/test_ec2/test_elastic_block_store.py b/tests/test_ec2/test_elastic_block_store.py index 9dbaa5ea6..3c7e17ec8 100644 --- a/tests/test_ec2/test_elastic_block_store.py +++ b/tests/test_ec2/test_elastic_block_store.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -32,10 +33,11 @@ def test_create_and_delete_volume(): with assert_raises(EC2ResponseError) as ex: volume.delete(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteVolume operation: Request would have succeeded, but DryRun flag is set" + ) volume.delete() @@ -46,7 +48,7 @@ def test_create_and_delete_volume(): # Deleting something that was already deleted should throw an error with assert_raises(EC2ResponseError) as cm: volume.delete() - cm.exception.code.should.equal('InvalidVolume.NotFound') + cm.exception.code.should.equal("InvalidVolume.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -56,10 +58,11 @@ def test_create_encrypted_volume_dryrun(): conn = boto.ec2.connect_to_region("us-east-1") with assert_raises(EC2ResponseError) as ex: conn.create_volume(80, "us-east-1a", encrypted=True, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set" + ) @mock_ec2_deprecated @@ -69,10 +72,11 @@ def test_create_encrypted_volume(): with assert_raises(EC2ResponseError) as ex: conn.create_volume(80, "us-east-1a", encrypted=True, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateVolume operation: Request would have succeeded, but DryRun flag is set" + ) all_volumes = [vol for vol in conn.get_all_volumes() if vol.id == volume.id] all_volumes[0].encrypted.should.be(True) @@ -87,13 +91,13 @@ def test_filter_volume_by_id(): vol1 = conn.get_all_volumes(volume_ids=volume3.id) vol1.should.have.length_of(1) vol1[0].size.should.equal(20) - vol1[0].zone.should.equal('us-east-1c') + vol1[0].zone.should.equal("us-east-1c") vol2 = conn.get_all_volumes(volume_ids=[volume1.id, volume2.id]) vol2.should.have.length_of(2) with assert_raises(EC2ResponseError) as cm: - conn.get_all_volumes(volume_ids=['vol-does_not_exist']) - cm.exception.code.should.equal('InvalidVolume.NotFound') + conn.get_all_volumes(volume_ids=["vol-does_not_exist"]) + cm.exception.code.should.equal("InvalidVolume.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -102,7 +106,7 @@ def test_filter_volume_by_id(): def test_volume_filters(): conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.update() @@ -111,142 +115,155 @@ def test_volume_filters(): volume2 = conn.create_volume(36, "us-east-1b", encrypted=False) volume3 = conn.create_volume(20, "us-east-1c", encrypted=True) - snapshot = volume3.create_snapshot(description='testsnap') + snapshot = volume3.create_snapshot(description="testsnap") volume4 = conn.create_volume(25, "us-east-1a", snapshot=snapshot) - conn.create_tags([volume1.id], {'testkey1': 'testvalue1'}) - conn.create_tags([volume2.id], {'testkey2': 'testvalue2'}) + conn.create_tags([volume1.id], {"testkey1": "testvalue1"}) + conn.create_tags([volume2.id], {"testkey2": "testvalue2"}) volume1.update() volume2.update() volume3.update() volume4.update() - block_mapping = instance.block_device_mapping['/dev/sda1'] + block_mapping = instance.block_device_mapping["/dev/sda1"] - volume_ids = (volume1.id, volume2.id, volume3.id, volume4.id, block_mapping.volume_id) - - volumes_by_attach_time = conn.get_all_volumes( - filters={'attachment.attach-time': block_mapping.attach_time}) - set([vol.id for vol in volumes_by_attach_time] - ).should.equal({block_mapping.volume_id}) - - volumes_by_attach_device = conn.get_all_volumes( - filters={'attachment.device': '/dev/sda1'}) - set([vol.id for vol in volumes_by_attach_device] - ).should.equal({block_mapping.volume_id}) - - volumes_by_attach_instance_id = conn.get_all_volumes( - filters={'attachment.instance-id': instance.id}) - set([vol.id for vol in volumes_by_attach_instance_id] - ).should.equal({block_mapping.volume_id}) - - volumes_by_attach_status = conn.get_all_volumes( - filters={'attachment.status': 'attached'}) - set([vol.id for vol in volumes_by_attach_status] - ).should.equal({block_mapping.volume_id}) - - volumes_by_create_time = conn.get_all_volumes( - filters={'create-time': volume4.create_time}) - set([vol.create_time for vol in volumes_by_create_time] - ).should.equal({volume4.create_time}) - - volumes_by_size = conn.get_all_volumes(filters={'size': volume2.size}) - set([vol.id for vol in volumes_by_size]).should.equal({volume2.id}) - - volumes_by_snapshot_id = conn.get_all_volumes( - filters={'snapshot-id': snapshot.id}) - set([vol.id for vol in volumes_by_snapshot_id] - ).should.equal({volume4.id}) - - volumes_by_status = conn.get_all_volumes(filters={'status': 'in-use'}) - set([vol.id for vol in volumes_by_status]).should.equal( - {block_mapping.volume_id}) - - volumes_by_id = conn.get_all_volumes(filters={'volume-id': volume1.id}) - set([vol.id for vol in volumes_by_id]).should.equal({volume1.id}) - - volumes_by_tag_key = conn.get_all_volumes(filters={'tag-key': 'testkey1'}) - set([vol.id for vol in volumes_by_tag_key]).should.equal({volume1.id}) - - volumes_by_tag_value = conn.get_all_volumes( - filters={'tag-value': 'testvalue1'}) - set([vol.id for vol in volumes_by_tag_value] - ).should.equal({volume1.id}) - - volumes_by_tag = conn.get_all_volumes( - filters={'tag:testkey1': 'testvalue1'}) - set([vol.id for vol in volumes_by_tag]).should.equal({volume1.id}) - - volumes_by_unencrypted = conn.get_all_volumes( - filters={'encrypted': 'false'}) - set([vol.id for vol in volumes_by_unencrypted if vol.id in volume_ids]).should.equal( - {block_mapping.volume_id, volume2.id} + volume_ids = ( + volume1.id, + volume2.id, + volume3.id, + volume4.id, + block_mapping.volume_id, ) - volumes_by_encrypted = conn.get_all_volumes(filters={'encrypted': 'true'}) + volumes_by_attach_time = conn.get_all_volumes( + filters={"attachment.attach-time": block_mapping.attach_time} + ) + set([vol.id for vol in volumes_by_attach_time]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_attach_device = conn.get_all_volumes( + filters={"attachment.device": "/dev/sda1"} + ) + set([vol.id for vol in volumes_by_attach_device]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_attach_instance_id = conn.get_all_volumes( + filters={"attachment.instance-id": instance.id} + ) + set([vol.id for vol in volumes_by_attach_instance_id]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_attach_status = conn.get_all_volumes( + filters={"attachment.status": "attached"} + ) + set([vol.id for vol in volumes_by_attach_status]).should.equal( + {block_mapping.volume_id} + ) + + volumes_by_create_time = conn.get_all_volumes( + filters={"create-time": volume4.create_time} + ) + set([vol.create_time for vol in volumes_by_create_time]).should.equal( + {volume4.create_time} + ) + + volumes_by_size = conn.get_all_volumes(filters={"size": volume2.size}) + set([vol.id for vol in volumes_by_size]).should.equal({volume2.id}) + + volumes_by_snapshot_id = conn.get_all_volumes(filters={"snapshot-id": snapshot.id}) + set([vol.id for vol in volumes_by_snapshot_id]).should.equal({volume4.id}) + + volumes_by_status = conn.get_all_volumes(filters={"status": "in-use"}) + set([vol.id for vol in volumes_by_status]).should.equal({block_mapping.volume_id}) + + volumes_by_id = conn.get_all_volumes(filters={"volume-id": volume1.id}) + set([vol.id for vol in volumes_by_id]).should.equal({volume1.id}) + + volumes_by_tag_key = conn.get_all_volumes(filters={"tag-key": "testkey1"}) + set([vol.id for vol in volumes_by_tag_key]).should.equal({volume1.id}) + + volumes_by_tag_value = conn.get_all_volumes(filters={"tag-value": "testvalue1"}) + set([vol.id for vol in volumes_by_tag_value]).should.equal({volume1.id}) + + volumes_by_tag = conn.get_all_volumes(filters={"tag:testkey1": "testvalue1"}) + set([vol.id for vol in volumes_by_tag]).should.equal({volume1.id}) + + volumes_by_unencrypted = conn.get_all_volumes(filters={"encrypted": "false"}) + set( + [vol.id for vol in volumes_by_unencrypted if vol.id in volume_ids] + ).should.equal({block_mapping.volume_id, volume2.id}) + + volumes_by_encrypted = conn.get_all_volumes(filters={"encrypted": "true"}) set([vol.id for vol in volumes_by_encrypted if vol.id in volume_ids]).should.equal( {volume1.id, volume3.id, volume4.id} ) - volumes_by_availability_zone = conn.get_all_volumes(filters={'availability-zone': 'us-east-1b'}) - set([vol.id for vol in volumes_by_availability_zone if vol.id in volume_ids]).should.equal( - {volume2.id} + volumes_by_availability_zone = conn.get_all_volumes( + filters={"availability-zone": "us-east-1b"} ) + set( + [vol.id for vol in volumes_by_availability_zone if vol.id in volume_ids] + ).should.equal({volume2.id}) @mock_ec2_deprecated def test_volume_attach_and_detach(): conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] volume = conn.create_volume(80, "us-east-1a") volume.update() - volume.volume_state().should.equal('available') + volume.volume_state().should.equal("available") with assert_raises(EC2ResponseError) as ex: volume.attach(instance.id, "/dev/sdh", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AttachVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AttachVolume operation: Request would have succeeded, but DryRun flag is set" + ) volume.attach(instance.id, "/dev/sdh") volume.update() - volume.volume_state().should.equal('in-use') - volume.attachment_state().should.equal('attached') + volume.volume_state().should.equal("in-use") + volume.attachment_state().should.equal("attached") volume.attach_data.instance_id.should.equal(instance.id) with assert_raises(EC2ResponseError) as ex: volume.detach(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DetachVolume operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DetachVolume operation: Request would have succeeded, but DryRun flag is set" + ) volume.detach() volume.update() - volume.volume_state().should.equal('available') + volume.volume_state().should.equal("available") with assert_raises(EC2ResponseError) as cm1: - volume.attach('i-1234abcd', "/dev/sdh") - cm1.exception.code.should.equal('InvalidInstanceID.NotFound') + volume.attach("i-1234abcd", "/dev/sdh") + cm1.exception.code.should.equal("InvalidInstanceID.NotFound") cm1.exception.status.should.equal(400) cm1.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as cm2: conn.detach_volume(volume.id, instance.id, "/dev/sdh") - cm2.exception.code.should.equal('InvalidAttachment.NotFound') + cm2.exception.code.should.equal("InvalidAttachment.NotFound") cm2.exception.status.should.equal(400) cm2.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as cm3: - conn.detach_volume(volume.id, 'i-1234abcd', "/dev/sdh") - cm3.exception.code.should.equal('InvalidInstanceID.NotFound') + conn.detach_volume(volume.id, "i-1234abcd", "/dev/sdh") + cm3.exception.code.should.equal("InvalidInstanceID.NotFound") cm3.exception.status.should.equal(400) cm3.exception.request_id.should_not.be.none @@ -257,19 +274,20 @@ def test_create_snapshot(): volume = conn.create_volume(80, "us-east-1a") with assert_raises(EC2ResponseError) as ex: - snapshot = volume.create_snapshot('a dryrun snapshot', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + snapshot = volume.create_snapshot("a dryrun snapshot", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set" + ) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") snapshots = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] snapshots.should.have.length_of(1) - snapshots[0].description.should.equal('a test snapshot') + snapshots[0].description.should.equal("a test snapshot") snapshots[0].start_time.should_not.be.none snapshots[0].encrypted.should.be(False) @@ -285,7 +303,7 @@ def test_create_snapshot(): # Deleting something that was already deleted should throw an error with assert_raises(EC2ResponseError) as cm: snapshot.delete() - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -294,13 +312,13 @@ def test_create_snapshot(): def test_create_encrypted_snapshot(): conn = boto.ec2.connect_to_region("us-east-1") volume = conn.create_volume(80, "us-east-1a", encrypted=True) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") snapshots = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] snapshots.should.have.length_of(1) - snapshots[0].description.should.equal('a test snapshot') + snapshots[0].description.should.equal("a test snapshot") snapshots[0].start_time.should_not.be.none snapshots[0].encrypted.should.be(True) @@ -309,11 +327,11 @@ def test_create_encrypted_snapshot(): def test_filter_snapshot_by_id(): conn = boto.ec2.connect_to_region("us-east-1") volume1 = conn.create_volume(36, "us-east-1a") - snap1 = volume1.create_snapshot('a test snapshot 1') - volume2 = conn.create_volume(42, 'us-east-1a') - snap2 = volume2.create_snapshot('a test snapshot 2') - volume3 = conn.create_volume(84, 'us-east-1a') - snap3 = volume3.create_snapshot('a test snapshot 3') + snap1 = volume1.create_snapshot("a test snapshot 1") + volume2 = conn.create_volume(42, "us-east-1a") + snap2 = volume2.create_snapshot("a test snapshot 2") + volume3 = conn.create_volume(84, "us-east-1a") + snap3 = volume3.create_snapshot("a test snapshot 3") snapshots1 = conn.get_all_snapshots(snapshot_ids=snap2.id) snapshots1.should.have.length_of(1) snapshots1[0].volume_id.should.equal(volume2.id) @@ -326,8 +344,8 @@ def test_filter_snapshot_by_id(): s.region.name.should.equal(conn.region.name) with assert_raises(EC2ResponseError) as cm: - conn.get_all_snapshots(snapshot_ids=['snap-does_not_exist']) - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + conn.get_all_snapshots(snapshot_ids=["snap-does_not_exist"]) + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -338,67 +356,62 @@ def test_snapshot_filters(): volume1 = conn.create_volume(20, "us-east-1a", encrypted=False) volume2 = conn.create_volume(25, "us-east-1a", encrypted=True) - snapshot1 = volume1.create_snapshot(description='testsnapshot1') - snapshot2 = volume1.create_snapshot(description='testsnapshot2') - snapshot3 = volume2.create_snapshot(description='testsnapshot3') + snapshot1 = volume1.create_snapshot(description="testsnapshot1") + snapshot2 = volume1.create_snapshot(description="testsnapshot2") + snapshot3 = volume2.create_snapshot(description="testsnapshot3") - conn.create_tags([snapshot1.id], {'testkey1': 'testvalue1'}) - conn.create_tags([snapshot2.id], {'testkey2': 'testvalue2'}) + conn.create_tags([snapshot1.id], {"testkey1": "testvalue1"}) + conn.create_tags([snapshot2.id], {"testkey2": "testvalue2"}) snapshots_by_description = conn.get_all_snapshots( - filters={'description': 'testsnapshot1'}) - set([snap.id for snap in snapshots_by_description] - ).should.equal({snapshot1.id}) + filters={"description": "testsnapshot1"} + ) + set([snap.id for snap in snapshots_by_description]).should.equal({snapshot1.id}) - snapshots_by_id = conn.get_all_snapshots( - filters={'snapshot-id': snapshot1.id}) - set([snap.id for snap in snapshots_by_id] - ).should.equal({snapshot1.id}) + snapshots_by_id = conn.get_all_snapshots(filters={"snapshot-id": snapshot1.id}) + set([snap.id for snap in snapshots_by_id]).should.equal({snapshot1.id}) snapshots_by_start_time = conn.get_all_snapshots( - filters={'start-time': snapshot1.start_time}) - set([snap.start_time for snap in snapshots_by_start_time] - ).should.equal({snapshot1.start_time}) + filters={"start-time": snapshot1.start_time} + ) + set([snap.start_time for snap in snapshots_by_start_time]).should.equal( + {snapshot1.start_time} + ) - snapshots_by_volume_id = conn.get_all_snapshots( - filters={'volume-id': volume1.id}) - set([snap.id for snap in snapshots_by_volume_id] - ).should.equal({snapshot1.id, snapshot2.id}) + snapshots_by_volume_id = conn.get_all_snapshots(filters={"volume-id": volume1.id}) + set([snap.id for snap in snapshots_by_volume_id]).should.equal( + {snapshot1.id, snapshot2.id} + ) - snapshots_by_status = conn.get_all_snapshots( - filters={'status': 'completed'}) - ({snapshot1.id, snapshot2.id, snapshot3.id} - - {snap.id for snap in snapshots_by_status}).should.have.length_of(0) + snapshots_by_status = conn.get_all_snapshots(filters={"status": "completed"}) + ( + {snapshot1.id, snapshot2.id, snapshot3.id} + - {snap.id for snap in snapshots_by_status} + ).should.have.length_of(0) snapshots_by_volume_size = conn.get_all_snapshots( - filters={'volume-size': volume1.size}) - set([snap.id for snap in snapshots_by_volume_size] - ).should.equal({snapshot1.id, snapshot2.id}) + filters={"volume-size": volume1.size} + ) + set([snap.id for snap in snapshots_by_volume_size]).should.equal( + {snapshot1.id, snapshot2.id} + ) - snapshots_by_tag_key = conn.get_all_snapshots( - filters={'tag-key': 'testkey1'}) - set([snap.id for snap in snapshots_by_tag_key] - ).should.equal({snapshot1.id}) + snapshots_by_tag_key = conn.get_all_snapshots(filters={"tag-key": "testkey1"}) + set([snap.id for snap in snapshots_by_tag_key]).should.equal({snapshot1.id}) - snapshots_by_tag_value = conn.get_all_snapshots( - filters={'tag-value': 'testvalue1'}) - set([snap.id for snap in snapshots_by_tag_value] - ).should.equal({snapshot1.id}) + snapshots_by_tag_value = conn.get_all_snapshots(filters={"tag-value": "testvalue1"}) + set([snap.id for snap in snapshots_by_tag_value]).should.equal({snapshot1.id}) - snapshots_by_tag = conn.get_all_snapshots( - filters={'tag:testkey1': 'testvalue1'}) - set([snap.id for snap in snapshots_by_tag] - ).should.equal({snapshot1.id}) + snapshots_by_tag = conn.get_all_snapshots(filters={"tag:testkey1": "testvalue1"}) + set([snap.id for snap in snapshots_by_tag]).should.equal({snapshot1.id}) - snapshots_by_encrypted = conn.get_all_snapshots( - filters={'encrypted': 'true'}) - set([snap.id for snap in snapshots_by_encrypted] - ).should.equal({snapshot3.id}) + snapshots_by_encrypted = conn.get_all_snapshots(filters={"encrypted": "true"}) + set([snap.id for snap in snapshots_by_encrypted]).should.equal({snapshot3.id}) - snapshots_by_owner_id = conn.get_all_snapshots( - filters={'owner-id': OWNER_ID}) - set([snap.id for snap in snapshots_by_owner_id] - ).should.equal({snapshot1.id, snapshot2.id, snapshot3.id}) + snapshots_by_owner_id = conn.get_all_snapshots(filters={"owner-id": OWNER_ID}) + set([snap.id for snap in snapshots_by_owner_id]).should.equal( + {snapshot1.id, snapshot2.id, snapshot3.id} + ) @mock_ec2_deprecated @@ -411,119 +424,139 @@ def test_snapshot_attribute(): # Baseline attributes = conn.get_snapshot_attribute( - snapshot.id, attribute='createVolumePermission') - attributes.name.should.equal('create_volume_permission') + snapshot.id, attribute="createVolumePermission" + ) + attributes.name.should.equal("create_volume_permission") attributes.attrs.should.have.length_of(0) - ADD_GROUP_ARGS = {'snapshot_id': snapshot.id, - 'attribute': 'createVolumePermission', - 'operation': 'add', - 'groups': 'all'} + ADD_GROUP_ARGS = { + "snapshot_id": snapshot.id, + "attribute": "createVolumePermission", + "operation": "add", + "groups": "all", + } - REMOVE_GROUP_ARGS = {'snapshot_id': snapshot.id, - 'attribute': 'createVolumePermission', - 'operation': 'remove', - 'groups': 'all'} + REMOVE_GROUP_ARGS = { + "snapshot_id": snapshot.id, + "attribute": "createVolumePermission", + "operation": "remove", + "groups": "all", + } # Add 'all' group and confirm with assert_raises(EC2ResponseError) as ex: - conn.modify_snapshot_attribute( - **dict(ADD_GROUP_ARGS, **{'dry_run': True})) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_snapshot_attribute(**dict(ADD_GROUP_ARGS, **{"dry_run": True})) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set" + ) conn.modify_snapshot_attribute(**ADD_GROUP_ARGS) attributes = conn.get_snapshot_attribute( - snapshot.id, attribute='createVolumePermission') - attributes.attrs['groups'].should.have.length_of(1) - attributes.attrs['groups'].should.equal(['all']) + snapshot.id, attribute="createVolumePermission" + ) + attributes.attrs["groups"].should.have.length_of(1) + attributes.attrs["groups"].should.equal(["all"]) # Add is idempotent - conn.modify_snapshot_attribute.when.called_with( - **ADD_GROUP_ARGS).should_not.throw(EC2ResponseError) + conn.modify_snapshot_attribute.when.called_with(**ADD_GROUP_ARGS).should_not.throw( + EC2ResponseError + ) # Remove 'all' group and confirm with assert_raises(EC2ResponseError) as ex: - conn.modify_snapshot_attribute( - **dict(REMOVE_GROUP_ARGS, **{'dry_run': True})) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_snapshot_attribute(**dict(REMOVE_GROUP_ARGS, **{"dry_run": True})) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifySnapshotAttribute operation: Request would have succeeded, but DryRun flag is set" + ) conn.modify_snapshot_attribute(**REMOVE_GROUP_ARGS) attributes = conn.get_snapshot_attribute( - snapshot.id, attribute='createVolumePermission') + snapshot.id, attribute="createVolumePermission" + ) attributes.attrs.should.have.length_of(0) # Remove is idempotent conn.modify_snapshot_attribute.when.called_with( - **REMOVE_GROUP_ARGS).should_not.throw(EC2ResponseError) + **REMOVE_GROUP_ARGS + ).should_not.throw(EC2ResponseError) # Error: Add with group != 'all' with assert_raises(EC2ResponseError) as cm: - conn.modify_snapshot_attribute(snapshot.id, - attribute='createVolumePermission', - operation='add', - groups='everyone') - cm.exception.code.should.equal('InvalidAMIAttributeItemValue') + conn.modify_snapshot_attribute( + snapshot.id, + attribute="createVolumePermission", + operation="add", + groups="everyone", + ) + cm.exception.code.should.equal("InvalidAMIAttributeItemValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add with invalid snapshot ID with assert_raises(EC2ResponseError) as cm: - conn.modify_snapshot_attribute("snapshot-abcd1234", - attribute='createVolumePermission', - operation='add', - groups='all') - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + conn.modify_snapshot_attribute( + "snapshot-abcd1234", + attribute="createVolumePermission", + operation="add", + groups="all", + ) + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Remove with invalid snapshot ID with assert_raises(EC2ResponseError) as cm: - conn.modify_snapshot_attribute("snapshot-abcd1234", - attribute='createVolumePermission', - operation='remove', - groups='all') - cm.exception.code.should.equal('InvalidSnapshot.NotFound') + conn.modify_snapshot_attribute( + "snapshot-abcd1234", + attribute="createVolumePermission", + operation="remove", + groups="all", + ) + cm.exception.code.should.equal("InvalidSnapshot.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Add or remove with user ID instead of group - conn.modify_snapshot_attribute.when.called_with(snapshot.id, - attribute='createVolumePermission', - operation='add', - user_ids=['user']).should.throw(NotImplementedError) - conn.modify_snapshot_attribute.when.called_with(snapshot.id, - attribute='createVolumePermission', - operation='remove', - user_ids=['user']).should.throw(NotImplementedError) + conn.modify_snapshot_attribute.when.called_with( + snapshot.id, + attribute="createVolumePermission", + operation="add", + user_ids=["user"], + ).should.throw(NotImplementedError) + conn.modify_snapshot_attribute.when.called_with( + snapshot.id, + attribute="createVolumePermission", + operation="remove", + user_ids=["user"], + ).should.throw(NotImplementedError) @mock_ec2_deprecated def test_create_volume_from_snapshot(): conn = boto.ec2.connect_to_region("us-east-1") volume = conn.create_volume(80, "us-east-1a") - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") with assert_raises(EC2ResponseError) as ex: - snapshot = volume.create_snapshot('a test snapshot', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + snapshot = volume.create_snapshot("a test snapshot", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSnapshot operation: Request would have succeeded, but DryRun flag is set" + ) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") - new_volume = snapshot.create_volume('us-east-1a') + new_volume = snapshot.create_volume("us-east-1a") new_volume.size.should.equal(80) new_volume.snapshot_id.should.equal(snapshot.id) @@ -533,11 +566,11 @@ def test_create_volume_from_encrypted_snapshot(): conn = boto.ec2.connect_to_region("us-east-1") volume = conn.create_volume(80, "us-east-1a", encrypted=True) - snapshot = volume.create_snapshot('a test snapshot') + snapshot = volume.create_snapshot("a test snapshot") snapshot.update() - snapshot.status.should.equal('completed') + snapshot.status.should.equal("completed") - new_volume = snapshot.create_volume('us-east-1a') + new_volume = snapshot.create_volume("us-east-1a") new_volume.size.should.equal(80) new_volume.snapshot_id.should.equal(snapshot.id) new_volume.encrypted.should.be(True) @@ -553,131 +586,133 @@ def test_modify_attribute_blockDeviceMapping(): """ conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: - instance.modify_attribute('blockDeviceMapping', { - '/dev/sda1': True}, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + instance.modify_attribute( + "blockDeviceMapping", {"/dev/sda1": True}, dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceAttribute operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceAttribute operation: Request would have succeeded, but DryRun flag is set" + ) - instance.modify_attribute('blockDeviceMapping', {'/dev/sda1': True}) + instance.modify_attribute("blockDeviceMapping", {"/dev/sda1": True}) instance = ec2_backends[conn.region.name].get_instance(instance.id) - instance.block_device_mapping.should.have.key('/dev/sda1') - instance.block_device_mapping[ - '/dev/sda1'].delete_on_termination.should.be(True) + instance.block_device_mapping.should.have.key("/dev/sda1") + instance.block_device_mapping["/dev/sda1"].delete_on_termination.should.be(True) @mock_ec2_deprecated def test_volume_tag_escaping(): conn = boto.ec2.connect_to_region("us-east-1") - vol = conn.create_volume(10, 'us-east-1a') - snapshot = conn.create_snapshot(vol.id, 'Desc') + vol = conn.create_volume(10, "us-east-1a") + snapshot = conn.create_snapshot(vol.id, "Desc") with assert_raises(EC2ResponseError) as ex: - snapshot.add_tags({'key': ''}, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + snapshot.add_tags({"key": ""}, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) snaps = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] - dict(snaps[0].tags).should_not.be.equal( - {'key': ''}) + dict(snaps[0].tags).should_not.be.equal({"key": ""}) - snapshot.add_tags({'key': ''}) + snapshot.add_tags({"key": ""}) snaps = [snap for snap in conn.get_all_snapshots() if snap.id == snapshot.id] - dict(snaps[0].tags).should.equal({'key': ''}) + dict(snaps[0].tags).should.equal({"key": ""}) @mock_ec2 def test_volume_property_hidden_when_no_tags_exist(): - ec2_client = boto3.client('ec2', region_name='us-east-1') + ec2_client = boto3.client("ec2", region_name="us-east-1") - volume_response = ec2_client.create_volume( - Size=10, - AvailabilityZone='us-east-1a' - ) + volume_response = ec2_client.create_volume(Size=10, AvailabilityZone="us-east-1a") - volume_response.get('Tags').should.equal(None) + volume_response.get("Tags").should.equal(None) @freeze_time @mock_ec2 def test_copy_snapshot(): - ec2_client = boto3.client('ec2', region_name='eu-west-1') - dest_ec2_client = boto3.client('ec2', region_name='eu-west-2') + ec2_client = boto3.client("ec2", region_name="eu-west-1") + dest_ec2_client = boto3.client("ec2", region_name="eu-west-2") - volume_response = ec2_client.create_volume( - AvailabilityZone='eu-west-1a', Size=10 - ) + volume_response = ec2_client.create_volume(AvailabilityZone="eu-west-1a", Size=10) create_snapshot_response = ec2_client.create_snapshot( - VolumeId=volume_response['VolumeId'] + VolumeId=volume_response["VolumeId"] ) copy_snapshot_response = dest_ec2_client.copy_snapshot( - SourceSnapshotId=create_snapshot_response['SnapshotId'], - SourceRegion="eu-west-1" + SourceSnapshotId=create_snapshot_response["SnapshotId"], + SourceRegion="eu-west-1", ) - ec2 = boto3.resource('ec2', region_name='eu-west-1') - dest_ec2 = boto3.resource('ec2', region_name='eu-west-2') + ec2 = boto3.resource("ec2", region_name="eu-west-1") + dest_ec2 = boto3.resource("ec2", region_name="eu-west-2") - source = ec2.Snapshot(create_snapshot_response['SnapshotId']) - dest = dest_ec2.Snapshot(copy_snapshot_response['SnapshotId']) + source = ec2.Snapshot(create_snapshot_response["SnapshotId"]) + dest = dest_ec2.Snapshot(copy_snapshot_response["SnapshotId"]) - attribs = ['data_encryption_key_id', 'encrypted', - 'kms_key_id', 'owner_alias', 'owner_id', - 'progress', 'state', 'state_message', - 'tags', 'volume_id', 'volume_size'] + attribs = [ + "data_encryption_key_id", + "encrypted", + "kms_key_id", + "owner_alias", + "owner_id", + "progress", + "state", + "state_message", + "tags", + "volume_id", + "volume_size", + ] for attrib in attribs: getattr(source, attrib).should.equal(getattr(dest, attrib)) # Copy from non-existent source ID. with assert_raises(ClientError) as cm: - create_snapshot_error = ec2_client.create_snapshot( - VolumeId='vol-abcd1234' - ) - cm.exception.response['Error']['Code'].should.equal('InvalidVolume.NotFound') - cm.exception.response['Error']['Message'].should.equal("The volume 'vol-abcd1234' does not exist.") - cm.exception.response['ResponseMetadata']['RequestId'].should_not.be.none - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + create_snapshot_error = ec2_client.create_snapshot(VolumeId="vol-abcd1234") + cm.exception.response["Error"]["Code"].should.equal("InvalidVolume.NotFound") + cm.exception.response["Error"]["Message"].should.equal( + "The volume 'vol-abcd1234' does not exist." + ) + cm.exception.response["ResponseMetadata"]["RequestId"].should_not.be.none + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) # Copy from non-existent source region. with assert_raises(ClientError) as cm: copy_snapshot_response = dest_ec2_client.copy_snapshot( - SourceSnapshotId=create_snapshot_response['SnapshotId'], - SourceRegion="eu-west-2" + SourceSnapshotId=create_snapshot_response["SnapshotId"], + SourceRegion="eu-west-2", ) - cm.exception.response['Error']['Code'].should.equal('InvalidSnapshot.NotFound') - cm.exception.response['Error']['Message'].should.be.none - cm.exception.response['ResponseMetadata']['RequestId'].should_not.be.none - cm.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) + cm.exception.response["Error"]["Code"].should.equal("InvalidSnapshot.NotFound") + cm.exception.response["Error"]["Message"].should.be.none + cm.exception.response["ResponseMetadata"]["RequestId"].should_not.be.none + cm.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + @mock_ec2 def test_search_for_many_snapshots(): - ec2_client = boto3.client('ec2', region_name='eu-west-1') + ec2_client = boto3.client("ec2", region_name="eu-west-1") - volume_response = ec2_client.create_volume( - AvailabilityZone='eu-west-1a', Size=10 - ) + volume_response = ec2_client.create_volume(AvailabilityZone="eu-west-1a", Size=10) snapshot_ids = [] for i in range(1, 20): create_snapshot_response = ec2_client.create_snapshot( - VolumeId=volume_response['VolumeId'] + VolumeId=volume_response["VolumeId"] ) - snapshot_ids.append(create_snapshot_response['SnapshotId']) + snapshot_ids.append(create_snapshot_response["SnapshotId"]) - snapshots_response = ec2_client.describe_snapshots( - SnapshotIds=snapshot_ids - ) + snapshots_response = ec2_client.describe_snapshots(SnapshotIds=snapshot_ids) - assert len(snapshots_response['Snapshots']) == len(snapshot_ids) + assert len(snapshots_response["Snapshots"]) == len(snapshot_ids) diff --git a/tests/test_ec2/test_elastic_ip_addresses.py b/tests/test_ec2/test_elastic_ip_addresses.py index ca6637b18..886cdff56 100644 --- a/tests/test_ec2/test_elastic_ip_addresses.py +++ b/tests/test_ec2/test_elastic_ip_addresses.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -18,14 +19,15 @@ import logging @mock_ec2_deprecated def test_eip_allocate_classic(): """Allocate/release Classic EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: standard = conn.allocate_address(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set" + ) standard = conn.allocate_address() standard.should.be.a(boto.ec2.address.Address) @@ -35,10 +37,11 @@ def test_eip_allocate_classic(): with assert_raises(EC2ResponseError) as ex: standard.release(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set" + ) standard.release() standard.should_not.be.within(conn.get_all_addresses()) @@ -47,14 +50,15 @@ def test_eip_allocate_classic(): @mock_ec2_deprecated def test_eip_allocate_vpc(): """Allocate/release VPC EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: vpc = conn.allocate_address(domain="vpc", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AllocateAddress operation: Request would have succeeded, but DryRun flag is set" + ) vpc = conn.allocate_address(domain="vpc") vpc.should.be.a(boto.ec2.address.Address) @@ -62,26 +66,27 @@ def test_eip_allocate_vpc(): logging.debug("vpc alloc_id:".format(vpc.allocation_id)) vpc.release() + @mock_ec2 def test_specific_eip_allocate_vpc(): """Allocate VPC EIP with specific address""" - service = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + service = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") vpc = client.allocate_address(Domain="vpc", Address="127.38.43.222") - vpc['Domain'].should.be.equal("vpc") - vpc['PublicIp'].should.be.equal("127.38.43.222") - logging.debug("vpc alloc_id:".format(vpc['AllocationId'])) + vpc["Domain"].should.be.equal("vpc") + vpc["PublicIp"].should.be.equal("127.38.43.222") + logging.debug("vpc alloc_id:".format(vpc["AllocationId"])) @mock_ec2_deprecated def test_eip_allocate_invalid_domain(): """Allocate EIP invalid domain""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.allocate_address(domain="bogus") - cm.exception.code.should.equal('InvalidParameterValue') + cm.exception.code.should.equal("InvalidParameterValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -89,9 +94,9 @@ def test_eip_allocate_invalid_domain(): @mock_ec2_deprecated def test_eip_associate_classic(): """Associate/Disassociate EIP to classic instance""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] eip = conn.allocate_address() @@ -99,17 +104,19 @@ def test_eip_associate_classic(): with assert_raises(EC2ResponseError) as cm: conn.associate_address(public_ip=eip.public_ip) - cm.exception.code.should.equal('MissingParameter') + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as ex: - conn.associate_address(instance_id=instance.id, - public_ip=eip.public_ip, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.associate_address( + instance_id=instance.id, public_ip=eip.public_ip, dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AssociateAddress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AssociateAddress operation: Request would have succeeded, but DryRun flag is set" + ) conn.associate_address(instance_id=instance.id, public_ip=eip.public_ip) # no .update() on address ): @@ -118,15 +125,16 @@ def test_eip_associate_classic(): with assert_raises(EC2ResponseError) as ex: conn.disassociate_address(public_ip=eip.public_ip, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DisAssociateAddress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DisAssociateAddress operation: Request would have succeeded, but DryRun flag is set" + ) conn.disassociate_address(public_ip=eip.public_ip) # no .update() on address ): eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.instance_id.should.be.equal(u'') + eip.instance_id.should.be.equal("") eip.release() eip.should_not.be.within(conn.get_all_addresses()) eip = None @@ -137,37 +145,37 @@ def test_eip_associate_classic(): @mock_ec2_deprecated def test_eip_associate_vpc(): """Associate/Disassociate EIP to VPC instance""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] - eip = conn.allocate_address(domain='vpc') + eip = conn.allocate_address(domain="vpc") eip.instance_id.should.be.none with assert_raises(EC2ResponseError) as cm: conn.associate_address(allocation_id=eip.allocation_id) - cm.exception.code.should.equal('MissingParameter') + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none - conn.associate_address(instance_id=instance.id, - allocation_id=eip.allocation_id) + conn.associate_address(instance_id=instance.id, allocation_id=eip.allocation_id) # no .update() on address ): eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] eip.instance_id.should.be.equal(instance.id) conn.disassociate_address(association_id=eip.association_id) # no .update() on address ): eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.instance_id.should.be.equal(u'') + eip.instance_id.should.be.equal("") eip.association_id.should.be.none with assert_raises(EC2ResponseError) as ex: eip.release(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ReleaseAddress operation: Request would have succeeded, but DryRun flag is set" + ) eip.release() eip = None @@ -178,34 +186,38 @@ def test_eip_associate_vpc(): @mock_ec2 def test_eip_boto3_vpc_association(): """Associate EIP to VPC instance in a new subnet with boto3""" - service = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') - vpc_res = client.create_vpc(CidrBlock='10.0.0.0/24') + service = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + vpc_res = client.create_vpc(CidrBlock="10.0.0.0/24") subnet_res = client.create_subnet( - VpcId=vpc_res['Vpc']['VpcId'], CidrBlock='10.0.0.0/24') - instance = service.create_instances(**{ - 'InstanceType': 't2.micro', - 'ImageId': 'ami-test', - 'MinCount': 1, - 'MaxCount': 1, - 'SubnetId': subnet_res['Subnet']['SubnetId'] - })[0] - allocation_id = client.allocate_address(Domain='vpc')['AllocationId'] + VpcId=vpc_res["Vpc"]["VpcId"], CidrBlock="10.0.0.0/24" + ) + instance = service.create_instances( + **{ + "InstanceType": "t2.micro", + "ImageId": "ami-test", + "MinCount": 1, + "MaxCount": 1, + "SubnetId": subnet_res["Subnet"]["SubnetId"], + } + )[0] + allocation_id = client.allocate_address(Domain="vpc")["AllocationId"] address = service.VpcAddress(allocation_id) address.load() address.association_id.should.be.none address.instance_id.should.be.empty address.network_interface_id.should.be.empty association_id = client.associate_address( - InstanceId=instance.id, - AllocationId=allocation_id, - AllowReassociation=False) + InstanceId=instance.id, AllocationId=allocation_id, AllowReassociation=False + ) instance.load() address.reload() address.association_id.should_not.be.none instance.public_ip_address.should_not.be.none instance.public_dns_name.should_not.be.none - address.network_interface_id.should.equal(instance.network_interfaces_attribute[0].get('NetworkInterfaceId')) + address.network_interface_id.should.equal( + instance.network_interfaces_attribute[0].get("NetworkInterfaceId") + ) address.public_ip.should.equal(instance.public_ip_address) address.instance_id.should.equal(instance.id) @@ -221,22 +233,21 @@ def test_eip_boto3_vpc_association(): @mock_ec2_deprecated def test_eip_associate_network_interface(): """Associate/Disassociate EIP to NIC""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") eni = conn.create_network_interface(subnet.id) - eip = conn.allocate_address(domain='vpc') + eip = conn.allocate_address(domain="vpc") eip.network_interface_id.should.be.none with assert_raises(EC2ResponseError) as cm: conn.associate_address(network_interface_id=eni.id) - cm.exception.code.should.equal('MissingParameter') + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none - conn.associate_address(network_interface_id=eni.id, - allocation_id=eip.allocation_id) + conn.associate_address(network_interface_id=eni.id, allocation_id=eip.allocation_id) # no .update() on address ): eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] eip.network_interface_id.should.be.equal(eni.id) @@ -244,7 +255,7 @@ def test_eip_associate_network_interface(): conn.disassociate_address(association_id=eip.association_id) # no .update() on address ): eip = conn.get_all_addresses(addresses=[eip.public_ip])[0] - eip.network_interface_id.should.be.equal(u'') + eip.network_interface_id.should.be.equal("") eip.association_id.should.be.none eip.release() eip = None @@ -253,9 +264,9 @@ def test_eip_associate_network_interface(): @mock_ec2_deprecated def test_eip_reassociate(): """reassociate EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - reservation = conn.run_instances('ami-1234abcd', min_count=2) + reservation = conn.run_instances("ami-1234abcd", min_count=2) instance1, instance2 = reservation.instances eip = conn.allocate_address() @@ -267,13 +278,15 @@ def test_eip_reassociate(): # Different ID detects resource association with assert_raises(EC2ResponseError) as cm: conn.associate_address( - instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=False) - cm.exception.code.should.equal('Resource.AlreadyAssociated') + instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=False + ) + cm.exception.code.should.equal("Resource.AlreadyAssociated") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none conn.associate_address.when.called_with( - instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=True).should_not.throw(EC2ResponseError) + instance_id=instance2.id, public_ip=eip.public_ip, allow_reassociation=True + ).should_not.throw(EC2ResponseError) eip.release() eip = None @@ -285,7 +298,7 @@ def test_eip_reassociate(): @mock_ec2_deprecated def test_eip_reassociate_nic(): """reassociate EIP""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") @@ -293,23 +306,21 @@ def test_eip_reassociate_nic(): eni2 = conn.create_network_interface(subnet.id) eip = conn.allocate_address() - conn.associate_address(network_interface_id=eni1.id, - public_ip=eip.public_ip) + conn.associate_address(network_interface_id=eni1.id, public_ip=eip.public_ip) # Same ID is idempotent - conn.associate_address(network_interface_id=eni1.id, - public_ip=eip.public_ip) + conn.associate_address(network_interface_id=eni1.id, public_ip=eip.public_ip) # Different ID detects resource association with assert_raises(EC2ResponseError) as cm: - conn.associate_address( - network_interface_id=eni2.id, public_ip=eip.public_ip) - cm.exception.code.should.equal('Resource.AlreadyAssociated') + conn.associate_address(network_interface_id=eni2.id, public_ip=eip.public_ip) + cm.exception.code.should.equal("Resource.AlreadyAssociated") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none conn.associate_address.when.called_with( - network_interface_id=eni2.id, public_ip=eip.public_ip, allow_reassociation=True).should_not.throw(EC2ResponseError) + network_interface_id=eni2.id, public_ip=eip.public_ip, allow_reassociation=True + ).should_not.throw(EC2ResponseError) eip.release() eip = None @@ -318,16 +329,16 @@ def test_eip_reassociate_nic(): @mock_ec2_deprecated def test_eip_associate_invalid_args(): """Associate EIP, invalid args """ - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] eip = conn.allocate_address() with assert_raises(EC2ResponseError) as cm: conn.associate_address(instance_id=instance.id) - cm.exception.code.should.equal('MissingParameter') + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -337,11 +348,11 @@ def test_eip_associate_invalid_args(): @mock_ec2_deprecated def test_eip_disassociate_bogus_association(): """Disassociate bogus EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.disassociate_address(association_id="bogus") - cm.exception.code.should.equal('InvalidAssociationID.NotFound') + cm.exception.code.should.equal("InvalidAssociationID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -349,11 +360,11 @@ def test_eip_disassociate_bogus_association(): @mock_ec2_deprecated def test_eip_release_bogus_eip(): """Release bogus EIP""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.release_address(allocation_id="bogus") - cm.exception.code.should.equal('InvalidAllocationID.NotFound') + cm.exception.code.should.equal("InvalidAllocationID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -361,11 +372,11 @@ def test_eip_release_bogus_eip(): @mock_ec2_deprecated def test_eip_disassociate_arg_error(): """Invalid arguments disassociate address""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.disassociate_address() - cm.exception.code.should.equal('MissingParameter') + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -373,11 +384,11 @@ def test_eip_disassociate_arg_error(): @mock_ec2_deprecated def test_eip_release_arg_error(): """Invalid arguments release address""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.release_address() - cm.exception.code.should.equal('MissingParameter') + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -385,7 +396,7 @@ def test_eip_release_arg_error(): @mock_ec2_deprecated def test_eip_describe(): """Listing of allocated Elastic IP Addresses.""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") eips = [] number_of_classic_ips = 2 number_of_vpc_ips = 2 @@ -394,23 +405,24 @@ def test_eip_describe(): for _ in range(number_of_classic_ips): eips.append(conn.allocate_address()) for _ in range(number_of_vpc_ips): - eips.append(conn.allocate_address(domain='vpc')) + eips.append(conn.allocate_address(domain="vpc")) len(eips).should.be.equal(number_of_classic_ips + number_of_vpc_ips) # Can we find each one individually? for eip in eips: if eip.allocation_id: lookup_addresses = conn.get_all_addresses( - allocation_ids=[eip.allocation_id]) + allocation_ids=[eip.allocation_id] + ) else: - lookup_addresses = conn.get_all_addresses( - addresses=[eip.public_ip]) + lookup_addresses = conn.get_all_addresses(addresses=[eip.public_ip]) len(lookup_addresses).should.be.equal(1) lookup_addresses[0].public_ip.should.be.equal(eip.public_ip) # Can we find first two when we search for them? lookup_addresses = conn.get_all_addresses( - addresses=[eips[0].public_ip, eips[1].public_ip]) + addresses=[eips[0].public_ip, eips[1].public_ip] + ) len(lookup_addresses).should.be.equal(2) lookup_addresses[0].public_ip.should.be.equal(eips[0].public_ip) lookup_addresses[1].public_ip.should.be.equal(eips[1].public_ip) @@ -424,36 +436,38 @@ def test_eip_describe(): @mock_ec2_deprecated def test_eip_describe_none(): """Error when search for bogus IP""" - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.get_all_addresses(addresses=["256.256.256.256"]) - cm.exception.code.should.equal('InvalidAddress.NotFound') + cm.exception.code.should.equal("InvalidAddress.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_eip_filters(): - service = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') - vpc_res = client.create_vpc(CidrBlock='10.0.0.0/24') + service = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + vpc_res = client.create_vpc(CidrBlock="10.0.0.0/24") subnet_res = client.create_subnet( - VpcId=vpc_res['Vpc']['VpcId'], CidrBlock='10.0.0.0/24') + VpcId=vpc_res["Vpc"]["VpcId"], CidrBlock="10.0.0.0/24" + ) def create_inst_with_eip(): - instance = service.create_instances(**{ - 'InstanceType': 't2.micro', - 'ImageId': 'ami-test', - 'MinCount': 1, - 'MaxCount': 1, - 'SubnetId': subnet_res['Subnet']['SubnetId'] - })[0] - allocation_id = client.allocate_address(Domain='vpc')['AllocationId'] + instance = service.create_instances( + **{ + "InstanceType": "t2.micro", + "ImageId": "ami-test", + "MinCount": 1, + "MaxCount": 1, + "SubnetId": subnet_res["Subnet"]["SubnetId"], + } + )[0] + allocation_id = client.allocate_address(Domain="vpc")["AllocationId"] _ = client.associate_address( - InstanceId=instance.id, - AllocationId=allocation_id, - AllowReassociation=False) + InstanceId=instance.id, AllocationId=allocation_id, AllowReassociation=False + ) instance.load() address = service.VpcAddress(allocation_id) address.load() @@ -477,38 +491,49 @@ def test_eip_filters(): # Param search by Filter def check_vpc_filter_valid(filter_name, filter_values): - addresses = list(service.vpc_addresses.filter( - Filters=[{'Name': filter_name, - 'Values': filter_values}])) + addresses = list( + service.vpc_addresses.filter( + Filters=[{"Name": filter_name, "Values": filter_values}] + ) + ) len(addresses).should.equal(2) ips = [addr.public_ip for addr in addresses] set(ips).should.equal(set([eip1.public_ip, eip2.public_ip])) ips.should.contain(inst1.public_ip_address) def check_vpc_filter_invalid(filter_name): - addresses = list(service.vpc_addresses.filter( - Filters=[{'Name': filter_name, - 'Values': ['dummy1', 'dummy2']}])) + addresses = list( + service.vpc_addresses.filter( + Filters=[{"Name": filter_name, "Values": ["dummy1", "dummy2"]}] + ) + ) len(addresses).should.equal(0) def check_vpc_filter(filter_name, filter_values): check_vpc_filter_valid(filter_name, filter_values) check_vpc_filter_invalid(filter_name) - check_vpc_filter('allocation-id', [eip1.allocation_id, eip2.allocation_id]) - check_vpc_filter('association-id', [eip1.association_id, eip2.association_id]) - check_vpc_filter('instance-id', [inst1.id, inst2.id]) + check_vpc_filter("allocation-id", [eip1.allocation_id, eip2.allocation_id]) + check_vpc_filter("association-id", [eip1.association_id, eip2.association_id]) + check_vpc_filter("instance-id", [inst1.id, inst2.id]) check_vpc_filter( - 'network-interface-id', - [inst1.network_interfaces_attribute[0].get('NetworkInterfaceId'), - inst2.network_interfaces_attribute[0].get('NetworkInterfaceId')]) + "network-interface-id", + [ + inst1.network_interfaces_attribute[0].get("NetworkInterfaceId"), + inst2.network_interfaces_attribute[0].get("NetworkInterfaceId"), + ], + ) check_vpc_filter( - 'private-ip-address', - [inst1.network_interfaces_attribute[0].get('PrivateIpAddress'), - inst2.network_interfaces_attribute[0].get('PrivateIpAddress')]) - check_vpc_filter('public-ip', [inst1.public_ip_address, inst2.public_ip_address]) + "private-ip-address", + [ + inst1.network_interfaces_attribute[0].get("PrivateIpAddress"), + inst2.network_interfaces_attribute[0].get("PrivateIpAddress"), + ], + ) + check_vpc_filter("public-ip", [inst1.public_ip_address, inst2.public_ip_address]) # all the ips are in a VPC - addresses = list(service.vpc_addresses.filter( - Filters=[{'Name': 'domain', 'Values': ['vpc']}])) + addresses = list( + service.vpc_addresses.filter(Filters=[{"Name": "domain", "Values": ["vpc"]}]) + ) len(addresses).should.equal(3) diff --git a/tests/test_ec2/test_elastic_network_interfaces.py b/tests/test_ec2/test_elastic_network_interfaces.py index 05b45fda9..4e502586e 100644 --- a/tests/test_ec2/test_elastic_network_interfaces.py +++ b/tests/test_ec2/test_elastic_network_interfaces.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -19,16 +20,17 @@ import json @mock_ec2_deprecated def test_elastic_network_interfaces(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") with assert_raises(EC2ResponseError) as ex: eni = conn.create_network_interface(subnet.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) eni = conn.create_network_interface(subnet.id) @@ -37,14 +39,15 @@ def test_elastic_network_interfaces(): eni = all_enis[0] eni.groups.should.have.length_of(0) eni.private_ip_addresses.should.have.length_of(1) - eni.private_ip_addresses[0].private_ip_address.startswith('10.').should.be.true + eni.private_ip_addresses[0].private_ip_address.startswith("10.").should.be.true with assert_raises(EC2ResponseError) as ex: conn.delete_network_interface(eni.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) conn.delete_network_interface(eni.id) @@ -53,25 +56,25 @@ def test_elastic_network_interfaces(): with assert_raises(EC2ResponseError) as cm: conn.delete_network_interface(eni.id) - cm.exception.error_code.should.equal('InvalidNetworkInterfaceID.NotFound') + cm.exception.error_code.should.equal("InvalidNetworkInterfaceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_elastic_network_interfaces_subnet_validation(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.create_network_interface("subnet-abcd1234") - cm.exception.error_code.should.equal('InvalidSubnetID.NotFound') + cm.exception.error_code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_elastic_network_interfaces_with_private_ip(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") private_ip = "54.0.0.1" @@ -89,15 +92,18 @@ def test_elastic_network_interfaces_with_private_ip(): @mock_ec2_deprecated def test_elastic_network_interfaces_with_groups(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) conn.create_network_interface( - subnet.id, groups=[security_group1.id, security_group2.id]) + subnet.id, groups=[security_group1.id, security_group2.id] + ) all_enis = conn.get_all_network_interfaces() all_enis.should.have.length_of(1) @@ -105,19 +111,22 @@ def test_elastic_network_interfaces_with_groups(): eni = all_enis[0] eni.groups.should.have.length_of(2) set([group.id for group in eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) @requires_boto_gte("2.12.0") @mock_ec2_deprecated def test_elastic_network_interfaces_modify_attribute(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) conn.create_network_interface(subnet.id, groups=[security_group1.id]) all_enis = conn.get_all_network_interfaces() @@ -129,14 +138,15 @@ def test_elastic_network_interfaces_modify_attribute(): with assert_raises(EC2ResponseError) as ex: conn.modify_network_interface_attribute( - eni.id, 'groupset', [security_group2.id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + eni.id, "groupset", [security_group2.id], dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) - conn.modify_network_interface_attribute( - eni.id, 'groupset', [security_group2.id]) + conn.modify_network_interface_attribute(eni.id, "groupset", [security_group2.id]) all_enis = conn.get_all_network_interfaces() all_enis.should.have.length_of(1) @@ -148,20 +158,22 @@ def test_elastic_network_interfaces_modify_attribute(): @mock_ec2_deprecated def test_elastic_network_interfaces_filtering(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) eni1 = conn.create_network_interface( - subnet.id, groups=[security_group1.id, security_group2.id]) - eni2 = conn.create_network_interface( - subnet.id, groups=[security_group1.id]) - eni3 = conn.create_network_interface(subnet.id, description='test description') + subnet.id, groups=[security_group1.id, security_group2.id] + ) + eni2 = conn.create_network_interface(subnet.id, groups=[security_group1.id]) + eni3 = conn.create_network_interface(subnet.id, description="test description") all_enis = conn.get_all_network_interfaces() all_enis.should.have.length_of(3) @@ -173,280 +185,322 @@ def test_elastic_network_interfaces_filtering(): # Filter by ENI ID enis_by_id = conn.get_all_network_interfaces( - filters={'network-interface-id': eni1.id}) + filters={"network-interface-id": eni1.id} + ) enis_by_id.should.have.length_of(1) set([eni.id for eni in enis_by_id]).should.equal(set([eni1.id])) # Filter by Security Group enis_by_group = conn.get_all_network_interfaces( - filters={'group-id': security_group1.id}) + filters={"group-id": security_group1.id} + ) enis_by_group.should.have.length_of(2) set([eni.id for eni in enis_by_group]).should.equal(set([eni1.id, eni2.id])) # Filter by ENI ID and Security Group enis_by_group = conn.get_all_network_interfaces( - filters={'network-interface-id': eni1.id, 'group-id': security_group1.id}) + filters={"network-interface-id": eni1.id, "group-id": security_group1.id} + ) enis_by_group.should.have.length_of(1) set([eni.id for eni in enis_by_group]).should.equal(set([eni1.id])) # Filter by Description enis_by_description = conn.get_all_network_interfaces( - filters={'description': eni3.description }) + filters={"description": eni3.description} + ) enis_by_description.should.have.length_of(1) enis_by_description[0].description.should.equal(eni3.description) # Unsupported filter conn.get_all_network_interfaces.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2 def test_elastic_network_interfaces_get_by_tag_name(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) with assert_raises(ClientError) as ex: - eni1.create_tags(Tags=[{'Key': 'Name', 'Value': 'eni1'}], DryRun=True) - ex.exception.response['Error']['Code'].should.equal('DryRunOperation') - ex.exception.response['ResponseMetadata'][ - 'HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + eni1.create_tags(Tags=[{"Key": "Name", "Value": "eni1"}], DryRun=True) + ex.exception.response["Error"]["Code"].should.equal("DryRunOperation") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) - eni1.create_tags(Tags=[{'Key': 'Name', 'Value': 'eni1'}]) + eni1.create_tags(Tags=[{"Key": "Name", "Value": "eni1"}]) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'tag:Name', 'Values': ['eni1']}] + filters = [{"Name": "tag:Name", "Values": ["eni1"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'tag:Name', 'Values': ['wrong-name']}] + filters = [{"Name": "tag:Name", "Values": ["wrong-name"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_availability_zone(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet1 = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.1.0/24', AvailabilityZone='us-west-2b') + VpcId=vpc.id, CidrBlock="10.0.1.0/24", AvailabilityZone="us-west-2b" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet1.id, PrivateIpAddress='10.0.0.15') + SubnetId=subnet1.id, PrivateIpAddress="10.0.0.15" + ) eni2 = ec2.create_network_interface( - SubnetId=subnet2.id, PrivateIpAddress='10.0.1.15') + SubnetId=subnet2.id, PrivateIpAddress="10.0.1.15" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id, eni2.id]) - filters = [{'Name': 'availability-zone', 'Values': ['us-west-2a']}] + filters = [{"Name": "availability-zone", "Values": ["us-west-2a"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'availability-zone', 'Values': ['us-west-2c']}] + filters = [{"Name": "availability-zone", "Values": ["us-west-2c"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_private_ip(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'private-ip-address', 'Values': ['10.0.10.5']}] + filters = [{"Name": "private-ip-address", "Values": ["10.0.10.5"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'private-ip-address', 'Values': ['10.0.10.10']}] + filters = [{"Name": "private-ip-address", "Values": ["10.0.10.10"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) - filters = [{'Name': 'addresses.private-ip-address', 'Values': ['10.0.10.5']}] + filters = [{"Name": "addresses.private-ip-address", "Values": ["10.0.10.5"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'addresses.private-ip-address', 'Values': ['10.0.10.10']}] + filters = [{"Name": "addresses.private-ip-address", "Values": ["10.0.10.10"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_vpc_id(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'vpc-id', 'Values': [subnet.vpc_id]}] + filters = [{"Name": "vpc-id", "Values": [subnet.vpc_id]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'vpc-id', 'Values': ['vpc-aaaa1111']}] + filters = [{"Name": "vpc-id", "Values": ["vpc-aaaa1111"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_subnet_id(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'subnet-id', 'Values': [subnet.id]}] + filters = [{"Name": "subnet-id", "Values": [subnet.id]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'subnet-id', 'Values': ['subnet-aaaa1111']}] + filters = [{"Name": "subnet-id", "Values": ["subnet-aaaa1111"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_get_by_description(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5', Description='test interface') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5", Description="test interface" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) - filters = [{'Name': 'description', 'Values': [eni1.description]}] + filters = [{"Name": "description", "Values": [eni1.description]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(1) - filters = [{'Name': 'description', 'Values': ['bad description']}] + filters = [{"Name": "description", "Values": ["bad description"]}] enis = list(ec2.network_interfaces.filter(Filters=filters)) enis.should.have.length_of(0) @mock_ec2 def test_elastic_network_interfaces_describe_network_interfaces_with_filter(): - ec2 = boto3.resource('ec2', region_name='us-west-2') - ec2_client = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.resource("ec2", region_name="us-west-2") + ec2_client = boto3.client("ec2", region_name="us-west-2") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-2a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-2a" + ) eni1 = ec2.create_network_interface( - SubnetId=subnet.id, PrivateIpAddress='10.0.10.5', Description='test interface') + SubnetId=subnet.id, PrivateIpAddress="10.0.10.5", Description="test interface" + ) # The status of the new interface should be 'available' - waiter = ec2_client.get_waiter('network_interface_available') + waiter = ec2_client.get_waiter("network_interface_available") waiter.wait(NetworkInterfaceIds=[eni1.id]) # Filter by network-interface-id response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'network-interface-id', 'Values': [eni1.id]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "network-interface-id", "Values": [eni1.id]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'network-interface-id', 'Values': ['bad-id']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "network-interface-id", "Values": ["bad-id"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by private-ip-address response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'private-ip-address', 'Values': [eni1.private_ip_address]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "private-ip-address", "Values": [eni1.private_ip_address]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'private-ip-address', 'Values': ['11.11.11.11']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "private-ip-address", "Values": ["11.11.11.11"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by sunet-id response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'subnet-id', 'Values': [eni1.subnet.id]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "subnet-id", "Values": [eni1.subnet.id]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'subnet-id', 'Values': ['sn-bad-id']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "subnet-id", "Values": ["sn-bad-id"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by description response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'description', 'Values': [eni1.description]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[{"Name": "description", "Values": [eni1.description]}] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'description', 'Values': ['bad description']}]) - response['NetworkInterfaces'].should.have.length_of(0) + Filters=[{"Name": "description", "Values": ["bad description"]}] + ) + response["NetworkInterfaces"].should.have.length_of(0) # Filter by multiple filters response = ec2_client.describe_network_interfaces( - Filters=[{'Name': 'private-ip-address', 'Values': [eni1.private_ip_address]}, - {'Name': 'network-interface-id', 'Values': [eni1.id]}, - {'Name': 'subnet-id', 'Values': [eni1.subnet.id]}]) - response['NetworkInterfaces'].should.have.length_of(1) - response['NetworkInterfaces'][0]['NetworkInterfaceId'].should.equal(eni1.id) - response['NetworkInterfaces'][0]['PrivateIpAddress'].should.equal(eni1.private_ip_address) - response['NetworkInterfaces'][0]['Description'].should.equal(eni1.description) + Filters=[ + {"Name": "private-ip-address", "Values": [eni1.private_ip_address]}, + {"Name": "network-interface-id", "Values": [eni1.id]}, + {"Name": "subnet-id", "Values": [eni1.subnet.id]}, + ] + ) + response["NetworkInterfaces"].should.have.length_of(1) + response["NetworkInterfaces"][0]["NetworkInterfaceId"].should.equal(eni1.id) + response["NetworkInterfaces"][0]["PrivateIpAddress"].should.equal( + eni1.private_ip_address + ) + response["NetworkInterfaces"][0]["Description"].should.equal(eni1.description) @mock_ec2_deprecated @@ -455,19 +509,19 @@ def test_elastic_network_interfaces_cloudformation(): template = vpc_eni.template template_json = json.dumps(template) conn = boto.cloudformation.connect_to_region("us-west-1") - conn.create_stack( - "test_stack", - template_body=template_json, - ) + conn.create_stack("test_stack", template_body=template_json) ec2_conn = boto.ec2.connect_to_region("us-west-1") eni = ec2_conn.get_all_network_interfaces()[0] eni.private_ip_addresses.should.have.length_of(1) stack = conn.describe_stacks()[0] resources = stack.describe_resources() - cfn_eni = [resource for resource in resources if resource.resource_type == - 'AWS::EC2::NetworkInterface'][0] + cfn_eni = [ + resource + for resource in resources + if resource.resource_type == "AWS::EC2::NetworkInterface" + ][0] cfn_eni.physical_resource_id.should.equal(eni.id) outputs = {output.key: output.value for output in stack.outputs} - outputs['ENIIpAddress'].should.equal(eni.private_ip_addresses[0].private_ip_address) + outputs["ENIIpAddress"].should.equal(eni.private_ip_addresses[0].private_ip_address) diff --git a/tests/test_ec2/test_general.py b/tests/test_ec2/test_general.py index 4c319d30d..7b8f3bd53 100644 --- a/tests/test_ec2/test_general.py +++ b/tests/test_ec2/test_general.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -13,8 +14,8 @@ from moto import mock_ec2_deprecated, mock_ec2 @mock_ec2_deprecated def test_console_output(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance_id = reservation.instances[0].id output = conn.get_console_output(instance_id) output.output.should_not.equal(None) @@ -22,21 +23,19 @@ def test_console_output(): @mock_ec2_deprecated def test_console_output_without_instance(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.get_console_output('i-1234abcd') - cm.exception.code.should.equal('InvalidInstanceID.NotFound') + conn.get_console_output("i-1234abcd") + cm.exception.code.should.equal("InvalidInstanceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_console_output_boto3(): - conn = boto3.resource('ec2', 'us-east-1') - instances = conn.create_instances(ImageId='ami-1234abcd', - MinCount=1, - MaxCount=1) + conn = boto3.resource("ec2", "us-east-1") + instances = conn.create_instances(ImageId="ami-1234abcd", MinCount=1, MaxCount=1) output = instances[0].console_output() - output.get('Output').should_not.equal(None) + output.get("Output").should_not.equal(None) diff --git a/tests/test_ec2/test_instances.py b/tests/test_ec2/test_instances.py index a83384709..041bc8c85 100644 --- a/tests/test_ec2/test_instances.py +++ b/tests/test_ec2/test_instances.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 from botocore.exceptions import ClientError @@ -30,13 +31,14 @@ def add_servers(ami_id, count): @mock_ec2_deprecated def test_add_servers(): - add_servers('ami-1234abcd', 2) + add_servers("ami-1234abcd", 2) conn = boto.connect_ec2() reservations = conn.get_all_instances() assert len(reservations) == 2 instance1 = reservations[0].instances[0] - assert instance1.image_id == 'ami-1234abcd' + assert instance1.image_id == "ami-1234abcd" + ############################################ @@ -47,17 +49,18 @@ def test_instance_launch_and_terminate(): conn = boto.ec2.connect_to_region("us-east-1") with assert_raises(EC2ResponseError) as ex: - reservation = conn.run_instances('ami-1234abcd', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + reservation = conn.run_instances("ami-1234abcd", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RunInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RunInstance operation: Request would have succeeded, but DryRun flag is set" + ) - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") reservation.should.be.a(Reservation) reservation.instances.should.have.length_of(1) instance = reservation.instances[0] - instance.state.should.equal('pending') + instance.state.should.equal("pending") reservations = conn.get_all_instances() reservations.should.have.length_of(1) @@ -66,47 +69,46 @@ def test_instance_launch_and_terminate(): instances.should.have.length_of(1) instance = instances[0] instance.id.should.equal(instance.id) - instance.state.should.equal('running') + instance.state.should.equal("running") instance.launch_time.should.equal("2014-01-01T05:00:00.000Z") instance.vpc_id.should.equal(None) - instance.placement.should.equal('us-east-1a') + instance.placement.should.equal("us-east-1a") root_device_name = instance.root_device_name - instance.block_device_mapping[ - root_device_name].status.should.equal('in-use') + instance.block_device_mapping[root_device_name].status.should.equal("in-use") volume_id = instance.block_device_mapping[root_device_name].volume_id - volume_id.should.match(r'vol-\w+') + volume_id.should.match(r"vol-\w+") volume = conn.get_all_volumes(volume_ids=[volume_id])[0] volume.attach_data.instance_id.should.equal(instance.id) - volume.status.should.equal('in-use') + volume.status.should.equal("in-use") with assert_raises(EC2ResponseError) as ex: conn.terminate_instances([instance.id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the TerminateInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the TerminateInstance operation: Request would have succeeded, but DryRun flag is set" + ) conn.terminate_instances([instance.id]) reservations = conn.get_all_instances() instance = reservations[0].instances[0] - instance.state.should.equal('terminated') + instance.state.should.equal("terminated") @mock_ec2_deprecated def test_terminate_empty_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.terminate_instances.when.called_with( - []).should.throw(EC2ResponseError) + conn = boto.connect_ec2("the_key", "the_secret") + conn.terminate_instances.when.called_with([]).should.throw(EC2ResponseError) @freeze_time("2014-01-01 05:00:00") @mock_ec2_deprecated def test_instance_attach_volume(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] vol1 = conn.create_volume(size=36, zone=conn.region.name) @@ -124,20 +126,22 @@ def test_instance_attach_volume(): instance.block_device_mapping.should.have.length_of(3) - for v in conn.get_all_volumes(volume_ids=[instance.block_device_mapping['/dev/sdc1'].volume_id]): + for v in conn.get_all_volumes( + volume_ids=[instance.block_device_mapping["/dev/sdc1"].volume_id] + ): v.attach_data.instance_id.should.equal(instance.id) # can do due to freeze_time decorator. v.attach_data.attach_time.should.equal(instance.launch_time) # can do due to freeze_time decorator. v.create_time.should.equal(instance.launch_time) v.region.name.should.equal(instance.region.name) - v.status.should.equal('in-use') + v.status.should.equal("in-use") @mock_ec2_deprecated def test_get_instances_by_id(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=2) + reservation = conn.run_instances("ami-1234abcd", min_count=2) instance1, instance2 = reservation.instances reservations = conn.get_all_instances(instance_ids=[instance1.id]) @@ -146,8 +150,7 @@ def test_get_instances_by_id(): reservation.instances.should.have.length_of(1) reservation.instances[0].id.should.equal(instance1.id) - reservations = conn.get_all_instances( - instance_ids=[instance1.id, instance2.id]) + reservations = conn.get_all_instances(instance_ids=[instance1.id, instance2.id]) reservations.should.have.length_of(1) reservation = reservations[0] reservation.instances.should.have.length_of(2) @@ -157,78 +160,64 @@ def test_get_instances_by_id(): # Call get_all_instances with a bad id should raise an error with assert_raises(EC2ResponseError) as cm: conn.get_all_instances(instance_ids=[instance1.id, "i-1234abcd"]) - cm.exception.code.should.equal('InvalidInstanceID.NotFound') + cm.exception.code.should.equal("InvalidInstanceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2 def test_get_paginated_instances(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - conn = boto3.resource('ec2', 'us-east-1') + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + conn = boto3.resource("ec2", "us-east-1") for i in range(100): - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1) + conn.create_instances(ImageId=image_id, MinCount=1, MaxCount=1) resp = client.describe_instances(MaxResults=50) - reservations = resp['Reservations'] + reservations = resp["Reservations"] reservations.should.have.length_of(50) - next_token = resp['NextToken'] + next_token = resp["NextToken"] next_token.should_not.be.none resp2 = client.describe_instances(NextToken=next_token) - reservations.extend(resp2['Reservations']) + reservations.extend(resp2["Reservations"]) reservations.should.have.length_of(100) - assert 'NextToken' not in resp2.keys() + assert "NextToken" not in resp2.keys() @mock_ec2 def test_create_with_tags(): - ec2 = boto3.client('ec2', region_name='us-west-2') + ec2 = boto3.client("ec2", region_name="us-west-2") instances = ec2.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) - assert 'Tags' in instances['Instances'][0] - len(instances['Instances'][0]['Tags']).should.equal(3) + assert "Tags" in instances["Instances"][0] + len(instances["Instances"][0]["Tags"]).should.equal(3) @mock_ec2_deprecated def test_get_instances_filtering_by_state(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances conn.terminate_instances([instance1.id]) - reservations = conn.get_all_instances( - filters={'instance-state-name': 'running'}) + reservations = conn.get_all_instances(filters={"instance-state-name": "running"}) reservations.should.have.length_of(1) # Since we terminated instance1, only instance2 and instance3 should be # returned @@ -236,13 +225,15 @@ def test_get_instances_filtering_by_state(): set(instance_ids).should.equal(set([instance2.id, instance3.id])) reservations = conn.get_all_instances( - [instance2.id], filters={'instance-state-name': 'running'}) + [instance2.id], filters={"instance-state-name": "running"} + ) reservations.should.have.length_of(1) instance_ids = [instance.id for instance in reservations[0].instances] instance_ids.should.equal([instance2.id]) reservations = conn.get_all_instances( - [instance2.id], filters={'instance-state-name': 'terminated'}) + [instance2.id], filters={"instance-state-name": "terminated"} + ) list(reservations).should.equal([]) # get_all_instances should still return all 3 @@ -250,60 +241,58 @@ def test_get_instances_filtering_by_state(): reservations[0].instances.should.have.length_of(3) conn.get_all_instances.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2_deprecated def test_get_instances_filtering_by_instance_id(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - reservations = conn.get_all_instances( - filters={'instance-id': instance1.id}) + reservations = conn.get_all_instances(filters={"instance-id": instance1.id}) # get_all_instances should return just instance1 reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance1.id) reservations = conn.get_all_instances( - filters={'instance-id': [instance1.id, instance2.id]}) + filters={"instance-id": [instance1.id, instance2.id]} + ) # get_all_instances should return two reservations[0].instances.should.have.length_of(2) - reservations = conn.get_all_instances( - filters={'instance-id': 'non-existing-id'}) + reservations = conn.get_all_instances(filters={"instance-id": "non-existing-id"}) reservations.should.have.length_of(0) @mock_ec2_deprecated def test_get_instances_filtering_by_instance_type(): conn = boto.connect_ec2() - reservation1 = conn.run_instances('ami-1234abcd', instance_type='m1.small') + reservation1 = conn.run_instances("ami-1234abcd", instance_type="m1.small") instance1 = reservation1.instances[0] - reservation2 = conn.run_instances('ami-1234abcd', instance_type='m1.small') + reservation2 = conn.run_instances("ami-1234abcd", instance_type="m1.small") instance2 = reservation2.instances[0] - reservation3 = conn.run_instances('ami-1234abcd', instance_type='t1.micro') + reservation3 = conn.run_instances("ami-1234abcd", instance_type="t1.micro") instance3 = reservation3.instances[0] - reservations = conn.get_all_instances( - filters={'instance-type': 'm1.small'}) + reservations = conn.get_all_instances(filters={"instance-type": "m1.small"}) # get_all_instances should return instance1,2 reservations.should.have.length_of(2) reservations[0].instances.should.have.length_of(1) reservations[1].instances.should.have.length_of(1) - instance_ids = [reservations[0].instances[0].id, - reservations[1].instances[0].id] + instance_ids = [reservations[0].instances[0].id, reservations[1].instances[0].id] set(instance_ids).should.equal(set([instance1.id, instance2.id])) - reservations = conn.get_all_instances( - filters={'instance-type': 't1.micro'}) + reservations = conn.get_all_instances(filters={"instance-type": "t1.micro"}) # get_all_instances should return one reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance3.id) reservations = conn.get_all_instances( - filters={'instance-type': ['t1.micro', 'm1.small']}) + filters={"instance-type": ["t1.micro", "m1.small"]} + ) reservations.should.have.length_of(3) reservations[0].instances.should.have.length_of(1) reservations[1].instances.should.have.length_of(1) @@ -313,10 +302,9 @@ def test_get_instances_filtering_by_instance_type(): reservations[1].instances[0].id, reservations[2].instances[0].id, ] - set(instance_ids).should.equal( - set([instance1.id, instance2.id, instance3.id])) + set(instance_ids).should.equal(set([instance1.id, instance2.id, instance3.id])) - reservations = conn.get_all_instances(filters={'instance-type': 'bogus'}) + reservations = conn.get_all_instances(filters={"instance-type": "bogus"}) # bogus instance-type should return none reservations.should.have.length_of(0) @@ -324,19 +312,21 @@ def test_get_instances_filtering_by_instance_type(): @mock_ec2_deprecated def test_get_instances_filtering_by_reason_code(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances instance1.stop() instance2.terminate() reservations = conn.get_all_instances( - filters={'state-reason-code': 'Client.UserInitiatedShutdown'}) + filters={"state-reason-code": "Client.UserInitiatedShutdown"} + ) # get_all_instances should return instance1 and instance2 reservations[0].instances.should.have.length_of(2) set([instance1.id, instance2.id]).should.equal( - set([i.id for i in reservations[0].instances])) + set([i.id for i in reservations[0].instances]) + ) - reservations = conn.get_all_instances(filters={'state-reason-code': ''}) + reservations = conn.get_all_instances(filters={"state-reason-code": ""}) # get_all_instances should return instance 3 reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance3.id) @@ -345,15 +335,18 @@ def test_get_instances_filtering_by_reason_code(): @mock_ec2_deprecated def test_get_instances_filtering_by_source_dest_check(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=2) + reservation = conn.run_instances("ami-1234abcd", min_count=2) instance1, instance2 = reservation.instances conn.modify_instance_attribute( - instance1.id, attribute='sourceDestCheck', value=False) + instance1.id, attribute="sourceDestCheck", value=False + ) source_dest_check_false = conn.get_all_instances( - filters={'source-dest-check': 'false'}) + filters={"source-dest-check": "false"} + ) source_dest_check_true = conn.get_all_instances( - filters={'source-dest-check': 'true'}) + filters={"source-dest-check": "true"} + ) source_dest_check_false[0].instances.should.have.length_of(1) source_dest_check_false[0].instances[0].id.should.equal(instance1.id) @@ -364,27 +357,25 @@ def test_get_instances_filtering_by_source_dest_check(): @mock_ec2_deprecated def test_get_instances_filtering_by_vpc_id(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc1 = conn.create_vpc("10.0.0.0/16") subnet1 = conn.create_subnet(vpc1.id, "10.0.0.0/27") - reservation1 = conn.run_instances( - 'ami-1234abcd', min_count=1, subnet_id=subnet1.id) + reservation1 = conn.run_instances("ami-1234abcd", min_count=1, subnet_id=subnet1.id) instance1 = reservation1.instances[0] vpc2 = conn.create_vpc("10.1.0.0/16") subnet2 = conn.create_subnet(vpc2.id, "10.1.0.0/27") - reservation2 = conn.run_instances( - 'ami-1234abcd', min_count=1, subnet_id=subnet2.id) + reservation2 = conn.run_instances("ami-1234abcd", min_count=1, subnet_id=subnet2.id) instance2 = reservation2.instances[0] - reservations1 = conn.get_all_instances(filters={'vpc-id': vpc1.id}) + reservations1 = conn.get_all_instances(filters={"vpc-id": vpc1.id}) reservations1.should.have.length_of(1) reservations1[0].instances.should.have.length_of(1) reservations1[0].instances[0].id.should.equal(instance1.id) reservations1[0].instances[0].vpc_id.should.equal(vpc1.id) reservations1[0].instances[0].subnet_id.should.equal(subnet1.id) - reservations2 = conn.get_all_instances(filters={'vpc-id': vpc2.id}) + reservations2 = conn.get_all_instances(filters={"vpc-id": vpc2.id}) reservations2.should.have.length_of(1) reservations2[0].instances.should.have.length_of(1) reservations2[0].instances[0].id.should.equal(instance2.id) @@ -395,111 +386,105 @@ def test_get_instances_filtering_by_vpc_id(): @mock_ec2_deprecated def test_get_instances_filtering_by_architecture(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=1) + reservation = conn.run_instances("ami-1234abcd", min_count=1) instance = reservation.instances - reservations = conn.get_all_instances(filters={'architecture': 'x86_64'}) + reservations = conn.get_all_instances(filters={"architecture": "x86_64"}) # get_all_instances should return the instance reservations[0].instances.should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_image_id(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - conn = boto3.resource('ec2', 'us-east-1') - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1) + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + conn = boto3.resource("ec2", "us-east-1") + conn.create_instances(ImageId=image_id, MinCount=1, MaxCount=1) - reservations = client.describe_instances(Filters=[{'Name': 'image-id', - 'Values': [image_id]}])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + reservations = client.describe_instances( + Filters=[{"Name": "image-id", "Values": [image_id]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_private_dns(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - conn = boto3.resource('ec2', 'us-east-1') - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - PrivateIpAddress='10.0.0.1') - reservations = client.describe_instances(Filters=[ - {'Name': 'private-dns-name', 'Values': ['ip-10-0-0-1.ec2.internal']} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + conn = boto3.resource("ec2", "us-east-1") + conn.create_instances( + ImageId=image_id, MinCount=1, MaxCount=1, PrivateIpAddress="10.0.0.1" + ) + reservations = client.describe_instances( + Filters=[{"Name": "private-dns-name", "Values": ["ip-10-0-0-1.ec2.internal"]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_ni_private_dns(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-west-2') - conn = boto3.resource('ec2', 'us-west-2') - conn.create_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - PrivateIpAddress='10.0.0.1') - reservations = client.describe_instances(Filters=[ - {'Name': 'network-interface.private-dns-name', 'Values': ['ip-10-0-0-1.us-west-2.compute.internal']} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-west-2") + conn = boto3.resource("ec2", "us-west-2") + conn.create_instances( + ImageId=image_id, MinCount=1, MaxCount=1, PrivateIpAddress="10.0.0.1" + ) + reservations = client.describe_instances( + Filters=[ + { + "Name": "network-interface.private-dns-name", + "Values": ["ip-10-0-0-1.us-west-2.compute.internal"], + } + ] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_instance_group_name(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - client.create_security_group( - Description='test', - GroupName='test_sg' + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + client.create_security_group(Description="test", GroupName="test_sg") + client.run_instances( + ImageId=image_id, MinCount=1, MaxCount=1, SecurityGroups=["test_sg"] ) - client.run_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - SecurityGroups=['test_sg']) - reservations = client.describe_instances(Filters=[ - {'Name': 'instance.group-name', 'Values': ['test_sg']} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + reservations = client.describe_instances( + Filters=[{"Name": "instance.group-name", "Values": ["test_sg"]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2 def test_get_instances_filtering_by_instance_group_id(): - image_id = 'ami-1234abcd' - client = boto3.client('ec2', region_name='us-east-1') - create_sg = client.create_security_group( - Description='test', - GroupName='test_sg' + image_id = "ami-1234abcd" + client = boto3.client("ec2", region_name="us-east-1") + create_sg = client.create_security_group(Description="test", GroupName="test_sg") + group_id = create_sg["GroupId"] + client.run_instances( + ImageId=image_id, MinCount=1, MaxCount=1, SecurityGroups=["test_sg"] ) - group_id = create_sg['GroupId'] - client.run_instances(ImageId=image_id, - MinCount=1, - MaxCount=1, - SecurityGroups=['test_sg']) - reservations = client.describe_instances(Filters=[ - {'Name': 'instance.group-id', 'Values': [group_id]} - ])['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) + reservations = client.describe_instances( + Filters=[{"Name": "instance.group-id", "Values": [group_id]}] + )["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) @mock_ec2_deprecated def test_get_instances_filtering_by_tag(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - instance1.add_tag('tag1', 'value1') - instance1.add_tag('tag2', 'value2') - instance2.add_tag('tag1', 'value1') - instance2.add_tag('tag2', 'wrong value') - instance3.add_tag('tag2', 'value2') + instance1.add_tag("tag1", "value1") + instance1.add_tag("tag2", "value2") + instance2.add_tag("tag1", "value1") + instance2.add_tag("tag2", "wrong value") + instance3.add_tag("tag2", "value2") - reservations = conn.get_all_instances(filters={'tag:tag0': 'value0'}) + reservations = conn.get_all_instances(filters={"tag:tag0": "value0"}) # get_all_instances should return no instances reservations.should.have.length_of(0) - reservations = conn.get_all_instances(filters={'tag:tag1': 'value1'}) + reservations = conn.get_all_instances(filters={"tag:tag1": "value1"}) # get_all_instances should return both instances with this tag value reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(2) @@ -507,21 +492,22 @@ def test_get_instances_filtering_by_tag(): reservations[0].instances[1].id.should.equal(instance2.id) reservations = conn.get_all_instances( - filters={'tag:tag1': 'value1', 'tag:tag2': 'value2'}) + filters={"tag:tag1": "value1", "tag:tag2": "value2"} + ) # get_all_instances should return the instance with both tag values reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance1.id) reservations = conn.get_all_instances( - filters={'tag:tag1': 'value1', 'tag:tag2': 'value2'}) + filters={"tag:tag1": "value1", "tag:tag2": "value2"} + ) # get_all_instances should return the instance with both tag values reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(1) reservations[0].instances[0].id.should.equal(instance1.id) - reservations = conn.get_all_instances( - filters={'tag:tag2': ['value2', 'bogus']}) + reservations = conn.get_all_instances(filters={"tag:tag2": ["value2", "bogus"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -533,27 +519,26 @@ def test_get_instances_filtering_by_tag(): @mock_ec2_deprecated def test_get_instances_filtering_by_tag_value(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - instance1.add_tag('tag1', 'value1') - instance1.add_tag('tag2', 'value2') - instance2.add_tag('tag1', 'value1') - instance2.add_tag('tag2', 'wrong value') - instance3.add_tag('tag2', 'value2') + instance1.add_tag("tag1", "value1") + instance1.add_tag("tag2", "value2") + instance2.add_tag("tag1", "value1") + instance2.add_tag("tag2", "wrong value") + instance3.add_tag("tag2", "value2") - reservations = conn.get_all_instances(filters={'tag-value': 'value0'}) + reservations = conn.get_all_instances(filters={"tag-value": "value0"}) # get_all_instances should return no instances reservations.should.have.length_of(0) - reservations = conn.get_all_instances(filters={'tag-value': 'value1'}) + reservations = conn.get_all_instances(filters={"tag-value": "value1"}) # get_all_instances should return both instances with this tag value reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(2) reservations[0].instances[0].id.should.equal(instance1.id) reservations[0].instances[1].id.should.equal(instance2.id) - reservations = conn.get_all_instances( - filters={'tag-value': ['value2', 'value1']}) + reservations = conn.get_all_instances(filters={"tag-value": ["value2", "value1"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -562,8 +547,7 @@ def test_get_instances_filtering_by_tag_value(): reservations[0].instances[1].id.should.equal(instance2.id) reservations[0].instances[2].id.should.equal(instance3.id) - reservations = conn.get_all_instances( - filters={'tag-value': ['value2', 'bogus']}) + reservations = conn.get_all_instances(filters={"tag-value": ["value2", "bogus"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -575,27 +559,26 @@ def test_get_instances_filtering_by_tag_value(): @mock_ec2_deprecated def test_get_instances_filtering_by_tag_name(): conn = boto.connect_ec2() - reservation = conn.run_instances('ami-1234abcd', min_count=3) + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances - instance1.add_tag('tag1') - instance1.add_tag('tag2') - instance2.add_tag('tag1') - instance2.add_tag('tag2X') - instance3.add_tag('tag3') + instance1.add_tag("tag1") + instance1.add_tag("tag2") + instance2.add_tag("tag1") + instance2.add_tag("tag2X") + instance3.add_tag("tag3") - reservations = conn.get_all_instances(filters={'tag-key': 'tagX'}) + reservations = conn.get_all_instances(filters={"tag-key": "tagX"}) # get_all_instances should return no instances reservations.should.have.length_of(0) - reservations = conn.get_all_instances(filters={'tag-key': 'tag1'}) + reservations = conn.get_all_instances(filters={"tag-key": "tag1"}) # get_all_instances should return both instances with this tag value reservations.should.have.length_of(1) reservations[0].instances.should.have.length_of(2) reservations[0].instances[0].id.should.equal(instance1.id) reservations[0].instances[1].id.should.equal(instance2.id) - reservations = conn.get_all_instances( - filters={'tag-key': ['tag1', 'tag3']}) + reservations = conn.get_all_instances(filters={"tag-key": ["tag1", "tag3"]}) # get_all_instances should return both instances with one of the # acceptable tag values reservations.should.have.length_of(1) @@ -607,8 +590,8 @@ def test_get_instances_filtering_by_tag_name(): @mock_ec2_deprecated def test_instance_start_and_stop(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', min_count=2) + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", min_count=2) instances = reservation.instances instances.should.have.length_of(2) @@ -616,103 +599,111 @@ def test_instance_start_and_stop(): with assert_raises(EC2ResponseError) as ex: stopped_instances = conn.stop_instances(instance_ids, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the StopInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the StopInstance operation: Request would have succeeded, but DryRun flag is set" + ) stopped_instances = conn.stop_instances(instance_ids) for instance in stopped_instances: - instance.state.should.equal('stopping') + instance.state.should.equal("stopping") with assert_raises(EC2ResponseError) as ex: - started_instances = conn.start_instances( - [instances[0].id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + started_instances = conn.start_instances([instances[0].id], dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the StartInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the StartInstance operation: Request would have succeeded, but DryRun flag is set" + ) started_instances = conn.start_instances([instances[0].id]) - started_instances[0].state.should.equal('pending') + started_instances[0].state.should.equal("pending") @mock_ec2_deprecated def test_instance_reboot(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: instance.reboot(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RebootInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RebootInstance operation: Request would have succeeded, but DryRun flag is set" + ) instance.reboot() - instance.state.should.equal('pending') + instance.state.should.equal("pending") @mock_ec2_deprecated def test_instance_attribute_instance_type(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: instance.modify_attribute("instanceType", "m1.small", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceType operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceType operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("instanceType", "m1.small") instance_attribute = instance.get_attribute("instanceType") instance_attribute.should.be.a(InstanceAttribute) - instance_attribute.get('instanceType').should.equal("m1.small") + instance_attribute.get("instanceType").should.equal("m1.small") @mock_ec2_deprecated def test_modify_instance_attribute_security_groups(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] - sg_id = conn.create_security_group('test security group', 'this is a test security group').id - sg_id2 = conn.create_security_group('test security group 2', 'this is a test security group 2').id + sg_id = conn.create_security_group( + "test security group", "this is a test security group" + ).id + sg_id2 = conn.create_security_group( + "test security group 2", "this is a test security group 2" + ).id with assert_raises(EC2ResponseError) as ex: instance.modify_attribute("groupSet", [sg_id, sg_id2], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("groupSet", [sg_id, sg_id2]) instance_attribute = instance.get_attribute("groupSet") instance_attribute.should.be.a(InstanceAttribute) - group_list = instance_attribute.get('groupSet') + group_list = instance_attribute.get("groupSet") any(g.id == sg_id for g in group_list).should.be.ok any(g.id == sg_id2 for g in group_list).should.be.ok @mock_ec2_deprecated def test_instance_attribute_user_data(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: - instance.modify_attribute( - "userData", "this is my user data", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + instance.modify_attribute("userData", "this is my user data", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyUserData operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyUserData operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("userData", "this is my user data") @@ -723,12 +714,12 @@ def test_instance_attribute_user_data(): @mock_ec2_deprecated def test_instance_attribute_source_dest_check(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] # Default value is true - instance.sourceDestCheck.should.equal('true') + instance.sourceDestCheck.should.equal("true") instance_attribute = instance.get_attribute("sourceDestCheck") instance_attribute.should.be.a(InstanceAttribute) @@ -738,15 +729,16 @@ def test_instance_attribute_source_dest_check(): with assert_raises(EC2ResponseError) as ex: instance.modify_attribute("sourceDestCheck", False, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifySourceDestCheck operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifySourceDestCheck operation: Request would have succeeded, but DryRun flag is set" + ) instance.modify_attribute("sourceDestCheck", False) instance.update() - instance.sourceDestCheck.should.equal('false') + instance.sourceDestCheck.should.equal("false") instance_attribute = instance.get_attribute("sourceDestCheck") instance_attribute.should.be.a(InstanceAttribute) @@ -756,7 +748,7 @@ def test_instance_attribute_source_dest_check(): instance.modify_attribute("sourceDestCheck", True) instance.update() - instance.sourceDestCheck.should.equal('true') + instance.sourceDestCheck.should.equal("true") instance_attribute = instance.get_attribute("sourceDestCheck") instance_attribute.should.be.a(InstanceAttribute) @@ -766,33 +758,32 @@ def test_instance_attribute_source_dest_check(): @mock_ec2_deprecated def test_user_data_with_run_instance(): user_data = b"some user data" - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', user_data=user_data) + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", user_data=user_data) instance = reservation.instances[0] instance_attribute = instance.get_attribute("userData") instance_attribute.should.be.a(InstanceAttribute) - retrieved_user_data = instance_attribute.get("userData").encode('utf-8') + retrieved_user_data = instance_attribute.get("userData").encode("utf-8") decoded_user_data = base64.decodestring(retrieved_user_data) decoded_user_data.should.equal(b"some user data") @mock_ec2_deprecated def test_run_instance_with_security_group_name(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - group = conn.create_security_group( - 'group1', "some description", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + group = conn.create_security_group("group1", "some description", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set" + ) - group = conn.create_security_group('group1', "some description") + group = conn.create_security_group("group1", "some description") - reservation = conn.run_instances('ami-1234abcd', - security_groups=['group1']) + reservation = conn.run_instances("ami-1234abcd", security_groups=["group1"]) instance = reservation.instances[0] instance.groups[0].id.should.equal(group.id) @@ -801,10 +792,9 @@ def test_run_instance_with_security_group_name(): @mock_ec2_deprecated def test_run_instance_with_security_group_id(): - conn = boto.connect_ec2('the_key', 'the_secret') - group = conn.create_security_group('group1', "some description") - reservation = conn.run_instances('ami-1234abcd', - security_group_ids=[group.id]) + conn = boto.connect_ec2("the_key", "the_secret") + group = conn.create_security_group("group1", "some description") + reservation = conn.run_instances("ami-1234abcd", security_group_ids=[group.id]) instance = reservation.instances[0] instance.groups[0].id.should.equal(group.id) @@ -813,8 +803,8 @@ def test_run_instance_with_security_group_id(): @mock_ec2_deprecated def test_run_instance_with_instance_type(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', instance_type="t1.micro") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", instance_type="t1.micro") instance = reservation.instances[0] instance.instance_type.should.equal("t1.micro") @@ -823,7 +813,7 @@ def test_run_instance_with_instance_type(): @mock_ec2_deprecated def test_run_instance_with_default_placement(): conn = boto.ec2.connect_to_region("us-east-1") - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.placement.should.equal("us-east-1a") @@ -831,8 +821,8 @@ def test_run_instance_with_default_placement(): @mock_ec2_deprecated def test_run_instance_with_placement(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', placement="us-east-1b") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", placement="us-east-1b") instance = reservation.instances[0] instance.placement.should.equal("us-east-1b") @@ -840,11 +830,14 @@ def test_run_instance_with_placement(): @mock_ec2 def test_run_instance_with_subnet_boto3(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") ip_networks = [ - (ipaddress.ip_network('10.0.0.0/16'), ipaddress.ip_network('10.0.99.0/24')), - (ipaddress.ip_network('192.168.42.0/24'), ipaddress.ip_network('192.168.42.0/25')) + (ipaddress.ip_network("10.0.0.0/16"), ipaddress.ip_network("10.0.99.0/24")), + ( + ipaddress.ip_network("192.168.42.0/24"), + ipaddress.ip_network("192.168.42.0/25"), + ), ] # Tests instances are created with the correct IPs @@ -853,115 +846,104 @@ def test_run_instance_with_subnet_boto3(): CidrBlock=str(vpc_cidr), AmazonProvidedIpv6CidrBlock=False, DryRun=False, - InstanceTenancy='default' + InstanceTenancy="default", ) - vpc_id = resp['Vpc']['VpcId'] + vpc_id = resp["Vpc"]["VpcId"] - resp = client.create_subnet( - CidrBlock=str(subnet_cidr), - VpcId=vpc_id - ) - subnet_id = resp['Subnet']['SubnetId'] + resp = client.create_subnet(CidrBlock=str(subnet_cidr), VpcId=vpc_id) + subnet_id = resp["Subnet"]["SubnetId"] resp = client.run_instances( - ImageId='ami-1234abcd', - MaxCount=1, - MinCount=1, - SubnetId=subnet_id + ImageId="ami-1234abcd", MaxCount=1, MinCount=1, SubnetId=subnet_id ) - instance = resp['Instances'][0] - instance['SubnetId'].should.equal(subnet_id) + instance = resp["Instances"][0] + instance["SubnetId"].should.equal(subnet_id) - priv_ipv4 = ipaddress.ip_address(six.text_type(instance['PrivateIpAddress'])) + priv_ipv4 = ipaddress.ip_address(six.text_type(instance["PrivateIpAddress"])) subnet_cidr.should.contain(priv_ipv4) @mock_ec2 def test_run_instance_with_specified_private_ipv4(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") - vpc_cidr = ipaddress.ip_network('192.168.42.0/24') - subnet_cidr = ipaddress.ip_network('192.168.42.0/25') + vpc_cidr = ipaddress.ip_network("192.168.42.0/24") + subnet_cidr = ipaddress.ip_network("192.168.42.0/25") resp = client.create_vpc( CidrBlock=str(vpc_cidr), AmazonProvidedIpv6CidrBlock=False, DryRun=False, - InstanceTenancy='default' + InstanceTenancy="default", ) - vpc_id = resp['Vpc']['VpcId'] + vpc_id = resp["Vpc"]["VpcId"] - resp = client.create_subnet( - CidrBlock=str(subnet_cidr), - VpcId=vpc_id - ) - subnet_id = resp['Subnet']['SubnetId'] + resp = client.create_subnet(CidrBlock=str(subnet_cidr), VpcId=vpc_id) + subnet_id = resp["Subnet"]["SubnetId"] resp = client.run_instances( - ImageId='ami-1234abcd', + ImageId="ami-1234abcd", MaxCount=1, MinCount=1, SubnetId=subnet_id, - PrivateIpAddress='192.168.42.5' + PrivateIpAddress="192.168.42.5", ) - instance = resp['Instances'][0] - instance['SubnetId'].should.equal(subnet_id) - instance['PrivateIpAddress'].should.equal('192.168.42.5') + instance = resp["Instances"][0] + instance["SubnetId"].should.equal(subnet_id) + instance["PrivateIpAddress"].should.equal("192.168.42.5") @mock_ec2 def test_run_instance_mapped_public_ipv4(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") - vpc_cidr = ipaddress.ip_network('192.168.42.0/24') - subnet_cidr = ipaddress.ip_network('192.168.42.0/25') + vpc_cidr = ipaddress.ip_network("192.168.42.0/24") + subnet_cidr = ipaddress.ip_network("192.168.42.0/25") resp = client.create_vpc( CidrBlock=str(vpc_cidr), AmazonProvidedIpv6CidrBlock=False, DryRun=False, - InstanceTenancy='default' + InstanceTenancy="default", ) - vpc_id = resp['Vpc']['VpcId'] + vpc_id = resp["Vpc"]["VpcId"] - resp = client.create_subnet( - CidrBlock=str(subnet_cidr), - VpcId=vpc_id - ) - subnet_id = resp['Subnet']['SubnetId'] + resp = client.create_subnet(CidrBlock=str(subnet_cidr), VpcId=vpc_id) + subnet_id = resp["Subnet"]["SubnetId"] client.modify_subnet_attribute( - SubnetId=subnet_id, - MapPublicIpOnLaunch={'Value': True} + SubnetId=subnet_id, MapPublicIpOnLaunch={"Value": True} ) resp = client.run_instances( - ImageId='ami-1234abcd', - MaxCount=1, - MinCount=1, - SubnetId=subnet_id + ImageId="ami-1234abcd", MaxCount=1, MinCount=1, SubnetId=subnet_id ) - instance = resp['Instances'][0] - instance.should.contain('PublicDnsName') - instance.should.contain('PublicIpAddress') - len(instance['PublicDnsName']).should.be.greater_than(0) - len(instance['PublicIpAddress']).should.be.greater_than(0) + instance = resp["Instances"][0] + instance.should.contain("PublicDnsName") + instance.should.contain("PublicIpAddress") + len(instance["PublicDnsName"]).should.be.greater_than(0) + len(instance["PublicIpAddress"]).should.be.greater_than(0) @mock_ec2_deprecated def test_run_instance_with_nic_autocreated(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) private_ip = "10.0.0.1" - reservation = conn.run_instances('ami-1234abcd', subnet_id=subnet.id, - security_groups=[security_group1.name], - security_group_ids=[security_group2.id], - private_ip_address=private_ip) + reservation = conn.run_instances( + "ami-1234abcd", + subnet_id=subnet.id, + security_groups=[security_group1.name], + security_group_ids=[security_group2.id], + private_ip_address=private_ip, + ) instance = reservation.instances[0] all_enis = conn.get_all_network_interfaces() @@ -974,39 +956,52 @@ def test_run_instance_with_nic_autocreated(): instance.subnet_id.should.equal(subnet.id) instance.groups.should.have.length_of(2) set([group.id for group in instance.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) eni.subnet_id.should.equal(subnet.id) eni.groups.should.have.length_of(2) set([group.id for group in eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) eni.private_ip_addresses.should.have.length_of(1) eni.private_ip_addresses[0].private_ip_address.should.equal(private_ip) @mock_ec2_deprecated def test_run_instance_with_nic_preexisting(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) private_ip = "54.0.0.1" eni = conn.create_network_interface( - subnet.id, private_ip, groups=[security_group1.id]) + subnet.id, private_ip, groups=[security_group1.id] + ) # Boto requires NetworkInterfaceCollection of NetworkInterfaceSpecifications... # annoying, but generates the desired querystring. - from boto.ec2.networkinterface import NetworkInterfaceSpecification, NetworkInterfaceCollection + from boto.ec2.networkinterface import ( + NetworkInterfaceSpecification, + NetworkInterfaceCollection, + ) + interface = NetworkInterfaceSpecification( - network_interface_id=eni.id, device_index=0) + network_interface_id=eni.id, device_index=0 + ) interfaces = NetworkInterfaceCollection(interface) # end Boto objects - reservation = conn.run_instances('ami-1234abcd', network_interfaces=interfaces, - security_group_ids=[security_group2.id]) + reservation = conn.run_instances( + "ami-1234abcd", + network_interfaces=interfaces, + security_group_ids=[security_group2.id], + ) instance = reservation.instances[0] instance.subnet_id.should.equal(subnet.id) @@ -1021,26 +1016,29 @@ def test_run_instance_with_nic_preexisting(): instance_eni.subnet_id.should.equal(subnet.id) instance_eni.groups.should.have.length_of(2) set([group.id for group in instance_eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) instance_eni.private_ip_addresses.should.have.length_of(1) - instance_eni.private_ip_addresses[ - 0].private_ip_address.should.equal(private_ip) + instance_eni.private_ip_addresses[0].private_ip_address.should.equal(private_ip) @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_instance_with_nic_attach_detach(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") security_group1 = conn.create_security_group( - 'test security group #1', 'this is a test security group') + "test security group #1", "this is a test security group" + ) security_group2 = conn.create_security_group( - 'test security group #2', 'this is a test security group') + "test security group #2", "this is a test security group" + ) reservation = conn.run_instances( - 'ami-1234abcd', security_group_ids=[security_group1.id]) + "ami-1234abcd", security_group_ids=[security_group1.id] + ) instance = reservation.instances[0] eni = conn.create_network_interface(subnet.id, groups=[security_group2.id]) @@ -1049,17 +1047,16 @@ def test_instance_with_nic_attach_detach(): instance.interfaces.should.have.length_of(1) eni.groups.should.have.length_of(1) - set([group.id for group in eni.groups]).should.equal( - set([security_group2.id])) + set([group.id for group in eni.groups]).should.equal(set([security_group2.id])) # Attach with assert_raises(EC2ResponseError) as ex: - conn.attach_network_interface( - eni.id, instance.id, device_index=1, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.attach_network_interface(eni.id, instance.id, device_index=1, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AttachNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AttachNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) conn.attach_network_interface(eni.id, instance.id, device_index=1) @@ -1070,21 +1067,23 @@ def test_instance_with_nic_attach_detach(): instance_eni.id.should.equal(eni.id) instance_eni.groups.should.have.length_of(2) set([group.id for group in instance_eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) - eni = conn.get_all_network_interfaces( - filters={'network-interface-id': eni.id})[0] + eni = conn.get_all_network_interfaces(filters={"network-interface-id": eni.id})[0] eni.groups.should.have.length_of(2) set([group.id for group in eni.groups]).should.equal( - set([security_group1.id, security_group2.id])) + set([security_group1.id, security_group2.id]) + ) # Detach with assert_raises(EC2ResponseError) as ex: conn.detach_network_interface(instance_eni.attachment.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DetachNetworkInterface operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DetachNetworkInterface operation: Request would have succeeded, but DryRun flag is set" + ) conn.detach_network_interface(instance_eni.attachment.id) @@ -1092,35 +1091,35 @@ def test_instance_with_nic_attach_detach(): instance.update() instance.interfaces.should.have.length_of(1) - eni = conn.get_all_network_interfaces( - filters={'network-interface-id': eni.id})[0] + eni = conn.get_all_network_interfaces(filters={"network-interface-id": eni.id})[0] eni.groups.should.have.length_of(1) - set([group.id for group in eni.groups]).should.equal( - set([security_group2.id])) + set([group.id for group in eni.groups]).should.equal(set([security_group2.id])) # Detach with invalid attachment ID with assert_raises(EC2ResponseError) as cm: - conn.detach_network_interface('eni-attach-1234abcd') - cm.exception.code.should.equal('InvalidAttachmentID.NotFound') + conn.detach_network_interface("eni-attach-1234abcd") + cm.exception.code.should.equal("InvalidAttachmentID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_ec2_classic_has_public_ip_address(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", key_name="keypair_name") instance = reservation.instances[0] instance.ip_address.should_not.equal(None) - instance.public_dns_name.should.contain(instance.ip_address.replace('.', '-')) + instance.public_dns_name.should.contain(instance.ip_address.replace(".", "-")) instance.private_ip_address.should_not.equal(None) - instance.private_dns_name.should.contain(instance.private_ip_address.replace('.', '-')) + instance.private_dns_name.should.contain( + instance.private_ip_address.replace(".", "-") + ) @mock_ec2_deprecated def test_run_instance_with_keypair(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", key_name="keypair_name") instance = reservation.instances[0] instance.key_name.should.equal("keypair_name") @@ -1128,32 +1127,32 @@ def test_run_instance_with_keypair(): @mock_ec2_deprecated def test_describe_instance_status_no_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") all_status = conn.get_all_instance_status() len(all_status).should.equal(0) @mock_ec2_deprecated def test_describe_instance_status_with_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn = boto.connect_ec2("the_key", "the_secret") + conn.run_instances("ami-1234abcd", key_name="keypair_name") all_status = conn.get_all_instance_status() len(all_status).should.equal(1) - all_status[0].instance_status.status.should.equal('ok') - all_status[0].system_status.status.should.equal('ok') + all_status[0].instance_status.status.should.equal("ok") + all_status[0].system_status.status.should.equal("ok") @mock_ec2_deprecated def test_describe_instance_status_with_instance_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") # We want to filter based on this one - reservation = conn.run_instances('ami-1234abcd', key_name="keypair_name") + reservation = conn.run_instances("ami-1234abcd", key_name="keypair_name") instance = reservation.instances[0] # This is just to setup the test - conn.run_instances('ami-1234abcd', key_name="keypair_name") + conn.run_instances("ami-1234abcd", key_name="keypair_name") all_status = conn.get_all_instance_status(instance_ids=[instance.id]) len(all_status).should.equal(1) @@ -1162,7 +1161,7 @@ def test_describe_instance_status_with_instance_filter(): # Call get_all_instance_status with a bad id should raise an error with assert_raises(EC2ResponseError) as cm: conn.get_all_instance_status(instance_ids=[instance.id, "i-1234abcd"]) - cm.exception.code.should.equal('InvalidInstanceID.NotFound') + cm.exception.code.should.equal("InvalidInstanceID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -1170,8 +1169,8 @@ def test_describe_instance_status_with_instance_filter(): @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_describe_instance_status_with_non_running_instances(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd', min_count=3) + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd", min_count=3) instance1, instance2, instance3 = reservation.instances instance1.stop() instance2.terminate() @@ -1179,40 +1178,41 @@ def test_describe_instance_status_with_non_running_instances(): all_running_status = conn.get_all_instance_status() all_running_status.should.have.length_of(1) all_running_status[0].id.should.equal(instance3.id) - all_running_status[0].state_name.should.equal('running') + all_running_status[0].state_name.should.equal("running") all_status = conn.get_all_instance_status(include_all_instances=True) all_status.should.have.length_of(3) status1 = next((s for s in all_status if s.id == instance1.id), None) - status1.state_name.should.equal('stopped') + status1.state_name.should.equal("stopped") status2 = next((s for s in all_status if s.id == instance2.id), None) - status2.state_name.should.equal('terminated') + status2.state_name.should.equal("terminated") status3 = next((s for s in all_status if s.id == instance3.id), None) - status3.state_name.should.equal('running') + status3.state_name.should.equal("running") @mock_ec2_deprecated def test_get_instance_by_security_group(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - conn.run_instances('ami-1234abcd') + conn.run_instances("ami-1234abcd") instance = conn.get_only_instances()[0] - security_group = conn.create_security_group('test', 'test') + security_group = conn.create_security_group("test", "test") with assert_raises(EC2ResponseError) as ex: - conn.modify_instance_attribute(instance.id, "groupSet", [ - security_group.id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.modify_instance_attribute( + instance.id, "groupSet", [security_group.id], dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ModifyInstanceSecurityGroups operation: Request would have succeeded, but DryRun flag is set" + ) - conn.modify_instance_attribute( - instance.id, "groupSet", [security_group.id]) + conn.modify_instance_attribute(instance.id, "groupSet", [security_group.id]) security_group_instances = security_group.instances() @@ -1222,38 +1222,31 @@ def test_get_instance_by_security_group(): @mock_ec2 def test_modify_delete_on_termination(): - ec2_client = boto3.resource('ec2', region_name='us-west-1') - result = ec2_client.create_instances(ImageId='ami-12345678', MinCount=1, MaxCount=1) + ec2_client = boto3.resource("ec2", region_name="us-west-1") + result = ec2_client.create_instances(ImageId="ami-12345678", MinCount=1, MaxCount=1) instance = result[0] instance.load() - instance.block_device_mappings[0]['Ebs']['DeleteOnTermination'].should.be(False) + instance.block_device_mappings[0]["Ebs"]["DeleteOnTermination"].should.be(False) instance.modify_attribute( - BlockDeviceMappings=[{ - 'DeviceName': '/dev/sda1', - 'Ebs': {'DeleteOnTermination': True} - }] + BlockDeviceMappings=[ + {"DeviceName": "/dev/sda1", "Ebs": {"DeleteOnTermination": True}} + ] ) instance.load() - instance.block_device_mappings[0]['Ebs']['DeleteOnTermination'].should.be(True) + instance.block_device_mappings[0]["Ebs"]["DeleteOnTermination"].should.be(True) + @mock_ec2 def test_create_instance_ebs_optimized(): - ec2_resource = boto3.resource('ec2', region_name='eu-west-1') + ec2_resource = boto3.resource("ec2", region_name="eu-west-1") instance = ec2_resource.create_instances( - ImageId = 'ami-12345678', - MaxCount = 1, - MinCount = 1, - EbsOptimized = True, + ImageId="ami-12345678", MaxCount=1, MinCount=1, EbsOptimized=True )[0] instance.load() instance.ebs_optimized.should.be(True) - instance.modify_attribute( - EbsOptimized={ - 'Value': False - } - ) + instance.modify_attribute(EbsOptimized={"Value": False}) instance.load() instance.ebs_optimized.should.be(False) @@ -1261,34 +1254,55 @@ def test_create_instance_ebs_optimized(): @mock_ec2 def test_run_multiple_instances_in_same_command(): instance_count = 4 - client = boto3.client('ec2', region_name='us-east-1') - client.run_instances(ImageId='ami-1234abcd', - MinCount=instance_count, - MaxCount=instance_count) - reservations = client.describe_instances()['Reservations'] + client = boto3.client("ec2", region_name="us-east-1") + client.run_instances( + ImageId="ami-1234abcd", MinCount=instance_count, MaxCount=instance_count + ) + reservations = client.describe_instances()["Reservations"] - reservations[0]['Instances'].should.have.length_of(instance_count) + reservations[0]["Instances"].should.have.length_of(instance_count) - instances = reservations[0]['Instances'] + instances = reservations[0]["Instances"] for i in range(0, instance_count): - instances[i]['AmiLaunchIndex'].should.be(i) + instances[i]["AmiLaunchIndex"].should.be(i) @mock_ec2 def test_describe_instance_attribute(): - client = boto3.client('ec2', region_name='us-east-1') + client = boto3.client("ec2", region_name="us-east-1") security_group_id = client.create_security_group( - GroupName='test security group', Description='this is a test security group')['GroupId'] - client.run_instances(ImageId='ami-1234abcd', - MinCount=1, - MaxCount=1, - SecurityGroupIds=[security_group_id]) - instance_id = client.describe_instances()['Reservations'][0]['Instances'][0]['InstanceId'] + GroupName="test security group", Description="this is a test security group" + )["GroupId"] + client.run_instances( + ImageId="ami-1234abcd", + MinCount=1, + MaxCount=1, + SecurityGroupIds=[security_group_id], + ) + instance_id = client.describe_instances()["Reservations"][0]["Instances"][0][ + "InstanceId" + ] - valid_instance_attributes = ['instanceType', 'kernel', 'ramdisk', 'userData', 'disableApiTermination', 'instanceInitiatedShutdownBehavior', 'rootDeviceName', 'blockDeviceMapping', 'productCodes', 'sourceDestCheck', 'groupSet', 'ebsOptimized', 'sriovNetSupport'] + valid_instance_attributes = [ + "instanceType", + "kernel", + "ramdisk", + "userData", + "disableApiTermination", + "instanceInitiatedShutdownBehavior", + "rootDeviceName", + "blockDeviceMapping", + "productCodes", + "sourceDestCheck", + "groupSet", + "ebsOptimized", + "sriovNetSupport", + ] for valid_instance_attribute in valid_instance_attributes: - response = client.describe_instance_attribute(InstanceId=instance_id, Attribute=valid_instance_attribute) + response = client.describe_instance_attribute( + InstanceId=instance_id, Attribute=valid_instance_attribute + ) if valid_instance_attribute == "groupSet": response.should.have.key("Groups") response["Groups"].should.have.length_of(1) @@ -1297,12 +1311,22 @@ def test_describe_instance_attribute(): response.should.have.key("UserData") response["UserData"].should.be.empty - invalid_instance_attributes = ['abc', 'Kernel', 'RamDisk', 'userdata', 'iNsTaNcEtYpE'] + invalid_instance_attributes = [ + "abc", + "Kernel", + "RamDisk", + "userdata", + "iNsTaNcEtYpE", + ] for invalid_instance_attribute in invalid_instance_attributes: with assert_raises(ClientError) as ex: - client.describe_instance_attribute(InstanceId=instance_id, Attribute=invalid_instance_attribute) - ex.exception.response['Error']['Code'].should.equal('InvalidParameterValue') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - message = 'Value ({invalid_instance_attribute}) for parameter attribute is invalid. Unknown attribute.'.format(invalid_instance_attribute=invalid_instance_attribute) - ex.exception.response['Error']['Message'].should.equal(message) + client.describe_instance_attribute( + InstanceId=instance_id, Attribute=invalid_instance_attribute + ) + ex.exception.response["Error"]["Code"].should.equal("InvalidParameterValue") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + message = "Value ({invalid_instance_attribute}) for parameter attribute is invalid. Unknown attribute.".format( + invalid_instance_attribute=invalid_instance_attribute + ) + ex.exception.response["Error"]["Message"].should.equal(message) diff --git a/tests/test_ec2/test_internet_gateways.py b/tests/test_ec2/test_internet_gateways.py index 3a1d0fda9..5941643cf 100644 --- a/tests/test_ec2/test_internet_gateways.py +++ b/tests/test_ec2/test_internet_gateways.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -21,20 +22,21 @@ BAD_IGW = "igw-deadbeef" @mock_ec2_deprecated def test_igw_create(): """ internet gateway create """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") conn.get_all_internet_gateways().should.have.length_of(0) with assert_raises(EC2ResponseError) as ex: igw = conn.create_internet_gateway(dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateInternetGateway operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) igw = conn.create_internet_gateway() conn.get_all_internet_gateways().should.have.length_of(1) - igw.id.should.match(r'igw-[0-9a-f]+') + igw.id.should.match(r"igw-[0-9a-f]+") igw = conn.get_all_internet_gateways()[0] igw.attachments.should.have.length_of(0) @@ -43,16 +45,17 @@ def test_igw_create(): @mock_ec2_deprecated def test_igw_attach(): """ internet gateway attach """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() vpc = conn.create_vpc(VPC_CIDR) with assert_raises(EC2ResponseError) as ex: conn.attach_internet_gateway(igw.id, vpc.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the AttachInternetGateway operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the AttachInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) conn.attach_internet_gateway(igw.id, vpc.id) @@ -63,12 +66,12 @@ def test_igw_attach(): @mock_ec2_deprecated def test_igw_attach_bad_vpc(): """ internet gateway fail to attach w/ bad vpc """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() with assert_raises(EC2ResponseError) as cm: conn.attach_internet_gateway(igw.id, BAD_VPC) - cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -76,7 +79,7 @@ def test_igw_attach_bad_vpc(): @mock_ec2_deprecated def test_igw_attach_twice(): """ internet gateway fail to attach twice """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() vpc1 = conn.create_vpc(VPC_CIDR) vpc2 = conn.create_vpc(VPC_CIDR) @@ -84,7 +87,7 @@ def test_igw_attach_twice(): with assert_raises(EC2ResponseError) as cm: conn.attach_internet_gateway(igw.id, vpc2.id) - cm.exception.code.should.equal('Resource.AlreadyAssociated') + cm.exception.code.should.equal("Resource.AlreadyAssociated") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -92,17 +95,18 @@ def test_igw_attach_twice(): @mock_ec2_deprecated def test_igw_detach(): """ internet gateway detach""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() vpc = conn.create_vpc(VPC_CIDR) conn.attach_internet_gateway(igw.id, vpc.id) with assert_raises(EC2ResponseError) as ex: conn.detach_internet_gateway(igw.id, vpc.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DetachInternetGateway operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DetachInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) conn.detach_internet_gateway(igw.id, vpc.id) igw = conn.get_all_internet_gateways()[0] @@ -112,7 +116,7 @@ def test_igw_detach(): @mock_ec2_deprecated def test_igw_detach_wrong_vpc(): """ internet gateway fail to detach w/ wrong vpc """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() vpc1 = conn.create_vpc(VPC_CIDR) vpc2 = conn.create_vpc(VPC_CIDR) @@ -120,7 +124,7 @@ def test_igw_detach_wrong_vpc(): with assert_raises(EC2ResponseError) as cm: conn.detach_internet_gateway(igw.id, vpc2.id) - cm.exception.code.should.equal('Gateway.NotAttached') + cm.exception.code.should.equal("Gateway.NotAttached") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -128,14 +132,14 @@ def test_igw_detach_wrong_vpc(): @mock_ec2_deprecated def test_igw_detach_invalid_vpc(): """ internet gateway fail to detach w/ invalid vpc """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() vpc = conn.create_vpc(VPC_CIDR) conn.attach_internet_gateway(igw.id, vpc.id) with assert_raises(EC2ResponseError) as cm: conn.detach_internet_gateway(igw.id, BAD_VPC) - cm.exception.code.should.equal('Gateway.NotAttached') + cm.exception.code.should.equal("Gateway.NotAttached") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -143,13 +147,13 @@ def test_igw_detach_invalid_vpc(): @mock_ec2_deprecated def test_igw_detach_unattached(): """ internet gateway fail to detach unattached """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() vpc = conn.create_vpc(VPC_CIDR) with assert_raises(EC2ResponseError) as cm: conn.detach_internet_gateway(igw.id, vpc.id) - cm.exception.code.should.equal('Gateway.NotAttached') + cm.exception.code.should.equal("Gateway.NotAttached") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -157,7 +161,7 @@ def test_igw_detach_unattached(): @mock_ec2_deprecated def test_igw_delete(): """ internet gateway delete""" - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc(VPC_CIDR) conn.get_all_internet_gateways().should.have.length_of(0) igw = conn.create_internet_gateway() @@ -165,10 +169,11 @@ def test_igw_delete(): with assert_raises(EC2ResponseError) as ex: conn.delete_internet_gateway(igw.id, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteInternetGateway operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteInternetGateway operation: Request would have succeeded, but DryRun flag is set" + ) conn.delete_internet_gateway(igw.id) conn.get_all_internet_gateways().should.have.length_of(0) @@ -177,14 +182,14 @@ def test_igw_delete(): @mock_ec2_deprecated def test_igw_delete_attached(): """ internet gateway fail to delete attached """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() vpc = conn.create_vpc(VPC_CIDR) conn.attach_internet_gateway(igw.id, vpc.id) with assert_raises(EC2ResponseError) as cm: conn.delete_internet_gateway(igw.id) - cm.exception.code.should.equal('DependencyViolation') + cm.exception.code.should.equal("DependencyViolation") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -192,7 +197,7 @@ def test_igw_delete_attached(): @mock_ec2_deprecated def test_igw_desribe(): """ internet gateway fetch by id """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw = conn.create_internet_gateway() igw_by_search = conn.get_all_internet_gateways([igw.id])[0] igw.id.should.equal(igw_by_search.id) @@ -201,10 +206,10 @@ def test_igw_desribe(): @mock_ec2_deprecated def test_igw_describe_bad_id(): """ internet gateway fail to fetch by bad id """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.get_all_internet_gateways([BAD_IGW]) - cm.exception.code.should.equal('InvalidInternetGatewayID.NotFound') + cm.exception.code.should.equal("InvalidInternetGatewayID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -212,15 +217,14 @@ def test_igw_describe_bad_id(): @mock_ec2_deprecated def test_igw_filter_by_vpc_id(): """ internet gateway filter by vpc id """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw1 = conn.create_internet_gateway() igw2 = conn.create_internet_gateway() vpc = conn.create_vpc(VPC_CIDR) conn.attach_internet_gateway(igw1.id, vpc.id) - result = conn.get_all_internet_gateways( - filters={"attachment.vpc-id": vpc.id}) + result = conn.get_all_internet_gateways(filters={"attachment.vpc-id": vpc.id}) result.should.have.length_of(1) result[0].id.should.equal(igw1.id) @@ -228,7 +232,7 @@ def test_igw_filter_by_vpc_id(): @mock_ec2_deprecated def test_igw_filter_by_tags(): """ internet gateway filter by vpc id """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw1 = conn.create_internet_gateway() igw2 = conn.create_internet_gateway() @@ -242,13 +246,12 @@ def test_igw_filter_by_tags(): @mock_ec2_deprecated def test_igw_filter_by_internet_gateway_id(): """ internet gateway filter by internet gateway id """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw1 = conn.create_internet_gateway() igw2 = conn.create_internet_gateway() - result = conn.get_all_internet_gateways( - filters={"internet-gateway-id": igw1.id}) + result = conn.get_all_internet_gateways(filters={"internet-gateway-id": igw1.id}) result.should.have.length_of(1) result[0].id.should.equal(igw1.id) @@ -256,14 +259,13 @@ def test_igw_filter_by_internet_gateway_id(): @mock_ec2_deprecated def test_igw_filter_by_attachment_state(): """ internet gateway filter by attachment state """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") igw1 = conn.create_internet_gateway() igw2 = conn.create_internet_gateway() vpc = conn.create_vpc(VPC_CIDR) conn.attach_internet_gateway(igw1.id, vpc.id) - result = conn.get_all_internet_gateways( - filters={"attachment.state": "available"}) + result = conn.get_all_internet_gateways(filters={"attachment.state": "available"}) result.should.have.length_of(1) result[0].id.should.equal(igw1.id) diff --git a/tests/test_ec2/test_key_pairs.py b/tests/test_ec2/test_key_pairs.py index dfe6eabdf..d632c2478 100644 --- a/tests/test_ec2/test_key_pairs.py +++ b/tests/test_ec2/test_key_pairs.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -47,116 +48,119 @@ moto@github.com""" @mock_ec2_deprecated def test_key_pairs_empty(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") assert len(conn.get_all_key_pairs()) == 0 @mock_ec2_deprecated def test_key_pairs_invalid_id(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.get_all_key_pairs('foo') - cm.exception.code.should.equal('InvalidKeyPair.NotFound') + conn.get_all_key_pairs("foo") + cm.exception.code.should.equal("InvalidKeyPair.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_key_pairs_create(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - conn.create_key_pair('foo', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.create_key_pair("foo", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateKeyPair operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateKeyPair operation: Request would have succeeded, but DryRun flag is set" + ) - kp = conn.create_key_pair('foo') + kp = conn.create_key_pair("foo") rsa_check_private_key(kp.material) kps = conn.get_all_key_pairs() assert len(kps) == 1 - assert kps[0].name == 'foo' + assert kps[0].name == "foo" @mock_ec2_deprecated def test_key_pairs_create_two(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - kp1 = conn.create_key_pair('foo') + kp1 = conn.create_key_pair("foo") rsa_check_private_key(kp1.material) - kp2 = conn.create_key_pair('bar') + kp2 = conn.create_key_pair("bar") rsa_check_private_key(kp2.material) assert kp1.material != kp2.material kps = conn.get_all_key_pairs() kps.should.have.length_of(2) - assert {i.name for i in kps} == {'foo', 'bar'} + assert {i.name for i in kps} == {"foo", "bar"} - kps = conn.get_all_key_pairs('foo') + kps = conn.get_all_key_pairs("foo") kps.should.have.length_of(1) - kps[0].name.should.equal('foo') + kps[0].name.should.equal("foo") @mock_ec2_deprecated def test_key_pairs_create_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.create_key_pair('foo') + conn = boto.connect_ec2("the_key", "the_secret") + conn.create_key_pair("foo") assert len(conn.get_all_key_pairs()) == 1 with assert_raises(EC2ResponseError) as cm: - conn.create_key_pair('foo') - cm.exception.code.should.equal('InvalidKeyPair.Duplicate') + conn.create_key_pair("foo") + cm.exception.code.should.equal("InvalidKeyPair.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_key_pairs_delete_no_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") assert len(conn.get_all_key_pairs()) == 0 - r = conn.delete_key_pair('foo') + r = conn.delete_key_pair("foo") r.should.be.ok @mock_ec2_deprecated def test_key_pairs_delete_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') - conn.create_key_pair('foo') + conn = boto.connect_ec2("the_key", "the_secret") + conn.create_key_pair("foo") with assert_raises(EC2ResponseError) as ex: - r = conn.delete_key_pair('foo', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + r = conn.delete_key_pair("foo", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteKeyPair operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteKeyPair operation: Request would have succeeded, but DryRun flag is set" + ) - r = conn.delete_key_pair('foo') + r = conn.delete_key_pair("foo") r.should.be.ok assert len(conn.get_all_key_pairs()) == 0 @mock_ec2_deprecated def test_key_pairs_import(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', RSA_PUBLIC_KEY_OPENSSH, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.import_key_pair("foo", RSA_PUBLIC_KEY_OPENSSH, dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the ImportKeyPair operation: Request would have succeeded, but DryRun flag is set" + ) - kp1 = conn.import_key_pair('foo', RSA_PUBLIC_KEY_OPENSSH) - assert kp1.name == 'foo' + kp1 = conn.import_key_pair("foo", RSA_PUBLIC_KEY_OPENSSH) + assert kp1.name == "foo" assert kp1.fingerprint == RSA_PUBLIC_KEY_FINGERPRINT - kp2 = conn.import_key_pair('foo2', RSA_PUBLIC_KEY_RFC4716) - assert kp2.name == 'foo2' + kp2 = conn.import_key_pair("foo2", RSA_PUBLIC_KEY_RFC4716) + assert kp2.name == "foo2" assert kp2.fingerprint == RSA_PUBLIC_KEY_FINGERPRINT kps = conn.get_all_key_pairs() @@ -167,58 +171,51 @@ def test_key_pairs_import(): @mock_ec2_deprecated def test_key_pairs_import_exist(): - conn = boto.connect_ec2('the_key', 'the_secret') - kp = conn.import_key_pair('foo', RSA_PUBLIC_KEY_OPENSSH) - assert kp.name == 'foo' + conn = boto.connect_ec2("the_key", "the_secret") + kp = conn.import_key_pair("foo", RSA_PUBLIC_KEY_OPENSSH) + assert kp.name == "foo" assert len(conn.get_all_key_pairs()) == 1 with assert_raises(EC2ResponseError) as cm: - conn.create_key_pair('foo') - cm.exception.code.should.equal('InvalidKeyPair.Duplicate') + conn.create_key_pair("foo") + cm.exception.code.should.equal("InvalidKeyPair.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_key_pairs_invalid(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', b'') - ex.exception.error_code.should.equal('InvalidKeyPair.Format') + conn.import_key_pair("foo", b"") + ex.exception.error_code.should.equal("InvalidKeyPair.Format") ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'Key is not in valid OpenSSH public key format') + ex.exception.message.should.equal("Key is not in valid OpenSSH public key format") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', b'garbage') - ex.exception.error_code.should.equal('InvalidKeyPair.Format') + conn.import_key_pair("foo", b"garbage") + ex.exception.error_code.should.equal("InvalidKeyPair.Format") ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'Key is not in valid OpenSSH public key format') + ex.exception.message.should.equal("Key is not in valid OpenSSH public key format") with assert_raises(EC2ResponseError) as ex: - conn.import_key_pair('foo', DSA_PUBLIC_KEY_OPENSSH) - ex.exception.error_code.should.equal('InvalidKeyPair.Format') + conn.import_key_pair("foo", DSA_PUBLIC_KEY_OPENSSH) + ex.exception.error_code.should.equal("InvalidKeyPair.Format") ex.exception.status.should.equal(400) - ex.exception.message.should.equal( - 'Key is not in valid OpenSSH public key format') + ex.exception.message.should.equal("Key is not in valid OpenSSH public key format") @mock_ec2_deprecated def test_key_pair_filters(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") - _ = conn.create_key_pair('kpfltr1') - kp2 = conn.create_key_pair('kpfltr2') - kp3 = conn.create_key_pair('kpfltr3') + _ = conn.create_key_pair("kpfltr1") + kp2 = conn.create_key_pair("kpfltr2") + kp3 = conn.create_key_pair("kpfltr3") - kp_by_name = conn.get_all_key_pairs( - filters={'key-name': 'kpfltr2'}) - set([kp.name for kp in kp_by_name] - ).should.equal(set([kp2.name])) + kp_by_name = conn.get_all_key_pairs(filters={"key-name": "kpfltr2"}) + set([kp.name for kp in kp_by_name]).should.equal(set([kp2.name])) - kp_by_name = conn.get_all_key_pairs( - filters={'fingerprint': kp3.fingerprint}) - set([kp.name for kp in kp_by_name] - ).should.equal(set([kp3.name])) + kp_by_name = conn.get_all_key_pairs(filters={"fingerprint": kp3.fingerprint}) + set([kp.name for kp in kp_by_name]).should.equal(set([kp3.name])) diff --git a/tests/test_ec2/test_launch_templates.py b/tests/test_ec2/test_launch_templates.py index 87e1d3986..4c37818d1 100644 --- a/tests/test_ec2/test_launch_templates.py +++ b/tests/test_ec2/test_launch_templates.py @@ -13,16 +13,14 @@ def test_launch_template_create(): resp = cli.create_launch_template( LaunchTemplateName="test-template", - # the absolute minimum needed to create a template without other resources LaunchTemplateData={ - "TagSpecifications": [{ - "ResourceType": "instance", - "Tags": [{ - "Key": "test", - "Value": "value", - }], - }], + "TagSpecifications": [ + { + "ResourceType": "instance", + "Tags": [{"Key": "test", "Value": "value"}], + } + ] }, ) @@ -36,18 +34,18 @@ def test_launch_template_create(): cli.create_launch_template( LaunchTemplateName="test-template", LaunchTemplateData={ - "TagSpecifications": [{ - "ResourceType": "instance", - "Tags": [{ - "Key": "test", - "Value": "value", - }], - }], + "TagSpecifications": [ + { + "ResourceType": "instance", + "Tags": [{"Key": "test", "Value": "value"}], + } + ] }, ) str(ex.exception).should.equal( - 'An error occurred (InvalidLaunchTemplateName.AlreadyExistsException) when calling the CreateLaunchTemplate operation: Launch template name already in use.') + "An error occurred (InvalidLaunchTemplateName.AlreadyExistsException) when calling the CreateLaunchTemplate operation: Launch template name already in use." + ) @mock_ec2 @@ -55,29 +53,22 @@ def test_describe_launch_template_versions(): template_data = { "ImageId": "ami-abc123", "DisableApiTermination": False, - "TagSpecifications": [{ - "ResourceType": "instance", - "Tags": [{ - "Key": "test", - "Value": "value", - }], - }], - "SecurityGroupIds": [ - "sg-1234", - "sg-ab5678", + "TagSpecifications": [ + {"ResourceType": "instance", "Tags": [{"Key": "test", "Value": "value"}]} ], + "SecurityGroupIds": ["sg-1234", "sg-ab5678"], } cli = boto3.client("ec2", region_name="us-east-1") create_resp = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData=template_data) + LaunchTemplateName="test-template", LaunchTemplateData=template_data + ) # test using name resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - Versions=['1']) + LaunchTemplateName="test-template", Versions=["1"] + ) templ = resp["LaunchTemplateVersions"][0]["LaunchTemplateData"] templ.should.equal(template_data) @@ -85,7 +76,8 @@ def test_describe_launch_template_versions(): # test using id resp = cli.describe_launch_template_versions( LaunchTemplateId=create_resp["LaunchTemplate"]["LaunchTemplateId"], - Versions=['1']) + Versions=["1"], + ) templ = resp["LaunchTemplateVersions"][0]["LaunchTemplateData"] templ.should.equal(template_data) @@ -96,22 +88,21 @@ def test_create_launch_template_version(): cli = boto3.client("ec2", region_name="us-east-1") create_resp = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) version_resp = cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) version_resp.should.have.key("LaunchTemplateVersion") version = version_resp["LaunchTemplateVersion"] version["DefaultVersion"].should.equal(False) - version["LaunchTemplateId"].should.equal(create_resp["LaunchTemplate"]["LaunchTemplateId"]) + version["LaunchTemplateId"].should.equal( + create_resp["LaunchTemplate"]["LaunchTemplateId"] + ) version["VersionDescription"].should.equal("new ami") version["VersionNumber"].should.equal(2) @@ -121,22 +112,21 @@ def test_create_launch_template_version_by_id(): cli = boto3.client("ec2", region_name="us-east-1") create_resp = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) version_resp = cli.create_launch_template_version( LaunchTemplateId=create_resp["LaunchTemplate"]["LaunchTemplateId"], - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) version_resp.should.have.key("LaunchTemplateVersion") version = version_resp["LaunchTemplateVersion"] version["DefaultVersion"].should.equal(False) - version["LaunchTemplateId"].should.equal(create_resp["LaunchTemplate"]["LaunchTemplateId"]) + version["LaunchTemplateId"].should.equal( + create_resp["LaunchTemplate"]["LaunchTemplateId"] + ) version["VersionDescription"].should.equal("new ami") version["VersionNumber"].should.equal(2) @@ -146,24 +136,24 @@ def test_describe_launch_template_versions_with_multiple_versions(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) - resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template") + resp = cli.describe_launch_template_versions(LaunchTemplateName="test-template") resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-abc123") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-abc123" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) @mock_ec2 @@ -171,32 +161,32 @@ def test_describe_launch_template_versions_with_versions_option(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - Versions=["2", "3"]) + LaunchTemplateName="test-template", Versions=["2", "3"] + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-hij789") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-hij789" + ) @mock_ec2 @@ -204,32 +194,32 @@ def test_describe_launch_template_versions_with_min(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - MinVersion="2") + LaunchTemplateName="test-template", MinVersion="2" + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-hij789") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-hij789" + ) @mock_ec2 @@ -237,32 +227,32 @@ def test_describe_launch_template_versions_with_max(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - MaxVersion="2") + LaunchTemplateName="test-template", MaxVersion="2" + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-abc123") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-abc123" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) @mock_ec2 @@ -270,40 +260,38 @@ def test_describe_launch_template_versions_with_min_and_max(): cli = boto3.client("ec2", region_name="us-east-1") cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-def456" - }, - VersionDescription="new ami") + LaunchTemplateData={"ImageId": "ami-def456"}, + VersionDescription="new ami", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-hij789" - }, - VersionDescription="new ami, again") + LaunchTemplateData={"ImageId": "ami-hij789"}, + VersionDescription="new ami, again", + ) cli.create_launch_template_version( LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-345abc" - }, - VersionDescription="new ami, because why not") + LaunchTemplateData={"ImageId": "ami-345abc"}, + VersionDescription="new ami, because why not", + ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - MinVersion="2", - MaxVersion="3") + LaunchTemplateName="test-template", MinVersion="2", MaxVersion="3" + ) resp["LaunchTemplateVersions"].should.have.length_of(2) - resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal("ami-def456") - resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal("ami-hij789") + resp["LaunchTemplateVersions"][0]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-def456" + ) + resp["LaunchTemplateVersions"][1]["LaunchTemplateData"]["ImageId"].should.equal( + "ami-hij789" + ) @mock_ec2 @@ -312,17 +300,14 @@ def test_describe_launch_templates(): lt_ids = [] r = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) lt_ids.append(r["LaunchTemplate"]["LaunchTemplateId"]) r = cli.create_launch_template( LaunchTemplateName="test-template2", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateData={"ImageId": "ami-abc123"}, + ) lt_ids.append(r["LaunchTemplate"]["LaunchTemplateId"]) # general call, all templates @@ -334,7 +319,8 @@ def test_describe_launch_templates(): # filter by names resp = cli.describe_launch_templates( - LaunchTemplateNames=["test-template2", "test-template"]) + LaunchTemplateNames=["test-template2", "test-template"] + ) resp.should.have.key("LaunchTemplates") resp["LaunchTemplates"].should.have.length_of(2) resp["LaunchTemplates"][0]["LaunchTemplateName"].should.equal("test-template2") @@ -353,34 +339,31 @@ def test_describe_launch_templates_with_filters(): cli = boto3.client("ec2", region_name="us-east-1") r = cli.create_launch_template( - LaunchTemplateName="test-template", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"} + ) cli.create_tags( Resources=[r["LaunchTemplate"]["LaunchTemplateId"]], Tags=[ {"Key": "tag1", "Value": "a value"}, {"Key": "another-key", "Value": "this value"}, - ]) + ], + ) cli.create_launch_template( - LaunchTemplateName="no-tags", - LaunchTemplateData={ - "ImageId": "ami-abc123" - }) + LaunchTemplateName="no-tags", LaunchTemplateData={"ImageId": "ami-abc123"} + ) - resp = cli.describe_launch_templates(Filters=[{ - "Name": "tag:tag1", "Values": ["a value"] - }]) + resp = cli.describe_launch_templates( + Filters=[{"Name": "tag:tag1", "Values": ["a value"]}] + ) resp["LaunchTemplates"].should.have.length_of(1) resp["LaunchTemplates"][0]["LaunchTemplateName"].should.equal("test-template") - resp = cli.describe_launch_templates(Filters=[{ - "Name": "launch-template-name", "Values": ["no-tags"] - }]) + resp = cli.describe_launch_templates( + Filters=[{"Name": "launch-template-name", "Values": ["no-tags"]}] + ) resp["LaunchTemplates"].should.have.length_of(1) resp["LaunchTemplates"][0]["LaunchTemplateName"].should.equal("no-tags") @@ -392,24 +375,18 @@ def test_create_launch_template_with_tag_spec(): cli.create_launch_template( LaunchTemplateName="test-template", LaunchTemplateData={"ImageId": "ami-abc123"}, - TagSpecifications=[{ - "ResourceType": "instance", - "Tags": [ - {"Key": "key", "Value": "value"} - ] - }], + TagSpecifications=[ + {"ResourceType": "instance", "Tags": [{"Key": "key", "Value": "value"}]} + ], ) resp = cli.describe_launch_template_versions( - LaunchTemplateName="test-template", - Versions=["1"]) + LaunchTemplateName="test-template", Versions=["1"] + ) version = resp["LaunchTemplateVersions"][0] version["LaunchTemplateData"].should.have.key("TagSpecifications") version["LaunchTemplateData"]["TagSpecifications"].should.have.length_of(1) - version["LaunchTemplateData"]["TagSpecifications"][0].should.equal({ - "ResourceType": "instance", - "Tags": [ - {"Key": "key", "Value": "value"} - ] - }) + version["LaunchTemplateData"]["TagSpecifications"][0].should.equal( + {"ResourceType": "instance", "Tags": [{"Key": "key", "Value": "value"}]} + ) diff --git a/tests/test_ec2/test_nat_gateway.py b/tests/test_ec2/test_nat_gateway.py index 27e8753be..484088356 100644 --- a/tests/test_ec2/test_nat_gateway.py +++ b/tests/test_ec2/test_nat_gateway.py @@ -6,104 +6,99 @@ from moto import mock_ec2 @mock_ec2 def test_describe_nat_gateways(): - conn = boto3.client('ec2', 'us-east-1') + conn = boto3.client("ec2", "us-east-1") response = conn.describe_nat_gateways() - response['NatGateways'].should.have.length_of(0) + response["NatGateways"].should.have.length_of(0) @mock_ec2 def test_create_nat_gateway(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock='10.0.0.0/16') - vpc_id = vpc['Vpc']['VpcId'] + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] subnet = conn.create_subnet( - VpcId=vpc_id, - CidrBlock='10.0.1.0/27', - AvailabilityZone='us-east-1a', + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" ) - allocation_id = conn.allocate_address(Domain='vpc')['AllocationId'] - subnet_id = subnet['Subnet']['SubnetId'] + allocation_id = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id = subnet["Subnet"]["SubnetId"] - response = conn.create_nat_gateway( - SubnetId=subnet_id, - AllocationId=allocation_id, - ) + response = conn.create_nat_gateway(SubnetId=subnet_id, AllocationId=allocation_id) - response['NatGateway']['VpcId'].should.equal(vpc_id) - response['NatGateway']['SubnetId'].should.equal(subnet_id) - response['NatGateway']['State'].should.equal('available') + response["NatGateway"]["VpcId"].should.equal(vpc_id) + response["NatGateway"]["SubnetId"].should.equal(subnet_id) + response["NatGateway"]["State"].should.equal("available") @mock_ec2 def test_delete_nat_gateway(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock='10.0.0.0/16') - vpc_id = vpc['Vpc']['VpcId'] + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] subnet = conn.create_subnet( - VpcId=vpc_id, - CidrBlock='10.0.1.0/27', - AvailabilityZone='us-east-1a', + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" ) - allocation_id = conn.allocate_address(Domain='vpc')['AllocationId'] - subnet_id = subnet['Subnet']['SubnetId'] + allocation_id = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id = subnet["Subnet"]["SubnetId"] nat_gateway = conn.create_nat_gateway( - SubnetId=subnet_id, - AllocationId=allocation_id, + SubnetId=subnet_id, AllocationId=allocation_id ) - nat_gateway_id = nat_gateway['NatGateway']['NatGatewayId'] + nat_gateway_id = nat_gateway["NatGateway"]["NatGatewayId"] response = conn.delete_nat_gateway(NatGatewayId=nat_gateway_id) # this is hard to match against, so remove it - response['ResponseMetadata'].pop('HTTPHeaders', None) - response['ResponseMetadata'].pop('RetryAttempts', None) - response.should.equal({ - 'NatGatewayId': nat_gateway_id, - 'ResponseMetadata': { - 'HTTPStatusCode': 200, - 'RequestId': '741fc8ab-6ebe-452b-b92b-example' + response["ResponseMetadata"].pop("HTTPHeaders", None) + response["ResponseMetadata"].pop("RetryAttempts", None) + response.should.equal( + { + "NatGatewayId": nat_gateway_id, + "ResponseMetadata": { + "HTTPStatusCode": 200, + "RequestId": "741fc8ab-6ebe-452b-b92b-example", + }, } - }) + ) @mock_ec2 def test_create_and_describe_nat_gateway(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock='10.0.0.0/16') - vpc_id = vpc['Vpc']['VpcId'] + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16") + vpc_id = vpc["Vpc"]["VpcId"] subnet = conn.create_subnet( - VpcId=vpc_id, - CidrBlock='10.0.1.0/27', - AvailabilityZone='us-east-1a', + VpcId=vpc_id, CidrBlock="10.0.1.0/27", AvailabilityZone="us-east-1a" ) - allocation_id = conn.allocate_address(Domain='vpc')['AllocationId'] - subnet_id = subnet['Subnet']['SubnetId'] + allocation_id = conn.allocate_address(Domain="vpc")["AllocationId"] + subnet_id = subnet["Subnet"]["SubnetId"] create_response = conn.create_nat_gateway( - SubnetId=subnet_id, - AllocationId=allocation_id, + SubnetId=subnet_id, AllocationId=allocation_id ) - nat_gateway_id = create_response['NatGateway']['NatGatewayId'] + nat_gateway_id = create_response["NatGateway"]["NatGatewayId"] describe_response = conn.describe_nat_gateways() - enis = conn.describe_network_interfaces()['NetworkInterfaces'] - eni_id = enis[0]['NetworkInterfaceId'] - public_ip = conn.describe_addresses(AllocationIds=[allocation_id])[ - 'Addresses'][0]['PublicIp'] + enis = conn.describe_network_interfaces()["NetworkInterfaces"] + eni_id = enis[0]["NetworkInterfaceId"] + public_ip = conn.describe_addresses(AllocationIds=[allocation_id])["Addresses"][0][ + "PublicIp" + ] - describe_response['NatGateways'].should.have.length_of(1) - describe_response['NatGateways'][0][ - 'NatGatewayId'].should.equal(nat_gateway_id) - describe_response['NatGateways'][0]['State'].should.equal('available') - describe_response['NatGateways'][0]['SubnetId'].should.equal(subnet_id) - describe_response['NatGateways'][0]['VpcId'].should.equal(vpc_id) - describe_response['NatGateways'][0]['NatGatewayAddresses'][ - 0]['AllocationId'].should.equal(allocation_id) - describe_response['NatGateways'][0]['NatGatewayAddresses'][ - 0]['NetworkInterfaceId'].should.equal(eni_id) - assert describe_response['NatGateways'][0][ - 'NatGatewayAddresses'][0]['PrivateIp'].startswith('10.') - describe_response['NatGateways'][0]['NatGatewayAddresses'][ - 0]['PublicIp'].should.equal(public_ip) + describe_response["NatGateways"].should.have.length_of(1) + describe_response["NatGateways"][0]["NatGatewayId"].should.equal(nat_gateway_id) + describe_response["NatGateways"][0]["State"].should.equal("available") + describe_response["NatGateways"][0]["SubnetId"].should.equal(subnet_id) + describe_response["NatGateways"][0]["VpcId"].should.equal(vpc_id) + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "AllocationId" + ].should.equal(allocation_id) + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "NetworkInterfaceId" + ].should.equal(eni_id) + assert describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "PrivateIp" + ].startswith("10.") + describe_response["NatGateways"][0]["NatGatewayAddresses"][0][ + "PublicIp" + ].should.equal(public_ip) diff --git a/tests/test_ec2/test_network_acls.py b/tests/test_ec2/test_network_acls.py index 1c69624bf..fb62f7178 100644 --- a/tests/test_ec2/test_network_acls.py +++ b/tests/test_ec2/test_network_acls.py @@ -10,7 +10,7 @@ from moto import mock_ec2_deprecated, mock_ec2 @mock_ec2_deprecated def test_default_network_acl_created_with_vpc(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") all_network_acls = conn.get_all_network_acls() all_network_acls.should.have.length_of(2) @@ -18,7 +18,7 @@ def test_default_network_acl_created_with_vpc(): @mock_ec2_deprecated def test_network_acls(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) all_network_acls = conn.get_all_network_acls() @@ -27,7 +27,7 @@ def test_network_acls(): @mock_ec2_deprecated def test_new_subnet_associates_with_default_network_acl(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.get_all_vpcs()[0] subnet = conn.create_subnet(vpc.id, "172.31.112.0/20") @@ -41,88 +41,100 @@ def test_new_subnet_associates_with_default_network_acl(): @mock_ec2_deprecated def test_network_acl_entries(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) network_acl_entry = conn.create_network_acl_entry( - network_acl.id, 110, 6, - 'ALLOW', '0.0.0.0/0', False, - port_range_from='443', - port_range_to='443' + network_acl.id, + 110, + 6, + "ALLOW", + "0.0.0.0/0", + False, + port_range_from="443", + port_range_to="443", ) all_network_acls = conn.get_all_network_acls() all_network_acls.should.have.length_of(3) - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) entries = test_network_acl.network_acl_entries entries.should.have.length_of(1) - entries[0].rule_number.should.equal('110') - entries[0].protocol.should.equal('6') - entries[0].rule_action.should.equal('ALLOW') + entries[0].rule_number.should.equal("110") + entries[0].protocol.should.equal("6") + entries[0].rule_action.should.equal("ALLOW") @mock_ec2_deprecated def test_delete_network_acl_entry(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) conn.create_network_acl_entry( - network_acl.id, 110, 6, - 'ALLOW', '0.0.0.0/0', False, - port_range_from='443', - port_range_to='443' - ) - conn.delete_network_acl_entry( - network_acl.id, 110, False + network_acl.id, + 110, + 6, + "ALLOW", + "0.0.0.0/0", + False, + port_range_from="443", + port_range_to="443", ) + conn.delete_network_acl_entry(network_acl.id, 110, False) all_network_acls = conn.get_all_network_acls() - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) entries = test_network_acl.network_acl_entries entries.should.have.length_of(0) @mock_ec2_deprecated def test_replace_network_acl_entry(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) conn.create_network_acl_entry( - network_acl.id, 110, 6, - 'ALLOW', '0.0.0.0/0', False, - port_range_from='443', - port_range_to='443' + network_acl.id, + 110, + 6, + "ALLOW", + "0.0.0.0/0", + False, + port_range_from="443", + port_range_to="443", ) conn.replace_network_acl_entry( - network_acl.id, 110, -1, - 'DENY', '0.0.0.0/0', False, - port_range_from='22', - port_range_to='22' + network_acl.id, + 110, + -1, + "DENY", + "0.0.0.0/0", + False, + port_range_from="22", + port_range_to="22", ) all_network_acls = conn.get_all_network_acls() - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) entries = test_network_acl.network_acl_entries entries.should.have.length_of(1) - entries[0].rule_number.should.equal('110') - entries[0].protocol.should.equal('-1') - entries[0].rule_action.should.equal('DENY') + entries[0].rule_number.should.equal("110") + entries[0].protocol.should.equal("-1") + entries[0].rule_action.should.equal("DENY") + @mock_ec2_deprecated def test_associate_new_network_acl_with_subnet(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") network_acl = conn.create_network_acl(vpc.id) @@ -132,8 +144,7 @@ def test_associate_new_network_acl_with_subnet(): all_network_acls = conn.get_all_network_acls() all_network_acls.should.have.length_of(3) - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) test_network_acl.associations.should.have.length_of(1) test_network_acl.associations[0].subnet_id.should.equal(subnet.id) @@ -141,7 +152,7 @@ def test_associate_new_network_acl_with_subnet(): @mock_ec2_deprecated def test_delete_network_acl(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") network_acl = conn.create_network_acl(vpc.id) @@ -161,7 +172,7 @@ def test_delete_network_acl(): @mock_ec2_deprecated def test_network_acl_tagging(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") network_acl = conn.create_network_acl(vpc.id) @@ -172,46 +183,45 @@ def test_network_acl_tagging(): tag.value.should.equal("some value") all_network_acls = conn.get_all_network_acls() - test_network_acl = next(na for na in all_network_acls - if na.id == network_acl.id) + test_network_acl = next(na for na in all_network_acls if na.id == network_acl.id) test_network_acl.tags.should.have.length_of(1) test_network_acl.tags["a key"].should.equal("some value") @mock_ec2 def test_new_subnet_in_new_vpc_associates_with_default_network_acl(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - new_vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-west-1") + new_vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") new_vpc.reload() - subnet = ec2.create_subnet(VpcId=new_vpc.id, CidrBlock='10.0.0.0/24') + subnet = ec2.create_subnet(VpcId=new_vpc.id, CidrBlock="10.0.0.0/24") subnet.reload() new_vpcs_default_network_acl = next(iter(new_vpc.network_acls.all()), None) new_vpcs_default_network_acl.reload() new_vpcs_default_network_acl.vpc_id.should.equal(new_vpc.id) new_vpcs_default_network_acl.associations.should.have.length_of(1) - new_vpcs_default_network_acl.associations[0]['SubnetId'].should.equal(subnet.id) + new_vpcs_default_network_acl.associations[0]["SubnetId"].should.equal(subnet.id) @mock_ec2 def test_default_network_acl_default_entries(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_network_acl = next(iter(ec2.network_acls.all()), None) default_network_acl.is_default.should.be.ok default_network_acl.entries.should.have.length_of(4) unique_entries = [] for entry in default_network_acl.entries: - entry['CidrBlock'].should.equal('0.0.0.0/0') - entry['Protocol'].should.equal('-1') - entry['RuleNumber'].should.be.within([100, 32767]) - entry['RuleAction'].should.be.within(['allow', 'deny']) - assert type(entry['Egress']) is bool - if entry['RuleAction'] == 'allow': - entry['RuleNumber'].should.be.equal(100) + entry["CidrBlock"].should.equal("0.0.0.0/0") + entry["Protocol"].should.equal("-1") + entry["RuleNumber"].should.be.within([100, 32767]) + entry["RuleAction"].should.be.within(["allow", "deny"]) + assert type(entry["Egress"]) is bool + if entry["RuleAction"] == "allow": + entry["RuleNumber"].should.be.equal(100) else: - entry['RuleNumber'].should.be.equal(32767) + entry["RuleNumber"].should.be.equal(32767) if entry not in unique_entries: unique_entries.append(entry) @@ -220,33 +230,48 @@ def test_default_network_acl_default_entries(): @mock_ec2 def test_delete_default_network_acl_default_entry(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_network_acl = next(iter(ec2.network_acls.all()), None) default_network_acl.is_default.should.be.ok default_network_acl.entries.should.have.length_of(4) first_default_network_acl_entry = default_network_acl.entries[0] - default_network_acl.delete_entry(Egress=first_default_network_acl_entry['Egress'], - RuleNumber=first_default_network_acl_entry['RuleNumber']) + default_network_acl.delete_entry( + Egress=first_default_network_acl_entry["Egress"], + RuleNumber=first_default_network_acl_entry["RuleNumber"], + ) default_network_acl.entries.should.have.length_of(3) @mock_ec2 def test_duplicate_network_acl_entry(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_network_acl = next(iter(ec2.network_acls.all()), None) default_network_acl.is_default.should.be.ok rule_number = 200 egress = True - default_network_acl.create_entry(CidrBlock="0.0.0.0/0", Egress=egress, Protocol="-1", RuleAction="allow", RuleNumber=rule_number) + default_network_acl.create_entry( + CidrBlock="0.0.0.0/0", + Egress=egress, + Protocol="-1", + RuleAction="allow", + RuleNumber=rule_number, + ) with assert_raises(ClientError) as ex: - default_network_acl.create_entry(CidrBlock="10.0.0.0/0", Egress=egress, Protocol="-1", RuleAction="deny", RuleNumber=rule_number) + default_network_acl.create_entry( + CidrBlock="10.0.0.0/0", + Egress=egress, + Protocol="-1", + RuleAction="deny", + RuleNumber=rule_number, + ) str(ex.exception).should.equal( "An error occurred (NetworkAclEntryAlreadyExists) when calling the CreateNetworkAclEntry " - "operation: The network acl entry identified by {} already exists.".format(rule_number)) - - + "operation: The network acl entry identified by {} already exists.".format( + rule_number + ) + ) diff --git a/tests/test_ec2/test_regions.py b/tests/test_ec2/test_regions.py index f94c78eaf..551b739f2 100644 --- a/tests/test_ec2/test_regions.py +++ b/tests/test_ec2/test_regions.py @@ -7,38 +7,41 @@ from moto import mock_ec2_deprecated, mock_autoscaling_deprecated, mock_elb_depr from moto.ec2 import ec2_backends + def test_use_boto_regions(): boto_regions = {r.name for r in boto.ec2.regions()} moto_regions = set(ec2_backends) moto_regions.should.equal(boto_regions) + def add_servers_to_region(ami_id, count, region): conn = boto.ec2.connect_to_region(region) for index in range(count): conn.run_instances(ami_id) + @mock_ec2_deprecated def test_add_servers_to_a_single_region(): - region = 'ap-northeast-1' - add_servers_to_region('ami-1234abcd', 1, region) - add_servers_to_region('ami-5678efgh', 1, region) + region = "ap-northeast-1" + add_servers_to_region("ami-1234abcd", 1, region) + add_servers_to_region("ami-5678efgh", 1, region) conn = boto.ec2.connect_to_region(region) reservations = conn.get_all_instances() len(reservations).should.equal(2) reservations.sort(key=lambda x: x.instances[0].image_id) - reservations[0].instances[0].image_id.should.equal('ami-1234abcd') - reservations[1].instances[0].image_id.should.equal('ami-5678efgh') + reservations[0].instances[0].image_id.should.equal("ami-1234abcd") + reservations[1].instances[0].image_id.should.equal("ami-5678efgh") @mock_ec2_deprecated def test_add_servers_to_multiple_regions(): - region1 = 'us-east-1' - region2 = 'ap-northeast-1' - add_servers_to_region('ami-1234abcd', 1, region1) - add_servers_to_region('ami-5678efgh', 1, region2) + region1 = "us-east-1" + region2 = "ap-northeast-1" + add_servers_to_region("ami-1234abcd", 1, region1) + add_servers_to_region("ami-5678efgh", 1, region2) us_conn = boto.ec2.connect_to_region(region1) ap_conn = boto.ec2.connect_to_region(region2) @@ -48,33 +51,35 @@ def test_add_servers_to_multiple_regions(): len(us_reservations).should.equal(1) len(ap_reservations).should.equal(1) - us_reservations[0].instances[0].image_id.should.equal('ami-1234abcd') - ap_reservations[0].instances[0].image_id.should.equal('ami-5678efgh') + us_reservations[0].instances[0].image_id.should.equal("ami-1234abcd") + ap_reservations[0].instances[0].image_id.should.equal("ami-5678efgh") @mock_autoscaling_deprecated @mock_elb_deprecated def test_create_autoscaling_group(): - elb_conn = boto.ec2.elb.connect_to_region('us-east-1') + elb_conn = boto.ec2.elb.connect_to_region("us-east-1") elb_conn.create_load_balancer( - 'us_test_lb', zones=[], listeners=[(80, 8080, 'http')]) - elb_conn = boto.ec2.elb.connect_to_region('ap-northeast-1') + "us_test_lb", zones=[], listeners=[(80, 8080, "http")] + ) + elb_conn = boto.ec2.elb.connect_to_region("ap-northeast-1") elb_conn.create_load_balancer( - 'ap_test_lb', zones=[], listeners=[(80, 8080, 'http')]) + "ap_test_lb", zones=[], listeners=[(80, 8080, "http")] + ) - us_conn = boto.ec2.autoscale.connect_to_region('us-east-1') + us_conn = boto.ec2.autoscale.connect_to_region("us-east-1") config = boto.ec2.autoscale.LaunchConfiguration( - name='us_tester', - image_id='ami-abcd1234', - instance_type='m1.small', + name="us_tester", image_id="ami-abcd1234", instance_type="m1.small" ) x = us_conn.create_launch_configuration(config) - us_subnet_id = list(ec2_backends['us-east-1'].subnets['us-east-1c'].keys())[0] - ap_subnet_id = list(ec2_backends['ap-northeast-1'].subnets['ap-northeast-1a'].keys())[0] + us_subnet_id = list(ec2_backends["us-east-1"].subnets["us-east-1c"].keys())[0] + ap_subnet_id = list( + ec2_backends["ap-northeast-1"].subnets["ap-northeast-1a"].keys() + )[0] group = boto.ec2.autoscale.AutoScalingGroup( - name='us_tester_group', - availability_zones=['us-east-1c'], + name="us_tester_group", + availability_zones=["us-east-1c"], default_cooldown=60, desired_capacity=2, health_check_period=100, @@ -89,17 +94,15 @@ def test_create_autoscaling_group(): ) us_conn.create_auto_scaling_group(group) - ap_conn = boto.ec2.autoscale.connect_to_region('ap-northeast-1') + ap_conn = boto.ec2.autoscale.connect_to_region("ap-northeast-1") config = boto.ec2.autoscale.LaunchConfiguration( - name='ap_tester', - image_id='ami-efgh5678', - instance_type='m1.small', + name="ap_tester", image_id="ami-efgh5678", instance_type="m1.small" ) ap_conn.create_launch_configuration(config) group = boto.ec2.autoscale.AutoScalingGroup( - name='ap_tester_group', - availability_zones=['ap-northeast-1a'], + name="ap_tester_group", + availability_zones=["ap-northeast-1a"], default_cooldown=60, desired_capacity=2, health_check_period=100, @@ -118,33 +121,35 @@ def test_create_autoscaling_group(): len(ap_conn.get_all_groups()).should.equal(1) us_group = us_conn.get_all_groups()[0] - us_group.name.should.equal('us_tester_group') - list(us_group.availability_zones).should.equal(['us-east-1c']) + us_group.name.should.equal("us_tester_group") + list(us_group.availability_zones).should.equal(["us-east-1c"]) us_group.desired_capacity.should.equal(2) us_group.max_size.should.equal(2) us_group.min_size.should.equal(2) us_group.vpc_zone_identifier.should.equal(us_subnet_id) - us_group.launch_config_name.should.equal('us_tester') + us_group.launch_config_name.should.equal("us_tester") us_group.default_cooldown.should.equal(60) us_group.health_check_period.should.equal(100) us_group.health_check_type.should.equal("EC2") list(us_group.load_balancers).should.equal(["us_test_lb"]) us_group.placement_group.should.equal("us_test_placement") list(us_group.termination_policies).should.equal( - ["OldestInstance", "NewestInstance"]) + ["OldestInstance", "NewestInstance"] + ) ap_group = ap_conn.get_all_groups()[0] - ap_group.name.should.equal('ap_tester_group') - list(ap_group.availability_zones).should.equal(['ap-northeast-1a']) + ap_group.name.should.equal("ap_tester_group") + list(ap_group.availability_zones).should.equal(["ap-northeast-1a"]) ap_group.desired_capacity.should.equal(2) ap_group.max_size.should.equal(2) ap_group.min_size.should.equal(2) ap_group.vpc_zone_identifier.should.equal(ap_subnet_id) - ap_group.launch_config_name.should.equal('ap_tester') + ap_group.launch_config_name.should.equal("ap_tester") ap_group.default_cooldown.should.equal(60) ap_group.health_check_period.should.equal(100) ap_group.health_check_type.should.equal("EC2") list(ap_group.load_balancers).should.equal(["ap_test_lb"]) ap_group.placement_group.should.equal("ap_test_placement") list(ap_group.termination_policies).should.equal( - ["OldestInstance", "NewestInstance"]) + ["OldestInstance", "NewestInstance"] + ) diff --git a/tests/test_ec2/test_route_tables.py b/tests/test_ec2/test_route_tables.py index de33b3f7a..b82313bc8 100644 --- a/tests/test_ec2/test_route_tables.py +++ b/tests/test_ec2/test_route_tables.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -15,10 +16,10 @@ from tests.helpers import requires_boto_gte @mock_ec2_deprecated def test_route_tables_defaults(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(1) main_route_table = all_route_tables[0] @@ -28,23 +29,23 @@ def test_route_tables_defaults(): routes.should.have.length_of(1) local_route = routes[0] - local_route.gateway_id.should.equal('local') - local_route.state.should.equal('active') + local_route.gateway_id.should.equal("local") + local_route.state.should.equal("active") local_route.destination_cidr_block.should.equal(vpc.cidr_block) vpc.delete() - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(0) @mock_ec2_deprecated def test_route_tables_additional(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") route_table = conn.create_route_table(vpc.id) - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(2) all_route_tables[0].vpc_id.should.equal(vpc.id) all_route_tables[1].vpc_id.should.equal(vpc.id) @@ -56,31 +57,31 @@ def test_route_tables_additional(): routes.should.have.length_of(1) local_route = routes[0] - local_route.gateway_id.should.equal('local') - local_route.state.should.equal('active') + local_route.gateway_id.should.equal("local") + local_route.state.should.equal("active") local_route.destination_cidr_block.should.equal(vpc.cidr_block) with assert_raises(EC2ResponseError) as cm: conn.delete_vpc(vpc.id) - cm.exception.code.should.equal('DependencyViolation') + cm.exception.code.should.equal("DependencyViolation") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none conn.delete_route_table(route_table.id) - all_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc.id}) + all_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc.id}) all_route_tables.should.have.length_of(1) with assert_raises(EC2ResponseError) as cm: conn.delete_route_table("rtb-1234abcd") - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_route_tables_filters_standard(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc1 = conn.create_vpc("10.0.0.0/16") route_table1 = conn.create_route_table(vpc1.id) @@ -92,39 +93,39 @@ def test_route_tables_filters_standard(): all_route_tables.should.have.length_of(5) # Filter by main route table - main_route_tables = conn.get_all_route_tables( - filters={'association.main': 'true'}) + main_route_tables = conn.get_all_route_tables(filters={"association.main": "true"}) main_route_tables.should.have.length_of(3) - main_route_table_ids = [ - route_table.id for route_table in main_route_tables] + main_route_table_ids = [route_table.id for route_table in main_route_tables] main_route_table_ids.should_not.contain(route_table1.id) main_route_table_ids.should_not.contain(route_table2.id) # Filter by VPC - vpc1_route_tables = conn.get_all_route_tables(filters={'vpc-id': vpc1.id}) + vpc1_route_tables = conn.get_all_route_tables(filters={"vpc-id": vpc1.id}) vpc1_route_tables.should.have.length_of(2) - vpc1_route_table_ids = [ - route_table.id for route_table in vpc1_route_tables] + vpc1_route_table_ids = [route_table.id for route_table in vpc1_route_tables] vpc1_route_table_ids.should.contain(route_table1.id) vpc1_route_table_ids.should_not.contain(route_table2.id) # Filter by VPC and main route table vpc2_main_route_tables = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc2.id}) + filters={"association.main": "true", "vpc-id": vpc2.id} + ) vpc2_main_route_tables.should.have.length_of(1) vpc2_main_route_table_ids = [ - route_table.id for route_table in vpc2_main_route_tables] + route_table.id for route_table in vpc2_main_route_tables + ] vpc2_main_route_table_ids.should_not.contain(route_table1.id) vpc2_main_route_table_ids.should_not.contain(route_table2.id) # Unsupported filter conn.get_all_route_tables.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2_deprecated def test_route_tables_filters_associations(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet1 = conn.create_subnet(vpc.id, "10.0.0.0/24") @@ -142,21 +143,24 @@ def test_route_tables_filters_associations(): # Filter by association ID association1_route_tables = conn.get_all_route_tables( - filters={'association.route-table-association-id': association_id1}) + filters={"association.route-table-association-id": association_id1} + ) association1_route_tables.should.have.length_of(1) association1_route_tables[0].id.should.equal(route_table1.id) association1_route_tables[0].associations.should.have.length_of(2) # Filter by route table ID route_table2_route_tables = conn.get_all_route_tables( - filters={'association.route-table-id': route_table2.id}) + filters={"association.route-table-id": route_table2.id} + ) route_table2_route_tables.should.have.length_of(1) route_table2_route_tables[0].id.should.equal(route_table2.id) route_table2_route_tables[0].associations.should.have.length_of(1) # Filter by subnet ID subnet_route_tables = conn.get_all_route_tables( - filters={'association.subnet-id': subnet1.id}) + filters={"association.subnet-id": subnet1.id} + ) subnet_route_tables.should.have.length_of(1) subnet_route_tables[0].id.should.equal(route_table1.id) association1_route_tables[0].associations.should.have.length_of(2) @@ -164,7 +168,7 @@ def test_route_tables_filters_associations(): @mock_ec2_deprecated def test_route_table_associations(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") route_table = conn.create_route_table(vpc.id) @@ -189,14 +193,13 @@ def test_route_table_associations(): route_table.associations[0].subnet_id.should.equal(subnet.id) # Associate is idempotent - association_id_idempotent = conn.associate_route_table( - route_table.id, subnet.id) + association_id_idempotent = conn.associate_route_table(route_table.id, subnet.id) association_id_idempotent.should.equal(association_id) # Error: Attempt delete associated route table. with assert_raises(EC2ResponseError) as cm: conn.delete_route_table(route_table.id) - cm.exception.code.should.equal('DependencyViolation') + cm.exception.code.should.equal("DependencyViolation") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -210,21 +213,21 @@ def test_route_table_associations(): # Error: Disassociate with invalid association ID with assert_raises(EC2ResponseError) as cm: conn.disassociate_route_table(association_id) - cm.exception.code.should.equal('InvalidAssociationID.NotFound') + cm.exception.code.should.equal("InvalidAssociationID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Associate with invalid subnet ID with assert_raises(EC2ResponseError) as cm: conn.associate_route_table(route_table.id, "subnet-1234abcd") - cm.exception.code.should.equal('InvalidSubnetID.NotFound') + cm.exception.code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Associate with invalid route table ID with assert_raises(EC2ResponseError) as cm: conn.associate_route_table("rtb-1234abcd", subnet.id) - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -236,7 +239,7 @@ def test_route_table_replace_route_table_association(): Note: Boto has deprecated replace_route_table_assocation (which returns status) and now uses replace_route_table_assocation_with_assoc (which returns association ID). """ - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") route_table1 = conn.create_route_table(vpc.id) @@ -267,7 +270,8 @@ def test_route_table_replace_route_table_association(): # Replace Association association_id2 = conn.replace_route_table_association_with_assoc( - association_id1, route_table2.id) + association_id1, route_table2.id + ) # Refresh route_table1 = conn.get_all_route_tables(route_table1.id)[0] @@ -284,120 +288,128 @@ def test_route_table_replace_route_table_association(): # Replace Association is idempotent association_id_idempotent = conn.replace_route_table_association_with_assoc( - association_id2, route_table2.id) + association_id2, route_table2.id + ) association_id_idempotent.should.equal(association_id2) # Error: Replace association with invalid association ID with assert_raises(EC2ResponseError) as cm: conn.replace_route_table_association_with_assoc( - "rtbassoc-1234abcd", route_table1.id) - cm.exception.code.should.equal('InvalidAssociationID.NotFound') + "rtbassoc-1234abcd", route_table1.id + ) + cm.exception.code.should.equal("InvalidAssociationID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Error: Replace association with invalid route table ID with assert_raises(EC2ResponseError) as cm: - conn.replace_route_table_association_with_assoc( - association_id2, "rtb-1234abcd") - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + conn.replace_route_table_association_with_assoc(association_id2, "rtb-1234abcd") + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_route_table_get_by_tag(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - vpc = conn.create_vpc('10.0.0.0/16') + vpc = conn.create_vpc("10.0.0.0/16") route_table = conn.create_route_table(vpc.id) - route_table.add_tag('Name', 'TestRouteTable') + route_table.add_tag("Name", "TestRouteTable") - route_tables = conn.get_all_route_tables( - filters={'tag:Name': 'TestRouteTable'}) + route_tables = conn.get_all_route_tables(filters={"tag:Name": "TestRouteTable"}) route_tables.should.have.length_of(1) route_tables[0].vpc_id.should.equal(vpc.id) route_tables[0].id.should.equal(route_table.id) route_tables[0].tags.should.have.length_of(1) - route_tables[0].tags['Name'].should.equal('TestRouteTable') + route_tables[0].tags["Name"].should.equal("TestRouteTable") @mock_ec2 def test_route_table_get_by_tag_boto3(): - ec2 = boto3.resource('ec2', region_name='eu-central-1') + ec2 = boto3.resource("ec2", region_name="eu-central-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") route_table = ec2.create_route_table(VpcId=vpc.id) - route_table.create_tags(Tags=[{'Key': 'Name', 'Value': 'TestRouteTable'}]) + route_table.create_tags(Tags=[{"Key": "Name", "Value": "TestRouteTable"}]) - filters = [{'Name': 'tag:Name', 'Values': ['TestRouteTable']}] + filters = [{"Name": "tag:Name", "Values": ["TestRouteTable"]}] route_tables = list(ec2.route_tables.filter(Filters=filters)) route_tables.should.have.length_of(1) route_tables[0].vpc_id.should.equal(vpc.id) route_tables[0].id.should.equal(route_table.id) route_tables[0].tags.should.have.length_of(1) - route_tables[0].tags[0].should.equal( - {'Key': 'Name', 'Value': 'TestRouteTable'}) + route_tables[0].tags[0].should.equal({"Key": "Name", "Value": "TestRouteTable"}) @mock_ec2_deprecated def test_routes_additional(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - main_route_table = conn.get_all_route_tables(filters={'vpc-id': vpc.id})[0] + main_route_table = conn.get_all_route_tables(filters={"vpc-id": vpc.id})[0] local_route = main_route_table.routes[0] igw = conn.create_internet_gateway() ROUTE_CIDR = "10.0.0.4/24" conn.create_route(main_route_table.id, ROUTE_CIDR, gateway_id=igw.id) - main_route_table = conn.get_all_route_tables( - filters={'vpc-id': vpc.id})[0] # Refresh route table + main_route_table = conn.get_all_route_tables(filters={"vpc-id": vpc.id})[ + 0 + ] # Refresh route table main_route_table.routes.should.have.length_of(2) new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(1) new_route = new_routes[0] new_route.gateway_id.should.equal(igw.id) new_route.instance_id.should.be.none - new_route.state.should.equal('active') + new_route.state.should.equal("active") new_route.destination_cidr_block.should.equal(ROUTE_CIDR) conn.delete_route(main_route_table.id, ROUTE_CIDR) - main_route_table = conn.get_all_route_tables( - filters={'vpc-id': vpc.id})[0] # Refresh route table + main_route_table = conn.get_all_route_tables(filters={"vpc-id": vpc.id})[ + 0 + ] # Refresh route table main_route_table.routes.should.have.length_of(1) new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(0) with assert_raises(EC2ResponseError) as cm: conn.delete_route(main_route_table.id, ROUTE_CIDR) - cm.exception.code.should.equal('InvalidRoute.NotFound') + cm.exception.code.should.equal("InvalidRoute.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_routes_replace(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc.id})[0] + filters={"association.main": "true", "vpc-id": vpc.id} + )[0] local_route = main_route_table.routes[0] ROUTE_CIDR = "10.0.0.4/24" # Various route targets igw = conn.create_internet_gateway() - reservation = conn.run_instances('ami-1234abcd') + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] # Create initial route @@ -407,17 +419,19 @@ def test_routes_replace(): def get_target_route(): route_table = conn.get_all_route_tables(main_route_table.id)[0] routes = [ - route for route in route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] routes.should.have.length_of(1) return routes[0] - conn.replace_route(main_route_table.id, ROUTE_CIDR, - instance_id=instance.id) + conn.replace_route(main_route_table.id, ROUTE_CIDR, instance_id=instance.id) target_route = get_target_route() target_route.gateway_id.should.be.none target_route.instance_id.should.equal(instance.id) - target_route.state.should.equal('active') + target_route.state.should.equal("active") target_route.destination_cidr_block.should.equal(ROUTE_CIDR) conn.replace_route(main_route_table.id, ROUTE_CIDR, gateway_id=igw.id) @@ -425,12 +439,12 @@ def test_routes_replace(): target_route = get_target_route() target_route.gateway_id.should.equal(igw.id) target_route.instance_id.should.be.none - target_route.state.should.equal('active') + target_route.state.should.equal("active") target_route.destination_cidr_block.should.equal(ROUTE_CIDR) with assert_raises(EC2ResponseError) as cm: - conn.replace_route('rtb-1234abcd', ROUTE_CIDR, gateway_id=igw.id) - cm.exception.code.should.equal('InvalidRouteTableID.NotFound') + conn.replace_route("rtb-1234abcd", ROUTE_CIDR, gateway_id=igw.id) + cm.exception.code.should.equal("InvalidRouteTableID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -438,7 +452,7 @@ def test_routes_replace(): @requires_boto_gte("2.19.0") @mock_ec2_deprecated def test_routes_not_supported(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables()[0] local_route = main_route_table.routes[0] @@ -447,42 +461,49 @@ def test_routes_not_supported(): # Create conn.create_route.when.called_with( - main_route_table.id, ROUTE_CIDR, interface_id='eni-1234abcd').should.throw(NotImplementedError) + main_route_table.id, ROUTE_CIDR, interface_id="eni-1234abcd" + ).should.throw(NotImplementedError) # Replace igw = conn.create_internet_gateway() conn.create_route(main_route_table.id, ROUTE_CIDR, gateway_id=igw.id) conn.replace_route.when.called_with( - main_route_table.id, ROUTE_CIDR, interface_id='eni-1234abcd').should.throw(NotImplementedError) + main_route_table.id, ROUTE_CIDR, interface_id="eni-1234abcd" + ).should.throw(NotImplementedError) @requires_boto_gte("2.34.0") @mock_ec2_deprecated def test_routes_vpc_peering_connection(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc.id})[0] + filters={"association.main": "true", "vpc-id": vpc.id} + )[0] local_route = main_route_table.routes[0] ROUTE_CIDR = "10.0.0.4/24" peer_vpc = conn.create_vpc("11.0.0.0/16") vpc_pcx = conn.create_vpc_peering_connection(vpc.id, peer_vpc.id) - conn.create_route(main_route_table.id, ROUTE_CIDR, - vpc_peering_connection_id=vpc_pcx.id) + conn.create_route( + main_route_table.id, ROUTE_CIDR, vpc_peering_connection_id=vpc_pcx.id + ) # Refresh route table main_route_table = conn.get_all_route_tables(main_route_table.id)[0] new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(1) new_route = new_routes[0] new_route.gateway_id.should.be.none new_route.instance_id.should.be.none new_route.vpc_peering_connection_id.should.equal(vpc_pcx.id) - new_route.state.should.equal('blackhole') + new_route.state.should.equal("blackhole") new_route.destination_cidr_block.should.equal(ROUTE_CIDR) @@ -490,10 +511,11 @@ def test_routes_vpc_peering_connection(): @mock_ec2_deprecated def test_routes_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") main_route_table = conn.get_all_route_tables( - filters={'association.main': 'true', 'vpc-id': vpc.id})[0] + filters={"association.main": "true", "vpc-id": vpc.id} + )[0] ROUTE_CIDR = "10.0.0.4/24" vpn_gw = conn.create_vpn_gateway(type="ipsec.1") @@ -502,7 +524,10 @@ def test_routes_vpn_gateway(): main_route_table = conn.get_all_route_tables(main_route_table.id)[0] new_routes = [ - route for route in main_route_table.routes if route.destination_cidr_block != vpc.cidr_block] + route + for route in main_route_table.routes + if route.destination_cidr_block != vpc.cidr_block + ] new_routes.should.have.length_of(1) new_route = new_routes[0] @@ -514,7 +539,7 @@ def test_routes_vpn_gateway(): @mock_ec2_deprecated def test_network_acl_tagging(): - conn = boto.connect_vpc('the_key', 'the secret') + conn = boto.connect_vpc("the_key", "the secret") vpc = conn.create_vpc("10.0.0.0/16") route_table = conn.create_route_table(vpc.id) @@ -525,17 +550,16 @@ def test_network_acl_tagging(): tag.value.should.equal("some value") all_route_tables = conn.get_all_route_tables() - test_route_table = next(na for na in all_route_tables - if na.id == route_table.id) + test_route_table = next(na for na in all_route_tables if na.id == route_table.id) test_route_table.tags.should.have.length_of(1) test_route_table.tags["a key"].should.equal("some value") @mock_ec2 def test_create_route_with_invalid_destination_cidr_block_parameter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok @@ -546,9 +570,14 @@ def test_create_route_with_invalid_destination_cidr_block_parameter(): vpc.attach_internet_gateway(InternetGatewayId=internet_gateway.id) internet_gateway.reload() - destination_cidr_block = '1000.1.0.0/20' + destination_cidr_block = "1000.1.0.0/20" with assert_raises(ClientError) as ex: - route = route_table.create_route(DestinationCidrBlock=destination_cidr_block, GatewayId=internet_gateway.id) + route = route_table.create_route( + DestinationCidrBlock=destination_cidr_block, GatewayId=internet_gateway.id + ) str(ex.exception).should.equal( "An error occurred (InvalidParameterValue) when calling the CreateRoute " - "operation: Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format(destination_cidr_block)) \ No newline at end of file + "operation: Value ({}) for parameter destinationCidrBlock is invalid. This is not a valid CIDR block.".format( + destination_cidr_block + ) + ) diff --git a/tests/test_ec2/test_security_groups.py b/tests/test_ec2/test_security_groups.py index c09b1e8f4..d872bdf87 100644 --- a/tests/test_ec2/test_security_groups.py +++ b/tests/test_ec2/test_security_groups.py @@ -17,27 +17,31 @@ from moto import mock_ec2, mock_ec2_deprecated @mock_ec2_deprecated def test_create_and_describe_security_group(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as ex: security_group = conn.create_security_group( - 'test security group', 'this is a test security group', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + "test security group", "this is a test security group", dry_run=True + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateSecurityGroup operation: Request would have succeeded, but DryRun flag is set" + ) security_group = conn.create_security_group( - 'test security group', 'this is a test security group') + "test security group", "this is a test security group" + ) - security_group.name.should.equal('test security group') - security_group.description.should.equal('this is a test security group') + security_group.name.should.equal("test security group") + security_group.description.should.equal("this is a test security group") # Trying to create another group with the same name should throw an error with assert_raises(EC2ResponseError) as cm: conn.create_security_group( - 'test security group', 'this is a test security group') - cm.exception.code.should.equal('InvalidGroup.Duplicate') + "test security group", "this is a test security group" + ) + cm.exception.code.should.equal("InvalidGroup.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -50,18 +54,18 @@ def test_create_and_describe_security_group(): @mock_ec2_deprecated def test_create_security_group_without_description_raises_error(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.create_security_group('test security group', '') - cm.exception.code.should.equal('MissingParameter') + conn.create_security_group("test security group", "") + cm.exception.code.should.equal("MissingParameter") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_default_security_group(): - conn = boto.ec2.connect_to_region('us-east-1') + conn = boto.ec2.connect_to_region("us-east-1") groups = conn.get_all_security_groups() groups.should.have.length_of(2) groups[0].name.should.equal("default") @@ -69,43 +73,47 @@ def test_default_security_group(): @mock_ec2_deprecated def test_create_and_describe_vpc_security_group(): - conn = boto.connect_ec2('the_key', 'the_secret') - vpc_id = 'vpc-5300000c' + conn = boto.connect_ec2("the_key", "the_secret") + vpc_id = "vpc-5300000c" security_group = conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id=vpc_id) + "test security group", "this is a test security group", vpc_id=vpc_id + ) security_group.vpc_id.should.equal(vpc_id) - security_group.name.should.equal('test security group') - security_group.description.should.equal('this is a test security group') + security_group.name.should.equal("test security group") + security_group.description.should.equal("this is a test security group") # Trying to create another group with the same name in the same VPC should # throw an error with assert_raises(EC2ResponseError) as cm: conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id) - cm.exception.code.should.equal('InvalidGroup.Duplicate') + "test security group", "this is a test security group", vpc_id + ) + cm.exception.code.should.equal("InvalidGroup.Duplicate") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none - all_groups = conn.get_all_security_groups(filters={'vpc_id': [vpc_id]}) + all_groups = conn.get_all_security_groups(filters={"vpc_id": [vpc_id]}) all_groups[0].vpc_id.should.equal(vpc_id) all_groups.should.have.length_of(1) - all_groups[0].name.should.equal('test security group') + all_groups[0].name.should.equal("test security group") @mock_ec2_deprecated def test_create_two_security_groups_with_same_name_in_different_vpc(): - conn = boto.connect_ec2('the_key', 'the_secret') - vpc_id = 'vpc-5300000c' - vpc_id2 = 'vpc-5300000d' + conn = boto.connect_ec2("the_key", "the_secret") + vpc_id = "vpc-5300000c" + vpc_id2 = "vpc-5300000d" conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id) + "test security group", "this is a test security group", vpc_id + ) conn.create_security_group( - 'test security group', 'this is a test security group', vpc_id2) + "test security group", "this is a test security group", vpc_id2 + ) all_groups = conn.get_all_security_groups() @@ -117,28 +125,29 @@ def test_create_two_security_groups_with_same_name_in_different_vpc(): @mock_ec2_deprecated def test_deleting_security_groups(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group1 = conn.create_security_group('test1', 'test1') - conn.create_security_group('test2', 'test2') + conn = boto.connect_ec2("the_key", "the_secret") + security_group1 = conn.create_security_group("test1", "test1") + conn.create_security_group("test2", "test2") conn.get_all_security_groups().should.have.length_of(4) # Deleting a group that doesn't exist should throw an error with assert_raises(EC2ResponseError) as cm: - conn.delete_security_group('foobar') - cm.exception.code.should.equal('InvalidGroup.NotFound') + conn.delete_security_group("foobar") + cm.exception.code.should.equal("InvalidGroup.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Delete by name with assert_raises(EC2ResponseError) as ex: - conn.delete_security_group('test2', dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + conn.delete_security_group("test2", dry_run=True) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteSecurityGroup operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteSecurityGroup operation: Request would have succeeded, but DryRun flag is set" + ) - conn.delete_security_group('test2') + conn.delete_security_group("test2") conn.get_all_security_groups().should.have.length_of(3) # Delete by group id @@ -148,9 +157,9 @@ def test_deleting_security_groups(): @mock_ec2_deprecated def test_delete_security_group_in_vpc(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") vpc_id = "vpc-12345" - security_group1 = conn.create_security_group('test1', 'test1', vpc_id) + security_group1 = conn.create_security_group("test1", "test1", vpc_id) # this should not throw an exception conn.delete_security_group(group_id=security_group1.id) @@ -158,87 +167,130 @@ def test_delete_security_group_in_vpc(): @mock_ec2_deprecated def test_authorize_ip_range_and_revoke(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group = conn.create_security_group('test', 'test') + conn = boto.connect_ec2("the_key", "the_secret") + security_group = conn.create_security_group("test", "test") with assert_raises(EC2ResponseError) as ex: success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ip_protocol="tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the GrantSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the GrantSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set" + ) success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32") + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32" + ) assert success.should.be.true - security_group = conn.get_all_security_groups(groupnames=['test'])[0] + security_group = conn.get_all_security_groups(groupnames=["test"])[0] int(security_group.rules[0].to_port).should.equal(2222) - security_group.rules[0].grants[ - 0].cidr_ip.should.equal("123.123.123.123/32") + security_group.rules[0].grants[0].cidr_ip.should.equal("123.123.123.123/32") # Wrong Cidr should throw error with assert_raises(EC2ResponseError) as cm: - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", cidr_ip="123.123.123.122/32") - cm.exception.code.should.equal('InvalidPermission.NotFound') + security_group.revoke( + ip_protocol="tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.122/32", + ) + cm.exception.code.should.equal("InvalidPermission.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Actually revoke with assert_raises(EC2ResponseError) as ex: - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + security_group.revoke( + ip_protocol="tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RevokeSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RevokeSecurityGroupIngress operation: Request would have succeeded, but DryRun flag is set" + ) - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", cidr_ip="123.123.123.123/32") + security_group.revoke( + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32" + ) security_group = conn.get_all_security_groups()[0] security_group.rules.should.have.length_of(0) # Test for egress as well egress_security_group = conn.create_security_group( - 'testegress', 'testegress', vpc_id='vpc-3432589') + "testegress", "testegress", vpc_id="vpc-3432589" + ) with assert_raises(EC2ResponseError) as ex: success = conn.authorize_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the GrantSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the GrantSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set" + ) success = conn.authorize_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32") + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + ) assert success.should.be.true - egress_security_group = conn.get_all_security_groups( - groupnames='testegress')[0] + egress_security_group = conn.get_all_security_groups(groupnames="testegress")[0] # There are two egress rules associated with the security group: # the default outbound rule and the new one int(egress_security_group.rules_egress[1].to_port).should.equal(2222) - egress_security_group.rules_egress[1].grants[ - 0].cidr_ip.should.equal("123.123.123.123/32") + egress_security_group.rules_egress[1].grants[0].cidr_ip.should.equal( + "123.123.123.123/32" + ) # Wrong Cidr should throw error egress_security_group.revoke.when.called_with( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.122/32").should.throw(EC2ResponseError) + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.122/32" + ).should.throw(EC2ResponseError) # Actually revoke with assert_raises(EC2ResponseError) as ex: conn.revoke_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + dry_run=True, + ) + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the RevokeSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the RevokeSecurityGroupEgress operation: Request would have succeeded, but DryRun flag is set" + ) conn.revoke_security_group_egress( - egress_security_group.id, "tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123/32") + egress_security_group.id, + "tcp", + from_port="22", + to_port="2222", + cidr_ip="123.123.123.123/32", + ) egress_security_group = conn.get_all_security_groups()[0] # There is still the default outbound rule @@ -247,55 +299,69 @@ def test_authorize_ip_range_and_revoke(): @mock_ec2_deprecated def test_authorize_other_group_and_revoke(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group = conn.create_security_group('test', 'test') - other_security_group = conn.create_security_group('other', 'other') - wrong_group = conn.create_security_group('wrong', 'wrong') + conn = boto.connect_ec2("the_key", "the_secret") + security_group = conn.create_security_group("test", "test") + other_security_group = conn.create_security_group("other", "other") + wrong_group = conn.create_security_group("wrong", "wrong") success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", src_group=other_security_group) + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) assert success.should.be.true security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test'][0] + group for group in conn.get_all_security_groups() if group.name == "test" + ][0] int(security_group.rules[0].to_port).should.equal(2222) - security_group.rules[0].grants[ - 0].group_id.should.equal(other_security_group.id) + security_group.rules[0].grants[0].group_id.should.equal(other_security_group.id) # Wrong source group should throw error with assert_raises(EC2ResponseError) as cm: - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", src_group=wrong_group) - cm.exception.code.should.equal('InvalidPermission.NotFound') + security_group.revoke( + ip_protocol="tcp", from_port="22", to_port="2222", src_group=wrong_group + ) + cm.exception.code.should.equal("InvalidPermission.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none # Actually revoke - security_group.revoke(ip_protocol="tcp", from_port="22", - to_port="2222", src_group=other_security_group) + security_group.revoke( + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test'][0] + group for group in conn.get_all_security_groups() if group.name == "test" + ][0] security_group.rules.should.have.length_of(0) @mock_ec2 def test_authorize_other_group_egress_and_revoke(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") sg01 = ec2.create_security_group( - GroupName='sg01', Description='Test security group sg01', VpcId=vpc.id) + GroupName="sg01", Description="Test security group sg01", VpcId=vpc.id + ) sg02 = ec2.create_security_group( - GroupName='sg02', Description='Test security group sg02', VpcId=vpc.id) + GroupName="sg02", Description="Test security group sg02", VpcId=vpc.id + ) ip_permission = { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'UserIdGroupPairs': [{'GroupId': sg02.id, 'GroupName': 'sg02', 'UserId': sg02.owner_id}], - 'IpRanges': [] + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "UserIdGroupPairs": [ + {"GroupId": sg02.id, "GroupName": "sg02", "UserId": sg02.owner_id} + ], + "IpRanges": [], } sg01.authorize_egress(IpPermissions=[ip_permission]) @@ -308,32 +374,41 @@ def test_authorize_other_group_egress_and_revoke(): @mock_ec2_deprecated def test_authorize_group_in_vpc(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") vpc_id = "vpc-12345" # create 2 groups in a vpc - security_group = conn.create_security_group('test1', 'test1', vpc_id) - other_security_group = conn.create_security_group('test2', 'test2', vpc_id) + security_group = conn.create_security_group("test1", "test1", vpc_id) + other_security_group = conn.create_security_group("test2", "test2", vpc_id) success = security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", src_group=other_security_group) + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) success.should.be.true # Check that the rule is accurate security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test1'][0] + group for group in conn.get_all_security_groups() if group.name == "test1" + ][0] int(security_group.rules[0].to_port).should.equal(2222) - security_group.rules[0].grants[ - 0].group_id.should.equal(other_security_group.id) + security_group.rules[0].grants[0].group_id.should.equal(other_security_group.id) # Now remove the rule success = security_group.revoke( - ip_protocol="tcp", from_port="22", to_port="2222", src_group=other_security_group) + ip_protocol="tcp", + from_port="22", + to_port="2222", + src_group=other_security_group, + ) success.should.be.true # And check that it gets revoked security_group = [ - group for group in conn.get_all_security_groups() if group.name == 'test1'][0] + group for group in conn.get_all_security_groups() if group.name == "test1" + ][0] security_group.rules.should.have.length_of(0) @@ -341,31 +416,32 @@ def test_authorize_group_in_vpc(): def test_get_all_security_groups(): conn = boto.connect_ec2() sg1 = conn.create_security_group( - name='test1', description='test1', vpc_id='vpc-mjm05d27') - conn.create_security_group(name='test2', description='test2') + name="test1", description="test1", vpc_id="vpc-mjm05d27" + ) + conn.create_security_group(name="test2", description="test2") - resp = conn.get_all_security_groups(groupnames=['test1']) + resp = conn.get_all_security_groups(groupnames=["test1"]) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) with assert_raises(EC2ResponseError) as cm: - conn.get_all_security_groups(groupnames=['does_not_exist']) - cm.exception.code.should.equal('InvalidGroup.NotFound') + conn.get_all_security_groups(groupnames=["does_not_exist"]) + cm.exception.code.should.equal("InvalidGroup.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) - resp = conn.get_all_security_groups(filters={'vpc-id': ['vpc-mjm05d27']}) + resp = conn.get_all_security_groups(filters={"vpc-id": ["vpc-mjm05d27"]}) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) - resp = conn.get_all_security_groups(filters={'vpc_id': ['vpc-mjm05d27']}) + resp = conn.get_all_security_groups(filters={"vpc_id": ["vpc-mjm05d27"]}) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) - resp = conn.get_all_security_groups(filters={'description': ['test1']}) + resp = conn.get_all_security_groups(filters={"description": ["test1"]}) resp.should.have.length_of(1) resp[0].id.should.equal(sg1.id) @@ -375,12 +451,13 @@ def test_get_all_security_groups(): @mock_ec2_deprecated def test_authorize_bad_cidr_throws_invalid_parameter_value(): - conn = boto.connect_ec2('the_key', 'the_secret') - security_group = conn.create_security_group('test', 'test') + conn = boto.connect_ec2("the_key", "the_secret") + security_group = conn.create_security_group("test", "test") with assert_raises(EC2ResponseError) as cm: security_group.authorize( - ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123") - cm.exception.code.should.equal('InvalidParameterValue') + ip_protocol="tcp", from_port="22", to_port="2222", cidr_ip="123.123.123.123" + ) + cm.exception.code.should.equal("InvalidParameterValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -394,10 +471,11 @@ def test_security_group_tagging(): with assert_raises(EC2ResponseError) as ex: sg.add_tag("Test", "Tag", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) sg.add_tag("Test", "Tag") @@ -416,20 +494,19 @@ def test_security_group_tag_filtering(): sg = conn.create_security_group("test-sg", "Test SG") sg.add_tag("test-tag", "test-value") - groups = conn.get_all_security_groups( - filters={"tag:test-tag": "test-value"}) + groups = conn.get_all_security_groups(filters={"tag:test-tag": "test-value"}) groups.should.have.length_of(1) @mock_ec2_deprecated def test_authorize_all_protocols_with_no_port_specification(): conn = boto.connect_ec2() - sg = conn.create_security_group('test', 'test') + sg = conn.create_security_group("test", "test") - success = sg.authorize(ip_protocol='-1', cidr_ip='0.0.0.0/0') + success = sg.authorize(ip_protocol="-1", cidr_ip="0.0.0.0/0") success.should.be.true - sg = conn.get_all_security_groups('test')[0] + sg = conn.get_all_security_groups("test")[0] sg.rules[0].from_port.should.equal(None) sg.rules[0].to_port.should.equal(None) @@ -437,63 +514,68 @@ def test_authorize_all_protocols_with_no_port_specification(): @mock_ec2_deprecated def test_sec_group_rule_limit(): ec2_conn = boto.connect_ec2() - sg = ec2_conn.create_security_group('test', 'test') - other_sg = ec2_conn.create_security_group('test_2', 'test_other') + sg = ec2_conn.create_security_group("test", "test") + other_sg = ec2_conn.create_security_group("test_2", "test_other") # INGRESS with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(110)]) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(110)], + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") sg.rules.should.be.empty # authorize a rule targeting a different sec group (because this count too) success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) success.should.be.true # fill the rules up the limit success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(99)]) + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(99)], + ) success.should.be.true # verify that we cannot authorize past the limit for a CIDR IP with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', cidr_ip=['100.0.0.0/0']) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip=["100.0.0.0/0"] + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # EGRESS # authorize a rule targeting a different sec group (because this count too) ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) # fill the rules up the limit # remember that by default, when created a sec group contains 1 egress rule # so our other_sg rule + 98 CIDR IP rules + 1 by default == 100 the limit for i in range(98): ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='{0}.0.0.0/0'.format(i)) + group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i) + ) # verify that we cannot authorize past the limit for a CIDR IP with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='101.0.0.0/0') - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip="101.0.0.0/0" + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") @mock_ec2_deprecated @@ -501,87 +583,93 @@ def test_sec_group_rule_limit_vpc(): ec2_conn = boto.connect_ec2() vpc_conn = boto.connect_vpc() - vpc = vpc_conn.create_vpc('10.0.0.0/16') + vpc = vpc_conn.create_vpc("10.0.0.0/16") - sg = ec2_conn.create_security_group('test', 'test', vpc_id=vpc.id) - other_sg = ec2_conn.create_security_group('test_2', 'test', vpc_id=vpc.id) + sg = ec2_conn.create_security_group("test", "test", vpc_id=vpc.id) + other_sg = ec2_conn.create_security_group("test_2", "test", vpc_id=vpc.id) # INGRESS with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(110)]) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(110)], + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") sg.rules.should.be.empty # authorize a rule targeting a different sec group (because this count too) success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) success.should.be.true # fill the rules up the limit success = ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - cidr_ip=['{0}.0.0.0/0'.format(i) for i in range(49)]) + group_id=sg.id, + ip_protocol="-1", + cidr_ip=["{0}.0.0.0/0".format(i) for i in range(49)], + ) # verify that we cannot authorize past the limit for a CIDR IP success.should.be.true with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', cidr_ip=['100.0.0.0/0']) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip=["100.0.0.0/0"] + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group( - group_id=sg.id, ip_protocol='-1', - src_security_group_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_security_group_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # EGRESS # authorize a rule targeting a different sec group (because this count too) ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) # fill the rules up the limit # remember that by default, when created a sec group contains 1 egress rule # so our other_sg rule + 48 CIDR IP rules + 1 by default == 50 the limit for i in range(48): ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='{0}.0.0.0/0'.format(i)) + group_id=sg.id, ip_protocol="-1", cidr_ip="{0}.0.0.0/0".format(i) + ) # verify that we cannot authorize past the limit for a CIDR IP with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - cidr_ip='50.0.0.0/0') - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", cidr_ip="50.0.0.0/0" + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") # verify that we cannot authorize past the limit for a different sec group with assert_raises(EC2ResponseError) as cm: ec2_conn.authorize_security_group_egress( - group_id=sg.id, ip_protocol='-1', - src_group_id=other_sg.id) - cm.exception.error_code.should.equal('RulesPerSecurityGroupLimitExceeded') + group_id=sg.id, ip_protocol="-1", src_group_id=other_sg.id + ) + cm.exception.error_code.should.equal("RulesPerSecurityGroupLimitExceeded") -''' +""" Boto3 -''' +""" @mock_ec2 def test_add_same_rule_twice_throws_error(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") sg = ec2.create_security_group( - GroupName='sg1', Description='Test security group sg1', VpcId=vpc.id) + GroupName="sg1", Description="Test security group sg1", VpcId=vpc.id + ) ip_permissions = [ { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'IpRanges': [{"CidrIp": "1.2.3.4/32"}] - }, + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "IpRanges": [{"CidrIp": "1.2.3.4/32"}], + } ] sg.authorize_ingress(IpPermissions=ip_permissions) @@ -591,82 +679,89 @@ def test_add_same_rule_twice_throws_error(): @mock_ec2 def test_security_group_tagging_boto3(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") sg = conn.create_security_group(GroupName="test-sg", Description="Test SG") with assert_raises(ClientError) as ex: - conn.create_tags(Resources=[sg['GroupId']], Tags=[ - {'Key': 'Test', 'Value': 'Tag'}], DryRun=True) - ex.exception.response['Error']['Code'].should.equal('DryRunOperation') - ex.exception.response['ResponseMetadata'][ - 'HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + conn.create_tags( + Resources=[sg["GroupId"]], + Tags=[{"Key": "Test", "Value": "Tag"}], + DryRun=True, + ) + ex.exception.response["Error"]["Code"].should.equal("DryRunOperation") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) - conn.create_tags(Resources=[sg['GroupId']], Tags=[ - {'Key': 'Test', 'Value': 'Tag'}]) + conn.create_tags(Resources=[sg["GroupId"]], Tags=[{"Key": "Test", "Value": "Tag"}]) describe = conn.describe_security_groups( - Filters=[{'Name': 'tag-value', 'Values': ['Tag']}]) - tag = describe["SecurityGroups"][0]['Tags'][0] - tag['Value'].should.equal("Tag") - tag['Key'].should.equal("Test") + Filters=[{"Name": "tag-value", "Values": ["Tag"]}] + ) + tag = describe["SecurityGroups"][0]["Tags"][0] + tag["Value"].should.equal("Tag") + tag["Key"].should.equal("Test") @mock_ec2 def test_security_group_wildcard_tag_filter_boto3(): - conn = boto3.client('ec2', region_name='us-east-1') + conn = boto3.client("ec2", region_name="us-east-1") sg = conn.create_security_group(GroupName="test-sg", Description="Test SG") - conn.create_tags(Resources=[sg['GroupId']], Tags=[ - {'Key': 'Test', 'Value': 'Tag'}]) + conn.create_tags(Resources=[sg["GroupId"]], Tags=[{"Key": "Test", "Value": "Tag"}]) describe = conn.describe_security_groups( - Filters=[{'Name': 'tag-value', 'Values': ['*']}]) + Filters=[{"Name": "tag-value", "Values": ["*"]}] + ) - tag = describe["SecurityGroups"][0]['Tags'][0] - tag['Value'].should.equal("Tag") - tag['Key'].should.equal("Test") + tag = describe["SecurityGroups"][0]["Tags"][0] + tag["Value"].should.equal("Tag") + tag["Key"].should.equal("Test") @mock_ec2 def test_authorize_and_revoke_in_bulk(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") sg01 = ec2.create_security_group( - GroupName='sg01', Description='Test security group sg01', VpcId=vpc.id) + GroupName="sg01", Description="Test security group sg01", VpcId=vpc.id + ) sg02 = ec2.create_security_group( - GroupName='sg02', Description='Test security group sg02', VpcId=vpc.id) + GroupName="sg02", Description="Test security group sg02", VpcId=vpc.id + ) sg03 = ec2.create_security_group( - GroupName='sg03', Description='Test security group sg03') + GroupName="sg03", Description="Test security group sg03" + ) ip_permissions = [ { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'UserIdGroupPairs': [{'GroupId': sg02.id, 'GroupName': 'sg02', - 'UserId': sg02.owner_id}], - 'IpRanges': [] + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "UserIdGroupPairs": [ + {"GroupId": sg02.id, "GroupName": "sg02", "UserId": sg02.owner_id} + ], + "IpRanges": [], }, { - 'IpProtocol': 'tcp', - 'FromPort': 27018, - 'ToPort': 27018, - 'UserIdGroupPairs': [{'GroupId': sg02.id, 'UserId': sg02.owner_id}], - 'IpRanges': [] + "IpProtocol": "tcp", + "FromPort": 27018, + "ToPort": 27018, + "UserIdGroupPairs": [{"GroupId": sg02.id, "UserId": sg02.owner_id}], + "IpRanges": [], }, { - 'IpProtocol': 'tcp', - 'FromPort': 27017, - 'ToPort': 27017, - 'UserIdGroupPairs': [{'GroupName': 'sg03', 'UserId': sg03.owner_id}], - 'IpRanges': [] - } + "IpProtocol": "tcp", + "FromPort": 27017, + "ToPort": 27017, + "UserIdGroupPairs": [{"GroupName": "sg03", "UserId": sg03.owner_id}], + "IpRanges": [], + }, ] expected_ip_permissions = copy.deepcopy(ip_permissions) - expected_ip_permissions[1]['UserIdGroupPairs'][0]['GroupName'] = 'sg02' - expected_ip_permissions[2]['UserIdGroupPairs'][0]['GroupId'] = sg03.id + expected_ip_permissions[1]["UserIdGroupPairs"][0]["GroupName"] = "sg02" + expected_ip_permissions[2]["UserIdGroupPairs"][0]["GroupId"] = sg03.id sg01.authorize_ingress(IpPermissions=ip_permissions) sg01.ip_permissions.should.have.length_of(3) @@ -691,11 +786,13 @@ def test_authorize_and_revoke_in_bulk(): @mock_ec2 def test_security_group_ingress_without_multirule(): - ec2 = boto3.resource('ec2', 'ca-central-1') - sg = ec2.create_security_group(Description='Test SG', GroupName='test-sg') + ec2 = boto3.resource("ec2", "ca-central-1") + sg = ec2.create_security_group(Description="Test SG", GroupName="test-sg") assert len(sg.ip_permissions) == 0 - sg.authorize_ingress(CidrIp='192.168.0.1/32', FromPort=22, ToPort=22, IpProtocol='tcp') + sg.authorize_ingress( + CidrIp="192.168.0.1/32", FromPort=22, ToPort=22, IpProtocol="tcp" + ) # Fails assert len(sg.ip_permissions) == 1 @@ -703,11 +800,13 @@ def test_security_group_ingress_without_multirule(): @mock_ec2 def test_security_group_ingress_without_multirule_after_reload(): - ec2 = boto3.resource('ec2', 'ca-central-1') - sg = ec2.create_security_group(Description='Test SG', GroupName='test-sg') + ec2 = boto3.resource("ec2", "ca-central-1") + sg = ec2.create_security_group(Description="Test SG", GroupName="test-sg") assert len(sg.ip_permissions) == 0 - sg.authorize_ingress(CidrIp='192.168.0.1/32', FromPort=22, ToPort=22, IpProtocol='tcp') + sg.authorize_ingress( + CidrIp="192.168.0.1/32", FromPort=22, ToPort=22, IpProtocol="tcp" + ) # Also Fails sg_after = ec2.SecurityGroup(sg.id) @@ -716,22 +815,21 @@ def test_security_group_ingress_without_multirule_after_reload(): @mock_ec2_deprecated def test_get_all_security_groups_filter_with_same_vpc_id(): - conn = boto.connect_ec2('the_key', 'the_secret') - vpc_id = 'vpc-5300000c' - security_group = conn.create_security_group( - 'test1', 'test1', vpc_id=vpc_id) - security_group2 = conn.create_security_group( - 'test2', 'test2', vpc_id=vpc_id) + conn = boto.connect_ec2("the_key", "the_secret") + vpc_id = "vpc-5300000c" + security_group = conn.create_security_group("test1", "test1", vpc_id=vpc_id) + security_group2 = conn.create_security_group("test2", "test2", vpc_id=vpc_id) security_group.vpc_id.should.equal(vpc_id) security_group2.vpc_id.should.equal(vpc_id) security_groups = conn.get_all_security_groups( - group_ids=[security_group.id], filters={'vpc-id': [vpc_id]}) + group_ids=[security_group.id], filters={"vpc-id": [vpc_id]} + ) security_groups.should.have.length_of(1) with assert_raises(EC2ResponseError) as cm: - conn.get_all_security_groups(group_ids=['does_not_exist']) - cm.exception.code.should.equal('InvalidGroup.NotFound') + conn.get_all_security_groups(group_ids=["does_not_exist"]) + cm.exception.code.should.equal("InvalidGroup.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none diff --git a/tests/test_ec2/test_server.py b/tests/test_ec2/test_server.py index 00be62593..f09146b2a 100644 --- a/tests/test_ec2/test_server.py +++ b/tests/test_ec2/test_server.py @@ -4,9 +4,9 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_ec2_server_get(): @@ -14,13 +14,12 @@ def test_ec2_server_get(): test_client = backend.test_client() res = test_client.get( - '/?Action=RunInstances&ImageId=ami-60a54009', - headers={"Host": "ec2.us-east-1.amazonaws.com"} + "/?Action=RunInstances&ImageId=ami-60a54009", + headers={"Host": "ec2.us-east-1.amazonaws.com"}, ) - groups = re.search("(.*)", - res.data.decode('utf-8')) + groups = re.search("(.*)", res.data.decode("utf-8")) instance_id = groups.groups()[0] - res = test_client.get('/?Action=DescribeInstances') - res.data.decode('utf-8').should.contain(instance_id) + res = test_client.get("/?Action=DescribeInstances") + res.data.decode("utf-8").should.contain(instance_id) diff --git a/tests/test_ec2/test_spot_fleet.py b/tests/test_ec2/test_spot_fleet.py index 6221d633f..7b27764a1 100644 --- a/tests/test_ec2/test_spot_fleet.py +++ b/tests/test_ec2/test_spot_fleet.py @@ -7,381 +7,368 @@ from moto import mock_ec2 def get_subnet_id(conn): - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] return subnet_id def spot_config(subnet_id, allocation_strategy="lowestPrice"): return { - 'ClientToken': 'string', - 'SpotPrice': '0.12', - 'TargetCapacity': 6, - 'IamFleetRole': 'arn:aws:iam::123456789012:role/fleet', - 'LaunchSpecifications': [{ - 'ImageId': 'ami-123', - 'KeyName': 'my-key', - 'SecurityGroups': [ - { - 'GroupId': 'sg-123' - }, - ], - 'UserData': 'some user data', - 'InstanceType': 't2.small', - 'BlockDeviceMappings': [ - { - 'VirtualName': 'string', - 'DeviceName': 'string', - 'Ebs': { - 'SnapshotId': 'string', - 'VolumeSize': 123, - 'DeleteOnTermination': True | False, - 'VolumeType': 'standard', - 'Iops': 123, - 'Encrypted': True | False + "ClientToken": "string", + "SpotPrice": "0.12", + "TargetCapacity": 6, + "IamFleetRole": "arn:aws:iam::123456789012:role/fleet", + "LaunchSpecifications": [ + { + "ImageId": "ami-123", + "KeyName": "my-key", + "SecurityGroups": [{"GroupId": "sg-123"}], + "UserData": "some user data", + "InstanceType": "t2.small", + "BlockDeviceMappings": [ + { + "VirtualName": "string", + "DeviceName": "string", + "Ebs": { + "SnapshotId": "string", + "VolumeSize": 123, + "DeleteOnTermination": True | False, + "VolumeType": "standard", + "Iops": 123, + "Encrypted": True | False, }, - 'NoDevice': 'string' - }, - ], - 'Monitoring': { - 'Enabled': True + "NoDevice": "string", + } + ], + "Monitoring": {"Enabled": True}, + "SubnetId": subnet_id, + "IamInstanceProfile": {"Arn": "arn:aws:iam::123456789012:role/fleet"}, + "EbsOptimized": False, + "WeightedCapacity": 2.0, + "SpotPrice": "0.13", }, - 'SubnetId': subnet_id, - 'IamInstanceProfile': { - 'Arn': 'arn:aws:iam::123456789012:role/fleet' + { + "ImageId": "ami-123", + "KeyName": "my-key", + "SecurityGroups": [{"GroupId": "sg-123"}], + "UserData": "some user data", + "InstanceType": "t2.large", + "Monitoring": {"Enabled": True}, + "SubnetId": subnet_id, + "IamInstanceProfile": {"Arn": "arn:aws:iam::123456789012:role/fleet"}, + "EbsOptimized": False, + "WeightedCapacity": 4.0, + "SpotPrice": "10.00", }, - 'EbsOptimized': False, - 'WeightedCapacity': 2.0, - 'SpotPrice': '0.13', - }, { - 'ImageId': 'ami-123', - 'KeyName': 'my-key', - 'SecurityGroups': [ - { - 'GroupId': 'sg-123' - }, - ], - 'UserData': 'some user data', - 'InstanceType': 't2.large', - 'Monitoring': { - 'Enabled': True - }, - 'SubnetId': subnet_id, - 'IamInstanceProfile': { - 'Arn': 'arn:aws:iam::123456789012:role/fleet' - }, - 'EbsOptimized': False, - 'WeightedCapacity': 4.0, - 'SpotPrice': '10.00', - }], - 'AllocationStrategy': allocation_strategy, - 'FulfilledCapacity': 6, + ], + "AllocationStrategy": allocation_strategy, + "FulfilledCapacity": 6, } @mock_ec2 def test_create_spot_fleet_with_lowest_price(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_request['SpotFleetRequestState'].should.equal("active") - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_request["SpotFleetRequestState"].should.equal("active") + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - spot_fleet_config['SpotPrice'].should.equal('0.12') - spot_fleet_config['TargetCapacity'].should.equal(6) - spot_fleet_config['IamFleetRole'].should.equal( - 'arn:aws:iam::123456789012:role/fleet') - spot_fleet_config['AllocationStrategy'].should.equal('lowestPrice') - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + spot_fleet_config["SpotPrice"].should.equal("0.12") + spot_fleet_config["TargetCapacity"].should.equal(6) + spot_fleet_config["IamFleetRole"].should.equal( + "arn:aws:iam::123456789012:role/fleet" + ) + spot_fleet_config["AllocationStrategy"].should.equal("lowestPrice") + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec = spot_fleet_config['LaunchSpecifications'][0] + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec = spot_fleet_config["LaunchSpecifications"][0] - launch_spec['EbsOptimized'].should.equal(False) - launch_spec['SecurityGroups'].should.equal([{"GroupId": "sg-123"}]) - launch_spec['IamInstanceProfile'].should.equal( - {"Arn": "arn:aws:iam::123456789012:role/fleet"}) - launch_spec['ImageId'].should.equal("ami-123") - launch_spec['InstanceType'].should.equal("t2.small") - launch_spec['KeyName'].should.equal("my-key") - launch_spec['Monitoring'].should.equal({"Enabled": True}) - launch_spec['SpotPrice'].should.equal("0.13") - launch_spec['SubnetId'].should.equal(subnet_id) - launch_spec['UserData'].should.equal("some user data") - launch_spec['WeightedCapacity'].should.equal(2.0) + launch_spec["EbsOptimized"].should.equal(False) + launch_spec["SecurityGroups"].should.equal([{"GroupId": "sg-123"}]) + launch_spec["IamInstanceProfile"].should.equal( + {"Arn": "arn:aws:iam::123456789012:role/fleet"} + ) + launch_spec["ImageId"].should.equal("ami-123") + launch_spec["InstanceType"].should.equal("t2.small") + launch_spec["KeyName"].should.equal("my-key") + launch_spec["Monitoring"].should.equal({"Enabled": True}) + launch_spec["SpotPrice"].should.equal("0.13") + launch_spec["SubnetId"].should.equal(subnet_id) + launch_spec["UserData"].should.equal("some user data") + launch_spec["WeightedCapacity"].should.equal(2.0) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(3) @mock_ec2 def test_create_diversified_spot_fleet(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) - diversified_config = spot_config( - subnet_id, allocation_strategy='diversified') + diversified_config = spot_config(subnet_id, allocation_strategy="diversified") - spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=diversified_config - ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=diversified_config) + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(2) - instance_types = set([instance['InstanceType'] for instance in instances]) + instance_types = set([instance["InstanceType"] for instance in instances]) instance_types.should.equal(set(["t2.small", "t2.large"])) - instances[0]['InstanceId'].should.contain("i-") + instances[0]["InstanceId"].should.contain("i-") @mock_ec2 def test_create_spot_fleet_request_with_tag_spec(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) tag_spec = [ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'tag-1', - 'Value': 'foo', - }, - { - 'Key': 'tag-2', - 'Value': 'bar', - }, - ] - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "tag-1", "Value": "foo"}, + {"Key": "tag-2", "Value": "bar"}, + ], + } ] config = spot_config(subnet_id) - config['LaunchSpecifications'][0]['TagSpecifications'] = tag_spec - spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=config - ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + config["LaunchSpecifications"][0]["TagSpecifications"] = tag_spec + spot_fleet_res = conn.request_spot_fleet(SpotFleetRequestConfig=config) + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] - spot_fleet_config = spot_fleet_requests[0]['SpotFleetRequestConfig'] - spot_fleet_config['LaunchSpecifications'][0]['TagSpecifications'][0][ - 'ResourceType'].should.equal('instance') - for tag in tag_spec[0]['Tags']: - spot_fleet_config['LaunchSpecifications'][0]['TagSpecifications'][0]['Tags'].should.contain(tag) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] + spot_fleet_config = spot_fleet_requests[0]["SpotFleetRequestConfig"] + spot_fleet_config["LaunchSpecifications"][0]["TagSpecifications"][0][ + "ResourceType" + ].should.equal("instance") + for tag in tag_spec[0]["Tags"]: + spot_fleet_config["LaunchSpecifications"][0]["TagSpecifications"][0][ + "Tags" + ].should.contain(tag) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = conn.describe_instances(InstanceIds=[i['InstanceId'] for i in instance_res['ActiveInstances']]) - for instance in instances['Reservations'][0]['Instances']: - for tag in tag_spec[0]['Tags']: - instance['Tags'].should.contain(tag) + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = conn.describe_instances( + InstanceIds=[i["InstanceId"] for i in instance_res["ActiveInstances"]] + ) + for instance in instances["Reservations"][0]["Instances"]: + for tag in tag_spec[0]["Tags"]: + instance["Tags"].should.contain(tag) @mock_ec2 def test_cancel_spot_fleet_request(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] conn.cancel_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id], TerminateInstances=True) + SpotFleetRequestIds=[spot_fleet_id], TerminateInstances=True + ) spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(0) @mock_ec2 def test_modify_spot_fleet_request_up(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=20) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=20) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(10) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(20) - spot_fleet_config['FulfilledCapacity'].should.equal(20.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(20) + spot_fleet_config["FulfilledCapacity"].should.equal(20.0) @mock_ec2 def test_modify_spot_fleet_request_up_diversified(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config( - subnet_id, allocation_strategy='diversified'), + SpotFleetRequestConfig=spot_config(subnet_id, allocation_strategy="diversified") ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=19) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=19) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(7) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(19) - spot_fleet_config['FulfilledCapacity'].should.equal(20.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(19) + spot_fleet_config["FulfilledCapacity"].should.equal(20.0) @mock_ec2 def test_modify_spot_fleet_request_down_no_terminate(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=1, ExcessCapacityTerminationPolicy="noTermination") + SpotFleetRequestId=spot_fleet_id, + TargetCapacity=1, + ExcessCapacityTerminationPolicy="noTermination", + ) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(3) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(1) - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(1) + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) @mock_ec2 def test_modify_spot_fleet_request_down_odd(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=7) - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=5) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=7) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=5) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(3) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(5) - spot_fleet_config['FulfilledCapacity'].should.equal(6.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(5) + spot_fleet_config["FulfilledCapacity"].should.equal(6.0) @mock_ec2 def test_modify_spot_fleet_request_down(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=1) + conn.modify_spot_fleet_request(SpotFleetRequestId=spot_fleet_id, TargetCapacity=1) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(1) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(1) - spot_fleet_config['FulfilledCapacity'].should.equal(2.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(1) + spot_fleet_config["FulfilledCapacity"].should.equal(2.0) @mock_ec2 def test_modify_spot_fleet_request_down_no_terminate_after_custom_terminate(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) spot_fleet_res = conn.request_spot_fleet( - SpotFleetRequestConfig=spot_config(subnet_id), + SpotFleetRequestConfig=spot_config(subnet_id) ) - spot_fleet_id = spot_fleet_res['SpotFleetRequestId'] + spot_fleet_id = spot_fleet_res["SpotFleetRequestId"] - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] - conn.terminate_instances(InstanceIds=[i['InstanceId'] for i in instances[1:]]) + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] + conn.terminate_instances(InstanceIds=[i["InstanceId"] for i in instances[1:]]) conn.modify_spot_fleet_request( - SpotFleetRequestId=spot_fleet_id, TargetCapacity=1, ExcessCapacityTerminationPolicy="noTermination") + SpotFleetRequestId=spot_fleet_id, + TargetCapacity=1, + ExcessCapacityTerminationPolicy="noTermination", + ) - instance_res = conn.describe_spot_fleet_instances( - SpotFleetRequestId=spot_fleet_id) - instances = instance_res['ActiveInstances'] + instance_res = conn.describe_spot_fleet_instances(SpotFleetRequestId=spot_fleet_id) + instances = instance_res["ActiveInstances"] len(instances).should.equal(1) spot_fleet_config = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'][0]['SpotFleetRequestConfig'] - spot_fleet_config['TargetCapacity'].should.equal(1) - spot_fleet_config['FulfilledCapacity'].should.equal(2.0) + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"][0]["SpotFleetRequestConfig"] + spot_fleet_config["TargetCapacity"].should.equal(1) + spot_fleet_config["FulfilledCapacity"].should.equal(2.0) @mock_ec2 def test_create_spot_fleet_without_spot_price(): - conn = boto3.client("ec2", region_name='us-west-2') + conn = boto3.client("ec2", region_name="us-west-2") subnet_id = get_subnet_id(conn) # remove prices to force a fallback to ondemand price spot_config_without_price = spot_config(subnet_id) - del spot_config_without_price['SpotPrice'] - for spec in spot_config_without_price['LaunchSpecifications']: - del spec['SpotPrice'] + del spot_config_without_price["SpotPrice"] + for spec in spot_config_without_price["LaunchSpecifications"]: + del spec["SpotPrice"] - spot_fleet_id = conn.request_spot_fleet(SpotFleetRequestConfig=spot_config_without_price)['SpotFleetRequestId'] + spot_fleet_id = conn.request_spot_fleet( + SpotFleetRequestConfig=spot_config_without_price + )["SpotFleetRequestId"] spot_fleet_requests = conn.describe_spot_fleet_requests( - SpotFleetRequestIds=[spot_fleet_id])['SpotFleetRequestConfigs'] + SpotFleetRequestIds=[spot_fleet_id] + )["SpotFleetRequestConfigs"] len(spot_fleet_requests).should.equal(1) spot_fleet_request = spot_fleet_requests[0] - spot_fleet_config = spot_fleet_request['SpotFleetRequestConfig'] + spot_fleet_config = spot_fleet_request["SpotFleetRequestConfig"] - len(spot_fleet_config['LaunchSpecifications']).should.equal(2) - launch_spec1 = spot_fleet_config['LaunchSpecifications'][0] - launch_spec2 = spot_fleet_config['LaunchSpecifications'][1] + len(spot_fleet_config["LaunchSpecifications"]).should.equal(2) + launch_spec1 = spot_fleet_config["LaunchSpecifications"][0] + launch_spec2 = spot_fleet_config["LaunchSpecifications"][1] # AWS will figure out the price - assert 'SpotPrice' not in launch_spec1 - assert 'SpotPrice' not in launch_spec2 + assert "SpotPrice" not in launch_spec1 + assert "SpotPrice" not in launch_spec2 diff --git a/tests/test_ec2/test_spot_instances.py b/tests/test_ec2/test_spot_instances.py index ab08d392c..cfc95bb82 100644 --- a/tests/test_ec2/test_spot_instances.py +++ b/tests/test_ec2/test_spot_instances.py @@ -16,14 +16,15 @@ from moto.core.utils import iso_8601_datetime_with_milliseconds @mock_ec2 def test_request_spot_instances(): - conn = boto3.client('ec2', 'us-east-1') - vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")['Vpc'] + conn = boto3.client("ec2", "us-east-1") + vpc = conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] subnet = conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.0.0/16', AvailabilityZone='us-east-1a')['Subnet'] - subnet_id = subnet['SubnetId'] + VpcId=vpc["VpcId"], CidrBlock="10.0.0.0/16", AvailabilityZone="us-east-1a" + )["Subnet"] + subnet_id = subnet["SubnetId"] - conn.create_security_group(GroupName='group1', Description='description') - conn.create_security_group(GroupName='group2', Description='description') + conn.create_security_group(GroupName="group1", Description="description") + conn.create_security_group(GroupName="group2", Description="description") start_dt = datetime.datetime(2013, 1, 1).replace(tzinfo=pytz.utc) end_dt = datetime.datetime(2013, 1, 2).replace(tzinfo=pytz.utc) @@ -32,78 +33,79 @@ def test_request_spot_instances(): with assert_raises(ClientError) as ex: request = conn.request_spot_instances( - SpotPrice="0.5", InstanceCount=1, Type='one-time', - ValidFrom=start, ValidUntil=end, LaunchGroup="the-group", - AvailabilityZoneGroup='my-group', + SpotPrice="0.5", + InstanceCount=1, + Type="one-time", + ValidFrom=start, + ValidUntil=end, + LaunchGroup="the-group", + AvailabilityZoneGroup="my-group", LaunchSpecification={ - "ImageId": 'ami-abcd1234', + "ImageId": "ami-abcd1234", "KeyName": "test", - "SecurityGroups": ['group1', 'group2'], + "SecurityGroups": ["group1", "group2"], "UserData": "some test data", - "InstanceType": 'm1.small', - "Placement": { - "AvailabilityZone": 'us-east-1c', - }, + "InstanceType": "m1.small", + "Placement": {"AvailabilityZone": "us-east-1c"}, "KernelId": "test-kernel", "RamdiskId": "test-ramdisk", - "Monitoring": { - "Enabled": True, - }, + "Monitoring": {"Enabled": True}, "SubnetId": subnet_id, }, DryRun=True, ) - ex.exception.response['Error']['Code'].should.equal('DryRunOperation') - ex.exception.response['ResponseMetadata'][ - 'HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal( - 'An error occurred (DryRunOperation) when calling the RequestSpotInstance operation: Request would have succeeded, but DryRun flag is set') + ex.exception.response["Error"]["Code"].should.equal("DryRunOperation") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal( + "An error occurred (DryRunOperation) when calling the RequestSpotInstance operation: Request would have succeeded, but DryRun flag is set" + ) request = conn.request_spot_instances( - SpotPrice="0.5", InstanceCount=1, Type='one-time', - ValidFrom=start, ValidUntil=end, LaunchGroup="the-group", - AvailabilityZoneGroup='my-group', + SpotPrice="0.5", + InstanceCount=1, + Type="one-time", + ValidFrom=start, + ValidUntil=end, + LaunchGroup="the-group", + AvailabilityZoneGroup="my-group", LaunchSpecification={ - "ImageId": 'ami-abcd1234', + "ImageId": "ami-abcd1234", "KeyName": "test", - "SecurityGroups": ['group1', 'group2'], + "SecurityGroups": ["group1", "group2"], "UserData": "some test data", - "InstanceType": 'm1.small', - "Placement": { - "AvailabilityZone": 'us-east-1c', - }, + "InstanceType": "m1.small", + "Placement": {"AvailabilityZone": "us-east-1c"}, "KernelId": "test-kernel", "RamdiskId": "test-ramdisk", - "Monitoring": { - "Enabled": True, - }, + "Monitoring": {"Enabled": True}, "SubnetId": subnet_id, }, ) - requests = conn.describe_spot_instance_requests()['SpotInstanceRequests'] + requests = conn.describe_spot_instance_requests()["SpotInstanceRequests"] requests.should.have.length_of(1) request = requests[0] - request['State'].should.equal("open") - request['SpotPrice'].should.equal("0.5") - request['Type'].should.equal('one-time') - request['ValidFrom'].should.equal(start_dt) - request['ValidUntil'].should.equal(end_dt) - request['LaunchGroup'].should.equal("the-group") - request['AvailabilityZoneGroup'].should.equal('my-group') + request["State"].should.equal("open") + request["SpotPrice"].should.equal("0.5") + request["Type"].should.equal("one-time") + request["ValidFrom"].should.equal(start_dt) + request["ValidUntil"].should.equal(end_dt) + request["LaunchGroup"].should.equal("the-group") + request["AvailabilityZoneGroup"].should.equal("my-group") - launch_spec = request['LaunchSpecification'] - security_group_names = [group['GroupName'] - for group in launch_spec['SecurityGroups']] - set(security_group_names).should.equal(set(['group1', 'group2'])) + launch_spec = request["LaunchSpecification"] + security_group_names = [ + group["GroupName"] for group in launch_spec["SecurityGroups"] + ] + set(security_group_names).should.equal(set(["group1", "group2"])) - launch_spec['ImageId'].should.equal('ami-abcd1234') - launch_spec['KeyName'].should.equal("test") - launch_spec['InstanceType'].should.equal('m1.small') - launch_spec['KernelId'].should.equal("test-kernel") - launch_spec['RamdiskId'].should.equal("test-ramdisk") - launch_spec['SubnetId'].should.equal(subnet_id) + launch_spec["ImageId"].should.equal("ami-abcd1234") + launch_spec["KeyName"].should.equal("test") + launch_spec["InstanceType"].should.equal("m1.small") + launch_spec["KernelId"].should.equal("test-kernel") + launch_spec["RamdiskId"].should.equal("test-ramdisk") + launch_spec["SubnetId"].should.equal(subnet_id) @mock_ec2 @@ -111,58 +113,55 @@ def test_request_spot_instances_default_arguments(): """ Test that moto set the correct default arguments """ - conn = boto3.client('ec2', 'us-east-1') + conn = boto3.client("ec2", "us-east-1") request = conn.request_spot_instances( - SpotPrice="0.5", - LaunchSpecification={ - "ImageId": 'ami-abcd1234', - } + SpotPrice="0.5", LaunchSpecification={"ImageId": "ami-abcd1234"} ) - requests = conn.describe_spot_instance_requests()['SpotInstanceRequests'] + requests = conn.describe_spot_instance_requests()["SpotInstanceRequests"] requests.should.have.length_of(1) request = requests[0] - request['State'].should.equal("open") - request['SpotPrice'].should.equal("0.5") - request['Type'].should.equal('one-time') - request.shouldnt.contain('ValidFrom') - request.shouldnt.contain('ValidUntil') - request.shouldnt.contain('LaunchGroup') - request.shouldnt.contain('AvailabilityZoneGroup') + request["State"].should.equal("open") + request["SpotPrice"].should.equal("0.5") + request["Type"].should.equal("one-time") + request.shouldnt.contain("ValidFrom") + request.shouldnt.contain("ValidUntil") + request.shouldnt.contain("LaunchGroup") + request.shouldnt.contain("AvailabilityZoneGroup") - launch_spec = request['LaunchSpecification'] + launch_spec = request["LaunchSpecification"] - security_group_names = [group['GroupName'] - for group in launch_spec['SecurityGroups']] + security_group_names = [ + group["GroupName"] for group in launch_spec["SecurityGroups"] + ] security_group_names.should.equal(["default"]) - launch_spec['ImageId'].should.equal('ami-abcd1234') - request.shouldnt.contain('KeyName') - launch_spec['InstanceType'].should.equal('m1.small') - request.shouldnt.contain('KernelId') - request.shouldnt.contain('RamdiskId') - request.shouldnt.contain('SubnetId') + launch_spec["ImageId"].should.equal("ami-abcd1234") + request.shouldnt.contain("KeyName") + launch_spec["InstanceType"].should.equal("m1.small") + request.shouldnt.contain("KernelId") + request.shouldnt.contain("RamdiskId") + request.shouldnt.contain("SubnetId") @mock_ec2_deprecated def test_cancel_spot_instance_request(): conn = boto.connect_ec2() - conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) + conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) with assert_raises(EC2ResponseError) as ex: conn.cancel_spot_instance_requests([requests[0].id], dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CancelSpotInstance operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CancelSpotInstance operation: Request would have succeeded, but DryRun flag is set" + ) conn.cancel_spot_instance_requests([requests[0].id]) @@ -177,9 +176,7 @@ def test_request_spot_instances_fulfilled(): """ conn = boto.ec2.connect_to_region("us-east-1") - request = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) + request = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) @@ -187,7 +184,7 @@ def test_request_spot_instances_fulfilled(): request.state.should.equal("open") - get_model('SpotInstanceRequest', 'us-east-1')[0].state = 'active' + get_model("SpotInstanceRequest", "us-east-1")[0].state = "active" requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) @@ -203,18 +200,16 @@ def test_tag_spot_instance_request(): """ conn = boto.connect_ec2() - request = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - request[0].add_tag('tag1', 'value1') - request[0].add_tag('tag2', 'value2') + request = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + request[0].add_tag("tag1", "value1") + request[0].add_tag("tag2", "value2") requests = conn.get_all_spot_instance_requests() requests.should.have.length_of(1) request = requests[0] tag_dict = dict(request.tags) - tag_dict.should.equal({'tag1': 'value1', 'tag2': 'value2'}) + tag_dict.should.equal({"tag1": "value1", "tag2": "value2"}) @mock_ec2_deprecated @@ -224,45 +219,38 @@ def test_get_all_spot_instance_requests_filtering(): """ conn = boto.connect_ec2() - request1 = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - request2 = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234', - ) - request1[0].add_tag('tag1', 'value1') - request1[0].add_tag('tag2', 'value2') - request2[0].add_tag('tag1', 'value1') - request2[0].add_tag('tag2', 'wrong') + request1 = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + request2 = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") + request1[0].add_tag("tag1", "value1") + request1[0].add_tag("tag2", "value2") + request2[0].add_tag("tag1", "value1") + request2[0].add_tag("tag2", "wrong") - requests = conn.get_all_spot_instance_requests(filters={'state': 'active'}) + requests = conn.get_all_spot_instance_requests(filters={"state": "active"}) requests.should.have.length_of(0) - requests = conn.get_all_spot_instance_requests(filters={'state': 'open'}) + requests = conn.get_all_spot_instance_requests(filters={"state": "open"}) requests.should.have.length_of(3) - requests = conn.get_all_spot_instance_requests( - filters={'tag:tag1': 'value1'}) + requests = conn.get_all_spot_instance_requests(filters={"tag:tag1": "value1"}) requests.should.have.length_of(2) requests = conn.get_all_spot_instance_requests( - filters={'tag:tag1': 'value1', 'tag:tag2': 'value2'}) + filters={"tag:tag1": "value1", "tag:tag2": "value2"} + ) requests.should.have.length_of(1) @mock_ec2_deprecated def test_request_spot_instances_setting_instance_id(): conn = boto.ec2.connect_to_region("us-east-1") - request = conn.request_spot_instances( - price=0.5, image_id='ami-abcd1234') + request = conn.request_spot_instances(price=0.5, image_id="ami-abcd1234") - req = get_model('SpotInstanceRequest', 'us-east-1')[0] - req.state = 'active' - req.instance_id = 'i-12345678' + req = get_model("SpotInstanceRequest", "us-east-1")[0] + req.state = "active" + req.instance_id = "i-12345678" request = conn.get_all_spot_instance_requests()[0] - assert request.state == 'active' - assert request.instance_id == 'i-12345678' + assert request.state == "active" + assert request.instance_id == "i-12345678" diff --git a/tests/test_ec2/test_subnets.py b/tests/test_ec2/test_subnets.py index 38c36f682..f5f1af433 100644 --- a/tests/test_ec2/test_subnets.py +++ b/tests/test_ec2/test_subnets.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises # noqa from nose.tools import assert_raises @@ -16,8 +17,8 @@ from moto import mock_cloudformation_deprecated, mock_ec2, mock_ec2_deprecated @mock_ec2_deprecated def test_subnets(): - ec2 = boto.connect_ec2('the_key', 'the_secret') - conn = boto.connect_vpc('the_key', 'the_secret') + ec2 = boto.connect_ec2("the_key", "the_secret") + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") @@ -31,25 +32,25 @@ def test_subnets(): with assert_raises(EC2ResponseError) as cm: conn.delete_subnet(subnet.id) - cm.exception.code.should.equal('InvalidSubnetID.NotFound') + cm.exception.code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_subnet_create_vpc_validation(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: conn.create_subnet("vpc-abcd1234", "10.0.0.0/18") - cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_subnet_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") subnet = conn.create_subnet(vpc.id, "10.0.0.0/18") @@ -67,31 +68,31 @@ def test_subnet_tagging(): @mock_ec2_deprecated def test_subnet_should_have_proper_availability_zone_set(): - conn = boto.vpc.connect_to_region('us-west-1') + conn = boto.vpc.connect_to_region("us-west-1") vpcA = conn.create_vpc("10.0.0.0/16") - subnetA = conn.create_subnet( - vpcA.id, "10.0.0.0/24", availability_zone='us-west-1b') - subnetA.availability_zone.should.equal('us-west-1b') + subnetA = conn.create_subnet(vpcA.id, "10.0.0.0/24", availability_zone="us-west-1b") + subnetA.availability_zone.should.equal("us-west-1b") @mock_ec2 def test_default_subnet(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") default_vpc = list(ec2.vpcs.all())[0] - default_vpc.cidr_block.should.equal('172.31.0.0/16') + default_vpc.cidr_block.should.equal("172.31.0.0/16") default_vpc.reload() default_vpc.is_default.should.be.ok subnet = ec2.create_subnet( - VpcId=default_vpc.id, CidrBlock='172.31.48.0/20', AvailabilityZone='us-west-1a') + VpcId=default_vpc.id, CidrBlock="172.31.48.0/20", AvailabilityZone="us-west-1a" + ) subnet.reload() subnet.map_public_ip_on_launch.shouldnt.be.ok @mock_ec2_deprecated def test_non_default_subnet(): - vpc_cli = boto.vpc.connect_to_region('us-west-1') + vpc_cli = boto.vpc.connect_to_region("us-west-1") # Create the non default VPC vpc = vpc_cli.create_vpc("10.0.0.0/16") @@ -99,34 +100,36 @@ def test_non_default_subnet(): subnet = vpc_cli.create_subnet(vpc.id, "10.0.0.0/24") subnet = vpc_cli.get_all_subnets(subnet_ids=[subnet.id])[0] - subnet.mapPublicIpOnLaunch.should.equal('false') + subnet.mapPublicIpOnLaunch.should.equal("false") @mock_ec2 def test_boto3_non_default_subnet(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the non default VPC - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + ) subnet.reload() subnet.map_public_ip_on_launch.shouldnt.be.ok @mock_ec2 def test_modify_subnet_attribute_public_ip_on_launch(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") # Get the default VPC vpc = list(ec2.vpcs.all())[0] subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock="172.31.48.0/20", AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="172.31.48.0/20", AvailabilityZone="us-west-1a" + ) # 'map_public_ip_on_launch' is set when calling 'DescribeSubnets' action subnet.reload() @@ -135,26 +138,29 @@ def test_modify_subnet_attribute_public_ip_on_launch(): subnet.map_public_ip_on_launch.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': False}) + SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": False} + ) subnet.reload() subnet.map_public_ip_on_launch.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': True}) + SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": True} + ) subnet.reload() subnet.map_public_ip_on_launch.should.be.ok @mock_ec2 def test_modify_subnet_attribute_assign_ipv6_address_on_creation(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") # Get the default VPC vpc = list(ec2.vpcs.all())[0] subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='172.31.112.0/20', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="172.31.112.0/20", AvailabilityZone="us-west-1a" + ) # 'map_public_ip_on_launch' is set when calling 'DescribeSubnets' action subnet.reload() @@ -163,41 +169,46 @@ def test_modify_subnet_attribute_assign_ipv6_address_on_creation(): subnet.assign_ipv6_address_on_creation.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, AssignIpv6AddressOnCreation={'Value': False}) + SubnetId=subnet.id, AssignIpv6AddressOnCreation={"Value": False} + ) subnet.reload() subnet.assign_ipv6_address_on_creation.shouldnt.be.ok client.modify_subnet_attribute( - SubnetId=subnet.id, AssignIpv6AddressOnCreation={'Value': True}) + SubnetId=subnet.id, AssignIpv6AddressOnCreation={"Value": True} + ) subnet.reload() subnet.assign_ipv6_address_on_creation.should.be.ok @mock_ec2 def test_modify_subnet_attribute_validation(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + ) with assert_raises(ParamValidationError): client.modify_subnet_attribute( - SubnetId=subnet.id, MapPublicIpOnLaunch={'Value': 'invalid'}) + SubnetId=subnet.id, MapPublicIpOnLaunch={"Value": "invalid"} + ) @mock_ec2_deprecated def test_subnet_get_by_id(): - ec2 = boto.ec2.connect_to_region('us-west-1') - conn = boto.vpc.connect_to_region('us-west-1') + ec2 = boto.ec2.connect_to_region("us-west-1") + conn = boto.vpc.connect_to_region("us-west-1") vpcA = conn.create_vpc("10.0.0.0/16") - subnetA = conn.create_subnet( - vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a') + subnetA = conn.create_subnet(vpcA.id, "10.0.0.0/24", availability_zone="us-west-1a") vpcB = conn.create_vpc("10.0.0.0/16") subnetB1 = conn.create_subnet( - vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a') + vpcB.id, "10.0.0.0/24", availability_zone="us-west-1a" + ) subnetB2 = conn.create_subnet( - vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b') + vpcB.id, "10.0.1.0/24", availability_zone="us-west-1b" + ) subnets_by_id = conn.get_all_subnets(subnet_ids=[subnetA.id, subnetB1.id]) subnets_by_id.should.have.length_of(2) @@ -206,85 +217,91 @@ def test_subnet_get_by_id(): subnetB1.id.should.be.within(subnets_by_id) with assert_raises(EC2ResponseError) as cm: - conn.get_all_subnets(subnet_ids=['subnet-does_not_exist']) - cm.exception.code.should.equal('InvalidSubnetID.NotFound') + conn.get_all_subnets(subnet_ids=["subnet-does_not_exist"]) + cm.exception.code.should.equal("InvalidSubnetID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_get_subnets_filtering(): - ec2 = boto.ec2.connect_to_region('us-west-1') - conn = boto.vpc.connect_to_region('us-west-1') + ec2 = boto.ec2.connect_to_region("us-west-1") + conn = boto.vpc.connect_to_region("us-west-1") vpcA = conn.create_vpc("10.0.0.0/16") - subnetA = conn.create_subnet( - vpcA.id, "10.0.0.0/24", availability_zone='us-west-1a') + subnetA = conn.create_subnet(vpcA.id, "10.0.0.0/24", availability_zone="us-west-1a") vpcB = conn.create_vpc("10.0.0.0/16") subnetB1 = conn.create_subnet( - vpcB.id, "10.0.0.0/24", availability_zone='us-west-1a') + vpcB.id, "10.0.0.0/24", availability_zone="us-west-1a" + ) subnetB2 = conn.create_subnet( - vpcB.id, "10.0.1.0/24", availability_zone='us-west-1b') + vpcB.id, "10.0.1.0/24", availability_zone="us-west-1b" + ) all_subnets = conn.get_all_subnets() all_subnets.should.have.length_of(3 + len(ec2.get_all_zones())) # Filter by VPC ID - subnets_by_vpc = conn.get_all_subnets(filters={'vpc-id': vpcB.id}) + subnets_by_vpc = conn.get_all_subnets(filters={"vpc-id": vpcB.id}) subnets_by_vpc.should.have.length_of(2) set([subnet.id for subnet in subnets_by_vpc]).should.equal( - set([subnetB1.id, subnetB2.id])) + set([subnetB1.id, subnetB2.id]) + ) # Filter by CIDR variations - subnets_by_cidr1 = conn.get_all_subnets(filters={'cidr': "10.0.0.0/24"}) + subnets_by_cidr1 = conn.get_all_subnets(filters={"cidr": "10.0.0.0/24"}) subnets_by_cidr1.should.have.length_of(2) - set([subnet.id for subnet in subnets_by_cidr1] - ).should.equal(set([subnetA.id, subnetB1.id])) + set([subnet.id for subnet in subnets_by_cidr1]).should.equal( + set([subnetA.id, subnetB1.id]) + ) - subnets_by_cidr2 = conn.get_all_subnets( - filters={'cidr-block': "10.0.0.0/24"}) + subnets_by_cidr2 = conn.get_all_subnets(filters={"cidr-block": "10.0.0.0/24"}) subnets_by_cidr2.should.have.length_of(2) - set([subnet.id for subnet in subnets_by_cidr2] - ).should.equal(set([subnetA.id, subnetB1.id])) + set([subnet.id for subnet in subnets_by_cidr2]).should.equal( + set([subnetA.id, subnetB1.id]) + ) - subnets_by_cidr3 = conn.get_all_subnets( - filters={'cidrBlock': "10.0.0.0/24"}) + subnets_by_cidr3 = conn.get_all_subnets(filters={"cidrBlock": "10.0.0.0/24"}) subnets_by_cidr3.should.have.length_of(2) - set([subnet.id for subnet in subnets_by_cidr3] - ).should.equal(set([subnetA.id, subnetB1.id])) + set([subnet.id for subnet in subnets_by_cidr3]).should.equal( + set([subnetA.id, subnetB1.id]) + ) # Filter by VPC ID and CIDR subnets_by_vpc_and_cidr = conn.get_all_subnets( - filters={'vpc-id': vpcB.id, 'cidr': "10.0.0.0/24"}) + filters={"vpc-id": vpcB.id, "cidr": "10.0.0.0/24"} + ) subnets_by_vpc_and_cidr.should.have.length_of(1) - set([subnet.id for subnet in subnets_by_vpc_and_cidr] - ).should.equal(set([subnetB1.id])) + set([subnet.id for subnet in subnets_by_vpc_and_cidr]).should.equal( + set([subnetB1.id]) + ) # Filter by subnet ID - subnets_by_id = conn.get_all_subnets(filters={'subnet-id': subnetA.id}) + subnets_by_id = conn.get_all_subnets(filters={"subnet-id": subnetA.id}) subnets_by_id.should.have.length_of(1) set([subnet.id for subnet in subnets_by_id]).should.equal(set([subnetA.id])) # Filter by availabilityZone subnets_by_az = conn.get_all_subnets( - filters={'availabilityZone': 'us-west-1a', 'vpc-id': vpcB.id}) + filters={"availabilityZone": "us-west-1a", "vpc-id": vpcB.id} + ) subnets_by_az.should.have.length_of(1) - set([subnet.id for subnet in subnets_by_az] - ).should.equal(set([subnetB1.id])) + set([subnet.id for subnet in subnets_by_az]).should.equal(set([subnetB1.id])) # Filter by defaultForAz - subnets_by_az = conn.get_all_subnets(filters={'defaultForAz': "true"}) + subnets_by_az = conn.get_all_subnets(filters={"defaultForAz": "true"}) subnets_by_az.should.have.length_of(len(conn.get_all_zones())) # Unsupported filter conn.get_all_subnets.when.called_with( - filters={'not-implemented-filter': 'foobar'}).should.throw(NotImplementedError) + filters={"not-implemented-filter": "foobar"} + ).should.throw(NotImplementedError) @mock_ec2_deprecated @mock_cloudformation_deprecated def test_subnet_tags_through_cloudformation(): - vpc_conn = boto.vpc.connect_to_region('us-west-1') + vpc_conn = boto.vpc.connect_to_region("us-west-1") vpc = vpc_conn.create_vpc("10.0.0.0/16") subnet_template = { @@ -296,151 +313,164 @@ def test_subnet_tags_through_cloudformation(): "VpcId": vpc.id, "CidrBlock": "10.0.0.0/24", "AvailabilityZone": "us-west-1b", - "Tags": [{ - "Key": "foo", - "Value": "bar", - }, { - "Key": "blah", - "Value": "baz", - }] - } + "Tags": [ + {"Key": "foo", "Value": "bar"}, + {"Key": "blah", "Value": "baz"}, + ], + }, } - } + }, } cf_conn = boto.cloudformation.connect_to_region("us-west-1") template_json = json.dumps(subnet_template) - cf_conn.create_stack( - "test_stack", - template_body=template_json, - ) + cf_conn.create_stack("test_stack", template_body=template_json) - subnet = vpc_conn.get_all_subnets(filters={'cidrBlock': '10.0.0.0/24'})[0] + subnet = vpc_conn.get_all_subnets(filters={"cidrBlock": "10.0.0.0/24"})[0] subnet.tags["foo"].should.equal("bar") subnet.tags["blah"].should.equal("baz") @mock_ec2 def test_create_subnet_response_fields(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = client.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a')['Subnet'] + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + )["Subnet"] - subnet.should.have.key('AvailabilityZone') - subnet.should.have.key('AvailabilityZoneId') - subnet.should.have.key('AvailableIpAddressCount') - subnet.should.have.key('CidrBlock') - subnet.should.have.key('State') - subnet.should.have.key('SubnetId') - subnet.should.have.key('VpcId') - subnet.shouldnt.have.key('Tags') - subnet.should.have.key('DefaultForAz').which.should.equal(False) - subnet.should.have.key('MapPublicIpOnLaunch').which.should.equal(False) - subnet.should.have.key('OwnerId') - subnet.should.have.key('AssignIpv6AddressOnCreation').which.should.equal(False) + subnet.should.have.key("AvailabilityZone") + subnet.should.have.key("AvailabilityZoneId") + subnet.should.have.key("AvailableIpAddressCount") + subnet.should.have.key("CidrBlock") + subnet.should.have.key("State") + subnet.should.have.key("SubnetId") + subnet.should.have.key("VpcId") + subnet.shouldnt.have.key("Tags") + subnet.should.have.key("DefaultForAz").which.should.equal(False) + subnet.should.have.key("MapPublicIpOnLaunch").which.should.equal(False) + subnet.should.have.key("OwnerId") + subnet.should.have.key("AssignIpv6AddressOnCreation").which.should.equal(False) - subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format(region=subnet['AvailabilityZone'][0:-1], - owner_id=subnet['OwnerId'], - subnet_id=subnet['SubnetId']) - subnet.should.have.key('SubnetArn').which.should.equal(subnet_arn) - subnet.should.have.key('Ipv6CidrBlockAssociationSet').which.should.equal([]) + subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format( + region=subnet["AvailabilityZone"][0:-1], + owner_id=subnet["OwnerId"], + subnet_id=subnet["SubnetId"], + ) + subnet.should.have.key("SubnetArn").which.should.equal(subnet_arn) + subnet.should.have.key("Ipv6CidrBlockAssociationSet").which.should.equal([]) @mock_ec2 def test_describe_subnet_response_fields(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet_object = ec2.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone='us-west-1a') + VpcId=vpc.id, CidrBlock="10.0.0.0/24", AvailabilityZone="us-west-1a" + ) - subnets = client.describe_subnets(SubnetIds=[subnet_object.id])['Subnets'] + subnets = client.describe_subnets(SubnetIds=[subnet_object.id])["Subnets"] subnets.should.have.length_of(1) subnet = subnets[0] - subnet.should.have.key('AvailabilityZone') - subnet.should.have.key('AvailabilityZoneId') - subnet.should.have.key('AvailableIpAddressCount') - subnet.should.have.key('CidrBlock') - subnet.should.have.key('State') - subnet.should.have.key('SubnetId') - subnet.should.have.key('VpcId') - subnet.shouldnt.have.key('Tags') - subnet.should.have.key('DefaultForAz').which.should.equal(False) - subnet.should.have.key('MapPublicIpOnLaunch').which.should.equal(False) - subnet.should.have.key('OwnerId') - subnet.should.have.key('AssignIpv6AddressOnCreation').which.should.equal(False) + subnet.should.have.key("AvailabilityZone") + subnet.should.have.key("AvailabilityZoneId") + subnet.should.have.key("AvailableIpAddressCount") + subnet.should.have.key("CidrBlock") + subnet.should.have.key("State") + subnet.should.have.key("SubnetId") + subnet.should.have.key("VpcId") + subnet.shouldnt.have.key("Tags") + subnet.should.have.key("DefaultForAz").which.should.equal(False) + subnet.should.have.key("MapPublicIpOnLaunch").which.should.equal(False) + subnet.should.have.key("OwnerId") + subnet.should.have.key("AssignIpv6AddressOnCreation").which.should.equal(False) - subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format(region=subnet['AvailabilityZone'][0:-1], - owner_id=subnet['OwnerId'], - subnet_id=subnet['SubnetId']) - subnet.should.have.key('SubnetArn').which.should.equal(subnet_arn) - subnet.should.have.key('Ipv6CidrBlockAssociationSet').which.should.equal([]) + subnet_arn = "arn:aws:ec2:{region}:{owner_id}:subnet/{subnet_id}".format( + region=subnet["AvailabilityZone"][0:-1], + owner_id=subnet["OwnerId"], + subnet_id=subnet["SubnetId"], + ) + subnet.should.have.key("SubnetArn").which.should.equal(subnet_arn) + subnet.should.have.key("Ipv6CidrBlockAssociationSet").which.should.equal([]) @mock_ec2 def test_create_subnet_with_invalid_availability_zone(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - client = boto3.client('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") + client = boto3.client("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") - subnet_availability_zone = 'asfasfas' + subnet_availability_zone = "asfasfas" with assert_raises(ClientError) as ex: subnet = client.create_subnet( - VpcId=vpc.id, CidrBlock='10.0.0.0/24', AvailabilityZone=subnet_availability_zone) + VpcId=vpc.id, + CidrBlock="10.0.0.0/24", + AvailabilityZone=subnet_availability_zone, + ) assert str(ex.exception).startswith( "An error occurred (InvalidParameterValue) when calling the CreateSubnet " - "operation: Value ({}) for parameter availabilityZone is invalid. Subnets can currently only be created in the following availability zones: ".format(subnet_availability_zone)) + "operation: Value ({}) for parameter availabilityZone is invalid. Subnets can currently only be created in the following availability zones: ".format( + subnet_availability_zone + ) + ) @mock_ec2 def test_create_subnet_with_invalid_cidr_range(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok - subnet_cidr_block = '10.1.0.0/20' + subnet_cidr_block = "10.1.0.0/20" with assert_raises(ClientError) as ex: subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidSubnet.Range) when calling the CreateSubnet " - "operation: The CIDR '{}' is invalid.".format(subnet_cidr_block)) + "operation: The CIDR '{}' is invalid.".format(subnet_cidr_block) + ) @mock_ec2 def test_create_subnet_with_invalid_cidr_block_parameter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok - subnet_cidr_block = '1000.1.0.0/20' + subnet_cidr_block = "1000.1.0.0/20" with assert_raises(ClientError) as ex: subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidParameterValue) when calling the CreateSubnet " - "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(subnet_cidr_block)) + "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format( + subnet_cidr_block + ) + ) @mock_ec2 def test_create_subnets_with_overlapping_cidr_blocks(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok - subnet_cidr_block = '10.0.0.0/24' + subnet_cidr_block = "10.0.0.0/24" with assert_raises(ClientError) as ex: subnet1 = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) subnet2 = ec2.create_subnet(VpcId=vpc.id, CidrBlock=subnet_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidSubnet.Conflict) when calling the CreateSubnet " - "operation: The CIDR '{}' conflicts with another subnet".format(subnet_cidr_block)) + "operation: The CIDR '{}' conflicts with another subnet".format( + subnet_cidr_block + ) + ) diff --git a/tests/test_ec2/test_tags.py b/tests/test_ec2/test_tags.py index 2294979ba..29d2cb1e3 100644 --- a/tests/test_ec2/test_tags.py +++ b/tests/test_ec2/test_tags.py @@ -16,21 +16,23 @@ from nose.tools import assert_raises @mock_ec2_deprecated def test_add_tag(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as ex: instance.add_tag("a key", "some value", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) instance.add_tag("a key", "some value") chain = itertools.chain.from_iterable existing_instances = list( - chain([res.instances for res in conn.get_all_instances()])) + chain([res.instances for res in conn.get_all_instances()]) + ) existing_instances.should.have.length_of(1) existing_instance = existing_instances[0] existing_instance.tags["a key"].should.equal("some value") @@ -38,8 +40,8 @@ def test_add_tag(): @mock_ec2_deprecated def test_remove_tag(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("a key", "some value") @@ -51,10 +53,11 @@ def test_remove_tag(): with assert_raises(EC2ResponseError) as ex: instance.remove_tag("a key", dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the DeleteTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the DeleteTags operation: Request would have succeeded, but DryRun flag is set" + ) instance.remove_tag("a key") conn.get_all_tags().should.have.length_of(0) @@ -66,8 +69,8 @@ def test_remove_tag(): @mock_ec2_deprecated def test_get_all_tags(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("a key", "some value") @@ -80,8 +83,8 @@ def test_get_all_tags(): @mock_ec2_deprecated def test_get_all_tags_with_special_characters(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("a key", "some<> value") @@ -94,47 +97,50 @@ def test_get_all_tags_with_special_characters(): @mock_ec2_deprecated def test_create_tags(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] - tag_dict = {'a key': 'some value', - 'another key': 'some other value', - 'blank key': ''} + tag_dict = { + "a key": "some value", + "another key": "some other value", + "blank key": "", + } with assert_raises(EC2ResponseError) as ex: conn.create_tags(instance.id, tag_dict, dry_run=True) - ex.exception.error_code.should.equal('DryRunOperation') + ex.exception.error_code.should.equal("DryRunOperation") ex.exception.status.should.equal(400) ex.exception.message.should.equal( - 'An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set') + "An error occurred (DryRunOperation) when calling the CreateTags operation: Request would have succeeded, but DryRun flag is set" + ) conn.create_tags(instance.id, tag_dict) tags = conn.get_all_tags() - set([key for key in tag_dict]).should.equal( - set([tag.name for tag in tags])) + set([key for key in tag_dict]).should.equal(set([tag.name for tag in tags])) set([tag_dict[key] for key in tag_dict]).should.equal( - set([tag.value for tag in tags])) + set([tag.value for tag in tags]) + ) @mock_ec2_deprecated def test_tag_limit_exceeded(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] tag_dict = {} for i in range(51): - tag_dict['{0:02d}'.format(i + 1)] = '' + tag_dict["{0:02d}".format(i + 1)] = "" with assert_raises(EC2ResponseError) as cm: conn.create_tags(instance.id, tag_dict) - cm.exception.code.should.equal('TagLimitExceeded') + cm.exception.code.should.equal("TagLimitExceeded") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none instance.add_tag("a key", "a value") with assert_raises(EC2ResponseError) as cm: conn.create_tags(instance.id, tag_dict) - cm.exception.code.should.equal('TagLimitExceeded') + cm.exception.code.should.equal("TagLimitExceeded") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -147,158 +153,158 @@ def test_tag_limit_exceeded(): @mock_ec2_deprecated def test_invalid_parameter_tag_null(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] with assert_raises(EC2ResponseError) as cm: instance.add_tag("a key", None) - cm.exception.code.should.equal('InvalidParameterValue') + cm.exception.code.should.equal("InvalidParameterValue") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_invalid_id(): - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") with assert_raises(EC2ResponseError) as cm: - conn.create_tags('ami-blah', {'key': 'tag'}) - cm.exception.code.should.equal('InvalidID') + conn.create_tags("ami-blah", {"key": "tag"}) + cm.exception.code.should.equal("InvalidID") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none with assert_raises(EC2ResponseError) as cm: - conn.create_tags('blah-blah', {'key': 'tag'}) - cm.exception.code.should.equal('InvalidID') + conn.create_tags("blah-blah", {"key": "tag"}) + cm.exception.code.should.equal("InvalidID") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_get_all_tags_resource_id_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'resource-id': instance.id}) + tags = conn.get_all_tags(filters={"resource-id": instance.id}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(instance.id) - tag.res_type.should.equal('instance') + tag.res_type.should.equal("instance") tag.name.should.equal("an instance key") tag.value.should.equal("some value") - tags = conn.get_all_tags(filters={'resource-id': image_id}) + tags = conn.get_all_tags(filters={"resource-id": image_id}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(image_id) - tag.res_type.should.equal('image') + tag.res_type.should.equal("image") tag.name.should.equal("an image key") tag.value.should.equal("some value") @mock_ec2_deprecated def test_get_all_tags_resource_type_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'resource-type': 'instance'}) + tags = conn.get_all_tags(filters={"resource-type": "instance"}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(instance.id) - tag.res_type.should.equal('instance') + tag.res_type.should.equal("instance") tag.name.should.equal("an instance key") tag.value.should.equal("some value") - tags = conn.get_all_tags(filters={'resource-type': 'image'}) + tags = conn.get_all_tags(filters={"resource-type": "image"}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(image_id) - tag.res_type.should.equal('image') + tag.res_type.should.equal("image") tag.name.should.equal("an image key") tag.value.should.equal("some value") @mock_ec2_deprecated def test_get_all_tags_key_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'key': 'an instance key'}) + tags = conn.get_all_tags(filters={"key": "an instance key"}) tag = tags[0] tags.should.have.length_of(1) tag.res_id.should.equal(instance.id) - tag.res_type.should.equal('instance') + tag.res_type.should.equal("instance") tag.name.should.equal("an instance key") tag.value.should.equal("some value") @mock_ec2_deprecated def test_get_all_tags_value_filter(): - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") instance = reservation.instances[0] instance.add_tag("an instance key", "some value") - reservation_b = conn.run_instances('ami-1234abcd') + reservation_b = conn.run_instances("ami-1234abcd") instance_b = reservation_b.instances[0] instance_b.add_tag("an instance key", "some other value") - reservation_c = conn.run_instances('ami-1234abcd') + reservation_c = conn.run_instances("ami-1234abcd") instance_c = reservation_c.instances[0] instance_c.add_tag("an instance key", "other value*") - reservation_d = conn.run_instances('ami-1234abcd') + reservation_d = conn.run_instances("ami-1234abcd") instance_d = reservation_d.instances[0] instance_d.add_tag("an instance key", "other value**") - reservation_e = conn.run_instances('ami-1234abcd') + reservation_e = conn.run_instances("ami-1234abcd") instance_e = reservation_e.instances[0] instance_e.add_tag("an instance key", "other value*?") image_id = conn.create_image(instance.id, "test-ami", "this is a test ami") image = conn.get_image(image_id) image.add_tag("an image key", "some value") - tags = conn.get_all_tags(filters={'value': 'some value'}) + tags = conn.get_all_tags(filters={"value": "some value"}) tags.should.have.length_of(2) - tags = conn.get_all_tags(filters={'value': 'some*value'}) + tags = conn.get_all_tags(filters={"value": "some*value"}) tags.should.have.length_of(3) - tags = conn.get_all_tags(filters={'value': '*some*value'}) + tags = conn.get_all_tags(filters={"value": "*some*value"}) tags.should.have.length_of(3) - tags = conn.get_all_tags(filters={'value': '*some*value*'}) + tags = conn.get_all_tags(filters={"value": "*some*value*"}) tags.should.have.length_of(3) - tags = conn.get_all_tags(filters={'value': '*value\*'}) + tags = conn.get_all_tags(filters={"value": "*value\*"}) tags.should.have.length_of(1) - tags = conn.get_all_tags(filters={'value': '*value\*\*'}) + tags = conn.get_all_tags(filters={"value": "*value\*\*"}) tags.should.have.length_of(1) - tags = conn.get_all_tags(filters={'value': '*value\*\?'}) + tags = conn.get_all_tags(filters={"value": "*value\*\?"}) tags.should.have.length_of(1) @mock_ec2_deprecated def test_retrieved_instances_must_contain_their_tags(): - tag_key = 'Tag name' - tag_value = 'Tag value' + tag_key = "Tag name" + tag_value = "Tag value" tags_to_be_set = {tag_key: tag_value} - conn = boto.connect_ec2('the_key', 'the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2("the_key", "the_secret") + reservation = conn.run_instances("ami-1234abcd") reservation.should.be.a(Reservation) reservation.instances.should.have.length_of(1) instance = reservation.instances[0] @@ -324,10 +330,10 @@ def test_retrieved_instances_must_contain_their_tags(): @mock_ec2_deprecated def test_retrieved_volumes_must_contain_their_tags(): - tag_key = 'Tag name' - tag_value = 'Tag value' + tag_key = "Tag name" + tag_value = "Tag value" tags_to_be_set = {tag_key: tag_value} - conn = boto.connect_ec2('the_key', 'the_secret') + conn = boto.connect_ec2("the_key", "the_secret") volume = conn.create_volume(80, "us-east-1a") all_volumes = conn.get_all_volumes() @@ -347,11 +353,12 @@ def test_retrieved_volumes_must_contain_their_tags(): @mock_ec2_deprecated def test_retrieved_snapshots_must_contain_their_tags(): - tag_key = 'Tag name' - tag_value = 'Tag value' + tag_key = "Tag name" + tag_value = "Tag value" tags_to_be_set = {tag_key: tag_value} - conn = boto.connect_ec2(aws_access_key_id='the_key', - aws_secret_access_key='the_secret') + conn = boto.connect_ec2( + aws_access_key_id="the_key", aws_secret_access_key="the_secret" + ) volume = conn.create_volume(80, "eu-west-1a") snapshot = conn.create_snapshot(volume.id) conn.create_tags([snapshot.id], tags_to_be_set) @@ -370,113 +377,94 @@ def test_retrieved_snapshots_must_contain_their_tags(): @mock_ec2_deprecated def test_filter_instances_by_wildcard_tags(): - conn = boto.connect_ec2(aws_access_key_id='the_key', - aws_secret_access_key='the_secret') - reservation = conn.run_instances('ami-1234abcd') + conn = boto.connect_ec2( + aws_access_key_id="the_key", aws_secret_access_key="the_secret" + ) + reservation = conn.run_instances("ami-1234abcd") instance_a = reservation.instances[0] instance_a.add_tag("Key1", "Value1") - reservation_b = conn.run_instances('ami-1234abcd') + reservation_b = conn.run_instances("ami-1234abcd") instance_b = reservation_b.instances[0] instance_b.add_tag("Key1", "Value2") - reservations = conn.get_all_instances(filters={'tag:Key1': 'Value*'}) + reservations = conn.get_all_instances(filters={"tag:Key1": "Value*"}) reservations.should.have.length_of(2) - reservations = conn.get_all_instances(filters={'tag-key': 'Key*'}) + reservations = conn.get_all_instances(filters={"tag-key": "Key*"}) reservations.should.have.length_of(2) - reservations = conn.get_all_instances(filters={'tag-value': 'Value*'}) + reservations = conn.get_all_instances(filters={"tag-value": "Value*"}) reservations.should.have.length_of(2) @mock_ec2 def test_create_volume_with_tags(): - client = boto3.client('ec2', 'us-west-2') + client = boto3.client("ec2", "us-west-2") response = client.create_volume( - AvailabilityZone='us-west-2', - Encrypted=False, - Size=40, - TagSpecifications=[ - { - 'ResourceType': 'volume', - 'Tags': [ - { - 'Key': 'TEST_TAG', - 'Value': 'TEST_VALUE' - } - ], - } - ] - ) - - assert response['Tags'][0]['Key'] == 'TEST_TAG' - - -@mock_ec2 -def test_create_snapshot_with_tags(): - client = boto3.client('ec2', 'us-west-2') - volume_id = client.create_volume( - AvailabilityZone='us-west-2', + AvailabilityZone="us-west-2", Encrypted=False, Size=40, TagSpecifications=[ { - 'ResourceType': 'volume', - 'Tags': [ - { - 'Key': 'TEST_TAG', - 'Value': 'TEST_VALUE' - } - ], + "ResourceType": "volume", + "Tags": [{"Key": "TEST_TAG", "Value": "TEST_VALUE"}], } - ] - )['VolumeId'] + ], + ) + + assert response["Tags"][0]["Key"] == "TEST_TAG" + + +@mock_ec2 +def test_create_snapshot_with_tags(): + client = boto3.client("ec2", "us-west-2") + volume_id = client.create_volume( + AvailabilityZone="us-west-2", + Encrypted=False, + Size=40, + TagSpecifications=[ + { + "ResourceType": "volume", + "Tags": [{"Key": "TEST_TAG", "Value": "TEST_VALUE"}], + } + ], + )["VolumeId"] snapshot = client.create_snapshot( VolumeId=volume_id, TagSpecifications=[ { - 'ResourceType': 'snapshot', - 'Tags': [ - { - 'Key': 'TEST_SNAPSHOT_TAG', - 'Value': 'TEST_SNAPSHOT_VALUE' - } - ], + "ResourceType": "snapshot", + "Tags": [{"Key": "TEST_SNAPSHOT_TAG", "Value": "TEST_SNAPSHOT_VALUE"}], } - ] + ], ) - expected_tags = [{ - 'Key': 'TEST_SNAPSHOT_TAG', - 'Value': 'TEST_SNAPSHOT_VALUE' - }] + expected_tags = [{"Key": "TEST_SNAPSHOT_TAG", "Value": "TEST_SNAPSHOT_VALUE"}] - assert snapshot['Tags'] == expected_tags + assert snapshot["Tags"] == expected_tags @mock_ec2 def test_create_tag_empty_resource(): # create ec2 client in us-west-1 - client = boto3.client('ec2', region_name='us-west-1') + client = boto3.client("ec2", region_name="us-west-1") # create tag with empty resource with assert_raises(ClientError) as ex: - client.create_tags( - Resources=[], - Tags=[{'Key': 'Value'}] - ) - ex.exception.response['Error']['Code'].should.equal('MissingParameter') - ex.exception.response['Error']['Message'].should.equal('The request must contain the parameter resourceIdSet') + client.create_tags(Resources=[], Tags=[{"Key": "Value"}]) + ex.exception.response["Error"]["Code"].should.equal("MissingParameter") + ex.exception.response["Error"]["Message"].should.equal( + "The request must contain the parameter resourceIdSet" + ) @mock_ec2 def test_delete_tag_empty_resource(): # create ec2 client in us-west-1 - client = boto3.client('ec2', region_name='us-west-1') + client = boto3.client("ec2", region_name="us-west-1") # delete tag with empty resource with assert_raises(ClientError) as ex: - client.delete_tags( - Resources=[], - Tags=[{'Key': 'Value'}] - ) - ex.exception.response['Error']['Code'].should.equal('MissingParameter') - ex.exception.response['Error']['Message'].should.equal('The request must contain the parameter resourceIdSet') + client.delete_tags(Resources=[], Tags=[{"Key": "Value"}]) + ex.exception.response["Error"]["Code"].should.equal("MissingParameter") + ex.exception.response["Error"]["Message"].should.equal( + "The request must contain the parameter resourceIdSet" + ) diff --git a/tests/test_ec2/test_utils.py b/tests/test_ec2/test_utils.py index 49192dc79..75e3953bf 100644 --- a/tests/test_ec2/test_utils.py +++ b/tests/test_ec2/test_utils.py @@ -5,8 +5,8 @@ from .helpers import rsa_check_private_key def test_random_key_pair(): key_pair = utils.random_key_pair() - rsa_check_private_key(key_pair['material']) + rsa_check_private_key(key_pair["material"]) # AWS uses MD5 fingerprints, which are 47 characters long, *not* SHA1 # fingerprints with 59 characters. - assert len(key_pair['fingerprint']) == 47 + assert len(key_pair["fingerprint"]) == 47 diff --git a/tests/test_ec2/test_virtual_private_gateways.py b/tests/test_ec2/test_virtual_private_gateways.py index d90e97b45..f778ac3e5 100644 --- a/tests/test_ec2/test_virtual_private_gateways.py +++ b/tests/test_ec2/test_virtual_private_gateways.py @@ -7,54 +7,51 @@ from moto import mock_ec2_deprecated @mock_ec2_deprecated def test_virtual_private_gateways(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway.should_not.be.none - vpn_gateway.id.should.match(r'vgw-\w+') - vpn_gateway.type.should.equal('ipsec.1') - vpn_gateway.state.should.equal('available') - vpn_gateway.availability_zone.should.equal('us-east-1a') + vpn_gateway.id.should.match(r"vgw-\w+") + vpn_gateway.type.should.equal("ipsec.1") + vpn_gateway.state.should.equal("available") + vpn_gateway.availability_zone.should.equal("us-east-1a") @mock_ec2_deprecated def test_describe_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + conn = boto.connect_vpc("the_key", "the_secret") + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vgws = conn.get_all_vpn_gateways() vgws.should.have.length_of(1) gateway = vgws[0] - gateway.id.should.match(r'vgw-\w+') + gateway.id.should.match(r"vgw-\w+") gateway.id.should.equal(vpn_gateway.id) - vpn_gateway.type.should.equal('ipsec.1') - vpn_gateway.state.should.equal('available') - vpn_gateway.availability_zone.should.equal('us-east-1a') + vpn_gateway.type.should.equal("ipsec.1") + vpn_gateway.state.should.equal("available") + vpn_gateway.availability_zone.should.equal("us-east-1a") @mock_ec2_deprecated def test_vpn_gateway_vpc_attachment(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") - conn.attach_vpn_gateway( - vpn_gateway_id=vpn_gateway.id, - vpc_id=vpc.id - ) + conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) gateway = conn.get_all_vpn_gateways()[0] attachments = gateway.attachments attachments.should.have.length_of(1) attachments[0].vpc_id.should.equal(vpc.id) - attachments[0].state.should.equal('attached') + attachments[0].state.should.equal("attached") @mock_ec2_deprecated def test_delete_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + conn = boto.connect_vpc("the_key", "the_secret") + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") conn.delete_vpn_gateway(vpn_gateway.id) vgws = conn.get_all_vpn_gateways() @@ -63,8 +60,8 @@ def test_delete_vpn_gateway(): @mock_ec2_deprecated def test_vpn_gateway_tagging(): - conn = boto.connect_vpc('the_key', 'the_secret') - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + conn = boto.connect_vpc("the_key", "the_secret") + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") vpn_gateway.add_tag("a key", "some value") tag = conn.get_all_tags()[0] @@ -80,25 +77,19 @@ def test_vpn_gateway_tagging(): @mock_ec2_deprecated def test_detach_vpn_gateway(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - vpn_gateway = conn.create_vpn_gateway('ipsec.1', 'us-east-1a') + vpn_gateway = conn.create_vpn_gateway("ipsec.1", "us-east-1a") - conn.attach_vpn_gateway( - vpn_gateway_id=vpn_gateway.id, - vpc_id=vpc.id - ) + conn.attach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) gateway = conn.get_all_vpn_gateways()[0] attachments = gateway.attachments attachments.should.have.length_of(1) attachments[0].vpc_id.should.equal(vpc.id) - attachments[0].state.should.equal('attached') + attachments[0].state.should.equal("attached") - conn.detach_vpn_gateway( - vpn_gateway_id=vpn_gateway.id, - vpc_id=vpc.id - ) + conn.detach_vpn_gateway(vpn_gateway_id=vpn_gateway.id, vpc_id=vpc.id) gateway = conn.get_all_vpn_gateways()[0] attachments = gateway.attachments diff --git a/tests/test_ec2/test_vpc_peering.py b/tests/test_ec2/test_vpc_peering.py index edfbfb3c2..fc1646961 100644 --- a/tests/test_ec2/test_vpc_peering.py +++ b/tests/test_ec2/test_vpc_peering.py @@ -1,4 +1,5 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises from nose.tools import assert_raises @@ -17,12 +18,12 @@ from tests.helpers import requires_boto_gte @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") peer_vpc = conn.create_vpc("11.0.0.0/16") vpc_pcx = conn.create_vpc_peering_connection(vpc.id, peer_vpc.id) - vpc_pcx._status.code.should.equal('initiating-request') + vpc_pcx._status.code.should.equal("initiating-request") return vpc_pcx @@ -30,39 +31,39 @@ def test_vpc_peering_connections(): @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections_get_all(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() - vpc_pcx._status.code.should.equal('initiating-request') + vpc_pcx._status.code.should.equal("initiating-request") all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('pending-acceptance') + all_vpc_pcxs[0]._status.code.should.equal("pending-acceptance") @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections_accept(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() vpc_pcx = conn.accept_vpc_peering_connection(vpc_pcx.id) - vpc_pcx._status.code.should.equal('active') + vpc_pcx._status.code.should.equal("active") with assert_raises(EC2ResponseError) as cm: conn.reject_vpc_peering_connection(vpc_pcx.id) - cm.exception.code.should.equal('InvalidStateTransition') + cm.exception.code.should.equal("InvalidStateTransition") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('active') + all_vpc_pcxs[0]._status.code.should.equal("active") @requires_boto_gte("2.32.0") @mock_ec2_deprecated def test_vpc_peering_connections_reject(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() verdict = conn.reject_vpc_peering_connection(vpc_pcx.id) @@ -70,19 +71,19 @@ def test_vpc_peering_connections_reject(): with assert_raises(EC2ResponseError) as cm: conn.accept_vpc_peering_connection(vpc_pcx.id) - cm.exception.code.should.equal('InvalidStateTransition') + cm.exception.code.should.equal("InvalidStateTransition") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('rejected') + all_vpc_pcxs[0]._status.code.should.equal("rejected") @requires_boto_gte("2.32.1") @mock_ec2_deprecated def test_vpc_peering_connections_delete(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc_pcx = test_vpc_peering_connections() verdict = vpc_pcx.delete() @@ -90,11 +91,11 @@ def test_vpc_peering_connections_delete(): all_vpc_pcxs = conn.get_all_vpc_peering_connections() all_vpc_pcxs.should.have.length_of(1) - all_vpc_pcxs[0]._status.code.should.equal('deleted') + all_vpc_pcxs[0]._status.code.should.equal("deleted") with assert_raises(EC2ResponseError) as cm: conn.delete_vpc_peering_connection("pcx-1234abcd") - cm.exception.code.should.equal('InvalidVpcPeeringConnectionId.NotFound') + cm.exception.code.should.equal("InvalidVpcPeeringConnectionId.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -102,17 +103,15 @@ def test_vpc_peering_connections_delete(): @mock_ec2 def test_vpc_peering_connections_cross_region(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) - vpc_pcx_usw1.status['Code'].should.equal('initiating-request') + vpc_pcx_usw1.status["Code"].should.equal("initiating-request") vpc_pcx_usw1.requester_vpc.id.should.equal(vpc_usw1.id) vpc_pcx_usw1.accepter_vpc.id.should.equal(vpc_apn1.id) # test cross region vpc peering connection exist @@ -125,35 +124,32 @@ def test_vpc_peering_connections_cross_region(): @mock_ec2 def test_vpc_peering_connections_cross_region_fail(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering wrong region with no vpc with assert_raises(ClientError) as cm: ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-2') - cm.exception.response['Error']['Code'].should.equal('InvalidVpcID.NotFound') + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-2" + ) + cm.exception.response["Error"]["Code"].should.equal("InvalidVpcID.NotFound") @mock_ec2 def test_vpc_peering_connections_cross_region_accept(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # accept peering from ap-northeast-1 - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") acp_pcx_apn1 = ec2_apn1.accept_vpc_peering_connection( VpcPeeringConnectionId=vpc_pcx_usw1.id ) @@ -163,27 +159,25 @@ def test_vpc_peering_connections_cross_region_accept(): des_pcx_usw1 = ec2_usw1.describe_vpc_peering_connections( VpcPeeringConnectionIds=[vpc_pcx_usw1.id] ) - acp_pcx_apn1['VpcPeeringConnection']['Status']['Code'].should.equal('active') - des_pcx_apn1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('active') - des_pcx_usw1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('active') + acp_pcx_apn1["VpcPeeringConnection"]["Status"]["Code"].should.equal("active") + des_pcx_apn1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("active") + des_pcx_usw1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("active") @mock_ec2 def test_vpc_peering_connections_cross_region_reject(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # reject peering from ap-northeast-1 - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") rej_pcx_apn1 = ec2_apn1.reject_vpc_peering_connection( VpcPeeringConnectionId=vpc_pcx_usw1.id ) @@ -193,27 +187,25 @@ def test_vpc_peering_connections_cross_region_reject(): des_pcx_usw1 = ec2_usw1.describe_vpc_peering_connections( VpcPeeringConnectionIds=[vpc_pcx_usw1.id] ) - rej_pcx_apn1['Return'].should.equal(True) - des_pcx_apn1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('rejected') - des_pcx_usw1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('rejected') + rej_pcx_apn1["Return"].should.equal(True) + des_pcx_apn1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("rejected") + des_pcx_usw1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("rejected") @mock_ec2 def test_vpc_peering_connections_cross_region_delete(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # reject peering from ap-northeast-1 - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") del_pcx_apn1 = ec2_apn1.delete_vpc_peering_connection( VpcPeeringConnectionId=vpc_pcx_usw1.id ) @@ -223,61 +215,57 @@ def test_vpc_peering_connections_cross_region_delete(): des_pcx_usw1 = ec2_usw1.describe_vpc_peering_connections( VpcPeeringConnectionIds=[vpc_pcx_usw1.id] ) - del_pcx_apn1['Return'].should.equal(True) - des_pcx_apn1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('deleted') - des_pcx_usw1['VpcPeeringConnections'][0]['Status']['Code'].should.equal('deleted') + del_pcx_apn1["Return"].should.equal(True) + des_pcx_apn1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("deleted") + des_pcx_usw1["VpcPeeringConnections"][0]["Status"]["Code"].should.equal("deleted") @mock_ec2 def test_vpc_peering_connections_cross_region_accept_wrong_region(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # accept wrong peering from us-west-1 which will raise error - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") with assert_raises(ClientError) as cm: - ec2_usw1.accept_vpc_peering_connection( - VpcPeeringConnectionId=vpc_pcx_usw1.id - ) - cm.exception.response['Error']['Code'].should.equal('OperationNotPermitted') - exp_msg = 'Incorrect region ({0}) specified for this request.VPC ' \ - 'peering connection {1} must be ' \ - 'accepted in region {2}'.format('us-west-1', vpc_pcx_usw1.id, 'ap-northeast-1') - cm.exception.response['Error']['Message'].should.equal(exp_msg) + ec2_usw1.accept_vpc_peering_connection(VpcPeeringConnectionId=vpc_pcx_usw1.id) + cm.exception.response["Error"]["Code"].should.equal("OperationNotPermitted") + exp_msg = ( + "Incorrect region ({0}) specified for this request.VPC " + "peering connection {1} must be " + "accepted in region {2}".format("us-west-1", vpc_pcx_usw1.id, "ap-northeast-1") + ) + cm.exception.response["Error"]["Message"].should.equal(exp_msg) @mock_ec2 def test_vpc_peering_connections_cross_region_reject_wrong_region(): # create vpc in us-west-1 and ap-northeast-1 - ec2_usw1 = boto3.resource('ec2', region_name='us-west-1') - vpc_usw1 = ec2_usw1.create_vpc(CidrBlock='10.90.0.0/16') - ec2_apn1 = boto3.resource('ec2', region_name='ap-northeast-1') - vpc_apn1 = ec2_apn1.create_vpc(CidrBlock='10.20.0.0/16') + ec2_usw1 = boto3.resource("ec2", region_name="us-west-1") + vpc_usw1 = ec2_usw1.create_vpc(CidrBlock="10.90.0.0/16") + ec2_apn1 = boto3.resource("ec2", region_name="ap-northeast-1") + vpc_apn1 = ec2_apn1.create_vpc(CidrBlock="10.20.0.0/16") # create peering vpc_pcx_usw1 = ec2_usw1.create_vpc_peering_connection( - VpcId=vpc_usw1.id, - PeerVpcId=vpc_apn1.id, - PeerRegion='ap-northeast-1', + VpcId=vpc_usw1.id, PeerVpcId=vpc_apn1.id, PeerRegion="ap-northeast-1" ) # reject wrong peering from us-west-1 which will raise error - ec2_apn1 = boto3.client('ec2', region_name='ap-northeast-1') - ec2_usw1 = boto3.client('ec2', region_name='us-west-1') + ec2_apn1 = boto3.client("ec2", region_name="ap-northeast-1") + ec2_usw1 = boto3.client("ec2", region_name="us-west-1") with assert_raises(ClientError) as cm: - ec2_usw1.reject_vpc_peering_connection( - VpcPeeringConnectionId=vpc_pcx_usw1.id - ) - cm.exception.response['Error']['Code'].should.equal('OperationNotPermitted') - exp_msg = 'Incorrect region ({0}) specified for this request.VPC ' \ - 'peering connection {1} must be accepted or ' \ - 'rejected in region {2}'.format('us-west-1', vpc_pcx_usw1.id, 'ap-northeast-1') - cm.exception.response['Error']['Message'].should.equal(exp_msg) + ec2_usw1.reject_vpc_peering_connection(VpcPeeringConnectionId=vpc_pcx_usw1.id) + cm.exception.response["Error"]["Code"].should.equal("OperationNotPermitted") + exp_msg = ( + "Incorrect region ({0}) specified for this request.VPC " + "peering connection {1} must be accepted or " + "rejected in region {2}".format("us-west-1", vpc_pcx_usw1.id, "ap-northeast-1") + ) + cm.exception.response["Error"]["Message"].should.equal(exp_msg) diff --git a/tests/test_ec2/test_vpcs.py b/tests/test_ec2/test_vpcs.py index ad17deb3c..0894a8b8e 100644 --- a/tests/test_ec2/test_vpcs.py +++ b/tests/test_ec2/test_vpcs.py @@ -1,6 +1,7 @@ from __future__ import unicode_literals + # Ensure 'assert_raises' context manager support for Python 2.6 -import tests.backport_assert_raises # flake8: noqa +import tests.backport_assert_raises # noqa from nose.tools import assert_raises from moto.ec2.exceptions import EC2ClientError from botocore.exceptions import ClientError @@ -12,15 +13,15 @@ import sure # noqa from moto import mock_ec2, mock_ec2_deprecated -SAMPLE_DOMAIN_NAME = u'example.com' -SAMPLE_NAME_SERVERS = [u'10.0.0.6', u'10.0.0.7'] +SAMPLE_DOMAIN_NAME = "example.com" +SAMPLE_NAME_SERVERS = ["10.0.0.6", "10.0.0.7"] @mock_ec2_deprecated def test_vpcs(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - vpc.cidr_block.should.equal('10.0.0.0/16') + vpc.cidr_block.should.equal("10.0.0.0/16") all_vpcs = conn.get_all_vpcs() all_vpcs.should.have.length_of(2) @@ -32,58 +33,56 @@ def test_vpcs(): with assert_raises(EC2ResponseError) as cm: conn.delete_vpc("vpc-1234abcd") - cm.exception.code.should.equal('InvalidVpcID.NotFound') + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @mock_ec2_deprecated def test_vpc_defaults(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") conn.get_all_vpcs().should.have.length_of(2) conn.get_all_route_tables().should.have.length_of(2) - conn.get_all_security_groups( - filters={'vpc-id': [vpc.id]}).should.have.length_of(1) + conn.get_all_security_groups(filters={"vpc-id": [vpc.id]}).should.have.length_of(1) vpc.delete() conn.get_all_vpcs().should.have.length_of(1) conn.get_all_route_tables().should.have.length_of(1) - conn.get_all_security_groups( - filters={'vpc-id': [vpc.id]}).should.have.length_of(0) + conn.get_all_security_groups(filters={"vpc-id": [vpc.id]}).should.have.length_of(0) @mock_ec2_deprecated def test_vpc_isdefault_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") - conn.get_all_vpcs(filters={'isDefault': 'true'}).should.have.length_of(1) + conn.get_all_vpcs(filters={"isDefault": "true"}).should.have.length_of(1) vpc.delete() - conn.get_all_vpcs(filters={'isDefault': 'true'}).should.have.length_of(1) + conn.get_all_vpcs(filters={"isDefault": "true"}).should.have.length_of(1) @mock_ec2_deprecated def test_multiple_vpcs_default_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") conn.create_vpc("10.8.0.0/16") conn.create_vpc("10.0.0.0/16") conn.create_vpc("192.168.0.0/16") conn.get_all_vpcs().should.have.length_of(4) - vpc = conn.get_all_vpcs(filters={'isDefault': 'true'}) + vpc = conn.get_all_vpcs(filters={"isDefault": "true"}) vpc.should.have.length_of(1) - vpc[0].cidr_block.should.equal('172.31.0.0/16') + vpc[0].cidr_block.should.equal("172.31.0.0/16") @mock_ec2_deprecated def test_vpc_state_available_filter(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpc = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.1.0.0/16") - conn.get_all_vpcs(filters={'state': 'available'}).should.have.length_of(3) + conn.get_all_vpcs(filters={"state": "available"}).should.have.length_of(3) vpc.delete() - conn.get_all_vpcs(filters={'state': 'available'}).should.have.length_of(2) + conn.get_all_vpcs(filters={"state": "available"}).should.have.length_of(2) @mock_ec2_deprecated @@ -116,8 +115,8 @@ def test_vpc_get_by_id(): vpc2.id.should.be.within(vpc_ids) with assert_raises(EC2ResponseError) as cm: - conn.get_all_vpcs(vpc_ids=['vpc-does_not_exist']) - cm.exception.code.should.equal('InvalidVpcID.NotFound') + conn.get_all_vpcs(vpc_ids=["vpc-does_not_exist"]) + cm.exception.code.should.equal("InvalidVpcID.NotFound") cm.exception.status.should.equal(400) cm.exception.request_id.should_not.be.none @@ -129,7 +128,7 @@ def test_vpc_get_by_cidr_block(): vpc2 = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.0.0.0/24") - vpcs = conn.get_all_vpcs(filters={'cidr': '10.0.0.0/16'}) + vpcs = conn.get_all_vpcs(filters={"cidr": "10.0.0.0/16"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -139,8 +138,7 @@ def test_vpc_get_by_cidr_block(): @mock_ec2_deprecated def test_vpc_get_by_dhcp_options_id(): conn = boto.connect_vpc() - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) vpc1 = conn.create_vpc("10.0.0.0/16") vpc2 = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.0.0.0/24") @@ -148,7 +146,7 @@ def test_vpc_get_by_dhcp_options_id(): conn.associate_dhcp_options(dhcp_options.id, vpc1.id) conn.associate_dhcp_options(dhcp_options.id, vpc2.id) - vpcs = conn.get_all_vpcs(filters={'dhcp-options-id': dhcp_options.id}) + vpcs = conn.get_all_vpcs(filters={"dhcp-options-id": dhcp_options.id}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -162,11 +160,11 @@ def test_vpc_get_by_tag(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc2.add_tag('Name', 'TestVPC') - vpc3.add_tag('Name', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc2.add_tag("Name", "TestVPC") + vpc3.add_tag("Name", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag:Name': 'TestVPC'}) + vpcs = conn.get_all_vpcs(filters={"tag:Name": "TestVPC"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -180,13 +178,13 @@ def test_vpc_get_by_tag_key_superset(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') - vpc3.add_tag('Key', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") + vpc3.add_tag("Key", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-key': 'Name'}) + vpcs = conn.get_all_vpcs(filters={"tag-key": "Name"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -200,13 +198,13 @@ def test_vpc_get_by_tag_key_subset(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') - vpc3.add_tag('Test', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") + vpc3.add_tag("Test", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-key': ['Name', 'Key']}) + vpcs = conn.get_all_vpcs(filters={"tag-key": ["Name", "Key"]}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -220,13 +218,13 @@ def test_vpc_get_by_tag_value_superset(): vpc2 = conn.create_vpc("10.0.0.0/16") vpc3 = conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') - vpc3.add_tag('Key', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") + vpc3.add_tag("Key", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-value': 'TestVPC'}) + vpcs = conn.get_all_vpcs(filters={"tag-value": "TestVPC"}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -240,12 +238,12 @@ def test_vpc_get_by_tag_value_subset(): vpc2 = conn.create_vpc("10.0.0.0/16") conn.create_vpc("10.0.0.0/24") - vpc1.add_tag('Name', 'TestVPC') - vpc1.add_tag('Key', 'TestVPC2') - vpc2.add_tag('Name', 'TestVPC') - vpc2.add_tag('Key', 'TestVPC2') + vpc1.add_tag("Name", "TestVPC") + vpc1.add_tag("Key", "TestVPC2") + vpc2.add_tag("Name", "TestVPC") + vpc2.add_tag("Key", "TestVPC2") - vpcs = conn.get_all_vpcs(filters={'tag-value': ['TestVPC', 'TestVPC2']}) + vpcs = conn.get_all_vpcs(filters={"tag-value": ["TestVPC", "TestVPC2"]}) vpcs.should.have.length_of(2) vpc_ids = tuple(map(lambda v: v.id, vpcs)) vpc1.id.should.be.within(vpc_ids) @@ -254,117 +252,116 @@ def test_vpc_get_by_tag_value_subset(): @mock_ec2 def test_default_vpc(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC default_vpc = list(ec2.vpcs.all())[0] - default_vpc.cidr_block.should.equal('172.31.0.0/16') - default_vpc.instance_tenancy.should.equal('default') + default_vpc.cidr_block.should.equal("172.31.0.0/16") + default_vpc.instance_tenancy.should.equal("default") default_vpc.reload() default_vpc.is_default.should.be.ok # Test default values for VPC attributes - response = default_vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').should.be.ok + response = default_vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").should.be.ok - response = default_vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').should.be.ok + response = default_vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").should.be.ok @mock_ec2 def test_non_default_vpc(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - this already exists when backend instantiated! - #ec2.create_vpc(CidrBlock='172.31.0.0/16') + # ec2.create_vpc(CidrBlock='172.31.0.0/16') # Create the non default VPC - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") vpc.reload() vpc.is_default.shouldnt.be.ok # Test default instance_tenancy - vpc.instance_tenancy.should.equal('default') + vpc.instance_tenancy.should.equal("default") # Test default values for VPC attributes - response = vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').should.be.ok + response = vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").should.be.ok - response = vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').shouldnt.be.ok + response = vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").shouldnt.be.ok # Check Primary CIDR Block Associations cidr_block_association_set = next(iter(vpc.cidr_block_association_set), None) - cidr_block_association_set['CidrBlockState']['State'].should.equal('associated') - cidr_block_association_set['CidrBlock'].should.equal(vpc.cidr_block) - cidr_block_association_set['AssociationId'].should.contain('vpc-cidr-assoc') + cidr_block_association_set["CidrBlockState"]["State"].should.equal("associated") + cidr_block_association_set["CidrBlock"].should.equal(vpc.cidr_block) + cidr_block_association_set["AssociationId"].should.contain("vpc-cidr-assoc") @mock_ec2 def test_vpc_dedicated_tenancy(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - ec2.create_vpc(CidrBlock='172.31.0.0/16') + ec2.create_vpc(CidrBlock="172.31.0.0/16") # Create the non default VPC - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16', InstanceTenancy='dedicated') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16", InstanceTenancy="dedicated") vpc.reload() vpc.is_default.shouldnt.be.ok - vpc.instance_tenancy.should.equal('dedicated') + vpc.instance_tenancy.should.equal("dedicated") @mock_ec2 def test_vpc_modify_enable_dns_support(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - ec2.create_vpc(CidrBlock='172.31.0.0/16') + ec2.create_vpc(CidrBlock="172.31.0.0/16") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") # Test default values for VPC attributes - response = vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').should.be.ok + response = vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").should.be.ok - vpc.modify_attribute(EnableDnsSupport={'Value': False}) + vpc.modify_attribute(EnableDnsSupport={"Value": False}) - response = vpc.describe_attribute(Attribute='enableDnsSupport') - attr = response.get('EnableDnsSupport') - attr.get('Value').shouldnt.be.ok + response = vpc.describe_attribute(Attribute="enableDnsSupport") + attr = response.get("EnableDnsSupport") + attr.get("Value").shouldnt.be.ok @mock_ec2 def test_vpc_modify_enable_dns_hostnames(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Create the default VPC - ec2.create_vpc(CidrBlock='172.31.0.0/16') + ec2.create_vpc(CidrBlock="172.31.0.0/16") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") # Test default values for VPC attributes - response = vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').shouldnt.be.ok + response = vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").shouldnt.be.ok - vpc.modify_attribute(EnableDnsHostnames={'Value': True}) + vpc.modify_attribute(EnableDnsHostnames={"Value": True}) - response = vpc.describe_attribute(Attribute='enableDnsHostnames') - attr = response.get('EnableDnsHostnames') - attr.get('Value').should.be.ok + response = vpc.describe_attribute(Attribute="enableDnsHostnames") + attr = response.get("EnableDnsHostnames") + attr.get("Value").should.be.ok @mock_ec2_deprecated def test_vpc_associate_dhcp_options(): conn = boto.connect_vpc() - dhcp_options = conn.create_dhcp_options( - SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) + dhcp_options = conn.create_dhcp_options(SAMPLE_DOMAIN_NAME, SAMPLE_NAME_SERVERS) vpc = conn.create_vpc("10.0.0.0/16") conn.associate_dhcp_options(dhcp_options.id, vpc.id) @@ -375,117 +372,206 @@ def test_vpc_associate_dhcp_options(): @mock_ec2 def test_associate_vpc_ipv4_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24') + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24") # Associate/Extend vpc CIDR range up to 5 ciders for i in range(43, 47): - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock='10.10.{}.0/24'.format(i)) - response['CidrBlockAssociation']['CidrBlockState']['State'].should.equal('associating') - response['CidrBlockAssociation']['CidrBlock'].should.equal('10.10.{}.0/24'.format(i)) - response['CidrBlockAssociation']['AssociationId'].should.contain('vpc-cidr-assoc') + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, CidrBlock="10.10.{}.0/24".format(i) + ) + response["CidrBlockAssociation"]["CidrBlockState"]["State"].should.equal( + "associating" + ) + response["CidrBlockAssociation"]["CidrBlock"].should.equal( + "10.10.{}.0/24".format(i) + ) + response["CidrBlockAssociation"]["AssociationId"].should.contain( + "vpc-cidr-assoc" + ) # Check all associations exist vpc = ec2.Vpc(vpc.id) vpc.cidr_block_association_set.should.have.length_of(5) - vpc.cidr_block_association_set[2]['CidrBlockState']['State'].should.equal('associated') - vpc.cidr_block_association_set[4]['CidrBlockState']['State'].should.equal('associated') + vpc.cidr_block_association_set[2]["CidrBlockState"]["State"].should.equal( + "associated" + ) + vpc.cidr_block_association_set[4]["CidrBlockState"]["State"].should.equal( + "associated" + ) # Check error on adding 6th association. with assert_raises(ClientError) as ex: - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock='10.10.50.0/22') + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, CidrBlock="10.10.50.0/22" + ) str(ex.exception).should.equal( "An error occurred (CidrLimitExceeded) when calling the AssociateVpcCidrBlock " - "operation: This network '{}' has met its maximum number of allowed CIDRs: 5".format(vpc.id)) + "operation: This network '{}' has met its maximum number of allowed CIDRs: 5".format( + vpc.id + ) + ) + @mock_ec2 def test_disassociate_vpc_ipv4_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock='10.10.43.0/24') + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, CidrBlock="10.10.43.0/24") # Remove an extended cidr block vpc = ec2.Vpc(vpc.id) - non_default_assoc_cidr_block = next(iter([x for x in vpc.cidr_block_association_set if vpc.cidr_block != x['CidrBlock']]), None) - response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId=non_default_assoc_cidr_block['AssociationId']) - response['CidrBlockAssociation']['CidrBlockState']['State'].should.equal('disassociating') - response['CidrBlockAssociation']['CidrBlock'].should.equal(non_default_assoc_cidr_block['CidrBlock']) - response['CidrBlockAssociation']['AssociationId'].should.equal(non_default_assoc_cidr_block['AssociationId']) + non_default_assoc_cidr_block = next( + iter( + [ + x + for x in vpc.cidr_block_association_set + if vpc.cidr_block != x["CidrBlock"] + ] + ), + None, + ) + response = ec2.meta.client.disassociate_vpc_cidr_block( + AssociationId=non_default_assoc_cidr_block["AssociationId"] + ) + response["CidrBlockAssociation"]["CidrBlockState"]["State"].should.equal( + "disassociating" + ) + response["CidrBlockAssociation"]["CidrBlock"].should.equal( + non_default_assoc_cidr_block["CidrBlock"] + ) + response["CidrBlockAssociation"]["AssociationId"].should.equal( + non_default_assoc_cidr_block["AssociationId"] + ) # Error attempting to delete a non-existent CIDR_BLOCK association with assert_raises(ClientError) as ex: - response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId='vpc-cidr-assoc-BORING123') + response = ec2.meta.client.disassociate_vpc_cidr_block( + AssociationId="vpc-cidr-assoc-BORING123" + ) str(ex.exception).should.equal( "An error occurred (InvalidVpcCidrBlockAssociationIdError.NotFound) when calling the " "DisassociateVpcCidrBlock operation: The vpc CIDR block association ID " - "'vpc-cidr-assoc-BORING123' does not exist") + "'vpc-cidr-assoc-BORING123' does not exist" + ) # Error attempting to delete Primary CIDR BLOCK association - vpc_base_cidr_assoc_id = next(iter([x for x in vpc.cidr_block_association_set - if vpc.cidr_block == x['CidrBlock']]), {})['AssociationId'] + vpc_base_cidr_assoc_id = next( + iter( + [ + x + for x in vpc.cidr_block_association_set + if vpc.cidr_block == x["CidrBlock"] + ] + ), + {}, + )["AssociationId"] with assert_raises(ClientError) as ex: - response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId=vpc_base_cidr_assoc_id) + response = ec2.meta.client.disassociate_vpc_cidr_block( + AssociationId=vpc_base_cidr_assoc_id + ) str(ex.exception).should.equal( "An error occurred (OperationNotPermitted) when calling the DisassociateVpcCidrBlock operation: " "The vpc CIDR block with association ID {} may not be disassociated. It is the primary " - "IPv4 CIDR block of the VPC".format(vpc_base_cidr_assoc_id)) + "IPv4 CIDR block of the VPC".format(vpc_base_cidr_assoc_id) + ) + @mock_ec2 def test_cidr_block_association_filters(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - vpc1 = ec2.create_vpc(CidrBlock='10.90.0.0/16') - vpc2 = ec2.create_vpc(CidrBlock='10.91.0.0/16') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock='10.10.0.0/19') - vpc3 = ec2.create_vpc(CidrBlock='10.92.0.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.1.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.2.0/24') - vpc3_assoc_response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.3.0/24') + ec2 = boto3.resource("ec2", region_name="us-west-1") + vpc1 = ec2.create_vpc(CidrBlock="10.90.0.0/16") + vpc2 = ec2.create_vpc(CidrBlock="10.91.0.0/16") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock="10.10.0.0/19") + vpc3 = ec2.create_vpc(CidrBlock="10.92.0.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.1.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.2.0/24") + vpc3_assoc_response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc3.id, CidrBlock="10.92.3.0/24" + ) # Test filters for a cidr-block in all VPCs cidr-block-associations - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'cidr-block-association.cidr-block', - 'Values': ['10.10.0.0/19']}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "cidr-block-association.cidr-block", + "Values": ["10.10.0.0/19"], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc2.id) # Test filter for association id in VPCs - association_id = vpc3_assoc_response['CidrBlockAssociation']['AssociationId'] - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'cidr-block-association.association-id', - 'Values': [association_id]}])) + association_id = vpc3_assoc_response["CidrBlockAssociation"]["AssociationId"] + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "cidr-block-association.association-id", + "Values": [association_id], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc3.id) # Test filter for association state in VPC - this will never show anything in this test - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'cidr-block-association.association-id', - 'Values': ['failing']}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + {"Name": "cidr-block-association.association-id", "Values": ["failing"]} + ] + ) + ) filtered_vpcs.should.be.length_of(0) + @mock_ec2 def test_vpc_associate_ipv6_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Test create VPC with IPV6 cidr range - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24', AmazonProvidedIpv6CidrBlock=True) - ipv6_cidr_block_association_set = next(iter(vpc.ipv6_cidr_block_association_set), None) - ipv6_cidr_block_association_set['Ipv6CidrBlockState']['State'].should.equal('associated') - ipv6_cidr_block_association_set['Ipv6CidrBlock'].should.contain('::/56') - ipv6_cidr_block_association_set['AssociationId'].should.contain('vpc-cidr-assoc') + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24", AmazonProvidedIpv6CidrBlock=True) + ipv6_cidr_block_association_set = next( + iter(vpc.ipv6_cidr_block_association_set), None + ) + ipv6_cidr_block_association_set["Ipv6CidrBlockState"]["State"].should.equal( + "associated" + ) + ipv6_cidr_block_association_set["Ipv6CidrBlock"].should.contain("::/56") + ipv6_cidr_block_association_set["AssociationId"].should.contain("vpc-cidr-assoc") # Test Fail on adding 2nd IPV6 association - AWS only allows 1 at this time! with assert_raises(ClientError) as ex: - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True) + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True + ) str(ex.exception).should.equal( "An error occurred (CidrLimitExceeded) when calling the AssociateVpcCidrBlock " - "operation: This network '{}' has met its maximum number of allowed CIDRs: 1".format(vpc.id)) + "operation: This network '{}' has met its maximum number of allowed CIDRs: 1".format( + vpc.id + ) + ) # Test associate ipv6 cidr block after vpc created - vpc = ec2.create_vpc(CidrBlock='10.10.50.0/24') - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True) - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlockState']['State'].should.equal('associating') - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'].should.contain('::/56') - response['Ipv6CidrBlockAssociation']['AssociationId'].should.contain('vpc-cidr-assoc-') + vpc = ec2.create_vpc(CidrBlock="10.10.50.0/24") + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc.id, AmazonProvidedIpv6CidrBlock=True + ) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlockState"]["State"].should.equal( + "associating" + ) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlock"].should.contain("::/56") + response["Ipv6CidrBlockAssociation"]["AssociationId"].should.contain( + "vpc-cidr-assoc-" + ) # Check on describe vpc that has ipv6 cidr block association vpc = ec2.Vpc(vpc.id) @@ -494,72 +580,101 @@ def test_vpc_associate_ipv6_cidr_block(): @mock_ec2 def test_vpc_disassociate_ipv6_cidr_block(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") # Test create VPC with IPV6 cidr range - vpc = ec2.create_vpc(CidrBlock='10.10.42.0/24', AmazonProvidedIpv6CidrBlock=True) + vpc = ec2.create_vpc(CidrBlock="10.10.42.0/24", AmazonProvidedIpv6CidrBlock=True) # Test disassociating the only IPV6 - assoc_id = vpc.ipv6_cidr_block_association_set[0]['AssociationId'] + assoc_id = vpc.ipv6_cidr_block_association_set[0]["AssociationId"] response = ec2.meta.client.disassociate_vpc_cidr_block(AssociationId=assoc_id) - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlockState']['State'].should.equal('disassociating') - response['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'].should.contain('::/56') - response['Ipv6CidrBlockAssociation']['AssociationId'].should.equal(assoc_id) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlockState"]["State"].should.equal( + "disassociating" + ) + response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlock"].should.contain("::/56") + response["Ipv6CidrBlockAssociation"]["AssociationId"].should.equal(assoc_id) @mock_ec2 def test_ipv6_cidr_block_association_filters(): - ec2 = boto3.resource('ec2', region_name='us-west-1') - vpc1 = ec2.create_vpc(CidrBlock='10.90.0.0/16') + ec2 = boto3.resource("ec2", region_name="us-west-1") + vpc1 = ec2.create_vpc(CidrBlock="10.90.0.0/16") - vpc2 = ec2.create_vpc(CidrBlock='10.91.0.0/16', AmazonProvidedIpv6CidrBlock=True) - vpc2_assoc_ipv6_assoc_id = vpc2.ipv6_cidr_block_association_set[0]['AssociationId'] - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock='10.10.0.0/19') + vpc2 = ec2.create_vpc(CidrBlock="10.91.0.0/16", AmazonProvidedIpv6CidrBlock=True) + vpc2_assoc_ipv6_assoc_id = vpc2.ipv6_cidr_block_association_set[0]["AssociationId"] + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc2.id, CidrBlock="10.10.0.0/19") - vpc3 = ec2.create_vpc(CidrBlock='10.92.0.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.1.0/24') - ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock='10.92.2.0/24') - response = ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, AmazonProvidedIpv6CidrBlock=True) - vpc3_ipv6_cidr_block = response['Ipv6CidrBlockAssociation']['Ipv6CidrBlock'] + vpc3 = ec2.create_vpc(CidrBlock="10.92.0.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.1.0/24") + ec2.meta.client.associate_vpc_cidr_block(VpcId=vpc3.id, CidrBlock="10.92.2.0/24") + response = ec2.meta.client.associate_vpc_cidr_block( + VpcId=vpc3.id, AmazonProvidedIpv6CidrBlock=True + ) + vpc3_ipv6_cidr_block = response["Ipv6CidrBlockAssociation"]["Ipv6CidrBlock"] - vpc4 = ec2.create_vpc(CidrBlock='10.95.0.0/16') # Here for its looks + vpc4 = ec2.create_vpc(CidrBlock="10.95.0.0/16") # Here for its looks # Test filters for an ipv6 cidr-block in all VPCs cidr-block-associations - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'ipv6-cidr-block-association.ipv6-cidr-block', - 'Values': [vpc3_ipv6_cidr_block]}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "ipv6-cidr-block-association.ipv6-cidr-block", + "Values": [vpc3_ipv6_cidr_block], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc3.id) # Test filter for association id in VPCs - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'ipv6-cidr-block-association.association-id', - 'Values': [vpc2_assoc_ipv6_assoc_id]}])) + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + { + "Name": "ipv6-cidr-block-association.association-id", + "Values": [vpc2_assoc_ipv6_assoc_id], + } + ] + ) + ) filtered_vpcs.should.be.length_of(1) filtered_vpcs[0].id.should.equal(vpc2.id) # Test filter for association state in VPC - this will never show anything in this test - filtered_vpcs = list(ec2.vpcs.filter(Filters=[{'Name': 'ipv6-cidr-block-association.state', - 'Values': ['associated']}])) - filtered_vpcs.should.be.length_of(2) # 2 of 4 VPCs + filtered_vpcs = list( + ec2.vpcs.filter( + Filters=[ + {"Name": "ipv6-cidr-block-association.state", "Values": ["associated"]} + ] + ) + ) + filtered_vpcs.should.be.length_of(2) # 2 of 4 VPCs @mock_ec2 def test_create_vpc_with_invalid_cidr_block_parameter(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc_cidr_block = '1000.1.0.0/20' + vpc_cidr_block = "1000.1.0.0/20" with assert_raises(ClientError) as ex: vpc = ec2.create_vpc(CidrBlock=vpc_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidParameterValue) when calling the CreateVpc " - "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format(vpc_cidr_block)) + "operation: Value ({}) for parameter cidrBlock is invalid. This is not a valid CIDR block.".format( + vpc_cidr_block + ) + ) @mock_ec2 def test_create_vpc_with_invalid_cidr_range(): - ec2 = boto3.resource('ec2', region_name='us-west-1') + ec2 = boto3.resource("ec2", region_name="us-west-1") - vpc_cidr_block = '10.1.0.0/29' + vpc_cidr_block = "10.1.0.0/29" with assert_raises(ClientError) as ex: vpc = ec2.create_vpc(CidrBlock=vpc_cidr_block) str(ex.exception).should.equal( "An error occurred (InvalidVpc.Range) when calling the CreateVpc " - "operation: The CIDR '{}' is invalid.".format(vpc_cidr_block)) + "operation: The CIDR '{}' is invalid.".format(vpc_cidr_block) + ) diff --git a/tests/test_ec2/test_vpn_connections.py b/tests/test_ec2/test_vpn_connections.py index e95aa76ee..24396d3d1 100644 --- a/tests/test_ec2/test_vpn_connections.py +++ b/tests/test_ec2/test_vpn_connections.py @@ -9,19 +9,21 @@ from moto import mock_ec2_deprecated @mock_ec2_deprecated def test_create_vpn_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpn_connection = conn.create_vpn_connection( - 'ipsec.1', 'vgw-0123abcd', 'cgw-0123abcd') + "ipsec.1", "vgw-0123abcd", "cgw-0123abcd" + ) vpn_connection.should_not.be.none - vpn_connection.id.should.match(r'vpn-\w+') - vpn_connection.type.should.equal('ipsec.1') + vpn_connection.id.should.match(r"vpn-\w+") + vpn_connection.type.should.equal("ipsec.1") @mock_ec2_deprecated def test_delete_vpn_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") vpn_connection = conn.create_vpn_connection( - 'ipsec.1', 'vgw-0123abcd', 'cgw-0123abcd') + "ipsec.1", "vgw-0123abcd", "cgw-0123abcd" + ) list_of_vpn_connections = conn.get_all_vpn_connections() list_of_vpn_connections.should.have.length_of(1) conn.delete_vpn_connection(vpn_connection.id) @@ -31,20 +33,20 @@ def test_delete_vpn_connections(): @mock_ec2_deprecated def test_delete_vpn_connections_bad_id(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") with assert_raises(EC2ResponseError): - conn.delete_vpn_connection('vpn-0123abcd') + conn.delete_vpn_connection("vpn-0123abcd") @mock_ec2_deprecated def test_describe_vpn_connections(): - conn = boto.connect_vpc('the_key', 'the_secret') + conn = boto.connect_vpc("the_key", "the_secret") list_of_vpn_connections = conn.get_all_vpn_connections() list_of_vpn_connections.should.have.length_of(0) - conn.create_vpn_connection('ipsec.1', 'vgw-0123abcd', 'cgw-0123abcd') + conn.create_vpn_connection("ipsec.1", "vgw-0123abcd", "cgw-0123abcd") list_of_vpn_connections = conn.get_all_vpn_connections() list_of_vpn_connections.should.have.length_of(1) - vpn = conn.create_vpn_connection('ipsec.1', 'vgw-1234abcd', 'cgw-1234abcd') + vpn = conn.create_vpn_connection("ipsec.1", "vgw-1234abcd", "cgw-1234abcd") list_of_vpn_connections = conn.get_all_vpn_connections() list_of_vpn_connections.should.have.length_of(2) list_of_vpn_connections = conn.get_all_vpn_connections(vpn.id) diff --git a/tests/test_ecr/test_ecr_boto3.py b/tests/test_ecr/test_ecr_boto3.py index ec0e4e732..9115e3fad 100644 --- a/tests/test_ecr/test_ecr_boto3.py +++ b/tests/test_ecr/test_ecr_boto3.py @@ -20,1062 +20,1035 @@ from nose import SkipTest def _create_image_digest(contents=None): if not contents: - contents = 'docker_image{0}'.format(int(random() * 10 ** 6)) - return "sha256:%s" % hashlib.sha256(contents.encode('utf-8')).hexdigest() + contents = "docker_image{0}".format(int(random() * 10 ** 6)) + return "sha256:%s" % hashlib.sha256(contents.encode("utf-8")).hexdigest() def _create_image_manifest(): return { "schemaVersion": 2, "mediaType": "application/vnd.docker.distribution.manifest.v2+json", - "config": - { - "mediaType": "application/vnd.docker.container.image.v1+json", - "size": 7023, - "digest": _create_image_digest("config") - }, + "config": { + "mediaType": "application/vnd.docker.container.image.v1+json", + "size": 7023, + "digest": _create_image_digest("config"), + }, "layers": [ { "mediaType": "application/vnd.docker.image.rootfs.diff.tar.gzip", "size": 32654, - "digest": _create_image_digest("layer1") + "digest": _create_image_digest("layer1"), }, { "mediaType": "application/vnd.docker.image.rootfs.diff.tar.gzip", "size": 16724, - "digest": _create_image_digest("layer2") + "digest": _create_image_digest("layer2"), }, { "mediaType": "application/vnd.docker.image.rootfs.diff.tar.gzip", "size": 73109, # randomize image digest - "digest": _create_image_digest() - } - ] + "digest": _create_image_digest(), + }, + ], } @mock_ecr def test_create_repository(): - client = boto3.client('ecr', region_name='us-east-1') - response = client.create_repository( - repositoryName='test_ecr_repository' + client = boto3.client("ecr", region_name="us-east-1") + response = client.create_repository(repositoryName="test_ecr_repository") + response["repository"]["repositoryName"].should.equal("test_ecr_repository") + response["repository"]["repositoryArn"].should.equal( + "arn:aws:ecr:us-east-1:012345678910:repository/test_ecr_repository" + ) + response["repository"]["registryId"].should.equal("012345678910") + response["repository"]["repositoryUri"].should.equal( + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_ecr_repository" ) - response['repository']['repositoryName'].should.equal('test_ecr_repository') - response['repository']['repositoryArn'].should.equal( - 'arn:aws:ecr:us-east-1:012345678910:repository/test_ecr_repository') - response['repository']['registryId'].should.equal('012345678910') - response['repository']['repositoryUri'].should.equal( - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_ecr_repository') # response['repository']['createdAt'].should.equal(0) @mock_ecr def test_describe_repositories(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") response = client.describe_repositories() - len(response['repositories']).should.equal(2) + len(response["repositories"]).should.equal(2) - respository_arns = ['arn:aws:ecr:us-east-1:012345678910:repository/test_repository1', - 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository0'] - set([response['repositories'][0]['repositoryArn'], - response['repositories'][1]['repositoryArn']]).should.equal(set(respository_arns)) + respository_arns = [ + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryArn"], + response["repositories"][1]["repositoryArn"], + ] + ).should.equal(set(respository_arns)) - respository_uris = ['012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1', - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0'] - set([response['repositories'][0]['repositoryUri'], - response['repositories'][1]['repositoryUri']]).should.equal(set(respository_uris)) + respository_uris = [ + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryUri"], + response["repositories"][1]["repositoryUri"], + ] + ).should.equal(set(respository_uris)) @mock_ecr def test_describe_repositories_1(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) - response = client.describe_repositories(registryId='012345678910') - len(response['repositories']).should.equal(2) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") + response = client.describe_repositories(registryId="012345678910") + len(response["repositories"]).should.equal(2) - respository_arns = ['arn:aws:ecr:us-east-1:012345678910:repository/test_repository1', - 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository0'] - set([response['repositories'][0]['repositoryArn'], - response['repositories'][1]['repositoryArn']]).should.equal(set(respository_arns)) + respository_arns = [ + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1", + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryArn"], + response["repositories"][1]["repositoryArn"], + ] + ).should.equal(set(respository_arns)) - respository_uris = ['012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1', - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0'] - set([response['repositories'][0]['repositoryUri'], - response['repositories'][1]['repositoryUri']]).should.equal(set(respository_uris)) + respository_uris = [ + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1", + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository0", + ] + set( + [ + response["repositories"][0]["repositoryUri"], + response["repositories"][1]["repositoryUri"], + ] + ).should.equal(set(respository_uris)) @mock_ecr def test_describe_repositories_2(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) - response = client.describe_repositories(registryId='109876543210') - len(response['repositories']).should.equal(0) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") + response = client.describe_repositories(registryId="109876543210") + len(response["repositories"]).should.equal(0) @mock_ecr def test_describe_repositories_3(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository1' - ) - _ = client.create_repository( - repositoryName='test_repository0' - ) - response = client.describe_repositories(repositoryNames=['test_repository1']) - len(response['repositories']).should.equal(1) - respository_arn = 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository1' - response['repositories'][0]['repositoryArn'].should.equal(respository_arn) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository1") + _ = client.create_repository(repositoryName="test_repository0") + response = client.describe_repositories(repositoryNames=["test_repository1"]) + len(response["repositories"]).should.equal(1) + respository_arn = "arn:aws:ecr:us-east-1:012345678910:repository/test_repository1" + response["repositories"][0]["repositoryArn"].should.equal(respository_arn) - respository_uri = '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1' - response['repositories'][0]['repositoryUri'].should.equal(respository_uri) + respository_uri = "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository1" + response["repositories"][0]["repositoryUri"].should.equal(respository_uri) @mock_ecr def test_describe_repositories_with_image(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - response = client.describe_repositories(repositoryNames=['test_repository']) - len(response['repositories']).should.equal(1) + response = client.describe_repositories(repositoryNames=["test_repository"]) + len(response["repositories"]).should.equal(1) @mock_ecr def test_delete_repository(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") + response = client.delete_repository(repositoryName="test_repository") + response["repository"]["repositoryName"].should.equal("test_repository") + response["repository"]["repositoryArn"].should.equal( + "arn:aws:ecr:us-east-1:012345678910:repository/test_repository" + ) + response["repository"]["registryId"].should.equal("012345678910") + response["repository"]["repositoryUri"].should.equal( + "012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository" ) - response = client.delete_repository(repositoryName='test_repository') - response['repository']['repositoryName'].should.equal('test_repository') - response['repository']['repositoryArn'].should.equal( - 'arn:aws:ecr:us-east-1:012345678910:repository/test_repository') - response['repository']['registryId'].should.equal('012345678910') - response['repository']['repositoryUri'].should.equal( - '012345678910.dkr.ecr.us-east-1.amazonaws.com/test_repository') # response['repository']['createdAt'].should.equal(0) response = client.describe_repositories() - len(response['repositories']).should.equal(0) + len(response["repositories"]).should.equal(0) @mock_ecr def test_put_image(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - response['image']['imageId']['imageTag'].should.equal('latest') - response['image']['imageId']['imageDigest'].should.contain("sha") - response['image']['repositoryName'].should.equal('test_repository') - response['image']['registryId'].should.equal('012345678910') + response["image"]["imageId"]["imageTag"].should.equal("latest") + response["image"]["imageId"]["imageDigest"].should.contain("sha") + response["image"]["repositoryName"].should.equal("test_repository") + response["image"]["registryId"].should.equal("012345678910") @mock_ecr def test_put_image_with_push_date(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") - with freeze_time('2018-08-28 00:00:00'): + with freeze_time("2018-08-28 00:00:00"): image1_date = datetime.now() _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - with freeze_time('2019-05-31 00:00:00'): + with freeze_time("2019-05-31 00:00:00"): image2_date = datetime.now() _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") - type(describe_response['imageDetails']).should.be(list) - len(describe_response['imageDetails']).should.be(2) + type(describe_response["imageDetails"]).should.be(list) + len(describe_response["imageDetails"]).should.be(2) - set([describe_response['imageDetails'][0]['imagePushedAt'], - describe_response['imageDetails'][1]['imagePushedAt']]).should.equal(set([image1_date, image2_date])) + set( + [ + describe_response["imageDetails"][0]["imagePushedAt"], + describe_response["imageDetails"][1]["imagePushedAt"], + ] + ).should.equal(set([image1_date, image2_date])) @mock_ecr def test_put_image_with_multiple_tags(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag='v1' + imageTag="v1", ) - response['image']['imageId']['imageTag'].should.equal('v1') - response['image']['imageId']['imageDigest'].should.contain("sha") - response['image']['repositoryName'].should.equal('test_repository') - response['image']['registryId'].should.equal('012345678910') + response["image"]["imageId"]["imageTag"].should.equal("v1") + response["image"]["imageId"]["imageDigest"].should.contain("sha") + response["image"]["repositoryName"].should.equal("test_repository") + response["image"]["registryId"].should.equal("012345678910") response1 = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag='latest' + imageTag="latest", ) - response1['image']['imageId']['imageTag'].should.equal('latest') - response1['image']['imageId']['imageDigest'].should.contain("sha") - response1['image']['repositoryName'].should.equal('test_repository') - response1['image']['registryId'].should.equal('012345678910') + response1["image"]["imageId"]["imageTag"].should.equal("latest") + response1["image"]["imageId"]["imageDigest"].should.contain("sha") + response1["image"]["repositoryName"].should.equal("test_repository") + response1["image"]["registryId"].should.equal("012345678910") - response2 = client.describe_images(repositoryName='test_repository') - type(response2['imageDetails']).should.be(list) - len(response2['imageDetails']).should.be(1) + response2 = client.describe_images(repositoryName="test_repository") + type(response2["imageDetails"]).should.be(list) + len(response2["imageDetails"]).should.be(1) - response2['imageDetails'][0]['imageDigest'].should.contain("sha") + response2["imageDetails"][0]["imageDigest"].should.contain("sha") - response2['imageDetails'][0]['registryId'].should.equal("012345678910") + response2["imageDetails"][0]["registryId"].should.equal("012345678910") - response2['imageDetails'][0]['repositoryName'].should.equal("test_repository") + response2["imageDetails"][0]["repositoryName"].should.equal("test_repository") - len(response2['imageDetails'][0]['imageTags']).should.be(2) - response2['imageDetails'][0]['imageTags'].should.be.equal(['v1', 'latest']) + len(response2["imageDetails"][0]["imageTags"]).should.be(2) + response2["imageDetails"][0]["imageTags"].should.be.equal(["v1", "latest"]) @mock_ecr def test_list_images(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository_1' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository_1") - _ = client.create_repository( - repositoryName='test_repository_2' + _ = client.create_repository(repositoryName="test_repository_2") + + _ = client.put_image( + repositoryName="test_repository_1", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository_1', + repositoryName="test_repository_1", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="v1", ) _ = client.put_image( - repositoryName='test_repository_1', + repositoryName="test_repository_1", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' + imageTag="v2", ) _ = client.put_image( - repositoryName='test_repository_1', + repositoryName="test_repository_2", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="oldest", ) - _ = client.put_image( - repositoryName='test_repository_2', - imageManifest=json.dumps(_create_image_manifest()), - imageTag='oldest' - ) + response = client.list_images(repositoryName="test_repository_1") + type(response["imageIds"]).should.be(list) + len(response["imageIds"]).should.be(3) - response = client.list_images(repositoryName='test_repository_1') - type(response['imageIds']).should.be(list) - len(response['imageIds']).should.be(3) + image_tags = ["latest", "v1", "v2"] + set( + [ + response["imageIds"][0]["imageTag"], + response["imageIds"][1]["imageTag"], + response["imageIds"][2]["imageTag"], + ] + ).should.equal(set(image_tags)) - image_tags = ['latest', 'v1', 'v2'] - set([response['imageIds'][0]['imageTag'], - response['imageIds'][1]['imageTag'], - response['imageIds'][2]['imageTag']]).should.equal(set(image_tags)) - - response = client.list_images(repositoryName='test_repository_2') - type(response['imageIds']).should.be(list) - len(response['imageIds']).should.be(1) - response['imageIds'][0]['imageTag'].should.equal('oldest') + response = client.list_images(repositoryName="test_repository_2") + type(response["imageIds"]).should.be(list) + len(response["imageIds"]).should.be(1) + response["imageIds"][0]["imageTag"].should.equal("oldest") @mock_ecr def test_list_images_from_repository_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository_1' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository_1") # non existing repo error_msg = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.list_images.when.called_with( - repositoryName='repo-that-doesnt-exist', - registryId='123', + repositoryName="repo-that-doesnt-exist", registryId="123" ).should.throw(Exception, error_msg) # repo does not exist in specified registry error_msg = re.compile( r".*The repository with name 'test_repository_1' does not exist in the registry with id '222'.*", - re.MULTILINE) + re.MULTILINE, + ) client.list_images.when.called_with( - repositoryName='test_repository_1', - registryId='222', + repositoryName="test_repository_1", registryId="222" ).should.throw(Exception, error_msg) @mock_ecr def test_describe_images(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") _ = client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(_create_image_manifest()) - ) - - _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="v1", ) - response = client.describe_images(repositoryName='test_repository') - type(response['imageDetails']).should.be(list) - len(response['imageDetails']).should.be(4) + _ = client.put_image( + repositoryName="test_repository", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="v2", + ) - response['imageDetails'][0]['imageDigest'].should.contain("sha") - response['imageDetails'][1]['imageDigest'].should.contain("sha") - response['imageDetails'][2]['imageDigest'].should.contain("sha") - response['imageDetails'][3]['imageDigest'].should.contain("sha") + response = client.describe_images(repositoryName="test_repository") + type(response["imageDetails"]).should.be(list) + len(response["imageDetails"]).should.be(4) - response['imageDetails'][0]['registryId'].should.equal("012345678910") - response['imageDetails'][1]['registryId'].should.equal("012345678910") - response['imageDetails'][2]['registryId'].should.equal("012345678910") - response['imageDetails'][3]['registryId'].should.equal("012345678910") + response["imageDetails"][0]["imageDigest"].should.contain("sha") + response["imageDetails"][1]["imageDigest"].should.contain("sha") + response["imageDetails"][2]["imageDigest"].should.contain("sha") + response["imageDetails"][3]["imageDigest"].should.contain("sha") - response['imageDetails'][0]['repositoryName'].should.equal("test_repository") - response['imageDetails'][1]['repositoryName'].should.equal("test_repository") - response['imageDetails'][2]['repositoryName'].should.equal("test_repository") - response['imageDetails'][3]['repositoryName'].should.equal("test_repository") + response["imageDetails"][0]["registryId"].should.equal("012345678910") + response["imageDetails"][1]["registryId"].should.equal("012345678910") + response["imageDetails"][2]["registryId"].should.equal("012345678910") + response["imageDetails"][3]["registryId"].should.equal("012345678910") - response['imageDetails'][0].should_not.have.key('imageTags') - len(response['imageDetails'][1]['imageTags']).should.be(1) - len(response['imageDetails'][2]['imageTags']).should.be(1) - len(response['imageDetails'][3]['imageTags']).should.be(1) + response["imageDetails"][0]["repositoryName"].should.equal("test_repository") + response["imageDetails"][1]["repositoryName"].should.equal("test_repository") + response["imageDetails"][2]["repositoryName"].should.equal("test_repository") + response["imageDetails"][3]["repositoryName"].should.equal("test_repository") - image_tags = ['latest', 'v1', 'v2'] - set([response['imageDetails'][1]['imageTags'][0], - response['imageDetails'][2]['imageTags'][0], - response['imageDetails'][3]['imageTags'][0]]).should.equal(set(image_tags)) + response["imageDetails"][0].should_not.have.key("imageTags") + len(response["imageDetails"][1]["imageTags"]).should.be(1) + len(response["imageDetails"][2]["imageTags"]).should.be(1) + len(response["imageDetails"][3]["imageTags"]).should.be(1) - response['imageDetails'][0]['imageSizeInBytes'].should.equal(52428800) - response['imageDetails'][1]['imageSizeInBytes'].should.equal(52428800) - response['imageDetails'][2]['imageSizeInBytes'].should.equal(52428800) - response['imageDetails'][3]['imageSizeInBytes'].should.equal(52428800) + image_tags = ["latest", "v1", "v2"] + set( + [ + response["imageDetails"][1]["imageTags"][0], + response["imageDetails"][2]["imageTags"][0], + response["imageDetails"][3]["imageTags"][0], + ] + ).should.equal(set(image_tags)) + + response["imageDetails"][0]["imageSizeInBytes"].should.equal(52428800) + response["imageDetails"][1]["imageSizeInBytes"].should.equal(52428800) + response["imageDetails"][2]["imageSizeInBytes"].should.equal(52428800) + response["imageDetails"][3]["imageSizeInBytes"].should.equal(52428800) @mock_ecr def test_describe_images_by_tag(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") tag_map = {} - for tag in ['latest', 'v1', 'v2']: + for tag in ["latest", "v1", "v2"]: put_response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag=tag + imageTag=tag, ) - tag_map[tag] = put_response['image'] + tag_map[tag] = put_response["image"] for tag, put_response in tag_map.items(): - response = client.describe_images(repositoryName='test_repository', imageIds=[{'imageTag': tag}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - image_detail['registryId'].should.equal("012345678910") - image_detail['repositoryName'].should.equal("test_repository") - image_detail['imageTags'].should.equal([put_response['imageId']['imageTag']]) - image_detail['imageDigest'].should.equal(put_response['imageId']['imageDigest']) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageTag": tag}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + image_detail["registryId"].should.equal("012345678910") + image_detail["repositoryName"].should.equal("test_repository") + image_detail["imageTags"].should.equal([put_response["imageId"]["imageTag"]]) + image_detail["imageDigest"].should.equal(put_response["imageId"]["imageDigest"]) @mock_ecr def test_describe_images_tags_should_not_contain_empty_tag1(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(manifest) + repositoryName="test_repository", imageManifest=json.dumps(manifest) ) - tags = ['v1', 'v2', 'latest'] + tags = ["v1", "v2", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - response = client.describe_images(repositoryName='test_repository', imageIds=[{'imageTag': tag}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - len(image_detail['imageTags']).should.equal(3) - image_detail['imageTags'].should.be.equal(tags) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageTag": tag}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + len(image_detail["imageTags"]).should.equal(3) + image_detail["imageTags"].should.be.equal(tags) @mock_ecr def test_describe_images_tags_should_not_contain_empty_tag2(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v2'] + tags = ["v1", "v2"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(manifest) + repositoryName="test_repository", imageManifest=json.dumps(manifest) ) client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag='latest' + imageTag="latest", ) - response = client.describe_images(repositoryName='test_repository', imageIds=[{'imageTag': tag}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - len(image_detail['imageTags']).should.equal(3) - image_detail['imageTags'].should.be.equal(['v1', 'v2', 'latest']) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageTag": tag}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + len(image_detail["imageTags"]).should.equal(3) + image_detail["imageTags"].should.be.equal(["v1", "v2", "latest"]) @mock_ecr def test_describe_repository_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') + client = boto3.client("ecr", region_name="us-east-1") error_msg = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.describe_repositories.when.called_with( - repositoryNames=['repo-that-doesnt-exist'], - registryId='123', + repositoryNames=["repo-that-doesnt-exist"], registryId="123" ).should.throw(ClientError, error_msg) + @mock_ecr def test_describe_image_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository(repositoryName='test_repository') + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") error_msg1 = re.compile( r".*The image with imageId {imageDigest:'null', imageTag:'testtag'} does not exist within " r"the repository with name 'test_repository' in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.describe_images.when.called_with( - repositoryName='test_repository', imageIds=[{'imageTag': 'testtag'}], registryId='123', + repositoryName="test_repository", + imageIds=[{"imageTag": "testtag"}], + registryId="123", ).should.throw(ClientError, error_msg1) error_msg2 = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.describe_images.when.called_with( - repositoryName='repo-that-doesnt-exist', imageIds=[{'imageTag': 'testtag'}], registryId='123', + repositoryName="repo-that-doesnt-exist", + imageIds=[{"imageTag": "testtag"}], + registryId="123", ).should.throw(ClientError, error_msg2) @mock_ecr def test_delete_repository_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') + client = boto3.client("ecr", region_name="us-east-1") error_msg = re.compile( r".*The repository with name 'repo-that-doesnt-exist' does not exist in the registry with id '123'.*", - re.MULTILINE) + re.MULTILINE, + ) client.delete_repository.when.called_with( - repositoryName='repo-that-doesnt-exist', - registryId='123').should.throw( - ClientError, error_msg) + repositoryName="repo-that-doesnt-exist", registryId="123" + ).should.throw(ClientError, error_msg) @mock_ecr def test_describe_images_by_digest(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") - tags = ['latest', 'v1', 'v2'] + tags = ["latest", "v1", "v2"] digest_map = {} for tag in tags: put_response = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag=tag + imageTag=tag, ) - digest_map[put_response['image']['imageId']['imageDigest']] = put_response['image'] + digest_map[put_response["image"]["imageId"]["imageDigest"]] = put_response[ + "image" + ] for digest, put_response in digest_map.items(): - response = client.describe_images(repositoryName='test_repository', - imageIds=[{'imageDigest': digest}]) - len(response['imageDetails']).should.be(1) - image_detail = response['imageDetails'][0] - image_detail['registryId'].should.equal("012345678910") - image_detail['repositoryName'].should.equal("test_repository") - image_detail['imageTags'].should.equal([put_response['imageId']['imageTag']]) - image_detail['imageDigest'].should.equal(digest) + response = client.describe_images( + repositoryName="test_repository", imageIds=[{"imageDigest": digest}] + ) + len(response["imageDetails"]).should.be(1) + image_detail = response["imageDetails"][0] + image_detail["registryId"].should.equal("012345678910") + image_detail["repositoryName"].should.equal("test_repository") + image_detail["imageTags"].should.equal([put_response["imageId"]["imageTag"]]) + image_detail["imageDigest"].should.equal(digest) @mock_ecr def test_get_authorization_token_assume_region(): - client = boto3.client('ecr', region_name='us-east-1') + client = boto3.client("ecr", region_name="us-east-1") auth_token_response = client.get_authorization_token() - auth_token_response.should.contain('authorizationData') - auth_token_response.should.contain('ResponseMetadata') - auth_token_response['authorizationData'].should.equal([ - { - 'authorizationToken': 'QVdTOjAxMjM0NTY3ODkxMC1hdXRoLXRva2Vu', - 'proxyEndpoint': 'https://012345678910.dkr.ecr.us-east-1.amazonaws.com', - 'expiresAt': datetime(2015, 1, 1, tzinfo=tzlocal()) - }, - ]) + auth_token_response.should.contain("authorizationData") + auth_token_response.should.contain("ResponseMetadata") + auth_token_response["authorizationData"].should.equal( + [ + { + "authorizationToken": "QVdTOjAxMjM0NTY3ODkxMC1hdXRoLXRva2Vu", + "proxyEndpoint": "https://012345678910.dkr.ecr.us-east-1.amazonaws.com", + "expiresAt": datetime(2015, 1, 1, tzinfo=tzlocal()), + } + ] + ) @mock_ecr def test_get_authorization_token_explicit_regions(): - client = boto3.client('ecr', region_name='us-east-1') - auth_token_response = client.get_authorization_token(registryIds=['10987654321', '878787878787']) + client = boto3.client("ecr", region_name="us-east-1") + auth_token_response = client.get_authorization_token( + registryIds=["10987654321", "878787878787"] + ) - auth_token_response.should.contain('authorizationData') - auth_token_response.should.contain('ResponseMetadata') - auth_token_response['authorizationData'].should.equal([ - { - 'authorizationToken': 'QVdTOjEwOTg3NjU0MzIxLWF1dGgtdG9rZW4=', - 'proxyEndpoint': 'https://10987654321.dkr.ecr.us-east-1.amazonaws.com', - 'expiresAt': datetime(2015, 1, 1, tzinfo=tzlocal()), - }, - { - 'authorizationToken': 'QVdTOjg3ODc4Nzg3ODc4Ny1hdXRoLXRva2Vu', - 'proxyEndpoint': 'https://878787878787.dkr.ecr.us-east-1.amazonaws.com', - 'expiresAt': datetime(2015, 1, 1, tzinfo=tzlocal()) - - } - ]) + auth_token_response.should.contain("authorizationData") + auth_token_response.should.contain("ResponseMetadata") + auth_token_response["authorizationData"].should.equal( + [ + { + "authorizationToken": "QVdTOjEwOTg3NjU0MzIxLWF1dGgtdG9rZW4=", + "proxyEndpoint": "https://10987654321.dkr.ecr.us-east-1.amazonaws.com", + "expiresAt": datetime(2015, 1, 1, tzinfo=tzlocal()), + }, + { + "authorizationToken": "QVdTOjg3ODc4Nzg3ODc4Ny1hdXRoLXRva2Vu", + "proxyEndpoint": "https://878787878787.dkr.ecr.us-east-1.amazonaws.com", + "expiresAt": datetime(2015, 1, 1, tzinfo=tzlocal()), + }, + ] + ) @mock_ecr def test_batch_get_image(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") + + _ = client.put_image( + repositoryName="test_repository", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="v1", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' - ) - - _ = client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="v2", ) response = client.batch_get_image( - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'v2' - }, - ], + repositoryName="test_repository", imageIds=[{"imageTag": "v2"}] ) - type(response['images']).should.be(list) - len(response['images']).should.be(1) + type(response["images"]).should.be(list) + len(response["images"]).should.be(1) - response['images'][0]['imageManifest'].should.contain("vnd.docker.distribution.manifest.v2+json") - response['images'][0]['registryId'].should.equal("012345678910") - response['images'][0]['repositoryName'].should.equal("test_repository") + response["images"][0]["imageManifest"].should.contain( + "vnd.docker.distribution.manifest.v2+json" + ) + response["images"][0]["registryId"].should.equal("012345678910") + response["images"][0]["repositoryName"].should.equal("test_repository") - response['images'][0]['imageId']['imageTag'].should.equal("v2") - response['images'][0]['imageId']['imageDigest'].should.contain("sha") + response["images"][0]["imageId"]["imageTag"].should.equal("v2") + response["images"][0]["imageId"]["imageDigest"].should.contain("sha") - type(response['failures']).should.be(list) - len(response['failures']).should.be(0) + type(response["failures"]).should.be(list) + len(response["failures"]).should.be(0) @mock_ecr def test_batch_get_image_that_doesnt_exist(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") + + _ = client.put_image( + repositoryName="test_repository", + imageManifest=json.dumps(_create_image_manifest()), + imageTag="latest", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="v1", ) _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1' - ) - - _ = client.put_image( - repositoryName='test_repository', - imageManifest=json.dumps(_create_image_manifest()), - imageTag='v2' + imageTag="v2", ) response = client.batch_get_image( - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'v5' - }, - ], + repositoryName="test_repository", imageIds=[{"imageTag": "v5"}] ) - type(response['images']).should.be(list) - len(response['images']).should.be(0) + type(response["images"]).should.be(list) + len(response["images"]).should.be(0) - type(response['failures']).should.be(list) - len(response['failures']).should.be(1) - response['failures'][0]['failureReason'].should.equal("Requested image not found") - response['failures'][0]['failureCode'].should.equal("ImageNotFound") - response['failures'][0]['imageId']['imageTag'].should.equal("v5") + type(response["failures"]).should.be(list) + len(response["failures"]).should.be(1) + response["failures"][0]["failureReason"].should.equal("Requested image not found") + response["failures"][0]["failureCode"].should.equal("ImageNotFound") + response["failures"][0]["imageId"]["imageTag"].should.equal("v5") @mock_ecr def test_batch_get_image_no_tags(): - client = boto3.client('ecr', region_name='us-east-1') - _ = client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + _ = client.create_repository(repositoryName="test_repository") _ = client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='latest' + imageTag="latest", ) error_msg = re.compile( - r".*Missing required parameter in input: \"imageIds\".*", - re.MULTILINE) + r".*Missing required parameter in input: \"imageIds\".*", re.MULTILINE + ) client.batch_get_image.when.called_with( - repositoryName='test_repository').should.throw( - ParamValidationError, error_msg) + repositoryName="test_repository" + ).should.throw(ParamValidationError, error_msg) @mock_ecr def test_batch_delete_image_by_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v1.0', 'latest'] + tags = ["v1", "v1.0", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), imageTag=tag, ) - describe_response1 = client.describe_images(repositoryName='test_repository') + describe_response1 = client.describe_images(repositoryName="test_repository") batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'latest' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageTag": "latest"}], ) - describe_response2 = client.describe_images(repositoryName='test_repository') + describe_response2 = client.describe_images(repositoryName="test_repository") - type(describe_response1['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response1['imageDetails'][0]['imageTags']).should.be(3) + type(describe_response1["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response1["imageDetails"][0]["imageTags"]).should.be(3) - type(describe_response2['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response2['imageDetails'][0]['imageTags']).should.be(2) + type(describe_response2["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response2["imageDetails"][0]["imageTags"]).should.be(2) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(1) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(1) - batch_delete_response['imageIds'][0]['imageTag'].should.equal("latest") + batch_delete_response["imageIds"][0]["imageTag"].should.equal("latest") - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_delete_last_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(_create_image_manifest()), - imageTag='v1', + imageTag="v1", ) - describe_response1 = client.describe_images(repositoryName='test_repository') + describe_response1 = client.describe_images(repositoryName="test_repository") batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': 'v1' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageTag": "v1"}], ) - describe_response2 = client.describe_images(repositoryName='test_repository') + describe_response2 = client.describe_images(repositoryName="test_repository") - type(describe_response1['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response1['imageDetails'][0]['imageTags']).should.be(1) + type(describe_response1["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response1["imageDetails"][0]["imageTags"]).should.be(1) - type(describe_response2['imageDetails']).should.be(list) - len(describe_response2['imageDetails']).should.be(0) + type(describe_response2["imageDetails"]).should.be(list) + len(describe_response2["imageDetails"]).should.be(0) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(1) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(1) - batch_delete_response['imageIds'][0]['imageTag'].should.equal("v1") + batch_delete_response["imageIds"][0]["imageTag"].should.equal("v1") - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_with_nonexistent_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v1.0', 'latest'] + tags = ["v1", "v1.0", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") missing_tag = "missing-tag" batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageTag': missing_tag - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageTag": missing_tag}], ) - type(describe_response['imageDetails'][0]['imageTags']).should.be(list) - len(describe_response['imageDetails'][0]['imageTags']).should.be(3) + type(describe_response["imageDetails"][0]["imageTags"]).should.be(list) + len(describe_response["imageDetails"][0]["imageTags"]).should.be(3) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - batch_delete_response['failures'][0]['imageId']['imageTag'].should.equal(missing_tag) - batch_delete_response['failures'][0]['failureCode'].should.equal("ImageNotFound") - batch_delete_response['failures'][0]['failureReason'].should.equal("Requested image not found") + batch_delete_response["failures"][0]["imageId"]["imageTag"].should.equal( + missing_tag + ) + batch_delete_response["failures"][0]["failureCode"].should.equal("ImageNotFound") + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Requested image not found" + ) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) @mock_ecr def test_batch_delete_image_by_digest(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v2', 'latest'] + tags = ["v1", "v2", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') - image_digest = describe_response['imageDetails'][0]['imageDigest'] + describe_response = client.describe_images(repositoryName="test_repository") + image_digest = describe_response["imageDetails"][0]["imageDigest"] batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': image_digest - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": image_digest}], ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") - type(describe_response['imageDetails']).should.be(list) - len(describe_response['imageDetails']).should.be(0) + type(describe_response["imageDetails"]).should.be(list) + len(describe_response["imageDetails"]).should.be(0) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(3) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(3) - batch_delete_response['imageIds'][0]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][1]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][2]['imageDigest'].should.equal(image_digest) + batch_delete_response["imageIds"][0]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][1]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][2]["imageDigest"].should.equal(image_digest) - set([ - batch_delete_response['imageIds'][0]['imageTag'], - batch_delete_response['imageIds'][1]['imageTag'], - batch_delete_response['imageIds'][2]['imageTag']]).should.equal(set(tags)) + set( + [ + batch_delete_response["imageIds"][0]["imageTag"], + batch_delete_response["imageIds"][1]["imageTag"], + batch_delete_response["imageIds"][2]["imageTag"], + ] + ).should.equal(set(tags)) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_with_invalid_digest(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v2', 'latest'] + tags = ["v1", "v2", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - invalid_image_digest = 'sha256:invalid-digest' + invalid_image_digest = "sha256:invalid-digest" batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': invalid_image_digest - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": invalid_image_digest}], ) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) - batch_delete_response['failures'][0]['imageId']['imageDigest'].should.equal(invalid_image_digest) - batch_delete_response['failures'][0]['failureCode'].should.equal("InvalidImageDigest") - batch_delete_response['failures'][0]['failureReason'].should.equal("Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'") + batch_delete_response["failures"][0]["imageId"]["imageDigest"].should.equal( + invalid_image_digest + ) + batch_delete_response["failures"][0]["failureCode"].should.equal( + "InvalidImageDigest" + ) + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Invalid request parameters: image digest should satisfy the regex '[a-zA-Z0-9-_+.]+:[a-fA-F0-9]+'" + ) @mock_ecr def test_batch_delete_image_with_missing_parameters(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - }, - ], + registryId="012345678910", repositoryName="test_repository", imageIds=[{}] ) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) - batch_delete_response['failures'][0]['failureCode'].should.equal("MissingDigestAndTag") - batch_delete_response['failures'][0]['failureReason'].should.equal("Invalid request parameters: both tag and digest cannot be null") + batch_delete_response["failures"][0]["failureCode"].should.equal( + "MissingDigestAndTag" + ) + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Invalid request parameters: both tag and digest cannot be null" + ) @mock_ecr def test_batch_delete_image_with_matching_digest_and_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'v1.0', 'latest'] + tags = ["v1", "v1.0", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') - image_digest = describe_response['imageDetails'][0]['imageDigest'] + describe_response = client.describe_images(repositoryName="test_repository") + image_digest = describe_response["imageDetails"][0]["imageDigest"] batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': image_digest, - 'imageTag': 'v1' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": image_digest, "imageTag": "v1"}], ) - describe_response = client.describe_images(repositoryName='test_repository') + describe_response = client.describe_images(repositoryName="test_repository") - type(describe_response['imageDetails']).should.be(list) - len(describe_response['imageDetails']).should.be(0) + type(describe_response["imageDetails"]).should.be(list) + len(describe_response["imageDetails"]).should.be(0) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(3) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(3) - batch_delete_response['imageIds'][0]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][1]['imageDigest'].should.equal(image_digest) - batch_delete_response['imageIds'][2]['imageDigest'].should.equal(image_digest) + batch_delete_response["imageIds"][0]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][1]["imageDigest"].should.equal(image_digest) + batch_delete_response["imageIds"][2]["imageDigest"].should.equal(image_digest) - set([ - batch_delete_response['imageIds'][0]['imageTag'], - batch_delete_response['imageIds'][1]['imageTag'], - batch_delete_response['imageIds'][2]['imageTag']]).should.equal(set(tags)) + set( + [ + batch_delete_response["imageIds"][0]["imageTag"], + batch_delete_response["imageIds"][1]["imageTag"], + batch_delete_response["imageIds"][2]["imageTag"], + ] + ).should.equal(set(tags)) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(0) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(0) @mock_ecr def test_batch_delete_image_with_mismatched_digest_and_tag(): - client = boto3.client('ecr', region_name='us-east-1') - client.create_repository( - repositoryName='test_repository' - ) + client = boto3.client("ecr", region_name="us-east-1") + client.create_repository(repositoryName="test_repository") manifest = _create_image_manifest() - tags = ['v1', 'latest'] + tags = ["v1", "latest"] for tag in tags: client.put_image( - repositoryName='test_repository', + repositoryName="test_repository", imageManifest=json.dumps(manifest), - imageTag=tag + imageTag=tag, ) - describe_response = client.describe_images(repositoryName='test_repository') - image_digest = describe_response['imageDetails'][0]['imageDigest'] + describe_response = client.describe_images(repositoryName="test_repository") + image_digest = describe_response["imageDetails"][0]["imageDigest"] batch_delete_response = client.batch_delete_image( - registryId='012345678910', - repositoryName='test_repository', - imageIds=[ - { - 'imageDigest': image_digest, - 'imageTag': 'v2' - }, - ], + registryId="012345678910", + repositoryName="test_repository", + imageIds=[{"imageDigest": image_digest, "imageTag": "v2"}], ) - type(batch_delete_response['imageIds']).should.be(list) - len(batch_delete_response['imageIds']).should.be(0) + type(batch_delete_response["imageIds"]).should.be(list) + len(batch_delete_response["imageIds"]).should.be(0) - type(batch_delete_response['failures']).should.be(list) - len(batch_delete_response['failures']).should.be(1) + type(batch_delete_response["failures"]).should.be(list) + len(batch_delete_response["failures"]).should.be(1) - batch_delete_response['failures'][0]['imageId']['imageDigest'].should.equal(image_digest) - batch_delete_response['failures'][0]['imageId']['imageTag'].should.equal("v2") - batch_delete_response['failures'][0]['failureCode'].should.equal("ImageNotFound") - batch_delete_response['failures'][0]['failureReason'].should.equal("Requested image not found") + batch_delete_response["failures"][0]["imageId"]["imageDigest"].should.equal( + image_digest + ) + batch_delete_response["failures"][0]["imageId"]["imageTag"].should.equal("v2") + batch_delete_response["failures"][0]["failureCode"].should.equal("ImageNotFound") + batch_delete_response["failures"][0]["failureReason"].should.equal( + "Requested image not found" + ) diff --git a/tests/test_ecs/test_ecs_boto3.py b/tests/test_ecs/test_ecs_boto3.py index 16d7a4d0d..224e6935b 100644 --- a/tests/test_ecs/test_ecs_boto3.py +++ b/tests/test_ecs/test_ecs_boto3.py @@ -18,658 +18,658 @@ from nose.tools import assert_raises @mock_ecs def test_create_cluster(): - client = boto3.client('ecs', region_name='us-east-1') - response = client.create_cluster( - clusterName='test_ecs_cluster' + client = boto3.client("ecs", region_name="us-east-1") + response = client.create_cluster(clusterName="test_ecs_cluster") + response["cluster"]["clusterName"].should.equal("test_ecs_cluster") + response["cluster"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" ) - response['cluster']['clusterName'].should.equal('test_ecs_cluster') - response['cluster']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['cluster']['status'].should.equal('ACTIVE') - response['cluster']['registeredContainerInstancesCount'].should.equal(0) - response['cluster']['runningTasksCount'].should.equal(0) - response['cluster']['pendingTasksCount'].should.equal(0) - response['cluster']['activeServicesCount'].should.equal(0) + response["cluster"]["status"].should.equal("ACTIVE") + response["cluster"]["registeredContainerInstancesCount"].should.equal(0) + response["cluster"]["runningTasksCount"].should.equal(0) + response["cluster"]["pendingTasksCount"].should.equal(0) + response["cluster"]["activeServicesCount"].should.equal(0) @mock_ecs def test_list_clusters(): - client = boto3.client('ecs', region_name='us-east-2') - _ = client.create_cluster( - clusterName='test_cluster0' - ) - _ = client.create_cluster( - clusterName='test_cluster1' - ) + client = boto3.client("ecs", region_name="us-east-2") + _ = client.create_cluster(clusterName="test_cluster0") + _ = client.create_cluster(clusterName="test_cluster1") response = client.list_clusters() - response['clusterArns'].should.contain( - 'arn:aws:ecs:us-east-2:012345678910:cluster/test_cluster0') - response['clusterArns'].should.contain( - 'arn:aws:ecs:us-east-2:012345678910:cluster/test_cluster1') + response["clusterArns"].should.contain( + "arn:aws:ecs:us-east-2:012345678910:cluster/test_cluster0" + ) + response["clusterArns"].should.contain( + "arn:aws:ecs:us-east-2:012345678910:cluster/test_cluster1" + ) @mock_ecs def test_describe_clusters(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") response = client.describe_clusters(clusters=["some-cluster"]) - response['failures'].should.contain({ - 'arn': 'arn:aws:ecs:us-east-1:012345678910:cluster/some-cluster', - 'reason': 'MISSING' - }) + response["failures"].should.contain( + { + "arn": "arn:aws:ecs:us-east-1:012345678910:cluster/some-cluster", + "reason": "MISSING", + } + ) + @mock_ecs def test_delete_cluster(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") + response = client.delete_cluster(cluster="test_ecs_cluster") + response["cluster"]["clusterName"].should.equal("test_ecs_cluster") + response["cluster"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" ) - response = client.delete_cluster(cluster='test_ecs_cluster') - response['cluster']['clusterName'].should.equal('test_ecs_cluster') - response['cluster']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['cluster']['status'].should.equal('ACTIVE') - response['cluster']['registeredContainerInstancesCount'].should.equal(0) - response['cluster']['runningTasksCount'].should.equal(0) - response['cluster']['pendingTasksCount'].should.equal(0) - response['cluster']['activeServicesCount'].should.equal(0) + response["cluster"]["status"].should.equal("ACTIVE") + response["cluster"]["registeredContainerInstancesCount"].should.equal(0) + response["cluster"]["runningTasksCount"].should.equal(0) + response["cluster"]["pendingTasksCount"].should.equal(0) + response["cluster"]["activeServicesCount"].should.equal(0) response = client.list_clusters() - len(response['clusterArns']).should.equal(0) + len(response["clusterArns"]).should.equal(0) @mock_ecs def test_register_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") response = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } ], tags=[ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ] + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], ) - type(response['taskDefinition']).should.be(dict) - response['taskDefinition']['revision'].should.equal(1) - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['taskDefinition']['containerDefinitions'][ - 0]['name'].should.equal('hello_world') - response['taskDefinition']['containerDefinitions'][0][ - 'image'].should.equal('docker/hello-world:latest') - response['taskDefinition']['containerDefinitions'][ - 0]['cpu'].should.equal(1024) - response['taskDefinition']['containerDefinitions'][ - 0]['memory'].should.equal(400) - response['taskDefinition']['containerDefinitions'][ - 0]['essential'].should.equal(True) - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['name'].should.equal('AWS_ACCESS_KEY_ID') - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['value'].should.equal('SOME_ACCESS_KEY') - response['taskDefinition']['containerDefinitions'][0][ - 'logConfiguration']['logDriver'].should.equal('json-file') + type(response["taskDefinition"]).should.be(dict) + response["taskDefinition"]["revision"].should.equal(1) + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["taskDefinition"]["containerDefinitions"][0]["name"].should.equal( + "hello_world" + ) + response["taskDefinition"]["containerDefinitions"][0]["image"].should.equal( + "docker/hello-world:latest" + ) + response["taskDefinition"]["containerDefinitions"][0]["cpu"].should.equal(1024) + response["taskDefinition"]["containerDefinitions"][0]["memory"].should.equal(400) + response["taskDefinition"]["containerDefinitions"][0]["essential"].should.equal( + True + ) + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "name" + ].should.equal("AWS_ACCESS_KEY_ID") + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "value" + ].should.equal("SOME_ACCESS_KEY") + response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][ + "logDriver" + ].should.equal("json-file") @mock_ecs def test_list_task_definitions(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world2', - 'image': 'docker/hello-world2:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY2' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world2", + "image": "docker/hello-world2:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY2"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.list_task_definitions() - len(response['taskDefinitionArns']).should.equal(2) - response['taskDefinitionArns'][0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['taskDefinitionArns'][1].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2') + len(response["taskDefinitionArns"]).should.equal(2) + response["taskDefinitionArns"][0].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["taskDefinitionArns"][1].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2" + ) @mock_ecs def test_describe_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world2', - 'image': 'docker/hello-world2:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY2' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world2", + "image": "docker/hello-world2:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY2"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world3', - 'image': 'docker/hello-world3:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY3' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world3", + "image": "docker/hello-world3:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY3"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], + ) + response = client.describe_task_definition(taskDefinition="test_ecs_task") + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:3" ) - response = client.describe_task_definition(taskDefinition='test_ecs_task') - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:3') - response = client.describe_task_definition( - taskDefinition='test_ecs_task:2') - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2') + response = client.describe_task_definition(taskDefinition="test_ecs_task:2") + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:2" + ) @mock_ecs def test_deregister_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) - response = client.deregister_task_definition( - taskDefinition='test_ecs_task:1' + response = client.deregister_task_definition(taskDefinition="test_ecs_task:1") + type(response["taskDefinition"]).should.be(dict) + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - type(response['taskDefinition']).should.be(dict) - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['taskDefinition']['containerDefinitions'][ - 0]['name'].should.equal('hello_world') - response['taskDefinition']['containerDefinitions'][0][ - 'image'].should.equal('docker/hello-world:latest') - response['taskDefinition']['containerDefinitions'][ - 0]['cpu'].should.equal(1024) - response['taskDefinition']['containerDefinitions'][ - 0]['memory'].should.equal(400) - response['taskDefinition']['containerDefinitions'][ - 0]['essential'].should.equal(True) - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['name'].should.equal('AWS_ACCESS_KEY_ID') - response['taskDefinition']['containerDefinitions'][0][ - 'environment'][0]['value'].should.equal('SOME_ACCESS_KEY') - response['taskDefinition']['containerDefinitions'][0][ - 'logConfiguration']['logDriver'].should.equal('json-file') + response["taskDefinition"]["containerDefinitions"][0]["name"].should.equal( + "hello_world" + ) + response["taskDefinition"]["containerDefinitions"][0]["image"].should.equal( + "docker/hello-world:latest" + ) + response["taskDefinition"]["containerDefinitions"][0]["cpu"].should.equal(1024) + response["taskDefinition"]["containerDefinitions"][0]["memory"].should.equal(400) + response["taskDefinition"]["containerDefinitions"][0]["essential"].should.equal( + True + ) + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "name" + ].should.equal("AWS_ACCESS_KEY_ID") + response["taskDefinition"]["containerDefinitions"][0]["environment"][0][ + "value" + ].should.equal("SOME_ACCESS_KEY") + response["taskDefinition"]["containerDefinitions"][0]["logConfiguration"][ + "logDriver" + ].should.equal("json-file") @mock_ecs def test_create_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(2) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(0) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['service']['schedulingStrategy'].should.equal('REPLICA') + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(2) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(0) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["service"]["schedulingStrategy"].should.equal("REPLICA") + @mock_ecs def test_create_service_scheduling_strategy(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, - schedulingStrategy='DAEMON', + schedulingStrategy="DAEMON", ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(2) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(0) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['service']['schedulingStrategy'].should.equal('DAEMON') + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(2) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(0) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["service"]["schedulingStrategy"].should.equal("DAEMON") @mock_ecs def test_list_services(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service1', - taskDefinition='test_ecs_task', - schedulingStrategy='REPLICA', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service1", + taskDefinition="test_ecs_task", + schedulingStrategy="REPLICA", + desiredCount=2, ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service2', - taskDefinition='test_ecs_task', - schedulingStrategy='DAEMON', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service2", + taskDefinition="test_ecs_task", + schedulingStrategy="DAEMON", + desiredCount=2, ) - unfiltered_response = client.list_services( - cluster='test_ecs_cluster' + unfiltered_response = client.list_services(cluster="test_ecs_cluster") + len(unfiltered_response["serviceArns"]).should.equal(2) + unfiltered_response["serviceArns"][0].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + unfiltered_response["serviceArns"][1].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2" ) - len(unfiltered_response['serviceArns']).should.equal(2) - unfiltered_response['serviceArns'][0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') - unfiltered_response['serviceArns'][1].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2') filtered_response = client.list_services( - cluster='test_ecs_cluster', - schedulingStrategy='REPLICA' + cluster="test_ecs_cluster", schedulingStrategy="REPLICA" ) - len(filtered_response['serviceArns']).should.equal(1) - filtered_response['serviceArns'][0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') + len(filtered_response["serviceArns"]).should.equal(1) + filtered_response["serviceArns"][0].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + @mock_ecs def test_describe_services(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service1', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service1", + taskDefinition="test_ecs_task", + desiredCount=2, ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service2', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service2", + taskDefinition="test_ecs_task", + desiredCount=2, ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service3', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service3", + taskDefinition="test_ecs_task", + desiredCount=2, ) response = client.describe_services( - cluster='test_ecs_cluster', - services=['test_ecs_service1', - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2'] + cluster="test_ecs_cluster", + services=[ + "test_ecs_service1", + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2", + ], ) - len(response['services']).should.equal(2) - response['services'][0]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') - response['services'][0]['serviceName'].should.equal('test_ecs_service1') - response['services'][1]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2') - response['services'][1]['serviceName'].should.equal('test_ecs_service2') + len(response["services"]).should.equal(2) + response["services"][0]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + response["services"][0]["serviceName"].should.equal("test_ecs_service1") + response["services"][1]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2" + ) + response["services"][1]["serviceName"].should.equal("test_ecs_service2") - response['services'][0]['deployments'][0]['desiredCount'].should.equal(2) - response['services'][0]['deployments'][0]['pendingCount'].should.equal(2) - response['services'][0]['deployments'][0]['runningCount'].should.equal(0) - response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY') - (datetime.now() - response['services'][0]['deployments'][0]["createdAt"].replace(tzinfo=None)).seconds.should.be.within(0, 10) - (datetime.now() - response['services'][0]['deployments'][0]["updatedAt"].replace(tzinfo=None)).seconds.should.be.within(0, 10) + response["services"][0]["deployments"][0]["desiredCount"].should.equal(2) + response["services"][0]["deployments"][0]["pendingCount"].should.equal(2) + response["services"][0]["deployments"][0]["runningCount"].should.equal(0) + response["services"][0]["deployments"][0]["status"].should.equal("PRIMARY") + ( + datetime.now() + - response["services"][0]["deployments"][0]["createdAt"].replace(tzinfo=None) + ).seconds.should.be.within(0, 10) + ( + datetime.now() + - response["services"][0]["deployments"][0]["updatedAt"].replace(tzinfo=None) + ).seconds.should.be.within(0, 10) @mock_ecs def test_describe_services_scheduling_strategy(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service1', - taskDefinition='test_ecs_task', - desiredCount=2 - ) - _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service2', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service1", + taskDefinition="test_ecs_task", desiredCount=2, - schedulingStrategy='DAEMON' ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service3', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service2", + taskDefinition="test_ecs_task", + desiredCount=2, + schedulingStrategy="DAEMON", + ) + _ = client.create_service( + cluster="test_ecs_cluster", + serviceName="test_ecs_service3", + taskDefinition="test_ecs_task", + desiredCount=2, ) response = client.describe_services( - cluster='test_ecs_cluster', - services=['test_ecs_service1', - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2', - 'test_ecs_service3'] + cluster="test_ecs_cluster", + services=[ + "test_ecs_service1", + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2", + "test_ecs_service3", + ], ) - len(response['services']).should.equal(3) - response['services'][0]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1') - response['services'][0]['serviceName'].should.equal('test_ecs_service1') - response['services'][1]['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2') - response['services'][1]['serviceName'].should.equal('test_ecs_service2') + len(response["services"]).should.equal(3) + response["services"][0]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service1" + ) + response["services"][0]["serviceName"].should.equal("test_ecs_service1") + response["services"][1]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service2" + ) + response["services"][1]["serviceName"].should.equal("test_ecs_service2") - response['services'][0]['deployments'][0]['desiredCount'].should.equal(2) - response['services'][0]['deployments'][0]['pendingCount'].should.equal(2) - response['services'][0]['deployments'][0]['runningCount'].should.equal(0) - response['services'][0]['deployments'][0]['status'].should.equal('PRIMARY') + response["services"][0]["deployments"][0]["desiredCount"].should.equal(2) + response["services"][0]["deployments"][0]["pendingCount"].should.equal(2) + response["services"][0]["deployments"][0]["runningCount"].should.equal(0) + response["services"][0]["deployments"][0]["status"].should.equal("PRIMARY") - response['services'][0]['schedulingStrategy'].should.equal('REPLICA') - response['services'][1]['schedulingStrategy'].should.equal('DAEMON') - response['services'][2]['schedulingStrategy'].should.equal('REPLICA') + response["services"][0]["schedulingStrategy"].should.equal("REPLICA") + response["services"][1]["schedulingStrategy"].should.equal("DAEMON") + response["services"][2]["schedulingStrategy"].should.equal("REPLICA") @mock_ecs def test_update_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, ) - response['service']['desiredCount'].should.equal(2) + response["service"]["desiredCount"].should.equal(2) response = client.update_service( - cluster='test_ecs_cluster', - service='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=0 + cluster="test_ecs_cluster", + service="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=0, ) - response['service']['desiredCount'].should.equal(0) - response['service']['schedulingStrategy'].should.equal('REPLICA') + response["service"]["desiredCount"].should.equal(0) + response["service"]["schedulingStrategy"].should.equal("REPLICA") @mock_ecs def test_update_missing_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") client.update_service.when.called_with( - cluster='test_ecs_cluster', - service='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=0 + cluster="test_ecs_cluster", + service="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=0, ).should.throw(ClientError) @mock_ecs def test_delete_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, ) _ = client.update_service( - cluster='test_ecs_cluster', - service='test_ecs_service', - desiredCount=0 + cluster="test_ecs_cluster", service="test_ecs_service", desiredCount=0 ) response = client.delete_service( - cluster='test_ecs_cluster', - service='test_ecs_service' + cluster="test_ecs_cluster", service="test_ecs_service" + ) + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(0) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(0) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["schedulingStrategy"].should.equal("REPLICA") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(0) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(0) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['schedulingStrategy'].should.equal('REPLICA') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') @mock_ecs def test_update_non_existant_service(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") try: client.update_service( - cluster="my-clustet", - service="my-service", - desiredCount=0, + cluster="my-clustet", service="my-service", desiredCount=0 ) except ClientError as exc: - error_code = exc.response['Error']['Code'] - error_code.should.equal('ServiceNotFoundException') + error_code = exc.response["Error"]["Code"] + error_code.should.equal("ServiceNotFoundException") else: raise Exception("Didn't raise ClientError") @@ -677,19 +677,15 @@ def test_update_non_existant_service(): @mock_ec2 @mock_ecs def test_register_container_instance(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -697,45 +693,37 @@ def test_register_container_instance(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn = response['containerInstance']['containerInstanceArn'] - arn_part = full_arn.split('/') - arn_part[0].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:container-instance') + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn = response["containerInstance"]["containerInstanceArn"] + arn_part = full_arn.split("/") + arn_part[0].should.equal("arn:aws:ecs:us-east-1:012345678910:container-instance") arn_part[1].should.equal(str(UUID(arn_part[1]))) - response['containerInstance']['status'].should.equal('ACTIVE') - len(response['containerInstance']['registeredResources']).should.equal(4) - len(response['containerInstance']['remainingResources']).should.equal(4) - response['containerInstance']['agentConnected'].should.equal(True) - response['containerInstance']['versionInfo'][ - 'agentVersion'].should.equal('1.0.0') - response['containerInstance']['versionInfo'][ - 'agentHash'].should.equal('4023248') - response['containerInstance']['versionInfo'][ - 'dockerVersion'].should.equal('DockerVersion: 1.5.0') + response["containerInstance"]["status"].should.equal("ACTIVE") + len(response["containerInstance"]["registeredResources"]).should.equal(4) + len(response["containerInstance"]["remainingResources"]).should.equal(4) + response["containerInstance"]["agentConnected"].should.equal(True) + response["containerInstance"]["versionInfo"]["agentVersion"].should.equal("1.0.0") + response["containerInstance"]["versionInfo"]["agentHash"].should.equal("4023248") + response["containerInstance"]["versionInfo"]["dockerVersion"].should.equal( + "DockerVersion: 1.5.0" + ) @mock_ec2 @mock_ecs def test_deregister_container_instance(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -743,87 +731,76 @@ def test_deregister_container_instance(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instance_id = response['containerInstance']['containerInstanceArn'] + container_instance_id = response["containerInstance"]["containerInstanceArn"] response = ecs_client.deregister_container_instance( - cluster=test_cluster_name, - containerInstance=container_instance_id + cluster=test_cluster_name, containerInstance=container_instance_id ) container_instances_response = ecs_client.list_container_instances( cluster=test_cluster_name ) - len(container_instances_response['containerInstanceArns']).should.equal(0) + len(container_instances_response["containerInstanceArns"]).should.equal(0) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instance_id = response['containerInstance']['containerInstanceArn'] + container_instance_id = response["containerInstance"]["containerInstanceArn"] _ = ecs_client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = ecs_client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='moto' + startedBy="moto", ) with assert_raises(Exception) as e: ecs_client.deregister_container_instance( - cluster=test_cluster_name, - containerInstance=container_instance_id + cluster=test_cluster_name, containerInstance=container_instance_id ).should.have.raised(Exception) container_instances_response = ecs_client.list_container_instances( cluster=test_cluster_name ) - len(container_instances_response['containerInstanceArns']).should.equal(1) + len(container_instances_response["containerInstanceArns"]).should.equal(1) ecs_client.deregister_container_instance( - cluster=test_cluster_name, - containerInstance=container_instance_id, - force=True + cluster=test_cluster_name, containerInstance=container_instance_id, force=True ) container_instances_response = ecs_client.list_container_instances( cluster=test_cluster_name ) - len(container_instances_response['containerInstanceArns']).should.equal(0) + len(container_instances_response["containerInstanceArns"]).should.equal(0) @mock_ec2 @mock_ecs def test_list_container_instances(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -831,37 +808,32 @@ def test_list_container_instances(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance'][ - 'containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) response = ecs_client.list_container_instances(cluster=test_cluster_name) - len(response['containerInstanceArns']).should.equal(instance_to_create) + len(response["containerInstanceArns"]).should.equal(instance_to_create) for arn in test_instance_arns: - response['containerInstanceArns'].should.contain(arn) + response["containerInstanceArns"].should.contain(arn) @mock_ec2 @mock_ecs def test_describe_container_instances(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -869,51 +841,46 @@ def test_describe_container_instances(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance'][ - 'containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) - test_instance_ids = list( - map((lambda x: x.split('/')[1]), test_instance_arns)) + test_instance_ids = list(map((lambda x: x.split("/")[1]), test_instance_arns)) response = ecs_client.describe_container_instances( - cluster=test_cluster_name, containerInstances=test_instance_ids) - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_arns = [ci['containerInstanceArn'] - for ci in response['containerInstances']] + cluster=test_cluster_name, containerInstances=test_instance_ids + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_arns = [ + ci["containerInstanceArn"] for ci in response["containerInstances"] + ] for arn in test_instance_arns: response_arns.should.contain(arn) - for instance in response['containerInstances']: - instance.keys().should.contain('runningTasksCount') - instance.keys().should.contain('pendingTasksCount') + for instance in response["containerInstances"]: + instance.keys().should.contain("runningTasksCount") + instance.keys().should.contain("pendingTasksCount") with assert_raises(ClientError) as e: ecs_client.describe_container_instances( - cluster=test_cluster_name, - containerInstances=[] + cluster=test_cluster_name, containerInstances=[] ) @mock_ec2 @mock_ecs def test_update_container_instances_state(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -921,59 +888,61 @@ def test_update_container_instances_state(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance']['containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) - test_instance_ids = list(map((lambda x: x.split('/')[1]), test_instance_arns)) - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + test_instance_ids = list(map((lambda x: x.split("/")[1]), test_instance_arns)) + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_ids, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_ids, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='ACTIVE') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, containerInstances=test_instance_ids, status="ACTIVE" + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('ACTIVE') - ecs_client.update_container_instances_state.when.called_with(cluster=test_cluster_name, - containerInstances=test_instance_ids, - status='test_status').should.throw(Exception) + status.should.equal("ACTIVE") + ecs_client.update_container_instances_state.when.called_with( + cluster=test_cluster_name, + containerInstances=test_instance_ids, + status="test_status", + ).should.throw(Exception) @mock_ec2 @mock_ecs def test_update_container_instances_state_by_arn(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + test_cluster_name = "test_ecs_cluster" + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instance_to_create = 3 test_instance_arns = [] for i in range(0, instance_to_create): test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -981,56 +950,60 @@ def test_update_container_instances_state_by_arn(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document) + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document + ) - test_instance_arns.append(response['containerInstance']['containerInstanceArn']) + test_instance_arns.append(response["containerInstance"]["containerInstanceArn"]) - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='DRAINING') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="DRAINING", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('DRAINING') - response = ecs_client.update_container_instances_state(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='ACTIVE') - len(response['failures']).should.equal(0) - len(response['containerInstances']).should.equal(instance_to_create) - response_statuses = [ci['status'] for ci in response['containerInstances']] + status.should.equal("DRAINING") + response = ecs_client.update_container_instances_state( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="ACTIVE", + ) + len(response["failures"]).should.equal(0) + len(response["containerInstances"]).should.equal(instance_to_create) + response_statuses = [ci["status"] for ci in response["containerInstances"]] for status in response_statuses: - status.should.equal('ACTIVE') - ecs_client.update_container_instances_state.when.called_with(cluster=test_cluster_name, - containerInstances=test_instance_arns, - status='test_status').should.throw(Exception) + status.should.equal("ACTIVE") + ecs_client.update_container_instances_state.when.called_with( + cluster=test_cluster_name, + containerInstances=test_instance_arns, + status="test_status", + ).should.throw(Exception) @mock_ec2 @mock_ecs def test_run_task(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1038,66 +1011,64 @@ def test_run_task(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(2) - response['tasks'][0]['taskArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:task/') - response['tasks'][0]['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['tasks'][0]['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['tasks'][0]['containerInstanceArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:container-instance/') - response['tasks'][0]['overrides'].should.equal({}) - response['tasks'][0]['lastStatus'].should.equal("RUNNING") - response['tasks'][0]['desiredStatus'].should.equal("RUNNING") - response['tasks'][0]['startedBy'].should.equal("moto") - response['tasks'][0]['stoppedReason'].should.equal("") + len(response["tasks"]).should.equal(2) + response["tasks"][0]["taskArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:task/" + ) + response["tasks"][0]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["tasks"][0]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["tasks"][0]["containerInstanceArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:container-instance/" + ) + response["tasks"][0]["overrides"].should.equal({}) + response["tasks"][0]["lastStatus"].should.equal("RUNNING") + response["tasks"][0]["desiredStatus"].should.equal("RUNNING") + response["tasks"][0]["startedBy"].should.equal("moto") + response["tasks"][0]["stoppedReason"].should.equal("") @mock_ec2 @mock_ecs def test_start_task(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1105,73 +1076,73 @@ def test_start_task(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instances = client.list_container_instances( - cluster=test_cluster_name) - container_instance_id = container_instances[ - 'containerInstanceArns'][0].split('/')[-1] + container_instances = client.list_container_instances(cluster=test_cluster_name) + container_instance_id = container_instances["containerInstanceArns"][0].split("/")[ + -1 + ] _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(1) - response['tasks'][0]['taskArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:task/') - response['tasks'][0]['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['tasks'][0]['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['tasks'][0]['containerInstanceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:container-instance/{0}'.format(container_instance_id)) - response['tasks'][0]['overrides'].should.equal({}) - response['tasks'][0]['lastStatus'].should.equal("RUNNING") - response['tasks'][0]['desiredStatus'].should.equal("RUNNING") - response['tasks'][0]['startedBy'].should.equal("moto") - response['tasks'][0]['stoppedReason'].should.equal("") + len(response["tasks"]).should.equal(1) + response["tasks"][0]["taskArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:task/" + ) + response["tasks"][0]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["tasks"][0]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["tasks"][0]["containerInstanceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:container-instance/{0}".format( + container_instance_id + ) + ) + response["tasks"][0]["overrides"].should.equal({}) + response["tasks"][0]["lastStatus"].should.equal("RUNNING") + response["tasks"][0]["desiredStatus"].should.equal("RUNNING") + response["tasks"][0]["startedBy"].should.equal("moto") + response["tasks"][0]["stoppedReason"].should.equal("") @mock_ec2 @mock_ecs def test_list_tasks(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1179,71 +1150,66 @@ def test_list_tasks(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - container_instances = client.list_container_instances( - cluster=test_cluster_name) - container_instance_id = container_instances[ - 'containerInstanceArns'][0].split('/')[-1] + container_instances = client.list_container_instances(cluster=test_cluster_name) + container_instance_id = container_instances["containerInstanceArns"][0].split("/")[ + -1 + ] _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) _ = client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='foo' + startedBy="foo", ) _ = client.start_task( - cluster='test_ecs_cluster', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + taskDefinition="test_ecs_task", overrides={}, containerInstances=[container_instance_id], - startedBy='bar' + startedBy="bar", ) - assert len(client.list_tasks()['taskArns']).should.equal(2) - assert len(client.list_tasks(cluster='test_ecs_cluster') - ['taskArns']).should.equal(2) - assert len(client.list_tasks(startedBy='foo')['taskArns']).should.equal(1) + assert len(client.list_tasks()["taskArns"]).should.equal(2) + assert len(client.list_tasks(cluster="test_ecs_cluster")["taskArns"]).should.equal( + 2 + ) + assert len(client.list_tasks(startedBy="foo")["taskArns"]).should.equal(1) @mock_ec2 @mock_ecs def test_describe_tasks(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1251,96 +1217,85 @@ def test_describe_tasks(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) tasks_arns = [ - task['taskArn'] for task in client.run_task( - cluster='test_ecs_cluster', + task["taskArn"] + for task in client.run_task( + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' - )['tasks'] + startedBy="moto", + )["tasks"] ] - response = client.describe_tasks( - cluster='test_ecs_cluster', - tasks=tasks_arns - ) + response = client.describe_tasks(cluster="test_ecs_cluster", tasks=tasks_arns) - len(response['tasks']).should.equal(2) - set([response['tasks'][0]['taskArn'], response['tasks'] - [1]['taskArn']]).should.equal(set(tasks_arns)) + len(response["tasks"]).should.equal(2) + set( + [response["tasks"][0]["taskArn"], response["tasks"][1]["taskArn"]] + ).should.equal(set(tasks_arns)) # Test we can pass task ids instead of ARNs response = client.describe_tasks( - cluster='test_ecs_cluster', - tasks=[tasks_arns[0].split("/")[-1]] + cluster="test_ecs_cluster", tasks=[tasks_arns[0].split("/")[-1]] ) - len(response['tasks']).should.equal(1) + len(response["tasks"]).should.equal(1) @mock_ecs def describe_task_definition(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") container_definition = { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [{"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"}], + "logConfiguration": {"logDriver": "json-file"}, } task_definition = client.register_task_definition( - family='test_ecs_task', - containerDefinitions=[container_definition] + family="test_ecs_task", containerDefinitions=[container_definition] ) - family = task_definition['family'] + family = task_definition["family"] task = client.describe_task_definition(taskDefinition=family) - task['containerDefinitions'][0].should.equal(container_definition) - task['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task2:1') - task['volumes'].should.equal([]) + task["containerDefinitions"][0].should.equal(container_definition) + task["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task2:1" + ) + task["volumes"].should.equal([]) @mock_ec2 @mock_ecs def test_stop_task(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1348,63 +1303,58 @@ def test_stop_task(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) run_response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=1, - startedBy='moto' + startedBy="moto", ) stop_response = client.stop_task( - cluster='test_ecs_cluster', - task=run_response['tasks'][0].get('taskArn'), - reason='moto testing' + cluster="test_ecs_cluster", + task=run_response["tasks"][0].get("taskArn"), + reason="moto testing", ) - stop_response['task']['taskArn'].should.equal( - run_response['tasks'][0].get('taskArn')) - stop_response['task']['lastStatus'].should.equal('STOPPED') - stop_response['task']['desiredStatus'].should.equal('STOPPED') - stop_response['task']['stoppedReason'].should.equal('moto testing') + stop_response["task"]["taskArn"].should.equal( + run_response["tasks"][0].get("taskArn") + ) + stop_response["task"]["lastStatus"].should.equal("STOPPED") + stop_response["task"]["desiredStatus"].should.equal("STOPPED") + stop_response["task"]["stoppedReason"].should.equal("moto testing") @mock_ec2 @mock_ecs def test_resource_reservation_and_release(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1412,84 +1362,74 @@ def test_resource_reservation_and_release(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'}, - 'portMappings': [ - { - 'hostPort': 80, - 'containerPort': 8080 - } - ] + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + "portMappings": [{"hostPort": 80, "containerPort": 8080}], } - ] + ], ) run_response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=1, - startedBy='moto' + startedBy="moto", ) - container_instance_arn = run_response['tasks'][0].get('containerInstanceArn') + container_instance_arn = run_response["tasks"][0].get("containerInstanceArn") container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] remaining_resources, registered_resources = _fetch_container_instance_resources( - container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU'] - 1024) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY'] - 400) - registered_resources['PORTS'].append('80') - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(1) + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"] - 1024) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"] - 400) + registered_resources["PORTS"].append("80") + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(1) client.stop_task( - cluster='test_ecs_cluster', - task=run_response['tasks'][0].get('taskArn'), - reason='moto testing' + cluster="test_ecs_cluster", + task=run_response["tasks"][0].get("taskArn"), + reason="moto testing", ) container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] remaining_resources, registered_resources = _fetch_container_instance_resources( - container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU']) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY']) - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(0) + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"]) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"]) + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(0) + @mock_ec2 @mock_ecs def test_resource_reservation_and_release_memory_reservation(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1497,63 +1437,58 @@ def test_resource_reservation_and_release_memory_reservation(): ) _ = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'memoryReservation': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'}, - 'portMappings': [ - { - 'containerPort': 8080 - } - ] + "name": "hello_world", + "image": "docker/hello-world:latest", + "memoryReservation": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + "portMappings": [{"containerPort": 8080}], } - ] + ], ) run_response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=1, - startedBy='moto' + startedBy="moto", ) - container_instance_arn = run_response['tasks'][0].get('containerInstanceArn') + container_instance_arn = run_response["tasks"][0].get("containerInstanceArn") container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] - remaining_resources, registered_resources = _fetch_container_instance_resources(container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU']) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY'] - 400) - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(1) + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] + remaining_resources, registered_resources = _fetch_container_instance_resources( + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"]) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"] - 400) + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(1) client.stop_task( - cluster='test_ecs_cluster', - task=run_response['tasks'][0].get('taskArn'), - reason='moto testing' + cluster="test_ecs_cluster", + task=run_response["tasks"][0].get("taskArn"), + reason="moto testing", ) container_instance_description = client.describe_container_instances( - cluster='test_ecs_cluster', - containerInstances=[container_instance_arn] - )['containerInstances'][0] - remaining_resources, registered_resources = _fetch_container_instance_resources(container_instance_description) - remaining_resources['CPU'].should.equal(registered_resources['CPU']) - remaining_resources['MEMORY'].should.equal(registered_resources['MEMORY']) - remaining_resources['PORTS'].should.equal(registered_resources['PORTS']) - container_instance_description['runningTasksCount'].should.equal(0) - + cluster="test_ecs_cluster", containerInstances=[container_instance_arn] + )["containerInstances"][0] + remaining_resources, registered_resources = _fetch_container_instance_resources( + container_instance_description + ) + remaining_resources["CPU"].should.equal(registered_resources["CPU"]) + remaining_resources["MEMORY"].should.equal(registered_resources["MEMORY"]) + remaining_resources["PORTS"].should.equal(registered_resources["PORTS"]) + container_instance_description["runningTasksCount"].should.equal(0) @mock_ecs @@ -1565,26 +1500,21 @@ def test_create_cluster_through_cloudformation(): "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster" - } + "Properties": {"ClusterName": "testcluster"}, } - } + }, } template_json = json.dumps(template) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(0) + len(resp["clusterArns"]).should.equal(0) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json) resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(1) + len(resp["clusterArns"]).should.equal(1) @mock_ecs @@ -1595,22 +1525,15 @@ def test_create_cluster_through_cloudformation_no_name(): template = { "AWSTemplateFormatVersion": "2010-09-09", "Description": "ECS Cluster Test CloudFormation", - "Resources": { - "testCluster": { - "Type": "AWS::ECS::Cluster", - } - } + "Resources": {"testCluster": {"Type": "AWS::ECS::Cluster"}}, } template_json = json.dumps(template) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(1) + len(resp["clusterArns"]).should.equal(1) @mock_ecs @@ -1622,31 +1545,24 @@ def test_update_cluster_name_through_cloudformation_should_trigger_a_replacement "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster1" - } + "Properties": {"ClusterName": "testcluster1"}, } - } + }, } template2 = deepcopy(template1) - template2['Resources']['testCluster'][ - 'Properties']['ClusterName'] = 'testcluster2' + template2["Resources"]["testCluster"]["Properties"]["ClusterName"] = "testcluster2" template1_json = json.dumps(template1) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") stack_resp = cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template1_json, + StackName="test_stack", TemplateBody=template1_json ) template2_json = json.dumps(template2) - cfn_conn.update_stack( - StackName=stack_resp['StackId'], - TemplateBody=template2_json - ) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + cfn_conn.update_stack(StackName=stack_resp["StackId"], TemplateBody=template2_json) + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_clusters() - len(resp['clusterArns']).should.equal(1) - resp['clusterArns'][0].endswith('testcluster2').should.be.true + len(resp["clusterArns"]).should.equal(1) + resp["clusterArns"][0].endswith("testcluster2").should.be.true @mock_ecs @@ -1665,47 +1581,42 @@ def test_create_task_definition_through_cloudformation(): "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, } - } + }, } template_json = json.dumps(template) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - stack_name = 'test_stack' - cfn_conn.create_stack( - StackName=stack_name, - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + stack_name = "test_stack" + cfn_conn.create_stack(StackName=stack_name, TemplateBody=template_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') + ecs_conn = boto3.client("ecs", region_name="us-west-1") resp = ecs_conn.list_task_definitions() - len(resp['taskDefinitionArns']).should.equal(1) - task_definition_arn = resp['taskDefinitionArns'][0] + len(resp["taskDefinitionArns"]).should.equal(1) + task_definition_arn = resp["taskDefinitionArns"][0] task_definition_details = cfn_conn.describe_stack_resource( - StackName=stack_name,LogicalResourceId='testTaskDefinition')['StackResourceDetail'] - task_definition_details['PhysicalResourceId'].should.equal(task_definition_arn) + StackName=stack_name, LogicalResourceId="testTaskDefinition" + )["StackResourceDetail"] + task_definition_details["PhysicalResourceId"].should.equal(task_definition_arn) + @mock_ec2 @mock_ecs def test_task_definitions_unable_to_be_placed(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1713,53 +1624,47 @@ def test_task_definitions_unable_to_be_placed(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 5000, - 'memory': 40000, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 5000, + "memory": 40000, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(0) + len(response["tasks"]).should.equal(0) @mock_ec2 @mock_ecs def test_task_definitions_with_port_clash(): - client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = client.create_cluster( - clusterName=test_cluster_name - ) + _ = client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -1767,54 +1672,51 @@ def test_task_definitions_with_port_clash(): ) response = client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 256, - 'memory': 512, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'}, - 'portMappings': [ - { - 'hostPort': 80, - 'containerPort': 8080 - } - ] + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 256, + "memory": 512, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, + "portMappings": [{"hostPort": 80, "containerPort": 8080}], } - ] + ], ) response = client.run_task( - cluster='test_ecs_cluster', + cluster="test_ecs_cluster", overrides={}, - taskDefinition='test_ecs_task', + taskDefinition="test_ecs_task", count=2, - startedBy='moto' + startedBy="moto", ) - len(response['tasks']).should.equal(1) - response['tasks'][0]['taskArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:task/') - response['tasks'][0]['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['tasks'][0]['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - response['tasks'][0]['containerInstanceArn'].should.contain( - 'arn:aws:ecs:us-east-1:012345678910:container-instance/') - response['tasks'][0]['overrides'].should.equal({}) - response['tasks'][0]['lastStatus'].should.equal("RUNNING") - response['tasks'][0]['desiredStatus'].should.equal("RUNNING") - response['tasks'][0]['startedBy'].should.equal("moto") - response['tasks'][0]['stoppedReason'].should.equal("") + len(response["tasks"]).should.equal(1) + response["tasks"][0]["taskArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:task/" + ) + response["tasks"][0]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["tasks"][0]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" + ) + response["tasks"][0]["containerInstanceArn"].should.contain( + "arn:aws:ecs:us-east-1:012345678910:container-instance/" + ) + response["tasks"][0]["overrides"].should.equal({}) + response["tasks"][0]["lastStatus"].should.equal("RUNNING") + response["tasks"][0]["desiredStatus"].should.equal("RUNNING") + response["tasks"][0]["startedBy"].should.equal("moto") + response["tasks"][0]["stoppedReason"].should.equal("") @mock_ecs @@ -1834,35 +1736,29 @@ def test_update_task_definition_family_through_cloudformation_should_trigger_a_r "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, } - } + }, } template1_json = json.dumps(template1) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template1_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template1_json) template2 = deepcopy(template1) - template2['Resources']['testTaskDefinition'][ - 'Properties']['Family'] = 'testTaskDefinition2' + template2["Resources"]["testTaskDefinition"]["Properties"][ + "Family" + ] = "testTaskDefinition2" template2_json = json.dumps(template2) - cfn_conn.update_stack( - StackName="test_stack", - TemplateBody=template2_json, - ) + cfn_conn.update_stack(StackName="test_stack", TemplateBody=template2_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') - resp = ecs_conn.list_task_definitions(familyPrefix='testTaskDefinition') - len(resp['taskDefinitionArns']).should.equal(1) - resp['taskDefinitionArns'][0].endswith( - 'testTaskDefinition2:1').should.be.true + ecs_conn = boto3.client("ecs", region_name="us-west-1") + resp = ecs_conn.list_task_definitions(familyPrefix="testTaskDefinition") + len(resp["taskDefinitionArns"]).should.equal(1) + resp["taskDefinitionArns"][0].endswith("testTaskDefinition2:1").should.be.true @mock_ecs @@ -1874,9 +1770,7 @@ def test_create_service_through_cloudformation(): "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster" - } + "Properties": {"ClusterName": "testcluster"}, }, "testTaskDefinition": { "Type": "AWS::ECS::TaskDefinition", @@ -1887,11 +1781,11 @@ def test_create_service_through_cloudformation(): "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, }, "testService": { "Type": "AWS::ECS::Service", @@ -1899,20 +1793,17 @@ def test_create_service_through_cloudformation(): "Cluster": {"Ref": "testCluster"}, "DesiredCount": 10, "TaskDefinition": {"Ref": "testTaskDefinition"}, - } - } - } + }, + }, + }, } template_json = json.dumps(template) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') - resp = ecs_conn.list_services(cluster='testcluster') - len(resp['serviceArns']).should.equal(1) + ecs_conn = boto3.client("ecs", region_name="us-west-1") + resp = ecs_conn.list_services(cluster="testcluster") + len(resp["serviceArns"]).should.equal(1) @mock_ecs @@ -1924,9 +1815,7 @@ def test_update_service_through_cloudformation_should_trigger_replacement(): "Resources": { "testCluster": { "Type": "AWS::ECS::Cluster", - "Properties": { - "ClusterName": "testcluster" - } + "Properties": {"ClusterName": "testcluster"}, }, "testTaskDefinition": { "Type": "AWS::ECS::TaskDefinition", @@ -1937,11 +1826,11 @@ def test_update_service_through_cloudformation_should_trigger_replacement(): "Image": "amazon/amazon-ecs-sample", "Cpu": "200", "Memory": "500", - "Essential": "true" + "Essential": "true", } ], "Volumes": [], - } + }, }, "testService": { "Type": "AWS::ECS::Service", @@ -1949,47 +1838,37 @@ def test_update_service_through_cloudformation_should_trigger_replacement(): "Cluster": {"Ref": "testCluster"}, "TaskDefinition": {"Ref": "testTaskDefinition"}, "DesiredCount": 10, - } - } - } + }, + }, + }, } template_json1 = json.dumps(template1) - cfn_conn = boto3.client('cloudformation', region_name='us-west-1') - cfn_conn.create_stack( - StackName="test_stack", - TemplateBody=template_json1, - ) + cfn_conn = boto3.client("cloudformation", region_name="us-west-1") + cfn_conn.create_stack(StackName="test_stack", TemplateBody=template_json1) template2 = deepcopy(template1) - template2['Resources']['testService']['Properties']['DesiredCount'] = 5 + template2["Resources"]["testService"]["Properties"]["DesiredCount"] = 5 template2_json = json.dumps(template2) - cfn_conn.update_stack( - StackName="test_stack", - TemplateBody=template2_json, - ) + cfn_conn.update_stack(StackName="test_stack", TemplateBody=template2_json) - ecs_conn = boto3.client('ecs', region_name='us-west-1') - resp = ecs_conn.list_services(cluster='testcluster') - len(resp['serviceArns']).should.equal(1) + ecs_conn = boto3.client("ecs", region_name="us-west-1") + resp = ecs_conn.list_services(cluster="testcluster") + len(resp["serviceArns"]).should.equal(1) @mock_ec2 @mock_ecs def test_attributes(): # Combined put, list delete attributes into the same test due to the amount of setup - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) instances = [] test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instances.append(test_instance) @@ -1998,18 +1877,14 @@ def test_attributes(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn1 = response['containerInstance']['containerInstanceArn'] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn1 = response["containerInstance"]["containerInstanceArn"] test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instances.append(test_instance) @@ -2018,133 +1893,143 @@ def test_attributes(): ) response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn2 = response['containerInstance']['containerInstanceArn'] - partial_arn2 = full_arn2.rsplit('/', 1)[-1] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn2 = response["containerInstance"]["containerInstanceArn"] + partial_arn2 = full_arn2.rsplit("/", 1)[-1] - full_arn2.should_not.equal(full_arn1) # uuid1 isnt unique enough when the pc is fast ;-) + full_arn2.should_not.equal( + full_arn1 + ) # uuid1 isnt unique enough when the pc is fast ;-) # Ok set instance 1 with 1 attribute, instance 2 with another, and all of them with a 3rd. ecs_client.put_attributes( cluster=test_cluster_name, attributes=[ - {'name': 'env', 'value': 'prod'}, - {'name': 'attr1', 'value': 'instance1', 'targetId': full_arn1}, - {'name': 'attr1', 'value': 'instance2', 'targetId': partial_arn2, - 'targetType': 'container-instance'} - ] + {"name": "env", "value": "prod"}, + {"name": "attr1", "value": "instance1", "targetId": full_arn1}, + { + "name": "attr1", + "value": "instance2", + "targetId": partial_arn2, + "targetType": "container-instance", + }, + ], ) resp = ecs_client.list_attributes( - cluster=test_cluster_name, - targetType='container-instance' + cluster=test_cluster_name, targetType="container-instance" ) - attrs = resp['attributes'] + attrs = resp["attributes"] NUM_CUSTOM_ATTRIBUTES = 4 # 2 specific to individual machines and 1 global, going to both machines (2 + 1*2) NUM_DEFAULT_ATTRIBUTES = 4 - len(attrs).should.equal(NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances))) + len(attrs).should.equal( + NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances)) + ) # Tests that the attrs have been set properly - len(list(filter(lambda item: item['name'] == 'env', attrs))).should.equal(2) - len(list( - filter(lambda item: item['name'] == 'attr1' and item['value'] == 'instance1', attrs))).should.equal(1) + len(list(filter(lambda item: item["name"] == "env", attrs))).should.equal(2) + len( + list( + filter( + lambda item: item["name"] == "attr1" and item["value"] == "instance1", + attrs, + ) + ) + ).should.equal(1) ecs_client.delete_attributes( cluster=test_cluster_name, attributes=[ - {'name': 'attr1', 'value': 'instance2', 'targetId': partial_arn2, - 'targetType': 'container-instance'} - ] + { + "name": "attr1", + "value": "instance2", + "targetId": partial_arn2, + "targetType": "container-instance", + } + ], ) NUM_CUSTOM_ATTRIBUTES -= 1 resp = ecs_client.list_attributes( - cluster=test_cluster_name, - targetType='container-instance' + cluster=test_cluster_name, targetType="container-instance" + ) + attrs = resp["attributes"] + len(attrs).should.equal( + NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances)) ) - attrs = resp['attributes'] - len(attrs).should.equal(NUM_CUSTOM_ATTRIBUTES + (NUM_DEFAULT_ATTRIBUTES * len(instances))) @mock_ecs def test_poll_endpoint(): # Combined put, list delete attributes into the same test due to the amount of setup - ecs_client = boto3.client('ecs', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") # Just a placeholder until someone actually wants useless data, just testing it doesnt raise an exception - resp = ecs_client.discover_poll_endpoint(cluster='blah', containerInstance='blah') - resp.should.contain('endpoint') - resp.should.contain('telemetryEndpoint') + resp = ecs_client.discover_poll_endpoint(cluster="blah", containerInstance="blah") + resp.should.contain("endpoint") + resp.should.contain("telemetryEndpoint") @mock_ecs def test_list_task_definition_families(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) client.register_task_definition( - family='alt_test_ecs_task', + family="alt_test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) resp1 = client.list_task_definition_families() - resp2 = client.list_task_definition_families(familyPrefix='alt') + resp2 = client.list_task_definition_families(familyPrefix="alt") - len(resp1['families']).should.equal(2) - len(resp2['families']).should.equal(1) + len(resp1["families"]).should.equal(2) + len(resp2["families"]).should.equal(1) @mock_ec2 @mock_ecs def test_default_container_instance_attributes(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" # Create cluster and EC2 instance - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -2153,44 +2038,42 @@ def test_default_container_instance_attributes(): # Register container instance response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn = response['containerInstance']['containerInstanceArn'] - container_instance_id = full_arn.rsplit('/', 1)[-1] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn = response["containerInstance"]["containerInstanceArn"] + container_instance_id = full_arn.rsplit("/", 1)[-1] - default_attributes = response['containerInstance']['attributes'] + default_attributes = response["containerInstance"]["attributes"] assert len(default_attributes) == 4 expected_result = [ - {'name': 'ecs.availability-zone', 'value': test_instance.placement['AvailabilityZone']}, - {'name': 'ecs.ami-id', 'value': test_instance.image_id}, - {'name': 'ecs.instance-type', 'value': test_instance.instance_type}, - {'name': 'ecs.os-type', 'value': test_instance.platform or 'linux'} + { + "name": "ecs.availability-zone", + "value": test_instance.placement["AvailabilityZone"], + }, + {"name": "ecs.ami-id", "value": test_instance.image_id}, + {"name": "ecs.instance-type", "value": test_instance.instance_type}, + {"name": "ecs.os-type", "value": test_instance.platform or "linux"}, ] - assert sorted(default_attributes, key=lambda item: item['name']) == sorted(expected_result, - key=lambda item: item['name']) + assert sorted(default_attributes, key=lambda item: item["name"]) == sorted( + expected_result, key=lambda item: item["name"] + ) @mock_ec2 @mock_ecs def test_describe_container_instances_with_attributes(): - ecs_client = boto3.client('ecs', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + ecs_client = boto3.client("ecs", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - test_cluster_name = 'test_ecs_cluster' + test_cluster_name = "test_ecs_cluster" # Create cluster and EC2 instance - _ = ecs_client.create_cluster( - clusterName=test_cluster_name - ) + _ = ecs_client.create_cluster(clusterName=test_cluster_name) test_instance = ec2.create_instances( - ImageId="ami-1234abcd", - MinCount=1, - MaxCount=1, + ImageId="ami-1234abcd", MinCount=1, MaxCount=1 )[0] instance_id_document = json.dumps( @@ -2199,396 +2082,395 @@ def test_describe_container_instances_with_attributes(): # Register container instance response = ecs_client.register_container_instance( - cluster=test_cluster_name, - instanceIdentityDocument=instance_id_document + cluster=test_cluster_name, instanceIdentityDocument=instance_id_document ) - response['containerInstance'][ - 'ec2InstanceId'].should.equal(test_instance.id) - full_arn = response['containerInstance']['containerInstanceArn'] - container_instance_id = full_arn.rsplit('/', 1)[-1] - default_attributes = response['containerInstance']['attributes'] + response["containerInstance"]["ec2InstanceId"].should.equal(test_instance.id) + full_arn = response["containerInstance"]["containerInstanceArn"] + container_instance_id = full_arn.rsplit("/", 1)[-1] + default_attributes = response["containerInstance"]["attributes"] # Set attributes on container instance, one without a value attributes = [ - {'name': 'env', 'value': 'prod'}, - {'name': 'attr1', 'value': 'instance1', 'targetId': container_instance_id, - 'targetType': 'container-instance'}, - {'name': 'attr_without_value'} + {"name": "env", "value": "prod"}, + { + "name": "attr1", + "value": "instance1", + "targetId": container_instance_id, + "targetType": "container-instance", + }, + {"name": "attr_without_value"}, ] - ecs_client.put_attributes( - cluster=test_cluster_name, - attributes=attributes - ) + ecs_client.put_attributes(cluster=test_cluster_name, attributes=attributes) # Describe container instance, should have attributes previously set - described_instance = ecs_client.describe_container_instances(cluster=test_cluster_name, - containerInstances=[container_instance_id]) + described_instance = ecs_client.describe_container_instances( + cluster=test_cluster_name, containerInstances=[container_instance_id] + ) - assert len(described_instance['containerInstances']) == 1 - assert isinstance(described_instance['containerInstances'][0]['attributes'], list) + assert len(described_instance["containerInstances"]) == 1 + assert isinstance(described_instance["containerInstances"][0]["attributes"], list) # Remove additional info passed to put_attributes cleaned_attributes = [] for attribute in attributes: - attribute.pop('targetId', None) - attribute.pop('targetType', None) + attribute.pop("targetId", None) + attribute.pop("targetType", None) cleaned_attributes.append(attribute) - described_attributes = sorted(described_instance['containerInstances'][0]['attributes'], - key=lambda item: item['name']) - expected_attributes = sorted(default_attributes + cleaned_attributes, key=lambda item: item['name']) + described_attributes = sorted( + described_instance["containerInstances"][0]["attributes"], + key=lambda item: item["name"], + ) + expected_attributes = sorted( + default_attributes + cleaned_attributes, key=lambda item: item["name"] + ) assert described_attributes == expected_attributes def _fetch_container_instance_resources(container_instance_description): remaining_resources = {} registered_resources = {} - remaining_resources_list = container_instance_description['remainingResources'] - registered_resources_list = container_instance_description['registeredResources'] - remaining_resources['CPU'] = [x['integerValue'] for x in remaining_resources_list if x['name'] == 'CPU'][ - 0] - remaining_resources['MEMORY'] = \ - [x['integerValue'] for x in remaining_resources_list if x['name'] == 'MEMORY'][0] - remaining_resources['PORTS'] = \ - [x['stringSetValue'] for x in remaining_resources_list if x['name'] == 'PORTS'][0] - registered_resources['CPU'] = \ - [x['integerValue'] for x in registered_resources_list if x['name'] == 'CPU'][0] - registered_resources['MEMORY'] = \ - [x['integerValue'] for x in registered_resources_list if x['name'] == 'MEMORY'][0] - registered_resources['PORTS'] = \ - [x['stringSetValue'] for x in registered_resources_list if x['name'] == 'PORTS'][0] + remaining_resources_list = container_instance_description["remainingResources"] + registered_resources_list = container_instance_description["registeredResources"] + remaining_resources["CPU"] = [ + x["integerValue"] for x in remaining_resources_list if x["name"] == "CPU" + ][0] + remaining_resources["MEMORY"] = [ + x["integerValue"] for x in remaining_resources_list if x["name"] == "MEMORY" + ][0] + remaining_resources["PORTS"] = [ + x["stringSetValue"] for x in remaining_resources_list if x["name"] == "PORTS" + ][0] + registered_resources["CPU"] = [ + x["integerValue"] for x in registered_resources_list if x["name"] == "CPU" + ][0] + registered_resources["MEMORY"] = [ + x["integerValue"] for x in registered_resources_list if x["name"] == "MEMORY" + ][0] + registered_resources["PORTS"] = [ + x["stringSetValue"] for x in registered_resources_list if x["name"] == "PORTS" + ][0] return remaining_resources, registered_resources @mock_ecs def test_create_service_load_balancing(): - client = boto3.client('ecs', region_name='us-east-1') - client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + client.create_cluster(clusterName="test_ecs_cluster") client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, loadBalancers=[ { - 'targetGroupArn': 'test_target_group_arn', - 'loadBalancerName': 'test_load_balancer_name', - 'containerName': 'test_container_name', - 'containerPort': 123 + "targetGroupArn": "test_target_group_arn", + "loadBalancerName": "test_load_balancer_name", + "containerName": "test_container_name", + "containerPort": 123, } - ] + ], + ) + response["service"]["clusterArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster" + ) + response["service"]["desiredCount"].should.equal(2) + len(response["service"]["events"]).should.equal(0) + len(response["service"]["loadBalancers"]).should.equal(1) + response["service"]["loadBalancers"][0]["targetGroupArn"].should.equal( + "test_target_group_arn" + ) + response["service"]["loadBalancers"][0]["loadBalancerName"].should.equal( + "test_load_balancer_name" + ) + response["service"]["loadBalancers"][0]["containerName"].should.equal( + "test_container_name" + ) + response["service"]["loadBalancers"][0]["containerPort"].should.equal(123) + response["service"]["pendingCount"].should.equal(0) + response["service"]["runningCount"].should.equal(0) + response["service"]["serviceArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service" + ) + response["service"]["serviceName"].should.equal("test_ecs_service") + response["service"]["status"].should.equal("ACTIVE") + response["service"]["taskDefinition"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - response['service']['clusterArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:cluster/test_ecs_cluster') - response['service']['desiredCount'].should.equal(2) - len(response['service']['events']).should.equal(0) - len(response['service']['loadBalancers']).should.equal(1) - response['service']['loadBalancers'][0]['targetGroupArn'].should.equal( - 'test_target_group_arn') - response['service']['loadBalancers'][0]['loadBalancerName'].should.equal( - 'test_load_balancer_name') - response['service']['loadBalancers'][0]['containerName'].should.equal( - 'test_container_name') - response['service']['loadBalancers'][0]['containerPort'].should.equal(123) - response['service']['pendingCount'].should.equal(0) - response['service']['runningCount'].should.equal(0) - response['service']['serviceArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:service/test_ecs_service') - response['service']['serviceName'].should.equal('test_ecs_service') - response['service']['status'].should.equal('ACTIVE') - response['service']['taskDefinition'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') @mock_ecs def test_list_tags_for_resource(): - client = boto3.client('ecs', region_name='us-east-1') + client = boto3.client("ecs", region_name="us-east-1") response = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } ], tags=[ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ] + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], + ) + type(response["taskDefinition"]).should.be(dict) + response["taskDefinition"]["revision"].should.equal(1) + response["taskDefinition"]["taskDefinitionArn"].should.equal( + "arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1" ) - type(response['taskDefinition']).should.be(dict) - response['taskDefinition']['revision'].should.equal(1) - response['taskDefinition']['taskDefinitionArn'].should.equal( - 'arn:aws:ecs:us-east-1:012345678910:task-definition/test_ecs_task:1') - task_definition_arn = response['taskDefinition']['taskDefinitionArn'] + task_definition_arn = response["taskDefinition"]["taskDefinitionArn"] response = client.list_tags_for_resource(resourceArn=task_definition_arn) - type(response['tags']).should.be(list) - response['tags'].should.equal([ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ]) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [{"key": "createdBy", "value": "moto-unittest"}, {"key": "foo", "value": "bar"}] + ) @mock_ecs def test_list_tags_for_resource_unknown(): - client = boto3.client('ecs', region_name='us-east-1') - task_definition_arn = 'arn:aws:ecs:us-east-1:012345678910:task-definition/unknown:1' + client = boto3.client("ecs", region_name="us-east-1") + task_definition_arn = "arn:aws:ecs:us-east-1:012345678910:task-definition/unknown:1" try: client.list_tags_for_resource(resourceArn=task_definition_arn) except ClientError as err: - err.response['Error']['Code'].should.equal('ClientException') + err.response["Error"]["Code"].should.equal("ClientException") @mock_ecs def test_list_tags_for_resource_ecs_service(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, tags=[ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ] + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [{"key": "createdBy", "value": "moto-unittest"}, {"key": "foo", "value": "bar"}] ) - response = client.list_tags_for_resource(resourceArn=response['service']['serviceArn']) - type(response['tags']).should.be(list) - response['tags'].should.equal([ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ]) @mock_ecs def test_list_tags_for_resource_unknown_service(): - client = boto3.client('ecs', region_name='us-east-1') - service_arn = 'arn:aws:ecs:us-east-1:012345678910:service/unknown:1' + client = boto3.client("ecs", region_name="us-east-1") + service_arn = "arn:aws:ecs:us-east-1:012345678910:service/unknown:1" try: client.list_tags_for_resource(resourceArn=service_arn) except ClientError as err: - err.response['Error']['Code'].should.equal('ServiceNotFoundException') + err.response["Error"]["Code"].should.equal("ServiceNotFoundException") @mock_ecs def test_ecs_service_tag_resource(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', - desiredCount=2 + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", + desiredCount=2, ) client.tag_resource( - resourceArn=response['service']['serviceArn'], + resourceArn=response["service"]["serviceArn"], tags=[ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ] + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "bar"}, + ], + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [{"key": "createdBy", "value": "moto-unittest"}, {"key": "foo", "value": "bar"}] ) - response = client.list_tags_for_resource(resourceArn=response['service']['serviceArn']) - type(response['tags']).should.be(list) - response['tags'].should.equal([ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'bar'}, - ]) @mock_ecs def test_ecs_service_tag_resource_overwrites_tag(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, - tags=[ - {'key': 'foo', 'value': 'bar'}, - ] + tags=[{"key": "foo", "value": "bar"}], ) client.tag_resource( - resourceArn=response['service']['serviceArn'], + resourceArn=response["service"]["serviceArn"], tags=[ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'hello world'}, + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "hello world"}, + ], + ) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + type(response["tags"]).should.be(list) + response["tags"].should.equal( + [ + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "foo", "value": "hello world"}, ] ) - response = client.list_tags_for_resource(resourceArn=response['service']['serviceArn']) - type(response['tags']).should.be(list) - response['tags'].should.equal([ - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'foo', 'value': 'hello world'}, - ]) @mock_ecs def test_ecs_service_untag_resource(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, - tags=[ - {'key': 'foo', 'value': 'bar'}, - ] + tags=[{"key": "foo", "value": "bar"}], ) client.untag_resource( - resourceArn=response['service']['serviceArn'], - tagKeys=['foo'] + resourceArn=response["service"]["serviceArn"], tagKeys=["foo"] ) - response = client.list_tags_for_resource(resourceArn=response['service']['serviceArn']) - response['tags'].should.equal([]) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + response["tags"].should.equal([]) @mock_ecs def test_ecs_service_untag_resource_multiple_tags(): - client = boto3.client('ecs', region_name='us-east-1') - _ = client.create_cluster( - clusterName='test_ecs_cluster' - ) + client = boto3.client("ecs", region_name="us-east-1") + _ = client.create_cluster(clusterName="test_ecs_cluster") _ = client.register_task_definition( - family='test_ecs_task', + family="test_ecs_task", containerDefinitions=[ { - 'name': 'hello_world', - 'image': 'docker/hello-world:latest', - 'cpu': 1024, - 'memory': 400, - 'essential': True, - 'environment': [{ - 'name': 'AWS_ACCESS_KEY_ID', - 'value': 'SOME_ACCESS_KEY' - }], - 'logConfiguration': {'logDriver': 'json-file'} + "name": "hello_world", + "image": "docker/hello-world:latest", + "cpu": 1024, + "memory": 400, + "essential": True, + "environment": [ + {"name": "AWS_ACCESS_KEY_ID", "value": "SOME_ACCESS_KEY"} + ], + "logConfiguration": {"logDriver": "json-file"}, } - ] + ], ) response = client.create_service( - cluster='test_ecs_cluster', - serviceName='test_ecs_service', - taskDefinition='test_ecs_task', + cluster="test_ecs_cluster", + serviceName="test_ecs_service", + taskDefinition="test_ecs_task", desiredCount=2, tags=[ - {'key': 'foo', 'value': 'bar'}, - {'key': 'createdBy', 'value': 'moto-unittest'}, - {'key': 'hello', 'value': 'world'}, - ] + {"key": "foo", "value": "bar"}, + {"key": "createdBy", "value": "moto-unittest"}, + {"key": "hello", "value": "world"}, + ], ) client.untag_resource( - resourceArn=response['service']['serviceArn'], - tagKeys=['foo', 'createdBy'] + resourceArn=response["service"]["serviceArn"], tagKeys=["foo", "createdBy"] ) - response = client.list_tags_for_resource(resourceArn=response['service']['serviceArn']) - response['tags'].should.equal([ - {'key': 'hello', 'value': 'world'}, - ]) + response = client.list_tags_for_resource( + resourceArn=response["service"]["serviceArn"] + ) + response["tags"].should.equal([{"key": "hello", "value": "world"}]) diff --git a/tests/test_elb/test_elb.py b/tests/test_elb/test_elb.py index 447896f15..d7a7b88cb 100644 --- a/tests/test_elb/test_elb.py +++ b/tests/test_elb/test_elb.py @@ -23,19 +23,20 @@ def test_create_load_balancer(): conn = boto.connect_elb() ec2 = boto.ec2.connect_to_region("us-east-1") - security_group = ec2.create_security_group('sg-abc987', 'description') + security_group = ec2.create_security_group("sg-abc987", "description") - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports, scheme='internal', security_groups=[security_group.id]) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer( + "my-lb", zones, ports, scheme="internal", security_groups=[security_group.id] + ) balancers = conn.get_all_load_balancers() balancer = balancers[0] balancer.name.should.equal("my-lb") balancer.scheme.should.equal("internal") list(balancer.security_groups).should.equal([security_group.id]) - set(balancer.availability_zones).should.equal( - set(['us-east-1a', 'us-east-1b'])) + set(balancer.availability_zones).should.equal(set(["us-east-1a", "us-east-1b"])) listener1 = balancer.listeners[0] listener1.load_balancer_port.should.equal(80) listener1.instance_port.should.equal(8080) @@ -50,19 +51,20 @@ def test_create_load_balancer(): def test_getting_missing_elb(): conn = boto.connect_elb() conn.get_all_load_balancers.when.called_with( - load_balancer_names='aaa').should.throw(BotoServerError) + load_balancer_names="aaa" + ).should.throw(BotoServerError) @mock_elb_deprecated def test_create_elb_in_multiple_region(): - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] west1_conn = boto.ec2.elb.connect_to_region("us-west-1") - west1_conn.create_load_balancer('my-lb', zones, ports) + west1_conn.create_load_balancer("my-lb", zones, ports) west2_conn = boto.ec2.elb.connect_to_region("us-west-2") - west2_conn.create_load_balancer('my-lb', zones, ports) + west2_conn.create_load_balancer("my-lb", zones, ports) list(west1_conn.get_all_load_balancers()).should.have.length_of(1) list(west2_conn.get_all_load_balancers()).should.have.length_of(1) @@ -72,117 +74,123 @@ def test_create_elb_in_multiple_region(): def test_create_load_balancer_with_certificate(): conn = boto.connect_elb() - zones = ['us-east-1a'] + zones = ["us-east-1a"] ports = [ - (443, 8443, 'https', 'arn:aws:iam:123456789012:server-certificate/test-cert')] - conn.create_load_balancer('my-lb', zones, ports) + (443, 8443, "https", "arn:aws:iam:123456789012:server-certificate/test-cert") + ] + conn.create_load_balancer("my-lb", zones, ports) balancers = conn.get_all_load_balancers() balancer = balancers[0] balancer.name.should.equal("my-lb") balancer.scheme.should.equal("internet-facing") - set(balancer.availability_zones).should.equal(set(['us-east-1a'])) + set(balancer.availability_zones).should.equal(set(["us-east-1a"])) listener = balancer.listeners[0] listener.load_balancer_port.should.equal(443) listener.instance_port.should.equal(8443) listener.protocol.should.equal("HTTPS") listener.ssl_certificate_id.should.equal( - 'arn:aws:iam:123456789012:server-certificate/test-cert') + "arn:aws:iam:123456789012:server-certificate/test-cert" + ) @mock_elb def test_create_and_delete_boto3_support(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(1) + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(1) - client.delete_load_balancer( - LoadBalancerName='my-lb' - ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(0) + client.delete_load_balancer(LoadBalancerName="my-lb") + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(0) @mock_elb def test_create_load_balancer_with_no_listeners_defined(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") with assert_raises(ClientError): client.create_load_balancer( - LoadBalancerName='my-lb', + LoadBalancerName="my-lb", Listeners=[], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + AvailabilityZones=["us-east-1a", "us-east-1b"], ) @mock_elb def test_describe_paginated_balancers(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") for i in range(51): client.create_load_balancer( - LoadBalancerName='my-lb%d' % i, + LoadBalancerName="my-lb%d" % i, Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + {"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080} + ], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) resp = client.describe_load_balancers() - resp['LoadBalancerDescriptions'].should.have.length_of(50) - resp['NextMarker'].should.equal(resp['LoadBalancerDescriptions'][-1]['LoadBalancerName']) - resp2 = client.describe_load_balancers(Marker=resp['NextMarker']) - resp2['LoadBalancerDescriptions'].should.have.length_of(1) - assert 'NextToken' not in resp2.keys() + resp["LoadBalancerDescriptions"].should.have.length_of(50) + resp["NextMarker"].should.equal( + resp["LoadBalancerDescriptions"][-1]["LoadBalancerName"] + ) + resp2 = client.describe_load_balancers(Marker=resp["NextMarker"]) + resp2["LoadBalancerDescriptions"].should.have.length_of(1) + assert "NextToken" not in resp2.keys() @mock_elb @mock_ec2 def test_apply_security_groups_to_load_balancer(): - client = boto3.client('elb', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") security_group = ec2.create_security_group( - GroupName='sg01', Description='Test security group sg01', VpcId=vpc.id) + GroupName="sg01", Description="Test security group sg01", VpcId=vpc.id + ) client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) response = client.apply_security_groups_to_load_balancer( - LoadBalancerName='my-lb', - SecurityGroups=[security_group.id]) + LoadBalancerName="my-lb", SecurityGroups=[security_group.id] + ) - assert response['SecurityGroups'] == [security_group.id] - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - assert balancer['SecurityGroups'] == [security_group.id] + assert response["SecurityGroups"] == [security_group.id] + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + assert balancer["SecurityGroups"] == [security_group.id] # Using a not-real security group raises an error with assert_raises(ClientError) as error: response = client.apply_security_groups_to_load_balancer( - LoadBalancerName='my-lb', - SecurityGroups=['not-really-a-security-group']) - assert "One or more of the specified security groups do not exist." in str(error.exception) + LoadBalancerName="my-lb", SecurityGroups=["not-really-a-security-group"] + ) + assert "One or more of the specified security groups do not exist." in str( + error.exception + ) @mock_elb_deprecated def test_add_listener(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http')] - conn.create_load_balancer('my-lb', zones, ports) - new_listener = (443, 8443, 'tcp') - conn.create_load_balancer_listeners('my-lb', [new_listener]) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http")] + conn.create_load_balancer("my-lb", zones, ports) + new_listener = (443, 8443, "tcp") + conn.create_load_balancer_listeners("my-lb", [new_listener]) balancers = conn.get_all_load_balancers() balancer = balancers[0] listener1 = balancer.listeners[0] @@ -199,10 +207,10 @@ def test_add_listener(): def test_delete_listener(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports) - conn.delete_load_balancer_listeners('my-lb', [443]) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb", zones, ports) + conn.delete_load_balancer_listeners("my-lb", [443]) balancers = conn.get_all_load_balancers() balancer = balancers[0] listener1 = balancer.listeners[0] @@ -214,61 +222,57 @@ def test_delete_listener(): @mock_elb def test_create_and_delete_listener_boto3_support(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(1) + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(1) client.create_load_balancer_listeners( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 443, 'InstancePort': 8443}] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 443, "InstancePort": 8443}], ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - list(balancer['ListenerDescriptions']).should.have.length_of(2) - balancer['ListenerDescriptions'][0][ - 'Listener']['Protocol'].should.equal('HTTP') - balancer['ListenerDescriptions'][0]['Listener'][ - 'LoadBalancerPort'].should.equal(80) - balancer['ListenerDescriptions'][0]['Listener'][ - 'InstancePort'].should.equal(8080) - balancer['ListenerDescriptions'][1][ - 'Listener']['Protocol'].should.equal('TCP') - balancer['ListenerDescriptions'][1]['Listener'][ - 'LoadBalancerPort'].should.equal(443) - balancer['ListenerDescriptions'][1]['Listener'][ - 'InstancePort'].should.equal(8443) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + list(balancer["ListenerDescriptions"]).should.have.length_of(2) + balancer["ListenerDescriptions"][0]["Listener"]["Protocol"].should.equal("HTTP") + balancer["ListenerDescriptions"][0]["Listener"]["LoadBalancerPort"].should.equal(80) + balancer["ListenerDescriptions"][0]["Listener"]["InstancePort"].should.equal(8080) + balancer["ListenerDescriptions"][1]["Listener"]["Protocol"].should.equal("TCP") + balancer["ListenerDescriptions"][1]["Listener"]["LoadBalancerPort"].should.equal( + 443 + ) + balancer["ListenerDescriptions"][1]["Listener"]["InstancePort"].should.equal(8443) # Creating this listener with an conflicting definition throws error with assert_raises(ClientError): client.create_load_balancer_listeners( - LoadBalancerName='my-lb', + LoadBalancerName="my-lb", Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 443, 'InstancePort': 1234}] + {"Protocol": "tcp", "LoadBalancerPort": 443, "InstancePort": 1234} + ], ) client.delete_load_balancer_listeners( - LoadBalancerName='my-lb', - LoadBalancerPorts=[443]) + LoadBalancerName="my-lb", LoadBalancerPorts=[443] + ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - list(balancer['ListenerDescriptions']).should.have.length_of(1) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + list(balancer["ListenerDescriptions"]).should.have.length_of(1) @mock_elb_deprecated def test_set_sslcertificate(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports) - conn.set_lb_listener_SSL_certificate('my-lb', '443', 'arn:certificate') + zones = ["us-east-1a", "us-east-1b"] + ports = [(443, 8443, "tcp")] + conn.create_load_balancer("my-lb", zones, ports) + conn.set_lb_listener_SSL_certificate("my-lb", "443", "arn:certificate") balancers = conn.get_all_load_balancers() balancer = balancers[0] listener1 = balancer.listeners[0] @@ -282,26 +286,26 @@ def test_set_sslcertificate(): def test_get_load_balancers_by_name(): conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb1', zones, ports) - conn.create_load_balancer('my-lb2', zones, ports) - conn.create_load_balancer('my-lb3', zones, ports) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb1", zones, ports) + conn.create_load_balancer("my-lb2", zones, ports) + conn.create_load_balancer("my-lb3", zones, ports) conn.get_all_load_balancers().should.have.length_of(3) + conn.get_all_load_balancers(load_balancer_names=["my-lb1"]).should.have.length_of(1) conn.get_all_load_balancers( - load_balancer_names=['my-lb1']).should.have.length_of(1) - conn.get_all_load_balancers( - load_balancer_names=['my-lb1', 'my-lb2']).should.have.length_of(2) + load_balancer_names=["my-lb1", "my-lb2"] + ).should.have.length_of(2) @mock_elb_deprecated def test_delete_load_balancer(): conn = boto.connect_elb() - zones = ['us-east-1a'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', zones, ports) + zones = ["us-east-1a"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb", zones, ports) balancers = conn.get_all_load_balancers() balancers.should.have.length_of(1) @@ -319,12 +323,12 @@ def test_create_health_check(): interval=20, healthy_threshold=3, unhealthy_threshold=5, - target='HTTP:8080/health', + target="HTTP:8080/health", timeout=23, ) - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) lb.configure_health_check(hc) balancer = conn.get_all_load_balancers()[0] @@ -332,50 +336,49 @@ def test_create_health_check(): health_check.interval.should.equal(20) health_check.healthy_threshold.should.equal(3) health_check.unhealthy_threshold.should.equal(5) - health_check.target.should.equal('HTTP:8080/health') + health_check.target.should.equal("HTTP:8080/health") health_check.timeout.should.equal(23) @mock_elb def test_create_health_check_boto3(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) client.configure_health_check( - LoadBalancerName='my-lb', + LoadBalancerName="my-lb", HealthCheck={ - 'Target': 'HTTP:8080/health', - 'Interval': 20, - 'Timeout': 23, - 'HealthyThreshold': 3, - 'UnhealthyThreshold': 5 - } + "Target": "HTTP:8080/health", + "Interval": 20, + "Timeout": 23, + "HealthyThreshold": 3, + "UnhealthyThreshold": 5, + }, ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - balancer['HealthCheck']['Target'].should.equal('HTTP:8080/health') - balancer['HealthCheck']['Interval'].should.equal(20) - balancer['HealthCheck']['Timeout'].should.equal(23) - balancer['HealthCheck']['HealthyThreshold'].should.equal(3) - balancer['HealthCheck']['UnhealthyThreshold'].should.equal(5) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + balancer["HealthCheck"]["Target"].should.equal("HTTP:8080/health") + balancer["HealthCheck"]["Interval"].should.equal(20) + balancer["HealthCheck"]["Timeout"].should.equal(23) + balancer["HealthCheck"]["HealthyThreshold"].should.equal(3) + balancer["HealthCheck"]["UnhealthyThreshold"].should.equal(5) @mock_ec2_deprecated @mock_elb_deprecated def test_register_instances(): ec2_conn = boto.connect_ec2() - reservation = ec2_conn.run_instances('ami-1234abcd', 2) + reservation = ec2_conn.run_instances("ami-1234abcd", 2) instance_id1 = reservation.instances[0].id instance_id2 = reservation.instances[1].id conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) lb.register_instances([instance_id1, instance_id2]) @@ -387,29 +390,23 @@ def test_register_instances(): @mock_ec2 @mock_elb def test_register_instances_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=2, MaxCount=2) + ec2 = boto3.resource("ec2", region_name="us-east-1") + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=2, MaxCount=2) instance_id1 = response[0].id instance_id2 = response[1].id - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) client.register_instances_with_load_balancer( - LoadBalancerName='my-lb', - Instances=[ - {'InstanceId': instance_id1}, - {'InstanceId': instance_id2} - ] + LoadBalancerName="my-lb", + Instances=[{"InstanceId": instance_id1}, {"InstanceId": instance_id2}], ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - instance_ids = [instance['InstanceId'] - for instance in balancer['Instances']] + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + instance_ids = [instance["InstanceId"] for instance in balancer["Instances"]] set(instance_ids).should.equal(set([instance_id1, instance_id2])) @@ -417,13 +414,13 @@ def test_register_instances_boto3(): @mock_elb_deprecated def test_deregister_instances(): ec2_conn = boto.connect_ec2() - reservation = ec2_conn.run_instances('ami-1234abcd', 2) + reservation = ec2_conn.run_instances("ami-1234abcd", 2) instance_id1 = reservation.instances[0].id instance_id2 = reservation.instances[1].id conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) lb.register_instances([instance_id1, instance_id2]) @@ -438,47 +435,39 @@ def test_deregister_instances(): @mock_ec2 @mock_elb def test_deregister_instances_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=2, MaxCount=2) + ec2 = boto3.resource("ec2", region_name="us-east-1") + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=2, MaxCount=2) instance_id1 = response[0].id instance_id2 = response[1].id - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'http', - 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "http", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) client.register_instances_with_load_balancer( - LoadBalancerName='my-lb', - Instances=[ - {'InstanceId': instance_id1}, - {'InstanceId': instance_id2} - ] + LoadBalancerName="my-lb", + Instances=[{"InstanceId": instance_id1}, {"InstanceId": instance_id2}], ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - balancer['Instances'].should.have.length_of(2) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + balancer["Instances"].should.have.length_of(2) client.deregister_instances_from_load_balancer( - LoadBalancerName='my-lb', - Instances=[ - {'InstanceId': instance_id1} - ] + LoadBalancerName="my-lb", Instances=[{"InstanceId": instance_id1}] ) - balancer = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - balancer['Instances'].should.have.length_of(1) - balancer['Instances'][0]['InstanceId'].should.equal(instance_id2) + balancer = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + balancer["Instances"].should.have.length_of(1) + balancer["Instances"][0]["InstanceId"].should.equal(instance_id2) @mock_elb_deprecated def test_default_attributes(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) attributes = lb.get_attributes() attributes.cross_zone_load_balancing.enabled.should.be.false @@ -490,8 +479,8 @@ def test_default_attributes(): @mock_elb_deprecated def test_cross_zone_load_balancing_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) conn.modify_lb_attribute("my-lb", "CrossZoneLoadBalancing", True) attributes = lb.get_attributes(force=True) @@ -505,28 +494,25 @@ def test_cross_zone_load_balancing_attribute(): @mock_elb_deprecated def test_connection_draining_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) connection_draining = ConnectionDrainingAttribute() connection_draining.enabled = True connection_draining.timeout = 60 - conn.modify_lb_attribute( - "my-lb", "ConnectionDraining", connection_draining) + conn.modify_lb_attribute("my-lb", "ConnectionDraining", connection_draining) attributes = lb.get_attributes(force=True) attributes.connection_draining.enabled.should.be.true attributes.connection_draining.timeout.should.equal(60) connection_draining.timeout = 30 - conn.modify_lb_attribute( - "my-lb", "ConnectionDraining", connection_draining) + conn.modify_lb_attribute("my-lb", "ConnectionDraining", connection_draining) attributes = lb.get_attributes(force=True) attributes.connection_draining.timeout.should.equal(30) connection_draining.enabled = False - conn.modify_lb_attribute( - "my-lb", "ConnectionDraining", connection_draining) + conn.modify_lb_attribute("my-lb", "ConnectionDraining", connection_draining) attributes = lb.get_attributes(force=True) attributes.connection_draining.enabled.should.be.false @@ -534,13 +520,13 @@ def test_connection_draining_attribute(): @mock_elb_deprecated def test_access_log_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) access_log = AccessLogAttribute() access_log.enabled = True - access_log.s3_bucket_name = 'bucket' - access_log.s3_bucket_prefix = 'prefix' + access_log.s3_bucket_name = "bucket" + access_log.s3_bucket_prefix = "prefix" access_log.emit_interval = 60 conn.modify_lb_attribute("my-lb", "AccessLog", access_log) @@ -559,20 +545,18 @@ def test_access_log_attribute(): @mock_elb_deprecated def test_connection_settings_attribute(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) connection_settings = ConnectionSettingAttribute(conn) connection_settings.idle_timeout = 120 - conn.modify_lb_attribute( - "my-lb", "ConnectingSettings", connection_settings) + conn.modify_lb_attribute("my-lb", "ConnectingSettings", connection_settings) attributes = lb.get_attributes(force=True) attributes.connecting_settings.idle_timeout.should.equal(120) connection_settings.idle_timeout = 60 - conn.modify_lb_attribute( - "my-lb", "ConnectingSettings", connection_settings) + conn.modify_lb_attribute("my-lb", "ConnectingSettings", connection_settings) attributes = lb.get_attributes(force=True) attributes.connecting_settings.idle_timeout.should.equal(60) @@ -580,8 +564,8 @@ def test_connection_settings_attribute(): @mock_elb_deprecated def test_create_lb_cookie_stickiness_policy(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) cookie_expiration_period = 60 policy_name = "LBCookieStickinessPolicy" @@ -594,55 +578,49 @@ def test_create_lb_cookie_stickiness_policy(): # # To work around that, this value is converted to an int and checked. cookie_expiration_period_response_str = lb.policies.lb_cookie_stickiness_policies[ - 0].cookie_expiration_period - int(cookie_expiration_period_response_str).should.equal( - cookie_expiration_period) - lb.policies.lb_cookie_stickiness_policies[ - 0].policy_name.should.equal(policy_name) + 0 + ].cookie_expiration_period + int(cookie_expiration_period_response_str).should.equal(cookie_expiration_period) + lb.policies.lb_cookie_stickiness_policies[0].policy_name.should.equal(policy_name) @mock_elb_deprecated def test_create_lb_cookie_stickiness_policy_no_expiry(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) policy_name = "LBCookieStickinessPolicy" lb.create_cookie_stickiness_policy(None, policy_name) lb = conn.get_all_load_balancers()[0] - lb.policies.lb_cookie_stickiness_policies[ - 0].cookie_expiration_period.should.be.none - lb.policies.lb_cookie_stickiness_policies[ - 0].policy_name.should.equal(policy_name) + lb.policies.lb_cookie_stickiness_policies[0].cookie_expiration_period.should.be.none + lb.policies.lb_cookie_stickiness_policies[0].policy_name.should.equal(policy_name) @mock_elb_deprecated def test_create_app_cookie_stickiness_policy(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) cookie_name = "my-stickiness-policy" policy_name = "AppCookieStickinessPolicy" lb.create_app_cookie_stickiness_policy(cookie_name, policy_name) lb = conn.get_all_load_balancers()[0] - lb.policies.app_cookie_stickiness_policies[ - 0].cookie_name.should.equal(cookie_name) - lb.policies.app_cookie_stickiness_policies[ - 0].policy_name.should.equal(policy_name) + lb.policies.app_cookie_stickiness_policies[0].cookie_name.should.equal(cookie_name) + lb.policies.app_cookie_stickiness_policies[0].policy_name.should.equal(policy_name) @mock_elb_deprecated def test_create_lb_policy(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) policy_name = "ProxyPolicy" - lb.create_lb_policy(policy_name, 'ProxyProtocolPolicyType', { - 'ProxyProtocol': True}) + lb.create_lb_policy(policy_name, "ProxyProtocolPolicyType", {"ProxyProtocol": True}) lb = conn.get_all_load_balancers()[0] lb.policies.other_policies[0].policy_name.should.equal(policy_name) @@ -651,8 +629,8 @@ def test_create_lb_policy(): @mock_elb_deprecated def test_set_policies_of_listener(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) listener_port = 80 policy_name = "my-stickiness-policy" @@ -674,15 +652,14 @@ def test_set_policies_of_listener(): @mock_elb_deprecated def test_set_policies_of_backend_server(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', [], ports) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", [], ports) instance_port = 8080 policy_name = "ProxyPolicy" # in a real flow, it is necessary first to create a policy, # then to set that policy to the backend - lb.create_lb_policy(policy_name, 'ProxyProtocolPolicyType', { - 'ProxyProtocol': True}) + lb.create_lb_policy(policy_name, "ProxyProtocolPolicyType", {"ProxyProtocol": True}) lb.set_policies_of_backend_server(instance_port, [policy_name]) lb = conn.get_all_load_balancers()[0] @@ -696,287 +673,262 @@ def test_set_policies_of_backend_server(): @mock_elb_deprecated def test_describe_instance_health(): ec2_conn = boto.connect_ec2() - reservation = ec2_conn.run_instances('ami-1234abcd', 2) + reservation = ec2_conn.run_instances("ami-1234abcd", 2) instance_id1 = reservation.instances[0].id instance_id2 = reservation.instances[1].id conn = boto.connect_elb() - zones = ['us-east-1a', 'us-east-1b'] - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - lb = conn.create_load_balancer('my-lb', zones, ports) + zones = ["us-east-1a", "us-east-1b"] + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + lb = conn.create_load_balancer("my-lb", zones, ports) - instances_health = conn.describe_instance_health('my-lb') + instances_health = conn.describe_instance_health("my-lb") instances_health.should.be.empty lb.register_instances([instance_id1, instance_id2]) - instances_health = conn.describe_instance_health('my-lb') + instances_health = conn.describe_instance_health("my-lb") instances_health.should.have.length_of(2) for instance_health in instances_health: - instance_health.instance_id.should.be.within( - [instance_id1, instance_id2]) - instance_health.state.should.equal('InService') + instance_health.instance_id.should.be.within([instance_id1, instance_id2]) + instance_health.state.should.equal("InService") - instances_health = conn.describe_instance_health('my-lb', [instance_id1]) + instances_health = conn.describe_instance_health("my-lb", [instance_id1]) instances_health.should.have.length_of(1) instances_health[0].instance_id.should.equal(instance_id1) - instances_health[0].state.should.equal('InService') + instances_health[0].state.should.equal("InService") @mock_ec2 @mock_elb def test_describe_instance_health_boto3(): - elb = boto3.client('elb', region_name="us-east-1") - ec2 = boto3.client('ec2', region_name="us-east-1") - instances = ec2.run_instances(MinCount=2, MaxCount=2)['Instances'] + elb = boto3.client("elb", region_name="us-east-1") + ec2 = boto3.client("ec2", region_name="us-east-1") + instances = ec2.run_instances(MinCount=2, MaxCount=2)["Instances"] lb_name = "my_load_balancer" elb.create_load_balancer( - Listeners=[{ - 'InstancePort': 80, - 'LoadBalancerPort': 8080, - 'Protocol': 'HTTP' - }], + Listeners=[{"InstancePort": 80, "LoadBalancerPort": 8080, "Protocol": "HTTP"}], LoadBalancerName=lb_name, ) elb.register_instances_with_load_balancer( - LoadBalancerName=lb_name, - Instances=[{'InstanceId': instances[0]['InstanceId']}] + LoadBalancerName=lb_name, Instances=[{"InstanceId": instances[0]["InstanceId"]}] ) instances_health = elb.describe_instance_health( LoadBalancerName=lb_name, - Instances=[{'InstanceId': instance['InstanceId']} for instance in instances] + Instances=[{"InstanceId": instance["InstanceId"]} for instance in instances], ) - instances_health['InstanceStates'].should.have.length_of(2) - instances_health['InstanceStates'][0]['InstanceId'].\ - should.equal(instances[0]['InstanceId']) - instances_health['InstanceStates'][0]['State'].\ - should.equal('InService') - instances_health['InstanceStates'][1]['InstanceId'].\ - should.equal(instances[1]['InstanceId']) - instances_health['InstanceStates'][1]['State'].\ - should.equal('Unknown') + instances_health["InstanceStates"].should.have.length_of(2) + instances_health["InstanceStates"][0]["InstanceId"].should.equal( + instances[0]["InstanceId"] + ) + instances_health["InstanceStates"][0]["State"].should.equal("InService") + instances_health["InstanceStates"][1]["InstanceId"].should.equal( + instances[1]["InstanceId"] + ) + instances_health["InstanceStates"][1]["State"].should.equal("Unknown") @mock_elb def test_add_remove_tags(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") - client.add_tags.when.called_with(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]).should.throw(botocore.exceptions.ClientError) + client.add_tags.when.called_with( + LoadBalancerNames=["my-lb"], Tags=[{"Key": "a", "Value": "b"}] + ).should.throw(botocore.exceptions.ClientError) client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - list(client.describe_load_balancers()[ - 'LoadBalancerDescriptions']).should.have.length_of(1) + list( + client.describe_load_balancers()["LoadBalancerDescriptions"] + ).should.have.length_of(1) - client.add_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]) + client.add_tags(LoadBalancerNames=["my-lb"], Tags=[{"Key": "a", "Value": "b"}]) - tags = dict([(d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']]) - tags.should.have.key('a').which.should.equal('b') + tags = dict( + [ + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])[ + "TagDescriptions" + ][0]["Tags"] + ] + ) + tags.should.have.key("a").which.should.equal("b") - client.add_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }, { - 'Key': 'b', - 'Value': 'b' - }, { - 'Key': 'c', - 'Value': 'b' - }, { - 'Key': 'd', - 'Value': 'b' - }, { - 'Key': 'e', - 'Value': 'b' - }, { - 'Key': 'f', - 'Value': 'b' - }, { - 'Key': 'g', - 'Value': 'b' - }, { - 'Key': 'h', - 'Value': 'b' - }, { - 'Key': 'i', - 'Value': 'b' - }, { - 'Key': 'j', - 'Value': 'b' - }]) + client.add_tags( + LoadBalancerNames=["my-lb"], + Tags=[ + {"Key": "a", "Value": "b"}, + {"Key": "b", "Value": "b"}, + {"Key": "c", "Value": "b"}, + {"Key": "d", "Value": "b"}, + {"Key": "e", "Value": "b"}, + {"Key": "f", "Value": "b"}, + {"Key": "g", "Value": "b"}, + {"Key": "h", "Value": "b"}, + {"Key": "i", "Value": "b"}, + {"Key": "j", "Value": "b"}, + ], + ) - client.add_tags.when.called_with(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'k', - 'Value': 'b' - }]).should.throw(botocore.exceptions.ClientError) + client.add_tags.when.called_with( + LoadBalancerNames=["my-lb"], Tags=[{"Key": "k", "Value": "b"}] + ).should.throw(botocore.exceptions.ClientError) - client.add_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'j', - 'Value': 'c' - }]) + client.add_tags(LoadBalancerNames=["my-lb"], Tags=[{"Key": "j", "Value": "c"}]) - tags = dict([(d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']]) + tags = dict( + [ + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])[ + "TagDescriptions" + ][0]["Tags"] + ] + ) - tags.should.have.key('a').which.should.equal('b') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('i').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') - tags.shouldnt.have.key('k') + tags.should.have.key("a").which.should.equal("b") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("i").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") + tags.shouldnt.have.key("k") - client.remove_tags(LoadBalancerNames=['my-lb'], - Tags=[{ - 'Key': 'a' - }]) + client.remove_tags(LoadBalancerNames=["my-lb"], Tags=[{"Key": "a"}]) - tags = dict([(d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']]) + tags = dict( + [ + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])[ + "TagDescriptions" + ][0]["Tags"] + ] + ) - tags.shouldnt.have.key('a') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('i').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') + tags.shouldnt.have.key("a") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("i").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") client.create_load_balancer( - LoadBalancerName='other-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 433, 'InstancePort': 8433}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="other-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 433, "InstancePort": 8433}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) - client.add_tags(LoadBalancerNames=['other-lb'], - Tags=[{ - 'Key': 'other', - 'Value': 'something' - }]) + client.add_tags( + LoadBalancerNames=["other-lb"], Tags=[{"Key": "other", "Value": "something"}] + ) - lb_tags = dict([(l['LoadBalancerName'], dict([(d['Key'], d['Value']) for d in l['Tags']])) - for l in client.describe_tags(LoadBalancerNames=['my-lb', 'other-lb'])['TagDescriptions']]) + lb_tags = dict( + [ + (l["LoadBalancerName"], dict([(d["Key"], d["Value"]) for d in l["Tags"]])) + for l in client.describe_tags(LoadBalancerNames=["my-lb", "other-lb"])[ + "TagDescriptions" + ] + ] + ) - lb_tags.should.have.key('my-lb') - lb_tags.should.have.key('other-lb') + lb_tags.should.have.key("my-lb") + lb_tags.should.have.key("other-lb") - lb_tags['my-lb'].shouldnt.have.key('other') - lb_tags[ - 'other-lb'].should.have.key('other').which.should.equal('something') + lb_tags["my-lb"].shouldnt.have.key("other") + lb_tags["other-lb"].should.have.key("other").which.should.equal("something") @mock_elb def test_create_with_tags(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'], - Tags=[{ - 'Key': 'k', - 'Value': 'v' - }] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], + Tags=[{"Key": "k", "Value": "v"}], ) - tags = dict((d['Key'], d['Value']) for d in client.describe_tags( - LoadBalancerNames=['my-lb'])['TagDescriptions'][0]['Tags']) - tags.should.have.key('k').which.should.equal('v') + tags = dict( + (d["Key"], d["Value"]) + for d in client.describe_tags(LoadBalancerNames=["my-lb"])["TagDescriptions"][ + 0 + ]["Tags"] + ) + tags.should.have.key("k").which.should.equal("v") @mock_elb def test_modify_attributes(): - client = boto3.client('elb', region_name='us-east-1') + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[{'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - AvailabilityZones=['us-east-1a', 'us-east-1b'] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + AvailabilityZones=["us-east-1a", "us-east-1b"], ) # Default ConnectionDraining timeout of 300 seconds client.modify_load_balancer_attributes( - LoadBalancerName='my-lb', - LoadBalancerAttributes={ - 'ConnectionDraining': {'Enabled': True}, - } + LoadBalancerName="my-lb", + LoadBalancerAttributes={"ConnectionDraining": {"Enabled": True}}, + ) + lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName="my-lb") + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Enabled"].should.equal( + True + ) + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Timeout"].should.equal( + 300 ) - lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName='my-lb') - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Enabled'].should.equal(True) - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Timeout'].should.equal(300) # specify a custom ConnectionDraining timeout client.modify_load_balancer_attributes( - LoadBalancerName='my-lb', - LoadBalancerAttributes={ - 'ConnectionDraining': { - 'Enabled': True, - 'Timeout': 45, - }, - } + LoadBalancerName="my-lb", + LoadBalancerAttributes={"ConnectionDraining": {"Enabled": True, "Timeout": 45}}, ) - lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName='my-lb') - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Enabled'].should.equal(True) - lb_attrs['LoadBalancerAttributes']['ConnectionDraining']['Timeout'].should.equal(45) + lb_attrs = client.describe_load_balancer_attributes(LoadBalancerName="my-lb") + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Enabled"].should.equal( + True + ) + lb_attrs["LoadBalancerAttributes"]["ConnectionDraining"]["Timeout"].should.equal(45) @mock_ec2 @mock_elb def test_subnets(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc( - CidrBlock='172.28.7.0/24', - InstanceTenancy='default' - ) - subnet = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26' - ) - client = boto3.client('elb', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") + subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="172.28.7.192/26") + client = boto3.client("elb", region_name="us-east-1") client.create_load_balancer( - LoadBalancerName='my-lb', - Listeners=[ - {'Protocol': 'tcp', 'LoadBalancerPort': 80, 'InstancePort': 8080}], - Subnets=[subnet.id] + LoadBalancerName="my-lb", + Listeners=[{"Protocol": "tcp", "LoadBalancerPort": 80, "InstancePort": 8080}], + Subnets=[subnet.id], ) - lb = client.describe_load_balancers()['LoadBalancerDescriptions'][0] - lb.should.have.key('Subnets').which.should.have.length_of(1) - lb['Subnets'][0].should.equal(subnet.id) + lb = client.describe_load_balancers()["LoadBalancerDescriptions"][0] + lb.should.have.key("Subnets").which.should.have.length_of(1) + lb["Subnets"][0].should.equal(subnet.id) - lb.should.have.key('VPCId').which.should.equal(vpc.id) + lb.should.have.key("VPCId").which.should.equal(vpc.id) @mock_elb_deprecated def test_create_load_balancer_duplicate(): conn = boto.connect_elb() - ports = [(80, 8080, 'http'), (443, 8443, 'tcp')] - conn.create_load_balancer('my-lb', [], ports) - conn.create_load_balancer.when.called_with( - 'my-lb', [], ports).should.throw(BotoServerError) + ports = [(80, 8080, "http"), (443, 8443, "tcp")] + conn.create_load_balancer("my-lb", [], ports) + conn.create_load_balancer.when.called_with("my-lb", [], ports).should.throw( + BotoServerError + ) diff --git a/tests/test_elb/test_server.py b/tests/test_elb/test_server.py index 0033284d7..0f432cef4 100644 --- a/tests/test_elb/test_server.py +++ b/tests/test_elb/test_server.py @@ -3,15 +3,15 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_elb_describe_instances(): backend = server.create_backend_app("elb") test_client = backend.test_client() - res = test_client.get('/?Action=DescribeLoadBalancers&Version=2015-12-01') + res = test_client.get("/?Action=DescribeLoadBalancers&Version=2015-12-01") - res.data.should.contain(b'DescribeLoadBalancersResponse') + res.data.should.contain(b"DescribeLoadBalancersResponse") diff --git a/tests/test_elbv2/test_elbv2.py b/tests/test_elbv2/test_elbv2.py index 97b876fec..593ced43b 100644 --- a/tests/test_elbv2/test_elbv2.py +++ b/tests/test_elbv2/test_elbv2.py @@ -15,656 +15,650 @@ from moto.elbv2 import elbv2_backends @mock_elbv2 @mock_ec2 def test_create_load_balancer(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - lb = response.get('LoadBalancers')[0] + lb = response.get("LoadBalancers")[0] - lb.get('DNSName').should.equal("my-lb-1.us-east-1.elb.amazonaws.com") - lb.get('LoadBalancerArn').should.equal( - 'arn:aws:elasticloadbalancing:us-east-1:1:loadbalancer/my-lb/50dc6c495c0c9188') - lb.get('SecurityGroups').should.equal([security_group.id]) - lb.get('AvailabilityZones').should.equal([ - {'SubnetId': subnet1.id, 'ZoneName': 'us-east-1a'}, - {'SubnetId': subnet2.id, 'ZoneName': 'us-east-1b'}]) + lb.get("DNSName").should.equal("my-lb-1.us-east-1.elb.amazonaws.com") + lb.get("LoadBalancerArn").should.equal( + "arn:aws:elasticloadbalancing:us-east-1:1:loadbalancer/my-lb/50dc6c495c0c9188" + ) + lb.get("SecurityGroups").should.equal([security_group.id]) + lb.get("AvailabilityZones").should.equal( + [ + {"SubnetId": subnet1.id, "ZoneName": "us-east-1a"}, + {"SubnetId": subnet2.id, "ZoneName": "us-east-1b"}, + ] + ) # Ensure the tags persisted - response = conn.describe_tags(ResourceArns=[lb.get('LoadBalancerArn')]) - tags = {d['Key']: d['Value'] - for d in response['TagDescriptions'][0]['Tags']} - tags.should.equal({'key_name': 'a_value'}) + response = conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")]) + tags = {d["Key"]: d["Value"] for d in response["TagDescriptions"][0]["Tags"]} + tags.should.equal({"key_name": "a_value"}) @mock_elbv2 @mock_ec2 def test_describe_load_balancers(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.describe_load_balancers() - response.get('LoadBalancers').should.have.length_of(1) - lb = response.get('LoadBalancers')[0] - lb.get('LoadBalancerName').should.equal('my-lb') + response.get("LoadBalancers").should.have.length_of(1) + lb = response.get("LoadBalancers")[0] + lb.get("LoadBalancerName").should.equal("my-lb") response = conn.describe_load_balancers( - LoadBalancerArns=[lb.get('LoadBalancerArn')]) - response.get('LoadBalancers')[0].get( - 'LoadBalancerName').should.equal('my-lb') + LoadBalancerArns=[lb.get("LoadBalancerArn")] + ) + response.get("LoadBalancers")[0].get("LoadBalancerName").should.equal("my-lb") - response = conn.describe_load_balancers(Names=['my-lb']) - response.get('LoadBalancers')[0].get( - 'LoadBalancerName').should.equal('my-lb') + response = conn.describe_load_balancers(Names=["my-lb"]) + response.get("LoadBalancers")[0].get("LoadBalancerName").should.equal("my-lb") with assert_raises(ClientError): - conn.describe_load_balancers(LoadBalancerArns=['not-a/real/arn']) + conn.describe_load_balancers(LoadBalancerArns=["not-a/real/arn"]) with assert_raises(ClientError): - conn.describe_load_balancers(Names=['nope']) + conn.describe_load_balancers(Names=["nope"]) @mock_elbv2 @mock_ec2 def test_add_remove_tags(): - conn = boto3.client('elbv2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - lbs = conn.describe_load_balancers()['LoadBalancers'] + lbs = conn.describe_load_balancers()["LoadBalancers"] lbs.should.have.length_of(1) lb = lbs[0] with assert_raises(ClientError): - conn.add_tags(ResourceArns=['missing-arn'], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]) + conn.add_tags(ResourceArns=["missing-arn"], Tags=[{"Key": "a", "Value": "b"}]) - conn.add_tags(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }]) + conn.add_tags( + ResourceArns=[lb.get("LoadBalancerArn")], Tags=[{"Key": "a", "Value": "b"}] + ) - tags = {d['Key']: d['Value'] for d in conn.describe_tags( - ResourceArns=[lb.get('LoadBalancerArn')])['TagDescriptions'][0]['Tags']} - tags.should.have.key('a').which.should.equal('b') + tags = { + d["Key"]: d["Value"] + for d in conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")])[ + "TagDescriptions" + ][0]["Tags"] + } + tags.should.have.key("a").which.should.equal("b") - conn.add_tags(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'a', - 'Value': 'b' - }, { - 'Key': 'b', - 'Value': 'b' - }, { - 'Key': 'c', - 'Value': 'b' - }, { - 'Key': 'd', - 'Value': 'b' - }, { - 'Key': 'e', - 'Value': 'b' - }, { - 'Key': 'f', - 'Value': 'b' - }, { - 'Key': 'g', - 'Value': 'b' - }, { - 'Key': 'h', - 'Value': 'b' - }, { - 'Key': 'j', - 'Value': 'b' - }]) + conn.add_tags( + ResourceArns=[lb.get("LoadBalancerArn")], + Tags=[ + {"Key": "a", "Value": "b"}, + {"Key": "b", "Value": "b"}, + {"Key": "c", "Value": "b"}, + {"Key": "d", "Value": "b"}, + {"Key": "e", "Value": "b"}, + {"Key": "f", "Value": "b"}, + {"Key": "g", "Value": "b"}, + {"Key": "h", "Value": "b"}, + {"Key": "j", "Value": "b"}, + ], + ) - conn.add_tags.when.called_with(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'k', - 'Value': 'b' - }]).should.throw(botocore.exceptions.ClientError) + conn.add_tags.when.called_with( + ResourceArns=[lb.get("LoadBalancerArn")], Tags=[{"Key": "k", "Value": "b"}] + ).should.throw(botocore.exceptions.ClientError) - conn.add_tags(ResourceArns=[lb.get('LoadBalancerArn')], - Tags=[{ - 'Key': 'j', - 'Value': 'c' - }]) + conn.add_tags( + ResourceArns=[lb.get("LoadBalancerArn")], Tags=[{"Key": "j", "Value": "c"}] + ) - tags = {d['Key']: d['Value'] for d in conn.describe_tags( - ResourceArns=[lb.get('LoadBalancerArn')])['TagDescriptions'][0]['Tags']} + tags = { + d["Key"]: d["Value"] + for d in conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")])[ + "TagDescriptions" + ][0]["Tags"] + } - tags.should.have.key('a').which.should.equal('b') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') - tags.shouldnt.have.key('k') + tags.should.have.key("a").which.should.equal("b") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") + tags.shouldnt.have.key("k") - conn.remove_tags(ResourceArns=[lb.get('LoadBalancerArn')], - TagKeys=['a']) + conn.remove_tags(ResourceArns=[lb.get("LoadBalancerArn")], TagKeys=["a"]) - tags = {d['Key']: d['Value'] for d in conn.describe_tags( - ResourceArns=[lb.get('LoadBalancerArn')])['TagDescriptions'][0]['Tags']} + tags = { + d["Key"]: d["Value"] + for d in conn.describe_tags(ResourceArns=[lb.get("LoadBalancerArn")])[ + "TagDescriptions" + ][0]["Tags"] + } - tags.shouldnt.have.key('a') - tags.should.have.key('b').which.should.equal('b') - tags.should.have.key('c').which.should.equal('b') - tags.should.have.key('d').which.should.equal('b') - tags.should.have.key('e').which.should.equal('b') - tags.should.have.key('f').which.should.equal('b') - tags.should.have.key('g').which.should.equal('b') - tags.should.have.key('h').which.should.equal('b') - tags.should.have.key('j').which.should.equal('c') + tags.shouldnt.have.key("a") + tags.should.have.key("b").which.should.equal("b") + tags.should.have.key("c").which.should.equal("b") + tags.should.have.key("d").which.should.equal("b") + tags.should.have.key("e").which.should.equal("b") + tags.should.have.key("f").which.should.equal("b") + tags.should.have.key("g").which.should.equal("b") + tags.should.have.key("h").which.should.equal("b") + tags.should.have.key("j").which.should.equal("c") @mock_elbv2 @mock_ec2 def test_create_elb_in_multiple_region(): - for region in ['us-west-1', 'us-west-2']: - conn = boto3.client('elbv2', region_name=region) - ec2 = boto3.resource('ec2', region_name=region) + for region in ["us-west-1", "us-west-2"]: + conn = boto3.client("elbv2", region_name=region) + ec2 = boto3.resource("ec2", region_name=region) security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc( - CidrBlock='172.28.7.0/24', - InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone=region + 'a') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone=region + "a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone=region + 'b') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone=region + "b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) list( - boto3.client( - 'elbv2', - region_name='us-west-1').describe_load_balancers().get('LoadBalancers') + boto3.client("elbv2", region_name="us-west-1") + .describe_load_balancers() + .get("LoadBalancers") ).should.have.length_of(1) list( - boto3.client( - 'elbv2', - region_name='us-west-2').describe_load_balancers().get('LoadBalancers') + boto3.client("elbv2", region_name="us-west-2") + .describe_load_balancers() + .get("LoadBalancers") ).should.have.length_of(1) @mock_elbv2 @mock_ec2 def test_create_target_group_and_listeners(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") # Can't create a target group with an invalid protocol with assert_raises(ClientError): conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='/HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="/HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] - target_group_arn = target_group['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] + target_group_arn = target_group["TargetGroupArn"] # Add tags to the target group - conn.add_tags(ResourceArns=[target_group_arn], Tags=[ - {'Key': 'target', 'Value': 'group'}]) - conn.describe_tags(ResourceArns=[target_group_arn])['TagDescriptions'][0]['Tags'].should.equal( - [{'Key': 'target', 'Value': 'group'}]) + conn.add_tags( + ResourceArns=[target_group_arn], Tags=[{"Key": "target", "Value": "group"}] + ) + conn.describe_tags(ResourceArns=[target_group_arn])["TagDescriptions"][0][ + "Tags" + ].should.equal([{"Key": "target", "Value": "group"}]) # Check it's in the describe_target_groups response response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(1) + response.get("TargetGroups").should.have.length_of(1) # Plain HTTP listener response = conn.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', + Protocol="HTTP", Port=80, - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group.get('TargetGroupArn')}]) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(80) - listener.get('Protocol').should.equal('HTTP') - listener.get('DefaultActions').should.equal([{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward'}]) - http_listener_arn = listener.get('ListenerArn') + DefaultActions=[ + {"Type": "forward", "TargetGroupArn": target_group.get("TargetGroupArn")} + ], + ) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(80) + listener.get("Protocol").should.equal("HTTP") + listener.get("DefaultActions").should.equal( + [{"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"}] + ) + http_listener_arn = listener.get("ListenerArn") - response = conn.describe_target_groups(LoadBalancerArn=load_balancer_arn, - Names=['a-target']) - response.get('TargetGroups').should.have.length_of(1) + response = conn.describe_target_groups( + LoadBalancerArn=load_balancer_arn, Names=["a-target"] + ) + response.get("TargetGroups").should.have.length_of(1) # And another with SSL response = conn.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTPS', + Protocol="HTTPS", Port=443, Certificates=[ - {'CertificateArn': 'arn:aws:iam:123456789012:server-certificate/test-cert'}], - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group.get('TargetGroupArn')}]) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(443) - listener.get('Protocol').should.equal('HTTPS') - listener.get('Certificates').should.equal([{ - 'CertificateArn': 'arn:aws:iam:123456789012:server-certificate/test-cert', - }]) - listener.get('DefaultActions').should.equal([{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward'}]) + {"CertificateArn": "arn:aws:iam:123456789012:server-certificate/test-cert"} + ], + DefaultActions=[ + {"Type": "forward", "TargetGroupArn": target_group.get("TargetGroupArn")} + ], + ) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(443) + listener.get("Protocol").should.equal("HTTPS") + listener.get("Certificates").should.equal( + [{"CertificateArn": "arn:aws:iam:123456789012:server-certificate/test-cert"}] + ) + listener.get("DefaultActions").should.equal( + [{"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"}] + ) - https_listener_arn = listener.get('ListenerArn') + https_listener_arn = listener.get("ListenerArn") response = conn.describe_listeners(LoadBalancerArn=load_balancer_arn) - response.get('Listeners').should.have.length_of(2) + response.get("Listeners").should.have.length_of(2) response = conn.describe_listeners(ListenerArns=[https_listener_arn]) - response.get('Listeners').should.have.length_of(1) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(443) - listener.get('Protocol').should.equal('HTTPS') + response.get("Listeners").should.have.length_of(1) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(443) + listener.get("Protocol").should.equal("HTTPS") response = conn.describe_listeners( - ListenerArns=[ - http_listener_arn, - https_listener_arn]) - response.get('Listeners').should.have.length_of(2) + ListenerArns=[http_listener_arn, https_listener_arn] + ) + response.get("Listeners").should.have.length_of(2) # Try to delete the target group and it fails because there's a # listener referencing it with assert_raises(ClientError) as e: - conn.delete_target_group( - TargetGroupArn=target_group.get('TargetGroupArn')) - e.exception.operation_name.should.equal('DeleteTargetGroup') - e.exception.args.should.equal(("An error occurred (ResourceInUse) when calling the DeleteTargetGroup operation: The target group 'arn:aws:elasticloadbalancing:us-east-1:1:targetgroup/a-target/50dc6c495c0c9188' is currently in use by a listener or a rule", )) # NOQA + conn.delete_target_group(TargetGroupArn=target_group.get("TargetGroupArn")) + e.exception.operation_name.should.equal("DeleteTargetGroup") + e.exception.args.should.equal( + ( + "An error occurred (ResourceInUse) when calling the DeleteTargetGroup operation: The target group 'arn:aws:elasticloadbalancing:us-east-1:1:targetgroup/a-target/50dc6c495c0c9188' is currently in use by a listener or a rule", + ) + ) # NOQA # Delete one listener response = conn.describe_listeners(LoadBalancerArn=load_balancer_arn) - response.get('Listeners').should.have.length_of(2) + response.get("Listeners").should.have.length_of(2) conn.delete_listener(ListenerArn=http_listener_arn) response = conn.describe_listeners(LoadBalancerArn=load_balancer_arn) - response.get('Listeners').should.have.length_of(1) + response.get("Listeners").should.have.length_of(1) # Then delete the load balancer conn.delete_load_balancer(LoadBalancerArn=load_balancer_arn) # It's gone response = conn.describe_load_balancers() - response.get('LoadBalancers').should.have.length_of(0) + response.get("LoadBalancers").should.have.length_of(0) # And it deleted the remaining listener response = conn.describe_listeners( - ListenerArns=[ - http_listener_arn, - https_listener_arn]) - response.get('Listeners').should.have.length_of(0) + ListenerArns=[http_listener_arn, https_listener_arn] + ) + response.get("Listeners").should.have.length_of(0) # But not the target groups response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(1) + response.get("TargetGroups").should.have.length_of(1) # Which we'll now delete - conn.delete_target_group(TargetGroupArn=target_group.get('TargetGroupArn')) + conn.delete_target_group(TargetGroupArn=target_group.get("TargetGroupArn")) response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(0) + response.get("TargetGroups").should.have.length_of(0) @mock_elbv2 @mock_ec2 def test_create_target_group_without_non_required_parameters(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) # request without HealthCheckIntervalSeconds parameter # which is default to 30 seconds response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080' + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", ) - target_group = response.get('TargetGroups')[0] + target_group = response.get("TargetGroups")[0] target_group.should_not.be.none @mock_elbv2 @mock_ec2 def test_create_invalid_target_group(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") # Fail to create target group with name which length is 33 - long_name = 'A' * 33 + long_name = "A" * 33 with assert_raises(ClientError): conn.create_target_group( Name=long_name, - Protocol='HTTP', + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) - invalid_names = [ - '-name', - 'name-', - '-name-', - 'example.com', - 'test@test', - 'Na--me'] + invalid_names = ["-name", "name-", "-name-", "example.com", "test@test", "Na--me"] for name in invalid_names: with assert_raises(ClientError): conn.create_target_group( Name=name, - Protocol='HTTP', + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) - valid_names = ['name', 'Name', '000'] + valid_names = ["name", "Name", "000"] for name in valid_names: conn.create_target_group( Name=name, - Protocol='HTTP', + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) @mock_elbv2 @mock_ec2 def test_describe_paginated_balancers(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) for i in range(51): conn.create_load_balancer( - Name='my-lb%d' % i, + Name="my-lb%d" % i, Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) resp = conn.describe_load_balancers() - resp['LoadBalancers'].should.have.length_of(50) - resp['NextMarker'].should.equal( - resp['LoadBalancers'][-1]['LoadBalancerName']) - resp2 = conn.describe_load_balancers(Marker=resp['NextMarker']) - resp2['LoadBalancers'].should.have.length_of(1) - assert 'NextToken' not in resp2.keys() + resp["LoadBalancers"].should.have.length_of(50) + resp["NextMarker"].should.equal(resp["LoadBalancers"][-1]["LoadBalancerName"]) + resp2 = conn.describe_load_balancers(Marker=resp["NextMarker"]) + resp2["LoadBalancers"].should.have.length_of(1) + assert "NextToken" not in resp2.keys() @mock_elbv2 @mock_ec2 def test_delete_load_balancer(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - response.get('LoadBalancers').should.have.length_of(1) - lb = response.get('LoadBalancers')[0] + response.get("LoadBalancers").should.have.length_of(1) + lb = response.get("LoadBalancers")[0] - conn.delete_load_balancer(LoadBalancerArn=lb.get('LoadBalancerArn')) - balancers = conn.describe_load_balancers().get('LoadBalancers') + conn.delete_load_balancer(LoadBalancerArn=lb.get("LoadBalancerArn")) + balancers = conn.describe_load_balancers().get("LoadBalancers") balancers.should.have.length_of(0) @mock_ec2 @mock_elbv2 def test_register_targets(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # No targets registered yet response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(0) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=2, MaxCount=2) + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=2, MaxCount=2) instance_id1 = response[0].id instance_id2 = response[1].id response = conn.register_targets( - TargetGroupArn=target_group.get('TargetGroupArn'), + TargetGroupArn=target_group.get("TargetGroupArn"), Targets=[ - { - 'Id': instance_id1, - 'Port': 5060, - }, - { - 'Id': instance_id2, - 'Port': 4030, - }, - ]) + {"Id": instance_id1, "Port": 5060}, + {"Id": instance_id2, "Port": 4030}, + ], + ) response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(2) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(2) response = conn.deregister_targets( - TargetGroupArn=target_group.get('TargetGroupArn'), - Targets=[{'Id': instance_id2}]) + TargetGroupArn=target_group.get("TargetGroupArn"), + Targets=[{"Id": instance_id2}], + ) response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(1) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) @mock_ec2 @@ -672,84 +666,84 @@ def test_register_targets(): def test_stopped_instance_target(): target_group_port = 8080 - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=target_group_port, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # No targets registered yet response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(0) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=1, MaxCount=1) + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=1, MaxCount=1) instance = response[0] - target_dict = { - 'Id': instance.id, - 'Port': 500 - } + target_dict = {"Id": instance.id, "Port": 500} response = conn.register_targets( - TargetGroupArn=target_group.get('TargetGroupArn'), - Targets=[target_dict]) + TargetGroupArn=target_group.get("TargetGroupArn"), Targets=[target_dict] + ) response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(1) - target_health_description = response.get('TargetHealthDescriptions')[0] + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) + target_health_description = response.get("TargetHealthDescriptions")[0] - target_health_description['Target'].should.equal(target_dict) - target_health_description['HealthCheckPort'].should.equal(str(target_group_port)) - target_health_description['TargetHealth'].should.equal({ - 'State': 'healthy' - }) + target_health_description["Target"].should.equal(target_dict) + target_health_description["HealthCheckPort"].should.equal(str(target_group_port)) + target_health_description["TargetHealth"].should.equal({"State": "healthy"}) instance.stop() response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(1) - target_health_description = response.get('TargetHealthDescriptions')[0] - target_health_description['Target'].should.equal(target_dict) - target_health_description['HealthCheckPort'].should.equal(str(target_group_port)) - target_health_description['TargetHealth'].should.equal({ - 'State': 'unused', - 'Reason': 'Target.InvalidState', - 'Description': 'Target is in the stopped state' - }) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) + target_health_description = response.get("TargetHealthDescriptions")[0] + target_health_description["Target"].should.equal(target_dict) + target_health_description["HealthCheckPort"].should.equal(str(target_group_port)) + target_health_description["TargetHealth"].should.equal( + { + "State": "unused", + "Reason": "Target.InvalidState", + "Description": "Target is in the stopped state", + } + ) @mock_ec2 @@ -757,281 +751,265 @@ def test_stopped_instance_target(): def test_terminated_instance_target(): target_group_port = 8080 - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=target_group_port, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # No targets registered yet response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(0) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) - response = ec2.create_instances( - ImageId='ami-1234abcd', MinCount=1, MaxCount=1) + response = ec2.create_instances(ImageId="ami-1234abcd", MinCount=1, MaxCount=1) instance = response[0] - target_dict = { - 'Id': instance.id, - 'Port': 500 - } + target_dict = {"Id": instance.id, "Port": 500} response = conn.register_targets( - TargetGroupArn=target_group.get('TargetGroupArn'), - Targets=[target_dict]) + TargetGroupArn=target_group.get("TargetGroupArn"), Targets=[target_dict] + ) response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(1) - target_health_description = response.get('TargetHealthDescriptions')[0] + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(1) + target_health_description = response.get("TargetHealthDescriptions")[0] - target_health_description['Target'].should.equal(target_dict) - target_health_description['HealthCheckPort'].should.equal(str(target_group_port)) - target_health_description['TargetHealth'].should.equal({ - 'State': 'healthy' - }) + target_health_description["Target"].should.equal(target_dict) + target_health_description["HealthCheckPort"].should.equal(str(target_group_port)) + target_health_description["TargetHealth"].should.equal({"State": "healthy"}) instance.terminate() response = conn.describe_target_health( - TargetGroupArn=target_group.get('TargetGroupArn')) - response.get('TargetHealthDescriptions').should.have.length_of(0) + TargetGroupArn=target_group.get("TargetGroupArn") + ) + response.get("TargetHealthDescriptions").should.have.length_of(0) @mock_ec2 @mock_elbv2 def test_target_group_attributes(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # Check it's in the describe_target_groups response response = conn.describe_target_groups() - response.get('TargetGroups').should.have.length_of(1) - target_group_arn = target_group['TargetGroupArn'] + response.get("TargetGroups").should.have.length_of(1) + target_group_arn = target_group["TargetGroupArn"] # check if Names filter works response = conn.describe_target_groups(Names=[]) - response = conn.describe_target_groups(Names=['a-target']) - response.get('TargetGroups').should.have.length_of(1) - target_group_arn = target_group['TargetGroupArn'] + response = conn.describe_target_groups(Names=["a-target"]) + response.get("TargetGroups").should.have.length_of(1) + target_group_arn = target_group["TargetGroupArn"] # The attributes should start with the two defaults - response = conn.describe_target_group_attributes( - TargetGroupArn=target_group_arn) - response['Attributes'].should.have.length_of(2) - attributes = {attr['Key']: attr['Value'] - for attr in response['Attributes']} - attributes['deregistration_delay.timeout_seconds'].should.equal('300') - attributes['stickiness.enabled'].should.equal('false') + response = conn.describe_target_group_attributes(TargetGroupArn=target_group_arn) + response["Attributes"].should.have.length_of(2) + attributes = {attr["Key"]: attr["Value"] for attr in response["Attributes"]} + attributes["deregistration_delay.timeout_seconds"].should.equal("300") + attributes["stickiness.enabled"].should.equal("false") # Add cookie stickiness response = conn.modify_target_group_attributes( TargetGroupArn=target_group_arn, Attributes=[ - { - 'Key': 'stickiness.enabled', - 'Value': 'true', - }, - { - 'Key': 'stickiness.type', - 'Value': 'lb_cookie', - }, - ]) + {"Key": "stickiness.enabled", "Value": "true"}, + {"Key": "stickiness.type", "Value": "lb_cookie"}, + ], + ) # The response should have only the keys updated - response['Attributes'].should.have.length_of(2) - attributes = {attr['Key']: attr['Value'] - for attr in response['Attributes']} - attributes['stickiness.type'].should.equal('lb_cookie') - attributes['stickiness.enabled'].should.equal('true') + response["Attributes"].should.have.length_of(2) + attributes = {attr["Key"]: attr["Value"] for attr in response["Attributes"]} + attributes["stickiness.type"].should.equal("lb_cookie") + attributes["stickiness.enabled"].should.equal("true") # These new values should be in the full attribute list - response = conn.describe_target_group_attributes( - TargetGroupArn=target_group_arn) - response['Attributes'].should.have.length_of(3) - attributes = {attr['Key']: attr['Value'] - for attr in response['Attributes']} - attributes['stickiness.type'].should.equal('lb_cookie') - attributes['stickiness.enabled'].should.equal('true') + response = conn.describe_target_group_attributes(TargetGroupArn=target_group_arn) + response["Attributes"].should.have.length_of(3) + attributes = {attr["Key"]: attr["Value"] for attr in response["Attributes"]} + attributes["stickiness.type"].should.equal("lb_cookie") + attributes["stickiness.enabled"].should.equal("true") @mock_elbv2 @mock_ec2 def test_handle_listener_rules(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") # Can't create a target group with an invalid protocol with assert_raises(ClientError): conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='/HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="/HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] # Plain HTTP listener response = conn.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', + Protocol="HTTP", Port=80, - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group.get('TargetGroupArn')}]) - listener = response.get('Listeners')[0] - listener.get('Port').should.equal(80) - listener.get('Protocol').should.equal('HTTP') - listener.get('DefaultActions').should.equal([{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward'}]) - http_listener_arn = listener.get('ListenerArn') + DefaultActions=[ + {"Type": "forward", "TargetGroupArn": target_group.get("TargetGroupArn")} + ], + ) + listener = response.get("Listeners")[0] + listener.get("Port").should.equal(80) + listener.get("Protocol").should.equal("HTTP") + listener.get("DefaultActions").should.equal( + [{"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"}] + ) + http_listener_arn = listener.get("ListenerArn") # create first rule priority = 100 - host = 'xxx.example.com' - path_pattern = 'foobar' + host = "xxx.example.com" + path_pattern = "foobar" created_rule = conn.create_rule( ListenerArn=http_listener_arn, Priority=priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, - { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] - )['Rules'][0] - created_rule['Priority'].should.equal('100') + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ + {"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"} + ], + )["Rules"][0] + created_rule["Priority"].should.equal("100") # check if rules is sorted by priority priority = 50 - host = 'yyy.example.com' - path_pattern = 'foobar' + host = "yyy.example.com" + path_pattern = "foobar" rules = conn.create_rule( ListenerArn=http_listener_arn, Priority=priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, - { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ + {"TargetGroupArn": target_group.get("TargetGroupArn"), "Type": "forward"} + ], ) # test for PriorityInUse @@ -1039,46 +1017,43 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) # test for describe listeners obtained_rules = conn.describe_rules(ListenerArn=http_listener_arn) - len(obtained_rules['Rules']).should.equal(3) - priorities = [rule['Priority'] for rule in obtained_rules['Rules']] - priorities.should.equal(['50', '100', 'default']) + len(obtained_rules["Rules"]).should.equal(3) + priorities = [rule["Priority"] for rule in obtained_rules["Rules"]] + priorities.should.equal(["50", "100", "default"]) - first_rule = obtained_rules['Rules'][0] - second_rule = obtained_rules['Rules'][1] - obtained_rules = conn.describe_rules(RuleArns=[first_rule['RuleArn']]) - obtained_rules['Rules'].should.equal([first_rule]) + first_rule = obtained_rules["Rules"][0] + second_rule = obtained_rules["Rules"][1] + obtained_rules = conn.describe_rules(RuleArns=[first_rule["RuleArn"]]) + obtained_rules["Rules"].should.equal([first_rule]) # test for pagination - obtained_rules = conn.describe_rules( - ListenerArn=http_listener_arn, PageSize=1) - len(obtained_rules['Rules']).should.equal(1) - obtained_rules.should.have.key('NextMarker') - next_marker = obtained_rules['NextMarker'] + obtained_rules = conn.describe_rules(ListenerArn=http_listener_arn, PageSize=1) + len(obtained_rules["Rules"]).should.equal(1) + obtained_rules.should.have.key("NextMarker") + next_marker = obtained_rules["NextMarker"] following_rules = conn.describe_rules( - ListenerArn=http_listener_arn, - PageSize=1, - Marker=next_marker) - len(following_rules['Rules']).should.equal(1) - following_rules.should.have.key('NextMarker') - following_rules['Rules'][0]['RuleArn'].should_not.equal( - obtained_rules['Rules'][0]['RuleArn']) + ListenerArn=http_listener_arn, PageSize=1, Marker=next_marker + ) + len(following_rules["Rules"]).should.equal(1) + following_rules.should.have.key("NextMarker") + following_rules["Rules"][0]["RuleArn"].should_not.equal( + obtained_rules["Rules"][0]["RuleArn"] + ) # test for invalid describe rule request with assert_raises(ClientError): @@ -1087,52 +1062,50 @@ def test_handle_listener_rules(): conn.describe_rules(RuleArns=[]) with assert_raises(ClientError): conn.describe_rules( - ListenerArn=http_listener_arn, - RuleArns=[first_rule['RuleArn']] + ListenerArn=http_listener_arn, RuleArns=[first_rule["RuleArn"]] ) # modify rule partially - new_host = 'new.example.com' - new_path_pattern = 'new_path' + new_host = "new.example.com" + new_path_pattern = "new_path" modified_rule = conn.modify_rule( - RuleArn=first_rule['RuleArn'], - Conditions=[{ - 'Field': 'host-header', - 'Values': [new_host] - }, - { - 'Field': 'path-pattern', - 'Values': [new_path_pattern] - }] - )['Rules'][0] + RuleArn=first_rule["RuleArn"], + Conditions=[ + {"Field": "host-header", "Values": [new_host]}, + {"Field": "path-pattern", "Values": [new_path_pattern]}, + ], + )["Rules"][0] rules = conn.describe_rules(ListenerArn=http_listener_arn) - obtained_rule = rules['Rules'][0] + obtained_rule = rules["Rules"][0] modified_rule.should.equal(obtained_rule) - obtained_rule['Conditions'][0]['Values'][0].should.equal(new_host) - obtained_rule['Conditions'][1]['Values'][0].should.equal(new_path_pattern) - obtained_rule['Actions'][0]['TargetGroupArn'].should.equal( - target_group.get('TargetGroupArn')) + obtained_rule["Conditions"][0]["Values"][0].should.equal(new_host) + obtained_rule["Conditions"][1]["Values"][0].should.equal(new_path_pattern) + obtained_rule["Actions"][0]["TargetGroupArn"].should.equal( + target_group.get("TargetGroupArn") + ) # modify priority conn.set_rule_priorities( RulePriorities=[ - {'RuleArn': first_rule['RuleArn'], - 'Priority': int(first_rule['Priority']) - 1} + { + "RuleArn": first_rule["RuleArn"], + "Priority": int(first_rule["Priority"]) - 1, + } ] ) with assert_raises(ClientError): conn.set_rule_priorities( RulePriorities=[ - {'RuleArn': first_rule['RuleArn'], 'Priority': 999}, - {'RuleArn': second_rule['RuleArn'], 'Priority': 999} + {"RuleArn": first_rule["RuleArn"], "Priority": 999}, + {"RuleArn": second_rule["RuleArn"], "Priority": 999}, ] ) # delete - arn = first_rule['RuleArn'] + arn = first_rule["RuleArn"] conn.delete_rule(RuleArn=arn) - rules = conn.describe_rules(ListenerArn=http_listener_arn)['Rules'] + rules = conn.describe_rules(ListenerArn=http_listener_arn)["Rules"] len(rules).should.equal(2) # test for invalid action type @@ -1141,39 +1114,30 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[ { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward2' - }] + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward2", + } + ], ) # test for invalid action type safe_priority = 2 - invalid_target_group_arn = target_group.get('TargetGroupArn') + 'x' + invalid_target_group_arn = target_group.get("TargetGroupArn") + "x" with assert_raises(ClientError): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host] - }, - { - 'Field': 'path-pattern', - 'Values': [path_pattern] - }], - Actions=[{ - 'TargetGroupArn': invalid_target_group_arn, - 'Type': 'forward' - }] + Conditions=[ + {"Field": "host-header", "Values": [host]}, + {"Field": "path-pattern", "Values": [path_pattern]}, + ], + Actions=[{"TargetGroupArn": invalid_target_group_arn, "Type": "forward"}], ) # test for invalid condition field_name @@ -1182,14 +1146,13 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'xxxxxxx', - 'Values': [host] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[{"Field": "xxxxxxx", "Values": [host]}], + Actions=[ + { + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) # test for emptry condition value @@ -1198,14 +1161,13 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[{"Field": "host-header", "Values": []}], + Actions=[ + { + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) # test for multiple condition value @@ -1214,444 +1176,440 @@ def test_handle_listener_rules(): conn.create_rule( ListenerArn=http_listener_arn, Priority=safe_priority, - Conditions=[{ - 'Field': 'host-header', - 'Values': [host, host] - }], - Actions=[{ - 'TargetGroupArn': target_group.get('TargetGroupArn'), - 'Type': 'forward' - }] + Conditions=[{"Field": "host-header", "Values": [host, host]}], + Actions=[ + { + "TargetGroupArn": target_group.get("TargetGroupArn"), + "Type": "forward", + } + ], ) @mock_elbv2 @mock_ec2 def test_describe_invalid_target_group(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - response.get('LoadBalancers')[0].get('LoadBalancerArn') + response.get("LoadBalancers")[0].get("LoadBalancerArn") response = conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) # Check error raises correctly with assert_raises(ClientError): - conn.describe_target_groups(Names=['invalid']) + conn.describe_target_groups(Names=["invalid"]) @mock_elbv2 @mock_ec2 def test_describe_target_groups_no_arguments(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - response.get('LoadBalancers')[0].get('LoadBalancerArn') + response.get("LoadBalancers")[0].get("LoadBalancerArn") conn.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) + Matcher={"HttpCode": "200"}, + ) - assert len(conn.describe_target_groups()['TargetGroups']) == 1 + assert len(conn.describe_target_groups()["TargetGroups"]) == 1 @mock_elbv2 def test_describe_account_limits(): - client = boto3.client('elbv2', region_name='eu-central-1') + client = boto3.client("elbv2", region_name="eu-central-1") resp = client.describe_account_limits() - resp['Limits'][0].should.contain('Name') - resp['Limits'][0].should.contain('Max') + resp["Limits"][0].should.contain("Name") + resp["Limits"][0].should.contain("Max") @mock_elbv2 def test_describe_ssl_policies(): - client = boto3.client('elbv2', region_name='eu-central-1') + client = boto3.client("elbv2", region_name="eu-central-1") resp = client.describe_ssl_policies() - len(resp['SslPolicies']).should.equal(5) + len(resp["SslPolicies"]).should.equal(5) - resp = client.describe_ssl_policies(Names=['ELBSecurityPolicy-TLS-1-2-2017-01', 'ELBSecurityPolicy-2016-08']) - len(resp['SslPolicies']).should.equal(2) + resp = client.describe_ssl_policies( + Names=["ELBSecurityPolicy-TLS-1-2-2017-01", "ELBSecurityPolicy-2016-08"] + ) + len(resp["SslPolicies"]).should.equal(2) @mock_elbv2 @mock_ec2 def test_set_ip_address_type(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] # Internal LBs cant be dualstack yet with assert_raises(ClientError): - client.set_ip_address_type( - LoadBalancerArn=arn, - IpAddressType='dualstack' - ) + client.set_ip_address_type(LoadBalancerArn=arn, IpAddressType="dualstack") # Create internet facing one response = client.create_load_balancer( - Name='my-lb2', + Name="my-lb2", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internet-facing', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] - - client.set_ip_address_type( - LoadBalancerArn=arn, - IpAddressType='dualstack' + Scheme="internet-facing", + Tags=[{"Key": "key_name", "Value": "a_value"}], ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] + + client.set_ip_address_type(LoadBalancerArn=arn, IpAddressType="dualstack") @mock_elbv2 @mock_ec2 def test_set_security_groups(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') + GroupName="a-security-group", Description="First One" + ) security_group2 = ec2.create_security_group( - GroupName='b-security-group', Description='Second One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="b-security-group", Description="Second One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] client.set_security_groups( - LoadBalancerArn=arn, - SecurityGroups=[security_group.id, security_group2.id] + LoadBalancerArn=arn, SecurityGroups=[security_group.id, security_group2.id] ) resp = client.describe_load_balancers(LoadBalancerArns=[arn]) - len(resp['LoadBalancers'][0]['SecurityGroups']).should.equal(2) + len(resp["LoadBalancers"][0]["SecurityGroups"]).should.equal(2) with assert_raises(ClientError): - client.set_security_groups( - LoadBalancerArn=arn, - SecurityGroups=['non_existant'] - ) + client.set_security_groups(LoadBalancerArn=arn, SecurityGroups=["non_existant"]) @mock_elbv2 @mock_ec2 def test_set_subnets(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.64/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.64/26", AvailabilityZone="us-east-1b" + ) subnet3 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1c') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1c" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] client.set_subnets( - LoadBalancerArn=arn, - Subnets=[subnet1.id, subnet2.id, subnet3.id] + LoadBalancerArn=arn, Subnets=[subnet1.id, subnet2.id, subnet3.id] ) resp = client.describe_load_balancers(LoadBalancerArns=[arn]) - len(resp['LoadBalancers'][0]['AvailabilityZones']).should.equal(3) + len(resp["LoadBalancers"][0]["AvailabilityZones"]).should.equal(3) # Only 1 AZ with assert_raises(ClientError): - client.set_subnets( - LoadBalancerArn=arn, - Subnets=[subnet1.id] - ) + client.set_subnets(LoadBalancerArn=arn, Subnets=[subnet1.id]) # Multiple subnets in same AZ with assert_raises(ClientError): client.set_subnets( - LoadBalancerArn=arn, - Subnets=[subnet1.id, subnet2.id, subnet2.id] + LoadBalancerArn=arn, Subnets=[subnet1.id, subnet2.id, subnet2.id] ) @mock_elbv2 @mock_ec2 def test_set_subnets(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - arn = response['LoadBalancers'][0]['LoadBalancerArn'] + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + arn = response["LoadBalancers"][0]["LoadBalancerArn"] client.modify_load_balancer_attributes( LoadBalancerArn=arn, - Attributes=[{'Key': 'idle_timeout.timeout_seconds', 'Value': '600'}] + Attributes=[{"Key": "idle_timeout.timeout_seconds", "Value": "600"}], ) # Check its 600 not 60 - response = client.describe_load_balancer_attributes( - LoadBalancerArn=arn - ) - idle_timeout = list(filter(lambda item: item['Key'] == 'idle_timeout.timeout_seconds', response['Attributes']))[0] - idle_timeout['Value'].should.equal('600') + response = client.describe_load_balancer_attributes(LoadBalancerArn=arn) + idle_timeout = list( + filter( + lambda item: item["Key"] == "idle_timeout.timeout_seconds", + response["Attributes"], + ) + )[0] + idle_timeout["Value"].should.equal("600") @mock_elbv2 @mock_ec2 def test_modify_target_group(): - client = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + client = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") response = client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - arn = response.get('TargetGroups')[0]['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + arn = response.get("TargetGroups")[0]["TargetGroupArn"] client.modify_target_group( TargetGroupArn=arn, - HealthCheckProtocol='HTTPS', - HealthCheckPort='8081', - HealthCheckPath='/status', + HealthCheckProtocol="HTTPS", + HealthCheckPort="8081", + HealthCheckPath="/status", HealthCheckIntervalSeconds=10, HealthCheckTimeoutSeconds=10, HealthyThresholdCount=10, UnhealthyThresholdCount=4, - Matcher={'HttpCode': '200-399'} + Matcher={"HttpCode": "200-399"}, ) - response = client.describe_target_groups( - TargetGroupArns=[arn] - ) - response['TargetGroups'][0]['Matcher']['HttpCode'].should.equal('200-399') - response['TargetGroups'][0]['HealthCheckIntervalSeconds'].should.equal(10) - response['TargetGroups'][0]['HealthCheckPath'].should.equal('/status') - response['TargetGroups'][0]['HealthCheckPort'].should.equal('8081') - response['TargetGroups'][0]['HealthCheckProtocol'].should.equal('HTTPS') - response['TargetGroups'][0]['HealthCheckTimeoutSeconds'].should.equal(10) - response['TargetGroups'][0]['HealthyThresholdCount'].should.equal(10) - response['TargetGroups'][0]['UnhealthyThresholdCount'].should.equal(4) + response = client.describe_target_groups(TargetGroupArns=[arn]) + response["TargetGroups"][0]["Matcher"]["HttpCode"].should.equal("200-399") + response["TargetGroups"][0]["HealthCheckIntervalSeconds"].should.equal(10) + response["TargetGroups"][0]["HealthCheckPath"].should.equal("/status") + response["TargetGroups"][0]["HealthCheckPort"].should.equal("8081") + response["TargetGroups"][0]["HealthCheckProtocol"].should.equal("HTTPS") + response["TargetGroups"][0]["HealthCheckTimeoutSeconds"].should.equal(10) + response["TargetGroups"][0]["HealthyThresholdCount"].should.equal(10) + response["TargetGroups"][0]["UnhealthyThresholdCount"].should.equal(4) @mock_elbv2 @mock_ec2 @mock_acm def test_modify_listener_http_to_https(): - client = boto3.client('elbv2', region_name='eu-central-1') - acm = boto3.client('acm', region_name='eu-central-1') - ec2 = boto3.resource('ec2', region_name='eu-central-1') + client = boto3.client("elbv2", region_name="eu-central-1") + acm = boto3.client("acm", region_name="eu-central-1") + ec2 = boto3.resource("ec2", region_name="eu-central-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='eu-central-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="eu-central-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='eu-central-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="eu-central-1b" + ) response = client.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") response = client.create_target_group( - Name='a-target', - Protocol='HTTP', + Name="a-target", + Protocol="HTTP", Port=8080, VpcId=vpc.id, - HealthCheckProtocol='HTTP', - HealthCheckPort='8080', - HealthCheckPath='/', + HealthCheckProtocol="HTTP", + HealthCheckPort="8080", + HealthCheckPath="/", HealthCheckIntervalSeconds=5, HealthCheckTimeoutSeconds=5, HealthyThresholdCount=5, UnhealthyThresholdCount=2, - Matcher={'HttpCode': '200'}) - target_group = response.get('TargetGroups')[0] - target_group_arn = target_group['TargetGroupArn'] + Matcher={"HttpCode": "200"}, + ) + target_group = response.get("TargetGroups")[0] + target_group_arn = target_group["TargetGroupArn"] # Plain HTTP listener response = client.create_listener( LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', + Protocol="HTTP", Port=80, - DefaultActions=[{'Type': 'forward', 'TargetGroupArn': target_group_arn}] + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) - listener_arn = response['Listeners'][0]['ListenerArn'] + listener_arn = response["Listeners"][0]["ListenerArn"] response = acm.request_certificate( - DomainName='google.com', - SubjectAlternativeNames=['google.com', 'www.google.com', 'mail.google.com'], + DomainName="google.com", + SubjectAlternativeNames=["google.com", "www.google.com", "mail.google.com"], ) - google_arn = response['CertificateArn'] + google_arn = response["CertificateArn"] response = acm.request_certificate( - DomainName='yahoo.com', - SubjectAlternativeNames=['yahoo.com', 'www.yahoo.com', 'mail.yahoo.com'], + DomainName="yahoo.com", + SubjectAlternativeNames=["yahoo.com", "www.yahoo.com", "mail.yahoo.com"], ) - yahoo_arn = response['CertificateArn'] + yahoo_arn = response["CertificateArn"] response = client.modify_listener( ListenerArn=listener_arn, Port=443, - Protocol='HTTPS', - SslPolicy='ELBSecurityPolicy-TLS-1-2-2017-01', + Protocol="HTTPS", + SslPolicy="ELBSecurityPolicy-TLS-1-2-2017-01", Certificates=[ - {'CertificateArn': google_arn, 'IsDefault': False}, - {'CertificateArn': yahoo_arn, 'IsDefault': True} + {"CertificateArn": google_arn, "IsDefault": False}, + {"CertificateArn": yahoo_arn, "IsDefault": True}, ], - DefaultActions=[ - {'Type': 'forward', 'TargetGroupArn': target_group_arn} - ] + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) - response['Listeners'][0]['Port'].should.equal(443) - response['Listeners'][0]['Protocol'].should.equal('HTTPS') - response['Listeners'][0]['SslPolicy'].should.equal('ELBSecurityPolicy-TLS-1-2-2017-01') - len(response['Listeners'][0]['Certificates']).should.equal(2) + response["Listeners"][0]["Port"].should.equal(443) + response["Listeners"][0]["Protocol"].should.equal("HTTPS") + response["Listeners"][0]["SslPolicy"].should.equal( + "ELBSecurityPolicy-TLS-1-2-2017-01" + ) + len(response["Listeners"][0]["Certificates"]).should.equal(2) # Check default cert, can't do this in server mode - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'false': - listener = elbv2_backends['eu-central-1'].load_balancers[load_balancer_arn].listeners[listener_arn] + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": + listener = ( + elbv2_backends["eu-central-1"] + .load_balancers[load_balancer_arn] + .listeners[listener_arn] + ) listener.certificate.should.equal(yahoo_arn) # No default cert @@ -1659,14 +1617,10 @@ def test_modify_listener_http_to_https(): client.modify_listener( ListenerArn=listener_arn, Port=443, - Protocol='HTTPS', - SslPolicy='ELBSecurityPolicy-TLS-1-2-2017-01', - Certificates=[ - {'CertificateArn': google_arn, 'IsDefault': False} - ], - DefaultActions=[ - {'Type': 'forward', 'TargetGroupArn': target_group_arn} - ] + Protocol="HTTPS", + SslPolicy="ELBSecurityPolicy-TLS-1-2-2017-01", + Certificates=[{"CertificateArn": google_arn, "IsDefault": False}], + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) # Bad cert @@ -1674,14 +1628,10 @@ def test_modify_listener_http_to_https(): client.modify_listener( ListenerArn=listener_arn, Port=443, - Protocol='HTTPS', - SslPolicy='ELBSecurityPolicy-TLS-1-2-2017-01', - Certificates=[ - {'CertificateArn': 'lalala', 'IsDefault': True} - ], - DefaultActions=[ - {'Type': 'forward', 'TargetGroupArn': target_group_arn} - ] + Protocol="HTTPS", + SslPolicy="ELBSecurityPolicy-TLS-1-2-2017-01", + Certificates=[{"CertificateArn": "lalala", "IsDefault": True}], + DefaultActions=[{"Type": "forward", "TargetGroupArn": target_group_arn}], ) @@ -1689,8 +1639,8 @@ def test_modify_listener_http_to_https(): @mock_elbv2 @mock_cloudformation def test_create_target_groups_through_cloudformation(): - cfn_conn = boto3.client('cloudformation', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + cfn_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") # test that setting a name manually as well as letting cloudformation create a name both work # this is a special case because test groups have a name length limit of 22 characters, and must be unique @@ -1701,9 +1651,7 @@ def test_create_target_groups_through_cloudformation(): "Resources": { "testVPC": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - }, + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "testGroup1": { "Type": "AWS::ElasticLoadBalancingV2::TargetGroup", @@ -1730,93 +1678,117 @@ def test_create_target_groups_through_cloudformation(): "VpcId": {"Ref": "testVPC"}, }, }, - } + }, } template_json = json.dumps(template) - cfn_conn.create_stack( - StackName="test-stack", - TemplateBody=template_json, - ) + cfn_conn.create_stack(StackName="test-stack", TemplateBody=template_json) describe_target_groups_response = elbv2_client.describe_target_groups() - target_group_dicts = describe_target_groups_response['TargetGroups'] + target_group_dicts = describe_target_groups_response["TargetGroups"] assert len(target_group_dicts) == 3 # there should be 2 target groups with the same prefix of 10 characters (since the random suffix is 12) # and one named MyTargetGroup - assert len([tg for tg in target_group_dicts if tg['TargetGroupName'] == 'MyTargetGroup']) == 1 - assert len( - [tg for tg in target_group_dicts if tg['TargetGroupName'].startswith('test-stack')] - ) == 2 + assert ( + len( + [ + tg + for tg in target_group_dicts + if tg["TargetGroupName"] == "MyTargetGroup" + ] + ) + == 1 + ) + assert ( + len( + [ + tg + for tg in target_group_dicts + if tg["TargetGroupName"].startswith("test-stack") + ] + ) + == 2 + ) @mock_elbv2 @mock_ec2 def test_redirect_action_listener_rule(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") - response = conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[ - {'Type': 'redirect', - 'RedirectConfig': { - 'Protocol': 'HTTPS', - 'Port': '443', - 'StatusCode': 'HTTP_301' - }}]) + response = conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[ + { + "Type": "redirect", + "RedirectConfig": { + "Protocol": "HTTPS", + "Port": "443", + "StatusCode": "HTTP_301", + }, + } + ], + ) - listener = response.get('Listeners')[0] - expected_default_actions = [{ - 'Type': 'redirect', - 'RedirectConfig': { - 'Protocol': 'HTTPS', - 'Port': '443', - 'StatusCode': 'HTTP_301' + listener = response.get("Listeners")[0] + expected_default_actions = [ + { + "Type": "redirect", + "RedirectConfig": { + "Protocol": "HTTPS", + "Port": "443", + "StatusCode": "HTTP_301", + }, } - }] - listener.get('DefaultActions').should.equal(expected_default_actions) - listener_arn = listener.get('ListenerArn') + ] + listener.get("DefaultActions").should.equal(expected_default_actions) + listener_arn = listener.get("ListenerArn") describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) - describe_rules_response['Rules'][0]['Actions'].should.equal(expected_default_actions) + describe_rules_response["Rules"][0]["Actions"].should.equal( + expected_default_actions + ) - describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) - describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'] + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn]) + describe_listener_actions = describe_listener_response["Listeners"][0][ + "DefaultActions" + ] describe_listener_actions.should.equal(expected_default_actions) modify_listener_response = conn.modify_listener(ListenerArn=listener_arn, Port=81) - modify_listener_actions = modify_listener_response['Listeners'][0]['DefaultActions'] + modify_listener_actions = modify_listener_response["Listeners"][0]["DefaultActions"] modify_listener_actions.should.equal(expected_default_actions) @mock_elbv2 @mock_cloudformation def test_redirect_action_listener_rule_cloudformation(): - cnf_conn = boto3.client('cloudformation', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + cnf_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -1824,9 +1796,7 @@ def test_redirect_action_listener_rule_cloudformation(): "Resources": { "testVPC": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - }, + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "subnet1": { "Type": "AWS::EC2::Subnet", @@ -1851,7 +1821,7 @@ def test_redirect_action_listener_rule_cloudformation(): "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], "Type": "application", "SecurityGroups": [], - } + }, }, "testListener": { "Type": "AWS::ElasticLoadBalancingV2::Listener", @@ -1859,93 +1829,110 @@ def test_redirect_action_listener_rule_cloudformation(): "LoadBalancerArn": {"Ref": "testLb"}, "Port": 80, "Protocol": "HTTP", - "DefaultActions": [{ - "Type": "redirect", - "RedirectConfig": { - "Port": "443", - "Protocol": "HTTPS", - "StatusCode": "HTTP_301", + "DefaultActions": [ + { + "Type": "redirect", + "RedirectConfig": { + "Port": "443", + "Protocol": "HTTPS", + "StatusCode": "HTTP_301", + }, } - }] - } - - } - } + ], + }, + }, + }, } template_json = json.dumps(template) cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) - describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) - describe_load_balancers_response['LoadBalancers'].should.have.length_of(1) - load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] + describe_load_balancers_response = elbv2_client.describe_load_balancers( + Names=["my-lb"] + ) + describe_load_balancers_response["LoadBalancers"].should.have.length_of(1) + load_balancer_arn = describe_load_balancers_response["LoadBalancers"][0][ + "LoadBalancerArn" + ] - describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + describe_listeners_response = elbv2_client.describe_listeners( + LoadBalancerArn=load_balancer_arn + ) - describe_listeners_response['Listeners'].should.have.length_of(1) - describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ - 'Type': 'redirect', - 'RedirectConfig': { - 'Port': '443', 'Protocol': 'HTTPS', 'StatusCode': 'HTTP_301', - } - },]) + describe_listeners_response["Listeners"].should.have.length_of(1) + describe_listeners_response["Listeners"][0]["DefaultActions"].should.equal( + [ + { + "Type": "redirect", + "RedirectConfig": { + "Port": "443", + "Protocol": "HTTPS", + "StatusCode": "HTTP_301", + }, + } + ] + ) @mock_elbv2 @mock_ec2 def test_cognito_action_listener_rule(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") action = { - 'Type': 'authenticate-cognito', - 'AuthenticateCognitoConfig': { - 'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234', - 'UserPoolClientId': 'abcd1234abcd', - 'UserPoolDomain': 'testpool', - } + "Type": "authenticate-cognito", + "AuthenticateCognitoConfig": { + "UserPoolArn": "arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234", + "UserPoolClientId": "abcd1234abcd", + "UserPoolDomain": "testpool", + }, } - response = conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[action]) + response = conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[action], + ) - listener = response.get('Listeners')[0] - listener.get('DefaultActions')[0].should.equal(action) - listener_arn = listener.get('ListenerArn') + listener = response.get("Listeners")[0] + listener.get("DefaultActions")[0].should.equal(action) + listener_arn = listener.get("ListenerArn") describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) - describe_rules_response['Rules'][0]['Actions'][0].should.equal(action) + describe_rules_response["Rules"][0]["Actions"][0].should.equal(action) - describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) - describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0] + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn]) + describe_listener_actions = describe_listener_response["Listeners"][0][ + "DefaultActions" + ][0] describe_listener_actions.should.equal(action) @mock_elbv2 @mock_cloudformation def test_cognito_action_listener_rule_cloudformation(): - cnf_conn = boto3.client('cloudformation', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + cnf_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -1953,9 +1940,7 @@ def test_cognito_action_listener_rule_cloudformation(): "Resources": { "testVPC": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - }, + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "subnet1": { "Type": "AWS::EC2::Subnet", @@ -1980,7 +1965,7 @@ def test_cognito_action_listener_rule_cloudformation(): "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], "Type": "application", "SecurityGroups": [], - } + }, }, "testListener": { "Type": "AWS::ElasticLoadBalancingV2::Listener", @@ -1988,93 +1973,108 @@ def test_cognito_action_listener_rule_cloudformation(): "LoadBalancerArn": {"Ref": "testLb"}, "Port": 80, "Protocol": "HTTP", - "DefaultActions": [{ - "Type": "authenticate-cognito", - "AuthenticateCognitoConfig": { - 'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234', - 'UserPoolClientId': 'abcd1234abcd', - 'UserPoolDomain': 'testpool', + "DefaultActions": [ + { + "Type": "authenticate-cognito", + "AuthenticateCognitoConfig": { + "UserPoolArn": "arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234", + "UserPoolClientId": "abcd1234abcd", + "UserPoolDomain": "testpool", + }, } - }] - } - - } - } + ], + }, + }, + }, } template_json = json.dumps(template) cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) - describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) - load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] - describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + describe_load_balancers_response = elbv2_client.describe_load_balancers( + Names=["my-lb"] + ) + load_balancer_arn = describe_load_balancers_response["LoadBalancers"][0][ + "LoadBalancerArn" + ] + describe_listeners_response = elbv2_client.describe_listeners( + LoadBalancerArn=load_balancer_arn + ) - describe_listeners_response['Listeners'].should.have.length_of(1) - describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ - 'Type': 'authenticate-cognito', - "AuthenticateCognitoConfig": { - 'UserPoolArn': 'arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234', - 'UserPoolClientId': 'abcd1234abcd', - 'UserPoolDomain': 'testpool', - } - },]) + describe_listeners_response["Listeners"].should.have.length_of(1) + describe_listeners_response["Listeners"][0]["DefaultActions"].should.equal( + [ + { + "Type": "authenticate-cognito", + "AuthenticateCognitoConfig": { + "UserPoolArn": "arn:aws:cognito-idp:us-east-1:123456789012:userpool/us-east-1_ABCD1234", + "UserPoolClientId": "abcd1234abcd", + "UserPoolDomain": "testpool", + }, + } + ] + ) @mock_elbv2 @mock_ec2 def test_fixed_response_action_listener_rule(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") action = { - 'Type': 'fixed-response', - 'FixedResponseConfig': { - 'ContentType': 'text/plain', - 'MessageBody': 'This page does not exist', - 'StatusCode': '404', - } + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "404", + }, } - response = conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[action]) + response = conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[action], + ) - listener = response.get('Listeners')[0] - listener.get('DefaultActions')[0].should.equal(action) - listener_arn = listener.get('ListenerArn') + listener = response.get("Listeners")[0] + listener.get("DefaultActions")[0].should.equal(action) + listener_arn = listener.get("ListenerArn") describe_rules_response = conn.describe_rules(ListenerArn=listener_arn) - describe_rules_response['Rules'][0]['Actions'][0].should.equal(action) + describe_rules_response["Rules"][0]["Actions"][0].should.equal(action) - describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn, ]) - describe_listener_actions = describe_listener_response['Listeners'][0]['DefaultActions'][0] + describe_listener_response = conn.describe_listeners(ListenerArns=[listener_arn]) + describe_listener_actions = describe_listener_response["Listeners"][0][ + "DefaultActions" + ][0] describe_listener_actions.should.equal(action) @mock_elbv2 @mock_cloudformation def test_fixed_response_action_listener_rule_cloudformation(): - cnf_conn = boto3.client('cloudformation', region_name='us-east-1') - elbv2_client = boto3.client('elbv2', region_name='us-east-1') + cnf_conn = boto3.client("cloudformation", region_name="us-east-1") + elbv2_client = boto3.client("elbv2", region_name="us-east-1") template = { "AWSTemplateFormatVersion": "2010-09-09", @@ -2082,9 +2082,7 @@ def test_fixed_response_action_listener_rule_cloudformation(): "Resources": { "testVPC": { "Type": "AWS::EC2::VPC", - "Properties": { - "CidrBlock": "10.0.0.0/16", - }, + "Properties": {"CidrBlock": "10.0.0.0/16"}, }, "subnet1": { "Type": "AWS::EC2::Subnet", @@ -2109,7 +2107,7 @@ def test_fixed_response_action_listener_rule_cloudformation(): "Subnets": [{"Ref": "subnet1"}, {"Ref": "subnet2"}], "Type": "application", "SecurityGroups": [], - } + }, }, "testListener": { "Type": "AWS::ElasticLoadBalancingV2::Listener", @@ -2117,179 +2115,202 @@ def test_fixed_response_action_listener_rule_cloudformation(): "LoadBalancerArn": {"Ref": "testLb"}, "Port": 80, "Protocol": "HTTP", - "DefaultActions": [{ - "Type": "fixed-response", - "FixedResponseConfig": { - 'ContentType': 'text/plain', - 'MessageBody': 'This page does not exist', - 'StatusCode': '404', + "DefaultActions": [ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "404", + }, } - }] - } - - } - } + ], + }, + }, + }, } template_json = json.dumps(template) cnf_conn.create_stack(StackName="test-stack", TemplateBody=template_json) - describe_load_balancers_response = elbv2_client.describe_load_balancers(Names=['my-lb',]) - load_balancer_arn = describe_load_balancers_response['LoadBalancers'][0]['LoadBalancerArn'] - describe_listeners_response = elbv2_client.describe_listeners(LoadBalancerArn=load_balancer_arn) + describe_load_balancers_response = elbv2_client.describe_load_balancers( + Names=["my-lb"] + ) + load_balancer_arn = describe_load_balancers_response["LoadBalancers"][0][ + "LoadBalancerArn" + ] + describe_listeners_response = elbv2_client.describe_listeners( + LoadBalancerArn=load_balancer_arn + ) - describe_listeners_response['Listeners'].should.have.length_of(1) - describe_listeners_response['Listeners'][0]['DefaultActions'].should.equal([{ - 'Type': 'fixed-response', - "FixedResponseConfig": { - 'ContentType': 'text/plain', - 'MessageBody': 'This page does not exist', - 'StatusCode': '404', - } - },]) + describe_listeners_response["Listeners"].should.have.length_of(1) + describe_listeners_response["Listeners"][0]["DefaultActions"].should.equal( + [ + { + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "404", + }, + } + ] + ) @mock_elbv2 @mock_ec2 def test_fixed_response_action_listener_rule_validates_status_code(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") missing_status_code_action = { - 'Type': 'fixed-response', - 'FixedResponseConfig': { - 'ContentType': 'text/plain', - 'MessageBody': 'This page does not exist', - } + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + }, } with assert_raises(ParamValidationError): - conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[missing_status_code_action]) + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[missing_status_code_action], + ) invalid_status_code_action = { - 'Type': 'fixed-response', - 'FixedResponseConfig': { - 'ContentType': 'text/plain', - 'MessageBody': 'This page does not exist', - 'StatusCode': '100' - } + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "100", + }, } @mock_elbv2 @mock_ec2 def test_fixed_response_action_listener_rule_validates_status_code(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") missing_status_code_action = { - 'Type': 'fixed-response', - 'FixedResponseConfig': { - 'ContentType': 'text/plain', - 'MessageBody': 'This page does not exist', - } + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + }, } with assert_raises(ParamValidationError): - conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[missing_status_code_action]) + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[missing_status_code_action], + ) invalid_status_code_action = { - 'Type': 'fixed-response', - 'FixedResponseConfig': { - 'ContentType': 'text/plain', - 'MessageBody': 'This page does not exist', - 'StatusCode': '100' - } + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "text/plain", + "MessageBody": "This page does not exist", + "StatusCode": "100", + }, } with assert_raises(ClientError) as invalid_status_code_exception: - conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[invalid_status_code_action]) + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[invalid_status_code_action], + ) - invalid_status_code_exception.exception.response['Error']['Code'].should.equal('ValidationError') + invalid_status_code_exception.exception.response["Error"]["Code"].should.equal( + "ValidationError" + ) @mock_elbv2 @mock_ec2 def test_fixed_response_action_listener_rule_validates_content_type(): - conn = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') + conn = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.128/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.128/26", AvailabilityZone="us-east-1b" + ) response = conn.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - Tags=[{'Key': 'key_name', 'Value': 'a_value'}]) - load_balancer_arn = response.get('LoadBalancers')[0].get('LoadBalancerArn') + Scheme="internal", + Tags=[{"Key": "key_name", "Value": "a_value"}], + ) + load_balancer_arn = response.get("LoadBalancers")[0].get("LoadBalancerArn") invalid_content_type_action = { - 'Type': 'fixed-response', - 'FixedResponseConfig': { - 'ContentType': 'Fake content type', - 'MessageBody': 'This page does not exist', - 'StatusCode': '200' - } + "Type": "fixed-response", + "FixedResponseConfig": { + "ContentType": "Fake content type", + "MessageBody": "This page does not exist", + "StatusCode": "200", + }, } with assert_raises(ClientError) as invalid_content_type_exception: - conn.create_listener(LoadBalancerArn=load_balancer_arn, - Protocol='HTTP', - Port=80, - DefaultActions=[invalid_content_type_action]) - invalid_content_type_exception.exception.response['Error']['Code'].should.equal('InvalidLoadBalancerAction') + conn.create_listener( + LoadBalancerArn=load_balancer_arn, + Protocol="HTTP", + Port=80, + DefaultActions=[invalid_content_type_action], + ) + invalid_content_type_exception.exception.response["Error"]["Code"].should.equal( + "InvalidLoadBalancerAction" + ) diff --git a/tests/test_elbv2/test_server.py b/tests/test_elbv2/test_server.py index ddd40a02d..7d2ce4b01 100644 --- a/tests/test_elbv2/test_server.py +++ b/tests/test_elbv2/test_server.py @@ -3,15 +3,15 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_elbv2_describe_load_balancers(): backend = server.create_backend_app("elbv2") test_client = backend.test_client() - res = test_client.get('/?Action=DescribeLoadBalancers&Version=2015-12-01') + res = test_client.get("/?Action=DescribeLoadBalancers&Version=2015-12-01") - res.data.should.contain(b'DescribeLoadBalancersResponse') + res.data.should.contain(b"DescribeLoadBalancersResponse") diff --git a/tests/test_emr/test_emr.py b/tests/test_emr/test_emr.py index 505c69b11..0dea23066 100644 --- a/tests/test_emr/test_emr.py +++ b/tests/test_emr/test_emr.py @@ -16,22 +16,22 @@ from tests.helpers import requires_boto_gte run_jobflow_args = dict( - job_flow_role='EMR_EC2_DefaultRole', + job_flow_role="EMR_EC2_DefaultRole", keep_alive=True, - log_uri='s3://some_bucket/jobflow_logs', - master_instance_type='c1.medium', - name='My jobflow', + log_uri="s3://some_bucket/jobflow_logs", + master_instance_type="c1.medium", + name="My jobflow", num_instances=2, - service_role='EMR_DefaultRole', - slave_instance_type='c1.medium', + service_role="EMR_DefaultRole", + slave_instance_type="c1.medium", ) input_instance_groups = [ - InstanceGroup(1, 'MASTER', 'c1.medium', 'ON_DEMAND', 'master'), - InstanceGroup(3, 'CORE', 'c1.medium', 'ON_DEMAND', 'core'), - InstanceGroup(6, 'TASK', 'c1.large', 'SPOT', 'task-1', '0.07'), - InstanceGroup(10, 'TASK', 'c1.xlarge', 'SPOT', 'task-2', '0.05'), + InstanceGroup(1, "MASTER", "c1.medium", "ON_DEMAND", "master"), + InstanceGroup(3, "CORE", "c1.medium", "ON_DEMAND", "core"), + InstanceGroup(6, "TASK", "c1.large", "SPOT", "task-1", "0.07"), + InstanceGroup(10, "TASK", "c1.xlarge", "SPOT", "task-2", "0.05"), ] @@ -39,72 +39,73 @@ input_instance_groups = [ def test_describe_cluster(): conn = boto.connect_emr() args = run_jobflow_args.copy() - args.update(dict( - api_params={ - 'Applications.member.1.Name': 'Spark', - 'Applications.member.1.Version': '2.4.2', - 'Configurations.member.1.Classification': 'yarn-site', - 'Configurations.member.1.Properties.entry.1.key': 'someproperty', - 'Configurations.member.1.Properties.entry.1.value': 'somevalue', - 'Configurations.member.1.Properties.entry.2.key': 'someotherproperty', - 'Configurations.member.1.Properties.entry.2.value': 'someothervalue', - 'Instances.EmrManagedMasterSecurityGroup': 'master-security-group', - 'Instances.Ec2SubnetId': 'subnet-8be41cec', - }, - availability_zone='us-east-2b', - ec2_keyname='mykey', - job_flow_role='EMR_EC2_DefaultRole', - keep_alive=False, - log_uri='s3://some_bucket/jobflow_logs', - name='My jobflow', - service_role='EMR_DefaultRole', - visible_to_all_users=True, - )) + args.update( + dict( + api_params={ + "Applications.member.1.Name": "Spark", + "Applications.member.1.Version": "2.4.2", + "Configurations.member.1.Classification": "yarn-site", + "Configurations.member.1.Properties.entry.1.key": "someproperty", + "Configurations.member.1.Properties.entry.1.value": "somevalue", + "Configurations.member.1.Properties.entry.2.key": "someotherproperty", + "Configurations.member.1.Properties.entry.2.value": "someothervalue", + "Instances.EmrManagedMasterSecurityGroup": "master-security-group", + "Instances.Ec2SubnetId": "subnet-8be41cec", + }, + availability_zone="us-east-2b", + ec2_keyname="mykey", + job_flow_role="EMR_EC2_DefaultRole", + keep_alive=False, + log_uri="s3://some_bucket/jobflow_logs", + name="My jobflow", + service_role="EMR_DefaultRole", + visible_to_all_users=True, + ) + ) cluster_id = conn.run_jobflow(**args) - input_tags = {'tag1': 'val1', 'tag2': 'val2'} + input_tags = {"tag1": "val1", "tag2": "val2"} conn.add_tags(cluster_id, input_tags) cluster = conn.describe_cluster(cluster_id) - cluster.applications[0].name.should.equal('Spark') - cluster.applications[0].version.should.equal('2.4.2') - cluster.autoterminate.should.equal('true') + cluster.applications[0].name.should.equal("Spark") + cluster.applications[0].version.should.equal("2.4.2") + cluster.autoterminate.should.equal("true") # configurations appear not be supplied as attributes? attrs = cluster.ec2instanceattributes # AdditionalMasterSecurityGroups # AdditionalSlaveSecurityGroups - attrs.ec2availabilityzone.should.equal(args['availability_zone']) - attrs.ec2keyname.should.equal(args['ec2_keyname']) - attrs.ec2subnetid.should.equal(args['api_params']['Instances.Ec2SubnetId']) + attrs.ec2availabilityzone.should.equal(args["availability_zone"]) + attrs.ec2keyname.should.equal(args["ec2_keyname"]) + attrs.ec2subnetid.should.equal(args["api_params"]["Instances.Ec2SubnetId"]) # EmrManagedMasterSecurityGroups # EmrManagedSlaveSecurityGroups - attrs.iaminstanceprofile.should.equal(args['job_flow_role']) + attrs.iaminstanceprofile.should.equal(args["job_flow_role"]) # ServiceAccessSecurityGroup cluster.id.should.equal(cluster_id) - cluster.loguri.should.equal(args['log_uri']) + cluster.loguri.should.equal(args["log_uri"]) cluster.masterpublicdnsname.should.be.a(six.string_types) - cluster.name.should.equal(args['name']) + cluster.name.should.equal(args["name"]) int(cluster.normalizedinstancehours).should.equal(0) # cluster.release_label - cluster.shouldnt.have.property('requestedamiversion') - cluster.runningamiversion.should.equal('1.0.0') + cluster.shouldnt.have.property("requestedamiversion") + cluster.runningamiversion.should.equal("1.0.0") # cluster.securityconfiguration - cluster.servicerole.should.equal(args['service_role']) + cluster.servicerole.should.equal(args["service_role"]) - cluster.status.state.should.equal('TERMINATED') + cluster.status.state.should.equal("TERMINATED") cluster.status.statechangereason.message.should.be.a(six.string_types) cluster.status.statechangereason.code.should.be.a(six.string_types) cluster.status.timeline.creationdatetime.should.be.a(six.string_types) # cluster.status.timeline.enddatetime.should.be.a(six.string_types) # cluster.status.timeline.readydatetime.should.be.a(six.string_types) - dict((item.key, item.value) - for item in cluster.tags).should.equal(input_tags) + dict((item.key, item.value) for item in cluster.tags).should.equal(input_tags) - cluster.terminationprotected.should.equal('false') - cluster.visibletoallusers.should.equal('true') + cluster.terminationprotected.should.equal("false") + cluster.visibletoallusers.should.equal("true") @mock_emr_deprecated @@ -114,13 +115,13 @@ def test_describe_jobflows(): expected = {} for idx in range(4): - cluster_name = 'cluster' + str(idx) - args['name'] = cluster_name + cluster_name = "cluster" + str(idx) + args["name"] = cluster_name cluster_id = conn.run_jobflow(**args) expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'state': 'WAITING' + "id": cluster_id, + "name": cluster_name, + "state": "WAITING", } # need sleep since it appears the timestamp is always rounded to @@ -130,14 +131,14 @@ def test_describe_jobflows(): time.sleep(1) for idx in range(4, 6): - cluster_name = 'cluster' + str(idx) - args['name'] = cluster_name + cluster_name = "cluster" + str(idx) + args["name"] = cluster_name cluster_id = conn.run_jobflow(**args) conn.terminate_jobflow(cluster_id) expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'state': 'TERMINATED' + "id": cluster_id, + "name": cluster_name, + "state": "TERMINATED", } jobs = conn.describe_jobflows() jobs.should.have.length_of(6) @@ -147,10 +148,10 @@ def test_describe_jobflows(): resp.should.have.length_of(1) resp[0].jobflowid.should.equal(cluster_id) - resp = conn.describe_jobflows(states=['WAITING']) + resp = conn.describe_jobflows(states=["WAITING"]) resp.should.have.length_of(4) for x in resp: - x.state.should.equal('WAITING') + x.state.should.equal("WAITING") resp = conn.describe_jobflows(created_before=timestamp) resp.should.have.length_of(4) @@ -163,83 +164,82 @@ def test_describe_jobflows(): def test_describe_jobflow(): conn = boto.connect_emr() args = run_jobflow_args.copy() - args.update(dict( - ami_version='3.8.1', - api_params={ - #'Applications.member.1.Name': 'Spark', - #'Applications.member.1.Version': '2.4.2', - #'Configurations.member.1.Classification': 'yarn-site', - #'Configurations.member.1.Properties.entry.1.key': 'someproperty', - #'Configurations.member.1.Properties.entry.1.value': 'somevalue', - #'Instances.EmrManagedMasterSecurityGroup': 'master-security-group', - 'Instances.Ec2SubnetId': 'subnet-8be41cec', - }, - ec2_keyname='mykey', - hadoop_version='2.4.0', - - name='My jobflow', - log_uri='s3://some_bucket/jobflow_logs', - keep_alive=True, - master_instance_type='c1.medium', - slave_instance_type='c1.medium', - num_instances=2, - - availability_zone='us-west-2b', - - job_flow_role='EMR_EC2_DefaultRole', - service_role='EMR_DefaultRole', - visible_to_all_users=True, - )) + args.update( + dict( + ami_version="3.8.1", + api_params={ + #'Applications.member.1.Name': 'Spark', + #'Applications.member.1.Version': '2.4.2', + #'Configurations.member.1.Classification': 'yarn-site', + #'Configurations.member.1.Properties.entry.1.key': 'someproperty', + #'Configurations.member.1.Properties.entry.1.value': 'somevalue', + #'Instances.EmrManagedMasterSecurityGroup': 'master-security-group', + "Instances.Ec2SubnetId": "subnet-8be41cec" + }, + ec2_keyname="mykey", + hadoop_version="2.4.0", + name="My jobflow", + log_uri="s3://some_bucket/jobflow_logs", + keep_alive=True, + master_instance_type="c1.medium", + slave_instance_type="c1.medium", + num_instances=2, + availability_zone="us-west-2b", + job_flow_role="EMR_EC2_DefaultRole", + service_role="EMR_DefaultRole", + visible_to_all_users=True, + ) + ) cluster_id = conn.run_jobflow(**args) jf = conn.describe_jobflow(cluster_id) - jf.amiversion.should.equal(args['ami_version']) + jf.amiversion.should.equal(args["ami_version"]) jf.bootstrapactions.should.equal(None) jf.creationdatetime.should.be.a(six.string_types) - jf.should.have.property('laststatechangereason') + jf.should.have.property("laststatechangereason") jf.readydatetime.should.be.a(six.string_types) jf.startdatetime.should.be.a(six.string_types) - jf.state.should.equal('WAITING') + jf.state.should.equal("WAITING") - jf.ec2keyname.should.equal(args['ec2_keyname']) + jf.ec2keyname.should.equal(args["ec2_keyname"]) # Ec2SubnetId - jf.hadoopversion.should.equal(args['hadoop_version']) + jf.hadoopversion.should.equal(args["hadoop_version"]) int(jf.instancecount).should.equal(2) for ig in jf.instancegroups: ig.creationdatetime.should.be.a(six.string_types) # ig.enddatetime.should.be.a(six.string_types) - ig.should.have.property('instancegroupid').being.a(six.string_types) + ig.should.have.property("instancegroupid").being.a(six.string_types) int(ig.instancerequestcount).should.equal(1) - ig.instancerole.should.be.within(['MASTER', 'CORE']) + ig.instancerole.should.be.within(["MASTER", "CORE"]) int(ig.instancerunningcount).should.equal(1) - ig.instancetype.should.equal('c1.medium') + ig.instancetype.should.equal("c1.medium") ig.laststatechangereason.should.be.a(six.string_types) - ig.market.should.equal('ON_DEMAND') + ig.market.should.equal("ON_DEMAND") ig.name.should.be.a(six.string_types) ig.readydatetime.should.be.a(six.string_types) ig.startdatetime.should.be.a(six.string_types) - ig.state.should.equal('RUNNING') + ig.state.should.equal("RUNNING") - jf.keepjobflowalivewhennosteps.should.equal('true') + jf.keepjobflowalivewhennosteps.should.equal("true") jf.masterinstanceid.should.be.a(six.string_types) - jf.masterinstancetype.should.equal(args['master_instance_type']) + jf.masterinstancetype.should.equal(args["master_instance_type"]) jf.masterpublicdnsname.should.be.a(six.string_types) int(jf.normalizedinstancehours).should.equal(0) - jf.availabilityzone.should.equal(args['availability_zone']) - jf.slaveinstancetype.should.equal(args['slave_instance_type']) - jf.terminationprotected.should.equal('false') + jf.availabilityzone.should.equal(args["availability_zone"]) + jf.slaveinstancetype.should.equal(args["slave_instance_type"]) + jf.terminationprotected.should.equal("false") jf.jobflowid.should.equal(cluster_id) # jf.jobflowrole.should.equal(args['job_flow_role']) - jf.loguri.should.equal(args['log_uri']) - jf.name.should.equal(args['name']) + jf.loguri.should.equal(args["log_uri"]) + jf.name.should.equal(args["name"]) # jf.servicerole.should.equal(args['service_role']) jf.steps.should.have.length_of(0) list(i.value for i in jf.supported_products).should.equal([]) - jf.visibletoallusers.should.equal('true') + jf.visibletoallusers.should.equal("true") @mock_emr_deprecated @@ -249,14 +249,14 @@ def test_list_clusters(): expected = {} for idx in range(40): - cluster_name = 'jobflow' + str(idx) - args['name'] = cluster_name + cluster_name = "jobflow" + str(idx) + args["name"] = cluster_name cluster_id = conn.run_jobflow(**args) expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'normalizedinstancehours': '0', - 'state': 'WAITING' + "id": cluster_id, + "name": cluster_name, + "normalizedinstancehours": "0", + "state": "WAITING", } # need sleep since it appears the timestamp is always rounded to @@ -266,15 +266,15 @@ def test_list_clusters(): time.sleep(1) for idx in range(40, 70): - cluster_name = 'jobflow' + str(idx) - args['name'] = cluster_name + cluster_name = "jobflow" + str(idx) + args["name"] = cluster_name cluster_id = conn.run_jobflow(**args) conn.terminate_jobflow(cluster_id) expected[cluster_id] = { - 'id': cluster_id, - 'name': cluster_name, - 'normalizedinstancehours': '0', - 'state': 'TERMINATED' + "id": cluster_id, + "name": cluster_name, + "normalizedinstancehours": "0", + "state": "TERMINATED", } args = {} @@ -284,25 +284,24 @@ def test_list_clusters(): len(clusters).should.be.lower_than_or_equal_to(50) for x in clusters: y = expected[x.id] - x.id.should.equal(y['id']) - x.name.should.equal(y['name']) - x.normalizedinstancehours.should.equal( - y['normalizedinstancehours']) - x.status.state.should.equal(y['state']) + x.id.should.equal(y["id"]) + x.name.should.equal(y["name"]) + x.normalizedinstancehours.should.equal(y["normalizedinstancehours"]) + x.status.state.should.equal(y["state"]) x.status.timeline.creationdatetime.should.be.a(six.string_types) - if y['state'] == 'TERMINATED': + if y["state"] == "TERMINATED": x.status.timeline.enddatetime.should.be.a(six.string_types) else: - x.status.timeline.shouldnt.have.property('enddatetime') + x.status.timeline.shouldnt.have.property("enddatetime") x.status.timeline.readydatetime.should.be.a(six.string_types) - if not hasattr(resp, 'marker'): + if not hasattr(resp, "marker"): break - args = {'marker': resp.marker} + args = {"marker": resp.marker} - resp = conn.list_clusters(cluster_states=['TERMINATED']) + resp = conn.list_clusters(cluster_states=["TERMINATED"]) resp.clusters.should.have.length_of(30) for x in resp.clusters: - x.status.state.should.equal('TERMINATED') + x.status.state.should.equal("TERMINATED") resp = conn.list_clusters(created_before=timestamp) resp.clusters.should.have.length_of(40) @@ -317,13 +316,13 @@ def test_run_jobflow(): args = run_jobflow_args.copy() job_id = conn.run_jobflow(**args) job_flow = conn.describe_jobflow(job_id) - job_flow.state.should.equal('WAITING') + job_flow.state.should.equal("WAITING") job_flow.jobflowid.should.equal(job_id) - job_flow.name.should.equal(args['name']) - job_flow.masterinstancetype.should.equal(args['master_instance_type']) - job_flow.slaveinstancetype.should.equal(args['slave_instance_type']) - job_flow.loguri.should.equal(args['log_uri']) - job_flow.visibletoallusers.should.equal('false') + job_flow.name.should.equal(args["name"]) + job_flow.masterinstancetype.should.equal(args["master_instance_type"]) + job_flow.slaveinstancetype.should.equal(args["slave_instance_type"]) + job_flow.loguri.should.equal(args["log_uri"]) + job_flow.visibletoallusers.should.equal("false") int(job_flow.normalizedinstancehours).should.equal(0) job_flow.steps.should.have.length_of(0) @@ -331,16 +330,16 @@ def test_run_jobflow(): @mock_emr_deprecated def test_run_jobflow_in_multiple_regions(): regions = {} - for region in ['us-east-1', 'eu-west-1']: + for region in ["us-east-1", "eu-west-1"]: conn = boto.emr.connect_to_region(region) args = run_jobflow_args.copy() - args['name'] = region + args["name"] = region cluster_id = conn.run_jobflow(**args) - regions[region] = {'conn': conn, 'cluster_id': cluster_id} + regions[region] = {"conn": conn, "cluster_id": cluster_id} for region in regions.keys(): - conn = regions[region]['conn'] - jf = conn.describe_jobflow(regions[region]['cluster_id']) + conn = regions[region]["conn"] + jf = conn.describe_jobflow(regions[region]["cluster_id"]) jf.name.should.equal(region) @@ -357,10 +356,7 @@ def test_run_jobflow_with_new_params(): def test_run_jobflow_with_visible_to_all_users(): conn = boto.connect_emr() for expected in (True, False): - job_id = conn.run_jobflow( - visible_to_all_users=expected, - **run_jobflow_args - ) + job_id = conn.run_jobflow(visible_to_all_users=expected, **run_jobflow_args) job_flow = conn.describe_jobflow(job_id) job_flow.visibletoallusers.should.equal(str(expected).lower()) @@ -370,20 +366,19 @@ def test_run_jobflow_with_visible_to_all_users(): def test_run_jobflow_with_instance_groups(): input_groups = dict((g.name, g) for g in input_instance_groups) conn = boto.connect_emr() - job_id = conn.run_jobflow(instance_groups=input_instance_groups, - **run_jobflow_args) + job_id = conn.run_jobflow(instance_groups=input_instance_groups, **run_jobflow_args) job_flow = conn.describe_jobflow(job_id) int(job_flow.instancecount).should.equal( - sum(g.num_instances for g in input_instance_groups)) + sum(g.num_instances for g in input_instance_groups) + ) for instance_group in job_flow.instancegroups: expected = input_groups[instance_group.name] - instance_group.should.have.property('instancegroupid') - int(instance_group.instancerunningcount).should.equal( - expected.num_instances) + instance_group.should.have.property("instancegroupid") + int(instance_group.instancerunningcount).should.equal(expected.num_instances) instance_group.instancerole.should.equal(expected.role) instance_group.instancetype.should.equal(expected.type) instance_group.market.should.equal(expected.market) - if hasattr(expected, 'bidprice'): + if hasattr(expected, "bidprice"): instance_group.bidprice.should.equal(expected.bidprice) @@ -393,15 +388,15 @@ def test_set_termination_protection(): conn = boto.connect_emr() job_id = conn.run_jobflow(**run_jobflow_args) job_flow = conn.describe_jobflow(job_id) - job_flow.terminationprotected.should.equal('false') + job_flow.terminationprotected.should.equal("false") conn.set_termination_protection(job_id, True) job_flow = conn.describe_jobflow(job_id) - job_flow.terminationprotected.should.equal('true') + job_flow.terminationprotected.should.equal("true") conn.set_termination_protection(job_id, False) job_flow = conn.describe_jobflow(job_id) - job_flow.terminationprotected.should.equal('false') + job_flow.terminationprotected.should.equal("false") @requires_boto_gte("2.8") @@ -409,18 +404,18 @@ def test_set_termination_protection(): def test_set_visible_to_all_users(): conn = boto.connect_emr() args = run_jobflow_args.copy() - args['visible_to_all_users'] = False + args["visible_to_all_users"] = False job_id = conn.run_jobflow(**args) job_flow = conn.describe_jobflow(job_id) - job_flow.visibletoallusers.should.equal('false') + job_flow.visibletoallusers.should.equal("false") conn.set_visible_to_all_users(job_id, True) job_flow = conn.describe_jobflow(job_id) - job_flow.visibletoallusers.should.equal('true') + job_flow.visibletoallusers.should.equal("true") conn.set_visible_to_all_users(job_id, False) job_flow = conn.describe_jobflow(job_id) - job_flow.visibletoallusers.should.equal('false') + job_flow.visibletoallusers.should.equal("false") @mock_emr_deprecated @@ -428,32 +423,32 @@ def test_terminate_jobflow(): conn = boto.connect_emr() job_id = conn.run_jobflow(**run_jobflow_args) flow = conn.describe_jobflows()[0] - flow.state.should.equal('WAITING') + flow.state.should.equal("WAITING") conn.terminate_jobflow(job_id) flow = conn.describe_jobflows()[0] - flow.state.should.equal('TERMINATED') + flow.state.should.equal("TERMINATED") # testing multiple end points for each feature + @mock_emr_deprecated def test_bootstrap_actions(): bootstrap_actions = [ BootstrapAction( - name='bs1', - path='path/to/script', - bootstrap_action_args=['arg1', 'arg2&arg3']), + name="bs1", + path="path/to/script", + bootstrap_action_args=["arg1", "arg2&arg3"], + ), BootstrapAction( - name='bs2', - path='path/to/anotherscript', - bootstrap_action_args=[]) + name="bs2", path="path/to/anotherscript", bootstrap_action_args=[] + ), ] conn = boto.connect_emr() cluster_id = conn.run_jobflow( - bootstrap_actions=bootstrap_actions, - **run_jobflow_args + bootstrap_actions=bootstrap_actions, **run_jobflow_args ) jf = conn.describe_jobflow(cluster_id) @@ -476,9 +471,9 @@ def test_instance_groups(): conn = boto.connect_emr() args = run_jobflow_args.copy() - for key in ['master_instance_type', 'slave_instance_type', 'num_instances']: + for key in ["master_instance_type", "slave_instance_type", "num_instances"]: del args[key] - args['instance_groups'] = input_instance_groups[:2] + args["instance_groups"] = input_instance_groups[:2] job_id = conn.run_jobflow(**args) jf = conn.describe_jobflow(job_id) @@ -488,14 +483,15 @@ def test_instance_groups(): jf = conn.describe_jobflow(job_id) int(jf.instancecount).should.equal( - sum(g.num_instances for g in input_instance_groups)) + sum(g.num_instances for g in input_instance_groups) + ) for x in jf.instancegroups: y = input_groups[x.name] - if hasattr(y, 'bidprice'): + if hasattr(y, "bidprice"): x.bidprice.should.equal(y.bidprice) x.creationdatetime.should.be.a(six.string_types) # x.enddatetime.should.be.a(six.string_types) - x.should.have.property('instancegroupid') + x.should.have.property("instancegroupid") int(x.instancerequestcount).should.equal(y.num_instances) x.instancerole.should.equal(y.role) int(x.instancerunningcount).should.equal(y.num_instances) @@ -505,16 +501,16 @@ def test_instance_groups(): x.name.should.be.a(six.string_types) x.readydatetime.should.be.a(six.string_types) x.startdatetime.should.be.a(six.string_types) - x.state.should.equal('RUNNING') + x.state.should.equal("RUNNING") for x in conn.list_instance_groups(job_id).instancegroups: y = input_groups[x.name] - if hasattr(y, 'bidprice'): + if hasattr(y, "bidprice"): x.bidprice.should.equal(y.bidprice) # Configurations # EbsBlockDevices # EbsOptimized - x.should.have.property('id') + x.should.have.property("id") x.instancegrouptype.should.equal(y.role) x.instancetype.should.equal(y.type) x.market.should.equal(y.market) @@ -522,7 +518,7 @@ def test_instance_groups(): int(x.requestedinstancecount).should.equal(y.num_instances) int(x.runninginstancecount).should.equal(y.num_instances) # ShrinkPolicy - x.status.state.should.equal('RUNNING') + x.status.state.should.equal("RUNNING") x.status.statechangereason.code.should.be.a(six.string_types) x.status.statechangereason.message.should.be.a(six.string_types) x.status.timeline.creationdatetime.should.be.a(six.string_types) @@ -532,38 +528,38 @@ def test_instance_groups(): igs = dict((g.name, g) for g in jf.instancegroups) conn.modify_instance_groups( - [igs['task-1'].instancegroupid, igs['task-2'].instancegroupid], - [2, 3]) + [igs["task-1"].instancegroupid, igs["task-2"].instancegroupid], [2, 3] + ) jf = conn.describe_jobflow(job_id) int(jf.instancecount).should.equal(base_instance_count + 5) igs = dict((g.name, g) for g in jf.instancegroups) - int(igs['task-1'].instancerunningcount).should.equal(2) - int(igs['task-2'].instancerunningcount).should.equal(3) + int(igs["task-1"].instancerunningcount).should.equal(2) + int(igs["task-2"].instancerunningcount).should.equal(3) @mock_emr_deprecated def test_steps(): input_steps = [ StreamingStep( - name='My wordcount example', - mapper='s3n://elasticmapreduce/samples/wordcount/wordSplitter.py', - reducer='aggregate', - input='s3n://elasticmapreduce/samples/wordcount/input', - output='s3n://output_bucket/output/wordcount_output'), + name="My wordcount example", + mapper="s3n://elasticmapreduce/samples/wordcount/wordSplitter.py", + reducer="aggregate", + input="s3n://elasticmapreduce/samples/wordcount/input", + output="s3n://output_bucket/output/wordcount_output", + ), StreamingStep( - name='My wordcount example & co.', - mapper='s3n://elasticmapreduce/samples/wordcount/wordSplitter2.py', - reducer='aggregate', - input='s3n://elasticmapreduce/samples/wordcount/input2', - output='s3n://output_bucket/output/wordcount_output2') + name="My wordcount example & co.", + mapper="s3n://elasticmapreduce/samples/wordcount/wordSplitter2.py", + reducer="aggregate", + input="s3n://elasticmapreduce/samples/wordcount/input2", + output="s3n://output_bucket/output/wordcount_output2", + ), ] # TODO: implementation and test for cancel_steps conn = boto.connect_emr() - cluster_id = conn.run_jobflow( - steps=[input_steps[0]], - **run_jobflow_args) + cluster_id = conn.run_jobflow(steps=[input_steps[0]], **run_jobflow_args) jf = conn.describe_jobflow(cluster_id) jf.steps.should.have.length_of(1) @@ -573,18 +569,17 @@ def test_steps(): jf = conn.describe_jobflow(cluster_id) jf.steps.should.have.length_of(2) for step in jf.steps: - step.actiononfailure.should.equal('TERMINATE_JOB_FLOW') + step.actiononfailure.should.equal("TERMINATE_JOB_FLOW") list(arg.value for arg in step.args).should.have.length_of(8) step.creationdatetime.should.be.a(six.string_types) # step.enddatetime.should.be.a(six.string_types) - step.jar.should.equal( - '/home/hadoop/contrib/streaming/hadoop-streaming.jar') + step.jar.should.equal("/home/hadoop/contrib/streaming/hadoop-streaming.jar") step.laststatechangereason.should.be.a(six.string_types) - step.mainclass.should.equal('') + step.mainclass.should.equal("") step.name.should.be.a(six.string_types) # step.readydatetime.should.be.a(six.string_types) # step.startdatetime.should.be.a(six.string_types) - step.state.should.be.within(['STARTING', 'PENDING']) + step.state.should.be.within(["STARTING", "PENDING"]) expected = dict((s.name, s) for s in input_steps) @@ -592,52 +587,63 @@ def test_steps(): for x in steps: y = expected[x.name] # actiononfailure - list(arg.value for arg in x.config.args).should.equal([ - '-mapper', y.mapper, - '-reducer', y.reducer, - '-input', y.input, - '-output', y.output, - ]) - x.config.jar.should.equal( - '/home/hadoop/contrib/streaming/hadoop-streaming.jar') - x.config.mainclass.should.equal('') + list(arg.value for arg in x.config.args).should.equal( + [ + "-mapper", + y.mapper, + "-reducer", + y.reducer, + "-input", + y.input, + "-output", + y.output, + ] + ) + x.config.jar.should.equal("/home/hadoop/contrib/streaming/hadoop-streaming.jar") + x.config.mainclass.should.equal("") # properties - x.should.have.property('id').should.be.a(six.string_types) + x.should.have.property("id").should.be.a(six.string_types) x.name.should.equal(y.name) - x.status.state.should.be.within(['STARTING', 'PENDING']) + x.status.state.should.be.within(["STARTING", "PENDING"]) # x.status.statechangereason x.status.timeline.creationdatetime.should.be.a(six.string_types) # x.status.timeline.enddatetime.should.be.a(six.string_types) # x.status.timeline.startdatetime.should.be.a(six.string_types) x = conn.describe_step(cluster_id, x.id) - list(arg.value for arg in x.config.args).should.equal([ - '-mapper', y.mapper, - '-reducer', y.reducer, - '-input', y.input, - '-output', y.output, - ]) - x.config.jar.should.equal( - '/home/hadoop/contrib/streaming/hadoop-streaming.jar') - x.config.mainclass.should.equal('') + list(arg.value for arg in x.config.args).should.equal( + [ + "-mapper", + y.mapper, + "-reducer", + y.reducer, + "-input", + y.input, + "-output", + y.output, + ] + ) + x.config.jar.should.equal("/home/hadoop/contrib/streaming/hadoop-streaming.jar") + x.config.mainclass.should.equal("") # properties - x.should.have.property('id').should.be.a(six.string_types) + x.should.have.property("id").should.be.a(six.string_types) x.name.should.equal(y.name) - x.status.state.should.be.within(['STARTING', 'PENDING']) + x.status.state.should.be.within(["STARTING", "PENDING"]) # x.status.statechangereason x.status.timeline.creationdatetime.should.be.a(six.string_types) # x.status.timeline.enddatetime.should.be.a(six.string_types) # x.status.timeline.startdatetime.should.be.a(six.string_types) - @requires_boto_gte('2.39') + @requires_boto_gte("2.39") def test_list_steps_with_states(): # boto's list_steps prior to 2.39 has a bug that ignores # step_states argument. steps = conn.list_steps(cluster_id).steps step_id = steps[0].id - steps = conn.list_steps(cluster_id, step_states=['STARTING']).steps + steps = conn.list_steps(cluster_id, step_states=["STARTING"]).steps steps.should.have.length_of(1) steps[0].id.should.equal(step_id) + test_list_steps_with_states() diff --git a/tests/test_emr/test_emr_boto3.py b/tests/test_emr/test_emr_boto3.py index b9a5025d9..212444abf 100644 --- a/tests/test_emr/test_emr_boto3.py +++ b/tests/test_emr/test_emr_boto3.py @@ -16,158 +16,176 @@ from moto import mock_emr run_job_flow_args = dict( Instances={ - 'InstanceCount': 3, - 'KeepJobFlowAliveWhenNoSteps': True, - 'MasterInstanceType': 'c3.medium', - 'Placement': {'AvailabilityZone': 'us-east-1a'}, - 'SlaveInstanceType': 'c3.xlarge', + "InstanceCount": 3, + "KeepJobFlowAliveWhenNoSteps": True, + "MasterInstanceType": "c3.medium", + "Placement": {"AvailabilityZone": "us-east-1a"}, + "SlaveInstanceType": "c3.xlarge", }, - JobFlowRole='EMR_EC2_DefaultRole', - LogUri='s3://mybucket/log', - Name='cluster', - ServiceRole='EMR_DefaultRole', - VisibleToAllUsers=True) + JobFlowRole="EMR_EC2_DefaultRole", + LogUri="s3://mybucket/log", + Name="cluster", + ServiceRole="EMR_DefaultRole", + VisibleToAllUsers=True, +) input_instance_groups = [ - {'InstanceCount': 1, - 'InstanceRole': 'MASTER', - 'InstanceType': 'c1.medium', - 'Market': 'ON_DEMAND', - 'Name': 'master'}, - {'InstanceCount': 3, - 'InstanceRole': 'CORE', - 'InstanceType': 'c1.medium', - 'Market': 'ON_DEMAND', - 'Name': 'core'}, - {'InstanceCount': 6, - 'InstanceRole': 'TASK', - 'InstanceType': 'c1.large', - 'Market': 'SPOT', - 'Name': 'task-1', - 'BidPrice': '0.07'}, - {'InstanceCount': 10, - 'InstanceRole': 'TASK', - 'InstanceType': 'c1.xlarge', - 'Market': 'SPOT', - 'Name': 'task-2', - 'BidPrice': '0.05'}, + { + "InstanceCount": 1, + "InstanceRole": "MASTER", + "InstanceType": "c1.medium", + "Market": "ON_DEMAND", + "Name": "master", + }, + { + "InstanceCount": 3, + "InstanceRole": "CORE", + "InstanceType": "c1.medium", + "Market": "ON_DEMAND", + "Name": "core", + }, + { + "InstanceCount": 6, + "InstanceRole": "TASK", + "InstanceType": "c1.large", + "Market": "SPOT", + "Name": "task-1", + "BidPrice": "0.07", + }, + { + "InstanceCount": 10, + "InstanceRole": "TASK", + "InstanceType": "c1.xlarge", + "Market": "SPOT", + "Name": "task-2", + "BidPrice": "0.05", + }, ] @mock_emr def test_describe_cluster(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Applications'] = [{'Name': 'Spark', 'Version': '2.4.2'}] - args['Configurations'] = [ - {'Classification': 'yarn-site', - 'Properties': {'someproperty': 'somevalue', - 'someotherproperty': 'someothervalue'}}, - {'Classification': 'nested-configs', - 'Properties': {}, - 'Configurations': [ - { - 'Classification': 'nested-config', - 'Properties': { - 'nested-property': 'nested-value' - } - } - ]} + args["Applications"] = [{"Name": "Spark", "Version": "2.4.2"}] + args["Configurations"] = [ + { + "Classification": "yarn-site", + "Properties": { + "someproperty": "somevalue", + "someotherproperty": "someothervalue", + }, + }, + { + "Classification": "nested-configs", + "Properties": {}, + "Configurations": [ + { + "Classification": "nested-config", + "Properties": {"nested-property": "nested-value"}, + } + ], + }, ] - args['Instances']['AdditionalMasterSecurityGroups'] = ['additional-master'] - args['Instances']['AdditionalSlaveSecurityGroups'] = ['additional-slave'] - args['Instances']['Ec2KeyName'] = 'mykey' - args['Instances']['Ec2SubnetId'] = 'subnet-8be41cec' - args['Instances']['EmrManagedMasterSecurityGroup'] = 'master-security-group' - args['Instances']['EmrManagedSlaveSecurityGroup'] = 'slave-security-group' - args['Instances']['KeepJobFlowAliveWhenNoSteps'] = False - args['Instances']['ServiceAccessSecurityGroup'] = 'service-access-security-group' - args['Tags'] = [{'Key': 'tag1', 'Value': 'val1'}, - {'Key': 'tag2', 'Value': 'val2'}] + args["Instances"]["AdditionalMasterSecurityGroups"] = ["additional-master"] + args["Instances"]["AdditionalSlaveSecurityGroups"] = ["additional-slave"] + args["Instances"]["Ec2KeyName"] = "mykey" + args["Instances"]["Ec2SubnetId"] = "subnet-8be41cec" + args["Instances"]["EmrManagedMasterSecurityGroup"] = "master-security-group" + args["Instances"]["EmrManagedSlaveSecurityGroup"] = "slave-security-group" + args["Instances"]["KeepJobFlowAliveWhenNoSteps"] = False + args["Instances"]["ServiceAccessSecurityGroup"] = "service-access-security-group" + args["Tags"] = [{"Key": "tag1", "Value": "val1"}, {"Key": "tag2", "Value": "val2"}] - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - cl = client.describe_cluster(ClusterId=cluster_id)['Cluster'] - cl['Applications'][0]['Name'].should.equal('Spark') - cl['Applications'][0]['Version'].should.equal('2.4.2') - cl['AutoTerminate'].should.equal(True) + cl = client.describe_cluster(ClusterId=cluster_id)["Cluster"] + cl["Applications"][0]["Name"].should.equal("Spark") + cl["Applications"][0]["Version"].should.equal("2.4.2") + cl["AutoTerminate"].should.equal(True) - config = cl['Configurations'][0] - config['Classification'].should.equal('yarn-site') - config['Properties'].should.equal(args['Configurations'][0]['Properties']) + config = cl["Configurations"][0] + config["Classification"].should.equal("yarn-site") + config["Properties"].should.equal(args["Configurations"][0]["Properties"]) - nested_config = cl['Configurations'][1] - nested_config['Classification'].should.equal('nested-configs') - nested_config['Properties'].should.equal(args['Configurations'][1]['Properties']) + nested_config = cl["Configurations"][1] + nested_config["Classification"].should.equal("nested-configs") + nested_config["Properties"].should.equal(args["Configurations"][1]["Properties"]) - attrs = cl['Ec2InstanceAttributes'] - attrs['AdditionalMasterSecurityGroups'].should.equal( - args['Instances']['AdditionalMasterSecurityGroups']) - attrs['AdditionalSlaveSecurityGroups'].should.equal( - args['Instances']['AdditionalSlaveSecurityGroups']) - attrs['Ec2AvailabilityZone'].should.equal('us-east-1a') - attrs['Ec2KeyName'].should.equal(args['Instances']['Ec2KeyName']) - attrs['Ec2SubnetId'].should.equal(args['Instances']['Ec2SubnetId']) - attrs['EmrManagedMasterSecurityGroup'].should.equal( - args['Instances']['EmrManagedMasterSecurityGroup']) - attrs['EmrManagedSlaveSecurityGroup'].should.equal( - args['Instances']['EmrManagedSlaveSecurityGroup']) - attrs['IamInstanceProfile'].should.equal(args['JobFlowRole']) - attrs['ServiceAccessSecurityGroup'].should.equal( - args['Instances']['ServiceAccessSecurityGroup']) - cl['Id'].should.equal(cluster_id) - cl['LogUri'].should.equal(args['LogUri']) - cl['MasterPublicDnsName'].should.be.a(six.string_types) - cl['Name'].should.equal(args['Name']) - cl['NormalizedInstanceHours'].should.equal(0) + attrs = cl["Ec2InstanceAttributes"] + attrs["AdditionalMasterSecurityGroups"].should.equal( + args["Instances"]["AdditionalMasterSecurityGroups"] + ) + attrs["AdditionalSlaveSecurityGroups"].should.equal( + args["Instances"]["AdditionalSlaveSecurityGroups"] + ) + attrs["Ec2AvailabilityZone"].should.equal("us-east-1a") + attrs["Ec2KeyName"].should.equal(args["Instances"]["Ec2KeyName"]) + attrs["Ec2SubnetId"].should.equal(args["Instances"]["Ec2SubnetId"]) + attrs["EmrManagedMasterSecurityGroup"].should.equal( + args["Instances"]["EmrManagedMasterSecurityGroup"] + ) + attrs["EmrManagedSlaveSecurityGroup"].should.equal( + args["Instances"]["EmrManagedSlaveSecurityGroup"] + ) + attrs["IamInstanceProfile"].should.equal(args["JobFlowRole"]) + attrs["ServiceAccessSecurityGroup"].should.equal( + args["Instances"]["ServiceAccessSecurityGroup"] + ) + cl["Id"].should.equal(cluster_id) + cl["LogUri"].should.equal(args["LogUri"]) + cl["MasterPublicDnsName"].should.be.a(six.string_types) + cl["Name"].should.equal(args["Name"]) + cl["NormalizedInstanceHours"].should.equal(0) # cl['ReleaseLabel'].should.equal('emr-5.0.0') - cl.shouldnt.have.key('RequestedAmiVersion') - cl['RunningAmiVersion'].should.equal('1.0.0') + cl.shouldnt.have.key("RequestedAmiVersion") + cl["RunningAmiVersion"].should.equal("1.0.0") # cl['SecurityConfiguration'].should.be.a(six.string_types) - cl['ServiceRole'].should.equal(args['ServiceRole']) + cl["ServiceRole"].should.equal(args["ServiceRole"]) - status = cl['Status'] - status['State'].should.equal('TERMINATED') + status = cl["Status"] + status["State"].should.equal("TERMINATED") # cluster['Status']['StateChangeReason'] - status['Timeline']['CreationDateTime'].should.be.a('datetime.datetime') + status["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # status['Timeline']['EndDateTime'].should.equal(datetime(2014, 1, 24, 2, 19, 46, tzinfo=pytz.utc)) - status['Timeline']['ReadyDateTime'].should.be.a('datetime.datetime') + status["Timeline"]["ReadyDateTime"].should.be.a("datetime.datetime") - dict((t['Key'], t['Value']) for t in cl['Tags']).should.equal( - dict((t['Key'], t['Value']) for t in args['Tags'])) + dict((t["Key"], t["Value"]) for t in cl["Tags"]).should.equal( + dict((t["Key"], t["Value"]) for t in args["Tags"]) + ) - cl['TerminationProtected'].should.equal(False) - cl['VisibleToAllUsers'].should.equal(True) + cl["TerminationProtected"].should.equal(False) + cl["VisibleToAllUsers"].should.equal(True) @mock_emr def test_describe_cluster_not_found(): - conn = boto3.client('emr', region_name='us-east-1') + conn = boto3.client("emr", region_name="us-east-1") raised = False try: - cluster = conn.describe_cluster(ClusterId='DummyId') + cluster = conn.describe_cluster(ClusterId="DummyId") except ClientError as e: - if e.response['Error']['Code'] == "ResourceNotFoundException": + if e.response["Error"]["Code"] == "ResourceNotFoundException": raised = True raised.should.equal(True) @mock_emr def test_describe_job_flows(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) expected = {} for idx in range(4): - cluster_name = 'cluster' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "cluster" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'State': 'WAITING' + "Id": cluster_id, + "Name": cluster_name, + "State": "WAITING", } # need sleep since it appears the timestamp is always rounded to @@ -177,117 +195,119 @@ def test_describe_job_flows(): time.sleep(1) for idx in range(4, 6): - cluster_name = 'cluster' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "cluster" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] client.terminate_job_flows(JobFlowIds=[cluster_id]) expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'State': 'TERMINATED' + "Id": cluster_id, + "Name": cluster_name, + "State": "TERMINATED", } resp = client.describe_job_flows() - resp['JobFlows'].should.have.length_of(6) + resp["JobFlows"].should.have.length_of(6) for cluster_id, y in expected.items(): resp = client.describe_job_flows(JobFlowIds=[cluster_id]) - resp['JobFlows'].should.have.length_of(1) - resp['JobFlows'][0]['JobFlowId'].should.equal(cluster_id) + resp["JobFlows"].should.have.length_of(1) + resp["JobFlows"][0]["JobFlowId"].should.equal(cluster_id) - resp = client.describe_job_flows(JobFlowStates=['WAITING']) - resp['JobFlows'].should.have.length_of(4) - for x in resp['JobFlows']: - x['ExecutionStatusDetail']['State'].should.equal('WAITING') + resp = client.describe_job_flows(JobFlowStates=["WAITING"]) + resp["JobFlows"].should.have.length_of(4) + for x in resp["JobFlows"]: + x["ExecutionStatusDetail"]["State"].should.equal("WAITING") resp = client.describe_job_flows(CreatedBefore=timestamp) - resp['JobFlows'].should.have.length_of(4) + resp["JobFlows"].should.have.length_of(4) resp = client.describe_job_flows(CreatedAfter=timestamp) - resp['JobFlows'].should.have.length_of(2) + resp["JobFlows"].should.have.length_of(2) @mock_emr def test_describe_job_flow(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['AmiVersion'] = '3.8.1' - args['Instances'].update( - {'Ec2KeyName': 'ec2keyname', - 'Ec2SubnetId': 'subnet-8be41cec', - 'HadoopVersion': '2.4.0'}) - args['VisibleToAllUsers'] = True + args["AmiVersion"] = "3.8.1" + args["Instances"].update( + { + "Ec2KeyName": "ec2keyname", + "Ec2SubnetId": "subnet-8be41cec", + "HadoopVersion": "2.4.0", + } + ) + args["VisibleToAllUsers"] = True - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] - jf['AmiVersion'].should.equal(args['AmiVersion']) - jf.shouldnt.have.key('BootstrapActions') - esd = jf['ExecutionStatusDetail'] - esd['CreationDateTime'].should.be.a('datetime.datetime') + jf["AmiVersion"].should.equal(args["AmiVersion"]) + jf.shouldnt.have.key("BootstrapActions") + esd = jf["ExecutionStatusDetail"] + esd["CreationDateTime"].should.be.a("datetime.datetime") # esd['EndDateTime'].should.be.a('datetime.datetime') # esd['LastStateChangeReason'].should.be.a(six.string_types) - esd['ReadyDateTime'].should.be.a('datetime.datetime') - esd['StartDateTime'].should.be.a('datetime.datetime') - esd['State'].should.equal('WAITING') - attrs = jf['Instances'] - attrs['Ec2KeyName'].should.equal(args['Instances']['Ec2KeyName']) - attrs['Ec2SubnetId'].should.equal(args['Instances']['Ec2SubnetId']) - attrs['HadoopVersion'].should.equal(args['Instances']['HadoopVersion']) - attrs['InstanceCount'].should.equal(args['Instances']['InstanceCount']) - for ig in attrs['InstanceGroups']: + esd["ReadyDateTime"].should.be.a("datetime.datetime") + esd["StartDateTime"].should.be.a("datetime.datetime") + esd["State"].should.equal("WAITING") + attrs = jf["Instances"] + attrs["Ec2KeyName"].should.equal(args["Instances"]["Ec2KeyName"]) + attrs["Ec2SubnetId"].should.equal(args["Instances"]["Ec2SubnetId"]) + attrs["HadoopVersion"].should.equal(args["Instances"]["HadoopVersion"]) + attrs["InstanceCount"].should.equal(args["Instances"]["InstanceCount"]) + for ig in attrs["InstanceGroups"]: # ig['BidPrice'] - ig['CreationDateTime'].should.be.a('datetime.datetime') + ig["CreationDateTime"].should.be.a("datetime.datetime") # ig['EndDateTime'].should.be.a('datetime.datetime') - ig['InstanceGroupId'].should.be.a(six.string_types) - ig['InstanceRequestCount'].should.be.a(int) - ig['InstanceRole'].should.be.within(['MASTER', 'CORE']) - ig['InstanceRunningCount'].should.be.a(int) - ig['InstanceType'].should.be.within(['c3.medium', 'c3.xlarge']) + ig["InstanceGroupId"].should.be.a(six.string_types) + ig["InstanceRequestCount"].should.be.a(int) + ig["InstanceRole"].should.be.within(["MASTER", "CORE"]) + ig["InstanceRunningCount"].should.be.a(int) + ig["InstanceType"].should.be.within(["c3.medium", "c3.xlarge"]) # ig['LastStateChangeReason'].should.be.a(six.string_types) - ig['Market'].should.equal('ON_DEMAND') - ig['Name'].should.be.a(six.string_types) - ig['ReadyDateTime'].should.be.a('datetime.datetime') - ig['StartDateTime'].should.be.a('datetime.datetime') - ig['State'].should.equal('RUNNING') - attrs['KeepJobFlowAliveWhenNoSteps'].should.equal(True) + ig["Market"].should.equal("ON_DEMAND") + ig["Name"].should.be.a(six.string_types) + ig["ReadyDateTime"].should.be.a("datetime.datetime") + ig["StartDateTime"].should.be.a("datetime.datetime") + ig["State"].should.equal("RUNNING") + attrs["KeepJobFlowAliveWhenNoSteps"].should.equal(True) # attrs['MasterInstanceId'].should.be.a(six.string_types) - attrs['MasterInstanceType'].should.equal( - args['Instances']['MasterInstanceType']) - attrs['MasterPublicDnsName'].should.be.a(six.string_types) - attrs['NormalizedInstanceHours'].should.equal(0) - attrs['Placement']['AvailabilityZone'].should.equal( - args['Instances']['Placement']['AvailabilityZone']) - attrs['SlaveInstanceType'].should.equal( - args['Instances']['SlaveInstanceType']) - attrs['TerminationProtected'].should.equal(False) - jf['JobFlowId'].should.equal(cluster_id) - jf['JobFlowRole'].should.equal(args['JobFlowRole']) - jf['LogUri'].should.equal(args['LogUri']) - jf['Name'].should.equal(args['Name']) - jf['ServiceRole'].should.equal(args['ServiceRole']) - jf['Steps'].should.equal([]) - jf['SupportedProducts'].should.equal([]) - jf['VisibleToAllUsers'].should.equal(True) + attrs["MasterInstanceType"].should.equal(args["Instances"]["MasterInstanceType"]) + attrs["MasterPublicDnsName"].should.be.a(six.string_types) + attrs["NormalizedInstanceHours"].should.equal(0) + attrs["Placement"]["AvailabilityZone"].should.equal( + args["Instances"]["Placement"]["AvailabilityZone"] + ) + attrs["SlaveInstanceType"].should.equal(args["Instances"]["SlaveInstanceType"]) + attrs["TerminationProtected"].should.equal(False) + jf["JobFlowId"].should.equal(cluster_id) + jf["JobFlowRole"].should.equal(args["JobFlowRole"]) + jf["LogUri"].should.equal(args["LogUri"]) + jf["Name"].should.equal(args["Name"]) + jf["ServiceRole"].should.equal(args["ServiceRole"]) + jf["Steps"].should.equal([]) + jf["SupportedProducts"].should.equal([]) + jf["VisibleToAllUsers"].should.equal(True) @mock_emr def test_list_clusters(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) expected = {} for idx in range(40): - cluster_name = 'jobflow' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "jobflow" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'NormalizedInstanceHours': 0, - 'State': 'WAITING' + "Id": cluster_id, + "Name": cluster_name, + "NormalizedInstanceHours": 0, + "State": "WAITING", } # need sleep since it appears the timestamp is always rounded to @@ -297,465 +317,484 @@ def test_list_clusters(): time.sleep(1) for idx in range(40, 70): - cluster_name = 'jobflow' + str(idx) - args['Name'] = cluster_name - cluster_id = client.run_job_flow(**args)['JobFlowId'] + cluster_name = "jobflow" + str(idx) + args["Name"] = cluster_name + cluster_id = client.run_job_flow(**args)["JobFlowId"] client.terminate_job_flows(JobFlowIds=[cluster_id]) expected[cluster_id] = { - 'Id': cluster_id, - 'Name': cluster_name, - 'NormalizedInstanceHours': 0, - 'State': 'TERMINATED' + "Id": cluster_id, + "Name": cluster_name, + "NormalizedInstanceHours": 0, + "State": "TERMINATED", } args = {} while 1: resp = client.list_clusters(**args) - clusters = resp['Clusters'] + clusters = resp["Clusters"] len(clusters).should.be.lower_than_or_equal_to(50) for x in clusters: - y = expected[x['Id']] - x['Id'].should.equal(y['Id']) - x['Name'].should.equal(y['Name']) - x['NormalizedInstanceHours'].should.equal( - y['NormalizedInstanceHours']) - x['Status']['State'].should.equal(y['State']) - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') - if y['State'] == 'TERMINATED': - x['Status']['Timeline'][ - 'EndDateTime'].should.be.a('datetime.datetime') + y = expected[x["Id"]] + x["Id"].should.equal(y["Id"]) + x["Name"].should.equal(y["Name"]) + x["NormalizedInstanceHours"].should.equal(y["NormalizedInstanceHours"]) + x["Status"]["State"].should.equal(y["State"]) + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") + if y["State"] == "TERMINATED": + x["Status"]["Timeline"]["EndDateTime"].should.be.a("datetime.datetime") else: - x['Status']['Timeline'].shouldnt.have.key('EndDateTime') - x['Status']['Timeline'][ - 'ReadyDateTime'].should.be.a('datetime.datetime') - marker = resp.get('Marker') + x["Status"]["Timeline"].shouldnt.have.key("EndDateTime") + x["Status"]["Timeline"]["ReadyDateTime"].should.be.a("datetime.datetime") + marker = resp.get("Marker") if marker is None: break - args = {'Marker': marker} + args = {"Marker": marker} - resp = client.list_clusters(ClusterStates=['TERMINATED']) - resp['Clusters'].should.have.length_of(30) - for x in resp['Clusters']: - x['Status']['State'].should.equal('TERMINATED') + resp = client.list_clusters(ClusterStates=["TERMINATED"]) + resp["Clusters"].should.have.length_of(30) + for x in resp["Clusters"]: + x["Status"]["State"].should.equal("TERMINATED") resp = client.list_clusters(CreatedBefore=timestamp) - resp['Clusters'].should.have.length_of(40) + resp["Clusters"].should.have.length_of(40) resp = client.list_clusters(CreatedAfter=timestamp) - resp['Clusters'].should.have.length_of(30) + resp["Clusters"].should.have.length_of(30) @mock_emr def test_run_job_flow(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - cluster_id = client.run_job_flow(**args)['JobFlowId'] - resp = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - resp['ExecutionStatusDetail']['State'].should.equal('WAITING') - resp['JobFlowId'].should.equal(cluster_id) - resp['Name'].should.equal(args['Name']) - resp['Instances']['MasterInstanceType'].should.equal( - args['Instances']['MasterInstanceType']) - resp['Instances']['SlaveInstanceType'].should.equal( - args['Instances']['SlaveInstanceType']) - resp['LogUri'].should.equal(args['LogUri']) - resp['VisibleToAllUsers'].should.equal(args['VisibleToAllUsers']) - resp['Instances']['NormalizedInstanceHours'].should.equal(0) - resp['Steps'].should.equal([]) + cluster_id = client.run_job_flow(**args)["JobFlowId"] + resp = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + resp["ExecutionStatusDetail"]["State"].should.equal("WAITING") + resp["JobFlowId"].should.equal(cluster_id) + resp["Name"].should.equal(args["Name"]) + resp["Instances"]["MasterInstanceType"].should.equal( + args["Instances"]["MasterInstanceType"] + ) + resp["Instances"]["SlaveInstanceType"].should.equal( + args["Instances"]["SlaveInstanceType"] + ) + resp["LogUri"].should.equal(args["LogUri"]) + resp["VisibleToAllUsers"].should.equal(args["VisibleToAllUsers"]) + resp["Instances"]["NormalizedInstanceHours"].should.equal(0) + resp["Steps"].should.equal([]) @mock_emr def test_run_job_flow_with_invalid_params(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") with assert_raises(ClientError) as ex: # cannot set both AmiVersion and ReleaseLabel args = deepcopy(run_job_flow_args) - args['AmiVersion'] = '2.4' - args['ReleaseLabel'] = 'emr-5.0.0' + args["AmiVersion"] = "2.4" + args["ReleaseLabel"] = "emr-5.0.0" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") @mock_emr def test_run_job_flow_in_multiple_regions(): regions = {} - for region in ['us-east-1', 'eu-west-1']: - client = boto3.client('emr', region_name=region) + for region in ["us-east-1", "eu-west-1"]: + client = boto3.client("emr", region_name=region) args = deepcopy(run_job_flow_args) - args['Name'] = region - cluster_id = client.run_job_flow(**args)['JobFlowId'] - regions[region] = {'client': client, 'cluster_id': cluster_id} + args["Name"] = region + cluster_id = client.run_job_flow(**args)["JobFlowId"] + regions[region] = {"client": client, "cluster_id": cluster_id} for region in regions.keys(): - client = regions[region]['client'] - resp = client.describe_cluster(ClusterId=regions[region]['cluster_id']) - resp['Cluster']['Name'].should.equal(region) + client = regions[region]["client"] + resp = client.describe_cluster(ClusterId=regions[region]["cluster_id"]) + resp["Cluster"]["Name"].should.equal(region) @mock_emr def test_run_job_flow_with_new_params(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") resp = client.run_job_flow(**run_job_flow_args) - resp.should.have.key('JobFlowId') + resp.should.have.key("JobFlowId") @mock_emr def test_run_job_flow_with_visible_to_all_users(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") for expected in (True, False): args = deepcopy(run_job_flow_args) - args['VisibleToAllUsers'] = expected + args["VisibleToAllUsers"] = expected resp = client.run_job_flow(**args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['VisibleToAllUsers'].should.equal(expected) + resp["Cluster"]["VisibleToAllUsers"].should.equal(expected) @mock_emr def test_run_job_flow_with_instance_groups(): - input_groups = dict((g['Name'], g) for g in input_instance_groups) - client = boto3.client('emr', region_name='us-east-1') + input_groups = dict((g["Name"], g) for g in input_instance_groups) + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Instances'] = {'InstanceGroups': input_instance_groups} - cluster_id = client.run_job_flow(**args)['JobFlowId'] - groups = client.list_instance_groups(ClusterId=cluster_id)[ - 'InstanceGroups'] + args["Instances"] = {"InstanceGroups": input_instance_groups} + cluster_id = client.run_job_flow(**args)["JobFlowId"] + groups = client.list_instance_groups(ClusterId=cluster_id)["InstanceGroups"] for x in groups: - y = input_groups[x['Name']] - x.should.have.key('Id') - x['RequestedInstanceCount'].should.equal(y['InstanceCount']) - x['InstanceGroupType'].should.equal(y['InstanceRole']) - x['InstanceType'].should.equal(y['InstanceType']) - x['Market'].should.equal(y['Market']) - if 'BidPrice' in y: - x['BidPrice'].should.equal(y['BidPrice']) + y = input_groups[x["Name"]] + x.should.have.key("Id") + x["RequestedInstanceCount"].should.equal(y["InstanceCount"]) + x["InstanceGroupType"].should.equal(y["InstanceRole"]) + x["InstanceType"].should.equal(y["InstanceType"]) + x["Market"].should.equal(y["Market"]) + if "BidPrice" in y: + x["BidPrice"].should.equal(y["BidPrice"]) @mock_emr def test_run_job_flow_with_custom_ami(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") with assert_raises(ClientError) as ex: # CustomAmiId available in Amazon EMR 5.7.0 and later args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomId' - args['ReleaseLabel'] = 'emr-5.6.0' + args["CustomAmiId"] = "MyEmrCustomId" + args["ReleaseLabel"] = "emr-5.6.0" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal('Custom AMI is not allowed') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal("Custom AMI is not allowed") with assert_raises(ClientError) as ex: args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomId' - args['AmiVersion'] = '3.8.1' + args["CustomAmiId"] = "MyEmrCustomId" + args["AmiVersion"] = "3.8.1" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.equal( - 'Custom AMI is not supported in this version of EMR') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.equal( + "Custom AMI is not supported in this version of EMR" + ) with assert_raises(ClientError) as ex: # AMI version and release label exception raises before CustomAmi exception args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomId' - args['ReleaseLabel'] = 'emr-5.6.0' - args['AmiVersion'] = '3.8.1' + args["CustomAmiId"] = "MyEmrCustomId" + args["ReleaseLabel"] = "emr-5.6.0" + args["AmiVersion"] = "3.8.1" client.run_job_flow(**args) - ex.exception.response['Error']['Code'].should.equal('ValidationException') - ex.exception.response['Error']['Message'].should.contain( - 'Only one AMI version and release label may be specified.') + ex.exception.response["Error"]["Code"].should.equal("ValidationException") + ex.exception.response["Error"]["Message"].should.contain( + "Only one AMI version and release label may be specified." + ) args = deepcopy(run_job_flow_args) - args['CustomAmiId'] = 'MyEmrCustomAmi' - args['ReleaseLabel'] = 'emr-5.7.0' - cluster_id = client.run_job_flow(**args)['JobFlowId'] + args["CustomAmiId"] = "MyEmrCustomAmi" + args["ReleaseLabel"] = "emr-5.7.0" + cluster_id = client.run_job_flow(**args)["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['CustomAmiId'].should.equal('MyEmrCustomAmi') + resp["Cluster"]["CustomAmiId"].should.equal("MyEmrCustomAmi") @mock_emr def test_set_termination_protection(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Instances']['TerminationProtected'] = False + args["Instances"]["TerminationProtected"] = False resp = client.run_job_flow(**args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['TerminationProtected'].should.equal(False) + resp["Cluster"]["TerminationProtected"].should.equal(False) for expected in (True, False): - resp = client.set_termination_protection(JobFlowIds=[cluster_id], - TerminationProtected=expected) + resp = client.set_termination_protection( + JobFlowIds=[cluster_id], TerminationProtected=expected + ) resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['TerminationProtected'].should.equal(expected) + resp["Cluster"]["TerminationProtected"].should.equal(expected) @mock_emr def test_set_visible_to_all_users(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['VisibleToAllUsers'] = False + args["VisibleToAllUsers"] = False resp = client.run_job_flow(**args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['VisibleToAllUsers'].should.equal(False) + resp["Cluster"]["VisibleToAllUsers"].should.equal(False) for expected in (True, False): - resp = client.set_visible_to_all_users(JobFlowIds=[cluster_id], - VisibleToAllUsers=expected) + resp = client.set_visible_to_all_users( + JobFlowIds=[cluster_id], VisibleToAllUsers=expected + ) resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['VisibleToAllUsers'].should.equal(expected) + resp["Cluster"]["VisibleToAllUsers"].should.equal(expected) @mock_emr def test_terminate_job_flows(): - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") resp = client.run_job_flow(**run_job_flow_args) - cluster_id = resp['JobFlowId'] + cluster_id = resp["JobFlowId"] resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['Status']['State'].should.equal('WAITING') + resp["Cluster"]["Status"]["State"].should.equal("WAITING") resp = client.terminate_job_flows(JobFlowIds=[cluster_id]) resp = client.describe_cluster(ClusterId=cluster_id) - resp['Cluster']['Status']['State'].should.equal('TERMINATED') + resp["Cluster"]["Status"]["State"].should.equal("TERMINATED") # testing multiple end points for each feature + @mock_emr def test_bootstrap_actions(): bootstrap_actions = [ - {'Name': 'bs1', - 'ScriptBootstrapAction': { - 'Args': ['arg1', 'arg2'], - 'Path': 's3://path/to/script'}}, - {'Name': 'bs2', - 'ScriptBootstrapAction': { - 'Args': [], - 'Path': 's3://path/to/anotherscript'}} + { + "Name": "bs1", + "ScriptBootstrapAction": { + "Args": ["arg1", "arg2"], + "Path": "s3://path/to/script", + }, + }, + { + "Name": "bs2", + "ScriptBootstrapAction": {"Args": [], "Path": "s3://path/to/anotherscript"}, + }, ] - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['BootstrapActions'] = bootstrap_actions - cluster_id = client.run_job_flow(**args)['JobFlowId'] + args["BootstrapActions"] = bootstrap_actions + cluster_id = client.run_job_flow(**args)["JobFlowId"] - cl = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - for x, y in zip(cl['BootstrapActions'], bootstrap_actions): - x['BootstrapActionConfig'].should.equal(y) + cl = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + for x, y in zip(cl["BootstrapActions"], bootstrap_actions): + x["BootstrapActionConfig"].should.equal(y) resp = client.list_bootstrap_actions(ClusterId=cluster_id) - for x, y in zip(resp['BootstrapActions'], bootstrap_actions): - x['Name'].should.equal(y['Name']) - if 'Args' in y['ScriptBootstrapAction']: - x['Args'].should.equal(y['ScriptBootstrapAction']['Args']) - x['ScriptPath'].should.equal(y['ScriptBootstrapAction']['Path']) + for x, y in zip(resp["BootstrapActions"], bootstrap_actions): + x["Name"].should.equal(y["Name"]) + if "Args" in y["ScriptBootstrapAction"]: + x["Args"].should.equal(y["ScriptBootstrapAction"]["Args"]) + x["ScriptPath"].should.equal(y["ScriptBootstrapAction"]["Path"]) @mock_emr def test_instance_groups(): - input_groups = dict((g['Name'], g) for g in input_instance_groups) + input_groups = dict((g["Name"], g) for g in input_instance_groups) - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - for key in ['MasterInstanceType', 'SlaveInstanceType', 'InstanceCount']: - del args['Instances'][key] - args['Instances']['InstanceGroups'] = input_instance_groups[:2] - cluster_id = client.run_job_flow(**args)['JobFlowId'] + for key in ["MasterInstanceType", "SlaveInstanceType", "InstanceCount"]: + del args["Instances"][key] + args["Instances"]["InstanceGroups"] = input_instance_groups[:2] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - base_instance_count = jf['Instances']['InstanceCount'] + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + base_instance_count = jf["Instances"]["InstanceCount"] client.add_instance_groups( - JobFlowId=cluster_id, InstanceGroups=input_instance_groups[2:]) + JobFlowId=cluster_id, InstanceGroups=input_instance_groups[2:] + ) - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Instances']['InstanceCount'].should.equal( - sum(g['InstanceCount'] for g in input_instance_groups)) - for x in jf['Instances']['InstanceGroups']: - y = input_groups[x['Name']] - if hasattr(y, 'BidPrice'): - x['BidPrice'].should.equal('BidPrice') - x['CreationDateTime'].should.be.a('datetime.datetime') + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Instances"]["InstanceCount"].should.equal( + sum(g["InstanceCount"] for g in input_instance_groups) + ) + for x in jf["Instances"]["InstanceGroups"]: + y = input_groups[x["Name"]] + if hasattr(y, "BidPrice"): + x["BidPrice"].should.equal("BidPrice") + x["CreationDateTime"].should.be.a("datetime.datetime") # x['EndDateTime'].should.be.a('datetime.datetime') - x.should.have.key('InstanceGroupId') - x['InstanceRequestCount'].should.equal(y['InstanceCount']) - x['InstanceRole'].should.equal(y['InstanceRole']) - x['InstanceRunningCount'].should.equal(y['InstanceCount']) - x['InstanceType'].should.equal(y['InstanceType']) + x.should.have.key("InstanceGroupId") + x["InstanceRequestCount"].should.equal(y["InstanceCount"]) + x["InstanceRole"].should.equal(y["InstanceRole"]) + x["InstanceRunningCount"].should.equal(y["InstanceCount"]) + x["InstanceType"].should.equal(y["InstanceType"]) # x['LastStateChangeReason'].should.equal(y['LastStateChangeReason']) - x['Market'].should.equal(y['Market']) - x['Name'].should.equal(y['Name']) - x['ReadyDateTime'].should.be.a('datetime.datetime') - x['StartDateTime'].should.be.a('datetime.datetime') - x['State'].should.equal('RUNNING') + x["Market"].should.equal(y["Market"]) + x["Name"].should.equal(y["Name"]) + x["ReadyDateTime"].should.be.a("datetime.datetime") + x["StartDateTime"].should.be.a("datetime.datetime") + x["State"].should.equal("RUNNING") - groups = client.list_instance_groups(ClusterId=cluster_id)[ - 'InstanceGroups'] + groups = client.list_instance_groups(ClusterId=cluster_id)["InstanceGroups"] for x in groups: - y = input_groups[x['Name']] - if hasattr(y, 'BidPrice'): - x['BidPrice'].should.equal('BidPrice') + y = input_groups[x["Name"]] + if hasattr(y, "BidPrice"): + x["BidPrice"].should.equal("BidPrice") # Configurations # EbsBlockDevices # EbsOptimized - x.should.have.key('Id') - x['InstanceGroupType'].should.equal(y['InstanceRole']) - x['InstanceType'].should.equal(y['InstanceType']) - x['Market'].should.equal(y['Market']) - x['Name'].should.equal(y['Name']) - x['RequestedInstanceCount'].should.equal(y['InstanceCount']) - x['RunningInstanceCount'].should.equal(y['InstanceCount']) + x.should.have.key("Id") + x["InstanceGroupType"].should.equal(y["InstanceRole"]) + x["InstanceType"].should.equal(y["InstanceType"]) + x["Market"].should.equal(y["Market"]) + x["Name"].should.equal(y["Name"]) + x["RequestedInstanceCount"].should.equal(y["InstanceCount"]) + x["RunningInstanceCount"].should.equal(y["InstanceCount"]) # ShrinkPolicy - x['Status']['State'].should.equal('RUNNING') - x['Status']['StateChangeReason']['Code'].should.be.a(six.string_types) + x["Status"]["State"].should.equal("RUNNING") + x["Status"]["StateChangeReason"]["Code"].should.be.a(six.string_types) # x['Status']['StateChangeReason']['Message'].should.be.a(six.string_types) - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # x['Status']['Timeline']['EndDateTime'].should.be.a('datetime.datetime') - x['Status']['Timeline'][ - 'ReadyDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["ReadyDateTime"].should.be.a("datetime.datetime") - igs = dict((g['Name'], g) for g in groups) + igs = dict((g["Name"], g) for g in groups) client.modify_instance_groups( InstanceGroups=[ - {'InstanceGroupId': igs['task-1']['Id'], - 'InstanceCount': 2}, - {'InstanceGroupId': igs['task-2']['Id'], - 'InstanceCount': 3}]) - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Instances']['InstanceCount'].should.equal(base_instance_count + 5) - igs = dict((g['Name'], g) for g in jf['Instances']['InstanceGroups']) - igs['task-1']['InstanceRunningCount'].should.equal(2) - igs['task-2']['InstanceRunningCount'].should.equal(3) + {"InstanceGroupId": igs["task-1"]["Id"], "InstanceCount": 2}, + {"InstanceGroupId": igs["task-2"]["Id"], "InstanceCount": 3}, + ] + ) + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Instances"]["InstanceCount"].should.equal(base_instance_count + 5) + igs = dict((g["Name"], g) for g in jf["Instances"]["InstanceGroups"]) + igs["task-1"]["InstanceRunningCount"].should.equal(2) + igs["task-2"]["InstanceRunningCount"].should.equal(3) @mock_emr def test_steps(): - input_steps = [{ - 'HadoopJarStep': { - 'Args': [ - 'hadoop-streaming', - '-files', 's3://elasticmapreduce/samples/wordcount/wordSplitter.py#wordSplitter.py', - '-mapper', 'python wordSplitter.py', - '-input', 's3://elasticmapreduce/samples/wordcount/input', - '-output', 's3://output_bucket/output/wordcount_output', - '-reducer', 'aggregate' - ], - 'Jar': 'command-runner.jar', + input_steps = [ + { + "HadoopJarStep": { + "Args": [ + "hadoop-streaming", + "-files", + "s3://elasticmapreduce/samples/wordcount/wordSplitter.py#wordSplitter.py", + "-mapper", + "python wordSplitter.py", + "-input", + "s3://elasticmapreduce/samples/wordcount/input", + "-output", + "s3://output_bucket/output/wordcount_output", + "-reducer", + "aggregate", + ], + "Jar": "command-runner.jar", + }, + "Name": "My wordcount example", }, - 'Name': 'My wordcount example', - }, { - 'HadoopJarStep': { - 'Args': [ - 'hadoop-streaming', - '-files', 's3://elasticmapreduce/samples/wordcount/wordSplitter2.py#wordSplitter2.py', - '-mapper', 'python wordSplitter2.py', - '-input', 's3://elasticmapreduce/samples/wordcount/input2', - '-output', 's3://output_bucket/output/wordcount_output2', - '-reducer', 'aggregate' - ], - 'Jar': 'command-runner.jar', + { + "HadoopJarStep": { + "Args": [ + "hadoop-streaming", + "-files", + "s3://elasticmapreduce/samples/wordcount/wordSplitter2.py#wordSplitter2.py", + "-mapper", + "python wordSplitter2.py", + "-input", + "s3://elasticmapreduce/samples/wordcount/input2", + "-output", + "s3://output_bucket/output/wordcount_output2", + "-reducer", + "aggregate", + ], + "Jar": "command-runner.jar", + }, + "Name": "My wordcount example2", }, - 'Name': 'My wordcount example2', - }] + ] # TODO: implementation and test for cancel_steps - client = boto3.client('emr', region_name='us-east-1') + client = boto3.client("emr", region_name="us-east-1") args = deepcopy(run_job_flow_args) - args['Steps'] = [input_steps[0]] - cluster_id = client.run_job_flow(**args)['JobFlowId'] + args["Steps"] = [input_steps[0]] + cluster_id = client.run_job_flow(**args)["JobFlowId"] - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Steps'].should.have.length_of(1) + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Steps"].should.have.length_of(1) client.add_job_flow_steps(JobFlowId=cluster_id, Steps=[input_steps[1]]) - jf = client.describe_job_flows(JobFlowIds=[cluster_id])['JobFlows'][0] - jf['Steps'].should.have.length_of(2) - for idx, (x, y) in enumerate(zip(jf['Steps'], input_steps)): - x['ExecutionStatusDetail'].should.have.key('CreationDateTime') + jf = client.describe_job_flows(JobFlowIds=[cluster_id])["JobFlows"][0] + jf["Steps"].should.have.length_of(2) + for idx, (x, y) in enumerate(zip(jf["Steps"], input_steps)): + x["ExecutionStatusDetail"].should.have.key("CreationDateTime") # x['ExecutionStatusDetail'].should.have.key('EndDateTime') # x['ExecutionStatusDetail'].should.have.key('LastStateChangeReason') # x['ExecutionStatusDetail'].should.have.key('StartDateTime') - x['ExecutionStatusDetail']['State'].should.equal( - 'STARTING' if idx == 0 else 'PENDING') - x['StepConfig']['ActionOnFailure'].should.equal('TERMINATE_CLUSTER') - x['StepConfig']['HadoopJarStep'][ - 'Args'].should.equal(y['HadoopJarStep']['Args']) - x['StepConfig']['HadoopJarStep'][ - 'Jar'].should.equal(y['HadoopJarStep']['Jar']) - if 'MainClass' in y['HadoopJarStep']: - x['StepConfig']['HadoopJarStep']['MainClass'].should.equal( - y['HadoopJarStep']['MainClass']) - if 'Properties' in y['HadoopJarStep']: - x['StepConfig']['HadoopJarStep']['Properties'].should.equal( - y['HadoopJarStep']['Properties']) - x['StepConfig']['Name'].should.equal(y['Name']) + x["ExecutionStatusDetail"]["State"].should.equal( + "STARTING" if idx == 0 else "PENDING" + ) + x["StepConfig"]["ActionOnFailure"].should.equal("TERMINATE_CLUSTER") + x["StepConfig"]["HadoopJarStep"]["Args"].should.equal( + y["HadoopJarStep"]["Args"] + ) + x["StepConfig"]["HadoopJarStep"]["Jar"].should.equal(y["HadoopJarStep"]["Jar"]) + if "MainClass" in y["HadoopJarStep"]: + x["StepConfig"]["HadoopJarStep"]["MainClass"].should.equal( + y["HadoopJarStep"]["MainClass"] + ) + if "Properties" in y["HadoopJarStep"]: + x["StepConfig"]["HadoopJarStep"]["Properties"].should.equal( + y["HadoopJarStep"]["Properties"] + ) + x["StepConfig"]["Name"].should.equal(y["Name"]) - expected = dict((s['Name'], s) for s in input_steps) + expected = dict((s["Name"], s) for s in input_steps) - steps = client.list_steps(ClusterId=cluster_id)['Steps'] + steps = client.list_steps(ClusterId=cluster_id)["Steps"] steps.should.have.length_of(2) for x in steps: - y = expected[x['Name']] - x['ActionOnFailure'].should.equal('TERMINATE_CLUSTER') - x['Config']['Args'].should.equal(y['HadoopJarStep']['Args']) - x['Config']['Jar'].should.equal(y['HadoopJarStep']['Jar']) + y = expected[x["Name"]] + x["ActionOnFailure"].should.equal("TERMINATE_CLUSTER") + x["Config"]["Args"].should.equal(y["HadoopJarStep"]["Args"]) + x["Config"]["Jar"].should.equal(y["HadoopJarStep"]["Jar"]) # x['Config']['MainClass'].should.equal(y['HadoopJarStep']['MainClass']) # Properties - x['Id'].should.be.a(six.string_types) - x['Name'].should.equal(y['Name']) - x['Status']['State'].should.be.within(['STARTING', 'PENDING']) + x["Id"].should.be.a(six.string_types) + x["Name"].should.equal(y["Name"]) + x["Status"]["State"].should.be.within(["STARTING", "PENDING"]) # StateChangeReason - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # x['Status']['Timeline']['EndDateTime'].should.be.a('datetime.datetime') # x['Status']['Timeline']['StartDateTime'].should.be.a('datetime.datetime') - x = client.describe_step(ClusterId=cluster_id, StepId=x['Id'])['Step'] - x['ActionOnFailure'].should.equal('TERMINATE_CLUSTER') - x['Config']['Args'].should.equal(y['HadoopJarStep']['Args']) - x['Config']['Jar'].should.equal(y['HadoopJarStep']['Jar']) + x = client.describe_step(ClusterId=cluster_id, StepId=x["Id"])["Step"] + x["ActionOnFailure"].should.equal("TERMINATE_CLUSTER") + x["Config"]["Args"].should.equal(y["HadoopJarStep"]["Args"]) + x["Config"]["Jar"].should.equal(y["HadoopJarStep"]["Jar"]) # x['Config']['MainClass'].should.equal(y['HadoopJarStep']['MainClass']) # Properties - x['Id'].should.be.a(six.string_types) - x['Name'].should.equal(y['Name']) - x['Status']['State'].should.be.within(['STARTING', 'PENDING']) + x["Id"].should.be.a(six.string_types) + x["Name"].should.equal(y["Name"]) + x["Status"]["State"].should.be.within(["STARTING", "PENDING"]) # StateChangeReason - x['Status']['Timeline'][ - 'CreationDateTime'].should.be.a('datetime.datetime') + x["Status"]["Timeline"]["CreationDateTime"].should.be.a("datetime.datetime") # x['Status']['Timeline']['EndDateTime'].should.be.a('datetime.datetime') # x['Status']['Timeline']['StartDateTime'].should.be.a('datetime.datetime') - step_id = steps[0]['Id'] - steps = client.list_steps(ClusterId=cluster_id, StepIds=[step_id])['Steps'] + step_id = steps[0]["Id"] + steps = client.list_steps(ClusterId=cluster_id, StepIds=[step_id])["Steps"] steps.should.have.length_of(1) - steps[0]['Id'].should.equal(step_id) + steps[0]["Id"].should.equal(step_id) - steps = client.list_steps(ClusterId=cluster_id, - StepStates=['STARTING'])['Steps'] + steps = client.list_steps(ClusterId=cluster_id, StepStates=["STARTING"])["Steps"] steps.should.have.length_of(1) - steps[0]['Id'].should.equal(step_id) + steps[0]["Id"].should.equal(step_id) @mock_emr def test_tags(): - input_tags = [{'Key': 'newkey1', 'Value': 'newval1'}, - {'Key': 'newkey2', 'Value': 'newval2'}] + input_tags = [ + {"Key": "newkey1", "Value": "newval1"}, + {"Key": "newkey2", "Value": "newval2"}, + ] - client = boto3.client('emr', region_name='us-east-1') - cluster_id = client.run_job_flow(**run_job_flow_args)['JobFlowId'] + client = boto3.client("emr", region_name="us-east-1") + cluster_id = client.run_job_flow(**run_job_flow_args)["JobFlowId"] client.add_tags(ResourceId=cluster_id, Tags=input_tags) - resp = client.describe_cluster(ClusterId=cluster_id)['Cluster'] - resp['Tags'].should.have.length_of(2) - dict((t['Key'], t['Value']) for t in resp['Tags']).should.equal( - dict((t['Key'], t['Value']) for t in input_tags)) + resp = client.describe_cluster(ClusterId=cluster_id)["Cluster"] + resp["Tags"].should.have.length_of(2) + dict((t["Key"], t["Value"]) for t in resp["Tags"]).should.equal( + dict((t["Key"], t["Value"]) for t in input_tags) + ) - client.remove_tags(ResourceId=cluster_id, TagKeys=[ - t['Key'] for t in input_tags]) - resp = client.describe_cluster(ClusterId=cluster_id)['Cluster'] - resp['Tags'].should.equal([]) + client.remove_tags(ResourceId=cluster_id, TagKeys=[t["Key"] for t in input_tags]) + resp = client.describe_cluster(ClusterId=cluster_id)["Cluster"] + resp["Tags"].should.equal([]) diff --git a/tests/test_emr/test_server.py b/tests/test_emr/test_server.py index 56eba3ff8..4dbd02553 100644 --- a/tests/test_emr/test_server.py +++ b/tests/test_emr/test_server.py @@ -3,16 +3,16 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_describe_jobflows(): backend = server.create_backend_app("emr") test_client = backend.test_client() - res = test_client.get('/?Action=DescribeJobFlows') + res = test_client.get("/?Action=DescribeJobFlows") - res.data.should.contain(b'') - res.data.should.contain(b'') + res.data.should.contain(b"") + res.data.should.contain(b"") diff --git a/tests/test_events/test_events.py b/tests/test_events/test_events.py index e9e1d12c9..d5bfdf782 100644 --- a/tests/test_events/test_events.py +++ b/tests/test_events/test_events.py @@ -7,42 +7,42 @@ from botocore.exceptions import ClientError from nose.tools import assert_raises RULES = [ - {'Name': 'test1', 'ScheduleExpression': 'rate(5 minutes)'}, - {'Name': 'test2', 'ScheduleExpression': 'rate(1 minute)'}, - {'Name': 'test3', 'EventPattern': '{"source": ["test-source"]}'} + {"Name": "test1", "ScheduleExpression": "rate(5 minutes)"}, + {"Name": "test2", "ScheduleExpression": "rate(1 minute)"}, + {"Name": "test3", "EventPattern": '{"source": ["test-source"]}'}, ] TARGETS = { - 'test-target-1': { - 'Id': 'test-target-1', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-1', - 'Rules': ['test1', 'test2'] + "test-target-1": { + "Id": "test-target-1", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-1", + "Rules": ["test1", "test2"], }, - 'test-target-2': { - 'Id': 'test-target-2', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-2', - 'Rules': ['test1', 'test3'] + "test-target-2": { + "Id": "test-target-2", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-2", + "Rules": ["test1", "test3"], }, - 'test-target-3': { - 'Id': 'test-target-3', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-3', - 'Rules': ['test1', 'test2'] + "test-target-3": { + "Id": "test-target-3", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-3", + "Rules": ["test1", "test2"], }, - 'test-target-4': { - 'Id': 'test-target-4', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-4', - 'Rules': ['test1', 'test3'] + "test-target-4": { + "Id": "test-target-4", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-4", + "Rules": ["test1", "test3"], }, - 'test-target-5': { - 'Id': 'test-target-5', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-5', - 'Rules': ['test1', 'test2'] + "test-target-5": { + "Id": "test-target-5", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-5", + "Rules": ["test1", "test2"], + }, + "test-target-6": { + "Id": "test-target-6", + "Arn": "arn:aws:lambda:us-west-2:111111111111:function:test-function-6", + "Rules": ["test1", "test3"], }, - 'test-target-6': { - 'Id': 'test-target-6', - 'Arn': 'arn:aws:lambda:us-west-2:111111111111:function:test-function-6', - 'Rules': ['test1', 'test3'] - } } @@ -51,21 +51,21 @@ def get_random_rule(): def generate_environment(): - client = boto3.client('events', 'us-west-2') + client = boto3.client("events", "us-west-2") for rule in RULES: client.put_rule( - Name=rule['Name'], - ScheduleExpression=rule.get('ScheduleExpression', ''), - EventPattern=rule.get('EventPattern', '') + Name=rule["Name"], + ScheduleExpression=rule.get("ScheduleExpression", ""), + EventPattern=rule.get("EventPattern", ""), ) targets = [] for target in TARGETS: - if rule['Name'] in TARGETS[target].get('Rules'): - targets.append({'Id': target, 'Arn': TARGETS[target]['Arn']}) + if rule["Name"] in TARGETS[target].get("Rules"): + targets.append({"Id": target, "Arn": TARGETS[target]["Arn"]}) - client.put_targets(Rule=rule['Name'], Targets=targets) + client.put_targets(Rule=rule["Name"], Targets=targets) return client @@ -75,61 +75,63 @@ def test_list_rules(): client = generate_environment() response = client.list_rules() - assert(response is not None) - assert(len(response['Rules']) > 0) + assert response is not None + assert len(response["Rules"]) > 0 @mock_events def test_describe_rule(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() response = client.describe_rule(Name=rule_name) - assert(response is not None) - assert(response.get('Name') == rule_name) - assert(response.get('Arn') == 'arn:aws:events:us-west-2:111111111111:rule/{0}'.format(rule_name)) + assert response is not None + assert response.get("Name") == rule_name + assert response.get( + "Arn" + ) == "arn:aws:events:us-west-2:111111111111:rule/{0}".format(rule_name) @mock_events def test_enable_disable_rule(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() # Rules should start out enabled in these tests. rule = client.describe_rule(Name=rule_name) - assert(rule['State'] == 'ENABLED') + assert rule["State"] == "ENABLED" client.disable_rule(Name=rule_name) rule = client.describe_rule(Name=rule_name) - assert(rule['State'] == 'DISABLED') + assert rule["State"] == "DISABLED" client.enable_rule(Name=rule_name) rule = client.describe_rule(Name=rule_name) - assert(rule['State'] == 'ENABLED') + assert rule["State"] == "ENABLED" # Test invalid name try: - client.enable_rule(Name='junk') + client.enable_rule(Name="junk") except ClientError as ce: - assert ce.response['Error']['Code'] == 'ResourceNotFoundException' + assert ce.response["Error"]["Code"] == "ResourceNotFoundException" @mock_events def test_list_rule_names_by_target(): - test_1_target = TARGETS['test-target-1'] - test_2_target = TARGETS['test-target-2'] + test_1_target = TARGETS["test-target-1"] + test_2_target = TARGETS["test-target-2"] client = generate_environment() - rules = client.list_rule_names_by_target(TargetArn=test_1_target['Arn']) - assert(len(rules['RuleNames']) == len(test_1_target['Rules'])) - for rule in rules['RuleNames']: - assert(rule in test_1_target['Rules']) + rules = client.list_rule_names_by_target(TargetArn=test_1_target["Arn"]) + assert len(rules["RuleNames"]) == len(test_1_target["Rules"]) + for rule in rules["RuleNames"]: + assert rule in test_1_target["Rules"] - rules = client.list_rule_names_by_target(TargetArn=test_2_target['Arn']) - assert(len(rules['RuleNames']) == len(test_2_target['Rules'])) - for rule in rules['RuleNames']: - assert(rule in test_2_target['Rules']) + rules = client.list_rule_names_by_target(TargetArn=test_2_target["Arn"]) + assert len(rules["RuleNames"]) == len(test_2_target["Rules"]) + for rule in rules["RuleNames"]: + assert rule in test_2_target["Rules"] @mock_events @@ -137,80 +139,84 @@ def test_list_rules(): client = generate_environment() rules = client.list_rules() - assert(len(rules['Rules']) == len(RULES)) + assert len(rules["Rules"]) == len(RULES) @mock_events def test_delete_rule(): client = generate_environment() - client.delete_rule(Name=RULES[0]['Name']) + client.delete_rule(Name=RULES[0]["Name"]) rules = client.list_rules() - assert(len(rules['Rules']) == len(RULES) - 1) + assert len(rules["Rules"]) == len(RULES) - 1 @mock_events def test_list_targets_by_rule(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() targets = client.list_targets_by_rule(Rule=rule_name) expected_targets = [] for target in TARGETS: - if rule_name in TARGETS[target].get('Rules'): + if rule_name in TARGETS[target].get("Rules"): expected_targets.append(target) - assert(len(targets['Targets']) == len(expected_targets)) + assert len(targets["Targets"]) == len(expected_targets) @mock_events def test_remove_targets(): - rule_name = get_random_rule()['Name'] + rule_name = get_random_rule()["Name"] client = generate_environment() - targets = client.list_targets_by_rule(Rule=rule_name)['Targets'] + targets = client.list_targets_by_rule(Rule=rule_name)["Targets"] targets_before = len(targets) - assert(targets_before > 0) + assert targets_before > 0 - client.remove_targets(Rule=rule_name, Ids=[targets[0]['Id']]) + client.remove_targets(Rule=rule_name, Ids=[targets[0]["Id"]]) - targets = client.list_targets_by_rule(Rule=rule_name)['Targets'] + targets = client.list_targets_by_rule(Rule=rule_name)["Targets"] targets_after = len(targets) - assert(targets_before - 1 == targets_after) + assert targets_before - 1 == targets_after @mock_events def test_permissions(): - client = boto3.client('events', 'eu-central-1') + client = boto3.client("events", "eu-central-1") - client.put_permission(Action='events:PutEvents', Principal='111111111111', StatementId='Account1') - client.put_permission(Action='events:PutEvents', Principal='222222222222', StatementId='Account2') + client.put_permission( + Action="events:PutEvents", Principal="111111111111", StatementId="Account1" + ) + client.put_permission( + Action="events:PutEvents", Principal="222222222222", StatementId="Account2" + ) resp = client.describe_event_bus() - resp_policy = json.loads(resp['Policy']) - assert len(resp_policy['Statement']) == 2 + resp_policy = json.loads(resp["Policy"]) + assert len(resp_policy["Statement"]) == 2 - client.remove_permission(StatementId='Account2') + client.remove_permission(StatementId="Account2") resp = client.describe_event_bus() - resp_policy = json.loads(resp['Policy']) - assert len(resp_policy['Statement']) == 1 - assert resp_policy['Statement'][0]['Sid'] == 'Account1' + resp_policy = json.loads(resp["Policy"]) + assert len(resp_policy["Statement"]) == 1 + assert resp_policy["Statement"][0]["Sid"] == "Account1" @mock_events def test_put_events(): - client = boto3.client('events', 'eu-central-1') + client = boto3.client("events", "eu-central-1") event = { "Source": "com.mycompany.myapp", "Detail": '{"key1": "value3", "key2": "value4"}', "Resources": ["resource1", "resource2"], - "DetailType": "myDetailType" + "DetailType": "myDetailType", } client.put_events(Entries=[event]) # Boto3 would error if it didn't return 200 OK with assert_raises(ClientError): - client.put_events(Entries=[event]*20) + client.put_events(Entries=[event] * 20) diff --git a/tests/test_glacier/test_glacier_jobs.py b/tests/test_glacier/test_glacier_jobs.py index 152aa14c8..11077d7f2 100644 --- a/tests/test_glacier/test_glacier_jobs.py +++ b/tests/test_glacier/test_glacier_jobs.py @@ -15,15 +15,14 @@ def test_init_glacier_job(): vault_name = "my_vault" conn.create_vault(vault_name) archive_id = conn.upload_archive( - vault_name, "some stuff", "", "", "some description") + vault_name, "some stuff", "", "", "some description" + ) - job_response = conn.initiate_job(vault_name, { - "ArchiveId": archive_id, - "Type": "archive-retrieval", - }) - job_id = job_response['JobId'] - job_response['Location'].should.equal( - "//vaults/my_vault/jobs/{0}".format(job_id)) + job_response = conn.initiate_job( + vault_name, {"ArchiveId": archive_id, "Type": "archive-retrieval"} + ) + job_id = job_response["JobId"] + job_response["Location"].should.equal("//vaults/my_vault/jobs/{0}".format(job_id)) @mock_glacier_deprecated @@ -32,19 +31,21 @@ def test_describe_job(): vault_name = "my_vault" conn.create_vault(vault_name) archive_id = conn.upload_archive( - vault_name, "some stuff", "", "", "some description") - job_response = conn.initiate_job(vault_name, { - "ArchiveId": archive_id, - "Type": "archive-retrieval", - }) - job_id = job_response['JobId'] + vault_name, "some stuff", "", "", "some description" + ) + job_response = conn.initiate_job( + vault_name, {"ArchiveId": archive_id, "Type": "archive-retrieval"} + ) + job_id = job_response["JobId"] job = conn.describe_job(vault_name, job_id) joboutput = json.loads(job.read().decode("utf-8")) - - joboutput.should.have.key('Tier').which.should.equal('Standard') - joboutput.should.have.key('StatusCode').which.should.equal('InProgress') - joboutput.should.have.key('VaultARN').which.should.equal('arn:aws:glacier:RegionInfo:us-west-2:012345678901:vaults/my_vault') + + joboutput.should.have.key("Tier").which.should.equal("Standard") + joboutput.should.have.key("StatusCode").which.should.equal("InProgress") + joboutput.should.have.key("VaultARN").which.should.equal( + "arn:aws:glacier:RegionInfo:us-west-2:012345678901:vaults/my_vault" + ) @mock_glacier_deprecated @@ -53,21 +54,21 @@ def test_list_glacier_jobs(): vault_name = "my_vault" conn.create_vault(vault_name) archive_id1 = conn.upload_archive( - vault_name, "some stuff", "", "", "some description")['ArchiveId'] + vault_name, "some stuff", "", "", "some description" + )["ArchiveId"] archive_id2 = conn.upload_archive( - vault_name, "some other stuff", "", "", "some description")['ArchiveId'] + vault_name, "some other stuff", "", "", "some description" + )["ArchiveId"] - conn.initiate_job(vault_name, { - "ArchiveId": archive_id1, - "Type": "archive-retrieval", - }) - conn.initiate_job(vault_name, { - "ArchiveId": archive_id2, - "Type": "archive-retrieval", - }) + conn.initiate_job( + vault_name, {"ArchiveId": archive_id1, "Type": "archive-retrieval"} + ) + conn.initiate_job( + vault_name, {"ArchiveId": archive_id2, "Type": "archive-retrieval"} + ) jobs = conn.list_jobs(vault_name) - len(jobs['JobList']).should.equal(2) + len(jobs["JobList"]).should.equal(2) @mock_glacier_deprecated @@ -76,15 +77,15 @@ def test_get_job_output(): vault_name = "my_vault" conn.create_vault(vault_name) archive_response = conn.upload_archive( - vault_name, "some stuff", "", "", "some description") - archive_id = archive_response['ArchiveId'] - job_response = conn.initiate_job(vault_name, { - "ArchiveId": archive_id, - "Type": "archive-retrieval", - }) - job_id = job_response['JobId'] + vault_name, "some stuff", "", "", "some description" + ) + archive_id = archive_response["ArchiveId"] + job_response = conn.initiate_job( + vault_name, {"ArchiveId": archive_id, "Type": "archive-retrieval"} + ) + job_id = job_response["JobId"] time.sleep(6) - + output = conn.get_job_output(vault_name, job_id) output.read().decode("utf-8").should.equal("some stuff") diff --git a/tests/test_glacier/test_glacier_server.py b/tests/test_glacier/test_glacier_server.py index fd8034421..d43dd4e8a 100644 --- a/tests/test_glacier/test_glacier_server.py +++ b/tests/test_glacier/test_glacier_server.py @@ -6,9 +6,9 @@ import sure # noqa import moto.server as server from moto import mock_glacier -''' +""" Test the different server responses -''' +""" @mock_glacier @@ -16,7 +16,6 @@ def test_list_vaults(): backend = server.create_backend_app("glacier") test_client = backend.test_client() - res = test_client.get('/1234bcd/vaults') + res = test_client.get("/1234bcd/vaults") - json.loads(res.data.decode("utf-8") - ).should.equal({u'Marker': None, u'VaultList': []}) + json.loads(res.data.decode("utf-8")).should.equal({"Marker": None, "VaultList": []}) diff --git a/tests/test_glue/fixtures/datacatalog.py b/tests/test_glue/fixtures/datacatalog.py index edad2f0f4..11cb30ca9 100644 --- a/tests/test_glue/fixtures/datacatalog.py +++ b/tests/test_glue/fixtures/datacatalog.py @@ -1,55 +1,54 @@ from __future__ import unicode_literals TABLE_INPUT = { - 'Owner': 'a_fake_owner', - 'Parameters': { - 'EXTERNAL': 'TRUE', - }, - 'Retention': 0, - 'StorageDescriptor': { - 'BucketColumns': [], - 'Compressed': False, - 'InputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat', - 'NumberOfBuckets': -1, - 'OutputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat', - 'Parameters': {}, - 'SerdeInfo': { - 'Parameters': { - 'serialization.format': '1' - }, - 'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + "Owner": "a_fake_owner", + "Parameters": {"EXTERNAL": "TRUE"}, + "Retention": 0, + "StorageDescriptor": { + "BucketColumns": [], + "Compressed": False, + "InputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "NumberOfBuckets": -1, + "OutputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "Parameters": {}, + "SerdeInfo": { + "Parameters": {"serialization.format": "1"}, + "SerializationLibrary": "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", }, - 'SkewedInfo': { - 'SkewedColumnNames': [], - 'SkewedColumnValueLocationMaps': {}, - 'SkewedColumnValues': [] + "SkewedInfo": { + "SkewedColumnNames": [], + "SkewedColumnValueLocationMaps": {}, + "SkewedColumnValues": [], }, - 'SortColumns': [], - 'StoredAsSubDirectories': False + "SortColumns": [], + "StoredAsSubDirectories": False, }, - 'TableType': 'EXTERNAL_TABLE', + "TableType": "EXTERNAL_TABLE", } PARTITION_INPUT = { # 'DatabaseName': 'dbname', - 'StorageDescriptor': { - 'BucketColumns': [], - 'Columns': [], - 'Compressed': False, - 'InputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat', - 'Location': 's3://.../partition=value', - 'NumberOfBuckets': -1, - 'OutputFormat': 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat', - 'Parameters': {}, - 'SerdeInfo': { - 'Parameters': {'path': 's3://...', 'serialization.format': '1'}, - 'SerializationLibrary': 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'}, - 'SkewedInfo': {'SkewedColumnNames': [], - 'SkewedColumnValueLocationMaps': {}, - 'SkewedColumnValues': []}, - 'SortColumns': [], - 'StoredAsSubDirectories': False, + "StorageDescriptor": { + "BucketColumns": [], + "Columns": [], + "Compressed": False, + "InputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat", + "Location": "s3://.../partition=value", + "NumberOfBuckets": -1, + "OutputFormat": "org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat", + "Parameters": {}, + "SerdeInfo": { + "Parameters": {"path": "s3://...", "serialization.format": "1"}, + "SerializationLibrary": "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe", + }, + "SkewedInfo": { + "SkewedColumnNames": [], + "SkewedColumnValueLocationMaps": {}, + "SkewedColumnValues": [], + }, + "SortColumns": [], + "StoredAsSubDirectories": False, }, # 'TableName': 'source_table', # 'Values': ['2018-06-26'], diff --git a/tests/test_glue/helpers.py b/tests/test_glue/helpers.py index 331b99867..130a879bc 100644 --- a/tests/test_glue/helpers.py +++ b/tests/test_glue/helpers.py @@ -6,11 +6,7 @@ from .fixtures.datacatalog import TABLE_INPUT, PARTITION_INPUT def create_database(client, database_name): - return client.create_database( - DatabaseInput={ - 'Name': database_name - } - ) + return client.create_database(DatabaseInput={"Name": database_name}) def get_database(client, database_name): @@ -19,12 +15,13 @@ def get_database(client, database_name): def create_table_input(database_name, table_name, columns=[], partition_keys=[]): table_input = copy.deepcopy(TABLE_INPUT) - table_input['Name'] = table_name - table_input['PartitionKeys'] = partition_keys - table_input['StorageDescriptor']['Columns'] = columns - table_input['StorageDescriptor']['Location'] = 's3://my-bucket/{database_name}/{table_name}'.format( - database_name=database_name, - table_name=table_name + table_input["Name"] = table_name + table_input["PartitionKeys"] = partition_keys + table_input["StorageDescriptor"]["Columns"] = columns + table_input["StorageDescriptor"][ + "Location" + ] = "s3://my-bucket/{database_name}/{table_name}".format( + database_name=database_name, table_name=table_name ) return table_input @@ -33,60 +30,43 @@ def create_table(client, database_name, table_name, table_input=None, **kwargs): if table_input is None: table_input = create_table_input(database_name, table_name, **kwargs) - return client.create_table( - DatabaseName=database_name, - TableInput=table_input - ) + return client.create_table(DatabaseName=database_name, TableInput=table_input) def update_table(client, database_name, table_name, table_input=None, **kwargs): if table_input is None: table_input = create_table_input(database_name, table_name, **kwargs) - return client.update_table( - DatabaseName=database_name, - TableInput=table_input, - ) + return client.update_table(DatabaseName=database_name, TableInput=table_input) def get_table(client, database_name, table_name): - return client.get_table( - DatabaseName=database_name, - Name=table_name - ) + return client.get_table(DatabaseName=database_name, Name=table_name) def get_tables(client, database_name): - return client.get_tables( - DatabaseName=database_name - ) + return client.get_tables(DatabaseName=database_name) def get_table_versions(client, database_name, table_name): - return client.get_table_versions( - DatabaseName=database_name, - TableName=table_name - ) + return client.get_table_versions(DatabaseName=database_name, TableName=table_name) def get_table_version(client, database_name, table_name, version_id): return client.get_table_version( - DatabaseName=database_name, - TableName=table_name, - VersionId=version_id, + DatabaseName=database_name, TableName=table_name, VersionId=version_id ) def create_partition_input(database_name, table_name, values=[], columns=[]): - root_path = 's3://my-bucket/{database_name}/{table_name}'.format( - database_name=database_name, - table_name=table_name + root_path = "s3://my-bucket/{database_name}/{table_name}".format( + database_name=database_name, table_name=table_name ) part_input = copy.deepcopy(PARTITION_INPUT) - part_input['Values'] = values - part_input['StorageDescriptor']['Columns'] = columns - part_input['StorageDescriptor']['SerdeInfo']['Parameters']['path'] = root_path + part_input["Values"] = values + part_input["StorageDescriptor"]["Columns"] = columns + part_input["StorageDescriptor"]["SerdeInfo"]["Parameters"]["path"] = root_path return part_input @@ -94,13 +74,13 @@ def create_partition(client, database_name, table_name, partiton_input=None, **k if partiton_input is None: partiton_input = create_partition_input(database_name, table_name, **kwargs) return client.create_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionInput=partiton_input + DatabaseName=database_name, TableName=table_name, PartitionInput=partiton_input ) -def update_partition(client, database_name, table_name, old_values=[], partiton_input=None, **kwargs): +def update_partition( + client, database_name, table_name, old_values=[], partiton_input=None, **kwargs +): if partiton_input is None: partiton_input = create_partition_input(database_name, table_name, **kwargs) return client.update_partition( @@ -113,7 +93,5 @@ def update_partition(client, database_name, table_name, old_values=[], partiton_ def get_partition(client, database_name, table_name, values): return client.get_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionValues=values, + DatabaseName=database_name, TableName=table_name, PartitionValues=values ) diff --git a/tests/test_glue/test_datacatalog.py b/tests/test_glue/test_datacatalog.py index 9034feb55..28281b18f 100644 --- a/tests/test_glue/test_datacatalog.py +++ b/tests/test_glue/test_datacatalog.py @@ -16,80 +16,82 @@ from . import helpers @mock_glue def test_create_database(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) response = helpers.get_database(client, database_name) - database = response['Database'] + database = response["Database"] - database.should.equal({'Name': database_name}) + database.should.equal({"Name": database_name}) @mock_glue def test_create_database_already_exists(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'cantcreatethisdatabasetwice' + client = boto3.client("glue", region_name="us-east-1") + database_name = "cantcreatethisdatabasetwice" helpers.create_database(client, database_name) with assert_raises(ClientError) as exc: helpers.create_database(client, database_name) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_get_database_not_exits(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'nosuchdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "nosuchdatabase" with assert_raises(ClientError) as exc: helpers.get_database(client, database_name) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Database nosuchdatabase not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Database nosuchdatabase not found" + ) @mock_glue def test_create_table(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myspecialtable' + table_name = "myspecialtable" table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) response = helpers.get_table(client, database_name, table_name) - table = response['Table'] + table = response["Table"] - table['Name'].should.equal(table_input['Name']) - table['StorageDescriptor'].should.equal(table_input['StorageDescriptor']) - table['PartitionKeys'].should.equal(table_input['PartitionKeys']) + table["Name"].should.equal(table_input["Name"]) + table["StorageDescriptor"].should.equal(table_input["StorageDescriptor"]) + table["PartitionKeys"].should.equal(table_input["PartitionKeys"]) @mock_glue def test_create_table_already_exists(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'cantcreatethistabletwice' + table_name = "cantcreatethistabletwice" helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: helpers.create_table(client, database_name, table_name) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_get_tables(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_names = ['myfirsttable', 'mysecondtable', 'mythirdtable'] + table_names = ["myfirsttable", "mysecondtable", "mythirdtable"] table_inputs = {} for table_name in table_names: @@ -99,31 +101,33 @@ def test_get_tables(): response = helpers.get_tables(client, database_name) - tables = response['TableList'] + tables = response["TableList"] tables.should.have.length_of(3) for table in tables: - table_name = table['Name'] - table_name.should.equal(table_inputs[table_name]['Name']) - table['StorageDescriptor'].should.equal(table_inputs[table_name]['StorageDescriptor']) - table['PartitionKeys'].should.equal(table_inputs[table_name]['PartitionKeys']) + table_name = table["Name"] + table_name.should.equal(table_inputs[table_name]["Name"]) + table["StorageDescriptor"].should.equal( + table_inputs[table_name]["StorageDescriptor"] + ) + table["PartitionKeys"].should.equal(table_inputs[table_name]["PartitionKeys"]) @mock_glue def test_get_table_versions(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myfirsttable' + table_name = "myfirsttable" version_inputs = {} table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) version_inputs["1"] = table_input - columns = [{'Name': 'country', 'Type': 'string'}] + columns = [{"Name": "country", "Type": "string"}] table_input = helpers.create_table_input(database_name, table_name, columns=columns) helpers.update_table(client, database_name, table_name, table_input) version_inputs["2"] = table_input @@ -134,174 +138,189 @@ def test_get_table_versions(): response = helpers.get_table_versions(client, database_name, table_name) - vers = response['TableVersions'] + vers = response["TableVersions"] vers.should.have.length_of(3) - vers[0]['Table']['StorageDescriptor']['Columns'].should.equal([]) - vers[-1]['Table']['StorageDescriptor']['Columns'].should.equal(columns) + vers[0]["Table"]["StorageDescriptor"]["Columns"].should.equal([]) + vers[-1]["Table"]["StorageDescriptor"]["Columns"].should.equal(columns) for n, ver in enumerate(vers): n = str(n + 1) - ver['VersionId'].should.equal(n) - ver['Table']['Name'].should.equal(table_name) - ver['Table']['StorageDescriptor'].should.equal(version_inputs[n]['StorageDescriptor']) - ver['Table']['PartitionKeys'].should.equal(version_inputs[n]['PartitionKeys']) + ver["VersionId"].should.equal(n) + ver["Table"]["Name"].should.equal(table_name) + ver["Table"]["StorageDescriptor"].should.equal( + version_inputs[n]["StorageDescriptor"] + ) + ver["Table"]["PartitionKeys"].should.equal(version_inputs[n]["PartitionKeys"]) response = helpers.get_table_version(client, database_name, table_name, "3") - ver = response['TableVersion'] + ver = response["TableVersion"] - ver['VersionId'].should.equal("3") - ver['Table']['Name'].should.equal(table_name) - ver['Table']['StorageDescriptor']['Columns'].should.equal(columns) + ver["VersionId"].should.equal("3") + ver["Table"]["Name"].should.equal(table_name) + ver["Table"]["StorageDescriptor"]["Columns"].should.equal(columns) @mock_glue def test_get_table_version_not_found(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.get_table_version(client, database_name, 'myfirsttable', "20") + helpers.get_table_version(client, database_name, "myfirsttable", "20") - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('version', re.I) + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("version", re.I) @mock_glue def test_get_table_version_invalid_input(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.get_table_version(client, database_name, 'myfirsttable', "10not-an-int") + helpers.get_table_version(client, database_name, "myfirsttable", "10not-an-int") - exc.exception.response['Error']['Code'].should.equal('InvalidInputException') + exc.exception.response["Error"]["Code"].should.equal("InvalidInputException") @mock_glue def test_get_table_not_exits(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) with assert_raises(ClientError) as exc: - helpers.get_table(client, database_name, 'myfirsttable') + helpers.get_table(client, database_name, "myfirsttable") - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Table myfirsttable not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Table myfirsttable not found" + ) @mock_glue def test_get_table_when_database_not_exits(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'nosuchdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "nosuchdatabase" with assert_raises(ClientError) as exc: - helpers.get_table(client, database_name, 'myfirsttable') + helpers.get_table(client, database_name, "myfirsttable") - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Database nosuchdatabase not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Database nosuchdatabase not found" + ) @mock_glue def test_delete_table(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myspecialtable' + table_name = "myspecialtable" table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) result = client.delete_table(DatabaseName=database_name, Name=table_name) - result['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # confirm table is deleted with assert_raises(ClientError) as exc: helpers.get_table(client, database_name, table_name) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Table myspecialtable not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Table myspecialtable not found" + ) + @mock_glue def test_batch_delete_table(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" helpers.create_database(client, database_name) - table_name = 'myspecialtable' + table_name = "myspecialtable" table_input = helpers.create_table_input(database_name, table_name) helpers.create_table(client, database_name, table_name, table_input) - result = client.batch_delete_table(DatabaseName=database_name, TablesToDelete=[table_name]) - result['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + result = client.batch_delete_table( + DatabaseName=database_name, TablesToDelete=[table_name] + ) + result["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # confirm table is deleted with assert_raises(ClientError) as exc: helpers.get_table(client, database_name, table_name) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('Table myspecialtable not found') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match( + "Table myspecialtable not found" + ) @mock_glue def test_get_partitions_empty(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - response['Partitions'].should.have.length_of(0) + response["Partitions"].should.have.length_of(0) @mock_glue def test_create_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) before = datetime.now(pytz.utc) - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) helpers.create_partition(client, database_name, table_name, part_input) after = datetime.now(pytz.utc) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(1) partition = partitions[0] - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor'].should.equal(part_input['StorageDescriptor']) - partition['Values'].should.equal(values) - partition['CreationTime'].should.be.greater_than(before) - partition['CreationTime'].should.be.lower_than(after) + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"].should.equal(part_input["StorageDescriptor"]) + partition["Values"].should.equal(values) + partition["CreationTime"].should.be.greater_than(before) + partition["CreationTime"].should.be.lower_than(after) @mock_glue def test_create_partition_already_exist(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -311,15 +330,15 @@ def test_create_partition_already_exist(): with assert_raises(ClientError) as exc: helpers.create_partition(client, database_name, table_name, values=values) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_get_partition_not_found(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -327,14 +346,15 @@ def test_get_partition_not_found(): with assert_raises(ClientError) as exc: helpers.get_partition(client, database_name, table_name, values) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('partition') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("partition") + @mock_glue def test_batch_create_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -344,197 +364,221 @@ def test_batch_create_partition(): partition_inputs = [] for i in range(0, 20): values = ["2018-10-{:2}".format(i)] - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) partition_inputs.append(part_input) client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=partition_inputs + PartitionInputList=partition_inputs, ) after = datetime.now(pytz.utc) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(20) for idx, partition in enumerate(partitions): partition_input = partition_inputs[idx] - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor'].should.equal(partition_input['StorageDescriptor']) - partition['Values'].should.equal(partition_input['Values']) - partition['CreationTime'].should.be.greater_than(before) - partition['CreationTime'].should.be.lower_than(after) + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"].should.equal( + partition_input["StorageDescriptor"] + ) + partition["Values"].should.equal(partition_input["Values"]) + partition["CreationTime"].should.be.greater_than(before) + partition["CreationTime"].should.be.lower_than(after) @mock_glue def test_batch_create_partition_already_exist(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) helpers.create_partition(client, database_name, table_name, values=values) - partition_input = helpers.create_partition_input(database_name, table_name, values=values) + partition_input = helpers.create_partition_input( + database_name, table_name, values=values + ) response = client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=[partition_input] + PartitionInputList=[partition_input], ) - response.should.have.key('Errors') - response['Errors'].should.have.length_of(1) - response['Errors'][0]['PartitionValues'].should.equal(values) - response['Errors'][0]['ErrorDetail']['ErrorCode'].should.equal('AlreadyExistsException') + response.should.have.key("Errors") + response["Errors"].should.have.length_of(1) + response["Errors"][0]["PartitionValues"].should.equal(values) + response["Errors"][0]["ErrorDetail"]["ErrorCode"].should.equal( + "AlreadyExistsException" + ) @mock_glue def test_get_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01']] + values = [["2018-10-01"], ["2018-09-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[1]) - response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=values[1]) + response = client.get_partition( + DatabaseName=database_name, TableName=table_name, PartitionValues=values[1] + ) - partition = response['Partition'] + partition = response["Partition"] - partition['TableName'].should.equal(table_name) - partition['Values'].should.equal(values[1]) + partition["TableName"].should.equal(table_name) + partition["Values"].should.equal(values[1]) @mock_glue def test_batch_get_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01']] + values = [["2018-10-01"], ["2018-09-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[1]) - partitions_to_get = [ - {'Values': values[0]}, - {'Values': values[1]}, - ] - response = client.batch_get_partition(DatabaseName=database_name, TableName=table_name, PartitionsToGet=partitions_to_get) + partitions_to_get = [{"Values": values[0]}, {"Values": values[1]}] + response = client.batch_get_partition( + DatabaseName=database_name, + TableName=table_name, + PartitionsToGet=partitions_to_get, + ) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(2) partition = partitions[1] - partition['TableName'].should.equal(table_name) - partition['Values'].should.equal(values[1]) + partition["TableName"].should.equal(table_name) + partition["Values"].should.equal(values[1]) @mock_glue def test_batch_get_partition_missing_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01'], ['2018-08-01']] + values = [["2018-10-01"], ["2018-09-01"], ["2018-08-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[2]) partitions_to_get = [ - {'Values': values[0]}, - {'Values': values[1]}, - {'Values': values[2]}, + {"Values": values[0]}, + {"Values": values[1]}, + {"Values": values[2]}, ] - response = client.batch_get_partition(DatabaseName=database_name, TableName=table_name, PartitionsToGet=partitions_to_get) + response = client.batch_get_partition( + DatabaseName=database_name, + TableName=table_name, + PartitionsToGet=partitions_to_get, + ) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.have.length_of(2) - partitions[0]['Values'].should.equal(values[0]) - partitions[1]['Values'].should.equal(values[2]) - + partitions[0]["Values"].should.equal(values[0]) + partitions[1]["Values"].should.equal(values[2]) @mock_glue def test_update_partition_not_found_moving(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.update_partition(client, database_name, table_name, old_values=['0000-00-00'], values=['2018-10-02']) + helpers.update_partition( + client, + database_name, + table_name, + old_values=["0000-00-00"], + values=["2018-10-02"], + ) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('partition') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("partition") @mock_glue def test_update_partition_not_found_change_in_place(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: - helpers.update_partition(client, database_name, table_name, old_values=values, values=values) + helpers.update_partition( + client, database_name, table_name, old_values=values, values=values + ) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') - exc.exception.response['Error']['Message'].should.match('partition') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + exc.exception.response["Error"]["Message"].should.match("partition") @mock_glue def test_update_partition_cannot_overwrite(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - values = [['2018-10-01'], ['2018-09-01']] + values = [["2018-10-01"], ["2018-09-01"]] helpers.create_partition(client, database_name, table_name, values=values[0]) helpers.create_partition(client, database_name, table_name, values=values[1]) with assert_raises(ClientError) as exc: - helpers.update_partition(client, database_name, table_name, old_values=values[0], values=values[1]) + helpers.update_partition( + client, database_name, table_name, old_values=values[0], values=values[1] + ) - exc.exception.response['Error']['Code'].should.equal('AlreadyExistsException') + exc.exception.response["Error"]["Code"].should.equal("AlreadyExistsException") @mock_glue def test_update_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -546,23 +590,27 @@ def test_update_partition(): table_name, old_values=values, values=values, - columns=[{'Name': 'country', 'Type': 'string'}], + columns=[{"Name": "country", "Type": "string"}], ) - response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=values) - partition = response['Partition'] + response = client.get_partition( + DatabaseName=database_name, TableName=table_name, PartitionValues=values + ) + partition = response["Partition"] - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor']['Columns'].should.equal([{'Name': 'country', 'Type': 'string'}]) + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"]["Columns"].should.equal( + [{"Name": "country", "Type": "string"}] + ) @mock_glue def test_update_partition_move(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] - new_values = ['2018-09-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] + new_values = ["2018-09-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) @@ -574,79 +622,86 @@ def test_update_partition_move(): table_name, old_values=values, values=new_values, - columns=[{'Name': 'country', 'Type': 'string'}], + columns=[{"Name": "country", "Type": "string"}], ) with assert_raises(ClientError) as exc: helpers.get_partition(client, database_name, table_name, values) # Old partition shouldn't exist anymore - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") - response = client.get_partition(DatabaseName=database_name, TableName=table_name, PartitionValues=new_values) - partition = response['Partition'] + response = client.get_partition( + DatabaseName=database_name, TableName=table_name, PartitionValues=new_values + ) + partition = response["Partition"] + + partition["TableName"].should.equal(table_name) + partition["StorageDescriptor"]["Columns"].should.equal( + [{"Name": "country", "Type": "string"}] + ) - partition['TableName'].should.equal(table_name) - partition['StorageDescriptor']['Columns'].should.equal([{'Name': 'country', 'Type': 'string'}]) @mock_glue def test_delete_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) helpers.create_partition(client, database_name, table_name, part_input) client.delete_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionValues=values, + DatabaseName=database_name, TableName=table_name, PartitionValues=values ) response = client.get_partitions(DatabaseName=database_name, TableName=table_name) - partitions = response['Partitions'] + partitions = response["Partitions"] partitions.should.be.empty + @mock_glue def test_delete_partition_bad_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' - values = ['2018-10-01'] + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" + values = ["2018-10-01"] helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) with assert_raises(ClientError) as exc: client.delete_partition( - DatabaseName=database_name, - TableName=table_name, - PartitionValues=values, + DatabaseName=database_name, TableName=table_name, PartitionValues=values ) - exc.exception.response['Error']['Code'].should.equal('EntityNotFoundException') + exc.exception.response["Error"]["Code"].should.equal("EntityNotFoundException") + @mock_glue def test_batch_delete_partition(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) partition_inputs = [] for i in range(0, 20): values = ["2018-10-{:2}".format(i)] - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) partition_inputs.append(part_input) client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=partition_inputs + PartitionInputList=partition_inputs, ) partition_values = [{"Values": p["Values"]} for p in partition_inputs] @@ -657,26 +712,29 @@ def test_batch_delete_partition(): PartitionsToDelete=partition_values, ) - response.should_not.have.key('Errors') + response.should_not.have.key("Errors") + @mock_glue def test_batch_delete_partition_with_bad_partitions(): - client = boto3.client('glue', region_name='us-east-1') - database_name = 'myspecialdatabase' - table_name = 'myfirsttable' + client = boto3.client("glue", region_name="us-east-1") + database_name = "myspecialdatabase" + table_name = "myfirsttable" helpers.create_database(client, database_name) helpers.create_table(client, database_name, table_name) partition_inputs = [] for i in range(0, 20): values = ["2018-10-{:2}".format(i)] - part_input = helpers.create_partition_input(database_name, table_name, values=values) + part_input = helpers.create_partition_input( + database_name, table_name, values=values + ) partition_inputs.append(part_input) client.batch_create_partition( DatabaseName=database_name, TableName=table_name, - PartitionInputList=partition_inputs + PartitionInputList=partition_inputs, ) partition_values = [{"Values": p["Values"]} for p in partition_inputs] @@ -691,9 +749,9 @@ def test_batch_delete_partition_with_bad_partitions(): PartitionsToDelete=partition_values, ) - response.should.have.key('Errors') - response['Errors'].should.have.length_of(3) - error_partitions = map(lambda x: x['PartitionValues'], response['Errors']) - ['2018-11-01'].should.be.within(error_partitions) - ['2018-11-02'].should.be.within(error_partitions) - ['2018-11-03'].should.be.within(error_partitions) + response.should.have.key("Errors") + response["Errors"].should.have.length_of(3) + error_partitions = map(lambda x: x["PartitionValues"], response["Errors"]) + ["2018-11-01"].should.be.within(error_partitions) + ["2018-11-02"].should.be.within(error_partitions) + ["2018-11-03"].should.be.within(error_partitions) diff --git a/tests/test_iam/test_iam.py b/tests/test_iam/test_iam.py index 4d6c37e83..c4fcda317 100644 --- a/tests/test_iam/test_iam.py +++ b/tests/test_iam/test_iam.py @@ -76,13 +76,13 @@ def test_get_all_server_certs(): conn = boto.connect_iam() conn.upload_server_cert("certname", "certbody", "privatekey") - certs = conn.get_all_server_certs()['list_server_certificates_response'][ - 'list_server_certificates_result']['server_certificate_metadata_list'] + certs = conn.get_all_server_certs()["list_server_certificates_response"][ + "list_server_certificates_result" + ]["server_certificate_metadata_list"] certs.should.have.length_of(1) cert1 = certs[0] cert1.server_certificate_name.should.equal("certname") - cert1.arn.should.equal( - "arn:aws:iam::123456789012:server-certificate/certname") + cert1.arn.should.equal("arn:aws:iam::123456789012:server-certificate/certname") @mock_iam_deprecated() @@ -100,8 +100,7 @@ def test_get_server_cert(): conn.upload_server_cert("certname", "certbody", "privatekey") cert = conn.get_server_certificate("certname") cert.server_certificate_name.should.equal("certname") - cert.arn.should.equal( - "arn:aws:iam::123456789012:server-certificate/certname") + cert.arn.should.equal("arn:aws:iam::123456789012:server-certificate/certname") @mock_iam_deprecated() @@ -111,8 +110,7 @@ def test_upload_server_cert(): conn.upload_server_cert("certname", "certbody", "privatekey") cert = conn.get_server_certificate("certname") cert.server_certificate_name.should.equal("certname") - cert.arn.should.equal( - "arn:aws:iam::123456789012:server-certificate/certname") + cert.arn.should.equal("arn:aws:iam::123456789012:server-certificate/certname") @mock_iam_deprecated() @@ -133,7 +131,7 @@ def test_delete_server_cert(): def test_get_role__should_throw__when_role_does_not_exist(): conn = boto.connect_iam() - conn.get_role('unexisting_role') + conn.get_role("unexisting_role") @mock_iam_deprecated() @@ -141,7 +139,7 @@ def test_get_role__should_throw__when_role_does_not_exist(): def test_get_instance_profile__should_throw__when_instance_profile_does_not_exist(): conn = boto.connect_iam() - conn.get_instance_profile('unexisting_instance_profile') + conn.get_instance_profile("unexisting_instance_profile") @mock_iam_deprecated() @@ -149,7 +147,8 @@ def test_create_role_and_instance_profile(): conn = boto.connect_iam() conn.create_instance_profile("my-profile", path="my-path") conn.create_role( - "my-role", assume_role_policy_document="some policy", path="my-path") + "my-role", assume_role_policy_document="some policy", path="my-path" + ) conn.add_role_to_instance_profile("my-profile", "my-role") @@ -160,26 +159,28 @@ def test_create_role_and_instance_profile(): profile = conn.get_instance_profile("my-profile") profile.path.should.equal("my-path") role_from_profile = list(profile.roles.values())[0] - role_from_profile['role_id'].should.equal(role.role_id) - role_from_profile['role_name'].should.equal("my-role") + role_from_profile["role_id"].should.equal(role.role_id) + role_from_profile["role_name"].should.equal("my-role") - conn.list_roles().roles[0].role_name.should.equal('my-role') + conn.list_roles().roles[0].role_name.should.equal("my-role") # Test with an empty path: - profile = conn.create_instance_profile('my-other-profile') - profile.path.should.equal('/') + profile = conn.create_instance_profile("my-other-profile") + profile.path.should.equal("/") + @mock_iam_deprecated() def test_remove_role_from_instance_profile(): conn = boto.connect_iam() conn.create_instance_profile("my-profile", path="my-path") conn.create_role( - "my-role", assume_role_policy_document="some policy", path="my-path") + "my-role", assume_role_policy_document="some policy", path="my-path" + ) conn.add_role_to_instance_profile("my-profile", "my-role") profile = conn.get_instance_profile("my-profile") role_from_profile = list(profile.roles.values())[0] - role_from_profile['role_name'].should.equal("my-role") + role_from_profile["role_name"].should.equal("my-role") conn.remove_role_from_instance_profile("my-profile", "my-role") @@ -189,49 +190,59 @@ def test_remove_role_from_instance_profile(): @mock_iam() def test_get_login_profile(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') - conn.create_login_profile(UserName='my-user', Password='my-pass') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") + conn.create_login_profile(UserName="my-user", Password="my-pass") - response = conn.get_login_profile(UserName='my-user') - response['LoginProfile']['UserName'].should.equal('my-user') + response = conn.get_login_profile(UserName="my-user") + response["LoginProfile"]["UserName"].should.equal("my-user") @mock_iam() def test_update_login_profile(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') - conn.create_login_profile(UserName='my-user', Password='my-pass') - response = conn.get_login_profile(UserName='my-user') - response['LoginProfile'].get('PasswordResetRequired').should.equal(None) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") + conn.create_login_profile(UserName="my-user", Password="my-pass") + response = conn.get_login_profile(UserName="my-user") + response["LoginProfile"].get("PasswordResetRequired").should.equal(None) - conn.update_login_profile(UserName='my-user', Password='new-pass', PasswordResetRequired=True) - response = conn.get_login_profile(UserName='my-user') - response['LoginProfile'].get('PasswordResetRequired').should.equal(True) + conn.update_login_profile( + UserName="my-user", Password="new-pass", PasswordResetRequired=True + ) + response = conn.get_login_profile(UserName="my-user") + response["LoginProfile"].get("PasswordResetRequired").should.equal(True) @mock_iam() def test_delete_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(conn.exceptions.NoSuchEntityException): conn.delete_role(RoleName="my-role") # Test deletion failure with a managed policy - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") - response = conn.create_policy(PolicyName="my-managed-policy", PolicyDocument=MOCK_POLICY) - conn.attach_role_policy(PolicyArn=response['Policy']['Arn'], RoleName="my-role") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + response = conn.create_policy( + PolicyName="my-managed-policy", PolicyDocument=MOCK_POLICY + ) + conn.attach_role_policy(PolicyArn=response["Policy"]["Arn"], RoleName="my-role") with assert_raises(conn.exceptions.DeleteConflictException): conn.delete_role(RoleName="my-role") - conn.detach_role_policy(PolicyArn=response['Policy']['Arn'], RoleName="my-role") - conn.delete_policy(PolicyArn=response['Policy']['Arn']) + conn.detach_role_policy(PolicyArn=response["Policy"]["Arn"], RoleName="my-role") + conn.delete_policy(PolicyArn=response["Policy"]["Arn"]) conn.delete_role(RoleName="my-role") with assert_raises(conn.exceptions.NoSuchEntityException): conn.get_role(RoleName="my-role") # Test deletion failure with an inline policy - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") - conn.put_role_policy(RoleName="my-role", PolicyName="my-role-policy", PolicyDocument=MOCK_POLICY) + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + conn.put_role_policy( + RoleName="my-role", PolicyName="my-role-policy", PolicyDocument=MOCK_POLICY + ) with assert_raises(conn.exceptions.DeleteConflictException): conn.delete_role(RoleName="my-role") conn.delete_role_policy(RoleName="my-role", PolicyName="my-role-policy") @@ -240,18 +251,26 @@ def test_delete_role(): conn.get_role(RoleName="my-role") # Test deletion failure with attachment to an instance profile - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) conn.create_instance_profile(InstanceProfileName="my-profile") - conn.add_role_to_instance_profile(InstanceProfileName="my-profile", RoleName="my-role") + conn.add_role_to_instance_profile( + InstanceProfileName="my-profile", RoleName="my-role" + ) with assert_raises(conn.exceptions.DeleteConflictException): conn.delete_role(RoleName="my-role") - conn.remove_role_from_instance_profile(InstanceProfileName="my-profile", RoleName="my-role") + conn.remove_role_from_instance_profile( + InstanceProfileName="my-profile", RoleName="my-role" + ) conn.delete_role(RoleName="my-role") with assert_raises(conn.exceptions.NoSuchEntityException): conn.get_role(RoleName="my-role") # Test deletion with no conflicts - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) conn.delete_role(RoleName="my-role") with assert_raises(conn.exceptions.NoSuchEntityException): conn.get_role(RoleName="my-role") @@ -276,37 +295,43 @@ def test_list_instance_profiles(): def test_list_instance_profiles_for_role(): conn = boto.connect_iam() - conn.create_role(role_name="my-role", - assume_role_policy_document="some policy", path="my-path") - conn.create_role(role_name="my-role2", - assume_role_policy_document="some policy2", path="my-path2") + conn.create_role( + role_name="my-role", assume_role_policy_document="some policy", path="my-path" + ) + conn.create_role( + role_name="my-role2", + assume_role_policy_document="some policy2", + path="my-path2", + ) - profile_name_list = ['my-profile', 'my-profile2'] - profile_path_list = ['my-path', 'my-path2'] + profile_name_list = ["my-profile", "my-profile2"] + profile_path_list = ["my-path", "my-path2"] for profile_count in range(0, 2): conn.create_instance_profile( - profile_name_list[profile_count], path=profile_path_list[profile_count]) + profile_name_list[profile_count], path=profile_path_list[profile_count] + ) for profile_count in range(0, 2): - conn.add_role_to_instance_profile( - profile_name_list[profile_count], "my-role") + conn.add_role_to_instance_profile(profile_name_list[profile_count], "my-role") profile_dump = conn.list_instance_profiles_for_role(role_name="my-role") - profile_list = profile_dump['list_instance_profiles_for_role_response'][ - 'list_instance_profiles_for_role_result']['instance_profiles'] + profile_list = profile_dump["list_instance_profiles_for_role_response"][ + "list_instance_profiles_for_role_result" + ]["instance_profiles"] for profile_count in range(0, len(profile_list)): - profile_name_list.remove(profile_list[profile_count][ - "instance_profile_name"]) + profile_name_list.remove(profile_list[profile_count]["instance_profile_name"]) profile_path_list.remove(profile_list[profile_count]["path"]) - profile_list[profile_count]["roles"]["member"][ - "role_name"].should.equal("my-role") + profile_list[profile_count]["roles"]["member"]["role_name"].should.equal( + "my-role" + ) len(profile_name_list).should.equal(0) len(profile_path_list).should.equal(0) profile_dump2 = conn.list_instance_profiles_for_role(role_name="my-role2") - profile_list = profile_dump2['list_instance_profiles_for_role_response'][ - 'list_instance_profiles_for_role_result']['instance_profiles'] + profile_list = profile_dump2["list_instance_profiles_for_role_response"][ + "list_instance_profiles_for_role_result" + ]["instance_profiles"] len(profile_list).should.equal(0) @@ -336,18 +361,21 @@ def test_list_role_policies(): def test_put_role_policy(): conn = boto.connect_iam() conn.create_role( - "my-role", assume_role_policy_document="some policy", path="my-path") + "my-role", assume_role_policy_document="some policy", path="my-path" + ) conn.put_role_policy("my-role", "test policy", MOCK_POLICY) - policy = conn.get_role_policy( - "my-role", "test policy")['get_role_policy_response']['get_role_policy_result']['policy_name'] + policy = conn.get_role_policy("my-role", "test policy")["get_role_policy_response"][ + "get_role_policy_result" + ]["policy_name"] policy.should.equal("test policy") @mock_iam def test_get_role_policy(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_role( - RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="my-path") + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="my-path" + ) with assert_raises(conn.exceptions.NoSuchEntityException): conn.get_role_policy(RoleName="my-role", PolicyName="does-not-exist") @@ -363,338 +391,361 @@ def test_update_assume_role_policy(): @mock_iam def test_create_policy(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") response = conn.create_policy( - PolicyName="TestCreatePolicy", - PolicyDocument=MOCK_POLICY) - response['Policy']['Arn'].should.equal("arn:aws:iam::123456789012:policy/TestCreatePolicy") + PolicyName="TestCreatePolicy", PolicyDocument=MOCK_POLICY + ) + response["Policy"]["Arn"].should.equal( + "arn:aws:iam::123456789012:policy/TestCreatePolicy" + ) @mock_iam def test_delete_policy(): - conn = boto3.client('iam', region_name='us-east-1') - response = conn.create_policy(PolicyName="TestCreatePolicy", PolicyDocument=MOCK_POLICY) - [pol['PolicyName'] for pol in conn.list_policies(Scope='Local')['Policies']].should.equal(['TestCreatePolicy']) - conn.delete_policy(PolicyArn=response['Policy']['Arn']) - assert conn.list_policies(Scope='Local')['Policies'].should.be.empty + conn = boto3.client("iam", region_name="us-east-1") + response = conn.create_policy( + PolicyName="TestCreatePolicy", PolicyDocument=MOCK_POLICY + ) + [ + pol["PolicyName"] for pol in conn.list_policies(Scope="Local")["Policies"] + ].should.equal(["TestCreatePolicy"]) + conn.delete_policy(PolicyArn=response["Policy"]["Arn"]) + assert conn.list_policies(Scope="Local")["Policies"].should.be.empty @mock_iam def test_create_policy_versions(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", - PolicyDocument='{"some":"policy"}') - conn.create_policy( - PolicyName="TestCreatePolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyDocument='{"some":"policy"}', + ) + conn.create_policy(PolicyName="TestCreatePolicyVersion", PolicyDocument=MOCK_POLICY) version = conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", PolicyDocument=MOCK_POLICY, - SetAsDefault=True) - version.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY)) - version.get('PolicyVersion').get('VersionId').should.equal("v2") - version.get('PolicyVersion').get('IsDefaultVersion').should.be.ok + SetAsDefault=True, + ) + version.get("PolicyVersion").get("Document").should.equal(json.loads(MOCK_POLICY)) + version.get("PolicyVersion").get("VersionId").should.equal("v2") + version.get("PolicyVersion").get("IsDefaultVersion").should.be.ok conn.delete_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", - VersionId="v1") + VersionId="v1", + ) version = conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestCreatePolicyVersion", - PolicyDocument=MOCK_POLICY) - version.get('PolicyVersion').get('VersionId').should.equal("v3") - version.get('PolicyVersion').get('IsDefaultVersion').shouldnt.be.ok + PolicyDocument=MOCK_POLICY, + ) + version.get("PolicyVersion").get("VersionId").should.equal("v3") + version.get("PolicyVersion").get("IsDefaultVersion").shouldnt.be.ok @mock_iam def test_create_many_policy_versions(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_policy( - PolicyName="TestCreateManyPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyName="TestCreateManyPolicyVersions", PolicyDocument=MOCK_POLICY + ) for _ in range(0, 4): conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestCreateManyPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyDocument=MOCK_POLICY, + ) with assert_raises(ClientError): conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestCreateManyPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyDocument=MOCK_POLICY, + ) @mock_iam def test_set_default_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_policy( - PolicyName="TestSetDefaultPolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyName="TestSetDefaultPolicyVersion", PolicyDocument=MOCK_POLICY + ) conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion", PolicyDocument=MOCK_POLICY_2, - SetAsDefault=True) + SetAsDefault=True, + ) conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion", PolicyDocument=MOCK_POLICY_3, - SetAsDefault=True) + SetAsDefault=True, + ) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion") - versions.get('Versions')[0].get('Document').should.equal(json.loads(MOCK_POLICY)) - versions.get('Versions')[0].get('IsDefaultVersion').shouldnt.be.ok - versions.get('Versions')[1].get('Document').should.equal(json.loads(MOCK_POLICY_2)) - versions.get('Versions')[1].get('IsDefaultVersion').shouldnt.be.ok - versions.get('Versions')[2].get('Document').should.equal(json.loads(MOCK_POLICY_3)) - versions.get('Versions')[2].get('IsDefaultVersion').should.be.ok + PolicyArn="arn:aws:iam::123456789012:policy/TestSetDefaultPolicyVersion" + ) + versions.get("Versions")[0].get("Document").should.equal(json.loads(MOCK_POLICY)) + versions.get("Versions")[0].get("IsDefaultVersion").shouldnt.be.ok + versions.get("Versions")[1].get("Document").should.equal(json.loads(MOCK_POLICY_2)) + versions.get("Versions")[1].get("IsDefaultVersion").shouldnt.be.ok + versions.get("Versions")[2].get("Document").should.equal(json.loads(MOCK_POLICY_3)) + versions.get("Versions")[2].get("IsDefaultVersion").should.be.ok @mock_iam def test_get_policy(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") response = conn.create_policy( - PolicyName="TestGetPolicy", - PolicyDocument=MOCK_POLICY) - policy = conn.get_policy( - PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicy") - policy['Policy']['Arn'].should.equal("arn:aws:iam::123456789012:policy/TestGetPolicy") + PolicyName="TestGetPolicy", PolicyDocument=MOCK_POLICY + ) + policy = conn.get_policy(PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicy") + policy["Policy"]["Arn"].should.equal( + "arn:aws:iam::123456789012:policy/TestGetPolicy" + ) @mock_iam def test_get_aws_managed_policy(): - conn = boto3.client('iam', region_name='us-east-1') - managed_policy_arn = 'arn:aws:iam::aws:policy/IAMUserChangePassword' - managed_policy_create_date = datetime.strptime("2016-11-15T00:25:16+00:00", "%Y-%m-%dT%H:%M:%S+00:00") - policy = conn.get_policy( - PolicyArn=managed_policy_arn) - policy['Policy']['Arn'].should.equal(managed_policy_arn) - policy['Policy']['CreateDate'].replace(tzinfo=None).should.equal(managed_policy_create_date) + conn = boto3.client("iam", region_name="us-east-1") + managed_policy_arn = "arn:aws:iam::aws:policy/IAMUserChangePassword" + managed_policy_create_date = datetime.strptime( + "2016-11-15T00:25:16+00:00", "%Y-%m-%dT%H:%M:%S+00:00" + ) + policy = conn.get_policy(PolicyArn=managed_policy_arn) + policy["Policy"]["Arn"].should.equal(managed_policy_arn) + policy["Policy"]["CreateDate"].replace(tzinfo=None).should.equal( + managed_policy_create_date + ) @mock_iam def test_get_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_policy( - PolicyName="TestGetPolicyVersion", - PolicyDocument=MOCK_POLICY) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_policy(PolicyName="TestGetPolicyVersion", PolicyDocument=MOCK_POLICY) version = conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyDocument=MOCK_POLICY, + ) with assert_raises(ClientError): conn.get_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion", - VersionId='v2-does-not-exist') + VersionId="v2-does-not-exist", + ) retrieved = conn.get_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestGetPolicyVersion", - VersionId=version.get('PolicyVersion').get('VersionId')) - retrieved.get('PolicyVersion').get('Document').should.equal(json.loads(MOCK_POLICY)) - retrieved.get('PolicyVersion').get('IsDefaultVersion').shouldnt.be.ok + VersionId=version.get("PolicyVersion").get("VersionId"), + ) + retrieved.get("PolicyVersion").get("Document").should.equal(json.loads(MOCK_POLICY)) + retrieved.get("PolicyVersion").get("IsDefaultVersion").shouldnt.be.ok @mock_iam def test_get_aws_managed_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - managed_policy_arn = 'arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole' - managed_policy_version_create_date = datetime.strptime("2015-04-09T15:03:43+00:00", "%Y-%m-%dT%H:%M:%S+00:00") + conn = boto3.client("iam", region_name="us-east-1") + managed_policy_arn = ( + "arn:aws:iam::aws:policy/service-role/AWSLambdaBasicExecutionRole" + ) + managed_policy_version_create_date = datetime.strptime( + "2015-04-09T15:03:43+00:00", "%Y-%m-%dT%H:%M:%S+00:00" + ) with assert_raises(ClientError): conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId='v2-does-not-exist') - retrieved = conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId="v1") - retrieved['PolicyVersion']['CreateDate'].replace(tzinfo=None).should.equal(managed_policy_version_create_date) - retrieved['PolicyVersion']['Document'].should.be.an(dict) + PolicyArn=managed_policy_arn, VersionId="v2-does-not-exist" + ) + retrieved = conn.get_policy_version(PolicyArn=managed_policy_arn, VersionId="v1") + retrieved["PolicyVersion"]["CreateDate"].replace(tzinfo=None).should.equal( + managed_policy_version_create_date + ) + retrieved["PolicyVersion"]["Document"].should.be.an(dict) @mock_iam def test_get_aws_managed_policy_v4_version(): - conn = boto3.client('iam', region_name='us-east-1') - managed_policy_arn = 'arn:aws:iam::aws:policy/job-function/SystemAdministrator' - managed_policy_version_create_date = datetime.strptime("2018-10-08T21:33:45+00:00", "%Y-%m-%dT%H:%M:%S+00:00") + conn = boto3.client("iam", region_name="us-east-1") + managed_policy_arn = "arn:aws:iam::aws:policy/job-function/SystemAdministrator" + managed_policy_version_create_date = datetime.strptime( + "2018-10-08T21:33:45+00:00", "%Y-%m-%dT%H:%M:%S+00:00" + ) with assert_raises(ClientError): conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId='v2-does-not-exist') - retrieved = conn.get_policy_version( - PolicyArn=managed_policy_arn, - VersionId="v4") - retrieved['PolicyVersion']['CreateDate'].replace(tzinfo=None).should.equal(managed_policy_version_create_date) - retrieved['PolicyVersion']['Document'].should.be.an(dict) + PolicyArn=managed_policy_arn, VersionId="v2-does-not-exist" + ) + retrieved = conn.get_policy_version(PolicyArn=managed_policy_arn, VersionId="v4") + retrieved["PolicyVersion"]["CreateDate"].replace(tzinfo=None).should.equal( + managed_policy_version_create_date + ) + retrieved["PolicyVersion"]["Document"].should.be.an(dict) @mock_iam def test_list_policy_versions(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") - conn.create_policy( - PolicyName="TestListPolicyVersions", - PolicyDocument=MOCK_POLICY) + PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions" + ) + conn.create_policy(PolicyName="TestListPolicyVersions", PolicyDocument=MOCK_POLICY) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") - versions.get('Versions')[0].get('VersionId').should.equal('v1') - versions.get('Versions')[0].get('IsDefaultVersion').should.be.ok + PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions" + ) + versions.get("Versions")[0].get("VersionId").should.equal("v1") + versions.get("Versions")[0].get("IsDefaultVersion").should.be.ok conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions", - PolicyDocument=MOCK_POLICY_2) + PolicyDocument=MOCK_POLICY_2, + ) conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions", - PolicyDocument=MOCK_POLICY_3) + PolicyDocument=MOCK_POLICY_3, + ) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions") - versions.get('Versions')[1].get('Document').should.equal(json.loads(MOCK_POLICY_2)) - versions.get('Versions')[1].get('IsDefaultVersion').shouldnt.be.ok - versions.get('Versions')[2].get('Document').should.equal(json.loads(MOCK_POLICY_3)) - versions.get('Versions')[2].get('IsDefaultVersion').shouldnt.be.ok + PolicyArn="arn:aws:iam::123456789012:policy/TestListPolicyVersions" + ) + versions.get("Versions")[1].get("Document").should.equal(json.loads(MOCK_POLICY_2)) + versions.get("Versions")[1].get("IsDefaultVersion").shouldnt.be.ok + versions.get("Versions")[2].get("Document").should.equal(json.loads(MOCK_POLICY_3)) + versions.get("Versions")[2].get("IsDefaultVersion").shouldnt.be.ok @mock_iam def test_delete_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_policy( - PolicyName="TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_policy(PolicyName="TestDeletePolicyVersion", PolicyDocument=MOCK_POLICY) conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY) + PolicyDocument=MOCK_POLICY, + ) with assert_raises(ClientError): conn.delete_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - VersionId='v2-nope-this-does-not-exist') + VersionId="v2-nope-this-does-not-exist", + ) conn.delete_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - VersionId='v2') + VersionId="v2", + ) versions = conn.list_policy_versions( - PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion") - len(versions.get('Versions')).should.equal(1) + PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion" + ) + len(versions.get("Versions")).should.equal(1) @mock_iam def test_delete_default_policy_version(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_policy( - PolicyName="TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_policy(PolicyName="TestDeletePolicyVersion", PolicyDocument=MOCK_POLICY) conn.create_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - PolicyDocument=MOCK_POLICY_2) + PolicyDocument=MOCK_POLICY_2, + ) with assert_raises(ClientError): conn.delete_policy_version( PolicyArn="arn:aws:iam::123456789012:policy/TestDeletePolicyVersion", - VersionId='v1') + VersionId="v1", + ) @mock_iam_deprecated() def test_create_user(): conn = boto.connect_iam() - conn.create_user('my-user') + conn.create_user("my-user") with assert_raises(BotoServerError): - conn.create_user('my-user') + conn.create_user("my-user") @mock_iam_deprecated() def test_get_user(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.get_user('my-user') - conn.create_user('my-user') - conn.get_user('my-user') + conn.get_user("my-user") + conn.create_user("my-user") + conn.get_user("my-user") @mock_iam() def test_update_user(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.update_user(UserName='my-user') - conn.create_user(UserName='my-user') - conn.update_user(UserName='my-user', NewPath='/new-path/', NewUserName='new-user') - response = conn.get_user(UserName='new-user') - response['User'].get('Path').should.equal('/new-path/') + conn.update_user(UserName="my-user") + conn.create_user(UserName="my-user") + conn.update_user(UserName="my-user", NewPath="/new-path/", NewUserName="new-user") + response = conn.get_user(UserName="new-user") + response["User"].get("Path").should.equal("/new-path/") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.get_user(UserName='my-user') + conn.get_user(UserName="my-user") @mock_iam_deprecated() def test_get_current_user(): """If no user is specific, IAM returns the current user""" conn = boto.connect_iam() - user = conn.get_user()['get_user_response']['get_user_result']['user'] - user['user_name'].should.equal('default_user') + user = conn.get_user()["get_user_response"]["get_user_result"]["user"] + user["user_name"].should.equal("default_user") @mock_iam() def test_list_users(): - path_prefix = '/' + path_prefix = "/" max_items = 10 - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") response = conn.list_users(PathPrefix=path_prefix, MaxItems=max_items) - user = response['Users'][0] - user['UserName'].should.equal('my-user') - user['Path'].should.equal('/') - user['Arn'].should.equal('arn:aws:iam::123456789012:user/my-user') + user = response["Users"][0] + user["UserName"].should.equal("my-user") + user["Path"].should.equal("/") + user["Arn"].should.equal("arn:aws:iam::123456789012:user/my-user") @mock_iam() def test_user_policies(): - policy_name = 'UserManagedPolicy' - user_name = 'my-user' - conn = boto3.client('iam', region_name='us-east-1') + policy_name = "UserManagedPolicy" + user_name = "my-user" + conn = boto3.client("iam", region_name="us-east-1") conn.create_user(UserName=user_name) conn.put_user_policy( - UserName=user_name, - PolicyName=policy_name, - PolicyDocument=MOCK_POLICY + UserName=user_name, PolicyName=policy_name, PolicyDocument=MOCK_POLICY ) - policy_doc = conn.get_user_policy( - UserName=user_name, - PolicyName=policy_name - ) - policy_doc['PolicyDocument'].should.equal(json.loads(MOCK_POLICY)) + policy_doc = conn.get_user_policy(UserName=user_name, PolicyName=policy_name) + policy_doc["PolicyDocument"].should.equal(json.loads(MOCK_POLICY)) policies = conn.list_user_policies(UserName=user_name) - len(policies['PolicyNames']).should.equal(1) - policies['PolicyNames'][0].should.equal(policy_name) + len(policies["PolicyNames"]).should.equal(1) + policies["PolicyNames"][0].should.equal(policy_name) - conn.delete_user_policy( - UserName=user_name, - PolicyName=policy_name - ) + conn.delete_user_policy(UserName=user_name, PolicyName=policy_name) policies = conn.list_user_policies(UserName=user_name) - len(policies['PolicyNames']).should.equal(0) + len(policies["PolicyNames"]).should.equal(0) @mock_iam_deprecated() def test_create_login_profile(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.create_login_profile('my-user', 'my-pass') - conn.create_user('my-user') - conn.create_login_profile('my-user', 'my-pass') + conn.create_login_profile("my-user", "my-pass") + conn.create_user("my-user") + conn.create_login_profile("my-user", "my-pass") with assert_raises(BotoServerError): - conn.create_login_profile('my-user', 'my-pass') + conn.create_login_profile("my-user", "my-pass") @mock_iam_deprecated() def test_delete_login_profile(): conn = boto.connect_iam() - conn.create_user('my-user') + conn.create_user("my-user") with assert_raises(BotoServerError): - conn.delete_login_profile('my-user') - conn.create_login_profile('my-user', 'my-pass') - conn.delete_login_profile('my-user') + conn.delete_login_profile("my-user") + conn.create_login_profile("my-user", "my-pass") + conn.delete_login_profile("my-user") @mock_iam() def test_create_access_key(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): - conn.create_access_key(UserName='my-user') - conn.create_user(UserName='my-user') - access_key = conn.create_access_key(UserName='my-user')["AccessKey"] - (datetime.utcnow() - access_key["CreateDate"].replace(tzinfo=None)).seconds.should.be.within(0, 10) + conn.create_access_key(UserName="my-user") + conn.create_user(UserName="my-user") + access_key = conn.create_access_key(UserName="my-user")["AccessKey"] + ( + datetime.utcnow() - access_key["CreateDate"].replace(tzinfo=None) + ).seconds.should.be.within(0, 10) access_key["AccessKeyId"].should.have.length_of(20) access_key["SecretAccessKey"].should.have.length_of(40) assert access_key["AccessKeyId"].startswith("AKIA") @@ -705,1428 +756,1442 @@ def test_get_all_access_keys(): """If no access keys exist there should be none in the response, if an access key is present it should have the correct fields present""" conn = boto.connect_iam() - conn.create_user('my-user') - response = conn.get_all_access_keys('my-user') + conn.create_user("my-user") + response = conn.get_all_access_keys("my-user") assert_equals( - response['list_access_keys_response'][ - 'list_access_keys_result']['access_key_metadata'], - [] + response["list_access_keys_response"]["list_access_keys_result"][ + "access_key_metadata" + ], + [], ) - conn.create_access_key('my-user') - response = conn.get_all_access_keys('my-user') + conn.create_access_key("my-user") + response = conn.get_all_access_keys("my-user") assert_equals( - sorted(response['list_access_keys_response'][ - 'list_access_keys_result']['access_key_metadata'][0].keys()), - sorted(['status', 'create_date', 'user_name', 'access_key_id']) + sorted( + response["list_access_keys_response"]["list_access_keys_result"][ + "access_key_metadata" + ][0].keys() + ), + sorted(["status", "create_date", "user_name", "access_key_id"]), ) @mock_iam_deprecated() def test_delete_access_key(): conn = boto.connect_iam() - conn.create_user('my-user') - access_key_id = conn.create_access_key('my-user')['create_access_key_response'][ - 'create_access_key_result']['access_key']['access_key_id'] - conn.delete_access_key(access_key_id, 'my-user') + conn.create_user("my-user") + access_key_id = conn.create_access_key("my-user")["create_access_key_response"][ + "create_access_key_result" + ]["access_key"]["access_key_id"] + conn.delete_access_key(access_key_id, "my-user") @mock_iam() def test_mfa_devices(): # Test enable device - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") conn.enable_mfa_device( - UserName='my-user', - SerialNumber='123456789', - AuthenticationCode1='234567', - AuthenticationCode2='987654' + UserName="my-user", + SerialNumber="123456789", + AuthenticationCode1="234567", + AuthenticationCode2="987654", ) # Test list mfa devices - response = conn.list_mfa_devices(UserName='my-user') - device = response['MFADevices'][0] - device['SerialNumber'].should.equal('123456789') + response = conn.list_mfa_devices(UserName="my-user") + device = response["MFADevices"][0] + device["SerialNumber"].should.equal("123456789") # Test deactivate mfa device - conn.deactivate_mfa_device(UserName='my-user', SerialNumber='123456789') - response = conn.list_mfa_devices(UserName='my-user') - len(response['MFADevices']).should.equal(0) + conn.deactivate_mfa_device(UserName="my-user", SerialNumber="123456789") + response = conn.list_mfa_devices(UserName="my-user") + len(response["MFADevices"]).should.equal(0) @mock_iam def test_create_virtual_mfa_device(): - client = boto3.client('iam', region_name='us-east-1') - response = client.create_virtual_mfa_device( - VirtualMFADeviceName='test-device' - ) - device = response['VirtualMFADevice'] + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + device = response["VirtualMFADevice"] - device['SerialNumber'].should.equal('arn:aws:iam::123456789012:mfa/test-device') - device['Base32StringSeed'].decode('ascii').should.match('[A-Z234567]') - device['QRCodePNG'].should_not.be.empty + device["SerialNumber"].should.equal("arn:aws:iam::123456789012:mfa/test-device") + device["Base32StringSeed"].decode("ascii").should.match("[A-Z234567]") + device["QRCodePNG"].should_not.be.empty response = client.create_virtual_mfa_device( - Path='/', - VirtualMFADeviceName='test-device-2' + Path="/", VirtualMFADeviceName="test-device-2" ) - device = response['VirtualMFADevice'] + device = response["VirtualMFADevice"] - device['SerialNumber'].should.equal('arn:aws:iam::123456789012:mfa/test-device-2') - device['Base32StringSeed'].decode('ascii').should.match('[A-Z234567]') - device['QRCodePNG'].should_not.be.empty + device["SerialNumber"].should.equal("arn:aws:iam::123456789012:mfa/test-device-2") + device["Base32StringSeed"].decode("ascii").should.match("[A-Z234567]") + device["QRCodePNG"].should_not.be.empty response = client.create_virtual_mfa_device( - Path='/test/', - VirtualMFADeviceName='test-device' + Path="/test/", VirtualMFADeviceName="test-device" ) - device = response['VirtualMFADevice'] + device = response["VirtualMFADevice"] - device['SerialNumber'].should.equal('arn:aws:iam::123456789012:mfa/test/test-device') - device['Base32StringSeed'].decode('ascii').should.match('[A-Z234567]') - device['QRCodePNG'].should_not.be.empty + device["SerialNumber"].should.equal( + "arn:aws:iam::123456789012:mfa/test/test-device" + ) + device["Base32StringSeed"].decode("ascii").should.match("[A-Z234567]") + device["QRCodePNG"].should_not.be.empty @mock_iam def test_create_virtual_mfa_device_errors(): - client = boto3.client('iam', region_name='us-east-1') - client.create_virtual_mfa_device( - VirtualMFADeviceName='test-device' + client = boto3.client("iam", region_name="us-east-1") + client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + + client.create_virtual_mfa_device.when.called_with( + VirtualMFADeviceName="test-device" + ).should.throw( + ClientError, "MFADevice entity at the same path and name already exists." ) client.create_virtual_mfa_device.when.called_with( - VirtualMFADeviceName='test-device' + Path="test", VirtualMFADeviceName="test-device" ).should.throw( ClientError, - 'MFADevice entity at the same path and name already exists.' + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters.", ) client.create_virtual_mfa_device.when.called_with( - Path='test', - VirtualMFADeviceName='test-device' + Path="/test//test/", VirtualMFADeviceName="test-device" ).should.throw( ClientError, - 'The specified value for path is invalid. ' - 'It must begin and end with / and contain only alphanumeric characters and/or / characters.' + "The specified value for path is invalid. " + "It must begin and end with / and contain only alphanumeric characters and/or / characters.", ) + too_long_path = "/{}/".format("b" * 511) client.create_virtual_mfa_device.when.called_with( - Path='/test//test/', - VirtualMFADeviceName='test-device' + Path=too_long_path, VirtualMFADeviceName="test-device" ).should.throw( ClientError, - 'The specified value for path is invalid. ' - 'It must begin and end with / and contain only alphanumeric characters and/or / characters.' - ) - - too_long_path = '/{}/'.format('b' * 511) - client.create_virtual_mfa_device.when.called_with( - Path=too_long_path, - VirtualMFADeviceName='test-device' - ).should.throw( - ClientError, - '1 validation error detected: ' + "1 validation error detected: " 'Value "{}" at "path" failed to satisfy constraint: ' - 'Member must have length less than or equal to 512' + "Member must have length less than or equal to 512", ) @mock_iam def test_delete_virtual_mfa_device(): - client = boto3.client('iam', region_name='us-east-1') - response = client.create_virtual_mfa_device( - VirtualMFADeviceName='test-device' - ) - serial_number = response['VirtualMFADevice']['SerialNumber'] + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + serial_number = response["VirtualMFADevice"]["SerialNumber"] - client.delete_virtual_mfa_device( - SerialNumber=serial_number - ) + client.delete_virtual_mfa_device(SerialNumber=serial_number) response = client.list_virtual_mfa_devices() - response['VirtualMFADevices'].should.have.length_of(0) - response['IsTruncated'].should_not.be.ok + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok @mock_iam def test_delete_virtual_mfa_device_errors(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") - serial_number = 'arn:aws:iam::123456789012:mfa/not-existing' + serial_number = "arn:aws:iam::123456789012:mfa/not-existing" client.delete_virtual_mfa_device.when.called_with( SerialNumber=serial_number ).should.throw( ClientError, - 'VirtualMFADevice with serial number {0} doesn\'t exist.'.format(serial_number) + "VirtualMFADevice with serial number {0} doesn't exist.".format(serial_number), ) @mock_iam def test_list_virtual_mfa_devices(): - client = boto3.client('iam', region_name='us-east-1') - response = client.create_virtual_mfa_device( - VirtualMFADeviceName='test-device' - ) - serial_number_1 = response['VirtualMFADevice']['SerialNumber'] + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + serial_number_1 = response["VirtualMFADevice"]["SerialNumber"] response = client.create_virtual_mfa_device( - Path='/test/', - VirtualMFADeviceName='test-device' + Path="/test/", VirtualMFADeviceName="test-device" ) - serial_number_2 = response['VirtualMFADevice']['SerialNumber'] + serial_number_2 = response["VirtualMFADevice"]["SerialNumber"] response = client.list_virtual_mfa_devices() - response['VirtualMFADevices'].should.equal([ - { - 'SerialNumber': serial_number_1 - }, - { - 'SerialNumber': serial_number_2 - } - ]) - response['IsTruncated'].should_not.be.ok + response["VirtualMFADevices"].should.equal( + [{"SerialNumber": serial_number_1}, {"SerialNumber": serial_number_2}] + ) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Assigned") + + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Unassigned") + + response["VirtualMFADevices"].should.equal( + [{"SerialNumber": serial_number_1}, {"SerialNumber": serial_number_2}] + ) + response["IsTruncated"].should_not.be.ok + + response = client.list_virtual_mfa_devices(AssignmentStatus="Any", MaxItems=1) + + response["VirtualMFADevices"].should.equal([{"SerialNumber": serial_number_1}]) + response["IsTruncated"].should.be.ok + response["Marker"].should.equal("1") response = client.list_virtual_mfa_devices( - AssignmentStatus='Assigned' + AssignmentStatus="Any", Marker=response["Marker"] ) - response['VirtualMFADevices'].should.have.length_of(0) - response['IsTruncated'].should_not.be.ok - - response = client.list_virtual_mfa_devices( - AssignmentStatus='Unassigned' - ) - - response['VirtualMFADevices'].should.equal([ - { - 'SerialNumber': serial_number_1 - }, - { - 'SerialNumber': serial_number_2 - } - ]) - response['IsTruncated'].should_not.be.ok - - response = client.list_virtual_mfa_devices( - AssignmentStatus='Any', - MaxItems=1 - ) - - response['VirtualMFADevices'].should.equal([ - { - 'SerialNumber': serial_number_1 - } - ]) - response['IsTruncated'].should.be.ok - response['Marker'].should.equal('1') - - response = client.list_virtual_mfa_devices( - AssignmentStatus='Any', - Marker=response['Marker'] - ) - - response['VirtualMFADevices'].should.equal([ - { - 'SerialNumber': serial_number_2 - } - ]) - response['IsTruncated'].should_not.be.ok + response["VirtualMFADevices"].should.equal([{"SerialNumber": serial_number_2}]) + response["IsTruncated"].should_not.be.ok @mock_iam def test_list_virtual_mfa_devices_errors(): - client = boto3.client('iam', region_name='us-east-1') - client.create_virtual_mfa_device( - VirtualMFADeviceName='test-device' - ) + client = boto3.client("iam", region_name="us-east-1") + client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") - client.list_virtual_mfa_devices.when.called_with( - Marker='100' - ).should.throw( - ClientError, - 'Invalid Marker.' + client.list_virtual_mfa_devices.when.called_with(Marker="100").should.throw( + ClientError, "Invalid Marker." ) @mock_iam def test_enable_virtual_mfa_device(): - client = boto3.client('iam', region_name='us-east-1') - response = client.create_virtual_mfa_device( - VirtualMFADeviceName='test-device' - ) - serial_number = response['VirtualMFADevice']['SerialNumber'] + client = boto3.client("iam", region_name="us-east-1") + response = client.create_virtual_mfa_device(VirtualMFADeviceName="test-device") + serial_number = response["VirtualMFADevice"]["SerialNumber"] - client.create_user(UserName='test-user') + client.create_user(UserName="test-user") client.enable_mfa_device( - UserName='test-user', + UserName="test-user", SerialNumber=serial_number, - AuthenticationCode1='234567', - AuthenticationCode2='987654' + AuthenticationCode1="234567", + AuthenticationCode2="987654", ) - response = client.list_virtual_mfa_devices( - AssignmentStatus='Unassigned' - ) + response = client.list_virtual_mfa_devices(AssignmentStatus="Unassigned") - response['VirtualMFADevices'].should.have.length_of(0) - response['IsTruncated'].should_not.be.ok + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok - response = client.list_virtual_mfa_devices( - AssignmentStatus='Assigned' - ) + response = client.list_virtual_mfa_devices(AssignmentStatus="Assigned") - device = response['VirtualMFADevices'][0] - device['SerialNumber'].should.equal(serial_number) - device['User']['Path'].should.equal('/') - device['User']['UserName'].should.equal('test-user') - device['User']['UserId'].should_not.be.empty - device['User']['Arn'].should.equal('arn:aws:iam::123456789012:user/test-user') - device['User']['CreateDate'].should.be.a(datetime) - device['EnableDate'].should.be.a(datetime) - response['IsTruncated'].should_not.be.ok + device = response["VirtualMFADevices"][0] + device["SerialNumber"].should.equal(serial_number) + device["User"]["Path"].should.equal("/") + device["User"]["UserName"].should.equal("test-user") + device["User"]["UserId"].should_not.be.empty + device["User"]["Arn"].should.equal("arn:aws:iam::123456789012:user/test-user") + device["User"]["CreateDate"].should.be.a(datetime) + device["EnableDate"].should.be.a(datetime) + response["IsTruncated"].should_not.be.ok - client.deactivate_mfa_device( - UserName='test-user', - SerialNumber=serial_number - ) + client.deactivate_mfa_device(UserName="test-user", SerialNumber=serial_number) - response = client.list_virtual_mfa_devices( - AssignmentStatus='Assigned' - ) + response = client.list_virtual_mfa_devices(AssignmentStatus="Assigned") - response['VirtualMFADevices'].should.have.length_of(0) - response['IsTruncated'].should_not.be.ok + response["VirtualMFADevices"].should.have.length_of(0) + response["IsTruncated"].should_not.be.ok - response = client.list_virtual_mfa_devices( - AssignmentStatus = 'Unassigned' - ) + response = client.list_virtual_mfa_devices(AssignmentStatus="Unassigned") - response['VirtualMFADevices'].should.equal([ - { - 'SerialNumber': serial_number - } - ]) - response['IsTruncated'].should_not.be.ok + response["VirtualMFADevices"].should.equal([{"SerialNumber": serial_number}]) + response["IsTruncated"].should_not.be.ok @mock_iam_deprecated() def test_delete_user_deprecated(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.delete_user('my-user') - conn.create_user('my-user') - conn.delete_user('my-user') + conn.delete_user("my-user") + conn.create_user("my-user") + conn.delete_user("my-user") @mock_iam() def test_delete_user(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.delete_user(UserName='my-user') + conn.delete_user(UserName="my-user") # Test deletion failure with a managed policy - conn.create_user(UserName='my-user') - response = conn.create_policy(PolicyName="my-managed-policy", PolicyDocument=MOCK_POLICY) - conn.attach_user_policy(PolicyArn=response['Policy']['Arn'], UserName="my-user") + conn.create_user(UserName="my-user") + response = conn.create_policy( + PolicyName="my-managed-policy", PolicyDocument=MOCK_POLICY + ) + conn.attach_user_policy(PolicyArn=response["Policy"]["Arn"], UserName="my-user") with assert_raises(conn.exceptions.DeleteConflictException): - conn.delete_user(UserName='my-user') - conn.detach_user_policy(PolicyArn=response['Policy']['Arn'], UserName="my-user") - conn.delete_policy(PolicyArn=response['Policy']['Arn']) - conn.delete_user(UserName='my-user') + conn.delete_user(UserName="my-user") + conn.detach_user_policy(PolicyArn=response["Policy"]["Arn"], UserName="my-user") + conn.delete_policy(PolicyArn=response["Policy"]["Arn"]) + conn.delete_user(UserName="my-user") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.get_user(UserName='my-user') + conn.get_user(UserName="my-user") # Test deletion failure with an inline policy - conn.create_user(UserName='my-user') + conn.create_user(UserName="my-user") conn.put_user_policy( - UserName='my-user', - PolicyName='my-user-policy', - PolicyDocument=MOCK_POLICY + UserName="my-user", PolicyName="my-user-policy", PolicyDocument=MOCK_POLICY ) with assert_raises(conn.exceptions.DeleteConflictException): - conn.delete_user(UserName='my-user') - conn.delete_user_policy(UserName='my-user', PolicyName='my-user-policy') - conn.delete_user(UserName='my-user') + conn.delete_user(UserName="my-user") + conn.delete_user_policy(UserName="my-user", PolicyName="my-user-policy") + conn.delete_user(UserName="my-user") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.get_user(UserName='my-user') + conn.get_user(UserName="my-user") # Test deletion with no conflicts - conn.create_user(UserName='my-user') - conn.delete_user(UserName='my-user') + conn.create_user(UserName="my-user") + conn.delete_user(UserName="my-user") with assert_raises(conn.exceptions.NoSuchEntityException): - conn.get_user(UserName='my-user') + conn.get_user(UserName="my-user") @mock_iam_deprecated() def test_generate_credential_report(): conn = boto.connect_iam() result = conn.generate_credential_report() - result['generate_credential_report_response'][ - 'generate_credential_report_result']['state'].should.equal('STARTED') + result["generate_credential_report_response"]["generate_credential_report_result"][ + "state" + ].should.equal("STARTED") result = conn.generate_credential_report() - result['generate_credential_report_response'][ - 'generate_credential_report_result']['state'].should.equal('COMPLETE') + result["generate_credential_report_response"]["generate_credential_report_result"][ + "state" + ].should.equal("COMPLETE") + @mock_iam def test_boto3_generate_credential_report(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") result = conn.generate_credential_report() - result['State'].should.equal('STARTED') + result["State"].should.equal("STARTED") result = conn.generate_credential_report() - result['State'].should.equal('COMPLETE') + result["State"].should.equal("COMPLETE") @mock_iam_deprecated() def test_get_credential_report(): conn = boto.connect_iam() - conn.create_user('my-user') + conn.create_user("my-user") with assert_raises(BotoServerError): conn.get_credential_report() result = conn.generate_credential_report() - while result['generate_credential_report_response']['generate_credential_report_result']['state'] != 'COMPLETE': + while ( + result["generate_credential_report_response"][ + "generate_credential_report_result" + ]["state"] + != "COMPLETE" + ): result = conn.generate_credential_report() result = conn.get_credential_report() - report = base64.b64decode(result['get_credential_report_response'][ - 'get_credential_report_result']['content'].encode('ascii')).decode('ascii') - report.should.match(r'.*my-user.*') + report = base64.b64decode( + result["get_credential_report_response"]["get_credential_report_result"][ + "content" + ].encode("ascii") + ).decode("ascii") + report.should.match(r".*my-user.*") @mock_iam def test_boto3_get_credential_report(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_user(UserName='my-user') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_user(UserName="my-user") with assert_raises(ClientError): conn.get_credential_report() result = conn.generate_credential_report() - while result['State'] != 'COMPLETE': + while result["State"] != "COMPLETE": result = conn.generate_credential_report() result = conn.get_credential_report() - report = result['Content'].decode('utf-8') - report.should.match(r'.*my-user.*') + report = result["Content"].decode("utf-8") + report.should.match(r".*my-user.*") -@requires_boto_gte('2.39') +@requires_boto_gte("2.39") @mock_iam_deprecated() def test_managed_policy(): conn = boto.connect_iam() - conn.create_policy(policy_name='UserManagedPolicy', - policy_document=MOCK_POLICY, - path='/mypolicy/', - description='my user managed policy') + conn.create_policy( + policy_name="UserManagedPolicy", + policy_document=MOCK_POLICY, + path="/mypolicy/", + description="my user managed policy", + ) marker = 0 aws_policies = [] while marker is not None: - response = conn.list_policies(scope='AWS', marker=marker)[ - 'list_policies_response']['list_policies_result'] - for policy in response['policies']: + response = conn.list_policies(scope="AWS", marker=marker)[ + "list_policies_response" + ]["list_policies_result"] + for policy in response["policies"]: aws_policies.append(policy) - marker = response.get('marker') + marker = response.get("marker") set(p.name for p in aws_managed_policies).should.equal( - set(p['policy_name'] for p in aws_policies)) + set(p["policy_name"] for p in aws_policies) + ) - user_policies = conn.list_policies(scope='Local')['list_policies_response'][ - 'list_policies_result']['policies'] - set(['UserManagedPolicy']).should.equal( - set(p['policy_name'] for p in user_policies)) + user_policies = conn.list_policies(scope="Local")["list_policies_response"][ + "list_policies_result" + ]["policies"] + set(["UserManagedPolicy"]).should.equal( + set(p["policy_name"] for p in user_policies) + ) marker = 0 all_policies = [] while marker is not None: - response = conn.list_policies(marker=marker)[ - 'list_policies_response']['list_policies_result'] - for policy in response['policies']: + response = conn.list_policies(marker=marker)["list_policies_response"][ + "list_policies_result" + ] + for policy in response["policies"]: all_policies.append(policy) - marker = response.get('marker') - set(p['policy_name'] for p in aws_policies + - user_policies).should.equal(set(p['policy_name'] for p in all_policies)) + marker = response.get("marker") + set(p["policy_name"] for p in aws_policies + user_policies).should.equal( + set(p["policy_name"] for p in all_policies) + ) - role_name = 'my-role' - conn.create_role(role_name, assume_role_policy_document={ - 'policy': 'test'}, path="my-path") - for policy_name in ['AmazonElasticMapReduceRole', - 'AmazonElasticMapReduceforEC2Role']: - policy_arn = 'arn:aws:iam::aws:policy/service-role/' + policy_name + role_name = "my-role" + conn.create_role( + role_name, assume_role_policy_document={"policy": "test"}, path="my-path" + ) + for policy_name in [ + "AmazonElasticMapReduceRole", + "AmazonElasticMapReduceforEC2Role", + ]: + policy_arn = "arn:aws:iam::aws:policy/service-role/" + policy_name conn.attach_role_policy(policy_arn, role_name) - rows = conn.list_policies(only_attached=True)['list_policies_response'][ - 'list_policies_result']['policies'] + rows = conn.list_policies(only_attached=True)["list_policies_response"][ + "list_policies_result" + ]["policies"] rows.should.have.length_of(2) for x in rows: - int(x['attachment_count']).should.be.greater_than(0) + int(x["attachment_count"]).should.be.greater_than(0) # boto has not implemented this end point but accessible this way - resp = conn.get_response('ListAttachedRolePolicies', - {'RoleName': role_name}, - list_marker='AttachedPolicies') - resp['list_attached_role_policies_response']['list_attached_role_policies_result'][ - 'attached_policies'].should.have.length_of(2) + resp = conn.get_response( + "ListAttachedRolePolicies", + {"RoleName": role_name}, + list_marker="AttachedPolicies", + ) + resp["list_attached_role_policies_response"]["list_attached_role_policies_result"][ + "attached_policies" + ].should.have.length_of(2) conn.detach_role_policy( - "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", - role_name) - rows = conn.list_policies(only_attached=True)['list_policies_response'][ - 'list_policies_result']['policies'] + "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", role_name + ) + rows = conn.list_policies(only_attached=True)["list_policies_response"][ + "list_policies_result" + ]["policies"] rows.should.have.length_of(1) for x in rows: - int(x['attachment_count']).should.be.greater_than(0) + int(x["attachment_count"]).should.be.greater_than(0) # boto has not implemented this end point but accessible this way - resp = conn.get_response('ListAttachedRolePolicies', - {'RoleName': role_name}, - list_marker='AttachedPolicies') - resp['list_attached_role_policies_response']['list_attached_role_policies_result'][ - 'attached_policies'].should.have.length_of(1) + resp = conn.get_response( + "ListAttachedRolePolicies", + {"RoleName": role_name}, + list_marker="AttachedPolicies", + ) + resp["list_attached_role_policies_response"]["list_attached_role_policies_result"][ + "attached_policies" + ].should.have.length_of(1) with assert_raises(BotoServerError): conn.detach_role_policy( - "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", - role_name) + "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceRole", role_name + ) with assert_raises(BotoServerError): - conn.detach_role_policy( - "arn:aws:iam::aws:policy/Nonexistent", role_name) + conn.detach_role_policy("arn:aws:iam::aws:policy/Nonexistent", role_name) @mock_iam def test_boto3_create_login_profile(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): - conn.create_login_profile(UserName='my-user', Password='Password') + conn.create_login_profile(UserName="my-user", Password="Password") - conn.create_user(UserName='my-user') - conn.create_login_profile(UserName='my-user', Password='Password') + conn.create_user(UserName="my-user") + conn.create_login_profile(UserName="my-user", Password="Password") with assert_raises(ClientError): - conn.create_login_profile(UserName='my-user', Password='Password') + conn.create_login_profile(UserName="my-user", Password="Password") @mock_iam() def test_attach_detach_user_policy(): - iam = boto3.resource('iam', region_name='us-east-1') - client = boto3.client('iam', region_name='us-east-1') + iam = boto3.resource("iam", region_name="us-east-1") + client = boto3.client("iam", region_name="us-east-1") - user = iam.create_user(UserName='test-user') + user = iam.create_user(UserName="test-user") - policy_name = 'UserAttachedPolicy' - policy = iam.create_policy(PolicyName=policy_name, - PolicyDocument=MOCK_POLICY, - Path='/mypolicy/', - Description='my user attached policy') + policy_name = "UserAttachedPolicy" + policy = iam.create_policy( + PolicyName=policy_name, + PolicyDocument=MOCK_POLICY, + Path="/mypolicy/", + Description="my user attached policy", + ) client.attach_user_policy(UserName=user.name, PolicyArn=policy.arn) resp = client.list_attached_user_policies(UserName=user.name) - resp['AttachedPolicies'].should.have.length_of(1) - attached_policy = resp['AttachedPolicies'][0] - attached_policy['PolicyArn'].should.equal(policy.arn) - attached_policy['PolicyName'].should.equal(policy_name) + resp["AttachedPolicies"].should.have.length_of(1) + attached_policy = resp["AttachedPolicies"][0] + attached_policy["PolicyArn"].should.equal(policy.arn) + attached_policy["PolicyName"].should.equal(policy_name) client.detach_user_policy(UserName=user.name, PolicyArn=policy.arn) resp = client.list_attached_user_policies(UserName=user.name) - resp['AttachedPolicies'].should.have.length_of(0) + resp["AttachedPolicies"].should.have.length_of(0) @mock_iam def test_update_access_key(): - iam = boto3.resource('iam', region_name='us-east-1') + iam = boto3.resource("iam", region_name="us-east-1") client = iam.meta.client - username = 'test-user' + username = "test-user" iam.create_user(UserName=username) with assert_raises(ClientError): - client.update_access_key(UserName=username, - AccessKeyId='non-existent-key', - Status='Inactive') - key = client.create_access_key(UserName=username)['AccessKey'] - client.update_access_key(UserName=username, - AccessKeyId=key['AccessKeyId'], - Status='Inactive') + client.update_access_key( + UserName=username, AccessKeyId="non-existent-key", Status="Inactive" + ) + key = client.create_access_key(UserName=username)["AccessKey"] + client.update_access_key( + UserName=username, AccessKeyId=key["AccessKeyId"], Status="Inactive" + ) resp = client.list_access_keys(UserName=username) - resp['AccessKeyMetadata'][0]['Status'].should.equal('Inactive') + resp["AccessKeyMetadata"][0]["Status"].should.equal("Inactive") @mock_iam def test_get_access_key_last_used(): - iam = boto3.resource('iam', region_name='us-east-1') + iam = boto3.resource("iam", region_name="us-east-1") client = iam.meta.client - username = 'test-user' + username = "test-user" iam.create_user(UserName=username) with assert_raises(ClientError): - client.get_access_key_last_used(AccessKeyId='non-existent-key-id') - create_key_response = client.create_access_key(UserName=username)['AccessKey'] - resp = client.get_access_key_last_used(AccessKeyId=create_key_response['AccessKeyId']) + client.get_access_key_last_used(AccessKeyId="non-existent-key-id") + create_key_response = client.create_access_key(UserName=username)["AccessKey"] + resp = client.get_access_key_last_used( + AccessKeyId=create_key_response["AccessKeyId"] + ) - datetime.strftime(resp["AccessKeyLastUsed"]["LastUsedDate"], "%Y-%m-%d").should.equal(datetime.strftime( - datetime.utcnow(), - "%Y-%m-%d" - )) + datetime.strftime( + resp["AccessKeyLastUsed"]["LastUsedDate"], "%Y-%m-%d" + ).should.equal(datetime.strftime(datetime.utcnow(), "%Y-%m-%d")) resp["UserName"].should.equal(create_key_response["UserName"]) @mock_iam def test_get_account_authorization_details(): - test_policy = json.dumps({ - "Version": "2012-10-17", - "Statement": [ - { - "Action": "s3:ListBucket", - "Resource": "*", - "Effect": "Allow", - } - ] - }) + test_policy = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + {"Action": "s3:ListBucket", "Resource": "*", "Effect": "Allow"} + ], + } + ) - conn = boto3.client('iam', region_name='us-east-1') - boundary = 'arn:aws:iam::123456789012:policy/boundary' - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/", Description='testing', PermissionsBoundary=boundary) - conn.create_user(Path='/', UserName='testUser') - conn.create_group(Path='/', GroupName='testGroup') + conn = boto3.client("iam", region_name="us-east-1") + boundary = "arn:aws:iam::123456789012:policy/boundary" + conn.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="some policy", + Path="/my-path/", + Description="testing", + PermissionsBoundary=boundary, + ) + conn.create_user(Path="/", UserName="testUser") + conn.create_group(Path="/", GroupName="testGroup") conn.create_policy( - PolicyName='testPolicy', - Path='/', + PolicyName="testPolicy", + Path="/", PolicyDocument=test_policy, - Description='Test Policy' + Description="Test Policy", ) # Attach things to the user and group: - conn.put_user_policy(UserName='testUser', PolicyName='testPolicy', PolicyDocument=test_policy) - conn.put_group_policy(GroupName='testGroup', PolicyName='testPolicy', PolicyDocument=test_policy) + conn.put_user_policy( + UserName="testUser", PolicyName="testPolicy", PolicyDocument=test_policy + ) + conn.put_group_policy( + GroupName="testGroup", PolicyName="testPolicy", PolicyDocument=test_policy + ) - conn.attach_user_policy(UserName='testUser', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') - conn.attach_group_policy(GroupName='testGroup', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.attach_user_policy( + UserName="testUser", PolicyArn="arn:aws:iam::123456789012:policy/testPolicy" + ) + conn.attach_group_policy( + GroupName="testGroup", PolicyArn="arn:aws:iam::123456789012:policy/testPolicy" + ) - conn.add_user_to_group(UserName='testUser', GroupName='testGroup') + conn.add_user_to_group(UserName="testUser", GroupName="testGroup") # Add things to the role: - conn.create_instance_profile(InstanceProfileName='ipn') - conn.add_role_to_instance_profile(InstanceProfileName='ipn', RoleName='my-role') - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) - conn.put_role_policy(RoleName='my-role', PolicyName='test-policy', PolicyDocument=test_policy) - conn.attach_role_policy(RoleName='my-role', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.create_instance_profile(InstanceProfileName="ipn") + conn.add_role_to_instance_profile(InstanceProfileName="ipn", RoleName="my-role") + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) + conn.put_role_policy( + RoleName="my-role", PolicyName="test-policy", PolicyDocument=test_policy + ) + conn.attach_role_policy( + RoleName="my-role", PolicyArn="arn:aws:iam::123456789012:policy/testPolicy" + ) - result = conn.get_account_authorization_details(Filter=['Role']) - assert len(result['RoleDetailList']) == 1 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) == 0 - assert len(result['RoleDetailList'][0]['InstanceProfileList']) == 1 - assert result['RoleDetailList'][0]['InstanceProfileList'][0]['Roles'][0]['Description'] == 'testing' - assert result['RoleDetailList'][0]['InstanceProfileList'][0]['Roles'][0]['PermissionsBoundary'] == { - 'PermissionsBoundaryType': 'PermissionsBoundaryPolicy', - 'PermissionsBoundaryArn': 'arn:aws:iam::123456789012:policy/boundary' + result = conn.get_account_authorization_details(Filter=["Role"]) + assert len(result["RoleDetailList"]) == 1 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) == 0 + assert len(result["RoleDetailList"][0]["InstanceProfileList"]) == 1 + assert ( + result["RoleDetailList"][0]["InstanceProfileList"][0]["Roles"][0]["Description"] + == "testing" + ) + assert result["RoleDetailList"][0]["InstanceProfileList"][0]["Roles"][0][ + "PermissionsBoundary" + ] == { + "PermissionsBoundaryType": "PermissionsBoundaryPolicy", + "PermissionsBoundaryArn": "arn:aws:iam::123456789012:policy/boundary", } - assert len(result['RoleDetailList'][0]['Tags']) == 2 - assert len(result['RoleDetailList'][0]['RolePolicyList']) == 1 - assert len(result['RoleDetailList'][0]['AttachedManagedPolicies']) == 1 - assert result['RoleDetailList'][0]['AttachedManagedPolicies'][0]['PolicyName'] == 'testPolicy' - assert result['RoleDetailList'][0]['AttachedManagedPolicies'][0]['PolicyArn'] == \ - 'arn:aws:iam::123456789012:policy/testPolicy' + assert len(result["RoleDetailList"][0]["Tags"]) == 2 + assert len(result["RoleDetailList"][0]["RolePolicyList"]) == 1 + assert len(result["RoleDetailList"][0]["AttachedManagedPolicies"]) == 1 + assert ( + result["RoleDetailList"][0]["AttachedManagedPolicies"][0]["PolicyName"] + == "testPolicy" + ) + assert ( + result["RoleDetailList"][0]["AttachedManagedPolicies"][0]["PolicyArn"] + == "arn:aws:iam::123456789012:policy/testPolicy" + ) - result = conn.get_account_authorization_details(Filter=['User']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 1 - assert len(result['UserDetailList'][0]['GroupList']) == 1 - assert len(result['UserDetailList'][0]['AttachedManagedPolicies']) == 1 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) == 0 - assert result['UserDetailList'][0]['AttachedManagedPolicies'][0]['PolicyName'] == 'testPolicy' - assert result['UserDetailList'][0]['AttachedManagedPolicies'][0]['PolicyArn'] == \ - 'arn:aws:iam::123456789012:policy/testPolicy' + result = conn.get_account_authorization_details(Filter=["User"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 1 + assert len(result["UserDetailList"][0]["GroupList"]) == 1 + assert len(result["UserDetailList"][0]["AttachedManagedPolicies"]) == 1 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) == 0 + assert ( + result["UserDetailList"][0]["AttachedManagedPolicies"][0]["PolicyName"] + == "testPolicy" + ) + assert ( + result["UserDetailList"][0]["AttachedManagedPolicies"][0]["PolicyArn"] + == "arn:aws:iam::123456789012:policy/testPolicy" + ) - result = conn.get_account_authorization_details(Filter=['Group']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 1 - assert len(result['GroupDetailList'][0]['GroupPolicyList']) == 1 - assert len(result['GroupDetailList'][0]['AttachedManagedPolicies']) == 1 - assert len(result['Policies']) == 0 - assert result['GroupDetailList'][0]['AttachedManagedPolicies'][0]['PolicyName'] == 'testPolicy' - assert result['GroupDetailList'][0]['AttachedManagedPolicies'][0]['PolicyArn'] == \ - 'arn:aws:iam::123456789012:policy/testPolicy' + result = conn.get_account_authorization_details(Filter=["Group"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 1 + assert len(result["GroupDetailList"][0]["GroupPolicyList"]) == 1 + assert len(result["GroupDetailList"][0]["AttachedManagedPolicies"]) == 1 + assert len(result["Policies"]) == 0 + assert ( + result["GroupDetailList"][0]["AttachedManagedPolicies"][0]["PolicyName"] + == "testPolicy" + ) + assert ( + result["GroupDetailList"][0]["AttachedManagedPolicies"][0]["PolicyArn"] + == "arn:aws:iam::123456789012:policy/testPolicy" + ) - result = conn.get_account_authorization_details(Filter=['LocalManagedPolicy']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) == 1 - assert len(result['Policies'][0]['PolicyVersionList']) == 1 + result = conn.get_account_authorization_details(Filter=["LocalManagedPolicy"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) == 1 + assert len(result["Policies"][0]["PolicyVersionList"]) == 1 # Check for greater than 1 since this should always be greater than one but might change. # See iam/aws_managed_policies.py - result = conn.get_account_authorization_details(Filter=['AWSManagedPolicy']) - assert len(result['RoleDetailList']) == 0 - assert len(result['UserDetailList']) == 0 - assert len(result['GroupDetailList']) == 0 - assert len(result['Policies']) > 1 + result = conn.get_account_authorization_details(Filter=["AWSManagedPolicy"]) + assert len(result["RoleDetailList"]) == 0 + assert len(result["UserDetailList"]) == 0 + assert len(result["GroupDetailList"]) == 0 + assert len(result["Policies"]) > 1 result = conn.get_account_authorization_details() - assert len(result['RoleDetailList']) == 1 - assert len(result['UserDetailList']) == 1 - assert len(result['GroupDetailList']) == 1 - assert len(result['Policies']) > 1 + assert len(result["RoleDetailList"]) == 1 + assert len(result["UserDetailList"]) == 1 + assert len(result["GroupDetailList"]) == 1 + assert len(result["Policies"]) > 1 @mock_iam def test_signing_certs(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") # Create the IAM user first: - client.create_user(UserName='testing') + client.create_user(UserName="testing") # Upload the cert: - resp = client.upload_signing_certificate(UserName='testing', CertificateBody=MOCK_CERT)['Certificate'] - cert_id = resp['CertificateId'] + resp = client.upload_signing_certificate( + UserName="testing", CertificateBody=MOCK_CERT + )["Certificate"] + cert_id = resp["CertificateId"] - assert resp['UserName'] == 'testing' - assert resp['Status'] == 'Active' - assert resp['CertificateBody'] == MOCK_CERT - assert resp['CertificateId'] + assert resp["UserName"] == "testing" + assert resp["Status"] == "Active" + assert resp["CertificateBody"] == MOCK_CERT + assert resp["CertificateId"] # Upload a the cert with an invalid body: with assert_raises(ClientError) as ce: - client.upload_signing_certificate(UserName='testing', CertificateBody='notacert') - assert ce.exception.response['Error']['Code'] == 'MalformedCertificate' + client.upload_signing_certificate( + UserName="testing", CertificateBody="notacert" + ) + assert ce.exception.response["Error"]["Code"] == "MalformedCertificate" # Upload with an invalid user: with assert_raises(ClientError): - client.upload_signing_certificate(UserName='notauser', CertificateBody=MOCK_CERT) + client.upload_signing_certificate( + UserName="notauser", CertificateBody=MOCK_CERT + ) # Update: - client.update_signing_certificate(UserName='testing', CertificateId=cert_id, Status='Inactive') + client.update_signing_certificate( + UserName="testing", CertificateId=cert_id, Status="Inactive" + ) with assert_raises(ClientError): - client.update_signing_certificate(UserName='notauser', CertificateId=cert_id, Status='Inactive') + client.update_signing_certificate( + UserName="notauser", CertificateId=cert_id, Status="Inactive" + ) with assert_raises(ClientError) as ce: - client.update_signing_certificate(UserName='testing', CertificateId='x' * 32, Status='Inactive') + client.update_signing_certificate( + UserName="testing", CertificateId="x" * 32, Status="Inactive" + ) - assert ce.exception.response['Error']['Message'] == 'The Certificate with id {id} cannot be found.'.format( - id='x' * 32) + assert ce.exception.response["Error"][ + "Message" + ] == "The Certificate with id {id} cannot be found.".format(id="x" * 32) # List the certs: - resp = client.list_signing_certificates(UserName='testing')['Certificates'] + resp = client.list_signing_certificates(UserName="testing")["Certificates"] assert len(resp) == 1 - assert resp[0]['CertificateBody'] == MOCK_CERT - assert resp[0]['Status'] == 'Inactive' # Changed with the update call above. + assert resp[0]["CertificateBody"] == MOCK_CERT + assert resp[0]["Status"] == "Inactive" # Changed with the update call above. with assert_raises(ClientError): - client.list_signing_certificates(UserName='notauser') + client.list_signing_certificates(UserName="notauser") # Delete: - client.delete_signing_certificate(UserName='testing', CertificateId=cert_id) + client.delete_signing_certificate(UserName="testing", CertificateId=cert_id) with assert_raises(ClientError): - client.delete_signing_certificate(UserName='notauser', CertificateId=cert_id) + client.delete_signing_certificate(UserName="notauser", CertificateId=cert_id) @mock_iam() def test_create_saml_provider(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") response = conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 + Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024 + ) + response["SAMLProviderArn"].should.equal( + "arn:aws:iam::123456789012:saml-provider/TestSAMLProvider" ) - response['SAMLProviderArn'].should.equal("arn:aws:iam::123456789012:saml-provider/TestSAMLProvider") @mock_iam() def test_get_saml_provider(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") saml_provider_create = conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 + Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024 ) response = conn.get_saml_provider( - SAMLProviderArn=saml_provider_create['SAMLProviderArn'] + SAMLProviderArn=saml_provider_create["SAMLProviderArn"] ) - response['SAMLMetadataDocument'].should.equal('a' * 1024) + response["SAMLMetadataDocument"].should.equal("a" * 1024) @mock_iam() def test_list_saml_providers(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 - ) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_saml_provider(Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024) response = conn.list_saml_providers() - response['SAMLProviderList'][0]['Arn'].should.equal("arn:aws:iam::123456789012:saml-provider/TestSAMLProvider") + response["SAMLProviderList"][0]["Arn"].should.equal( + "arn:aws:iam::123456789012:saml-provider/TestSAMLProvider" + ) @mock_iam() def test_delete_saml_provider(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") saml_provider_create = conn.create_saml_provider( - Name="TestSAMLProvider", - SAMLMetadataDocument='a' * 1024 + Name="TestSAMLProvider", SAMLMetadataDocument="a" * 1024 ) response = conn.list_saml_providers() - len(response['SAMLProviderList']).should.equal(1) - conn.delete_saml_provider( - SAMLProviderArn=saml_provider_create['SAMLProviderArn'] - ) + len(response["SAMLProviderList"]).should.equal(1) + conn.delete_saml_provider(SAMLProviderArn=saml_provider_create["SAMLProviderArn"]) response = conn.list_saml_providers() - len(response['SAMLProviderList']).should.equal(0) - conn.create_user(UserName='testing') + len(response["SAMLProviderList"]).should.equal(0) + conn.create_user(UserName="testing") - cert_id = '123456789012345678901234' + cert_id = "123456789012345678901234" with assert_raises(ClientError) as ce: - conn.delete_signing_certificate(UserName='testing', CertificateId=cert_id) + conn.delete_signing_certificate(UserName="testing", CertificateId=cert_id) - assert ce.exception.response['Error']['Message'] == 'The Certificate with id {id} cannot be found.'.format( - id=cert_id) + assert ce.exception.response["Error"][ + "Message" + ] == "The Certificate with id {id} cannot be found.".format(id=cert_id) # Verify that it's not in the list: - resp = conn.list_signing_certificates(UserName='testing') - assert not resp['Certificates'] + resp = conn.list_signing_certificates(UserName="testing") + assert not resp["Certificates"] @mock_iam() def test_create_role_with_tags(): """Tests both the tag_role and get_role_tags capability""" - conn = boto3.client('iam', region_name='us-east-1') - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="{}", Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ], Description='testing') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="{}", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + Description="testing", + ) # Get role: - role = conn.get_role(RoleName='my-role')['Role'] - assert len(role['Tags']) == 2 - assert role['Tags'][0]['Key'] == 'somekey' - assert role['Tags'][0]['Value'] == 'somevalue' - assert role['Tags'][1]['Key'] == 'someotherkey' - assert role['Tags'][1]['Value'] == 'someothervalue' - assert role['Description'] == 'testing' + role = conn.get_role(RoleName="my-role")["Role"] + assert len(role["Tags"]) == 2 + assert role["Tags"][0]["Key"] == "somekey" + assert role["Tags"][0]["Value"] == "somevalue" + assert role["Tags"][1]["Key"] == "someotherkey" + assert role["Tags"][1]["Value"] == "someothervalue" + assert role["Description"] == "testing" # Empty is good: - conn.create_role(RoleName="my-role2", AssumeRolePolicyDocument="{}", Tags=[ - { - 'Key': 'somekey', - 'Value': '' - } - ]) - tags = conn.list_role_tags(RoleName='my-role2') - assert len(tags['Tags']) == 1 - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == '' + conn.create_role( + RoleName="my-role2", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "somekey", "Value": ""}], + ) + tags = conn.list_role_tags(RoleName="my-role2") + assert len(tags["Tags"]) == 1 + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "" # Test creating tags with invalid values: # With more than 50 tags: with assert_raises(ClientError) as ce: - too_many_tags = list(map(lambda x: {'Key': str(x), 'Value': str(x)}, range(0, 51))) - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=too_many_tags) - assert 'failed to satisfy constraint: Member must have length less than or equal to 50.' \ - in ce.exception.response['Error']['Message'] + too_many_tags = list( + map(lambda x: {"Key": str(x), "Value": str(x)}, range(0, 51)) + ) + conn.create_role( + RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=too_many_tags + ) + assert ( + "failed to satisfy constraint: Member must have length less than or equal to 50." + in ce.exception.response["Error"]["Message"] + ) # With a duplicate tag: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': '0', 'Value': ''}, {'Key': '0', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "0", "Value": ""}, {"Key": "0", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # Duplicate tag with different casing: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': 'a', 'Value': ''}, {'Key': 'A', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "a", "Value": ""}, {"Key": "A", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # With a really big key: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': '0' * 129, 'Value': ''}]) - assert 'Member must have length less than or equal to 128.' in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "0" * 129, "Value": ""}], + ) + assert ( + "Member must have length less than or equal to 128." + in ce.exception.response["Error"]["Message"] + ) # With a really big value: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': '0', 'Value': '0' * 257}]) - assert 'Member must have length less than or equal to 256.' in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "0", "Value": "0" * 257}], + ) + assert ( + "Member must have length less than or equal to 256." + in ce.exception.response["Error"]["Message"] + ) # With an invalid character: with assert_raises(ClientError) as ce: - conn.create_role(RoleName="my-role3", AssumeRolePolicyDocument="{}", Tags=[{'Key': 'NOWAY!', 'Value': ''}]) - assert 'Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+' \ - in ce.exception.response['Error']['Message'] + conn.create_role( + RoleName="my-role3", + AssumeRolePolicyDocument="{}", + Tags=[{"Key": "NOWAY!", "Value": ""}], + ) + assert ( + "Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" + in ce.exception.response["Error"]["Message"] + ) @mock_iam() def test_tag_role(): """Tests both the tag_role and get_role_tags capability""" - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="{}") # Get without tags: - role = conn.get_role(RoleName='my-role')['Role'] - assert not role.get('Tags') + role = conn.get_role(RoleName="my-role")["Role"] + assert not role.get("Tags") # With proper tag values: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) # Get role: - role = conn.get_role(RoleName='my-role')['Role'] - assert len(role['Tags']) == 2 - assert role['Tags'][0]['Key'] == 'somekey' - assert role['Tags'][0]['Value'] == 'somevalue' - assert role['Tags'][1]['Key'] == 'someotherkey' - assert role['Tags'][1]['Value'] == 'someothervalue' + role = conn.get_role(RoleName="my-role")["Role"] + assert len(role["Tags"]) == 2 + assert role["Tags"][0]["Key"] == "somekey" + assert role["Tags"][0]["Value"] == "somevalue" + assert role["Tags"][1]["Key"] == "someotherkey" + assert role["Tags"][1]["Value"] == "someothervalue" # Same -- but for list_role_tags: - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 2 - assert role['Tags'][0]['Key'] == 'somekey' - assert role['Tags'][0]['Value'] == 'somevalue' - assert role['Tags'][1]['Key'] == 'someotherkey' - assert role['Tags'][1]['Value'] == 'someothervalue' - assert not tags['IsTruncated'] - assert not tags.get('Marker') + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 2 + assert role["Tags"][0]["Key"] == "somekey" + assert role["Tags"][0]["Value"] == "somevalue" + assert role["Tags"][1]["Key"] == "someotherkey" + assert role["Tags"][1]["Value"] == "someothervalue" + assert not tags["IsTruncated"] + assert not tags.get("Marker") # Test pagination: - tags = conn.list_role_tags(RoleName='my-role', MaxItems=1) - assert len(tags['Tags']) == 1 - assert tags['IsTruncated'] - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == 'somevalue' - assert tags['Marker'] == '1' + tags = conn.list_role_tags(RoleName="my-role", MaxItems=1) + assert len(tags["Tags"]) == 1 + assert tags["IsTruncated"] + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "somevalue" + assert tags["Marker"] == "1" - tags = conn.list_role_tags(RoleName='my-role', Marker=tags['Marker']) - assert len(tags['Tags']) == 1 - assert tags['Tags'][0]['Key'] == 'someotherkey' - assert tags['Tags'][0]['Value'] == 'someothervalue' - assert not tags['IsTruncated'] - assert not tags.get('Marker') + tags = conn.list_role_tags(RoleName="my-role", Marker=tags["Marker"]) + assert len(tags["Tags"]) == 1 + assert tags["Tags"][0]["Key"] == "someotherkey" + assert tags["Tags"][0]["Value"] == "someothervalue" + assert not tags["IsTruncated"] + assert not tags.get("Marker") # Test updating an existing tag: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somenewvalue' - } - ]) - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 2 - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == 'somenewvalue' + conn.tag_role( + RoleName="my-role", Tags=[{"Key": "somekey", "Value": "somenewvalue"}] + ) + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 2 + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "somenewvalue" # Empty is good: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': '' - } - ]) - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 2 - assert tags['Tags'][0]['Key'] == 'somekey' - assert tags['Tags'][0]['Value'] == '' + conn.tag_role(RoleName="my-role", Tags=[{"Key": "somekey", "Value": ""}]) + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 2 + assert tags["Tags"][0]["Key"] == "somekey" + assert tags["Tags"][0]["Value"] == "" # Test creating tags with invalid values: # With more than 50 tags: with assert_raises(ClientError) as ce: - too_many_tags = list(map(lambda x: {'Key': str(x), 'Value': str(x)}, range(0, 51))) - conn.tag_role(RoleName='my-role', Tags=too_many_tags) - assert 'failed to satisfy constraint: Member must have length less than or equal to 50.' \ - in ce.exception.response['Error']['Message'] + too_many_tags = list( + map(lambda x: {"Key": str(x), "Value": str(x)}, range(0, 51)) + ) + conn.tag_role(RoleName="my-role", Tags=too_many_tags) + assert ( + "failed to satisfy constraint: Member must have length less than or equal to 50." + in ce.exception.response["Error"]["Message"] + ) # With a duplicate tag: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': '0', 'Value': ''}, {'Key': '0', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.tag_role( + RoleName="my-role", + Tags=[{"Key": "0", "Value": ""}, {"Key": "0", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # Duplicate tag with different casing: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': 'a', 'Value': ''}, {'Key': 'A', 'Value': ''}]) - assert 'Duplicate tag keys found. Please note that Tag keys are case insensitive.' \ - in ce.exception.response['Error']['Message'] + conn.tag_role( + RoleName="my-role", + Tags=[{"Key": "a", "Value": ""}, {"Key": "A", "Value": ""}], + ) + assert ( + "Duplicate tag keys found. Please note that Tag keys are case insensitive." + in ce.exception.response["Error"]["Message"] + ) # With a really big key: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': '0' * 129, 'Value': ''}]) - assert 'Member must have length less than or equal to 128.' in ce.exception.response['Error']['Message'] + conn.tag_role(RoleName="my-role", Tags=[{"Key": "0" * 129, "Value": ""}]) + assert ( + "Member must have length less than or equal to 128." + in ce.exception.response["Error"]["Message"] + ) # With a really big value: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': '0', 'Value': '0' * 257}]) - assert 'Member must have length less than or equal to 256.' in ce.exception.response['Error']['Message'] + conn.tag_role(RoleName="my-role", Tags=[{"Key": "0", "Value": "0" * 257}]) + assert ( + "Member must have length less than or equal to 256." + in ce.exception.response["Error"]["Message"] + ) # With an invalid character: with assert_raises(ClientError) as ce: - conn.tag_role(RoleName='my-role', Tags=[{'Key': 'NOWAY!', 'Value': ''}]) - assert 'Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+' \ - in ce.exception.response['Error']['Message'] + conn.tag_role(RoleName="my-role", Tags=[{"Key": "NOWAY!", "Value": ""}]) + assert ( + "Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" + in ce.exception.response["Error"]["Message"] + ) # With a role that doesn't exist: with assert_raises(ClientError): - conn.tag_role(RoleName='notarole', Tags=[{'Key': 'some', 'Value': 'value'}]) + conn.tag_role(RoleName="notarole", Tags=[{"Key": "some", "Value": "value"}]) @mock_iam def test_untag_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="{}") # With proper tag values: - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) # Remove them: - conn.untag_role(RoleName='my-role', TagKeys=['somekey']) - tags = conn.list_role_tags(RoleName='my-role') - assert len(tags['Tags']) == 1 - assert tags['Tags'][0]['Key'] == 'someotherkey' - assert tags['Tags'][0]['Value'] == 'someothervalue' + conn.untag_role(RoleName="my-role", TagKeys=["somekey"]) + tags = conn.list_role_tags(RoleName="my-role") + assert len(tags["Tags"]) == 1 + assert tags["Tags"][0]["Key"] == "someotherkey" + assert tags["Tags"][0]["Value"] == "someothervalue" # And again: - conn.untag_role(RoleName='my-role', TagKeys=['someotherkey']) - tags = conn.list_role_tags(RoleName='my-role') - assert not tags['Tags'] + conn.untag_role(RoleName="my-role", TagKeys=["someotherkey"]) + tags = conn.list_role_tags(RoleName="my-role") + assert not tags["Tags"] # Test removing tags with invalid values: # With more than 50 tags: with assert_raises(ClientError) as ce: - conn.untag_role(RoleName='my-role', TagKeys=[str(x) for x in range(0, 51)]) - assert 'failed to satisfy constraint: Member must have length less than or equal to 50.' \ - in ce.exception.response['Error']['Message'] - assert 'tagKeys' in ce.exception.response['Error']['Message'] + conn.untag_role(RoleName="my-role", TagKeys=[str(x) for x in range(0, 51)]) + assert ( + "failed to satisfy constraint: Member must have length less than or equal to 50." + in ce.exception.response["Error"]["Message"] + ) + assert "tagKeys" in ce.exception.response["Error"]["Message"] # With a really big key: with assert_raises(ClientError) as ce: - conn.untag_role(RoleName='my-role', TagKeys=['0' * 129]) - assert 'Member must have length less than or equal to 128.' in ce.exception.response['Error']['Message'] - assert 'tagKeys' in ce.exception.response['Error']['Message'] + conn.untag_role(RoleName="my-role", TagKeys=["0" * 129]) + assert ( + "Member must have length less than or equal to 128." + in ce.exception.response["Error"]["Message"] + ) + assert "tagKeys" in ce.exception.response["Error"]["Message"] # With an invalid character: with assert_raises(ClientError) as ce: - conn.untag_role(RoleName='my-role', TagKeys=['NOWAY!']) - assert 'Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+' \ - in ce.exception.response['Error']['Message'] - assert 'tagKeys' in ce.exception.response['Error']['Message'] + conn.untag_role(RoleName="my-role", TagKeys=["NOWAY!"]) + assert ( + "Member must satisfy regular expression pattern: [\\p{L}\\p{Z}\\p{N}_.:/=+\\-@]+" + in ce.exception.response["Error"]["Message"] + ) + assert "tagKeys" in ce.exception.response["Error"]["Message"] # With a role that doesn't exist: with assert_raises(ClientError): - conn.untag_role(RoleName='notarole', TagKeys=['somevalue']) + conn.untag_role(RoleName="notarole", TagKeys=["somevalue"]) @mock_iam() def test_update_role_description(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.delete_role(RoleName="my-role") - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) response = conn.update_role_description(RoleName="my-role", Description="test") - assert response['Role']['RoleName'] == 'my-role' + assert response["Role"]["RoleName"] == "my-role" @mock_iam() def test_update_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.delete_role(RoleName="my-role") - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) response = conn.update_role_description(RoleName="my-role", Description="test") - assert response['Role']['RoleName'] == 'my-role' + assert response["Role"]["RoleName"] == "my-role" @mock_iam() def test_update_role(): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError): conn.delete_role(RoleName="my-role") - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) response = conn.update_role(RoleName="my-role", Description="test") assert len(response.keys()) == 1 @mock_iam() def test_list_entities_for_policy(): - test_policy = json.dumps({ - "Version": "2012-10-17", - "Statement": [ - { - "Action": "s3:ListBucket", - "Resource": "*", - "Effect": "Allow", - } - ] - }) + test_policy = json.dumps( + { + "Version": "2012-10-17", + "Statement": [ + {"Action": "s3:ListBucket", "Resource": "*", "Effect": "Allow"} + ], + } + ) - conn = boto3.client('iam', region_name='us-east-1') - conn.create_role(RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/") - conn.create_user(Path='/', UserName='testUser') - conn.create_group(Path='/', GroupName='testGroup') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Path="/my-path/" + ) + conn.create_user(Path="/", UserName="testUser") + conn.create_group(Path="/", GroupName="testGroup") conn.create_policy( - PolicyName='testPolicy', - Path='/', + PolicyName="testPolicy", + Path="/", PolicyDocument=test_policy, - Description='Test Policy' + Description="Test Policy", ) # Attach things to the user and group: - conn.put_user_policy(UserName='testUser', PolicyName='testPolicy', PolicyDocument=test_policy) - conn.put_group_policy(GroupName='testGroup', PolicyName='testPolicy', PolicyDocument=test_policy) + conn.put_user_policy( + UserName="testUser", PolicyName="testPolicy", PolicyDocument=test_policy + ) + conn.put_group_policy( + GroupName="testGroup", PolicyName="testPolicy", PolicyDocument=test_policy + ) - conn.attach_user_policy(UserName='testUser', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') - conn.attach_group_policy(GroupName='testGroup', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.attach_user_policy( + UserName="testUser", PolicyArn="arn:aws:iam::123456789012:policy/testPolicy" + ) + conn.attach_group_policy( + GroupName="testGroup", PolicyArn="arn:aws:iam::123456789012:policy/testPolicy" + ) - conn.add_user_to_group(UserName='testUser', GroupName='testGroup') + conn.add_user_to_group(UserName="testUser", GroupName="testGroup") # Add things to the role: - conn.create_instance_profile(InstanceProfileName='ipn') - conn.add_role_to_instance_profile(InstanceProfileName='ipn', RoleName='my-role') - conn.tag_role(RoleName='my-role', Tags=[ - { - 'Key': 'somekey', - 'Value': 'somevalue' - }, - { - 'Key': 'someotherkey', - 'Value': 'someothervalue' - } - ]) - conn.put_role_policy(RoleName='my-role', PolicyName='test-policy', PolicyDocument=test_policy) - conn.attach_role_policy(RoleName='my-role', PolicyArn='arn:aws:iam::123456789012:policy/testPolicy') + conn.create_instance_profile(InstanceProfileName="ipn") + conn.add_role_to_instance_profile(InstanceProfileName="ipn", RoleName="my-role") + conn.tag_role( + RoleName="my-role", + Tags=[ + {"Key": "somekey", "Value": "somevalue"}, + {"Key": "someotherkey", "Value": "someothervalue"}, + ], + ) + conn.put_role_policy( + RoleName="my-role", PolicyName="test-policy", PolicyDocument=test_policy + ) + conn.attach_role_policy( + RoleName="my-role", PolicyArn="arn:aws:iam::123456789012:policy/testPolicy" + ) response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='Role' + PolicyArn="arn:aws:iam::123456789012:policy/testPolicy", EntityFilter="Role" ) - assert response['PolicyRoles'] == [{'RoleName': 'my-role'}] + assert response["PolicyRoles"] == [{"RoleName": "my-role"}] response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='User', + PolicyArn="arn:aws:iam::123456789012:policy/testPolicy", EntityFilter="User" ) - assert response['PolicyUsers'] == [{'UserName': 'testUser'}] + assert response["PolicyUsers"] == [{"UserName": "testUser"}] response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='Group', + PolicyArn="arn:aws:iam::123456789012:policy/testPolicy", EntityFilter="Group" ) - assert response['PolicyGroups'] == [{'GroupName': 'testGroup'}] + assert response["PolicyGroups"] == [{"GroupName": "testGroup"}] response = conn.list_entities_for_policy( - PolicyArn='arn:aws:iam::123456789012:policy/testPolicy', - EntityFilter='LocalManagedPolicy', + PolicyArn="arn:aws:iam::123456789012:policy/testPolicy", + EntityFilter="LocalManagedPolicy", ) - assert response['PolicyGroups'] == [{'GroupName': 'testGroup'}] - assert response['PolicyUsers'] == [{'UserName': 'testUser'}] - assert response['PolicyRoles'] == [{'RoleName': 'my-role'}] + assert response["PolicyGroups"] == [{"GroupName": "testGroup"}] + assert response["PolicyUsers"] == [{"UserName": "testUser"}] + assert response["PolicyRoles"] == [{"RoleName": "my-role"}] @mock_iam() def test_create_role_no_path(): - conn = boto3.client('iam', region_name='us-east-1') - resp = conn.create_role(RoleName='my-role', AssumeRolePolicyDocument='some policy', Description='test') - resp.get('Role').get('Arn').should.equal('arn:aws:iam::123456789012:role/my-role') - resp.get('Role').should_not.have.key('PermissionsBoundary') - resp.get('Role').get('Description').should.equal('test') + conn = boto3.client("iam", region_name="us-east-1") + resp = conn.create_role( + RoleName="my-role", AssumeRolePolicyDocument="some policy", Description="test" + ) + resp.get("Role").get("Arn").should.equal("arn:aws:iam::123456789012:role/my-role") + resp.get("Role").should_not.have.key("PermissionsBoundary") + resp.get("Role").get("Description").should.equal("test") @mock_iam() def test_create_role_with_permissions_boundary(): - conn = boto3.client('iam', region_name='us-east-1') - boundary = 'arn:aws:iam::123456789012:policy/boundary' - resp = conn.create_role(RoleName='my-role', AssumeRolePolicyDocument='some policy', Description='test', PermissionsBoundary=boundary) + conn = boto3.client("iam", region_name="us-east-1") + boundary = "arn:aws:iam::123456789012:policy/boundary" + resp = conn.create_role( + RoleName="my-role", + AssumeRolePolicyDocument="some policy", + Description="test", + PermissionsBoundary=boundary, + ) expected = { - 'PermissionsBoundaryType': 'PermissionsBoundaryPolicy', - 'PermissionsBoundaryArn': boundary + "PermissionsBoundaryType": "PermissionsBoundaryPolicy", + "PermissionsBoundaryArn": boundary, } - resp.get('Role').get('PermissionsBoundary').should.equal(expected) - resp.get('Role').get('Description').should.equal('test') + resp.get("Role").get("PermissionsBoundary").should.equal(expected) + resp.get("Role").get("Description").should.equal("test") - invalid_boundary_arn = 'arn:aws:iam::123456789:not_a_boundary' + invalid_boundary_arn = "arn:aws:iam::123456789:not_a_boundary" with assert_raises(ClientError): - conn.create_role(RoleName='bad-boundary', AssumeRolePolicyDocument='some policy', Description='test', PermissionsBoundary=invalid_boundary_arn) + conn.create_role( + RoleName="bad-boundary", + AssumeRolePolicyDocument="some policy", + Description="test", + PermissionsBoundary=invalid_boundary_arn, + ) # Ensure the PermissionsBoundary is included in role listing as well - conn.list_roles().get('Roles')[0].get('PermissionsBoundary').should.equal(expected) + conn.list_roles().get("Roles")[0].get("PermissionsBoundary").should.equal(expected) @mock_iam def test_create_open_id_connect_provider(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") response = client.create_open_id_connect_provider( - Url='https://example.com', - ThumbprintList=[] # even it is required to provide at least one thumbprint, AWS accepts an empty list + Url="https://example.com", + ThumbprintList=[], # even it is required to provide at least one thumbprint, AWS accepts an empty list ) - response['OpenIDConnectProviderArn'].should.equal( - 'arn:aws:iam::123456789012:oidc-provider/example.com' + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::123456789012:oidc-provider/example.com" ) response = client.create_open_id_connect_provider( - Url='http://example.org', - ThumbprintList=[ - 'b' * 40 - ], - ClientIDList=[ - 'b' - ] + Url="http://example.org", ThumbprintList=["b" * 40], ClientIDList=["b"] ) - response['OpenIDConnectProviderArn'].should.equal( - 'arn:aws:iam::123456789012:oidc-provider/example.org' + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::123456789012:oidc-provider/example.org" ) response = client.create_open_id_connect_provider( - Url='http://example.org/oidc', - ThumbprintList=[] + Url="http://example.org/oidc", ThumbprintList=[] ) - response['OpenIDConnectProviderArn'].should.equal( - 'arn:aws:iam::123456789012:oidc-provider/example.org/oidc' + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::123456789012:oidc-provider/example.org/oidc" ) response = client.create_open_id_connect_provider( - Url='http://example.org/oidc-query?test=true', - ThumbprintList=[] + Url="http://example.org/oidc-query?test=true", ThumbprintList=[] ) - response['OpenIDConnectProviderArn'].should.equal( - 'arn:aws:iam::123456789012:oidc-provider/example.org/oidc-query' + response["OpenIDConnectProviderArn"].should.equal( + "arn:aws:iam::123456789012:oidc-provider/example.org/oidc-query" ) @mock_iam def test_create_open_id_connect_provider_errors(): - client = boto3.client('iam', region_name='us-east-1') - client.create_open_id_connect_provider( - Url='https://example.com', - ThumbprintList=[] - ) + client = boto3.client("iam", region_name="us-east-1") + client.create_open_id_connect_provider(Url="https://example.com", ThumbprintList=[]) client.create_open_id_connect_provider.when.called_with( - Url='https://example.com', - ThumbprintList=[] - ).should.throw( - ClientError, - 'Unknown' - ) + Url="https://example.com", ThumbprintList=[] + ).should.throw(ClientError, "Unknown") client.create_open_id_connect_provider.when.called_with( - Url='example.org', - ThumbprintList=[] - ).should.throw( - ClientError, - 'Invalid Open ID Connect Provider URL' - ) + Url="example.org", ThumbprintList=[] + ).should.throw(ClientError, "Invalid Open ID Connect Provider URL") client.create_open_id_connect_provider.when.called_with( - Url='example', - ThumbprintList=[] - ).should.throw( - ClientError, - 'Invalid Open ID Connect Provider URL' - ) + Url="example", ThumbprintList=[] + ).should.throw(ClientError, "Invalid Open ID Connect Provider URL") client.create_open_id_connect_provider.when.called_with( - Url='http://example.org', - ThumbprintList=[ - 'a' * 40, - 'b' * 40, - 'c' * 40, - 'd' * 40, - 'e' * 40, - 'f' * 40, - ] - ).should.throw( - ClientError, - 'Thumbprint list must contain fewer than 5 entries.' - ) + Url="http://example.org", + ThumbprintList=["a" * 40, "b" * 40, "c" * 40, "d" * 40, "e" * 40, "f" * 40], + ).should.throw(ClientError, "Thumbprint list must contain fewer than 5 entries.") - too_many_client_ids = ['{}'.format(i) for i in range(101)] + too_many_client_ids = ["{}".format(i) for i in range(101)] client.create_open_id_connect_provider.when.called_with( - Url='http://example.org', - ThumbprintList=[], - ClientIDList=too_many_client_ids + Url="http://example.org", ThumbprintList=[], ClientIDList=too_many_client_ids ).should.throw( - ClientError, - 'Cannot exceed quota for ClientIdsPerOpenIdConnectProvider: 100' + ClientError, "Cannot exceed quota for ClientIdsPerOpenIdConnectProvider: 100" ) - too_long_url = 'b' * 256 - too_long_thumbprint = 'b' * 41 - too_long_client_id = 'b' * 256 + too_long_url = "b" * 256 + too_long_thumbprint = "b" * 41 + too_long_client_id = "b" * 256 client.create_open_id_connect_provider.when.called_with( Url=too_long_url, - ThumbprintList=[ - too_long_thumbprint - ], - ClientIDList=[ - too_long_client_id - ] + ThumbprintList=[too_long_thumbprint], + ClientIDList=[too_long_client_id], ).should.throw( ClientError, - '3 validation errors detected: ' + "3 validation errors detected: " 'Value "{0}" at "clientIDList" failed to satisfy constraint: ' - 'Member must satisfy constraint: ' - '[Member must have length less than or equal to 255, ' - 'Member must have length greater than or equal to 1]; ' + "Member must satisfy constraint: " + "[Member must have length less than or equal to 255, " + "Member must have length greater than or equal to 1]; " 'Value "{1}" at "thumbprintList" failed to satisfy constraint: ' - 'Member must satisfy constraint: ' - '[Member must have length less than or equal to 40, ' - 'Member must have length greater than or equal to 40]; ' + "Member must satisfy constraint: " + "[Member must have length less than or equal to 40, " + "Member must have length greater than or equal to 40]; " 'Value "{2}" at "url" failed to satisfy constraint: ' - 'Member must have length less than or equal to 255'.format([too_long_client_id], [too_long_thumbprint], too_long_url) + "Member must have length less than or equal to 255".format( + [too_long_client_id], [too_long_thumbprint], too_long_url + ), ) @mock_iam def test_delete_open_id_connect_provider(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") response = client.create_open_id_connect_provider( - Url='https://example.com', - ThumbprintList=[] + Url="https://example.com", ThumbprintList=[] ) - open_id_arn = response['OpenIDConnectProviderArn'] + open_id_arn = response["OpenIDConnectProviderArn"] - client.delete_open_id_connect_provider( - OpenIDConnectProviderArn=open_id_arn - ) + client.delete_open_id_connect_provider(OpenIDConnectProviderArn=open_id_arn) client.get_open_id_connect_provider.when.called_with( OpenIDConnectProviderArn=open_id_arn ).should.throw( - ClientError, - 'OpenIDConnect Provider not found for arn {}'.format(open_id_arn) + ClientError, "OpenIDConnect Provider not found for arn {}".format(open_id_arn) ) # deleting a non existing provider should be successful - client.delete_open_id_connect_provider( - OpenIDConnectProviderArn=open_id_arn - ) + client.delete_open_id_connect_provider(OpenIDConnectProviderArn=open_id_arn) @mock_iam def test_get_open_id_connect_provider(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") response = client.create_open_id_connect_provider( - Url='https://example.com', - ThumbprintList=[ - 'b' * 40 - ], - ClientIDList=[ - 'b' - ] + Url="https://example.com", ThumbprintList=["b" * 40], ClientIDList=["b"] ) - open_id_arn = response['OpenIDConnectProviderArn'] + open_id_arn = response["OpenIDConnectProviderArn"] - response = client.get_open_id_connect_provider( - OpenIDConnectProviderArn=open_id_arn - ) + response = client.get_open_id_connect_provider(OpenIDConnectProviderArn=open_id_arn) - response['Url'].should.equal('example.com') - response['ThumbprintList'].should.equal([ - 'b' * 40 - ]) - response['ClientIDList'].should.equal([ - 'b' - ]) - response.should.have.key('CreateDate').should.be.a(datetime) + response["Url"].should.equal("example.com") + response["ThumbprintList"].should.equal(["b" * 40]) + response["ClientIDList"].should.equal(["b"]) + response.should.have.key("CreateDate").should.be.a(datetime) @mock_iam def test_get_open_id_connect_provider_errors(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") response = client.create_open_id_connect_provider( - Url='https://example.com', - ThumbprintList=[ - 'b' * 40 - ], - ClientIDList=[ - 'b' - ] + Url="https://example.com", ThumbprintList=["b" * 40], ClientIDList=["b"] ) - open_id_arn = response['OpenIDConnectProviderArn'] + open_id_arn = response["OpenIDConnectProviderArn"] client.get_open_id_connect_provider.when.called_with( - OpenIDConnectProviderArn=open_id_arn + '-not-existing' + OpenIDConnectProviderArn=open_id_arn + "-not-existing" ).should.throw( ClientError, - 'OpenIDConnect Provider not found for arn {}'.format(open_id_arn + '-not-existing') + "OpenIDConnect Provider not found for arn {}".format( + open_id_arn + "-not-existing" + ), ) @mock_iam def test_list_open_id_connect_providers(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") response = client.create_open_id_connect_provider( - Url='https://example.com', - ThumbprintList=[] + Url="https://example.com", ThumbprintList=[] ) - open_id_arn_1 = response['OpenIDConnectProviderArn'] + open_id_arn_1 = response["OpenIDConnectProviderArn"] response = client.create_open_id_connect_provider( - Url='http://example.org', - ThumbprintList=[ - 'b' * 40 - ], - ClientIDList=[ - 'b' - ] + Url="http://example.org", ThumbprintList=["b" * 40], ClientIDList=["b"] ) - open_id_arn_2 = response['OpenIDConnectProviderArn'] + open_id_arn_2 = response["OpenIDConnectProviderArn"] response = client.create_open_id_connect_provider( - Url='http://example.org/oidc', - ThumbprintList=[] + Url="http://example.org/oidc", ThumbprintList=[] ) - open_id_arn_3 = response['OpenIDConnectProviderArn'] + open_id_arn_3 = response["OpenIDConnectProviderArn"] response = client.list_open_id_connect_providers() - sorted(response['OpenIDConnectProviderList'], key=lambda i: i['Arn']).should.equal( - [ - { - 'Arn': open_id_arn_1 - }, - { - 'Arn': open_id_arn_2 - }, - { - 'Arn': open_id_arn_3 - } - ] + sorted(response["OpenIDConnectProviderList"], key=lambda i: i["Arn"]).should.equal( + [{"Arn": open_id_arn_1}, {"Arn": open_id_arn_2}, {"Arn": open_id_arn_3}] ) diff --git a/tests/test_iam/test_iam_account_aliases.py b/tests/test_iam/test_iam_account_aliases.py index 3d927038d..d01a72106 100644 --- a/tests/test_iam/test_iam_account_aliases.py +++ b/tests/test_iam/test_iam_account_aliases.py @@ -5,16 +5,16 @@ from moto import mock_iam @mock_iam() def test_account_aliases(): - client = boto3.client('iam', region_name='us-east-1') + client = boto3.client("iam", region_name="us-east-1") - alias = 'my-account-name' + alias = "my-account-name" aliases = client.list_account_aliases() - aliases.should.have.key('AccountAliases').which.should.equal([]) + aliases.should.have.key("AccountAliases").which.should.equal([]) client.create_account_alias(AccountAlias=alias) aliases = client.list_account_aliases() - aliases.should.have.key('AccountAliases').which.should.equal([alias]) + aliases.should.have.key("AccountAliases").which.should.equal([alias]) client.delete_account_alias(AccountAlias=alias) aliases = client.list_account_aliases() - aliases.should.have.key('AccountAliases').which.should.equal([]) + aliases.should.have.key("AccountAliases").which.should.equal([]) diff --git a/tests/test_iam/test_iam_groups.py b/tests/test_iam/test_iam_groups.py index 1ca9f2512..7fd299281 100644 --- a/tests/test_iam/test_iam_groups.py +++ b/tests/test_iam/test_iam_groups.py @@ -26,46 +26,50 @@ MOCK_POLICY = """ @mock_iam_deprecated() def test_create_group(): conn = boto.connect_iam() - conn.create_group('my-group') + conn.create_group("my-group") with assert_raises(BotoServerError): - conn.create_group('my-group') + conn.create_group("my-group") @mock_iam_deprecated() def test_get_group(): conn = boto.connect_iam() - conn.create_group('my-group') - conn.get_group('my-group') + conn.create_group("my-group") + conn.get_group("my-group") with assert_raises(BotoServerError): - conn.get_group('not-group') + conn.get_group("not-group") @mock_iam() def test_get_group_current(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_group(GroupName='my-group') - result = conn.get_group(GroupName='my-group') + conn = boto3.client("iam", region_name="us-east-1") + conn.create_group(GroupName="my-group") + result = conn.get_group(GroupName="my-group") - assert result['Group']['Path'] == '/' - assert result['Group']['GroupName'] == 'my-group' - assert isinstance(result['Group']['CreateDate'], datetime) - assert result['Group']['GroupId'] - assert result['Group']['Arn'] == 'arn:aws:iam::123456789012:group/my-group' - assert not result['Users'] + assert result["Group"]["Path"] == "/" + assert result["Group"]["GroupName"] == "my-group" + assert isinstance(result["Group"]["CreateDate"], datetime) + assert result["Group"]["GroupId"] + assert result["Group"]["Arn"] == "arn:aws:iam::123456789012:group/my-group" + assert not result["Users"] # Make a group with a different path: - other_group = conn.create_group(GroupName='my-other-group', Path='some/location') - assert other_group['Group']['Path'] == 'some/location' - assert other_group['Group']['Arn'] == 'arn:aws:iam::123456789012:group/some/location/my-other-group' + other_group = conn.create_group(GroupName="my-other-group", Path="some/location") + assert other_group["Group"]["Path"] == "some/location" + assert ( + other_group["Group"]["Arn"] + == "arn:aws:iam::123456789012:group/some/location/my-other-group" + ) @mock_iam_deprecated() def test_get_all_groups(): conn = boto.connect_iam() - conn.create_group('my-group1') - conn.create_group('my-group2') - groups = conn.get_all_groups()['list_groups_response'][ - 'list_groups_result']['groups'] + conn.create_group("my-group1") + conn.create_group("my-group2") + groups = conn.get_all_groups()["list_groups_response"]["list_groups_result"][ + "groups" + ] groups.should.have.length_of(2) @@ -73,95 +77,108 @@ def test_get_all_groups(): def test_add_user_to_group(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.add_user_to_group('my-group', 'my-user') - conn.create_group('my-group') + conn.add_user_to_group("my-group", "my-user") + conn.create_group("my-group") with assert_raises(BotoServerError): - conn.add_user_to_group('my-group', 'my-user') - conn.create_user('my-user') - conn.add_user_to_group('my-group', 'my-user') + conn.add_user_to_group("my-group", "my-user") + conn.create_user("my-user") + conn.add_user_to_group("my-group", "my-user") @mock_iam_deprecated() def test_remove_user_from_group(): conn = boto.connect_iam() with assert_raises(BotoServerError): - conn.remove_user_from_group('my-group', 'my-user') - conn.create_group('my-group') - conn.create_user('my-user') + conn.remove_user_from_group("my-group", "my-user") + conn.create_group("my-group") + conn.create_user("my-user") with assert_raises(BotoServerError): - conn.remove_user_from_group('my-group', 'my-user') - conn.add_user_to_group('my-group', 'my-user') - conn.remove_user_from_group('my-group', 'my-user') + conn.remove_user_from_group("my-group", "my-user") + conn.add_user_to_group("my-group", "my-user") + conn.remove_user_from_group("my-group", "my-user") @mock_iam_deprecated() def test_get_groups_for_user(): conn = boto.connect_iam() - conn.create_group('my-group1') - conn.create_group('my-group2') - conn.create_group('other-group') - conn.create_user('my-user') - conn.add_user_to_group('my-group1', 'my-user') - conn.add_user_to_group('my-group2', 'my-user') + conn.create_group("my-group1") + conn.create_group("my-group2") + conn.create_group("other-group") + conn.create_user("my-user") + conn.add_user_to_group("my-group1", "my-user") + conn.add_user_to_group("my-group2", "my-user") - groups = conn.get_groups_for_user( - 'my-user')['list_groups_for_user_response']['list_groups_for_user_result']['groups'] + groups = conn.get_groups_for_user("my-user")["list_groups_for_user_response"][ + "list_groups_for_user_result" + ]["groups"] groups.should.have.length_of(2) @mock_iam_deprecated() def test_put_group_policy(): conn = boto.connect_iam() - conn.create_group('my-group') - conn.put_group_policy('my-group', 'my-policy', MOCK_POLICY) + conn.create_group("my-group") + conn.put_group_policy("my-group", "my-policy", MOCK_POLICY) @mock_iam def test_attach_group_policies(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_group(GroupName='my-group') - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.be.empty - policy_arn = 'arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceforEC2Role' - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.be.empty - conn.attach_group_policy(GroupName='my-group', PolicyArn=policy_arn) - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.equal( - [ - { - 'PolicyName': 'AmazonElasticMapReduceforEC2Role', - 'PolicyArn': policy_arn, - } - ]) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_group(GroupName="my-group") + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.be.empty + policy_arn = "arn:aws:iam::aws:policy/service-role/AmazonElasticMapReduceforEC2Role" + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.be.empty + conn.attach_group_policy(GroupName="my-group", PolicyArn=policy_arn) + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.equal( + [{"PolicyName": "AmazonElasticMapReduceforEC2Role", "PolicyArn": policy_arn}] + ) - conn.detach_group_policy(GroupName='my-group', PolicyArn=policy_arn) - conn.list_attached_group_policies(GroupName='my-group')['AttachedPolicies'].should.be.empty + conn.detach_group_policy(GroupName="my-group", PolicyArn=policy_arn) + conn.list_attached_group_policies(GroupName="my-group")[ + "AttachedPolicies" + ].should.be.empty @mock_iam_deprecated() def test_get_group_policy(): conn = boto.connect_iam() - conn.create_group('my-group') + conn.create_group("my-group") with assert_raises(BotoServerError): - conn.get_group_policy('my-group', 'my-policy') + conn.get_group_policy("my-group", "my-policy") - conn.put_group_policy('my-group', 'my-policy', MOCK_POLICY) - conn.get_group_policy('my-group', 'my-policy') + conn.put_group_policy("my-group", "my-policy", MOCK_POLICY) + conn.get_group_policy("my-group", "my-policy") @mock_iam_deprecated() def test_get_all_group_policies(): conn = boto.connect_iam() - conn.create_group('my-group') - policies = conn.get_all_group_policies('my-group')['list_group_policies_response']['list_group_policies_result']['policy_names'] + conn.create_group("my-group") + policies = conn.get_all_group_policies("my-group")["list_group_policies_response"][ + "list_group_policies_result" + ]["policy_names"] assert policies == [] - conn.put_group_policy('my-group', 'my-policy', MOCK_POLICY) - policies = conn.get_all_group_policies('my-group')['list_group_policies_response']['list_group_policies_result']['policy_names'] - assert policies == ['my-policy'] + conn.put_group_policy("my-group", "my-policy", MOCK_POLICY) + policies = conn.get_all_group_policies("my-group")["list_group_policies_response"][ + "list_group_policies_result" + ]["policy_names"] + assert policies == ["my-policy"] @mock_iam() def test_list_group_policies(): - conn = boto3.client('iam', region_name='us-east-1') - conn.create_group(GroupName='my-group') - conn.list_group_policies(GroupName='my-group')['PolicyNames'].should.be.empty - conn.put_group_policy(GroupName='my-group', PolicyName='my-policy', PolicyDocument=MOCK_POLICY) - conn.list_group_policies(GroupName='my-group')['PolicyNames'].should.equal(['my-policy']) + conn = boto3.client("iam", region_name="us-east-1") + conn.create_group(GroupName="my-group") + conn.list_group_policies(GroupName="my-group")["PolicyNames"].should.be.empty + conn.put_group_policy( + GroupName="my-group", PolicyName="my-policy", PolicyDocument=MOCK_POLICY + ) + conn.list_group_policies(GroupName="my-group")["PolicyNames"].should.equal( + ["my-policy"] + ) diff --git a/tests/test_iam/test_iam_policies.py b/tests/test_iam/test_iam_policies.py index adb8bd990..6348b0cba 100644 --- a/tests/test_iam/test_iam_policies.py +++ b/tests/test_iam/test_iam_policies.py @@ -9,17 +9,17 @@ from moto import mock_iam invalid_policy_document_test_cases = [ { "document": "This is not a json document", - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", } }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { @@ -27,10 +27,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { @@ -38,35 +38,18 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17" - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": ["afd"] - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - "Extra field": "value" }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": {"Version": "2012-10-17"}, + "error_message": "Syntax errors in policy.", + }, + { + "document": {"Version": "2012-10-17", "Statement": ["afd"]}, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -75,10 +58,22 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Extra field": "value" - } + }, + "Extra field": "value", }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Extra field": "value", + }, + }, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -87,10 +82,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -99,10 +94,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -110,10 +105,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "invalid", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -121,46 +116,43 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "invalid", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.' + "error_message": "Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "NotAction": "", + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.' + "error_message": "Actions/Conditions must be prefaced by a vendor, e.g., iam, sdb, ec2, etc.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "a a:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "a a:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Vendor a a is not valid' + "error_message": "Vendor a a is not valid", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:List:Bucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:List:Bucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Actions/Condition can contain only one colon.' + "error_message": "Actions/Condition can contain only one colon.", }, { "document": { @@ -169,16 +161,16 @@ invalid_policy_document_test_cases = [ { "Effect": "Allow", "Action": "s3s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, { "Effect": "Allow", "Action": "s:3s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - ] + "Resource": "arn:aws:s3:::example_bucket", + }, + ], }, - "error_message": 'Actions/Condition can contain only one colon.' + "error_message": "Actions/Condition can contain only one colon.", }, { "document": { @@ -186,10 +178,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "invalid resource" - } + "Resource": "invalid resource", + }, }, - "error_message": 'Resource invalid resource must be in ARN format or "*".' + "error_message": 'Resource invalid resource must be in ARN format or "*".', }, { "document": { @@ -198,39 +190,32 @@ invalid_policy_document_test_cases = [ { "Sid": "EnableDisableHongKong", "Effect": "Allow", - "Action": [ - "account:EnableRegion", - "account:DisableRegion" - ], + "Action": ["account:EnableRegion", "account:DisableRegion"], "Resource": "", "Condition": { "StringEquals": {"account:TargetRegion": "ap-east-1"} - } + }, }, { "Sid": "ViewConsole", "Effect": "Allow", - "Action": [ - "aws-portal:ViewAccount", - "account:ListRegions" - ], - "Resource": "" - } - ] + "Action": ["aws-portal:ViewAccount", "account:ListRegions"], + "Resource": "", + }, + ], }, - "error_message": 'Resource must be in ARN format or "*".' + "error_message": 'Resource must be in ARN format or "*".', }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s:3:ListBucket", - "Resource": "sdfsadf" - } + "Statement": { + "Effect": "Allow", + "Action": "s:3:ListBucket", + "Resource": "sdfsadf", + }, }, - "error_message": 'Resource sdfsadf must be in ARN format or "*".' + "error_message": 'Resource sdfsadf must be in ARN format or "*".', }, { "document": { @@ -238,10 +223,50 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": ["adf"] - } + "Resource": ["adf"], + }, }, - "error_message": 'Resource adf must be in ARN format or "*".' + "error_message": 'Resource adf must be in ARN format or "*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": {"Effect": "Allow", "Action": "s3:ListBucket", "Resource": ""}, + }, + "error_message": 'Resource must be in ARN format or "*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Resource": "a:bsdfdsafsad", + }, + }, + "error_message": 'Partition "bsdfdsafsad" is not valid for resource "arn:bsdfdsafsad:*:*:*:*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Resource": "a:b:cadfsdf", + }, + }, + "error_message": 'Partition "b" is not valid for resource "arn:b:cadfsdf:*:*:*".', + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Resource": "a:b:c:d:e:f:g:h", + }, + }, + "error_message": 'Partition "b" is not valid for resource "arn:b:c:d:e:f:g:h".', }, { "document": { @@ -249,57 +274,10 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "" - } + "Resource": "aws:s3:::example_bucket", + }, }, - "error_message": 'Resource must be in ARN format or "*".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Resource": "a:bsdfdsafsad" - } - }, - "error_message": 'Partition "bsdfdsafsad" is not valid for resource "arn:bsdfdsafsad:*:*:*:*".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Resource": "a:b:cadfsdf" - } - }, - "error_message": 'Partition "b" is not valid for resource "arn:b:cadfsdf:*:*:*".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Resource": "a:b:c:d:e:f:g:h" - } - }, - "error_message": 'Partition "b" is not valid for resource "arn:b:c:d:e:f:g:h".' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "aws:s3:::example_bucket" - } - }, - "error_message": 'Partition "s3" is not valid for resource "arn:s3:::example_bucket:*".' + "error_message": 'Partition "s3" is not valid for resource "arn:s3:::example_bucket:*".', }, { "document": { @@ -309,166 +287,133 @@ invalid_policy_document_test_cases = [ "Action": "s3:ListBucket", "Resource": [ "arn:error:s3:::example_bucket", - "arn:error:s3::example_bucket" - ] - } + "arn:error:s3::example_bucket", + ], + }, }, - "error_message": 'Partition "error" is not valid for resource "arn:error:s3:::example_bucket".' + "error_message": 'Partition "error" is not valid for resource "arn:error:s3:::example_bucket".', + }, + { + "document": {"Version": "2012-10-17", "Statement": []}, + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": [] + "Statement": {"Effect": "Allow", "Action": "s3:ListBucket"}, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Policy statement must contain resources.", }, { "document": { "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket" - } + "Statement": {"Effect": "Allow", "Action": "s3:ListBucket", "Resource": []}, }, - "error_message": 'Policy statement must contain resources.' + "error_message": "Policy statement must contain resources.", }, { "document": { "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": [] - } + "Statement": {"Effect": "Allow", "Action": "invalid"}, }, - "error_message": 'Policy statement must contain resources.' + "error_message": "Policy statement must contain resources.", }, { "document": { "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "invalid" - } + "Statement": {"Effect": "Allow", "Resource": "arn:aws:s3:::example_bucket"}, }, - "error_message": 'Policy statement must contain resources.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Resource": "arn:aws:s3:::example_bucket" - } - }, - "error_message": 'Policy statement must contain actions.' + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": { "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": {"Version": "2012-10-17", "Statement": {"Effect": "Allow"}}, + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": { - "Effect": "Allow" - } + "Effect": "Allow", + "Action": [], + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Policy statement must contain actions.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": [], - "Resource": "arn:aws:s3:::example_bucket" - } - }, - "error_message": 'Policy statement must contain actions.' + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": [ + {"Effect": "Deny"}, { - "Effect": "Deny" + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - ] + ], }, - "error_message": 'Policy statement must contain actions.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:iam:::example_bucket" - } - }, - "error_message": 'IAM resource path must either be "*" or start with user/, federated-user/, role/, group/, instance-profile/, mfa/, server-certificate/, policy/, sms-mfa/, saml-provider/, oidc-provider/, report/, access-report/.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3::example_bucket" - } - }, - "error_message": 'The policy failed legacy parsing' + "error_message": "Policy statement must contain actions.", }, { "document": { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Resource": "arn:aws:s3::example_bucket" - } + "Action": "s3:ListBucket", + "Resource": "arn:aws:iam:::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": 'IAM resource path must either be "*" or start with user/, federated-user/, role/, group/, instance-profile/, mfa/, server-certificate/, policy/, sms-mfa/, saml-provider/, oidc-provider/, report/, access-report/.', }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3::example_bucket", + }, }, - "error_message": 'Resource vendor must be fully qualified and cannot contain regexes.' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": { - "a": "arn:aws:s3:::example_bucket" - } - } + "Statement": {"Effect": "Allow", "Resource": "arn:aws:s3::example_bucket"}, }, - "error_message": 'Syntax errors in policy.' + "error_message": "The policy failed legacy parsing", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws", + }, + }, + "error_message": "Resource vendor must be fully qualified and cannot contain regexes.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": {"a": "arn:aws:s3:::example_bucket"}, + }, + }, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -476,23 +421,22 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Deny", "Action": "s3:ListBucket", - "Resource": ["adfdf", {}] - } + "Resource": ["adfdf", {}], + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "NotResource": [] - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "NotResource": [], + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -500,135 +444,33 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Deny", "Action": [[]], - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3s:ListBucket", - "Action": [], - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3s:ListBucket", + "Action": [], + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": {}, - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": {}, + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": [] - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": "a" - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "a": "b" - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": "b" - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": [] - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": {}} - } - } - }, - "error_message": 'Syntax errors in policy.' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": {}} - } - } - }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -637,14 +479,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "x": { - "a": "1" - } - } - } + "Condition": [], + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -653,79 +491,153 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "ForAnyValue::StringEqualsIfExists": { - "a": "asf" - } - } - } + "Condition": "a", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": [ - {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}} - ] - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"a": "b"}, + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:iam:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": "b"}, + }, }, - "error_message": 'IAM resource arn:aws:iam:us-east-1::example_bucket cannot contain region information.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": []}, + }, }, - "error_message": 'Resource arn:aws:s3:us-east-1::example_bucket can not contain region information.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Sid": {}, - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": {}}}, + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Sid": [], - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": {}}}, + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"x": {"a": "1"}}, + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"ForAnyValue::StringEqualsIfExists": {"a": "asf"}}, + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": [ + {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}} + ], + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:iam:us-east-1::example_bucket", + }, + }, + "error_message": "IAM resource arn:aws:iam:us-east-1::example_bucket cannot contain region information.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:us-east-1::example_bucket", + }, + }, + "error_message": "Resource arn:aws:s3:us-east-1::example_bucket can not contain region information.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Sid": {}, + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, + }, + "error_message": "Syntax errors in policy.", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Sid": [], + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, + }, + "error_message": "Syntax errors in policy.", }, { "document": { @@ -735,15 +647,12 @@ invalid_policy_document_test_cases = [ "Sid": "sdf", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Sid": "sdf", - "Effect": "Allow" - } - ] + {"Sid": "sdf", "Effect": "Allow"}, + ], }, - "error_message": 'Statement IDs (SID) in a single policy must be unique.' + "error_message": "Statement IDs (SID) in a single policy must be unique.", }, { "document": { @@ -752,15 +661,12 @@ invalid_policy_document_test_cases = [ "Sid": "sdf", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Sid": "sdf", - "Effect": "Allow" - } + {"Sid": "sdf", "Effect": "Allow"}, ] }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { @@ -769,10 +675,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "NotAction": "s3:ListBucket", "Action": "iam:dsf", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -781,10 +687,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "NotResource": "*" - } + "NotResource": "*", + }, }, - "error_message": 'Syntax errors in policy.' + "error_message": "Syntax errors in policy.", }, { "document": { @@ -792,85 +698,74 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "denY", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": "sdfdsf"} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": "sdfdsf"} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + } }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { "Statement": { "Effect": "denY", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", } }, - "error_message": 'Policy document must be version 2012-10-17 or greater.' + "error_message": "Policy document must be version 2012-10-17 or greater.", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Condition": { - "DateGreaterThan": {"a": "sdfdsf"} - } - } + "Statement": { + "Effect": "Allow", + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3:ListBucket", - "Resource": "arn:aws::::example_bucket" - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3:ListBucket", + "Resource": "arn:aws::::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { "Version": "2012-10-17", - "Statement": - { - "Effect": "allow", - "Resource": "arn:aws:s3:us-east-1::example_bucket" - } + "Statement": { + "Effect": "allow", + "Resource": "arn:aws:s3:us-east-1::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -880,15 +775,12 @@ invalid_policy_document_test_cases = [ "Sid": "sdf", "Effect": "aLLow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, - { - "Sid": "sdf", - "Effect": "Allow" - } - ] + {"Sid": "sdf", "Effect": "Allow"}, + ], }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -896,10 +788,22 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "NotResource": "arn:aws:s3::example_bucket" - } + "NotResource": "arn:aws:s3::example_bucket", + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", + }, + { + "document": { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateLessThanEquals": {"a": "234-13"}}, + }, + }, + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -909,13 +813,11 @@ invalid_policy_document_test_cases = [ "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", "Condition": { - "DateLessThanEquals": { - "a": "234-13" - } - } - } + "DateLessThanEquals": {"a": "2016-12-13t2:00:00.593194+1"} + }, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -925,13 +827,11 @@ invalid_policy_document_test_cases = [ "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13t2:00:00.593194+1" - } - } - } + "DateLessThanEquals": {"a": "2016-12-13t2:00:00.1999999999+10:59"} + }, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -940,30 +840,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13t2:00:00.1999999999+10:59" - } - } - } + "Condition": {"DateLessThan": {"a": "9223372036854775808"}}, + }, }, - "error_message": 'The policy failed legacy parsing' - }, - { - "document": { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThan": { - "a": "9223372036854775808" - } - } - } - }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -972,14 +852,10 @@ invalid_policy_document_test_cases = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:error:s3:::example_bucket", - "Condition": { - "DateGreaterThan": { - "a": "sdfdsf" - } - } - } + "Condition": {"DateGreaterThan": {"a": "sdfdsf"}}, + }, }, - "error_message": 'The policy failed legacy parsing' + "error_message": "The policy failed legacy parsing", }, { "document": { @@ -987,11 +863,11 @@ invalid_policy_document_test_cases = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws::fdsasf" - } + "Resource": "arn:aws::fdsasf", + }, }, - "error_message": 'The policy failed legacy parsing' - } + "error_message": "The policy failed legacy parsing", + }, ] valid_policy_documents = [ @@ -1000,37 +876,32 @@ valid_policy_documents = [ "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": [ - "arn:aws:s3:::example_bucket" - ] - } + "Resource": ["arn:aws:s3:::example_bucket"], + }, }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "Action": "iam: asdf safdsf af ", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": [ - "arn:aws:s3:::example_bucket", - "*" - ] - } + "Resource": ["arn:aws:s3:::example_bucket", "*"], + }, }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "Action": "*", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", @@ -1038,9 +909,9 @@ valid_policy_documents = [ { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", } - ] + ], }, { "Version": "2012-10-17", @@ -1050,160 +921,139 @@ valid_policy_documents = [ "Resource": "*", "Condition": { "DateGreaterThan": {"aws:CurrentTime": "2017-07-01T00:00:00Z"}, - "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"} - } - } + "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"}, + }, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "fsx:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "fsx:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:iam:::user/example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:iam:::user/example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s33:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s33:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:fdsasf" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:fdsasf", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": {} - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}} - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"ForAllValues:StringEquals": {"aws:TagKeys": "Department"}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:cloudwatch:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:cloudwatch:us-east-1::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:ec2:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:ec2:us-east-1::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:invalid-service:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:invalid-service:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:invalid-service:us-east-1::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:invalid-service:us-east-1::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"aws:CurrentTime": "2017-07-01T00:00:00Z"}, - "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": { + "DateGreaterThan": {"aws:CurrentTime": "2017-07-01T00:00:00Z"}, + "DateLessThan": {"aws:CurrentTime": "2017-12-31T23:59:59Z"}, + }, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateGreaterThan": {"a": []} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": []}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "a": {} - } - } + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"a": {}}, + }, }, { "Version": "2012-10-17", - "Statement": - { - "Sid": "dsfsdfsdfsdfsdfsadfsd", - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Sid": "dsfsdfsdfsdfsdfsadfsd", + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", @@ -1217,37 +1067,29 @@ valid_policy_documents = [ "iam:ListRoles", "iam:ListRoleTags", "iam:ListUsers", - "iam:ListUserTags" + "iam:ListUserTags", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "AddTag", "Effect": "Allow", - "Action": [ - "iam:TagUser", - "iam:TagRole" - ], + "Action": ["iam:TagUser", "iam:TagRole"], "Resource": "*", "Condition": { - "StringEquals": { - "aws:RequestTag/CostCenter": [ - "A-123", - "B-456" - ] - }, - "ForAllValues:StringEquals": {"aws:TagKeys": "CostCenter"} - } - } - ] + "StringEquals": {"aws:RequestTag/CostCenter": ["A-123", "B-456"]}, + "ForAllValues:StringEquals": {"aws:TagKeys": "CostCenter"}, + }, + }, + ], }, { "Version": "2012-10-17", "Statement": { "Effect": "Allow", "NotAction": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", @@ -1256,9 +1098,9 @@ valid_policy_documents = [ "Action": "s3:*", "NotResource": [ "arn:aws:s3:::HRBucket/Payroll", - "arn:aws:s3:::HRBucket/Payroll/*" - ] - } + "arn:aws:s3:::HRBucket/Payroll/*", + ], + }, }, { "Version": "2012-10-17", @@ -1266,44 +1108,40 @@ valid_policy_documents = [ "Statement": { "Effect": "Allow", "NotAction": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "aaaaaadsfdsafsadfsadfaaaaa:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "aaaaaadsfdsafsadfsadfaaaaa:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3-s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3-s:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "Action": "s3.s:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } + "Statement": { + "Effect": "Allow", + "Action": "s3.s:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + }, }, { "Version": "2012-10-17", - "Statement": - { - "Effect": "Allow", - "NotAction": "s3:ListBucket", - "NotResource": "*" - } + "Statement": { + "Effect": "Allow", + "NotAction": "s3:ListBucket", + "NotResource": "*", + }, }, { "Version": "2012-10-17", @@ -1312,14 +1150,59 @@ valid_policy_documents = [ "Sid": "sdf", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" + "Resource": "arn:aws:s3:::example_bucket", }, { "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket" - } - ] + "Resource": "arn:aws:s3:::example_bucket", + }, + ], + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateGreaterThan": {"a": "01T"}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"x": {}, "y": {}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"StringEqualsIfExists": {"a": "asf"}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"ForAnyValue:StringEqualsIfExists": {"a": "asf"}}, + }, + }, + { + "Version": "2012-10-17", + "Statement": { + "Effect": "Allow", + "Action": "s3:ListBucket", + "Resource": "arn:aws:s3:::example_bucket", + "Condition": {"DateLessThanEquals": {"a": "2019-07-01T13:20:15Z"}}, + }, }, { "Version": "2012-10-17", @@ -1328,11 +1211,9 @@ valid_policy_documents = [ "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", "Condition": { - "DateGreaterThan": { - "a": "01T" - } - } - } + "DateLessThanEquals": {"a": "2016-12-13T21:20:37.593194+00:00"} + }, + }, }, { "Version": "2012-10-17", @@ -1340,12 +1221,8 @@ valid_policy_documents = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "x": { - }, - "y": {} - } - } + "Condition": {"DateLessThanEquals": {"a": "2016-12-13t2:00:00.593194+23"}}, + }, }, { "Version": "2012-10-17", @@ -1353,77 +1230,8 @@ valid_policy_documents = [ "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "StringEqualsIfExists": { - "a": "asf" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "ForAnyValue:StringEqualsIfExists": { - "a": "asf" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2019-07-01T13:20:15Z" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13T21:20:37.593194+00:00" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThanEquals": { - "a": "2016-12-13t2:00:00.593194+23" - } - } - } - }, - { - "Version": "2012-10-17", - "Statement": { - "Effect": "Allow", - "Action": "s3:ListBucket", - "Resource": "arn:aws:s3:::example_bucket", - "Condition": { - "DateLessThan": { - "a": "-292275054" - } - } - } + "Condition": {"DateLessThan": {"a": "-292275054"}}, + }, }, { "Version": "2012-10-17", @@ -1434,18 +1242,15 @@ valid_policy_documents = [ "Action": [ "iam:GetAccountPasswordPolicy", "iam:GetAccountSummary", - "iam:ListVirtualMFADevices" + "iam:ListVirtualMFADevices", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "AllowManageOwnPasswords", "Effect": "Allow", - "Action": [ - "iam:ChangePassword", - "iam:GetUser" - ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Action": ["iam:ChangePassword", "iam:GetUser"], + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnAccessKeys", @@ -1454,9 +1259,9 @@ valid_policy_documents = [ "iam:CreateAccessKey", "iam:DeleteAccessKey", "iam:ListAccessKeys", - "iam:UpdateAccessKey" + "iam:UpdateAccessKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSigningCertificates", @@ -1465,9 +1270,9 @@ valid_policy_documents = [ "iam:DeleteSigningCertificate", "iam:ListSigningCertificates", "iam:UpdateSigningCertificate", - "iam:UploadSigningCertificate" + "iam:UploadSigningCertificate", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSSHPublicKeys", @@ -1477,9 +1282,9 @@ valid_policy_documents = [ "iam:GetSSHPublicKey", "iam:ListSSHPublicKeys", "iam:UpdateSSHPublicKey", - "iam:UploadSSHPublicKey" + "iam:UploadSSHPublicKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnGitCredentials", @@ -1489,18 +1294,15 @@ valid_policy_documents = [ "iam:DeleteServiceSpecificCredential", "iam:ListServiceSpecificCredentials", "iam:ResetServiceSpecificCredential", - "iam:UpdateServiceSpecificCredential" + "iam:UpdateServiceSpecificCredential", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnVirtualMFADevice", "Effect": "Allow", - "Action": [ - "iam:CreateVirtualMFADevice", - "iam:DeleteVirtualMFADevice" - ], - "Resource": "arn:aws:iam::*:mfa/${aws:username}" + "Action": ["iam:CreateVirtualMFADevice", "iam:DeleteVirtualMFADevice"], + "Resource": "arn:aws:iam::*:mfa/${aws:username}", }, { "Sid": "AllowManageOwnUserMFA", @@ -1509,9 +1311,9 @@ valid_policy_documents = [ "iam:DeactivateMFADevice", "iam:EnableMFADevice", "iam:ListMFADevices", - "iam:ResyncMFADevice" + "iam:ResyncMFADevice", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "DenyAllExceptListedIfNoMFA", @@ -1523,16 +1325,12 @@ valid_policy_documents = [ "iam:ListMFADevices", "iam:ListVirtualMFADevices", "iam:ResyncMFADevice", - "sts:GetSessionToken" + "sts:GetSessionToken", ], "Resource": "*", - "Condition": { - "BoolIfExists": { - "aws:MultiFactorAuthPresent": "false" - } - } - } - ] + "Condition": {"BoolIfExists": {"aws:MultiFactorAuthPresent": "false"}}, + }, + ], }, { "Version": "2012-10-17", @@ -1544,9 +1342,9 @@ valid_policy_documents = [ "dynamodb:List*", "dynamodb:DescribeReservedCapacity*", "dynamodb:DescribeLimits", - "dynamodb:DescribeTimeToLive" + "dynamodb:DescribeTimeToLive", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "SpecificTable", @@ -1562,57 +1360,47 @@ valid_policy_documents = [ "dynamodb:CreateTable", "dynamodb:Delete*", "dynamodb:Update*", - "dynamodb:PutItem" + "dynamodb:PutItem", ], - "Resource": "arn:aws:dynamodb:*:*:table/MyTable" - } - ] + "Resource": "arn:aws:dynamodb:*:*:table/MyTable", + }, + ], }, { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Action": [ - "ec2:AttachVolume", - "ec2:DetachVolume" - ], - "Resource": [ - "arn:aws:ec2:*:*:volume/*", - "arn:aws:ec2:*:*:instance/*" - ], + "Action": ["ec2:AttachVolume", "ec2:DetachVolume"], + "Resource": ["arn:aws:ec2:*:*:volume/*", "arn:aws:ec2:*:*:instance/*"], "Condition": { - "ArnEquals": {"ec2:SourceInstanceARN": "arn:aws:ec2:*:*:instance/instance-id"} - } + "ArnEquals": { + "ec2:SourceInstanceARN": "arn:aws:ec2:*:*:instance/instance-id" + } + }, } - ] + ], }, { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Action": [ - "ec2:AttachVolume", - "ec2:DetachVolume" - ], + "Action": ["ec2:AttachVolume", "ec2:DetachVolume"], "Resource": "arn:aws:ec2:*:*:instance/*", "Condition": { "StringEquals": {"ec2:ResourceTag/Department": "Development"} - } + }, }, { "Effect": "Allow", - "Action": [ - "ec2:AttachVolume", - "ec2:DetachVolume" - ], + "Action": ["ec2:AttachVolume", "ec2:DetachVolume"], "Resource": "arn:aws:ec2:*:*:volume/*", "Condition": { "StringEquals": {"ec2:ResourceTag/VolumeUser": "${aws:username}"} - } - } - ] + }, + }, + ], }, { "Version": "2012-10-17", @@ -1623,17 +1411,17 @@ valid_policy_documents = [ "Action": [ "ec2:StartInstances", "ec2:StopInstances", - "ec2:DescribeTags" + "ec2:DescribeTags", ], "Resource": "arn:aws:ec2:region:account-id:instance/*", "Condition": { "StringEquals": { "ec2:ResourceTag/Project": "DataAnalytics", - "aws:PrincipalTag/Department": "Data" + "aws:PrincipalTag/Department": "Data", } - } + }, } - ] + ], }, { "Version": "2012-10-17", @@ -1645,59 +1433,48 @@ valid_policy_documents = [ "Resource": ["arn:aws:s3:::bucket-name"], "Condition": { "StringLike": { - "s3:prefix": ["cognito/application-name/${cognito-identity.amazonaws.com:sub}"] + "s3:prefix": [ + "cognito/application-name/${cognito-identity.amazonaws.com:sub}" + ] } - } + }, }, { "Sid": "ReadWriteDeleteYourObjects", "Effect": "Allow", - "Action": [ - "s3:GetObject", - "s3:PutObject", - "s3:DeleteObject" - ], + "Action": ["s3:GetObject", "s3:PutObject", "s3:DeleteObject"], "Resource": [ "arn:aws:s3:::bucket-name/cognito/application-name/${cognito-identity.amazonaws.com:sub}", - "arn:aws:s3:::bucket-name/cognito/application-name/${cognito-identity.amazonaws.com:sub}/*" - ] - } - ] + "arn:aws:s3:::bucket-name/cognito/application-name/${cognito-identity.amazonaws.com:sub}/*", + ], + }, + ], }, { "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", - "Action": [ - "s3:ListAllMyBuckets", - "s3:GetBucketLocation" - ], - "Resource": "*" + "Action": ["s3:ListAllMyBuckets", "s3:GetBucketLocation"], + "Resource": "*", }, { "Effect": "Allow", "Action": "s3:ListBucket", "Resource": "arn:aws:s3:::bucket-name", "Condition": { - "StringLike": { - "s3:prefix": [ - "", - "home/", - "home/${aws:userid}/*" - ] - } - } + "StringLike": {"s3:prefix": ["", "home/", "home/${aws:userid}/*"]} + }, }, { "Effect": "Allow", "Action": "s3:*", "Resource": [ "arn:aws:s3:::bucket-name/home/${aws:userid}", - "arn:aws:s3:::bucket-name/home/${aws:userid}/*" - ] - } - ] + "arn:aws:s3:::bucket-name/home/${aws:userid}/*", + ], + }, + ], }, { "Version": "2012-10-17", @@ -1711,23 +1488,23 @@ valid_policy_documents = [ "s3:GetBucketLocation", "s3:GetBucketPolicyStatus", "s3:GetBucketPublicAccessBlock", - "s3:ListAllMyBuckets" + "s3:ListAllMyBuckets", ], - "Resource": "*" + "Resource": "*", }, { "Sid": "ListObjectsInBucket", "Effect": "Allow", "Action": "s3:ListBucket", - "Resource": ["arn:aws:s3:::bucket-name"] + "Resource": ["arn:aws:s3:::bucket-name"], }, { "Sid": "AllObjectActions", "Effect": "Allow", "Action": "s3:*Object", - "Resource": ["arn:aws:s3:::bucket-name/*"] - } - ] + "Resource": ["arn:aws:s3:::bucket-name/*"], + }, + ], }, { "Version": "2012-10-17", @@ -1735,20 +1512,14 @@ valid_policy_documents = [ { "Sid": "AllowViewAccountInfo", "Effect": "Allow", - "Action": [ - "iam:GetAccountPasswordPolicy", - "iam:GetAccountSummary" - ], - "Resource": "*" + "Action": ["iam:GetAccountPasswordPolicy", "iam:GetAccountSummary"], + "Resource": "*", }, { "Sid": "AllowManageOwnPasswords", "Effect": "Allow", - "Action": [ - "iam:ChangePassword", - "iam:GetUser" - ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Action": ["iam:ChangePassword", "iam:GetUser"], + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnAccessKeys", @@ -1757,9 +1528,9 @@ valid_policy_documents = [ "iam:CreateAccessKey", "iam:DeleteAccessKey", "iam:ListAccessKeys", - "iam:UpdateAccessKey" + "iam:UpdateAccessKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSigningCertificates", @@ -1768,9 +1539,9 @@ valid_policy_documents = [ "iam:DeleteSigningCertificate", "iam:ListSigningCertificates", "iam:UpdateSigningCertificate", - "iam:UploadSigningCertificate" + "iam:UploadSigningCertificate", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnSSHPublicKeys", @@ -1780,9 +1551,9 @@ valid_policy_documents = [ "iam:GetSSHPublicKey", "iam:ListSSHPublicKeys", "iam:UpdateSSHPublicKey", - "iam:UploadSSHPublicKey" + "iam:UploadSSHPublicKey", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" + "Resource": "arn:aws:iam::*:user/${aws:username}", }, { "Sid": "AllowManageOwnGitCredentials", @@ -1792,11 +1563,11 @@ valid_policy_documents = [ "iam:DeleteServiceSpecificCredential", "iam:ListServiceSpecificCredentials", "iam:ResetServiceSpecificCredential", - "iam:UpdateServiceSpecificCredential" + "iam:UpdateServiceSpecificCredential", ], - "Resource": "arn:aws:iam::*:user/${aws:username}" - } - ] + "Resource": "arn:aws:iam::*:user/${aws:username}", + }, + ], }, { "Version": "2012-10-17", @@ -1805,13 +1576,9 @@ valid_policy_documents = [ "Action": "ec2:*", "Resource": "*", "Effect": "Allow", - "Condition": { - "StringEquals": { - "ec2:Region": "region" - } - } + "Condition": {"StringEquals": {"ec2:Region": "region"}}, } - ] + ], }, { "Version": "2012-10-17", @@ -1819,14 +1586,10 @@ valid_policy_documents = [ { "Effect": "Allow", "Action": "rds:*", - "Resource": ["arn:aws:rds:region:*:*"] + "Resource": ["arn:aws:rds:region:*:*"], }, - { - "Effect": "Allow", - "Action": ["rds:Describe*"], - "Resource": ["*"] - } - ] + {"Effect": "Allow", "Action": ["rds:Describe*"], "Resource": ["*"]}, + ], }, { "Version": "2012-10-17", @@ -1835,16 +1598,16 @@ valid_policy_documents = [ "Sid": "", "Effect": "Allow", "Action": "rds:*", - "Resource": ["arn:aws:rds:region:*:*"] + "Resource": ["arn:aws:rds:region:*:*"], }, { "Sid": "", "Effect": "Allow", "Action": ["rds:Describe*"], - "Resource": ["*"] - } - ] - } + "Resource": ["*"], + }, + ], + }, ] @@ -1860,19 +1623,20 @@ def test_create_policy_with_valid_policy_documents(): @mock_iam def check_create_policy_with_invalid_policy_document(test_case): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") with assert_raises(ClientError) as ex: conn.create_policy( PolicyName="TestCreatePolicy", - PolicyDocument=json.dumps(test_case["document"])) - ex.exception.response['Error']['Code'].should.equal('MalformedPolicyDocument') - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(400) - ex.exception.response['Error']['Message'].should.equal(test_case["error_message"]) + PolicyDocument=json.dumps(test_case["document"]), + ) + ex.exception.response["Error"]["Code"].should.equal("MalformedPolicyDocument") + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(400) + ex.exception.response["Error"]["Message"].should.equal(test_case["error_message"]) @mock_iam def check_create_policy_with_valid_policy_document(valid_policy_document): - conn = boto3.client('iam', region_name='us-east-1') + conn = boto3.client("iam", region_name="us-east-1") conn.create_policy( - PolicyName="TestCreatePolicy", - PolicyDocument=json.dumps(valid_policy_document)) + PolicyName="TestCreatePolicy", PolicyDocument=json.dumps(valid_policy_document) + ) diff --git a/tests/test_iam/test_server.py b/tests/test_iam/test_server.py index 59aaf1462..4d1698424 100644 --- a/tests/test_iam/test_server.py +++ b/tests/test_iam/test_server.py @@ -7,9 +7,9 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_iam_server_get(): @@ -17,7 +17,8 @@ def test_iam_server_get(): test_client = backend.test_client() group_data = test_client.action_data( - "CreateGroup", GroupName="test group", Path="/") + "CreateGroup", GroupName="test group", Path="/" + ) group_id = re.search("(.*)", group_data).groups()[0] groups_data = test_client.action_data("ListGroups") diff --git a/tests/test_iot/test_iot.py b/tests/test_iot/test_iot.py index 23d4e7876..713dc2977 100644 --- a/tests/test_iot/test_iot.py +++ b/tests/test_iot/test_iot.py @@ -11,280 +11,318 @@ from nose.tools import assert_raises @mock_iot def test_things(): - client = boto3.client('iot', region_name='ap-northeast-1') - name = 'my-thing' - type_name = 'my-type-name' + client = boto3.client("iot", region_name="ap-northeast-1") + name = "my-thing" + type_name = "my-type-name" # thing type thing_type = client.create_thing_type(thingTypeName=type_name) - thing_type.should.have.key('thingTypeName').which.should.equal(type_name) - thing_type.should.have.key('thingTypeArn') + thing_type.should.have.key("thingTypeName").which.should.equal(type_name) + thing_type.should.have.key("thingTypeArn") res = client.list_thing_types() - res.should.have.key('thingTypes').which.should.have.length_of(1) - for thing_type in res['thingTypes']: - thing_type.should.have.key('thingTypeName').which.should_not.be.none + res.should.have.key("thingTypes").which.should.have.length_of(1) + for thing_type in res["thingTypes"]: + thing_type.should.have.key("thingTypeName").which.should_not.be.none thing_type = client.describe_thing_type(thingTypeName=type_name) - thing_type.should.have.key('thingTypeName').which.should.equal(type_name) - thing_type.should.have.key('thingTypeProperties') - thing_type.should.have.key('thingTypeMetadata') + thing_type.should.have.key("thingTypeName").which.should.equal(type_name) + thing_type.should.have.key("thingTypeProperties") + thing_type.should.have.key("thingTypeMetadata") # thing thing = client.create_thing(thingName=name, thingTypeName=type_name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") res = client.list_things() - res.should.have.key('things').which.should.have.length_of(1) - for thing in res['things']: - thing.should.have.key('thingName').which.should_not.be.none - thing.should.have.key('thingArn').which.should_not.be.none + res.should.have.key("things").which.should.have.length_of(1) + for thing in res["things"]: + thing.should.have.key("thingName").which.should_not.be.none + thing.should.have.key("thingArn").which.should_not.be.none - thing = client.update_thing(thingName=name, attributePayload={'attributes': {'k1': 'v1'}}) + thing = client.update_thing( + thingName=name, attributePayload={"attributes": {"k1": "v1"}} + ) res = client.list_things() - res.should.have.key('things').which.should.have.length_of(1) - for thing in res['things']: - thing.should.have.key('thingName').which.should_not.be.none - thing.should.have.key('thingArn').which.should_not.be.none - res['things'][0]['attributes'].should.have.key('k1').which.should.equal('v1') + res.should.have.key("things").which.should.have.length_of(1) + for thing in res["things"]: + thing.should.have.key("thingName").which.should_not.be.none + thing.should.have.key("thingArn").which.should_not.be.none + res["things"][0]["attributes"].should.have.key("k1").which.should.equal("v1") thing = client.describe_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('defaultClientId') - thing.should.have.key('thingTypeName') - thing.should.have.key('attributes') - thing.should.have.key('version') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("defaultClientId") + thing.should.have.key("thingTypeName") + thing.should.have.key("attributes") + thing.should.have.key("version") # delete thing client.delete_thing(thingName=name) res = client.list_things() - res.should.have.key('things').which.should.have.length_of(0) + res.should.have.key("things").which.should.have.length_of(0) # delete thing type client.delete_thing_type(thingTypeName=type_name) res = client.list_thing_types() - res.should.have.key('thingTypes').which.should.have.length_of(0) + res.should.have.key("thingTypes").which.should.have.length_of(0) @mock_iot def test_list_thing_types(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") for i in range(0, 100): client.create_thing_type(thingTypeName=str(i + 1)) thing_types = client.list_thing_types() - thing_types.should.have.key('nextToken') - thing_types.should.have.key('thingTypes').which.should.have.length_of(50) - thing_types['thingTypes'][0]['thingTypeName'].should.equal('1') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('50') + thing_types.should.have.key("nextToken") + thing_types.should.have.key("thingTypes").which.should.have.length_of(50) + thing_types["thingTypes"][0]["thingTypeName"].should.equal("1") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("50") - thing_types = client.list_thing_types(nextToken=thing_types['nextToken']) - thing_types.should.have.key('thingTypes').which.should.have.length_of(50) - thing_types.should_not.have.key('nextToken') - thing_types['thingTypes'][0]['thingTypeName'].should.equal('51') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('100') + thing_types = client.list_thing_types(nextToken=thing_types["nextToken"]) + thing_types.should.have.key("thingTypes").which.should.have.length_of(50) + thing_types.should_not.have.key("nextToken") + thing_types["thingTypes"][0]["thingTypeName"].should.equal("51") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("100") @mock_iot def test_list_thing_types_with_typename_filter(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") - client.create_thing_type(thingTypeName='thing') - client.create_thing_type(thingTypeName='thingType') - client.create_thing_type(thingTypeName='thingTypeName') - client.create_thing_type(thingTypeName='thingTypeNameGroup') - client.create_thing_type(thingTypeName='shouldNotFind') - client.create_thing_type(thingTypeName='find me it shall not') + client.create_thing_type(thingTypeName="thing") + client.create_thing_type(thingTypeName="thingType") + client.create_thing_type(thingTypeName="thingTypeName") + client.create_thing_type(thingTypeName="thingTypeNameGroup") + client.create_thing_type(thingTypeName="shouldNotFind") + client.create_thing_type(thingTypeName="find me it shall not") - thing_types = client.list_thing_types(thingTypeName='thing') - thing_types.should_not.have.key('nextToken') - thing_types.should.have.key('thingTypes').which.should.have.length_of(4) - thing_types['thingTypes'][0]['thingTypeName'].should.equal('thing') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('thingTypeNameGroup') + thing_types = client.list_thing_types(thingTypeName="thing") + thing_types.should_not.have.key("nextToken") + thing_types.should.have.key("thingTypes").which.should.have.length_of(4) + thing_types["thingTypes"][0]["thingTypeName"].should.equal("thing") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("thingTypeNameGroup") - thing_types = client.list_thing_types(thingTypeName='thingTypeName') - thing_types.should_not.have.key('nextToken') - thing_types.should.have.key('thingTypes').which.should.have.length_of(2) - thing_types['thingTypes'][0]['thingTypeName'].should.equal('thingTypeName') - thing_types['thingTypes'][-1]['thingTypeName'].should.equal('thingTypeNameGroup') + thing_types = client.list_thing_types(thingTypeName="thingTypeName") + thing_types.should_not.have.key("nextToken") + thing_types.should.have.key("thingTypes").which.should.have.length_of(2) + thing_types["thingTypes"][0]["thingTypeName"].should.equal("thingTypeName") + thing_types["thingTypes"][-1]["thingTypeName"].should.equal("thingTypeNameGroup") @mock_iot def test_list_things_with_next_token(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") for i in range(0, 200): client.create_thing(thingName=str(i + 1)) things = client.list_things() - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('1') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/1') - things['things'][-1]['thingName'].should.equal('50') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/50') + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("1") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/1") + things["things"][-1]["thingName"].should.equal("50") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/50" + ) - things = client.list_things(nextToken=things['nextToken']) - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('51') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/51') - things['things'][-1]['thingName'].should.equal('100') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/100') + things = client.list_things(nextToken=things["nextToken"]) + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("51") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/51" + ) + things["things"][-1]["thingName"].should.equal("100") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/100" + ) - things = client.list_things(nextToken=things['nextToken']) - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('101') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/101') - things['things'][-1]['thingName'].should.equal('150') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/150') + things = client.list_things(nextToken=things["nextToken"]) + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("101") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/101" + ) + things["things"][-1]["thingName"].should.equal("150") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/150" + ) - things = client.list_things(nextToken=things['nextToken']) - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('151') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/151') - things['things'][-1]['thingName'].should.equal('200') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/200') + things = client.list_things(nextToken=things["nextToken"]) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("151") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/151" + ) + things["things"][-1]["thingName"].should.equal("200") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/200" + ) @mock_iot def test_list_things_with_attribute_and_thing_type_filter_and_next_token(): - client = boto3.client('iot', region_name='ap-northeast-1') - client.create_thing_type(thingTypeName='my-thing-type') + client = boto3.client("iot", region_name="ap-northeast-1") + client.create_thing_type(thingTypeName="my-thing-type") for i in range(0, 200): if not (i + 1) % 3: - attribute_payload = { - 'attributes': { - 'foo': 'bar' - } - } + attribute_payload = {"attributes": {"foo": "bar"}} elif not (i + 1) % 5: - attribute_payload = { - 'attributes': { - 'bar': 'foo' - } - } + attribute_payload = {"attributes": {"bar": "foo"}} else: attribute_payload = {} if not (i + 1) % 2: - thing_type_name = 'my-thing-type' - client.create_thing(thingName=str(i + 1), thingTypeName=thing_type_name, attributePayload=attribute_payload) + thing_type_name = "my-thing-type" + client.create_thing( + thingName=str(i + 1), + thingTypeName=thing_type_name, + attributePayload=attribute_payload, + ) else: - client.create_thing(thingName=str(i + 1), attributePayload=attribute_payload) + client.create_thing( + thingName=str(i + 1), attributePayload=attribute_payload + ) # Test filter for thingTypeName things = client.list_things(thingTypeName=thing_type_name) - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('2') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/2') - things['things'][-1]['thingName'].should.equal('100') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/100') - all(item['thingTypeName'] == thing_type_name for item in things['things']) + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("2") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/2") + things["things"][-1]["thingName"].should.equal("100") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/100" + ) + all(item["thingTypeName"] == thing_type_name for item in things["things"]) - things = client.list_things(nextToken=things['nextToken'], thingTypeName=thing_type_name) - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('102') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/102') - things['things'][-1]['thingName'].should.equal('200') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/200') - all(item['thingTypeName'] == thing_type_name for item in things['things']) + things = client.list_things( + nextToken=things["nextToken"], thingTypeName=thing_type_name + ) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("102") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/102" + ) + things["things"][-1]["thingName"].should.equal("200") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/200" + ) + all(item["thingTypeName"] == thing_type_name for item in things["things"]) # Test filter for attributes - things = client.list_things(attributeName='foo', attributeValue='bar') - things.should.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(50) - things['things'][0]['thingName'].should.equal('3') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/3') - things['things'][-1]['thingName'].should.equal('150') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/150') - all(item['attributes'] == {'foo': 'bar'} for item in things['things']) + things = client.list_things(attributeName="foo", attributeValue="bar") + things.should.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(50) + things["things"][0]["thingName"].should.equal("3") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/3") + things["things"][-1]["thingName"].should.equal("150") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/150" + ) + all(item["attributes"] == {"foo": "bar"} for item in things["things"]) - things = client.list_things(nextToken=things['nextToken'], attributeName='foo', attributeValue='bar') - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(16) - things['things'][0]['thingName'].should.equal('153') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/153') - things['things'][-1]['thingName'].should.equal('198') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/198') - all(item['attributes'] == {'foo': 'bar'} for item in things['things']) + things = client.list_things( + nextToken=things["nextToken"], attributeName="foo", attributeValue="bar" + ) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(16) + things["things"][0]["thingName"].should.equal("153") + things["things"][0]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/153" + ) + things["things"][-1]["thingName"].should.equal("198") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/198" + ) + all(item["attributes"] == {"foo": "bar"} for item in things["things"]) # Test filter for attributes and thingTypeName - things = client.list_things(thingTypeName=thing_type_name, attributeName='foo', attributeValue='bar') - things.should_not.have.key('nextToken') - things.should.have.key('things').which.should.have.length_of(33) - things['things'][0]['thingName'].should.equal('6') - things['things'][0]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/6') - things['things'][-1]['thingName'].should.equal('198') - things['things'][-1]['thingArn'].should.equal('arn:aws:iot:ap-northeast-1:1:thing/198') - all(item['attributes'] == {'foo': 'bar'} and item['thingTypeName'] == thing_type_name for item in things['things']) + things = client.list_things( + thingTypeName=thing_type_name, attributeName="foo", attributeValue="bar" + ) + things.should_not.have.key("nextToken") + things.should.have.key("things").which.should.have.length_of(33) + things["things"][0]["thingName"].should.equal("6") + things["things"][0]["thingArn"].should.equal("arn:aws:iot:ap-northeast-1:1:thing/6") + things["things"][-1]["thingName"].should.equal("198") + things["things"][-1]["thingArn"].should.equal( + "arn:aws:iot:ap-northeast-1:1:thing/198" + ) + all( + item["attributes"] == {"foo": "bar"} + and item["thingTypeName"] == thing_type_name + for item in things["things"] + ) @mock_iot def test_certs(): - client = boto3.client('iot', region_name='us-east-1') + client = boto3.client("iot", region_name="us-east-1") cert = client.create_keys_and_certificate(setAsActive=True) - cert.should.have.key('certificateArn').which.should_not.be.none - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('certificatePem').which.should_not.be.none - cert.should.have.key('keyPair') - cert['keyPair'].should.have.key('PublicKey').which.should_not.be.none - cert['keyPair'].should.have.key('PrivateKey').which.should_not.be.none - cert_id = cert['certificateId'] + cert.should.have.key("certificateArn").which.should_not.be.none + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("certificatePem").which.should_not.be.none + cert.should.have.key("keyPair") + cert["keyPair"].should.have.key("PublicKey").which.should_not.be.none + cert["keyPair"].should.have.key("PrivateKey").which.should_not.be.none + cert_id = cert["certificateId"] cert = client.describe_certificate(certificateId=cert_id) - cert.should.have.key('certificateDescription') - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('certificateArn').which.should_not.be.none - cert_desc.should.have.key('certificateId').which.should_not.be.none - cert_desc.should.have.key('certificatePem').which.should_not.be.none - cert_desc.should.have.key('status').which.should.equal('ACTIVE') - cert_pem = cert_desc['certificatePem'] + cert.should.have.key("certificateDescription") + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("certificateArn").which.should_not.be.none + cert_desc.should.have.key("certificateId").which.should_not.be.none + cert_desc.should.have.key("certificatePem").which.should_not.be.none + cert_desc.should.have.key("status").which.should.equal("ACTIVE") + cert_pem = cert_desc["certificatePem"] res = client.list_certificates() - for cert in res['certificates']: - cert.should.have.key('certificateArn').which.should_not.be.none - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('status').which.should_not.be.none - cert.should.have.key('creationDate').which.should_not.be.none + for cert in res["certificates"]: + cert.should.have.key("certificateArn").which.should_not.be.none + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("status").which.should_not.be.none + cert.should.have.key("creationDate").which.should_not.be.none - client.update_certificate(certificateId=cert_id, newStatus='REVOKED') + client.update_certificate(certificateId=cert_id, newStatus="REVOKED") cert = client.describe_certificate(certificateId=cert_id) - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('REVOKED') + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("REVOKED") client.delete_certificate(certificateId=cert_id) res = client.list_certificates() - res.should.have.key('certificates') + res.should.have.key("certificates") # Test register_certificate flow cert = client.register_certificate(certificatePem=cert_pem, setAsActive=True) - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('certificateArn').which.should_not.be.none - cert_id = cert['certificateId'] + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("certificateArn").which.should_not.be.none + cert_id = cert["certificateId"] res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) - for cert in res['certificates']: - cert.should.have.key('certificateArn').which.should_not.be.none - cert.should.have.key('certificateId').which.should_not.be.none - cert.should.have.key('status').which.should_not.be.none - cert.should.have.key('creationDate').which.should_not.be.none + res.should.have.key("certificates").which.should.have.length_of(1) + for cert in res["certificates"]: + cert.should.have.key("certificateArn").which.should_not.be.none + cert.should.have.key("certificateId").which.should_not.be.none + cert.should.have.key("status").which.should_not.be.none + cert.should.have.key("creationDate").which.should_not.be.none - client.update_certificate(certificateId=cert_id, newStatus='REVOKED') + client.update_certificate(certificateId=cert_id, newStatus="REVOKED") cert = client.describe_certificate(certificateId=cert_id) - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('REVOKED') + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("REVOKED") client.delete_certificate(certificateId=cert_id) res = client.list_certificates() - res.should.have.key('certificates') + res.should.have.key("certificates") @mock_iot @@ -302,24 +340,26 @@ def test_delete_policy_validation(): ] } """ - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] - policy_name = 'my-policy' + cert_arn = cert["certificateArn"] + policy_name = "my-policy" client.create_policy(policyName=policy_name, policyDocument=doc) client.attach_principal_policy(policyName=policy_name, principal=cert_arn) with assert_raises(ClientError) as e: client.delete_policy(policyName=policy_name) - e.exception.response['Error']['Message'].should.contain( - 'The policy cannot be deleted as the policy is attached to one or more principals (name=%s)' % policy_name) + e.exception.response["Error"]["Message"].should.contain( + "The policy cannot be deleted as the policy is attached to one or more principals (name=%s)" + % policy_name + ) res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(1) + res.should.have.key("policies").which.should.have.length_of(1) client.detach_principal_policy(policyName=policy_name, principal=cert_arn) client.delete_policy(policyName=policy_name) res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) @mock_iot @@ -337,12 +377,12 @@ def test_delete_certificate_validation(): ] } """ - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") cert = client.create_keys_and_certificate(setAsActive=True) - cert_id = cert['certificateId'] - cert_arn = cert['certificateArn'] - policy_name = 'my-policy' - thing_name = 'thing-1' + cert_id = cert["certificateId"] + cert_arn = cert["certificateArn"] + policy_name = "my-policy" + thing_name = "thing-1" client.create_policy(policyName=policy_name, policyDocument=doc) client.attach_principal_policy(policyName=policy_name, principal=cert_arn) client.create_thing(thingName=thing_name) @@ -350,189 +390,192 @@ def test_delete_certificate_validation(): with assert_raises(ClientError) as e: client.delete_certificate(certificateId=cert_id) - e.exception.response['Error']['Message'].should.contain( - 'Certificate must be deactivated (not ACTIVE) before deletion.') + e.exception.response["Error"]["Message"].should.contain( + "Certificate must be deactivated (not ACTIVE) before deletion." + ) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) + res.should.have.key("certificates").which.should.have.length_of(1) - client.update_certificate(certificateId=cert_id, newStatus='REVOKED') + client.update_certificate(certificateId=cert_id, newStatus="REVOKED") with assert_raises(ClientError) as e: client.delete_certificate(certificateId=cert_id) - e.exception.response['Error']['Message'].should.contain( - 'Things must be detached before deletion (arn: %s)' % cert_arn) + e.exception.response["Error"]["Message"].should.contain( + "Things must be detached before deletion (arn: %s)" % cert_arn + ) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) + res.should.have.key("certificates").which.should.have.length_of(1) client.detach_thing_principal(thingName=thing_name, principal=cert_arn) with assert_raises(ClientError) as e: client.delete_certificate(certificateId=cert_id) - e.exception.response['Error']['Message'].should.contain( - 'Certificate policies must be detached before deletion (arn: %s)' % cert_arn) + e.exception.response["Error"]["Message"].should.contain( + "Certificate policies must be detached before deletion (arn: %s)" % cert_arn + ) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(1) + res.should.have.key("certificates").which.should.have.length_of(1) client.detach_principal_policy(policyName=policy_name, principal=cert_arn) client.delete_certificate(certificateId=cert_id) res = client.list_certificates() - res.should.have.key('certificates').which.should.have.length_of(0) + res.should.have.key("certificates").which.should.have.length_of(0) @mock_iot def test_certs_create_inactive(): - client = boto3.client('iot', region_name='ap-northeast-1') + client = boto3.client("iot", region_name="ap-northeast-1") cert = client.create_keys_and_certificate(setAsActive=False) - cert_id = cert['certificateId'] + cert_id = cert["certificateId"] cert = client.describe_certificate(certificateId=cert_id) - cert.should.have.key('certificateDescription') - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('INACTIVE') + cert.should.have.key("certificateDescription") + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("INACTIVE") - client.update_certificate(certificateId=cert_id, newStatus='ACTIVE') + client.update_certificate(certificateId=cert_id, newStatus="ACTIVE") cert = client.describe_certificate(certificateId=cert_id) - cert.should.have.key('certificateDescription') - cert_desc = cert['certificateDescription'] - cert_desc.should.have.key('status').which.should.equal('ACTIVE') + cert.should.have.key("certificateDescription") + cert_desc = cert["certificateDescription"] + cert_desc.should.have.key("status").which.should.equal("ACTIVE") @mock_iot def test_policy(): - client = boto3.client('iot', region_name='ap-northeast-1') - name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + name = "my-policy" + doc = "{}" policy = client.create_policy(policyName=name, policyDocument=doc) - policy.should.have.key('policyName').which.should.equal(name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(doc) - policy.should.have.key('policyVersionId').which.should.equal('1') + policy.should.have.key("policyName").which.should.equal(name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(doc) + policy.should.have.key("policyVersionId").which.should.equal("1") policy = client.get_policy(policyName=name) - policy.should.have.key('policyName').which.should.equal(name) - policy.should.have.key('policyArn').which.should_not.be.none - policy.should.have.key('policyDocument').which.should.equal(doc) - policy.should.have.key('defaultVersionId').which.should.equal('1') + policy.should.have.key("policyName").which.should.equal(name) + policy.should.have.key("policyArn").which.should_not.be.none + policy.should.have.key("policyDocument").which.should.equal(doc) + policy.should.have.key("defaultVersionId").which.should.equal("1") res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none client.delete_policy(policyName=name) res = client.list_policies() - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) @mock_iot def test_principal_policy(): - client = boto3.client('iot', region_name='ap-northeast-1') - policy_name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + policy_name = "my-policy" + doc = "{}" client.create_policy(policyName=policy_name, policyDocument=doc) cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.attach_policy(policyName=policy_name, target=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none # do nothing if policy have already attached to certificate client.attach_policy(policyName=policy_name, target=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(1) - for principal in res['principals']: + res.should.have.key("principals").which.should.have.length_of(1) + for principal in res["principals"]: principal.should_not.be.none client.detach_policy(policyName=policy_name, target=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(0) + res.should.have.key("principals").which.should.have.length_of(0) with assert_raises(ClientError) as e: client.detach_policy(policyName=policy_name, target=cert_arn) - e.exception.response['Error']['Code'].should.equal('ResourceNotFoundException') + e.exception.response["Error"]["Code"].should.equal("ResourceNotFoundException") @mock_iot def test_principal_policy_deprecated(): - client = boto3.client('iot', region_name='ap-northeast-1') - policy_name = 'my-policy' - doc = '{}' + client = boto3.client("iot", region_name="ap-northeast-1") + policy_name = "my-policy" + doc = "{}" policy = client.create_policy(policyName=policy_name, policyDocument=doc) cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.attach_principal_policy(policyName=policy_name, principal=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(1) - for policy in res['policies']: - policy.should.have.key('policyName').which.should_not.be.none - policy.should.have.key('policyArn').which.should_not.be.none + res.should.have.key("policies").which.should.have.length_of(1) + for policy in res["policies"]: + policy.should.have.key("policyName").which.should_not.be.none + policy.should.have.key("policyArn").which.should_not.be.none res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(1) - for principal in res['principals']: + res.should.have.key("principals").which.should.have.length_of(1) + for principal in res["principals"]: principal.should_not.be.none client.detach_principal_policy(policyName=policy_name, principal=cert_arn) res = client.list_principal_policies(principal=cert_arn) - res.should.have.key('policies').which.should.have.length_of(0) + res.should.have.key("policies").which.should.have.length_of(0) res = client.list_policy_principals(policyName=policy_name) - res.should.have.key('principals').which.should.have.length_of(0) + res.should.have.key("principals").which.should.have.length_of(0) @mock_iot def test_principal_thing(): - client = boto3.client('iot', region_name='ap-northeast-1') - thing_name = 'my-thing' + client = boto3.client("iot", region_name="ap-northeast-1") + thing_name = "my-thing" thing = client.create_thing(thingName=thing_name) cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] + cert_arn = cert["certificateArn"] client.attach_thing_principal(thingName=thing_name, principal=cert_arn) res = client.list_principal_things(principal=cert_arn) - res.should.have.key('things').which.should.have.length_of(1) - for thing in res['things']: + res.should.have.key("things").which.should.have.length_of(1) + for thing in res["things"]: thing.should_not.be.none res = client.list_thing_principals(thingName=thing_name) - res.should.have.key('principals').which.should.have.length_of(1) - for principal in res['principals']: + res.should.have.key("principals").which.should.have.length_of(1) + for principal in res["principals"]: principal.should_not.be.none client.detach_thing_principal(thingName=thing_name, principal=cert_arn) res = client.list_principal_things(principal=cert_arn) - res.should.have.key('things').which.should.have.length_of(0) + res.should.have.key("things").which.should.have.length_of(0) res = client.list_thing_principals(thingName=thing_name) - res.should.have.key('principals').which.should.have.length_of(0) + res.should.have.key("principals").which.should.have.length_of(0) @mock_iot def test_delete_principal_thing(): - client = boto3.client('iot', region_name='ap-northeast-1') - thing_name = 'my-thing' + client = boto3.client("iot", region_name="ap-northeast-1") + thing_name = "my-thing" thing = client.create_thing(thingName=thing_name) cert = client.create_keys_and_certificate(setAsActive=True) - cert_arn = cert['certificateArn'] - cert_id = cert['certificateId'] + cert_arn = cert["certificateArn"] + cert_id = cert["certificateId"] client.attach_thing_principal(thingName=thing_name, principal=cert_arn) client.delete_thing(thingName=thing_name) res = client.list_principal_things(principal=cert_arn) - res.should.have.key('things').which.should.have.length_of(0) + res.should.have.key("things").which.should.have.length_of(0) client.update_certificate(certificateId=cert_id, newStatus="INACTIVE") client.delete_certificate(certificateId=cert_id) @@ -540,206 +583,152 @@ def test_delete_principal_thing(): @mock_iot def test_thing_groups(): - client = boto3.client('iot', region_name='ap-northeast-1') - group_name = 'my-group-name' + client = boto3.client("iot", region_name="ap-northeast-1") + group_name = "my-group-name" # thing group thing_group = client.create_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupArn') + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupArn") res = client.list_thing_groups() - res.should.have.key('thingGroups').which.should.have.length_of(1) - for thing_group in res['thingGroups']: - thing_group.should.have.key('groupName').which.should_not.be.none - thing_group.should.have.key('groupArn').which.should_not.be.none + res.should.have.key("thingGroups").which.should.have.length_of(1) + for thing_group in res["thingGroups"]: + thing_group.should.have.key("groupName").which.should_not.be.none + thing_group.should.have.key("groupArn").which.should_not.be.none thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupProperties') - thing_group.should.have.key('thingGroupMetadata') - thing_group.should.have.key('version') + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupProperties") + thing_group.should.have.key("thingGroupMetadata") + thing_group.should.have.key("version") # delete thing group client.delete_thing_group(thingGroupName=group_name) res = client.list_thing_groups() - res.should.have.key('thingGroups').which.should.have.length_of(0) + res.should.have.key("thingGroups").which.should.have.length_of(0) # props create test props = { - 'thingGroupDescription': 'my first thing group', - 'attributePayload': { - 'attributes': { - 'key1': 'val01', - 'Key02': 'VAL2' - } - } + "thingGroupDescription": "my first thing group", + "attributePayload": {"attributes": {"key1": "val01", "Key02": "VAL2"}}, } - thing_group = client.create_thing_group(thingGroupName=group_name, thingGroupProperties=props) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupArn') + thing_group = client.create_thing_group( + thingGroupName=group_name, thingGroupProperties=props + ) + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupArn") thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupProperties') \ - .which.should.have.key('attributePayload') \ - .which.should.have.key('attributes') - res_props = thing_group['thingGroupProperties']['attributePayload']['attributes'] - res_props.should.have.key('key1').which.should.equal('val01') - res_props.should.have.key('Key02').which.should.equal('VAL2') + thing_group.should.have.key("thingGroupProperties").which.should.have.key( + "attributePayload" + ).which.should.have.key("attributes") + res_props = thing_group["thingGroupProperties"]["attributePayload"]["attributes"] + res_props.should.have.key("key1").which.should.equal("val01") + res_props.should.have.key("Key02").which.should.equal("VAL2") # props update test with merge - new_props = { - 'attributePayload': { - 'attributes': { - 'k3': 'v3' - }, - 'merge': True - } - } - client.update_thing_group( - thingGroupName=group_name, - thingGroupProperties=new_props - ) + new_props = {"attributePayload": {"attributes": {"k3": "v3"}, "merge": True}} + client.update_thing_group(thingGroupName=group_name, thingGroupProperties=new_props) thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupProperties') \ - .which.should.have.key('attributePayload') \ - .which.should.have.key('attributes') - res_props = thing_group['thingGroupProperties']['attributePayload']['attributes'] - res_props.should.have.key('key1').which.should.equal('val01') - res_props.should.have.key('Key02').which.should.equal('VAL2') + thing_group.should.have.key("thingGroupProperties").which.should.have.key( + "attributePayload" + ).which.should.have.key("attributes") + res_props = thing_group["thingGroupProperties"]["attributePayload"]["attributes"] + res_props.should.have.key("key1").which.should.equal("val01") + res_props.should.have.key("Key02").which.should.equal("VAL2") - res_props.should.have.key('k3').which.should.equal('v3') + res_props.should.have.key("k3").which.should.equal("v3") # props update test - new_props = { - 'attributePayload': { - 'attributes': { - 'k4': 'v4' - } - } - } - client.update_thing_group( - thingGroupName=group_name, - thingGroupProperties=new_props - ) + new_props = {"attributePayload": {"attributes": {"k4": "v4"}}} + client.update_thing_group(thingGroupName=group_name, thingGroupProperties=new_props) thing_group = client.describe_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupProperties') \ - .which.should.have.key('attributePayload') \ - .which.should.have.key('attributes') - res_props = thing_group['thingGroupProperties']['attributePayload']['attributes'] - res_props.should.have.key('k4').which.should.equal('v4') - res_props.should_not.have.key('key1') + thing_group.should.have.key("thingGroupProperties").which.should.have.key( + "attributePayload" + ).which.should.have.key("attributes") + res_props = thing_group["thingGroupProperties"]["attributePayload"]["attributes"] + res_props.should.have.key("k4").which.should.equal("v4") + res_props.should_not.have.key("key1") @mock_iot def test_thing_group_relations(): - client = boto3.client('iot', region_name='ap-northeast-1') - name = 'my-thing' - group_name = 'my-group-name' + client = boto3.client("iot", region_name="ap-northeast-1") + name = "my-thing" + group_name = "my-group-name" # thing group thing_group = client.create_thing_group(thingGroupName=group_name) - thing_group.should.have.key('thingGroupName').which.should.equal(group_name) - thing_group.should.have.key('thingGroupArn') + thing_group.should.have.key("thingGroupName").which.should.equal(group_name) + thing_group.should.have.key("thingGroupArn") # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # add in 4 way + client.add_thing_to_thing_group(thingGroupName=group_name, thingName=name) client.add_thing_to_thing_group( - thingGroupName=group_name, - thingName=name + thingGroupArn=thing_group["thingGroupArn"], thingArn=thing["thingArn"] ) client.add_thing_to_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingArn=thing['thingArn'] + thingGroupName=group_name, thingArn=thing["thingArn"] ) client.add_thing_to_thing_group( - thingGroupName=group_name, - thingArn=thing['thingArn'] - ) - client.add_thing_to_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingName=name + thingGroupArn=thing_group["thingGroupArn"], thingName=name ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(1) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(1) - thing_groups = client.list_thing_groups_for_thing( - thingName=name - ) - thing_groups.should.have.key('thingGroups') - thing_groups['thingGroups'].should.have.length_of(1) + thing_groups = client.list_thing_groups_for_thing(thingName=name) + thing_groups.should.have.key("thingGroups") + thing_groups["thingGroups"].should.have.length_of(1) # remove in 4 way + client.remove_thing_from_thing_group(thingGroupName=group_name, thingName=name) client.remove_thing_from_thing_group( - thingGroupName=group_name, - thingName=name + thingGroupArn=thing_group["thingGroupArn"], thingArn=thing["thingArn"] ) client.remove_thing_from_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingArn=thing['thingArn'] + thingGroupName=group_name, thingArn=thing["thingArn"] ) client.remove_thing_from_thing_group( - thingGroupName=group_name, - thingArn=thing['thingArn'] + thingGroupArn=thing_group["thingGroupArn"], thingName=name ) - client.remove_thing_from_thing_group( - thingGroupArn=thing_group['thingGroupArn'], - thingName=name - ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(0) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(0) # update thing group for thing - client.update_thing_groups_for_thing( - thingName=name, - thingGroupsToAdd=[ - group_name - ] - ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(1) + client.update_thing_groups_for_thing(thingName=name, thingGroupsToAdd=[group_name]) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(1) client.update_thing_groups_for_thing( - thingName=name, - thingGroupsToRemove=[ - group_name - ] + thingName=name, thingGroupsToRemove=[group_name] ) - things = client.list_things_in_thing_group( - thingGroupName=group_name - ) - things.should.have.key('things') - things['things'].should.have.length_of(0) + things = client.list_things_in_thing_group(thingGroupName=group_name) + things.should.have.key("things") + things["things"].should.have.length_of(0) @mock_iot def test_create_job(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document - job_document = { - "field": "value" - } + job_document = {"field": "value"} job = client.create_job( jobId=job_id, @@ -747,113 +736,119 @@ def test_create_job(): document=json.dumps(job_document), description="Description", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123, }, targetSelection="CONTINUOUS", - jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 - } + jobExecutionsRolloutConfig={"maximumPerMinute": 10}, ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') - job.should.have.key('description') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") + job.should.have.key("description") @mock_iot def test_describe_job(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") job = client.create_job( jobId=job_id, targets=[thing["thingArn"]], documentSource="https://s3-eu-west-1.amazonaws.com/bucket-name/job_document.json", presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123, }, targetSelection="CONTINUOUS", - jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 - } + jobExecutionsRolloutConfig={"maximumPerMinute": 10}, ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job = client.describe_job(jobId=job_id) - job.should.have.key('documentSource') - job.should.have.key('job') - job.should.have.key('job').which.should.have.key("jobArn") - job.should.have.key('job').which.should.have.key("jobId").which.should.equal(job_id) - job.should.have.key('job').which.should.have.key("targets") - job.should.have.key('job').which.should.have.key("jobProcessDetails") - job.should.have.key('job').which.should.have.key("lastUpdatedAt") - job.should.have.key('job').which.should.have.key("createdAt") - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig") - job.should.have.key('job').which.should.have.key("targetSelection").which.should.equal("CONTINUOUS") - job.should.have.key('job').which.should.have.key("presignedUrlConfig") - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( - "roleArn").which.should.equal('arn:aws:iam::1:role/service-role/iot_job_role') - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( - "expiresInSec").which.should.equal(123) - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig").which.should.have.key( - "maximumPerMinute").which.should.equal(10) + job.should.have.key("documentSource") + job.should.have.key("job") + job.should.have.key("job").which.should.have.key("jobArn") + job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("job").which.should.have.key("targets") + job.should.have.key("job").which.should.have.key("jobProcessDetails") + job.should.have.key("job").which.should.have.key("lastUpdatedAt") + job.should.have.key("job").which.should.have.key("createdAt") + job.should.have.key("job").which.should.have.key("jobExecutionsRolloutConfig") + job.should.have.key("job").which.should.have.key( + "targetSelection" + ).which.should.equal("CONTINUOUS") + job.should.have.key("job").which.should.have.key("presignedUrlConfig") + job.should.have.key("job").which.should.have.key( + "presignedUrlConfig" + ).which.should.have.key("roleArn").which.should.equal( + "arn:aws:iam::1:role/service-role/iot_job_role" + ) + job.should.have.key("job").which.should.have.key( + "presignedUrlConfig" + ).which.should.have.key("expiresInSec").which.should.equal(123) + job.should.have.key("job").which.should.have.key( + "jobExecutionsRolloutConfig" + ).which.should.have.key("maximumPerMinute").which.should.equal(10) @mock_iot def test_describe_job_1(): - client = boto3.client('iot', region_name='eu-west-1') + client = boto3.client("iot", region_name="eu-west-1") name = "my-thing" job_id = "TestJob" # thing thing = client.create_thing(thingName=name) - thing.should.have.key('thingName').which.should.equal(name) - thing.should.have.key('thingArn') + thing.should.have.key("thingName").which.should.equal(name) + thing.should.have.key("thingArn") # job document - job_document = { - "field": "value" - } + job_document = {"field": "value"} job = client.create_job( jobId=job_id, targets=[thing["thingArn"]], document=json.dumps(job_document), presignedUrlConfig={ - 'roleArn': 'arn:aws:iam::1:role/service-role/iot_job_role', - 'expiresInSec': 123 + "roleArn": "arn:aws:iam::1:role/service-role/iot_job_role", + "expiresInSec": 123, }, targetSelection="CONTINUOUS", - jobExecutionsRolloutConfig={ - 'maximumPerMinute': 10 - } + jobExecutionsRolloutConfig={"maximumPerMinute": 10}, ) - job.should.have.key('jobId').which.should.equal(job_id) - job.should.have.key('jobArn') + job.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("jobArn") job = client.describe_job(jobId=job_id) - job.should.have.key('job') - job.should.have.key('job').which.should.have.key("jobArn") - job.should.have.key('job').which.should.have.key("jobId").which.should.equal(job_id) - job.should.have.key('job').which.should.have.key("targets") - job.should.have.key('job').which.should.have.key("jobProcessDetails") - job.should.have.key('job').which.should.have.key("lastUpdatedAt") - job.should.have.key('job').which.should.have.key("createdAt") - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig") - job.should.have.key('job').which.should.have.key("targetSelection").which.should.equal("CONTINUOUS") - job.should.have.key('job').which.should.have.key("presignedUrlConfig") - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( - "roleArn").which.should.equal('arn:aws:iam::1:role/service-role/iot_job_role') - job.should.have.key('job').which.should.have.key("presignedUrlConfig").which.should.have.key( - "expiresInSec").which.should.equal(123) - job.should.have.key('job').which.should.have.key("jobExecutionsRolloutConfig").which.should.have.key( - "maximumPerMinute").which.should.equal(10) + job.should.have.key("job") + job.should.have.key("job").which.should.have.key("jobArn") + job.should.have.key("job").which.should.have.key("jobId").which.should.equal(job_id) + job.should.have.key("job").which.should.have.key("targets") + job.should.have.key("job").which.should.have.key("jobProcessDetails") + job.should.have.key("job").which.should.have.key("lastUpdatedAt") + job.should.have.key("job").which.should.have.key("createdAt") + job.should.have.key("job").which.should.have.key("jobExecutionsRolloutConfig") + job.should.have.key("job").which.should.have.key( + "targetSelection" + ).which.should.equal("CONTINUOUS") + job.should.have.key("job").which.should.have.key("presignedUrlConfig") + job.should.have.key("job").which.should.have.key( + "presignedUrlConfig" + ).which.should.have.key("roleArn").which.should.equal( + "arn:aws:iam::1:role/service-role/iot_job_role" + ) + job.should.have.key("job").which.should.have.key( + "presignedUrlConfig" + ).which.should.have.key("expiresInSec").which.should.equal(123) + job.should.have.key("job").which.should.have.key( + "jobExecutionsRolloutConfig" + ).which.should.have.key("maximumPerMinute").which.should.equal(10) diff --git a/tests/test_iot/test_server.py b/tests/test_iot/test_server.py index 47091531a..b04f4d8ea 100644 --- a/tests/test_iot/test_server.py +++ b/tests/test_iot/test_server.py @@ -5,9 +5,10 @@ import sure # noqa import moto.server as server from moto import mock_iot -''' +""" Test the different server responses -''' +""" + @mock_iot def test_iot_list(): @@ -15,5 +16,5 @@ def test_iot_list(): test_client = backend.test_client() # just making sure that server is up - res = test_client.get('/things') + res = test_client.get("/things") res.status_code.should.equal(404) diff --git a/tests/test_iotdata/test_iotdata.py b/tests/test_iotdata/test_iotdata.py index 1cedcaa72..ac0a04244 100644 --- a/tests/test_iotdata/test_iotdata.py +++ b/tests/test_iotdata/test_iotdata.py @@ -11,9 +11,9 @@ from moto import mock_iotdata, mock_iot @mock_iot @mock_iotdata def test_basic(): - iot_client = boto3.client('iot', region_name='ap-northeast-1') - client = boto3.client('iot-data', region_name='ap-northeast-1') - name = 'my-thing' + iot_client = boto3.client("iot", region_name="ap-northeast-1") + client = boto3.client("iot-data", region_name="ap-northeast-1") + name = "my-thing" raw_payload = b'{"state": {"desired": {"led": "on"}}}' iot_client.create_thing(thingName=name) @@ -22,20 +22,24 @@ def test_basic(): res = client.update_thing_shadow(thingName=name, payload=raw_payload) - payload = json.loads(res['payload'].read()) + payload = json.loads(res["payload"].read()) expected_state = '{"desired": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") res = client.get_thing_shadow(thingName=name) - payload = json.loads(res['payload'].read()) + payload = json.loads(res["payload"].read()) expected_state = b'{"desired": {"led": "on"}, "delta": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") client.delete_thing_shadow(thingName=name) with assert_raises(ClientError): @@ -45,55 +49,63 @@ def test_basic(): @mock_iot @mock_iotdata def test_update(): - iot_client = boto3.client('iot', region_name='ap-northeast-1') - client = boto3.client('iot-data', region_name='ap-northeast-1') - name = 'my-thing' + iot_client = boto3.client("iot", region_name="ap-northeast-1") + client = boto3.client("iot-data", region_name="ap-northeast-1") + name = "my-thing" raw_payload = b'{"state": {"desired": {"led": "on"}}}' iot_client.create_thing(thingName=name) # first update res = client.update_thing_shadow(thingName=name, payload=raw_payload) - payload = json.loads(res['payload'].read()) + payload = json.loads(res["payload"].read()) expected_state = '{"desired": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") res = client.get_thing_shadow(thingName=name) - payload = json.loads(res['payload'].read()) + payload = json.loads(res["payload"].read()) expected_state = b'{"desired": {"led": "on"}, "delta": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(1) - payload.should.have.key('timestamp') + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(1) + payload.should.have.key("timestamp") # reporting new state new_payload = b'{"state": {"reported": {"led": "on"}}}' res = client.update_thing_shadow(thingName=name, payload=new_payload) - payload = json.loads(res['payload'].read()) + payload = json.loads(res["payload"].read()) expected_state = '{"reported": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('reported').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(2) - payload.should.have.key('timestamp') + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "reported" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(2) + payload.should.have.key("timestamp") res = client.get_thing_shadow(thingName=name) - payload = json.loads(res['payload'].read()) + payload = json.loads(res["payload"].read()) expected_state = b'{"desired": {"led": "on"}, "reported": {"led": "on"}}' - payload.should.have.key('state').which.should.equal(json.loads(expected_state)) - payload.should.have.key('metadata').which.should.have.key('desired').which.should.have.key('led') - payload.should.have.key('version').which.should.equal(2) - payload.should.have.key('timestamp') + payload.should.have.key("state").which.should.equal(json.loads(expected_state)) + payload.should.have.key("metadata").which.should.have.key( + "desired" + ).which.should.have.key("led") + payload.should.have.key("version").which.should.equal(2) + payload.should.have.key("timestamp") raw_payload = b'{"state": {"desired": {"led": "on"}}, "version": 1}' with assert_raises(ClientError) as ex: client.update_thing_shadow(thingName=name, payload=raw_payload) - ex.exception.response['ResponseMetadata']['HTTPStatusCode'].should.equal(409) - ex.exception.response['Error']['Message'].should.equal('Version conflict') + ex.exception.response["ResponseMetadata"]["HTTPStatusCode"].should.equal(409) + ex.exception.response["Error"]["Message"].should.equal("Version conflict") @mock_iotdata def test_publish(): - client = boto3.client('iot-data', region_name='ap-northeast-1') - client.publish(topic='test/topic', qos=1, payload=b'') + client = boto3.client("iot-data", region_name="ap-northeast-1") + client.publish(topic="test/topic", qos=1, payload=b"") diff --git a/tests/test_iotdata/test_server.py b/tests/test_iotdata/test_server.py index 42a5c5f22..bbced67b6 100644 --- a/tests/test_iotdata/test_server.py +++ b/tests/test_iotdata/test_server.py @@ -5,9 +5,10 @@ import sure # noqa import moto.server as server from moto import mock_iotdata -''' +""" Test the different server responses -''' +""" + @mock_iotdata def test_iotdata_list(): @@ -15,6 +16,6 @@ def test_iotdata_list(): test_client = backend.test_client() # just making sure that server is up - thing_name = 'nothing' - res = test_client.get('/things/{}/shadow'.format(thing_name)) + thing_name = "nothing" + res = test_client.get("/things/{}/shadow".format(thing_name)) res.status_code.should.equal(404) diff --git a/tests/test_kinesis/test_firehose.py b/tests/test_kinesis/test_firehose.py index 91c1038d3..7101c4eaf 100644 --- a/tests/test_kinesis/test_firehose.py +++ b/tests/test_kinesis/test_firehose.py @@ -14,263 +14,238 @@ def create_s3_delivery_stream(client, stream_name): DeliveryStreamName=stream_name, DeliveryStreamType="DirectPut", ExtendedS3DestinationConfiguration={ - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'CompressionFormat': 'UNCOMPRESSED', - 'DataFormatConversionConfiguration': { - 'Enabled': True, - 'InputFormatConfiguration': { - 'Deserializer': { - 'HiveJsonSerDe': { - }, - }, + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "CompressionFormat": "UNCOMPRESSED", + "DataFormatConversionConfiguration": { + "Enabled": True, + "InputFormatConfiguration": {"Deserializer": {"HiveJsonSerDe": {}}}, + "OutputFormatConfiguration": { + "Serializer": {"ParquetSerDe": {"Compression": "SNAPPY"}} }, - 'OutputFormatConfiguration': { - 'Serializer': { - 'ParquetSerDe': { - 'Compression': 'SNAPPY', - }, - }, - }, - 'SchemaConfiguration': { - 'DatabaseName': stream_name, - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'TableName': 'outputTable', + "SchemaConfiguration": { + "DatabaseName": stream_name, + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "TableName": "outputTable", }, }, - }) - + }, + ) def create_redshift_delivery_stream(client, stream_name): return client.create_delivery_stream( DeliveryStreamName=stream_name, RedshiftDestinationConfiguration={ - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'ClusterJDBCURL': 'jdbc:redshift://host.amazonaws.com:5439/database', - 'CopyCommand': { - 'DataTableName': 'outputTable', - 'CopyOptions': "CSV DELIMITER ',' NULL '\\0'" + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "ClusterJDBCURL": "jdbc:redshift://host.amazonaws.com:5439/database", + "CopyCommand": { + "DataTableName": "outputTable", + "CopyOptions": "CSV DELIMITER ',' NULL '\\0'", }, - 'Username': 'username', - 'Password': 'password', - 'S3Configuration': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 - }, - 'CompressionFormat': 'UNCOMPRESSED', - } - } + "Username": "username", + "Password": "password", + "S3Configuration": { + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": {"SizeInMBs": 123, "IntervalInSeconds": 124}, + "CompressionFormat": "UNCOMPRESSED", + }, + }, ) @mock_kinesis def test_create_redshift_delivery_stream(): - client = boto3.client('firehose', region_name='us-east-1') + client = boto3.client("firehose", region_name="us-east-1") - response = create_redshift_delivery_stream(client, 'stream1') - stream_arn = response['DeliveryStreamARN'] + response = create_redshift_delivery_stream(client, "stream1") + stream_arn = response["DeliveryStreamARN"] - response = client.describe_delivery_stream(DeliveryStreamName='stream1') - stream_description = response['DeliveryStreamDescription'] + response = client.describe_delivery_stream(DeliveryStreamName="stream1") + stream_description = response["DeliveryStreamDescription"] # Sure and Freezegun don't play nicely together - _ = stream_description.pop('CreateTimestamp') - _ = stream_description.pop('LastUpdateTimestamp') + _ = stream_description.pop("CreateTimestamp") + _ = stream_description.pop("LastUpdateTimestamp") - stream_description.should.equal({ - 'DeliveryStreamName': 'stream1', - 'DeliveryStreamARN': stream_arn, - 'DeliveryStreamStatus': 'ACTIVE', - 'VersionId': 'string', - 'Destinations': [ - { - 'DestinationId': 'string', - 'RedshiftDestinationDescription': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'ClusterJDBCURL': 'jdbc:redshift://host.amazonaws.com:5439/database', - 'CopyCommand': { - 'DataTableName': 'outputTable', - 'CopyOptions': "CSV DELIMITER ',' NULL '\\0'" - }, - 'Username': 'username', - 'S3DestinationDescription': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 + stream_description.should.equal( + { + "DeliveryStreamName": "stream1", + "DeliveryStreamARN": stream_arn, + "DeliveryStreamStatus": "ACTIVE", + "VersionId": "string", + "Destinations": [ + { + "DestinationId": "string", + "RedshiftDestinationDescription": { + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "ClusterJDBCURL": "jdbc:redshift://host.amazonaws.com:5439/database", + "CopyCommand": { + "DataTableName": "outputTable", + "CopyOptions": "CSV DELIMITER ',' NULL '\\0'", }, - 'CompressionFormat': 'UNCOMPRESSED', - } + "Username": "username", + "S3DestinationDescription": { + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": { + "SizeInMBs": 123, + "IntervalInSeconds": 124, + }, + "CompressionFormat": "UNCOMPRESSED", + }, + }, } - }, - ], - "HasMoreDestinations": False, - }) + ], + "HasMoreDestinations": False, + } + ) @mock_kinesis def test_create_s3_delivery_stream(): - client = boto3.client('firehose', region_name='us-east-1') + client = boto3.client("firehose", region_name="us-east-1") - response = create_s3_delivery_stream(client, 'stream1') - stream_arn = response['DeliveryStreamARN'] + response = create_s3_delivery_stream(client, "stream1") + stream_arn = response["DeliveryStreamARN"] - response = client.describe_delivery_stream(DeliveryStreamName='stream1') - stream_description = response['DeliveryStreamDescription'] + response = client.describe_delivery_stream(DeliveryStreamName="stream1") + stream_description = response["DeliveryStreamDescription"] # Sure and Freezegun don't play nicely together - _ = stream_description.pop('CreateTimestamp') - _ = stream_description.pop('LastUpdateTimestamp') + _ = stream_description.pop("CreateTimestamp") + _ = stream_description.pop("LastUpdateTimestamp") - stream_description.should.equal({ - 'DeliveryStreamName': 'stream1', - 'DeliveryStreamARN': stream_arn, - 'DeliveryStreamStatus': 'ACTIVE', - 'VersionId': 'string', - 'Destinations': [ - { - 'DestinationId': 'string', - 'ExtendedS3DestinationDescription': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'CompressionFormat': 'UNCOMPRESSED', - 'DataFormatConversionConfiguration': { - 'Enabled': True, - 'InputFormatConfiguration': { - 'Deserializer': { - 'HiveJsonSerDe': { - }, + stream_description.should.equal( + { + "DeliveryStreamName": "stream1", + "DeliveryStreamARN": stream_arn, + "DeliveryStreamStatus": "ACTIVE", + "VersionId": "string", + "Destinations": [ + { + "DestinationId": "string", + "ExtendedS3DestinationDescription": { + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "CompressionFormat": "UNCOMPRESSED", + "DataFormatConversionConfiguration": { + "Enabled": True, + "InputFormatConfiguration": { + "Deserializer": {"HiveJsonSerDe": {}} }, - }, - 'OutputFormatConfiguration': { - 'Serializer': { - 'ParquetSerDe': { - 'Compression': 'SNAPPY', - }, + "OutputFormatConfiguration": { + "Serializer": { + "ParquetSerDe": {"Compression": "SNAPPY"} + } + }, + "SchemaConfiguration": { + "DatabaseName": "stream1", + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "TableName": "outputTable", }, - }, - 'SchemaConfiguration': { - 'DatabaseName': 'stream1', - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'TableName': 'outputTable', }, }, - }, - }, - ], - "HasMoreDestinations": False, - }) + } + ], + "HasMoreDestinations": False, + } + ) + @mock_kinesis def test_create_stream_without_redshift(): - client = boto3.client('firehose', region_name='us-east-1') + client = boto3.client("firehose", region_name="us-east-1") response = client.create_delivery_stream( DeliveryStreamName="stream1", S3DestinationConfiguration={ - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 - }, - 'CompressionFormat': 'UNCOMPRESSED', - } + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": {"SizeInMBs": 123, "IntervalInSeconds": 124}, + "CompressionFormat": "UNCOMPRESSED", + }, ) - stream_arn = response['DeliveryStreamARN'] + stream_arn = response["DeliveryStreamARN"] - response = client.describe_delivery_stream(DeliveryStreamName='stream1') - stream_description = response['DeliveryStreamDescription'] + response = client.describe_delivery_stream(DeliveryStreamName="stream1") + stream_description = response["DeliveryStreamDescription"] # Sure and Freezegun don't play nicely together - _ = stream_description.pop('CreateTimestamp') - _ = stream_description.pop('LastUpdateTimestamp') + _ = stream_description.pop("CreateTimestamp") + _ = stream_description.pop("LastUpdateTimestamp") - stream_description.should.equal({ - 'DeliveryStreamName': 'stream1', - 'DeliveryStreamARN': stream_arn, - 'DeliveryStreamStatus': 'ACTIVE', - 'VersionId': 'string', - 'Destinations': [ - { - 'DestinationId': 'string', - 'S3DestinationDescription': { - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'RoleARN': 'arn:aws:iam::123456789012:role/firehose_delivery_role', - 'BucketARN': 'arn:aws:s3:::kinesis-test', - 'Prefix': 'myFolder/', - 'BufferingHints': { - 'SizeInMBs': 123, - 'IntervalInSeconds': 124 + stream_description.should.equal( + { + "DeliveryStreamName": "stream1", + "DeliveryStreamARN": stream_arn, + "DeliveryStreamStatus": "ACTIVE", + "VersionId": "string", + "Destinations": [ + { + "DestinationId": "string", + "S3DestinationDescription": { + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "RoleARN": "arn:aws:iam::123456789012:role/firehose_delivery_role", + "BucketARN": "arn:aws:s3:::kinesis-test", + "Prefix": "myFolder/", + "BufferingHints": {"SizeInMBs": 123, "IntervalInSeconds": 124}, + "CompressionFormat": "UNCOMPRESSED", }, - 'CompressionFormat': 'UNCOMPRESSED', } - }, - ], - "HasMoreDestinations": False, - }) + ], + "HasMoreDestinations": False, + } + ) @mock_kinesis def test_deescribe_non_existant_stream(): - client = boto3.client('firehose', region_name='us-east-1') + client = boto3.client("firehose", region_name="us-east-1") client.describe_delivery_stream.when.called_with( - DeliveryStreamName='not-a-stream').should.throw(ClientError) + DeliveryStreamName="not-a-stream" + ).should.throw(ClientError) @mock_kinesis def test_list_and_delete_stream(): - client = boto3.client('firehose', region_name='us-east-1') + client = boto3.client("firehose", region_name="us-east-1") - create_redshift_delivery_stream(client, 'stream1') - create_redshift_delivery_stream(client, 'stream2') + create_redshift_delivery_stream(client, "stream1") + create_redshift_delivery_stream(client, "stream2") - set(client.list_delivery_streams()['DeliveryStreamNames']).should.equal( - set(['stream1', 'stream2'])) + set(client.list_delivery_streams()["DeliveryStreamNames"]).should.equal( + set(["stream1", "stream2"]) + ) - client.delete_delivery_stream(DeliveryStreamName='stream1') + client.delete_delivery_stream(DeliveryStreamName="stream1") - set(client.list_delivery_streams()[ - 'DeliveryStreamNames']).should.equal(set(['stream2'])) + set(client.list_delivery_streams()["DeliveryStreamNames"]).should.equal( + set(["stream2"]) + ) @mock_kinesis def test_put_record(): - client = boto3.client('firehose', region_name='us-east-1') + client = boto3.client("firehose", region_name="us-east-1") - create_redshift_delivery_stream(client, 'stream1') - client.put_record( - DeliveryStreamName='stream1', - Record={ - 'Data': 'some data' - } - ) + create_redshift_delivery_stream(client, "stream1") + client.put_record(DeliveryStreamName="stream1", Record={"Data": "some data"}) @mock_kinesis def test_put_record_batch(): - client = boto3.client('firehose', region_name='us-east-1') + client = boto3.client("firehose", region_name="us-east-1") - create_redshift_delivery_stream(client, 'stream1') + create_redshift_delivery_stream(client, "stream1") client.put_record_batch( - DeliveryStreamName='stream1', - Records=[ - { - 'Data': 'some data1' - }, - { - 'Data': 'some data2' - }, - ] + DeliveryStreamName="stream1", + Records=[{"Data": "some data1"}, {"Data": "some data2"}], ) diff --git a/tests/test_kinesis/test_kinesis.py b/tests/test_kinesis/test_kinesis.py index e2de866fc..308100d8b 100644 --- a/tests/test_kinesis/test_kinesis.py +++ b/tests/test_kinesis/test_kinesis.py @@ -5,8 +5,7 @@ import time import boto.kinesis import boto3 -from boto.kinesis.exceptions import ResourceNotFoundException, \ - InvalidArgumentException +from boto.kinesis.exceptions import ResourceNotFoundException, InvalidArgumentException from moto import mock_kinesis, mock_kinesis_deprecated @@ -22,19 +21,19 @@ def test_create_cluster(): stream = stream_response["StreamDescription"] stream["StreamName"].should.equal("my_stream") stream["HasMoreShards"].should.equal(False) - stream["StreamARN"].should.equal( - "arn:aws:kinesis:us-west-2:123456789012:my_stream") + stream["StreamARN"].should.equal("arn:aws:kinesis:us-west-2:123456789012:my_stream") stream["StreamStatus"].should.equal("ACTIVE") - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(3) @mock_kinesis_deprecated def test_describe_non_existant_stream(): conn = boto.kinesis.connect_to_region("us-east-1") - conn.describe_stream.when.called_with( - "not-a-stream").should.throw(ResourceNotFoundException) + conn.describe_stream.when.called_with("not-a-stream").should.throw( + ResourceNotFoundException + ) @mock_kinesis_deprecated @@ -44,20 +43,21 @@ def test_list_and_delete_stream(): conn.create_stream("stream1", 1) conn.create_stream("stream2", 1) - conn.list_streams()['StreamNames'].should.have.length_of(2) + conn.list_streams()["StreamNames"].should.have.length_of(2) conn.delete_stream("stream2") - conn.list_streams()['StreamNames'].should.have.length_of(1) + conn.list_streams()["StreamNames"].should.have.length_of(1) # Delete invalid id - conn.delete_stream.when.called_with( - "not-a-stream").should.throw(ResourceNotFoundException) + conn.delete_stream.when.called_with("not-a-stream").should.throw( + ResourceNotFoundException + ) @mock_kinesis def test_list_many_streams(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") for i in range(11): conn.create_stream(StreamName="stream%d" % i, ShardCount=1) @@ -76,8 +76,8 @@ def test_list_many_streams(): @mock_kinesis def test_describe_stream_summary(): - conn = boto3.client('kinesis', region_name="us-west-2") - stream_name = 'my_stream_summary' + conn = boto3.client("kinesis", region_name="us-west-2") + stream_name = "my_stream_summary" shard_count = 5 conn.create_stream(StreamName=stream_name, ShardCount=shard_count) @@ -87,7 +87,8 @@ def test_describe_stream_summary(): stream["StreamName"].should.equal(stream_name) stream["OpenShardCount"].should.equal(shard_count) stream["StreamARN"].should.equal( - "arn:aws:kinesis:us-west-2:123456789012:{}".format(stream_name)) + "arn:aws:kinesis:us-west-2:123456789012:{}".format(stream_name) + ) stream["StreamStatus"].should.equal("ACTIVE") @@ -99,15 +100,15 @@ def test_basic_shard_iterator(): conn.create_stream(stream_name, 1) response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) - shard_iterator = response['NextShardIterator'] - response['Records'].should.equal([]) - response['MillisBehindLatest'].should.equal(0) + shard_iterator = response["NextShardIterator"] + response["Records"].should.equal([]) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis_deprecated @@ -118,8 +119,8 @@ def test_get_invalid_shard_iterator(): conn.create_stream(stream_name, 1) conn.get_shard_iterator.when.called_with( - stream_name, "123", 'TRIM_HORIZON').should.throw( - ResourceNotFoundException) + stream_name, "123", "TRIM_HORIZON" + ).should.throw(ResourceNotFoundException) @mock_kinesis_deprecated @@ -132,21 +133,22 @@ def test_put_records(): data = "hello world" partition_key = "1234" - conn.put_record.when.called_with( - stream_name, data, 1234).should.throw(InvalidArgumentException) + conn.put_record.when.called_with(stream_name, data, 1234).should.throw( + InvalidArgumentException + ) conn.put_record(stream_name, data, partition_key) response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) - shard_iterator = response['NextShardIterator'] - response['Records'].should.have.length_of(1) - record = response['Records'][0] + shard_iterator = response["NextShardIterator"] + response["Records"].should.have.length_of(1) + record = response["Records"][0] record["Data"].should.equal("hello world") record["PartitionKey"].should.equal("1234") @@ -168,18 +170,18 @@ def test_get_records_limit(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Retrieve only 3 records response = conn.get_records(shard_iterator, limit=3) - response['Records'].should.have.length_of(3) + response["Records"].should.have.length_of(3) # Then get the rest of the results - next_shard_iterator = response['NextShardIterator'] + next_shard_iterator = response["NextShardIterator"] response = conn.get_records(next_shard_iterator) - response['Records'].should.have.length_of(2) + response["Records"].should.have.length_of(2) @mock_kinesis_deprecated @@ -196,23 +198,24 @@ def test_get_records_at_sequence_number(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Get the second record response = conn.get_records(shard_iterator, limit=2) - second_sequence_id = response['Records'][1]['SequenceNumber'] + second_sequence_id = response["Records"][1]["SequenceNumber"] # Then get a new iterator starting at that id response = conn.get_shard_iterator( - stream_name, shard_id, 'AT_SEQUENCE_NUMBER', second_sequence_id) - shard_iterator = response['ShardIterator'] + stream_name, shard_id, "AT_SEQUENCE_NUMBER", second_sequence_id + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) # And the first result returned should be the second item - response['Records'][0]['SequenceNumber'].should.equal(second_sequence_id) - response['Records'][0]['Data'].should.equal('2') + response["Records"][0]["SequenceNumber"].should.equal(second_sequence_id) + response["Records"][0]["Data"].should.equal("2") @mock_kinesis_deprecated @@ -229,23 +232,24 @@ def test_get_records_after_sequence_number(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Get the second record response = conn.get_records(shard_iterator, limit=2) - second_sequence_id = response['Records'][1]['SequenceNumber'] + second_sequence_id = response["Records"][1]["SequenceNumber"] # Then get a new iterator starting after that id response = conn.get_shard_iterator( - stream_name, shard_id, 'AFTER_SEQUENCE_NUMBER', second_sequence_id) - shard_iterator = response['ShardIterator'] + stream_name, shard_id, "AFTER_SEQUENCE_NUMBER", second_sequence_id + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(shard_iterator) # And the first result returned should be the third item - response['Records'][0]['Data'].should.equal('3') - response['MillisBehindLatest'].should.equal(0) + response["Records"][0]["Data"].should.equal("3") + response["MillisBehindLatest"].should.equal(0) @mock_kinesis_deprecated @@ -262,42 +266,43 @@ def test_get_records_latest(): # Get a shard iterator response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(stream_name, shard_id, 'TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator(stream_name, shard_id, "TRIM_HORIZON") + shard_iterator = response["ShardIterator"] # Get the second record response = conn.get_records(shard_iterator, limit=2) - second_sequence_id = response['Records'][1]['SequenceNumber'] + second_sequence_id = response["Records"][1]["SequenceNumber"] # Then get a new iterator starting after that id response = conn.get_shard_iterator( - stream_name, shard_id, 'LATEST', second_sequence_id) - shard_iterator = response['ShardIterator'] + stream_name, shard_id, "LATEST", second_sequence_id + ) + shard_iterator = response["ShardIterator"] # Write some more data conn.put_record(stream_name, "last_record", "last_record") response = conn.get_records(shard_iterator) # And the only result returned should be the new item - response['Records'].should.have.length_of(1) - response['Records'][0]['PartitionKey'].should.equal('last_record') - response['Records'][0]['Data'].should.equal('last_record') - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(1) + response["Records"][0]["PartitionKey"].should.equal("last_record") + response["Records"][0]["Data"].should.equal("last_record") + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_at_timestamp(): # AT_TIMESTAMP - Read the first record at or after the specified timestamp - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) # Create some data for index in range(1, 5): - conn.put_record(StreamName=stream_name, - Data=str(index), - PartitionKey=str(index)) + conn.put_record( + StreamName=stream_name, Data=str(index), PartitionKey=str(index) + ) # When boto3 floors the timestamp that we pass to get_shard_iterator to # second precision even though AWS supports ms precision: @@ -309,148 +314,143 @@ def test_get_records_at_timestamp(): keys = [str(i) for i in range(5, 10)] for k in keys: - conn.put_record(StreamName=stream_name, - Data=k, - PartitionKey=k) + conn.put_record(StreamName=stream_name, Data=k, PartitionKey=k) # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(len(keys)) - partition_keys = [r['PartitionKey'] for r in response['Records']] + response["Records"].should.have.length_of(len(keys)) + partition_keys = [r["PartitionKey"] for r in response["Records"]] partition_keys.should.equal(keys) - response['MillisBehindLatest'].should.equal(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_at_very_old_timestamp(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) # Create some data keys = [str(i) for i in range(1, 5)] for k in keys: - conn.put_record(StreamName=stream_name, - Data=k, - PartitionKey=k) + conn.put_record(StreamName=stream_name, Data=k, PartitionKey=k) # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=1) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=1, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(len(keys)) - partition_keys = [r['PartitionKey'] for r in response['Records']] + response["Records"].should.have.length_of(len(keys)) + partition_keys = [r["PartitionKey"] for r in response["Records"]] partition_keys.should.equal(keys) - response['MillisBehindLatest'].should.equal(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_timestamp_filtering(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) - conn.put_record(StreamName=stream_name, - Data='0', - PartitionKey='0') + conn.put_record(StreamName=stream_name, Data="0", PartitionKey="0") time.sleep(1.0) timestamp = datetime.datetime.utcnow() - conn.put_record(StreamName=stream_name, - Data='1', - PartitionKey='1') + conn.put_record(StreamName=stream_name, Data="1", PartitionKey="1") response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(1) - response['Records'][0]['PartitionKey'].should.equal('1') - response['Records'][0]['ApproximateArrivalTimestamp'].should.be. \ - greater_than(timestamp) - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(1) + response["Records"][0]["PartitionKey"].should.equal("1") + response["Records"][0]["ApproximateArrivalTimestamp"].should.be.greater_than( + timestamp + ) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_millis_behind_latest(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) - conn.put_record(StreamName=stream_name, - Data='0', - PartitionKey='0') + conn.put_record(StreamName=stream_name, Data="0", PartitionKey="0") time.sleep(1.0) - conn.put_record(StreamName=stream_name, - Data='1', - PartitionKey='1') + conn.put_record(StreamName=stream_name, Data="1", PartitionKey="1") response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='TRIM_HORIZON') - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, ShardId=shard_id, ShardIteratorType="TRIM_HORIZON" + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator, Limit=1) - response['Records'].should.have.length_of(1) - response['MillisBehindLatest'].should.be.greater_than(0) + response["Records"].should.have.length_of(1) + response["MillisBehindLatest"].should.be.greater_than(0) @mock_kinesis def test_get_records_at_very_new_timestamp(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) # Create some data keys = [str(i) for i in range(1, 5)] for k in keys: - conn.put_record(StreamName=stream_name, - Data=k, - PartitionKey=k) + conn.put_record(StreamName=stream_name, Data=k, PartitionKey=k) timestamp = datetime.datetime.utcnow() + datetime.timedelta(seconds=1) # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(0) - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis def test_get_records_from_empty_stream_at_timestamp(): - conn = boto3.client('kinesis', region_name="us-west-2") + conn = boto3.client("kinesis", region_name="us-west-2") stream_name = "my_stream" conn.create_stream(StreamName=stream_name, ShardCount=1) @@ -458,17 +458,19 @@ def test_get_records_from_empty_stream_at_timestamp(): # Get a shard iterator response = conn.describe_stream(StreamName=stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] - response = conn.get_shard_iterator(StreamName=stream_name, - ShardId=shard_id, - ShardIteratorType='AT_TIMESTAMP', - Timestamp=timestamp) - shard_iterator = response['ShardIterator'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] + response = conn.get_shard_iterator( + StreamName=stream_name, + ShardId=shard_id, + ShardIteratorType="AT_TIMESTAMP", + Timestamp=timestamp, + ) + shard_iterator = response["ShardIterator"] response = conn.get_records(ShardIterator=shard_iterator) - response['Records'].should.have.length_of(0) - response['MillisBehindLatest'].should.equal(0) + response["Records"].should.have.length_of(0) + response["MillisBehindLatest"].should.equal(0) @mock_kinesis_deprecated @@ -478,10 +480,10 @@ def test_invalid_shard_iterator_type(): conn.create_stream(stream_name, 1) response = conn.describe_stream(stream_name) - shard_id = response['StreamDescription']['Shards'][0]['ShardId'] + shard_id = response["StreamDescription"]["Shards"][0]["ShardId"] response = conn.get_shard_iterator.when.called_with( - stream_name, shard_id, 'invalid-type').should.throw( - InvalidArgumentException) + stream_name, shard_id, "invalid-type" + ).should.throw(InvalidArgumentException) @mock_kinesis_deprecated @@ -491,10 +493,10 @@ def test_add_tags(): conn.create_stream(stream_name, 1) conn.describe_stream(stream_name) - conn.add_tags_to_stream(stream_name, {'tag1': 'val1'}) - conn.add_tags_to_stream(stream_name, {'tag2': 'val2'}) - conn.add_tags_to_stream(stream_name, {'tag1': 'val3'}) - conn.add_tags_to_stream(stream_name, {'tag2': 'val4'}) + conn.add_tags_to_stream(stream_name, {"tag1": "val1"}) + conn.add_tags_to_stream(stream_name, {"tag2": "val2"}) + conn.add_tags_to_stream(stream_name, {"tag1": "val3"}) + conn.add_tags_to_stream(stream_name, {"tag2": "val4"}) @mock_kinesis_deprecated @@ -504,22 +506,38 @@ def test_list_tags(): conn.create_stream(stream_name, 1) conn.describe_stream(stream_name) - conn.add_tags_to_stream(stream_name, {'tag1': 'val1'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal('val1') - conn.add_tags_to_stream(stream_name, {'tag2': 'val2'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal('val2') - conn.add_tags_to_stream(stream_name, {'tag1': 'val3'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal('val3') - conn.add_tags_to_stream(stream_name, {'tag2': 'val4'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal('val4') + conn.add_tags_to_stream(stream_name, {"tag1": "val1"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal("val1") + conn.add_tags_to_stream(stream_name, {"tag2": "val2"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal("val2") + conn.add_tags_to_stream(stream_name, {"tag1": "val3"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal("val3") + conn.add_tags_to_stream(stream_name, {"tag2": "val4"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal("val4") @mock_kinesis_deprecated @@ -529,29 +547,45 @@ def test_remove_tags(): conn.create_stream(stream_name, 1) conn.describe_stream(stream_name) - conn.add_tags_to_stream(stream_name, {'tag1': 'val1'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal('val1') - conn.remove_tags_from_stream(stream_name, ['tag1']) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag1').should.equal(None) + conn.add_tags_to_stream(stream_name, {"tag1": "val1"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal("val1") + conn.remove_tags_from_stream(stream_name, ["tag1"]) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag1").should.equal(None) - conn.add_tags_to_stream(stream_name, {'tag2': 'val2'}) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal('val2') - conn.remove_tags_from_stream(stream_name, ['tag2']) - tags = dict([(tag['Key'], tag['Value']) - for tag in conn.list_tags_for_stream(stream_name)['Tags']]) - tags.get('tag2').should.equal(None) + conn.add_tags_to_stream(stream_name, {"tag2": "val2"}) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal("val2") + conn.remove_tags_from_stream(stream_name, ["tag2"]) + tags = dict( + [ + (tag["Key"], tag["Value"]) + for tag in conn.list_tags_for_stream(stream_name)["Tags"] + ] + ) + tags.get("tag2").should.equal(None) @mock_kinesis_deprecated def test_split_shard(): conn = boto.kinesis.connect_to_region("us-west-2") - stream_name = 'my_stream' + stream_name = "my_stream" conn.create_stream(stream_name, 2) @@ -562,44 +596,47 @@ def test_split_shard(): stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(2) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) - shard_range = shards[0]['HashKeyRange'] + shard_range = shards[0]["HashKeyRange"] new_starting_hash = ( - int(shard_range['EndingHashKey']) + int( - shard_range['StartingHashKey'])) // 2 - conn.split_shard("my_stream", shards[0]['ShardId'], str(new_starting_hash)) + int(shard_range["EndingHashKey"]) + int(shard_range["StartingHashKey"]) + ) // 2 + conn.split_shard("my_stream", shards[0]["ShardId"], str(new_starting_hash)) stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(3) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) - shard_range = shards[2]['HashKeyRange'] + shard_range = shards[2]["HashKeyRange"] new_starting_hash = ( - int(shard_range['EndingHashKey']) + int( - shard_range['StartingHashKey'])) // 2 - conn.split_shard("my_stream", shards[2]['ShardId'], str(new_starting_hash)) + int(shard_range["EndingHashKey"]) + int(shard_range["StartingHashKey"]) + ) // 2 + conn.split_shard("my_stream", shards[2]["ShardId"], str(new_starting_hash)) stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(4) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) @mock_kinesis_deprecated def test_merge_shards(): conn = boto.kinesis.connect_to_region("us-west-2") - stream_name = 'my_stream' + stream_name = "my_stream" conn.create_stream(stream_name, 4) @@ -610,38 +647,39 @@ def test_merge_shards(): stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(4) conn.merge_shards.when.called_with( - stream_name, 'shardId-000000000000', - 'shardId-000000000002').should.throw(InvalidArgumentException) + stream_name, "shardId-000000000000", "shardId-000000000002" + ).should.throw(InvalidArgumentException) stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(4) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) - conn.merge_shards(stream_name, 'shardId-000000000000', - 'shardId-000000000001') + conn.merge_shards(stream_name, "shardId-000000000000", "shardId-000000000001") stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(3) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) - conn.merge_shards(stream_name, 'shardId-000000000002', - 'shardId-000000000000') + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) + conn.merge_shards(stream_name, "shardId-000000000002", "shardId-000000000000") stream_response = conn.describe_stream(stream_name) stream = stream_response["StreamDescription"] - shards = stream['Shards'] + shards = stream["Shards"] shards.should.have.length_of(2) - sum([shard['SequenceNumberRange']['EndingSequenceNumber'] - for shard in shards]).should.equal(99) + sum( + [shard["SequenceNumberRange"]["EndingSequenceNumber"] for shard in shards] + ).should.equal(99) diff --git a/tests/test_kinesis/test_server.py b/tests/test_kinesis/test_server.py index 527310d75..3d7fdeee4 100644 --- a/tests/test_kinesis/test_server.py +++ b/tests/test_kinesis/test_server.py @@ -6,9 +6,9 @@ import sure # noqa import moto.server as server from moto import mock_kinesis -''' +""" Test the different server responses -''' +""" @mock_kinesis @@ -16,10 +16,7 @@ def test_list_streams(): backend = server.create_backend_app("kinesis") test_client = backend.test_client() - res = test_client.get('/?Action=ListStreams') + res = test_client.get("/?Action=ListStreams") json_data = json.loads(res.data.decode("utf-8")) - json_data.should.equal({ - "HasMoreStreams": False, - "StreamNames": [], - }) + json_data.should.equal({"HasMoreStreams": False, "StreamNames": []}) diff --git a/tests/test_kms/test_kms.py b/tests/test_kms/test_kms.py index 99f2f15ae..70fa68787 100644 --- a/tests/test_kms/test_kms.py +++ b/tests/test_kms/test_kms.py @@ -24,7 +24,7 @@ from moto import mock_kms, mock_kms_deprecated PLAINTEXT_VECTORS = ( (b"some encodeable plaintext",), (b"some unencodeable plaintext \xec\x8a\xcf\xb6r\xe9\xb5\xeb\xff\xa23\x16",), - (u"some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",), + ("some unicode characters ø˚∆øˆˆ∆ßçøˆˆçßøˆ¨¥",), ) @@ -55,7 +55,9 @@ def test_create_key(): @mock_kms_deprecated def test_describe_key(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] key = conn.describe_key(key_id) @@ -66,8 +68,12 @@ def test_describe_key(): @mock_kms_deprecated def test_describe_key_via_alias(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") - conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) alias_key = conn.describe_key("alias/my-key-alias") alias_key["KeyMetadata"]["Description"].should.equal("my key") @@ -78,17 +84,25 @@ def test_describe_key_via_alias(): @mock_kms_deprecated def test_describe_key_via_alias_not_found(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") - conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) - conn.describe_key.when.called_with("alias/not-found-alias").should.throw(NotFoundException) + conn.describe_key.when.called_with("alias/not-found-alias").should.throw( + NotFoundException + ) -@parameterized(( +@parameterized( + ( ("alias/does-not-exist",), ("arn:aws:kms:us-east-1:012345678912:alias/does-not-exist",), ("invalid",), -)) + ) +) @mock_kms def test_describe_key_via_alias_invalid_alias(key_id): client = boto3.client("kms", region_name="us-east-1") @@ -101,7 +115,9 @@ def test_describe_key_via_alias_invalid_alias(key_id): @mock_kms_deprecated def test_describe_key_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) arn = key["KeyMetadata"]["Arn"] the_key = conn.describe_key(arn) @@ -120,8 +136,12 @@ def test_describe_missing_key(): def test_list_keys(): conn = boto.kms.connect_to_region("us-west-2") - conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") - conn.create_key(policy="my policy", description="my key2", key_usage="ENCRYPT_DECRYPT") + conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_key( + policy="my policy", description="my key2", key_usage="ENCRYPT_DECRYPT" + ) keys = conn.list_keys() keys["Keys"].should.have.length_of(2) @@ -131,7 +151,9 @@ def test_list_keys(): def test_enable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) @@ -143,7 +165,9 @@ def test_enable_key_rotation(): def test_enable_key_rotation_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["Arn"] conn.enable_key_rotation(key_id) @@ -154,26 +178,36 @@ def test_enable_key_rotation_via_arn(): @mock_kms_deprecated def test_enable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.enable_key_rotation.when.called_with("not-a-key").should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("not-a-key").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_enable_key_rotation_with_alias_name_should_fail(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") - conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) alias_key = conn.describe_key("alias/my-key-alias") alias_key["KeyMetadata"]["Arn"].should.equal(key["KeyMetadata"]["Arn"]) - conn.enable_key_rotation.when.called_with("alias/my-alias").should.throw(NotFoundException) + conn.enable_key_rotation.when.called_with("alias/my-alias").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_disable_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] conn.enable_key_rotation(key_id) @@ -187,7 +221,9 @@ def test_disable_key_rotation(): def test_generate_data_key(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] key_arn = key["KeyMetadata"]["Arn"] @@ -271,20 +307,26 @@ def test_decrypt(plaintext): @mock_kms_deprecated def test_disable_key_rotation_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.disable_key_rotation.when.called_with("not-a-key").should.throw(NotFoundException) + conn.disable_key_rotation.when.called_with("not-a-key").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_get_key_rotation_status_with_missing_key(): conn = boto.kms.connect_to_region("us-west-2") - conn.get_key_rotation_status.when.called_with("not-a-key").should.throw(NotFoundException) + conn.get_key_rotation_status.when.called_with("not-a-key").should.throw( + NotFoundException + ) @mock_kms_deprecated def test_get_key_rotation_status(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @@ -294,7 +336,9 @@ def test_get_key_rotation_status(): def test_create_key_defaults_key_rotation(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] conn.get_key_rotation_status(key_id)["KeyRotationEnabled"].should.equal(False) @@ -304,7 +348,9 @@ def test_create_key_defaults_key_rotation(): def test_get_key_policy(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] policy = conn.get_key_policy(key_id, "default") @@ -315,7 +361,9 @@ def test_get_key_policy(): def test_get_key_policy_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) policy = conn.get_key_policy(key["KeyMetadata"]["Arn"], "default") policy["Policy"].should.equal("my policy") @@ -325,7 +373,9 @@ def test_get_key_policy_via_arn(): def test_put_key_policy(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] conn.put_key_policy(key_id, "default", "new policy") @@ -337,7 +387,9 @@ def test_put_key_policy(): def test_put_key_policy_via_arn(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["Arn"] conn.put_key_policy(key_id, "default", "new policy") @@ -349,10 +401,16 @@ def test_put_key_policy_via_arn(): def test_put_key_policy_via_alias_should_not_update(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") - conn.create_alias(alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"]) + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) + conn.create_alias( + alias_name="alias/my-key-alias", target_key_id=key["KeyMetadata"]["KeyId"] + ) - conn.put_key_policy.when.called_with("alias/my-key-alias", "default", "new policy").should.throw(NotFoundException) + conn.put_key_policy.when.called_with( + "alias/my-key-alias", "default", "new policy" + ).should.throw(NotFoundException) policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") policy["Policy"].should.equal("my policy") @@ -362,7 +420,9 @@ def test_put_key_policy_via_alias_should_not_update(): def test_put_key_policy(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) conn.put_key_policy(key["KeyMetadata"]["Arn"], "default", "new policy") policy = conn.get_key_policy(key["KeyMetadata"]["KeyId"], "default") @@ -373,7 +433,9 @@ def test_put_key_policy(): def test_list_key_policies(): conn = boto.kms.connect_to_region("us-west-2") - key = conn.create_key(policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT") + key = conn.create_key( + policy="my policy", description="my key1", key_usage="ENCRYPT_DECRYPT" + ) key_id = key["KeyMetadata"]["KeyId"] policies = conn.list_key_policies(key_id) @@ -397,7 +459,12 @@ def test__create_alias__raises_if_reserved_alias(): create_resp = kms.create_key() key_id = create_resp["KeyMetadata"]["KeyId"] - reserved_aliases = ["alias/aws/ebs", "alias/aws/s3", "alias/aws/redshift", "alias/aws/rds"] + reserved_aliases = [ + "alias/aws/ebs", + "alias/aws/s3", + "alias/aws/redshift", + "alias/aws/rds", + ] for alias_name in reserved_aliases: with assert_raises(JSONResponseError) as err: @@ -434,7 +501,9 @@ def test__create_alias__raises_if_wrong_prefix(): ex = err.exception ex.error_message.should.equal("Invalid identifier") ex.error_code.should.equal("ValidationException") - ex.body.should.equal({"message": "Invalid identifier", "__type": "ValidationException"}) + ex.body.should.equal( + {"message": "Invalid identifier", "__type": "ValidationException"} + ) ex.reason.should.equal("Bad Request") ex.status.should.equal(400) @@ -454,13 +523,17 @@ def test__create_alias__raises_if_duplicate(): ex = err.exception ex.error_message.should.match( - r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format(**locals()) + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format( + **locals() + ) ) ex.error_code.should.be.none ex.box_usage.should.be.none ex.request_id.should.be.none ex.body["message"].should.match( - r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format(**locals()) + r"An alias with the name arn:aws:kms:{region}:\d{{12}}:{alias} already exists".format( + **locals() + ) ) ex.body["__type"].should.equal("AlreadyExistsException") ex.reason.should.equal("Bad Request") @@ -473,7 +546,11 @@ def test__create_alias__raises_if_alias_has_restricted_characters(): create_resp = kms.create_key() key_id = create_resp["KeyMetadata"]["KeyId"] - alias_names_with_restricted_characters = ["alias/my-alias!", "alias/my-alias$", "alias/my-alias@"] + alias_names_with_restricted_characters = [ + "alias/my-alias!", + "alias/my-alias$", + "alias/my-alias@", + ] for alias_name in alias_names_with_restricted_characters: with assert_raises(JSONResponseError) as err: @@ -510,17 +587,18 @@ def test__create_alias__raises_if_alias_has_colon_character(): kms.create_alias(alias_name, key_id) ex = err.exception ex.body["__type"].should.equal("ValidationException") - ex.body["message"].should.equal("{alias_name} contains invalid characters for an alias".format(**locals())) + ex.body["message"].should.equal( + "{alias_name} contains invalid characters for an alias".format(**locals()) + ) ex.error_code.should.equal("ValidationException") - ex.message.should.equal("{alias_name} contains invalid characters for an alias".format(**locals())) + ex.message.should.equal( + "{alias_name} contains invalid characters for an alias".format(**locals()) + ) ex.reason.should.equal("Bad Request") ex.status.should.equal(400) -@parameterized(( - ("alias/my-alias_/",), - ("alias/my_alias-/",), -)) +@parameterized((("alias/my-alias_/",), ("alias/my_alias-/",))) @mock_kms_deprecated def test__create_alias__accepted_characters(alias_name): kms = boto.connect_kms() @@ -601,8 +679,7 @@ def test__delete_alias__raises_if_alias_is_not_found(): kms.delete_alias(alias_name) expected_message_match = r"Alias arn:aws:kms:{region}:[0-9]{{12}}:{alias_name} is not found.".format( - region=region, - alias_name=alias_name + region=region, alias_name=alias_name ) ex = err.exception ex.body["__type"].should.equal("NotFoundException") @@ -636,39 +713,78 @@ def test__list_aliases(): alias_name = alias_obj["AliasName"] alias_arn = alias_obj["AliasArn"] return re.match( - r"arn:aws:kms:{region}:\d{{12}}:{alias_name}".format(region=region, alias_name=alias_name), alias_arn + r"arn:aws:kms:{region}:\d{{12}}:{alias_name}".format( + region=region, alias_name=alias_name + ), + alias_arn, ) - len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/ebs" == alias["AliasName"]]).should.equal( - 1 - ) - len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/rds" == alias["AliasName"]]).should.equal( - 1 - ) len( - [alias for alias in aliases if has_correct_arn(alias) and "alias/aws/redshift" == alias["AliasName"]] - ).should.equal(1) - len([alias for alias in aliases if has_correct_arn(alias) and "alias/aws/s3" == alias["AliasName"]]).should.equal(1) - - len( - [alias for alias in aliases if has_correct_arn(alias) and "alias/my-alias1" == alias["AliasName"]] + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/ebs" == alias["AliasName"] + ] ).should.equal(1) len( - [alias for alias in aliases if has_correct_arn(alias) and "alias/my-alias2" == alias["AliasName"]] + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/rds" == alias["AliasName"] + ] + ).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/redshift" == alias["AliasName"] + ] + ).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/aws/s3" == alias["AliasName"] + ] ).should.equal(1) - len([alias for alias in aliases if "TargetKeyId" in alias and key_id == alias["TargetKeyId"]]).should.equal(3) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/my-alias1" == alias["AliasName"] + ] + ).should.equal(1) + len( + [ + alias + for alias in aliases + if has_correct_arn(alias) and "alias/my-alias2" == alias["AliasName"] + ] + ).should.equal(1) + + len( + [ + alias + for alias in aliases + if "TargetKeyId" in alias and key_id == alias["TargetKeyId"] + ] + ).should.equal(3) len(aliases).should.equal(7) -@parameterized(( +@parameterized( + ( ("not-a-uuid",), ("alias/DoesNotExist",), ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), ("d25652e4-d2d2-49f7-929a-671ccda580c6",), - ("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",), -)) + ( + "arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6", + ), + ) +) @mock_kms def test_invalid_key_ids(key_id): client = boto3.client("kms", region_name="us-east-1") @@ -681,8 +797,12 @@ def test_invalid_key_ids(key_id): def test__assert_default_policy(): from moto.kms.responses import _assert_default_policy - _assert_default_policy.when.called_with("not-default").should.throw(MotoNotFoundException) - _assert_default_policy.when.called_with("default").should_not.throw(MotoNotFoundException) + _assert_default_policy.when.called_with("not-default").should.throw( + MotoNotFoundException + ) + _assert_default_policy.when.called_with("default").should_not.throw( + MotoNotFoundException + ) @parameterized(PLAINTEXT_VECTORS) @@ -727,7 +847,9 @@ def test_schedule_key_deletion(): with freeze_time("2015-01-01 12:00:00"): response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - assert response["DeletionDate"] == datetime(2015, 1, 31, 12, 0, tzinfo=tzutc()) + assert response["DeletionDate"] == datetime( + 2015, 1, 31, 12, 0, tzinfo=tzutc() + ) else: # Can't manipulate time in server mode response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) @@ -745,12 +867,18 @@ def test_schedule_key_deletion_custom(): key = client.create_key(Description="schedule-key-deletion") if os.environ.get("TEST_SERVER_MODE", "false").lower() == "false": with freeze_time("2015-01-01 12:00:00"): - response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7) + response = client.schedule_key_deletion( + KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7 + ) assert response["KeyId"] == key["KeyMetadata"]["KeyId"] - assert response["DeletionDate"] == datetime(2015, 1, 8, 12, 0, tzinfo=tzutc()) + assert response["DeletionDate"] == datetime( + 2015, 1, 8, 12, 0, tzinfo=tzutc() + ) else: # Can't manipulate time in server mode - response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7) + response = client.schedule_key_deletion( + KeyId=key["KeyMetadata"]["KeyId"], PendingWindowInDays=7 + ) assert response["KeyId"] == key["KeyMetadata"]["KeyId"] result = client.describe_key(KeyId=key["KeyMetadata"]["KeyId"]) @@ -790,7 +918,9 @@ def test_tag_resource(): response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) keyid = response["KeyId"] - response = client.tag_resource(KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]) + response = client.tag_resource( + KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}] + ) # Shouldn't have any data, just header assert len(response.keys()) == 1 @@ -803,20 +933,24 @@ def test_list_resource_tags(): response = client.schedule_key_deletion(KeyId=key["KeyMetadata"]["KeyId"]) keyid = response["KeyId"] - response = client.tag_resource(KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}]) + response = client.tag_resource( + KeyId=keyid, Tags=[{"TagKey": "string", "TagValue": "string"}] + ) response = client.list_resource_tags(KeyId=keyid) assert response["Tags"][0]["TagKey"] == "string" assert response["Tags"][0]["TagValue"] == "string" -@parameterized(( +@parameterized( + ( (dict(KeySpec="AES_256"), 32), (dict(KeySpec="AES_128"), 16), (dict(NumberOfBytes=64), 64), (dict(NumberOfBytes=1), 1), (dict(NumberOfBytes=1024), 1024), -)) + ) +) @mock_kms def test_generate_data_key_sizes(kwargs, expected_key_length): client = boto3.client("kms", region_name="us-east-1") @@ -832,34 +966,44 @@ def test_generate_data_key_decrypt(): client = boto3.client("kms", region_name="us-east-1") key = client.create_key(Description="generate-data-key-decrypt") - resp1 = client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") + resp1 = client.generate_data_key( + KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256" + ) resp2 = client.decrypt(CiphertextBlob=resp1["CiphertextBlob"]) assert resp1["Plaintext"] == resp2["Plaintext"] -@parameterized(( +@parameterized( + ( (dict(KeySpec="AES_257"),), (dict(KeySpec="AES_128", NumberOfBytes=16),), (dict(NumberOfBytes=2048),), (dict(NumberOfBytes=0),), (dict(),), -)) + ) +) @mock_kms def test_generate_data_key_invalid_size_params(kwargs): client = boto3.client("kms", region_name="us-east-1") key = client.create_key(Description="generate-data-key-size") - with assert_raises((botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError)) as err: + with assert_raises( + (botocore.exceptions.ClientError, botocore.exceptions.ParamValidationError) + ) as err: client.generate_data_key(KeyId=key["KeyMetadata"]["KeyId"], **kwargs) -@parameterized(( +@parameterized( + ( ("alias/DoesNotExist",), ("arn:aws:kms:us-east-1:012345678912:alias/DoesNotExist",), ("d25652e4-d2d2-49f7-929a-671ccda580c6",), - ("arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6",), -)) + ( + "arn:aws:kms:us-east-1:012345678912:key/d25652e4-d2d2-49f7-929a-671ccda580c6", + ), + ) +) @mock_kms def test_generate_data_key_invalid_key(key_id): client = boto3.client("kms", region_name="us-east-1") @@ -868,12 +1012,14 @@ def test_generate_data_key_invalid_key(key_id): client.generate_data_key(KeyId=key_id, KeySpec="AES_256") -@parameterized(( +@parameterized( + ( ("alias/DoesExist", False), ("arn:aws:kms:us-east-1:012345678912:alias/DoesExist", False), ("", True), ("arn:aws:kms:us-east-1:012345678912:key/", True), -)) + ) +) @mock_kms def test_generate_data_key_all_valid_key_ids(prefix, append_key_id): client = boto3.client("kms", region_name="us-east-1") @@ -893,7 +1039,9 @@ def test_generate_data_key_without_plaintext_decrypt(): client = boto3.client("kms", region_name="us-east-1") key = client.create_key(Description="generate-data-key-decrypt") - resp1 = client.generate_data_key_without_plaintext(KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256") + resp1 = client.generate_data_key_without_plaintext( + KeyId=key["KeyMetadata"]["KeyId"], KeySpec="AES_256" + ) assert "Plaintext" not in resp1 @@ -911,9 +1059,7 @@ def test_re_encrypt_decrypt(plaintext): key_2_arn = key_2["KeyMetadata"]["Arn"] encrypt_response = client.encrypt( - KeyId=key_1_id, - Plaintext=plaintext, - EncryptionContext={"encryption": "context"}, + KeyId=key_1_id, Plaintext=plaintext, EncryptionContext={"encryption": "context"} ) re_encrypt_response = client.re_encrypt( @@ -954,10 +1100,7 @@ def test_re_encrypt_to_invalid_destination(): key = client.create_key(Description="key 1") key_id = key["KeyMetadata"]["KeyId"] - encrypt_response = client.encrypt( - KeyId=key_id, - Plaintext=b"some plaintext", - ) + encrypt_response = client.encrypt(KeyId=key_id, Plaintext=b"some plaintext") with assert_raises(client.exceptions.NotFoundException): client.re_encrypt( @@ -977,13 +1120,15 @@ def test_generate_random(number_of_bytes): len(response["Plaintext"]).should.equal(number_of_bytes) -@parameterized(( +@parameterized( + ( (2048, botocore.exceptions.ClientError), (1025, botocore.exceptions.ClientError), (0, botocore.exceptions.ParamValidationError), (-1, botocore.exceptions.ParamValidationError), - (-1024, botocore.exceptions.ParamValidationError) -)) + (-1024, botocore.exceptions.ParamValidationError), + ) +) @mock_kms def test_generate_random_invalid_number_of_bytes(number_of_bytes, error_type): client = boto3.client("kms", region_name="us-west-2") @@ -1053,7 +1198,9 @@ def test_get_key_policy_key_not_found(): client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.get_key_policy(KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default") + client.get_key_policy( + KeyId="12366f9b-1230-123d-123e-123e6ae60c02", PolicyName="default" + ) @mock_kms @@ -1069,4 +1216,8 @@ def test_put_key_policy_key_not_found(): client = boto3.client("kms", region_name="us-east-1") with assert_raises(client.exceptions.NotFoundException): - client.put_key_policy(KeyId="00000000-0000-0000-0000-000000000000", PolicyName="default", Policy="new policy") + client.put_key_policy( + KeyId="00000000-0000-0000-0000-000000000000", + PolicyName="default", + Policy="new policy", + ) diff --git a/tests/test_kms/test_server.py b/tests/test_kms/test_server.py index 7b8f74e3b..083f9d18a 100644 --- a/tests/test_kms/test_server.py +++ b/tests/test_kms/test_server.py @@ -6,9 +6,9 @@ import sure # noqa import moto.server as server from moto import mock_kms -''' +""" Test the different server responses -''' +""" @mock_kms @@ -16,10 +16,8 @@ def test_list_keys(): backend = server.create_backend_app("kms") test_client = backend.test_client() - res = test_client.get('/?Action=ListKeys') + res = test_client.get("/?Action=ListKeys") - json.loads(res.data.decode("utf-8")).should.equal({ - "Keys": [], - "NextMarker": None, - "Truncated": False, - }) + json.loads(res.data.decode("utf-8")).should.equal( + {"Keys": [], "NextMarker": None, "Truncated": False} + ) diff --git a/tests/test_kms/test_utils.py b/tests/test_kms/test_utils.py index 73d7d3580..f5478e0ef 100644 --- a/tests/test_kms/test_utils.py +++ b/tests/test_kms/test_utils.py @@ -4,7 +4,11 @@ import sure # noqa from nose.tools import assert_raises from parameterized import parameterized -from moto.kms.exceptions import AccessDeniedException, InvalidCiphertextException, NotFoundException +from moto.kms.exceptions import ( + AccessDeniedException, + InvalidCiphertextException, + NotFoundException, +) from moto.kms.models import Key from moto.kms.utils import ( _deserialize_ciphertext_blob, @@ -19,8 +23,14 @@ from moto.kms.utils import ( ) ENCRYPTION_CONTEXT_VECTORS = ( - ({"this": "is", "an": "encryption", "context": "example"}, b"an" b"encryption" b"context" b"example" b"this" b"is"), - ({"a_this": "one", "b_is": "actually", "c_in": "order"}, b"a_this" b"one" b"b_is" b"actually" b"c_in" b"order"), + ( + {"this": "is", "an": "encryption", "context": "example"}, + b"an" b"encryption" b"context" b"example" b"this" b"is", + ), + ( + {"a_this": "one", "b_is": "actually", "c_in": "order"}, + b"a_this" b"one" b"b_is" b"actually" b"c_in" b"order", + ), ) CIPHERTEXT_BLOB_VECTORS = ( ( @@ -30,7 +40,10 @@ CIPHERTEXT_BLOB_VECTORS = ( ciphertext=b"some ciphertext", tag=b"1234567890123456", ), - b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext", + b"d25652e4-d2d2-49f7-929a-671ccda580c6" + b"123456789012" + b"1234567890123456" + b"some ciphertext", ), ( Ciphertext( @@ -93,12 +106,17 @@ def test_encrypt_decrypt_cycle(encryption_context): master_key_map = {master_key.id: master_key} ciphertext_blob = encrypt( - master_keys=master_key_map, key_id=master_key.id, plaintext=plaintext, encryption_context=encryption_context + master_keys=master_key_map, + key_id=master_key.id, + plaintext=plaintext, + encryption_context=encryption_context, ) ciphertext_blob.should_not.equal(plaintext) decrypted, decrypting_key_id = decrypt( - master_keys=master_key_map, ciphertext_blob=ciphertext_blob, encryption_context=encryption_context + master_keys=master_key_map, + ciphertext_blob=ciphertext_blob, + encryption_context=encryption_context, ) decrypted.should.equal(plaintext) decrypting_key_id.should.equal(master_key.id) @@ -106,7 +124,12 @@ def test_encrypt_decrypt_cycle(encryption_context): def test_encrypt_unknown_key_id(): with assert_raises(NotFoundException): - encrypt(master_keys={}, key_id="anything", plaintext=b"secrets", encryption_context={}) + encrypt( + master_keys={}, + key_id="anything", + plaintext=b"secrets", + encryption_context={}, + ) def test_decrypt_invalid_ciphertext_format(): @@ -118,7 +141,12 @@ def test_decrypt_invalid_ciphertext_format(): def test_decrypt_unknwown_key_id(): - ciphertext_blob = b"d25652e4-d2d2-49f7-929a-671ccda580c6" b"123456789012" b"1234567890123456" b"some ciphertext" + ciphertext_blob = ( + b"d25652e4-d2d2-49f7-929a-671ccda580c6" + b"123456789012" + b"1234567890123456" + b"some ciphertext" + ) with assert_raises(AccessDeniedException): decrypt(master_keys={}, ciphertext_blob=ciphertext_blob, encryption_context={}) @@ -127,7 +155,11 @@ def test_decrypt_unknwown_key_id(): def test_decrypt_invalid_ciphertext(): master_key = Key("nop", "nop", "nop", [], "nop") master_key_map = {master_key.id: master_key} - ciphertext_blob = master_key.id.encode("utf-8") + b"123456789012" b"1234567890123456" b"some ciphertext" + ciphertext_blob = ( + master_key.id.encode("utf-8") + b"123456789012" + b"1234567890123456" + b"some ciphertext" + ) with assert_raises(InvalidCiphertextException): decrypt( diff --git a/tests/test_logs/test_logs.py b/tests/test_logs/test_logs.py index 5a843a657..57f457381 100644 --- a/tests/test_logs/test_logs.py +++ b/tests/test_logs/test_logs.py @@ -8,93 +8,73 @@ from moto import mock_logs, settings from nose.tools import assert_raises from nose import SkipTest -_logs_region = 'us-east-1' if settings.TEST_SERVER_MODE else 'us-west-2' +_logs_region = "us-east-1" if settings.TEST_SERVER_MODE else "us-west-2" @mock_logs def test_log_group_create(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" response = conn.create_log_group(logGroupName=log_group_name) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 + assert len(response["logGroups"]) == 1 # AWS defaults to Never Expire for log group retention - assert response['logGroups'][0].get('retentionInDays') == None + assert response["logGroups"][0].get("retentionInDays") == None response = conn.delete_log_group(logGroupName=log_group_name) @mock_logs def test_exceptions(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - log_stream_name = 'dummp-stream' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "dummp-stream" conn.create_log_group(logGroupName=log_group_name) with assert_raises(ClientError): conn.create_log_group(logGroupName=log_group_name) # descrine_log_groups is not implemented yet - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) with assert_raises(ClientError): conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name + logGroupName=log_group_name, logStreamName=log_stream_name ) conn.put_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, - logEvents=[ - { - 'timestamp': 0, - 'message': 'line' - }, - ], + logEvents=[{"timestamp": 0, "message": "line"}], ) with assert_raises(ClientError): conn.put_log_events( logGroupName=log_group_name, logStreamName="invalid-stream", - logEvents=[ - { - 'timestamp': 0, - 'message': 'line' - }, - ], + logEvents=[{"timestamp": 0, "message": "line"}], ) @mock_logs def test_put_logs(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - log_stream_name = 'stream' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "stream" conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) messages = [ - {'timestamp': 0, 'message': 'hello'}, - {'timestamp': 0, 'message': 'world'} + {"timestamp": 0, "message": "hello"}, + {"timestamp": 0, "message": "world"}, ] putRes = conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=messages + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=messages ) res = conn.get_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name + logGroupName=log_group_name, logStreamName=log_stream_name ) - events = res['events'] - nextSequenceToken = putRes['nextSequenceToken'] + events = res["events"] + nextSequenceToken = putRes["nextSequenceToken"] assert isinstance(nextSequenceToken, six.string_types) == True assert len(nextSequenceToken) == 56 events.should.have.length_of(2) @@ -102,55 +82,43 @@ def test_put_logs(): @mock_logs def test_filter_logs_interleaved(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - log_stream_name = 'stream' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "stream" conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) messages = [ - {'timestamp': 0, 'message': 'hello'}, - {'timestamp': 0, 'message': 'world'} + {"timestamp": 0, "message": "hello"}, + {"timestamp": 0, "message": "world"}, ] conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=messages + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=messages ) res = conn.filter_log_events( - logGroupName=log_group_name, - logStreamNames=[log_stream_name], - interleaved=True, + logGroupName=log_group_name, logStreamNames=[log_stream_name], interleaved=True ) - events = res['events'] + events = res["events"] for original_message, resulting_event in zip(messages, events): - resulting_event['eventId'].should.equal(str(resulting_event['eventId'])) - resulting_event['timestamp'].should.equal(original_message['timestamp']) - resulting_event['message'].should.equal(original_message['message']) + resulting_event["eventId"].should.equal(str(resulting_event["eventId"])) + resulting_event["timestamp"].should.equal(original_message["timestamp"]) + resulting_event["message"].should.equal(original_message["message"]) @mock_logs def test_filter_logs_raises_if_filter_pattern(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Does not work in server mode due to error in Workzeug') - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - log_stream_name = 'stream' + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Does not work in server mode due to error in Workzeug") + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + log_stream_name = "stream" conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) messages = [ - {'timestamp': 0, 'message': 'hello'}, - {'timestamp': 0, 'message': 'world'} + {"timestamp": 0, "message": "hello"}, + {"timestamp": 0, "message": "world"}, ] conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=messages + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=messages ) with assert_raises(NotImplementedError): conn.filter_log_events( @@ -162,161 +130,169 @@ def test_filter_logs_raises_if_filter_pattern(): @mock_logs def test_put_retention_policy(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" response = conn.create_log_group(logGroupName=log_group_name) response = conn.put_retention_policy(logGroupName=log_group_name, retentionInDays=7) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 - assert response['logGroups'][0].get('retentionInDays') == 7 + assert len(response["logGroups"]) == 1 + assert response["logGroups"][0].get("retentionInDays") == 7 response = conn.delete_log_group(logGroupName=log_group_name) @mock_logs def test_delete_retention_policy(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" response = conn.create_log_group(logGroupName=log_group_name) response = conn.put_retention_policy(logGroupName=log_group_name, retentionInDays=7) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 - assert response['logGroups'][0].get('retentionInDays') == 7 + assert len(response["logGroups"]) == 1 + assert response["logGroups"][0].get("retentionInDays") == 7 response = conn.delete_retention_policy(logGroupName=log_group_name) response = conn.describe_log_groups(logGroupNamePrefix=log_group_name) - assert len(response['logGroups']) == 1 - assert response['logGroups'][0].get('retentionInDays') == None + assert len(response["logGroups"]) == 1 + assert response["logGroups"][0].get("retentionInDays") == None response = conn.delete_log_group(logGroupName=log_group_name) @mock_logs def test_get_log_events(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'test' - log_stream_name = 'stream' + conn = boto3.client("logs", "us-west-2") + log_group_name = "test" + log_stream_name = "stream" conn.create_log_group(logGroupName=log_group_name) - conn.create_log_stream( - logGroupName=log_group_name, - logStreamName=log_stream_name - ) + conn.create_log_stream(logGroupName=log_group_name, logStreamName=log_stream_name) - events = [{'timestamp': x, 'message': str(x)} for x in range(20)] + events = [{"timestamp": x, "message": str(x)} for x in range(20)] conn.put_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - logEvents=events + logGroupName=log_group_name, logStreamName=log_stream_name, logEvents=events ) resp = conn.get_log_events( - logGroupName=log_group_name, - logStreamName=log_stream_name, - limit=10) + logGroupName=log_group_name, logStreamName=log_stream_name, limit=10 + ) - resp['events'].should.have.length_of(10) - resp.should.have.key('nextForwardToken') - resp.should.have.key('nextBackwardToken') - resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000010') - resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000') + resp["events"].should.have.length_of(10) + resp.should.have.key("nextForwardToken") + resp.should.have.key("nextBackwardToken") + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000010" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000000" + ) for i in range(10): - resp['events'][i]['timestamp'].should.equal(i) - resp['events'][i]['message'].should.equal(str(i)) + resp["events"][i]["timestamp"].should.equal(i) + resp["events"][i]["message"].should.equal(str(i)) - next_token = resp['nextForwardToken'] + next_token = resp["nextForwardToken"] resp = conn.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, nextToken=next_token, - limit=10) + limit=10, + ) - resp['events'].should.have.length_of(10) - resp.should.have.key('nextForwardToken') - resp.should.have.key('nextBackwardToken') - resp['nextForwardToken'].should.equal('f/00000000000000000000000000000000000000000000000000000020') - resp['nextBackwardToken'].should.equal('b/00000000000000000000000000000000000000000000000000000000') + resp["events"].should.have.length_of(10) + resp.should.have.key("nextForwardToken") + resp.should.have.key("nextBackwardToken") + resp["nextForwardToken"].should.equal( + "f/00000000000000000000000000000000000000000000000000000020" + ) + resp["nextBackwardToken"].should.equal( + "b/00000000000000000000000000000000000000000000000000000000" + ) for i in range(10): - resp['events'][i]['timestamp'].should.equal(i+10) - resp['events'][i]['message'].should.equal(str(i+10)) + resp["events"][i]["timestamp"].should.equal(i + 10) + resp["events"][i]["message"].should.equal(str(i + 10)) resp = conn.get_log_events( logGroupName=log_group_name, logStreamName=log_stream_name, - nextToken=resp['nextBackwardToken'], - limit=10) + nextToken=resp["nextBackwardToken"], + limit=10, + ) - resp['events'].should.have.length_of(10) - resp.should.have.key('nextForwardToken') - resp.should.have.key('nextBackwardToken') + resp["events"].should.have.length_of(10) + resp.should.have.key("nextForwardToken") + resp.should.have.key("nextBackwardToken") for i in range(10): - resp['events'][i]['timestamp'].should.equal(i) - resp['events'][i]['message'].should.equal(str(i)) + resp["events"][i]["timestamp"].should.equal(i) + resp["events"][i]["message"].should.equal(str(i)) @mock_logs def test_list_tags_log_group(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - tags = {'tag_key_1': 'tag_value_1', 'tag_key_2': 'tag_value_2'} + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + tags = {"tag_key_1": "tag_value_1", "tag_key_2": "tag_value_2"} response = conn.create_log_group(logGroupName=log_group_name) response = conn.list_tags_log_group(logGroupName=log_group_name) - assert response['tags'] == {} + assert response["tags"] == {} response = conn.delete_log_group(logGroupName=log_group_name) response = conn.create_log_group(logGroupName=log_group_name, tags=tags) response = conn.list_tags_log_group(logGroupName=log_group_name) - assert response['tags'] == tags + assert response["tags"] == tags response = conn.delete_log_group(logGroupName=log_group_name) @mock_logs def test_tag_log_group(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' - tags = {'tag_key_1': 'tag_value_1'} + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" + tags = {"tag_key_1": "tag_value_1"} response = conn.create_log_group(logGroupName=log_group_name) response = conn.tag_log_group(logGroupName=log_group_name, tags=tags) response = conn.list_tags_log_group(logGroupName=log_group_name) - assert response['tags'] == tags + assert response["tags"] == tags - tags_with_added_value = {'tag_key_1': 'tag_value_1', 'tag_key_2': 'tag_value_2'} - response = conn.tag_log_group(logGroupName=log_group_name, tags={'tag_key_2': 'tag_value_2'}) + tags_with_added_value = {"tag_key_1": "tag_value_1", "tag_key_2": "tag_value_2"} + response = conn.tag_log_group( + logGroupName=log_group_name, tags={"tag_key_2": "tag_value_2"} + ) response = conn.list_tags_log_group(logGroupName=log_group_name) - assert response['tags'] == tags_with_added_value + assert response["tags"] == tags_with_added_value - tags_with_updated_value = {'tag_key_1': 'tag_value_XX', 'tag_key_2': 'tag_value_2'} - response = conn.tag_log_group(logGroupName=log_group_name, tags={'tag_key_1': 'tag_value_XX'}) + tags_with_updated_value = {"tag_key_1": "tag_value_XX", "tag_key_2": "tag_value_2"} + response = conn.tag_log_group( + logGroupName=log_group_name, tags={"tag_key_1": "tag_value_XX"} + ) response = conn.list_tags_log_group(logGroupName=log_group_name) - assert response['tags'] == tags_with_updated_value + assert response["tags"] == tags_with_updated_value response = conn.delete_log_group(logGroupName=log_group_name) @mock_logs def test_untag_log_group(): - conn = boto3.client('logs', 'us-west-2') - log_group_name = 'dummy' + conn = boto3.client("logs", "us-west-2") + log_group_name = "dummy" response = conn.create_log_group(logGroupName=log_group_name) - tags = {'tag_key_1': 'tag_value_1', 'tag_key_2': 'tag_value_2'} + tags = {"tag_key_1": "tag_value_1", "tag_key_2": "tag_value_2"} response = conn.tag_log_group(logGroupName=log_group_name, tags=tags) response = conn.list_tags_log_group(logGroupName=log_group_name) - assert response['tags'] == tags + assert response["tags"] == tags - tags_to_remove = ['tag_key_1'] - remaining_tags = {'tag_key_2': 'tag_value_2'} + tags_to_remove = ["tag_key_1"] + remaining_tags = {"tag_key_2": "tag_value_2"} response = conn.untag_log_group(logGroupName=log_group_name, tags=tags_to_remove) response = conn.list_tags_log_group(logGroupName=log_group_name) - assert response['tags'] == remaining_tags + assert response["tags"] == remaining_tags response = conn.delete_log_group(logGroupName=log_group_name) diff --git a/tests/test_opsworks/test_apps.py b/tests/test_opsworks/test_apps.py index 37d0f2fe4..417140df2 100644 --- a/tests/test_opsworks/test_apps.py +++ b/tests/test_opsworks/test_apps.py @@ -10,19 +10,15 @@ from moto import mock_opsworks @freeze_time("2015-01-01") @mock_opsworks def test_create_app_response(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] - response = client.create_app( - StackId=stack_id, - Type="other", - Name="TestApp" - ) + response = client.create_app(StackId=stack_id, Type="other", Name="TestApp") response.should.contain("AppId") @@ -30,73 +26,51 @@ def test_create_app_response(): Name="test_stack_2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] - response = client.create_app( - StackId=second_stack_id, - Type="other", - Name="TestApp" - ) + response = client.create_app(StackId=second_stack_id, Type="other", Name="TestApp") response.should.contain("AppId") # ClientError client.create_app.when.called_with( - StackId=stack_id, - Type="other", - Name="TestApp" - ).should.throw( - Exception, re.compile(r'already an app named "TestApp"') - ) + StackId=stack_id, Type="other", Name="TestApp" + ).should.throw(Exception, re.compile(r'already an app named "TestApp"')) # ClientError client.create_app.when.called_with( - StackId="nothere", - Type="other", - Name="TestApp" - ).should.throw( - Exception, "nothere" - ) + StackId="nothere", Type="other", Name="TestApp" + ).should.throw(Exception, "nothere") + @freeze_time("2015-01-01") @mock_opsworks def test_describe_apps(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] - app_id = client.create_app( - StackId=stack_id, - Type="other", - Name="TestApp" - )['AppId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] + app_id = client.create_app(StackId=stack_id, Type="other", Name="TestApp")["AppId"] rv1 = client.describe_apps(StackId=stack_id) rv2 = client.describe_apps(AppIds=[app_id]) - rv1['Apps'].should.equal(rv2['Apps']) + rv1["Apps"].should.equal(rv2["Apps"]) - rv1['Apps'][0]['Name'].should.equal("TestApp") + rv1["Apps"][0]["Name"].should.equal("TestApp") # ClientError client.describe_apps.when.called_with( - StackId=stack_id, - AppIds=[app_id] - ).should.throw( - Exception, "Please provide one or more app IDs or a stack ID" - ) + StackId=stack_id, AppIds=[app_id] + ).should.throw(Exception, "Please provide one or more app IDs or a stack ID") # ClientError - client.describe_apps.when.called_with( - StackId="nothere" - ).should.throw( + client.describe_apps.when.called_with(StackId="nothere").should.throw( Exception, "Unable to find stack with ID nothere" ) # ClientError - client.describe_apps.when.called_with( - AppIds=["nothere"] - ).should.throw( + client.describe_apps.when.called_with(AppIds=["nothere"]).should.throw( Exception, "nothere" ) diff --git a/tests/test_opsworks/test_instances.py b/tests/test_opsworks/test_instances.py index f594a87c8..55d23f08e 100644 --- a/tests/test_opsworks/test_instances.py +++ b/tests/test_opsworks/test_instances.py @@ -8,34 +8,34 @@ from moto import mock_ec2 @mock_opsworks def test_create_instance(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] layer_id = client.create_layer( StackId=stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" - )['LayerId'] + Shortname="TestLayerShortName", + )["LayerId"] second_stack_id = client.create_stack( Name="test_stack_2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] second_layer_id = client.create_layer( StackId=second_stack_id, Type="custom", Name="SecondTestLayer", - Shortname="SecondTestLayerShortName" - )['LayerId'] + Shortname="SecondTestLayerShortName", + )["LayerId"] response = client.create_instance( StackId=stack_id, LayerIds=[layer_id], InstanceType="t2.micro" @@ -55,9 +55,9 @@ def test_create_instance(): StackId=stack_id, LayerIds=[second_layer_id], InstanceType="t2.micro" ).should.throw(Exception, "Please only provide layer IDs from the same stack") # ClientError - client.start_instance.when.called_with( - InstanceId="nothere" - ).should.throw(Exception, "Unable to find instance with ID nothere") + client.start_instance.when.called_with(InstanceId="nothere").should.throw( + Exception, "Unable to find instance with ID nothere" + ) @mock_opsworks @@ -70,112 +70,95 @@ def test_describe_instances(): populate S2L2 with 3 instances (S2L2_i1..2) """ - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") S1 = client.create_stack( Name="S1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] S1L1 = client.create_layer( - StackId=S1, - Type="custom", - Name="S1L1", - Shortname="S1L1" - )['LayerId'] + StackId=S1, Type="custom", Name="S1L1", Shortname="S1L1" + )["LayerId"] S2 = client.create_stack( Name="S2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] S2L1 = client.create_layer( - StackId=S2, - Type="custom", - Name="S2L1", - Shortname="S2L1" - )['LayerId'] + StackId=S2, Type="custom", Name="S2L1", Shortname="S2L1" + )["LayerId"] S2L2 = client.create_layer( - StackId=S2, - Type="custom", - Name="S2L2", - Shortname="S2L2" - )['LayerId'] + StackId=S2, Type="custom", Name="S2L2", Shortname="S2L2" + )["LayerId"] S1L1_i1 = client.create_instance( StackId=S1, LayerIds=[S1L1], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S1L1_i2 = client.create_instance( StackId=S1, LayerIds=[S1L1], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S2L1_i1 = client.create_instance( StackId=S2, LayerIds=[S2L1], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S2L2_i1 = client.create_instance( StackId=S2, LayerIds=[S2L2], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] S2L2_i2 = client.create_instance( StackId=S2, LayerIds=[S2L2], InstanceType="t2.micro" - )['InstanceId'] + )["InstanceId"] # instances in Stack 1 - response = client.describe_instances(StackId=S1)['Instances'] + response = client.describe_instances(StackId=S1)["Instances"] response.should.have.length_of(2) S1L1_i1.should.be.within([i["InstanceId"] for i in response]) S1L1_i2.should.be.within([i["InstanceId"] for i in response]) - response2 = client.describe_instances( - InstanceIds=[S1L1_i1, S1L1_i2])['Instances'] - sorted(response2, key=lambda d: d['InstanceId']).should.equal( - sorted(response, key=lambda d: d['InstanceId'])) + response2 = client.describe_instances(InstanceIds=[S1L1_i1, S1L1_i2])["Instances"] + sorted(response2, key=lambda d: d["InstanceId"]).should.equal( + sorted(response, key=lambda d: d["InstanceId"]) + ) - response3 = client.describe_instances(LayerId=S1L1)['Instances'] - sorted(response3, key=lambda d: d['InstanceId']).should.equal( - sorted(response, key=lambda d: d['InstanceId'])) + response3 = client.describe_instances(LayerId=S1L1)["Instances"] + sorted(response3, key=lambda d: d["InstanceId"]).should.equal( + sorted(response, key=lambda d: d["InstanceId"]) + ) - response = client.describe_instances(StackId=S1)['Instances'] + response = client.describe_instances(StackId=S1)["Instances"] response.should.have.length_of(2) S1L1_i1.should.be.within([i["InstanceId"] for i in response]) S1L1_i2.should.be.within([i["InstanceId"] for i in response]) # instances in Stack 2 - response = client.describe_instances(StackId=S2)['Instances'] + response = client.describe_instances(StackId=S2)["Instances"] response.should.have.length_of(3) S2L1_i1.should.be.within([i["InstanceId"] for i in response]) S2L2_i1.should.be.within([i["InstanceId"] for i in response]) S2L2_i2.should.be.within([i["InstanceId"] for i in response]) - response = client.describe_instances(LayerId=S2L1)['Instances'] + response = client.describe_instances(LayerId=S2L1)["Instances"] response.should.have.length_of(1) S2L1_i1.should.be.within([i["InstanceId"] for i in response]) - response = client.describe_instances(LayerId=S2L2)['Instances'] + response = client.describe_instances(LayerId=S2L2)["Instances"] response.should.have.length_of(2) S2L1_i1.should_not.be.within([i["InstanceId"] for i in response]) # ClientError - client.describe_instances.when.called_with( - StackId=S1, - LayerId=S1L1 - ).should.throw( + client.describe_instances.when.called_with(StackId=S1, LayerId=S1L1).should.throw( Exception, "Please provide either one or more" ) # ClientError - client.describe_instances.when.called_with( - StackId="nothere" - ).should.throw( + client.describe_instances.when.called_with(StackId="nothere").should.throw( Exception, "nothere" ) # ClientError - client.describe_instances.when.called_with( - LayerId="nothere" - ).should.throw( + client.describe_instances.when.called_with(LayerId="nothere").should.throw( Exception, "nothere" ) # ClientError - client.describe_instances.when.called_with( - InstanceIds=["nothere"] - ).should.throw( + client.describe_instances.when.called_with(InstanceIds=["nothere"]).should.throw( Exception, "nothere" ) @@ -187,38 +170,37 @@ def test_ec2_integration(): instances created via OpsWorks should be discoverable via ec2 """ - opsworks = boto3.client('opsworks', region_name='us-east-1') + opsworks = boto3.client("opsworks", region_name="us-east-1") stack_id = opsworks.create_stack( Name="S1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] layer_id = opsworks.create_layer( - StackId=stack_id, - Type="custom", - Name="S1L1", - Shortname="S1L1" - )['LayerId'] + StackId=stack_id, Type="custom", Name="S1L1", Shortname="S1L1" + )["LayerId"] instance_id = opsworks.create_instance( - StackId=stack_id, LayerIds=[layer_id], InstanceType="t2.micro", SshKeyName="testSSH" - )['InstanceId'] + StackId=stack_id, + LayerIds=[layer_id], + InstanceType="t2.micro", + SshKeyName="testSSH", + )["InstanceId"] - ec2 = boto3.client('ec2', region_name='us-east-1') + ec2 = boto3.client("ec2", region_name="us-east-1") # Before starting the instance, it shouldn't be discoverable via ec2 - reservations = ec2.describe_instances()['Reservations'] + reservations = ec2.describe_instances()["Reservations"] assert reservations.should.be.empty # After starting the instance, it should be discoverable via ec2 opsworks.start_instance(InstanceId=instance_id) - reservations = ec2.describe_instances()['Reservations'] - reservations[0]['Instances'].should.have.length_of(1) - instance = reservations[0]['Instances'][0] - opsworks_instance = opsworks.describe_instances(StackId=stack_id)[ - 'Instances'][0] + reservations = ec2.describe_instances()["Reservations"] + reservations[0]["Instances"].should.have.length_of(1) + instance = reservations[0]["Instances"][0] + opsworks_instance = opsworks.describe_instances(StackId=stack_id)["Instances"][0] - instance['InstanceId'].should.equal(opsworks_instance['Ec2InstanceId']) - instance['PrivateIpAddress'].should.equal(opsworks_instance['PrivateIp']) + instance["InstanceId"].should.equal(opsworks_instance["Ec2InstanceId"]) + instance["PrivateIpAddress"].should.equal(opsworks_instance["PrivateIp"]) diff --git a/tests/test_opsworks/test_layers.py b/tests/test_opsworks/test_layers.py index 9c640dfc3..850666381 100644 --- a/tests/test_opsworks/test_layers.py +++ b/tests/test_opsworks/test_layers.py @@ -10,19 +10,19 @@ from moto import mock_opsworks @freeze_time("2015-01-01") @mock_opsworks def test_create_layer_response(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] response = client.create_layer( StackId=stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" + Shortname="TestLayerShortName", ) response.should.contain("LayerId") @@ -31,87 +31,66 @@ def test_create_layer_response(): Name="test_stack_2", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] response = client.create_layer( StackId=second_stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" + Shortname="TestLayerShortName", ) response.should.contain("LayerId") # ClientError client.create_layer.when.called_with( - StackId=stack_id, - Type="custom", - Name="TestLayer", - Shortname="_" + StackId=stack_id, Type="custom", Name="TestLayer", Shortname="_" + ).should.throw(Exception, re.compile(r'already a layer named "TestLayer"')) + # ClientError + client.create_layer.when.called_with( + StackId=stack_id, Type="custom", Name="_", Shortname="TestLayerShortName" ).should.throw( - Exception, re.compile(r'already a layer named "TestLayer"') + Exception, re.compile(r'already a layer with shortname "TestLayerShortName"') ) # ClientError client.create_layer.when.called_with( - StackId=stack_id, - Type="custom", - Name="_", - Shortname="TestLayerShortName" - ).should.throw( - Exception, re.compile( - r'already a layer with shortname "TestLayerShortName"') - ) - # ClientError - client.create_layer.when.called_with( - StackId="nothere", - Type="custom", - Name="TestLayer", - Shortname="_" - ).should.throw( - Exception, "nothere" - ) + StackId="nothere", Type="custom", Name="TestLayer", Shortname="_" + ).should.throw(Exception, "nothere") @freeze_time("2015-01-01") @mock_opsworks def test_describe_layers(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") stack_id = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" - )['StackId'] + DefaultInstanceProfileArn="profile_arn", + )["StackId"] layer_id = client.create_layer( StackId=stack_id, Type="custom", Name="TestLayer", - Shortname="TestLayerShortName" - )['LayerId'] + Shortname="TestLayerShortName", + )["LayerId"] rv1 = client.describe_layers(StackId=stack_id) rv2 = client.describe_layers(LayerIds=[layer_id]) - rv1['Layers'].should.equal(rv2['Layers']) + rv1["Layers"].should.equal(rv2["Layers"]) - rv1['Layers'][0]['Name'].should.equal("TestLayer") + rv1["Layers"][0]["Name"].should.equal("TestLayer") # ClientError client.describe_layers.when.called_with( - StackId=stack_id, - LayerIds=[layer_id] - ).should.throw( - Exception, "Please provide one or more layer IDs or a stack ID" - ) + StackId=stack_id, LayerIds=[layer_id] + ).should.throw(Exception, "Please provide one or more layer IDs or a stack ID") # ClientError - client.describe_layers.when.called_with( - StackId="nothere" - ).should.throw( + client.describe_layers.when.called_with(StackId="nothere").should.throw( Exception, "Unable to find stack with ID nothere" ) # ClientError - client.describe_layers.when.called_with( - LayerIds=["nothere"] - ).should.throw( + client.describe_layers.when.called_with(LayerIds=["nothere"]).should.throw( Exception, "nothere" ) diff --git a/tests/test_opsworks/test_stack.py b/tests/test_opsworks/test_stack.py index 5913ce6d5..277eda1ec 100644 --- a/tests/test_opsworks/test_stack.py +++ b/tests/test_opsworks/test_stack.py @@ -8,39 +8,39 @@ from moto import mock_opsworks @mock_opsworks def test_create_stack_response(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") response = client.create_stack( Name="test_stack_1", Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" + DefaultInstanceProfileArn="profile_arn", ) response.should.contain("StackId") @mock_opsworks def test_describe_stacks(): - client = boto3.client('opsworks', region_name='us-east-1') + client = boto3.client("opsworks", region_name="us-east-1") for i in range(1, 4): client.create_stack( Name="test_stack_{0}".format(i), Region="us-east-1", ServiceRoleArn="service_arn", - DefaultInstanceProfileArn="profile_arn" + DefaultInstanceProfileArn="profile_arn", ) response = client.describe_stacks() - response['Stacks'].should.have.length_of(3) - for stack in response['Stacks']: - stack['ServiceRoleArn'].should.equal("service_arn") - stack['DefaultInstanceProfileArn'].should.equal("profile_arn") + response["Stacks"].should.have.length_of(3) + for stack in response["Stacks"]: + stack["ServiceRoleArn"].should.equal("service_arn") + stack["DefaultInstanceProfileArn"].should.equal("profile_arn") - _id = response['Stacks'][0]['StackId'] + _id = response["Stacks"][0]["StackId"] response = client.describe_stacks(StackIds=[_id]) - response['Stacks'].should.have.length_of(1) - response['Stacks'][0]['Arn'].should.contain(_id) + response["Stacks"].should.have.length_of(1) + response["Stacks"][0]["Arn"].should.contain(_id) # ClientError/ResourceNotFoundException client.describe_stacks.when.called_with(StackIds=["foo"]).should.throw( - Exception, re.compile(r'foo') + Exception, re.compile(r"foo") ) diff --git a/tests/test_organizations/organizations_test_utils.py b/tests/test_organizations/organizations_test_utils.py index 83b60b877..12189c530 100644 --- a/tests/test_organizations/organizations_test_utils.py +++ b/tests/test_organizations/organizations_test_utils.py @@ -37,115 +37,108 @@ def test_make_random_service_control_policy_id(): def validate_organization(response): - org = response['Organization'] - sorted(org.keys()).should.equal([ - 'Arn', - 'AvailablePolicyTypes', - 'FeatureSet', - 'Id', - 'MasterAccountArn', - 'MasterAccountEmail', - 'MasterAccountId', - ]) - org['Id'].should.match(utils.ORG_ID_REGEX) - org['MasterAccountId'].should.equal(utils.MASTER_ACCOUNT_ID) - org['MasterAccountArn'].should.equal(utils.MASTER_ACCOUNT_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - )) - org['Arn'].should.equal(utils.ORGANIZATION_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - )) - org['MasterAccountEmail'].should.equal(utils.MASTER_ACCOUNT_EMAIL) - org['FeatureSet'].should.be.within(['ALL', 'CONSOLIDATED_BILLING']) - org['AvailablePolicyTypes'].should.equal([{ - 'Type': 'SERVICE_CONTROL_POLICY', - 'Status': 'ENABLED' - }]) + org = response["Organization"] + sorted(org.keys()).should.equal( + [ + "Arn", + "AvailablePolicyTypes", + "FeatureSet", + "Id", + "MasterAccountArn", + "MasterAccountEmail", + "MasterAccountId", + ] + ) + org["Id"].should.match(utils.ORG_ID_REGEX) + org["MasterAccountId"].should.equal(utils.MASTER_ACCOUNT_ID) + org["MasterAccountArn"].should.equal( + utils.MASTER_ACCOUNT_ARN_FORMAT.format(org["MasterAccountId"], org["Id"]) + ) + org["Arn"].should.equal( + utils.ORGANIZATION_ARN_FORMAT.format(org["MasterAccountId"], org["Id"]) + ) + org["MasterAccountEmail"].should.equal(utils.MASTER_ACCOUNT_EMAIL) + org["FeatureSet"].should.be.within(["ALL", "CONSOLIDATED_BILLING"]) + org["AvailablePolicyTypes"].should.equal( + [{"Type": "SERVICE_CONTROL_POLICY", "Status": "ENABLED"}] + ) def validate_roots(org, response): - response.should.have.key('Roots').should.be.a(list) - response['Roots'].should_not.be.empty - root = response['Roots'][0] - root.should.have.key('Id').should.match(utils.ROOT_ID_REGEX) - root.should.have.key('Arn').should.equal(utils.ROOT_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - root['Id'], - )) - root.should.have.key('Name').should.be.a(six.string_types) - root.should.have.key('PolicyTypes').should.be.a(list) - root['PolicyTypes'][0].should.have.key('Type').should.equal('SERVICE_CONTROL_POLICY') - root['PolicyTypes'][0].should.have.key('Status').should.equal('ENABLED') + response.should.have.key("Roots").should.be.a(list) + response["Roots"].should_not.be.empty + root = response["Roots"][0] + root.should.have.key("Id").should.match(utils.ROOT_ID_REGEX) + root.should.have.key("Arn").should.equal( + utils.ROOT_ARN_FORMAT.format(org["MasterAccountId"], org["Id"], root["Id"]) + ) + root.should.have.key("Name").should.be.a(six.string_types) + root.should.have.key("PolicyTypes").should.be.a(list) + root["PolicyTypes"][0].should.have.key("Type").should.equal( + "SERVICE_CONTROL_POLICY" + ) + root["PolicyTypes"][0].should.have.key("Status").should.equal("ENABLED") def validate_organizational_unit(org, response): - response.should.have.key('OrganizationalUnit').should.be.a(dict) - ou = response['OrganizationalUnit'] - ou.should.have.key('Id').should.match(utils.OU_ID_REGEX) - ou.should.have.key('Arn').should.equal(utils.OU_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - ou['Id'], - )) - ou.should.have.key('Name').should.be.a(six.string_types) + response.should.have.key("OrganizationalUnit").should.be.a(dict) + ou = response["OrganizationalUnit"] + ou.should.have.key("Id").should.match(utils.OU_ID_REGEX) + ou.should.have.key("Arn").should.equal( + utils.OU_ARN_FORMAT.format(org["MasterAccountId"], org["Id"], ou["Id"]) + ) + ou.should.have.key("Name").should.be.a(six.string_types) def validate_account(org, account): - sorted(account.keys()).should.equal([ - 'Arn', - 'Email', - 'Id', - 'JoinedMethod', - 'JoinedTimestamp', - 'Name', - 'Status', - ]) - account['Id'].should.match(utils.ACCOUNT_ID_REGEX) - account['Arn'].should.equal(utils.ACCOUNT_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - account['Id'], - )) - account['Email'].should.match(utils.EMAIL_REGEX) - account['JoinedMethod'].should.be.within(['INVITED', 'CREATED']) - account['Status'].should.be.within(['ACTIVE', 'SUSPENDED']) - account['Name'].should.be.a(six.string_types) - account['JoinedTimestamp'].should.be.a(datetime.datetime) + sorted(account.keys()).should.equal( + ["Arn", "Email", "Id", "JoinedMethod", "JoinedTimestamp", "Name", "Status"] + ) + account["Id"].should.match(utils.ACCOUNT_ID_REGEX) + account["Arn"].should.equal( + utils.ACCOUNT_ARN_FORMAT.format( + org["MasterAccountId"], org["Id"], account["Id"] + ) + ) + account["Email"].should.match(utils.EMAIL_REGEX) + account["JoinedMethod"].should.be.within(["INVITED", "CREATED"]) + account["Status"].should.be.within(["ACTIVE", "SUSPENDED"]) + account["Name"].should.be.a(six.string_types) + account["JoinedTimestamp"].should.be.a(datetime.datetime) def validate_create_account_status(create_status): - sorted(create_status.keys()).should.equal([ - 'AccountId', - 'AccountName', - 'CompletedTimestamp', - 'Id', - 'RequestedTimestamp', - 'State', - ]) - create_status['Id'].should.match(utils.CREATE_ACCOUNT_STATUS_ID_REGEX) - create_status['AccountId'].should.match(utils.ACCOUNT_ID_REGEX) - create_status['AccountName'].should.be.a(six.string_types) - create_status['State'].should.equal('SUCCEEDED') - create_status['RequestedTimestamp'].should.be.a(datetime.datetime) - create_status['CompletedTimestamp'].should.be.a(datetime.datetime) + sorted(create_status.keys()).should.equal( + [ + "AccountId", + "AccountName", + "CompletedTimestamp", + "Id", + "RequestedTimestamp", + "State", + ] + ) + create_status["Id"].should.match(utils.CREATE_ACCOUNT_STATUS_ID_REGEX) + create_status["AccountId"].should.match(utils.ACCOUNT_ID_REGEX) + create_status["AccountName"].should.be.a(six.string_types) + create_status["State"].should.equal("SUCCEEDED") + create_status["RequestedTimestamp"].should.be.a(datetime.datetime) + create_status["CompletedTimestamp"].should.be.a(datetime.datetime) + def validate_policy_summary(org, summary): summary.should.be.a(dict) - summary.should.have.key('Id').should.match(utils.SCP_ID_REGEX) - summary.should.have.key('Arn').should.equal(utils.SCP_ARN_FORMAT.format( - org['MasterAccountId'], - org['Id'], - summary['Id'], - )) - summary.should.have.key('Name').should.be.a(six.string_types) - summary.should.have.key('Description').should.be.a(six.string_types) - summary.should.have.key('Type').should.equal('SERVICE_CONTROL_POLICY') - summary.should.have.key('AwsManaged').should.be.a(bool) + summary.should.have.key("Id").should.match(utils.SCP_ID_REGEX) + summary.should.have.key("Arn").should.equal( + utils.SCP_ARN_FORMAT.format(org["MasterAccountId"], org["Id"], summary["Id"]) + ) + summary.should.have.key("Name").should.be.a(six.string_types) + summary.should.have.key("Description").should.be.a(six.string_types) + summary.should.have.key("Type").should.equal("SERVICE_CONTROL_POLICY") + summary.should.have.key("AwsManaged").should.be.a(bool) + def validate_service_control_policy(org, response): - response.should.have.key('PolicySummary').should.be.a(dict) - response.should.have.key('Content').should.be.a(six.string_types) - validate_policy_summary(org, response['PolicySummary']) + response.should.have.key("PolicySummary").should.be.a(dict) + response.should.have.key("Content").should.be.a(six.string_types) + validate_policy_summary(org, response["PolicySummary"]) diff --git a/tests/test_organizations/test_organizations_boto3.py b/tests/test_organizations/test_organizations_boto3.py index 28f8cca91..f8eb1328e 100644 --- a/tests/test_organizations/test_organizations_boto3.py +++ b/tests/test_organizations/test_organizations_boto3.py @@ -21,593 +21,576 @@ from .organizations_test_utils import ( @mock_organizations def test_create_organization(): - client = boto3.client('organizations', region_name='us-east-1') - response = client.create_organization(FeatureSet='ALL') + client = boto3.client("organizations", region_name="us-east-1") + response = client.create_organization(FeatureSet="ALL") validate_organization(response) - response['Organization']['FeatureSet'].should.equal('ALL') + response["Organization"]["FeatureSet"].should.equal("ALL") response = client.list_accounts() - len(response['Accounts']).should.equal(1) - response['Accounts'][0]['Name'].should.equal('master') - response['Accounts'][0]['Id'].should.equal(utils.MASTER_ACCOUNT_ID) - response['Accounts'][0]['Email'].should.equal(utils.MASTER_ACCOUNT_EMAIL) + len(response["Accounts"]).should.equal(1) + response["Accounts"][0]["Name"].should.equal("master") + response["Accounts"][0]["Id"].should.equal(utils.MASTER_ACCOUNT_ID) + response["Accounts"][0]["Email"].should.equal(utils.MASTER_ACCOUNT_EMAIL) - response = client.list_policies(Filter='SERVICE_CONTROL_POLICY') - len(response['Policies']).should.equal(1) - response['Policies'][0]['Name'].should.equal('FullAWSAccess') - response['Policies'][0]['Id'].should.equal(utils.DEFAULT_POLICY_ID) - response['Policies'][0]['AwsManaged'].should.equal(True) + response = client.list_policies(Filter="SERVICE_CONTROL_POLICY") + len(response["Policies"]).should.equal(1) + response["Policies"][0]["Name"].should.equal("FullAWSAccess") + response["Policies"][0]["Id"].should.equal(utils.DEFAULT_POLICY_ID) + response["Policies"][0]["AwsManaged"].should.equal(True) response = client.list_targets_for_policy(PolicyId=utils.DEFAULT_POLICY_ID) - len(response['Targets']).should.equal(2) - root_ou = [t for t in response['Targets'] if t['Type'] == 'ROOT'][0] - root_ou['Name'].should.equal('Root') - master_account = [t for t in response['Targets'] if t['Type'] == 'ACCOUNT'][0] - master_account['Name'].should.equal('master') + len(response["Targets"]).should.equal(2) + root_ou = [t for t in response["Targets"] if t["Type"] == "ROOT"][0] + root_ou["Name"].should.equal("Root") + master_account = [t for t in response["Targets"] if t["Type"] == "ACCOUNT"][0] + master_account["Name"].should.equal("master") @mock_organizations def test_describe_organization(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL') + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") response = client.describe_organization() validate_organization(response) @mock_organizations def test_describe_organization_exception(): - client = boto3.client('organizations', region_name='us-east-1') + client = boto3.client("organizations", region_name="us-east-1") with assert_raises(ClientError) as e: response = client.describe_organization() ex = e.exception - ex.operation_name.should.equal('DescribeOrganization') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AWSOrganizationsNotInUseException') + ex.operation_name.should.equal("DescribeOrganization") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AWSOrganizationsNotInUseException") # Organizational Units + @mock_organizations def test_list_roots(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] response = client.list_roots() validate_roots(org, response) @mock_organizations def test_create_organizational_unit(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_name = 'ou01' - response = client.create_organizational_unit( - ParentId=root_id, - Name=ou_name, - ) + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_name = "ou01" + response = client.create_organizational_unit(ParentId=root_id, Name=ou_name) validate_organizational_unit(org, response) - response['OrganizationalUnit']['Name'].should.equal(ou_name) + response["OrganizationalUnit"]["Name"].should.equal(ou_name) @mock_organizations def test_describe_organizational_unit(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] response = client.describe_organizational_unit(OrganizationalUnitId=ou_id) validate_organizational_unit(org, response) @mock_organizations def test_describe_organizational_unit_exception(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] with assert_raises(ClientError) as e: response = client.describe_organizational_unit( OrganizationalUnitId=utils.make_random_root_id() ) ex = e.exception - ex.operation_name.should.equal('DescribeOrganizationalUnit') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("DescribeOrganizationalUnit") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) @mock_organizations def test_list_organizational_units_for_parent(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - client.create_organizational_unit(ParentId=root_id, Name='ou01') - client.create_organizational_unit(ParentId=root_id, Name='ou02') - client.create_organizational_unit(ParentId=root_id, Name='ou03') + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + client.create_organizational_unit(ParentId=root_id, Name="ou01") + client.create_organizational_unit(ParentId=root_id, Name="ou02") + client.create_organizational_unit(ParentId=root_id, Name="ou03") response = client.list_organizational_units_for_parent(ParentId=root_id) - response.should.have.key('OrganizationalUnits').should.be.a(list) - for ou in response['OrganizationalUnits']: + response.should.have.key("OrganizationalUnits").should.be.a(list) + for ou in response["OrganizationalUnits"]: validate_organizational_unit(org, dict(OrganizationalUnit=ou)) @mock_organizations def test_list_organizational_units_for_parent_exception(): - client = boto3.client('organizations', region_name='us-east-1') + client = boto3.client("organizations", region_name="us-east-1") with assert_raises(ClientError) as e: response = client.list_organizational_units_for_parent( ParentId=utils.make_random_root_id() ) ex = e.exception - ex.operation_name.should.equal('ListOrganizationalUnitsForParent') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('ParentNotFoundException') + ex.operation_name.should.equal("ListOrganizationalUnitsForParent") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("ParentNotFoundException") # Accounts -mockname = 'mock-account' -mockdomain = 'moto-example.org' -mockemail = '@'.join([mockname, mockdomain]) +mockname = "mock-account" +mockdomain = "moto-example.org" +mockemail = "@".join([mockname, mockdomain]) @mock_organizations def test_create_account(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL') - create_status = client.create_account( - AccountName=mockname, Email=mockemail - )['CreateAccountStatus'] + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL") + create_status = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ] validate_create_account_status(create_status) - create_status['AccountName'].should.equal(mockname) + create_status["AccountName"].should.equal(mockname) @mock_organizations def test_describe_account(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - account_id = client.create_account( - AccountName=mockname, Email=mockemail - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] response = client.describe_account(AccountId=account_id) - validate_account(org, response['Account']) - response['Account']['Name'].should.equal(mockname) - response['Account']['Email'].should.equal(mockemail) + validate_account(org, response["Account"]) + response["Account"]["Name"].should.equal(mockname) + response["Account"]["Email"].should.equal(mockemail) @mock_organizations def test_describe_account_exception(): - client = boto3.client('organizations', region_name='us-east-1') + client = boto3.client("organizations", region_name="us-east-1") with assert_raises(ClientError) as e: response = client.describe_account(AccountId=utils.make_random_account_id()) ex = e.exception - ex.operation_name.should.equal('DescribeAccount') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AccountNotFoundException') + ex.operation_name.should.equal("DescribeAccount") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AccountNotFoundException") @mock_organizations def test_list_accounts(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] for i in range(5): name = mockname + str(i) - email = name + '@' + mockdomain + email = name + "@" + mockdomain client.create_account(AccountName=name, Email=email) response = client.list_accounts() - response.should.have.key('Accounts') - accounts = response['Accounts'] + response.should.have.key("Accounts") + accounts = response["Accounts"] len(accounts).should.equal(6) for account in accounts: validate_account(org, account) - accounts[4]['Name'].should.equal(mockname + '3') - accounts[3]['Email'].should.equal(mockname + '2' + '@' + mockdomain) + accounts[4]["Name"].should.equal(mockname + "3") + accounts[3]["Email"].should.equal(mockname + "2" + "@" + mockdomain) @mock_organizations def test_list_accounts_for_parent(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] response = client.list_accounts_for_parent(ParentId=root_id) - account_id.should.be.within([account['Id'] for account in response['Accounts']]) + account_id.should.be.within([account["Id"] for account in response["Accounts"]]) @mock_organizations def test_move_account(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - account_id = client.create_account( - AccountName=mockname, Email=mockemail - )['CreateAccountStatus']['AccountId'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] client.move_account( - AccountId=account_id, - SourceParentId=root_id, - DestinationParentId=ou01_id, + AccountId=account_id, SourceParentId=root_id, DestinationParentId=ou01_id ) response = client.list_accounts_for_parent(ParentId=ou01_id) - account_id.should.be.within([account['Id'] for account in response['Accounts']]) + account_id.should.be.within([account["Id"] for account in response["Accounts"]]) @mock_organizations def test_list_parents_for_ou(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] response01 = client.list_parents(ChildId=ou01_id) - response01.should.have.key('Parents').should.be.a(list) - response01['Parents'][0].should.have.key('Id').should.equal(root_id) - response01['Parents'][0].should.have.key('Type').should.equal('ROOT') - ou02 = client.create_organizational_unit(ParentId=ou01_id, Name='ou02') - ou02_id = ou02['OrganizationalUnit']['Id'] + response01.should.have.key("Parents").should.be.a(list) + response01["Parents"][0].should.have.key("Id").should.equal(root_id) + response01["Parents"][0].should.have.key("Type").should.equal("ROOT") + ou02 = client.create_organizational_unit(ParentId=ou01_id, Name="ou02") + ou02_id = ou02["OrganizationalUnit"]["Id"] response02 = client.list_parents(ChildId=ou02_id) - response02.should.have.key('Parents').should.be.a(list) - response02['Parents'][0].should.have.key('Id').should.equal(ou01_id) - response02['Parents'][0].should.have.key('Type').should.equal('ORGANIZATIONAL_UNIT') + response02.should.have.key("Parents").should.be.a(list) + response02["Parents"][0].should.have.key("Id").should.equal(ou01_id) + response02["Parents"][0].should.have.key("Type").should.equal("ORGANIZATIONAL_UNIT") @mock_organizations def test_list_parents_for_accounts(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] account01_id = client.create_account( - AccountName='account01', - Email='account01@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account01", Email="account01@moto-example.org" + )["CreateAccountStatus"]["AccountId"] account02_id = client.create_account( - AccountName='account02', - Email='account02@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account02", Email="account02@moto-example.org" + )["CreateAccountStatus"]["AccountId"] client.move_account( - AccountId=account02_id, - SourceParentId=root_id, - DestinationParentId=ou01_id, + AccountId=account02_id, SourceParentId=root_id, DestinationParentId=ou01_id ) response01 = client.list_parents(ChildId=account01_id) - response01.should.have.key('Parents').should.be.a(list) - response01['Parents'][0].should.have.key('Id').should.equal(root_id) - response01['Parents'][0].should.have.key('Type').should.equal('ROOT') + response01.should.have.key("Parents").should.be.a(list) + response01["Parents"][0].should.have.key("Id").should.equal(root_id) + response01["Parents"][0].should.have.key("Type").should.equal("ROOT") response02 = client.list_parents(ChildId=account02_id) - response02.should.have.key('Parents').should.be.a(list) - response02['Parents'][0].should.have.key('Id').should.equal(ou01_id) - response02['Parents'][0].should.have.key('Type').should.equal('ORGANIZATIONAL_UNIT') + response02.should.have.key("Parents").should.be.a(list) + response02["Parents"][0].should.have.key("Id").should.equal(ou01_id) + response02["Parents"][0].should.have.key("Type").should.equal("ORGANIZATIONAL_UNIT") @mock_organizations def test_list_children(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou01 = client.create_organizational_unit(ParentId=root_id, Name='ou01') - ou01_id = ou01['OrganizationalUnit']['Id'] - ou02 = client.create_organizational_unit(ParentId=ou01_id, Name='ou02') - ou02_id = ou02['OrganizationalUnit']['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou01 = client.create_organizational_unit(ParentId=root_id, Name="ou01") + ou01_id = ou01["OrganizationalUnit"]["Id"] + ou02 = client.create_organizational_unit(ParentId=ou01_id, Name="ou02") + ou02_id = ou02["OrganizationalUnit"]["Id"] account01_id = client.create_account( - AccountName='account01', - Email='account01@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account01", Email="account01@moto-example.org" + )["CreateAccountStatus"]["AccountId"] account02_id = client.create_account( - AccountName='account02', - Email='account02@moto-example.org' - )['CreateAccountStatus']['AccountId'] + AccountName="account02", Email="account02@moto-example.org" + )["CreateAccountStatus"]["AccountId"] client.move_account( - AccountId=account02_id, - SourceParentId=root_id, - DestinationParentId=ou01_id, + AccountId=account02_id, SourceParentId=root_id, DestinationParentId=ou01_id ) - response01 = client.list_children(ParentId=root_id, ChildType='ACCOUNT') - response02 = client.list_children(ParentId=root_id, ChildType='ORGANIZATIONAL_UNIT') - response03 = client.list_children(ParentId=ou01_id, ChildType='ACCOUNT') - response04 = client.list_children(ParentId=ou01_id, ChildType='ORGANIZATIONAL_UNIT') - response01['Children'][0]['Id'].should.equal(utils.MASTER_ACCOUNT_ID) - response01['Children'][0]['Type'].should.equal('ACCOUNT') - response01['Children'][1]['Id'].should.equal(account01_id) - response01['Children'][1]['Type'].should.equal('ACCOUNT') - response02['Children'][0]['Id'].should.equal(ou01_id) - response02['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT') - response03['Children'][0]['Id'].should.equal(account02_id) - response03['Children'][0]['Type'].should.equal('ACCOUNT') - response04['Children'][0]['Id'].should.equal(ou02_id) - response04['Children'][0]['Type'].should.equal('ORGANIZATIONAL_UNIT') + response01 = client.list_children(ParentId=root_id, ChildType="ACCOUNT") + response02 = client.list_children(ParentId=root_id, ChildType="ORGANIZATIONAL_UNIT") + response03 = client.list_children(ParentId=ou01_id, ChildType="ACCOUNT") + response04 = client.list_children(ParentId=ou01_id, ChildType="ORGANIZATIONAL_UNIT") + response01["Children"][0]["Id"].should.equal(utils.MASTER_ACCOUNT_ID) + response01["Children"][0]["Type"].should.equal("ACCOUNT") + response01["Children"][1]["Id"].should.equal(account01_id) + response01["Children"][1]["Type"].should.equal("ACCOUNT") + response02["Children"][0]["Id"].should.equal(ou01_id) + response02["Children"][0]["Type"].should.equal("ORGANIZATIONAL_UNIT") + response03["Children"][0]["Id"].should.equal(account02_id) + response03["Children"][0]["Type"].should.equal("ACCOUNT") + response04["Children"][0]["Id"].should.equal(ou02_id) + response04["Children"][0]["Type"].should.equal("ORGANIZATIONAL_UNIT") @mock_organizations def test_list_children_exception(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] with assert_raises(ClientError) as e: response = client.list_children( - ParentId=utils.make_random_root_id(), - ChildType='ACCOUNT' + ParentId=utils.make_random_root_id(), ChildType="ACCOUNT" ) ex = e.exception - ex.operation_name.should.equal('ListChildren') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('ParentNotFoundException') + ex.operation_name.should.equal("ListChildren") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("ParentNotFoundException") with assert_raises(ClientError) as e: - response = client.list_children( - ParentId=root_id, - ChildType='BLEE' - ) + response = client.list_children(ParentId=root_id, ChildType="BLEE") ex = e.exception - ex.operation_name.should.equal('ListChildren') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("ListChildren") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") # Service Control Policies policy_doc01 = dict( - Version='2012-10-17', - Statement=[dict( - Sid='MockPolicyStatement', - Effect='Allow', - Action='s3:*', - Resource='*', - )] + Version="2012-10-17", + Statement=[ + dict(Sid="MockPolicyStatement", Effect="Allow", Action="s3:*", Resource="*") + ], ) + @mock_organizations def test_create_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] policy = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"] validate_service_control_policy(org, policy) - policy['PolicySummary']['Name'].should.equal('MockServiceControlPolicy') - policy['PolicySummary']['Description'].should.equal('A dummy service control policy') - policy['Content'].should.equal(json.dumps(policy_doc01)) + policy["PolicySummary"]["Name"].should.equal("MockServiceControlPolicy") + policy["PolicySummary"]["Description"].should.equal( + "A dummy service control policy" + ) + policy["Content"].should.equal(json.dumps(policy_doc01)) @mock_organizations def test_describe_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] - policy = client.describe_policy(PolicyId=policy_id)['Policy'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] + policy = client.describe_policy(PolicyId=policy_id)["Policy"] validate_service_control_policy(org, policy) - policy['PolicySummary']['Name'].should.equal('MockServiceControlPolicy') - policy['PolicySummary']['Description'].should.equal('A dummy service control policy') - policy['Content'].should.equal(json.dumps(policy_doc01)) + policy["PolicySummary"]["Name"].should.equal("MockServiceControlPolicy") + policy["PolicySummary"]["Description"].should.equal( + "A dummy service control policy" + ) + policy["Content"].should.equal(json.dumps(policy_doc01)) @mock_organizations def test_describe_policy_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - policy_id = 'p-47fhe9s3' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + policy_id = "p-47fhe9s3" with assert_raises(ClientError) as e: response = client.describe_policy(PolicyId=policy_id) ex = e.exception - ex.operation_name.should.equal('DescribePolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('PolicyNotFoundException') + ex.operation_name.should.equal("DescribePolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("PolicyNotFoundException") with assert_raises(ClientError) as e: - response = client.describe_policy(PolicyId='meaninglessstring') + response = client.describe_policy(PolicyId="meaninglessstring") ex = e.exception - ex.operation_name.should.equal('DescribePolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("DescribePolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") @mock_organizations def test_attach_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] response = client.attach_policy(PolicyId=policy_id, TargetId=root_id) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) response = client.attach_policy(PolicyId=policy_id, TargetId=ou_id) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) response = client.attach_policy(PolicyId=policy_id, TargetId=account_id) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @mock_organizations def test_attach_policy_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - root_id='r-dj873' - ou_id='ou-gi99-i7r8eh2i2' - account_id='126644886543' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + root_id = "r-dj873" + ou_id = "ou-gi99-i7r8eh2i2" + account_id = "126644886543" policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] with assert_raises(ClientError) as e: response = client.attach_policy(PolicyId=policy_id, TargetId=root_id) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) with assert_raises(ClientError) as e: response = client.attach_policy(PolicyId=policy_id, TargetId=ou_id) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) with assert_raises(ClientError) as e: response = client.attach_policy(PolicyId=policy_id, TargetId=account_id) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AccountNotFoundException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AccountNotFoundException") with assert_raises(ClientError) as e: - response = client.attach_policy(PolicyId=policy_id, TargetId='meaninglessstring') + response = client.attach_policy( + PolicyId=policy_id, TargetId="meaninglessstring" + ) ex = e.exception - ex.operation_name.should.equal('AttachPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("AttachPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") @mock_organizations def test_list_polices(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - for i in range(0,4): + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + for i in range(0, 4): client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy' + str(i), - Type='SERVICE_CONTROL_POLICY' + Description="A dummy service control policy", + Name="MockServiceControlPolicy" + str(i), + Type="SERVICE_CONTROL_POLICY", ) - response = client.list_policies(Filter='SERVICE_CONTROL_POLICY') - for policy in response['Policies']: + response = client.list_policies(Filter="SERVICE_CONTROL_POLICY") + for policy in response["Policies"]: validate_policy_summary(org, policy) - + @mock_organizations def test_list_policies_for_target(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] client.attach_policy(PolicyId=policy_id, TargetId=ou_id) response = client.list_policies_for_target( - TargetId=ou_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=ou_id, Filter="SERVICE_CONTROL_POLICY" ) - for policy in response['Policies']: + for policy in response["Policies"]: validate_policy_summary(org, policy) client.attach_policy(PolicyId=policy_id, TargetId=account_id) response = client.list_policies_for_target( - TargetId=account_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=account_id, Filter="SERVICE_CONTROL_POLICY" ) - for policy in response['Policies']: + for policy in response["Policies"]: validate_policy_summary(org, policy) @mock_organizations def test_list_policies_for_target_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - ou_id='ou-gi99-i7r8eh2i2' - account_id='126644886543' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + ou_id = "ou-gi99-i7r8eh2i2" + account_id = "126644886543" with assert_raises(ClientError) as e: response = client.list_policies_for_target( - TargetId=ou_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=ou_id, Filter="SERVICE_CONTROL_POLICY" ) ex = e.exception - ex.operation_name.should.equal('ListPoliciesForTarget') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('OrganizationalUnitNotFoundException') + ex.operation_name.should.equal("ListPoliciesForTarget") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain( + "OrganizationalUnitNotFoundException" + ) with assert_raises(ClientError) as e: response = client.list_policies_for_target( - TargetId=account_id, - Filter='SERVICE_CONTROL_POLICY', + TargetId=account_id, Filter="SERVICE_CONTROL_POLICY" ) ex = e.exception - ex.operation_name.should.equal('ListPoliciesForTarget') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('AccountNotFoundException') + ex.operation_name.should.equal("ListPoliciesForTarget") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("AccountNotFoundException") with assert_raises(ClientError) as e: response = client.list_policies_for_target( - TargetId='meaninglessstring', - Filter='SERVICE_CONTROL_POLICY', + TargetId="meaninglessstring", Filter="SERVICE_CONTROL_POLICY" ) ex = e.exception - ex.operation_name.should.equal('ListPoliciesForTarget') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') - + ex.operation_name.should.equal("ListPoliciesForTarget") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") + @mock_organizations def test_list_targets_for_policy(): - client = boto3.client('organizations', region_name='us-east-1') - org = client.create_organization(FeatureSet='ALL')['Organization'] - root_id = client.list_roots()['Roots'][0]['Id'] - ou_id = client.create_organizational_unit( - ParentId=root_id, - Name='ou01', - )['OrganizationalUnit']['Id'] - account_id = client.create_account( - AccountName=mockname, - Email=mockemail, - )['CreateAccountStatus']['AccountId'] + client = boto3.client("organizations", region_name="us-east-1") + org = client.create_organization(FeatureSet="ALL")["Organization"] + root_id = client.list_roots()["Roots"][0]["Id"] + ou_id = client.create_organizational_unit(ParentId=root_id, Name="ou01")[ + "OrganizationalUnit" + ]["Id"] + account_id = client.create_account(AccountName=mockname, Email=mockemail)[ + "CreateAccountStatus" + ]["AccountId"] policy_id = client.create_policy( Content=json.dumps(policy_doc01), - Description='A dummy service control policy', - Name='MockServiceControlPolicy', - Type='SERVICE_CONTROL_POLICY' - )['Policy']['PolicySummary']['Id'] + Description="A dummy service control policy", + Name="MockServiceControlPolicy", + Type="SERVICE_CONTROL_POLICY", + )["Policy"]["PolicySummary"]["Id"] client.attach_policy(PolicyId=policy_id, TargetId=root_id) client.attach_policy(PolicyId=policy_id, TargetId=ou_id) client.attach_policy(PolicyId=policy_id, TargetId=account_id) response = client.list_targets_for_policy(PolicyId=policy_id) - for target in response['Targets']: + for target in response["Targets"]: target.should.be.a(dict) - target.should.have.key('Name').should.be.a(six.string_types) - target.should.have.key('Arn').should.be.a(six.string_types) - target.should.have.key('TargetId').should.be.a(six.string_types) - target.should.have.key('Type').should.be.within( - ['ROOT', 'ORGANIZATIONAL_UNIT', 'ACCOUNT'] + target.should.have.key("Name").should.be.a(six.string_types) + target.should.have.key("Arn").should.be.a(six.string_types) + target.should.have.key("TargetId").should.be.a(six.string_types) + target.should.have.key("Type").should.be.within( + ["ROOT", "ORGANIZATIONAL_UNIT", "ACCOUNT"] ) @mock_organizations def test_list_targets_for_policy_exception(): - client = boto3.client('organizations', region_name='us-east-1') - client.create_organization(FeatureSet='ALL')['Organization'] - policy_id = 'p-47fhe9s3' + client = boto3.client("organizations", region_name="us-east-1") + client.create_organization(FeatureSet="ALL")["Organization"] + policy_id = "p-47fhe9s3" with assert_raises(ClientError) as e: response = client.list_targets_for_policy(PolicyId=policy_id) ex = e.exception - ex.operation_name.should.equal('ListTargetsForPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('PolicyNotFoundException') + ex.operation_name.should.equal("ListTargetsForPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("PolicyNotFoundException") with assert_raises(ClientError) as e: - response = client.list_targets_for_policy(PolicyId='meaninglessstring') + response = client.list_targets_for_policy(PolicyId="meaninglessstring") ex = e.exception - ex.operation_name.should.equal('ListTargetsForPolicy') - ex.response['Error']['Code'].should.equal('400') - ex.response['Error']['Message'].should.contain('InvalidInputException') + ex.operation_name.should.equal("ListTargetsForPolicy") + ex.response["Error"]["Code"].should.equal("400") + ex.response["Error"]["Message"].should.contain("InvalidInputException") diff --git a/tests/test_packages/__init__.py b/tests/test_packages/__init__.py index bf582e0b3..05b1d476b 100644 --- a/tests/test_packages/__init__.py +++ b/tests/test_packages/__init__.py @@ -1,8 +1,9 @@ from __future__ import unicode_literals import logging + # Disable extra logging for tests -logging.getLogger('boto').setLevel(logging.CRITICAL) -logging.getLogger('boto3').setLevel(logging.CRITICAL) -logging.getLogger('botocore').setLevel(logging.CRITICAL) -logging.getLogger('nose').setLevel(logging.CRITICAL) +logging.getLogger("boto").setLevel(logging.CRITICAL) +logging.getLogger("boto3").setLevel(logging.CRITICAL) +logging.getLogger("botocore").setLevel(logging.CRITICAL) +logging.getLogger("nose").setLevel(logging.CRITICAL) diff --git a/tests/test_packages/test_httpretty.py b/tests/test_packages/test_httpretty.py index 48277a2de..ccf9b98ef 100644 --- a/tests/test_packages/test_httpretty.py +++ b/tests/test_packages/test_httpretty.py @@ -3,35 +3,42 @@ from __future__ import unicode_literals import mock -from moto.packages.httpretty.core import HTTPrettyRequest, fake_gethostname, fake_gethostbyname +from moto.packages.httpretty.core import ( + HTTPrettyRequest, + fake_gethostname, + fake_gethostbyname, +) def test_parse_querystring(): - core = HTTPrettyRequest(headers='test test HTTP/1.1') + core = HTTPrettyRequest(headers="test test HTTP/1.1") - qs = 'test test' + qs = "test test" response = core.parse_querystring(qs) assert response == {} -def test_parse_request_body(): - core = HTTPrettyRequest(headers='test test HTTP/1.1') - qs = 'test' +def test_parse_request_body(): + core = HTTPrettyRequest(headers="test test HTTP/1.1") + + qs = "test" response = core.parse_request_body(qs) - assert response == 'test' + assert response == "test" + def test_fake_gethostname(): - response = fake_gethostname() + response = fake_gethostname() + + assert response == "localhost" - assert response == 'localhost' def test_fake_gethostbyname(): - host = 'test' + host = "test" response = fake_gethostbyname(host=host) - assert response == '127.0.0.1' \ No newline at end of file + assert response == "127.0.0.1" diff --git a/tests/test_polly/test_polly.py b/tests/test_polly/test_polly.py index c5c864835..e172b98d0 100644 --- a/tests/test_polly/test_polly.py +++ b/tests/test_polly/test_polly.py @@ -7,7 +7,7 @@ from nose.tools import assert_raises from moto import mock_polly # Polly only available in a few regions -DEFAULT_REGION = 'eu-west-1' +DEFAULT_REGION = "eu-west-1" LEXICON_XML = """ @mock_polly def test_describe_voices(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) resp = client.describe_voices() - len(resp['Voices']).should.be.greater_than(1) + len(resp["Voices"]).should.be.greater_than(1) - resp = client.describe_voices(LanguageCode='en-GB') - len(resp['Voices']).should.equal(3) + resp = client.describe_voices(LanguageCode="en-GB") + len(resp["Voices"]).should.equal(3) try: - client.describe_voices(LanguageCode='SOME_LANGUAGE') + client.describe_voices(LanguageCode="SOME_LANGUAGE") except ClientError as err: - err.response['Error']['Code'].should.equal('400') + err.response["Error"]["Code"].should.equal("400") else: - raise RuntimeError('Should of raised an exception') + raise RuntimeError("Should of raised an exception") @mock_polly def test_put_list_lexicon(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) # Return nothing - client.put_lexicon( - Name='test', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test", Content=LEXICON_XML) resp = client.list_lexicons() - len(resp['Lexicons']).should.equal(1) + len(resp["Lexicons"]).should.equal(1) @mock_polly def test_put_get_lexicon(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) # Return nothing - client.put_lexicon( - Name='test', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test", Content=LEXICON_XML) - resp = client.get_lexicon(Name='test') - resp.should.contain('Lexicon') - resp.should.contain('LexiconAttributes') + resp = client.get_lexicon(Name="test") + resp.should.contain("Lexicon") + resp.should.contain("LexiconAttributes") @mock_polly def test_put_lexicon_bad_name(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) try: - client.put_lexicon( - Name='test-invalid', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test-invalid", Content=LEXICON_XML) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised an exception') + raise RuntimeError("Should of raised an exception") @mock_polly def test_synthesize_speech(): - client = boto3.client('polly', region_name=DEFAULT_REGION) + client = boto3.client("polly", region_name=DEFAULT_REGION) # Return nothing - client.put_lexicon( - Name='test', - Content=LEXICON_XML - ) + client.put_lexicon(Name="test", Content=LEXICON_XML) - tests = ( - ('pcm', 'audio/pcm'), - ('mp3', 'audio/mpeg'), - ('ogg_vorbis', 'audio/ogg'), - ) + tests = (("pcm", "audio/pcm"), ("mp3", "audio/mpeg"), ("ogg_vorbis", "audio/ogg")) for output_format, content_type in tests: resp = client.synthesize_speech( - LexiconNames=['test'], + LexiconNames=["test"], OutputFormat=output_format, - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) - resp['ContentType'].should.equal(content_type) + resp["ContentType"].should.equal(content_type) @mock_polly def test_synthesize_speech_bad_lexicon(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test2'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + LexiconNames=["test2"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('LexiconNotFoundException') + err.response["Error"]["Code"].should.equal("LexiconNotFoundException") else: - raise RuntimeError('Should of raised LexiconNotFoundException') + raise RuntimeError("Should of raised LexiconNotFoundException") @mock_polly def test_synthesize_speech_bad_output_format(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='invalid', - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="invalid", + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_sample_rate(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='18000', - Text='test1234', - TextType='text', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="18000", + Text="test1234", + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidSampleRateException') + err.response["Error"]["Code"].should.equal("InvalidSampleRateException") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_text_type(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='invalid', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="invalid", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_voice_id(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='text', - VoiceId='Luke' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="text", + VoiceId="Luke", ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_text_too_long(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234'*376, # = 3008 characters - TextType='text', - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234" * 376, # = 3008 characters + TextType="text", + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('TextLengthExceededException') + err.response["Error"]["Code"].should.equal("TextLengthExceededException") else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_speech_marks1(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='text', - SpeechMarkTypes=['word'], - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="text", + SpeechMarkTypes=["word"], + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('MarksNotSupportedForFormatException') + err.response["Error"]["Code"].should.equal( + "MarksNotSupportedForFormatException" + ) else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") @mock_polly def test_synthesize_speech_bad_speech_marks2(): - client = boto3.client('polly', region_name=DEFAULT_REGION) - client.put_lexicon(Name='test', Content=LEXICON_XML) + client = boto3.client("polly", region_name=DEFAULT_REGION) + client.put_lexicon(Name="test", Content=LEXICON_XML) try: client.synthesize_speech( - LexiconNames=['test'], - OutputFormat='pcm', - SampleRate='16000', - Text='test1234', - TextType='ssml', - SpeechMarkTypes=['word'], - VoiceId='Astrid' + LexiconNames=["test"], + OutputFormat="pcm", + SampleRate="16000", + Text="test1234", + TextType="ssml", + SpeechMarkTypes=["word"], + VoiceId="Astrid", ) except ClientError as err: - err.response['Error']['Code'].should.equal('MarksNotSupportedForFormatException') + err.response["Error"]["Code"].should.equal( + "MarksNotSupportedForFormatException" + ) else: - raise RuntimeError('Should of raised ') + raise RuntimeError("Should of raised ") diff --git a/tests/test_polly/test_server.py b/tests/test_polly/test_server.py index 3ae7f2254..756c9d7e4 100644 --- a/tests/test_polly/test_server.py +++ b/tests/test_polly/test_server.py @@ -5,9 +5,9 @@ import sure # noqa import moto.server as server from moto import mock_polly -''' +""" Test the different server responses -''' +""" @mock_polly @@ -15,5 +15,5 @@ def test_polly_list(): backend = server.create_backend_app("polly") test_client = backend.test_client() - res = test_client.get('/v1/lexicons') + res = test_client.get("/v1/lexicons") res.status_code.should.equal(200) diff --git a/tests/test_rds/test_rds.py b/tests/test_rds/test_rds.py index af330e672..4ebea0cf3 100644 --- a/tests/test_rds/test_rds.py +++ b/tests/test_rds/test_rds.py @@ -14,17 +14,19 @@ from tests.helpers import disable_on_py3 def test_create_database(): conn = boto.rds.connect_to_region("us-west-2") - database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2', - security_groups=["my_sg"]) + database = conn.create_dbinstance( + "db-master-1", 10, "db.m1.small", "root", "hunter2", security_groups=["my_sg"] + ) - database.status.should.equal('available') + database.status.should.equal("available") database.id.should.equal("db-master-1") database.allocated_storage.should.equal(10) database.instance_class.should.equal("db.m1.small") database.master_username.should.equal("root") database.endpoint.should.equal( - ('db-master-1.aaaaaaaaaa.us-west-2.rds.amazonaws.com', 3306)) - database.security_groups[0].name.should.equal('my_sg') + ("db-master-1.aaaaaaaaaa.us-west-2.rds.amazonaws.com", 3306) + ) + database.security_groups[0].name.should.equal("my_sg") @mock_rds_deprecated @@ -33,8 +35,8 @@ def test_get_databases(): list(conn.get_all_dbinstances()).should.have.length_of(0) - conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2') - conn.create_dbinstance("db-master-2", 10, 'db.m1.small', 'root', 'hunter2') + conn.create_dbinstance("db-master-1", 10, "db.m1.small", "root", "hunter2") + conn.create_dbinstance("db-master-2", 10, "db.m1.small", "root", "hunter2") list(conn.get_all_dbinstances()).should.have.length_of(2) @@ -46,18 +48,20 @@ def test_get_databases(): @mock_rds def test_get_databases_paginated(): - conn = boto3.client('rds', region_name="us-west-2") + conn = boto3.client("rds", region_name="us-west-2") for i in range(51): - conn.create_db_instance(AllocatedStorage=5, - Port=5432, - DBInstanceIdentifier='rds%d' % i, - DBInstanceClass='db.t1.micro', - Engine='postgres') + conn.create_db_instance( + AllocatedStorage=5, + Port=5432, + DBInstanceIdentifier="rds%d" % i, + DBInstanceClass="db.t1.micro", + Engine="postgres", + ) resp = conn.describe_db_instances() resp["DBInstances"].should.have.length_of(50) - resp["Marker"].should.equal(resp["DBInstances"][-1]['DBInstanceIdentifier']) + resp["Marker"].should.equal(resp["DBInstances"][-1]["DBInstanceIdentifier"]) resp2 = conn.describe_db_instances(Marker=resp["Marker"]) resp2["DBInstances"].should.have.length_of(1) @@ -66,8 +70,7 @@ def test_get_databases_paginated(): @mock_rds_deprecated def test_describe_non_existant_database(): conn = boto.rds.connect_to_region("us-west-2") - conn.get_all_dbinstances.when.called_with( - "not-a-db").should.throw(BotoServerError) + conn.get_all_dbinstances.when.called_with("not-a-db").should.throw(BotoServerError) @mock_rds_deprecated @@ -75,7 +78,7 @@ def test_delete_database(): conn = boto.rds.connect_to_region("us-west-2") list(conn.get_all_dbinstances()).should.have.length_of(0) - conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + conn.create_dbinstance("db-master-1", 10, "db.m1.small", "root", "hunter2") list(conn.get_all_dbinstances()).should.have.length_of(1) conn.delete_dbinstance("db-master-1") @@ -85,16 +88,15 @@ def test_delete_database(): @mock_rds_deprecated def test_delete_non_existant_database(): conn = boto.rds.connect_to_region("us-west-2") - conn.delete_dbinstance.when.called_with( - "not-a-db").should.throw(BotoServerError) + conn.delete_dbinstance.when.called_with("not-a-db").should.throw(BotoServerError) @mock_rds_deprecated def test_create_database_security_group(): conn = boto.rds.connect_to_region("us-west-2") - security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') - security_group.name.should.equal('db_sg') + security_group = conn.create_dbsecurity_group("db_sg", "DB Security Group") + security_group.name.should.equal("db_sg") security_group.description.should.equal("DB Security Group") list(security_group.ip_ranges).should.equal([]) @@ -105,8 +107,8 @@ def test_get_security_groups(): list(conn.get_all_dbsecurity_groups()).should.have.length_of(0) - conn.create_dbsecurity_group('db_sg1', 'DB Security Group') - conn.create_dbsecurity_group('db_sg2', 'DB Security Group') + conn.create_dbsecurity_group("db_sg1", "DB Security Group") + conn.create_dbsecurity_group("db_sg2", "DB Security Group") list(conn.get_all_dbsecurity_groups()).should.have.length_of(2) @@ -119,14 +121,15 @@ def test_get_security_groups(): @mock_rds_deprecated def test_get_non_existant_security_group(): conn = boto.rds.connect_to_region("us-west-2") - conn.get_all_dbsecurity_groups.when.called_with( - "not-a-sg").should.throw(BotoServerError) + conn.get_all_dbsecurity_groups.when.called_with("not-a-sg").should.throw( + BotoServerError + ) @mock_rds_deprecated def test_delete_database_security_group(): conn = boto.rds.connect_to_region("us-west-2") - conn.create_dbsecurity_group('db_sg', 'DB Security Group') + conn.create_dbsecurity_group("db_sg", "DB Security Group") list(conn.get_all_dbsecurity_groups()).should.have.length_of(1) @@ -137,21 +140,22 @@ def test_delete_database_security_group(): @mock_rds_deprecated def test_delete_non_existant_security_group(): conn = boto.rds.connect_to_region("us-west-2") - conn.delete_dbsecurity_group.when.called_with( - "not-a-db").should.throw(BotoServerError) + conn.delete_dbsecurity_group.when.called_with("not-a-db").should.throw( + BotoServerError + ) @disable_on_py3() @mock_rds_deprecated def test_security_group_authorize(): conn = boto.rds.connect_to_region("us-west-2") - security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') + security_group = conn.create_dbsecurity_group("db_sg", "DB Security Group") list(security_group.ip_ranges).should.equal([]) - security_group.authorize(cidr_ip='10.3.2.45/32') + security_group.authorize(cidr_ip="10.3.2.45/32") security_group = conn.get_all_dbsecurity_groups()[0] list(security_group.ip_ranges).should.have.length_of(1) - security_group.ip_ranges[0].cidr_ip.should.equal('10.3.2.45/32') + security_group.ip_ranges[0].cidr_ip.should.equal("10.3.2.45/32") @mock_rds_deprecated @@ -159,8 +163,9 @@ def test_add_security_group_to_database(): conn = boto.rds.connect_to_region("us-west-2") database = conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2') - security_group = conn.create_dbsecurity_group('db_sg', 'DB Security Group') + "db-master-1", 10, "db.m1.small", "root", "hunter2" + ) + security_group = conn.create_dbsecurity_group("db_sg", "DB Security Group") database.modify(security_groups=[security_group]) database = conn.get_all_dbinstances()[0] @@ -179,9 +184,8 @@ def test_add_database_subnet_group(): subnet_ids = [subnet1.id, subnet2.id] conn = boto.rds.connect_to_region("us-west-2") - subnet_group = conn.create_db_subnet_group( - "db_subnet", "my db subnet", subnet_ids) - subnet_group.name.should.equal('db_subnet') + subnet_group = conn.create_db_subnet_group("db_subnet", "my db subnet", subnet_ids) + subnet_group.name.should.equal("db_subnet") subnet_group.description.should.equal("my db subnet") list(subnet_group.subnet_ids).should.equal(subnet_ids) @@ -200,8 +204,9 @@ def test_describe_database_subnet_group(): list(conn.get_all_db_subnet_groups()).should.have.length_of(2) list(conn.get_all_db_subnet_groups("db_subnet1")).should.have.length_of(1) - conn.get_all_db_subnet_groups.when.called_with( - "not-a-subnet").should.throw(BotoServerError) + conn.get_all_db_subnet_groups.when.called_with("not-a-subnet").should.throw( + BotoServerError + ) @mock_ec2_deprecated @@ -218,8 +223,9 @@ def test_delete_database_subnet_group(): conn.delete_db_subnet_group("db_subnet1") list(conn.get_all_db_subnet_groups()).should.have.length_of(0) - conn.delete_db_subnet_group.when.called_with( - "db_subnet1").should.throw(BotoServerError) + conn.delete_db_subnet_group.when.called_with("db_subnet1").should.throw( + BotoServerError + ) @mock_ec2_deprecated @@ -232,8 +238,14 @@ def test_create_database_in_subnet_group(): conn = boto.rds.connect_to_region("us-west-2") conn.create_db_subnet_group("db_subnet1", "my db subnet", [subnet.id]) - database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', - 'root', 'hunter2', db_subnet_group_name="db_subnet1") + database = conn.create_dbinstance( + "db-master-1", + 10, + "db.m1.small", + "root", + "hunter2", + db_subnet_group_name="db_subnet1", + ) database = conn.get_all_dbinstances("db-master-1")[0] database.subnet_group.name.should.equal("db_subnet1") @@ -244,16 +256,18 @@ def test_create_database_replica(): conn = boto.rds.connect_to_region("us-west-2") primary = conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + "db-master-1", 10, "db.m1.small", "root", "hunter2" + ) replica = conn.create_dbinstance_read_replica( - "replica", "db-master-1", "db.m1.small") + "replica", "db-master-1", "db.m1.small" + ) replica.id.should.equal("replica") replica.instance_class.should.equal("db.m1.small") status_info = replica.status_infos[0] status_info.normal.should.equal(True) - status_info.status_type.should.equal('read replication') - status_info.status.should.equal('replicating') + status_info.status_type.should.equal("read replication") + status_info.status.should.equal("replicating") primary = conn.get_all_dbinstances("db-master-1")[0] primary.read_replica_dbinstance_identifiers[0].should.equal("replica") @@ -270,13 +284,12 @@ def test_create_cross_region_database_replica(): west_2_conn = boto.rds.connect_to_region("us-west-2") primary = west_1_conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2') + "db-master-1", 10, "db.m1.small", "root", "hunter2" + ) primary_arn = "arn:aws:rds:us-west-1:1234567890:db:db-master-1" replica = west_2_conn.create_dbinstance_read_replica( - "replica", - primary_arn, - "db.m1.small", + "replica", primary_arn, "db.m1.small" ) primary = west_1_conn.get_all_dbinstances("db-master-1")[0] @@ -298,17 +311,19 @@ def test_connecting_to_us_east_1(): # https://github.com/boto/boto/blob/e271ff09364ea18d9d8b6f4d63d6b0ac6cbc9b75/boto/endpoints.json#L285 conn = boto.rds.connect_to_region("us-east-1") - database = conn.create_dbinstance("db-master-1", 10, 'db.m1.small', 'root', 'hunter2', - security_groups=["my_sg"]) + database = conn.create_dbinstance( + "db-master-1", 10, "db.m1.small", "root", "hunter2", security_groups=["my_sg"] + ) - database.status.should.equal('available') + database.status.should.equal("available") database.id.should.equal("db-master-1") database.allocated_storage.should.equal(10) database.instance_class.should.equal("db.m1.small") database.master_username.should.equal("root") database.endpoint.should.equal( - ('db-master-1.aaaaaaaaaa.us-east-1.rds.amazonaws.com', 3306)) - database.security_groups[0].name.should.equal('my_sg') + ("db-master-1.aaaaaaaaaa.us-east-1.rds.amazonaws.com", 3306) + ) + database.security_groups[0].name.should.equal("my_sg") @mock_rds_deprecated @@ -316,9 +331,10 @@ def test_create_database_with_iops(): conn = boto.rds.connect_to_region("us-west-2") database = conn.create_dbinstance( - "db-master-1", 10, 'db.m1.small', 'root', 'hunter2', iops=6000) + "db-master-1", 10, "db.m1.small", "root", "hunter2", iops=6000 + ) - database.status.should.equal('available') + database.status.should.equal("available") database.iops.should.equal(6000) # boto>2.36.0 may change the following property name to `storage_type` - database.StorageType.should.equal('io1') + database.StorageType.should.equal("io1") diff --git a/tests/test_rds/test_server.py b/tests/test_rds/test_server.py index 224704a0b..ab53e83b4 100644 --- a/tests/test_rds/test_server.py +++ b/tests/test_rds/test_server.py @@ -5,9 +5,9 @@ import sure # noqa import moto.server as server from moto import mock_rds -''' +""" Test the different server responses -''' +""" @mock_rds @@ -15,6 +15,6 @@ def test_list_databases(): backend = server.create_backend_app("rds") test_client = backend.test_client() - res = test_client.get('/?Action=DescribeDBInstances') + res = test_client.get("/?Action=DescribeDBInstances") res.data.decode("utf-8").should.contain("") diff --git a/tests/test_rds2/test_rds2.py b/tests/test_rds2/test_rds2.py index 911f682a8..47b45539d 100644 --- a/tests/test_rds2/test_rds2.py +++ b/tests/test_rds2/test_rds2.py @@ -8,244 +8,301 @@ from moto import mock_ec2, mock_kms, mock_rds2 @mock_rds2 def test_create_database(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - VpcSecurityGroupIds=['sg-123456']) - db_instance = database['DBInstance'] - db_instance['AllocatedStorage'].should.equal(10) - db_instance['DBInstanceClass'].should.equal("db.m1.small") - db_instance['LicenseModel'].should.equal("license-included") - db_instance['MasterUsername'].should.equal("root") - db_instance['DBSecurityGroups'][0][ - 'DBSecurityGroupName'].should.equal('my_sg') - db_instance['DBInstanceArn'].should.equal( - 'arn:aws:rds:us-west-2:1234567890:db:db-master-1') - db_instance['DBInstanceStatus'].should.equal('available') - db_instance['DBName'].should.equal('staging-postgres') - db_instance['DBInstanceIdentifier'].should.equal("db-master-1") - db_instance['IAMDatabaseAuthenticationEnabled'].should.equal(False) - db_instance['DbiResourceId'].should.contain("db-") - db_instance['CopyTagsToSnapshot'].should.equal(False) - db_instance['InstanceCreateTime'].should.be.a("datetime.datetime") - db_instance['VpcSecurityGroups'][0]['VpcSecurityGroupId'].should.equal('sg-123456') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + VpcSecurityGroupIds=["sg-123456"], + ) + db_instance = database["DBInstance"] + db_instance["AllocatedStorage"].should.equal(10) + db_instance["DBInstanceClass"].should.equal("db.m1.small") + db_instance["LicenseModel"].should.equal("license-included") + db_instance["MasterUsername"].should.equal("root") + db_instance["DBSecurityGroups"][0]["DBSecurityGroupName"].should.equal("my_sg") + db_instance["DBInstanceArn"].should.equal( + "arn:aws:rds:us-west-2:1234567890:db:db-master-1" + ) + db_instance["DBInstanceStatus"].should.equal("available") + db_instance["DBName"].should.equal("staging-postgres") + db_instance["DBInstanceIdentifier"].should.equal("db-master-1") + db_instance["IAMDatabaseAuthenticationEnabled"].should.equal(False) + db_instance["DbiResourceId"].should.contain("db-") + db_instance["CopyTagsToSnapshot"].should.equal(False) + db_instance["InstanceCreateTime"].should.be.a("datetime.datetime") + db_instance["VpcSecurityGroups"][0]["VpcSecurityGroupId"].should.equal("sg-123456") @mock_rds2 def test_create_database_no_allocated_storage(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") database = conn.create_db_instance( - DBInstanceIdentifier='db-master-1', - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small') - db_instance = database['DBInstance'] - db_instance['Engine'].should.equal('postgres') - db_instance['StorageType'].should.equal('gp2') - db_instance['AllocatedStorage'].should.equal(20) + DBInstanceIdentifier="db-master-1", + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + ) + db_instance = database["DBInstance"] + db_instance["Engine"].should.equal("postgres") + db_instance["StorageType"].should.equal("gp2") + db_instance["AllocatedStorage"].should.equal(20) @mock_rds2 def test_create_database_non_existing_option_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") database = conn.create_db_instance.when.called_with( - DBInstanceIdentifier='db-master-1', + DBInstanceIdentifier="db-master-1", AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - OptionGroupName='non-existing').should.throw(ClientError) + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + OptionGroupName="non-existing", + ).should.throw(ClientError) @mock_rds2 def test_create_database_with_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='my-og', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - OptionGroupName='my-og') - db_instance = database['DBInstance'] - db_instance['AllocatedStorage'].should.equal(10) - db_instance['DBInstanceClass'].should.equal('db.m1.small') - db_instance['DBName'].should.equal('staging-postgres') - db_instance['OptionGroupMemberships'][0]['OptionGroupName'].should.equal('my-og') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="my-og", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + OptionGroupName="my-og", + ) + db_instance = database["DBInstance"] + db_instance["AllocatedStorage"].should.equal(10) + db_instance["DBInstanceClass"].should.equal("db.m1.small") + db_instance["DBName"].should.equal("staging-postgres") + db_instance["OptionGroupMemberships"][0]["OptionGroupName"].should.equal("my-og") @mock_rds2 def test_stop_database(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - mydb = conn.describe_db_instances(DBInstanceIdentifier=database['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + mydb = conn.describe_db_instances( + DBInstanceIdentifier=database["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # test stopping database should shutdown - response = conn.stop_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('stopped') + response = conn.stop_db_instance(DBInstanceIdentifier=mydb["DBInstanceIdentifier"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("stopped") # test rdsclient error when trying to stop an already stopped database - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # test stopping a stopped database with snapshot should error and no snapshot should exist for that call - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier'], DBSnapshotIdentifier='rocky4570-rds-snap').should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"], + DBSnapshotIdentifier="rocky4570-rds-snap", + ).should.throw(ClientError) response = conn.describe_db_snapshots() - response['DBSnapshots'].should.equal([]) + response["DBSnapshots"].should.equal([]) @mock_rds2 def test_start_database(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - mydb = conn.describe_db_instances(DBInstanceIdentifier=database['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + mydb = conn.describe_db_instances( + DBInstanceIdentifier=database["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # test starting an already started database should error - conn.start_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.start_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # stop and test start - should go from stopped to available, create snapshot and check snapshot - response = conn.stop_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier'], DBSnapshotIdentifier='rocky4570-rds-snap') - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('stopped') + response = conn.stop_db_instance( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"], + DBSnapshotIdentifier="rocky4570-rds-snap", + ) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("stopped") response = conn.describe_db_snapshots() - response['DBSnapshots'][0]['DBSnapshotIdentifier'].should.equal('rocky4570-rds-snap') - response = conn.start_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('available') + response["DBSnapshots"][0]["DBSnapshotIdentifier"].should.equal( + "rocky4570-rds-snap" + ) + response = conn.start_db_instance(DBInstanceIdentifier=mydb["DBInstanceIdentifier"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("available") # starting database should not remove snapshot response = conn.describe_db_snapshots() - response['DBSnapshots'][0]['DBSnapshotIdentifier'].should.equal('rocky4570-rds-snap') + response["DBSnapshots"][0]["DBSnapshotIdentifier"].should.equal( + "rocky4570-rds-snap" + ) # test stopping database, create snapshot with existing snapshot already created should throw error - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier'], DBSnapshotIdentifier='rocky4570-rds-snap').should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"], + DBSnapshotIdentifier="rocky4570-rds-snap", + ).should.throw(ClientError) # test stopping database not invoking snapshot should succeed. - response = conn.stop_db_instance(DBInstanceIdentifier=mydb['DBInstanceIdentifier']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['DBInstance']['DBInstanceStatus'].should.equal('stopped') + response = conn.stop_db_instance(DBInstanceIdentifier=mydb["DBInstanceIdentifier"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["DBInstance"]["DBInstanceStatus"].should.equal("stopped") @mock_rds2 def test_fail_to_stop_multi_az(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - MultiAZ=True) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + MultiAZ=True, + ) - mydb = conn.describe_db_instances(DBInstanceIdentifier=database['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + mydb = conn.describe_db_instances( + DBInstanceIdentifier=database["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # multi-az databases arent allowed to be shutdown at this time. - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # multi-az databases arent allowed to be started up at this time. - conn.start_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.start_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) @mock_rds2 def test_fail_to_stop_readreplica(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - LicenseModel='license-included', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + LicenseModel="license-included", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - replica = conn.create_db_instance_read_replica(DBInstanceIdentifier="db-replica-1", - SourceDBInstanceIdentifier="db-master-1", - DBInstanceClass="db.m1.small") + replica = conn.create_db_instance_read_replica( + DBInstanceIdentifier="db-replica-1", + SourceDBInstanceIdentifier="db-master-1", + DBInstanceClass="db.m1.small", + ) - mydb = conn.describe_db_instances(DBInstanceIdentifier=replica['DBInstance']['DBInstanceIdentifier'])['DBInstances'][0] - mydb['DBInstanceStatus'].should.equal('available') + mydb = conn.describe_db_instances( + DBInstanceIdentifier=replica["DBInstance"]["DBInstanceIdentifier"] + )["DBInstances"][0] + mydb["DBInstanceStatus"].should.equal("available") # read-replicas are not allowed to be stopped at this time. - conn.stop_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.stop_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) # read-replicas are not allowed to be started at this time. - conn.start_db_instance.when.called_with(DBInstanceIdentifier=mydb['DBInstanceIdentifier']).should.throw(ClientError) + conn.start_db_instance.when.called_with( + DBInstanceIdentifier=mydb["DBInstanceIdentifier"] + ).should.throw(ClientError) @mock_rds2 def test_get_databases(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(0) + list(instances["DBInstances"]).should.have.length_of(0) - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) - conn.create_db_instance(DBInstanceIdentifier='db-master-2', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + conn.create_db_instance( + DBInstanceIdentifier="db-master-2", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(2) + list(instances["DBInstances"]).should.have.length_of(2) instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - list(instances['DBInstances']).should.have.length_of(1) - instances['DBInstances'][0][ - 'DBInstanceIdentifier'].should.equal("db-master-1") - instances['DBInstances'][0]['DBInstanceArn'].should.equal( - 'arn:aws:rds:us-west-2:1234567890:db:db-master-1') + list(instances["DBInstances"]).should.have.length_of(1) + instances["DBInstances"][0]["DBInstanceIdentifier"].should.equal("db-master-1") + instances["DBInstances"][0]["DBInstanceArn"].should.equal( + "arn:aws:rds:us-west-2:1234567890:db:db-master-1" + ) @mock_rds2 def test_get_databases_paginated(): - conn = boto3.client('rds', region_name="us-west-2") + conn = boto3.client("rds", region_name="us-west-2") for i in range(51): - conn.create_db_instance(AllocatedStorage=5, - Port=5432, - DBInstanceIdentifier='rds%d' % i, - DBInstanceClass='db.t1.micro', - Engine='postgres') + conn.create_db_instance( + AllocatedStorage=5, + Port=5432, + DBInstanceIdentifier="rds%d" % i, + DBInstanceClass="db.t1.micro", + Engine="postgres", + ) resp = conn.describe_db_instances() resp["DBInstances"].should.have.length_of(50) - resp["Marker"].should.equal(resp["DBInstances"][-1]['DBInstanceIdentifier']) + resp["Marker"].should.equal(resp["DBInstances"][-1]["DBInstanceIdentifier"]) resp2 = conn.describe_db_instances(Marker=resp["Marker"]) resp2["DBInstances"].should.have.length_of(1) @@ -256,1269 +313,1379 @@ def test_get_databases_paginated(): @mock_rds2 def test_describe_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.describe_db_instances.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_modify_db_instance(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) - instances = conn.describe_db_instances(DBInstanceIdentifier='db-master-1') - instances['DBInstances'][0]['AllocatedStorage'].should.equal(10) - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=20, - ApplyImmediately=True, - VpcSecurityGroupIds=['sg-123456']) - instances = conn.describe_db_instances(DBInstanceIdentifier='db-master-1') - instances['DBInstances'][0]['AllocatedStorage'].should.equal(20) - instances['DBInstances'][0]['VpcSecurityGroups'][0]['VpcSecurityGroupId'].should.equal('sg-123456') + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") + instances["DBInstances"][0]["AllocatedStorage"].should.equal(10) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=20, + ApplyImmediately=True, + VpcSecurityGroupIds=["sg-123456"], + ) + instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") + instances["DBInstances"][0]["AllocatedStorage"].should.equal(20) + instances["DBInstances"][0]["VpcSecurityGroups"][0][ + "VpcSecurityGroupId" + ].should.equal("sg-123456") @mock_rds2 def test_rename_db_instance(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - list(instances['DBInstances']).should.have.length_of(1) - conn.describe_db_instances.when.called_with(DBInstanceIdentifier="db-master-2").should.throw(ClientError) - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - NewDBInstanceIdentifier='db-master-2', - ApplyImmediately=True) - conn.describe_db_instances.when.called_with(DBInstanceIdentifier="db-master-1").should.throw(ClientError) + list(instances["DBInstances"]).should.have.length_of(1) + conn.describe_db_instances.when.called_with( + DBInstanceIdentifier="db-master-2" + ).should.throw(ClientError) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", + NewDBInstanceIdentifier="db-master-2", + ApplyImmediately=True, + ) + conn.describe_db_instances.when.called_with( + DBInstanceIdentifier="db-master-1" + ).should.throw(ClientError) instances = conn.describe_db_instances(DBInstanceIdentifier="db-master-2") - list(instances['DBInstances']).should.have.length_of(1) + list(instances["DBInstances"]).should.have.length_of(1) @mock_rds2 def test_modify_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') - conn.modify_db_instance.when.called_with(DBInstanceIdentifier='not-a-db', - AllocatedStorage=20, - ApplyImmediately=True).should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.modify_db_instance.when.called_with( + DBInstanceIdentifier="not-a-db", AllocatedStorage=20, ApplyImmediately=True + ).should.throw(ClientError) @mock_rds2 def test_reboot_db_instance(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) - database = conn.reboot_db_instance(DBInstanceIdentifier='db-master-1') - database['DBInstance']['DBInstanceIdentifier'].should.equal("db-master-1") + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + database = conn.reboot_db_instance(DBInstanceIdentifier="db-master-1") + database["DBInstance"]["DBInstanceIdentifier"].should.equal("db-master-1") @mock_rds2 def test_reboot_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.reboot_db_instance.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_delete_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(0) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg']) + list(instances["DBInstances"]).should.have.length_of(0) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(1) + list(instances["DBInstances"]).should.have.length_of(1) - conn.delete_db_instance(DBInstanceIdentifier="db-primary-1", - FinalDBSnapshotIdentifier='primary-1-snapshot') + conn.delete_db_instance( + DBInstanceIdentifier="db-primary-1", + FinalDBSnapshotIdentifier="primary-1-snapshot", + ) instances = conn.describe_db_instances() - list(instances['DBInstances']).should.have.length_of(0) + list(instances["DBInstances"]).should.have.length_of(0) # Saved the snapshot - snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get('DBSnapshots') - snapshots[0].get('Engine').should.equal('postgres') + snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get( + "DBSnapshots" + ) + snapshots[0].get("Engine").should.equal("postgres") @mock_rds2 def test_delete_non_existant_database(): - conn = boto3.client('rds2', region_name="us-west-2") + conn = boto3.client("rds2", region_name="us-west-2") conn.delete_db_instance.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_create_db_snapshots(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.create_db_snapshot.when.called_with( - DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1').should.throw(ClientError) + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ).should.throw(ClientError) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='g-1').get('DBSnapshot') + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="g-1" + ).get("DBSnapshot") - snapshot.get('Engine').should.equal('postgres') - snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1') - snapshot.get('DBSnapshotIdentifier').should.equal('g-1') - result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn']) - result['TagList'].should.equal([]) + snapshot.get("Engine").should.equal("postgres") + snapshot.get("DBInstanceIdentifier").should.equal("db-primary-1") + snapshot.get("DBSnapshotIdentifier").should.equal("g-1") + result = conn.list_tags_for_resource(ResourceName=snapshot["DBSnapshotArn"]) + result["TagList"].should.equal([]) @mock_rds2 def test_create_db_snapshots_copy_tags(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.create_db_snapshot.when.called_with( - DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1').should.throw(ClientError) + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ).should.throw(ClientError) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - CopyTagsToSnapshot=True, - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + CopyTagsToSnapshot=True, + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='g-1').get('DBSnapshot') + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="g-1" + ).get("DBSnapshot") - snapshot.get('Engine').should.equal('postgres') - snapshot.get('DBInstanceIdentifier').should.equal('db-primary-1') - snapshot.get('DBSnapshotIdentifier').should.equal('g-1') - result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshotArn']) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + snapshot.get("Engine").should.equal("postgres") + snapshot.get("DBInstanceIdentifier").should.equal("db-primary-1") + snapshot.get("DBSnapshotIdentifier").should.equal("g-1") + result = conn.list_tags_for_resource(ResourceName=snapshot["DBSnapshotArn"]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_describe_db_snapshots(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - created = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1').get('DBSnapshot') + created = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ).get("DBSnapshot") - created.get('Engine').should.equal('postgres') + created.get("Engine").should.equal("postgres") - by_database_id = conn.describe_db_snapshots(DBInstanceIdentifier='db-primary-1').get('DBSnapshots') - by_snapshot_id = conn.describe_db_snapshots(DBSnapshotIdentifier='snapshot-1').get('DBSnapshots') + by_database_id = conn.describe_db_snapshots( + DBInstanceIdentifier="db-primary-1" + ).get("DBSnapshots") + by_snapshot_id = conn.describe_db_snapshots(DBSnapshotIdentifier="snapshot-1").get( + "DBSnapshots" + ) by_snapshot_id.should.equal(by_database_id) snapshot = by_snapshot_id[0] snapshot.should.equal(created) - snapshot.get('Engine').should.equal('postgres') + snapshot.get("Engine").should.equal("postgres") - conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-2') - snapshots = conn.describe_db_snapshots(DBInstanceIdentifier='db-primary-1').get('DBSnapshots') + conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-2" + ) + snapshots = conn.describe_db_snapshots(DBInstanceIdentifier="db-primary-1").get( + "DBSnapshots" + ) snapshots.should.have.length_of(2) @mock_rds2 def test_delete_db_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-1') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", DBSnapshotIdentifier="snapshot-1" + ) - conn.describe_db_snapshots(DBSnapshotIdentifier='snapshot-1').get('DBSnapshots')[0] - conn.delete_db_snapshot(DBSnapshotIdentifier='snapshot-1') + conn.describe_db_snapshots(DBSnapshotIdentifier="snapshot-1").get("DBSnapshots")[0] + conn.delete_db_snapshot(DBSnapshotIdentifier="snapshot-1") conn.describe_db_snapshots.when.called_with( - DBSnapshotIdentifier='snapshot-1').should.throw(ClientError) + DBSnapshotIdentifier="snapshot-1" + ).should.throw(ClientError) @mock_rds2 def test_create_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - option_group = conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - option_group['OptionGroup']['OptionGroupName'].should.equal('test') - option_group['OptionGroup']['EngineName'].should.equal('mysql') - option_group['OptionGroup'][ - 'OptionGroupDescription'].should.equal('test option group') - option_group['OptionGroup']['MajorEngineVersion'].should.equal('5.6') + conn = boto3.client("rds", region_name="us-west-2") + option_group = conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + option_group["OptionGroup"]["OptionGroupName"].should.equal("test") + option_group["OptionGroup"]["EngineName"].should.equal("mysql") + option_group["OptionGroup"]["OptionGroupDescription"].should.equal( + "test option group" + ) + option_group["OptionGroup"]["MajorEngineVersion"].should.equal("5.6") @mock_rds2 def test_create_option_group_bad_engine_name(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='invalid_engine', - MajorEngineVersion='5.6', - OptionGroupDescription='test invalid engine').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="invalid_engine", + MajorEngineVersion="5.6", + OptionGroupDescription="test invalid engine", + ).should.throw(ClientError) @mock_rds2 def test_create_option_group_bad_engine_major_version(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='6.6.6', - OptionGroupDescription='test invalid engine version').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="6.6.6", + OptionGroupDescription="test invalid engine version", + ).should.throw(ClientError) @mock_rds2 def test_create_option_group_empty_description(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="", + ).should.throw(ClientError) @mock_rds2 def test_create_option_group_duplicate(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - conn.create_option_group.when.called_with(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + conn.create_option_group.when.called_with( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ).should.throw(ClientError) @mock_rds2 def test_describe_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - option_groups = conn.describe_option_groups(OptionGroupName='test') - option_groups['OptionGroupsList'][0][ - 'OptionGroupName'].should.equal('test') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + option_groups = conn.describe_option_groups(OptionGroupName="test") + option_groups["OptionGroupsList"][0]["OptionGroupName"].should.equal("test") @mock_rds2 def test_describe_non_existant_option_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.describe_option_groups.when.called_with( - OptionGroupName="not-a-option-group").should.throw(ClientError) + OptionGroupName="not-a-option-group" + ).should.throw(ClientError) @mock_rds2 def test_delete_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') - option_groups = conn.describe_option_groups(OptionGroupName='test') - option_groups['OptionGroupsList'][0][ - 'OptionGroupName'].should.equal('test') - conn.delete_option_group(OptionGroupName='test') - conn.describe_option_groups.when.called_with( - OptionGroupName='test').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + option_groups = conn.describe_option_groups(OptionGroupName="test") + option_groups["OptionGroupsList"][0]["OptionGroupName"].should.equal("test") + conn.delete_option_group(OptionGroupName="test") + conn.describe_option_groups.when.called_with(OptionGroupName="test").should.throw( + ClientError + ) @mock_rds2 def test_delete_non_existant_option_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_option_group.when.called_with( - OptionGroupName='non-existant').should.throw(ClientError) + OptionGroupName="non-existant" + ).should.throw(ClientError) @mock_rds2 def test_describe_option_group_options(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") + option_group_options = conn.describe_option_group_options(EngineName="sqlserver-ee") + len(option_group_options["OptionGroupOptions"]).should.equal(4) option_group_options = conn.describe_option_group_options( - EngineName='sqlserver-ee') - len(option_group_options['OptionGroupOptions']).should.equal(4) + EngineName="sqlserver-ee", MajorEngineVersion="11.00" + ) + len(option_group_options["OptionGroupOptions"]).should.equal(2) option_group_options = conn.describe_option_group_options( - EngineName='sqlserver-ee', MajorEngineVersion='11.00') - len(option_group_options['OptionGroupOptions']).should.equal(2) - option_group_options = conn.describe_option_group_options( - EngineName='mysql', MajorEngineVersion='5.6') - len(option_group_options['OptionGroupOptions']).should.equal(1) + EngineName="mysql", MajorEngineVersion="5.6" + ) + len(option_group_options["OptionGroupOptions"]).should.equal(1) conn.describe_option_group_options.when.called_with( - EngineName='non-existent').should.throw(ClientError) + EngineName="non-existent" + ).should.throw(ClientError) conn.describe_option_group_options.when.called_with( - EngineName='mysql', MajorEngineVersion='non-existent').should.throw(ClientError) + EngineName="mysql", MajorEngineVersion="non-existent" + ).should.throw(ClientError) @mock_rds2 def test_modify_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', EngineName='mysql', - MajorEngineVersion='5.6', OptionGroupDescription='test option group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) # TODO: create option and validate before deleting. # if Someone can tell me how the hell to use this function # to add options to an option_group, I can finish coding this. - result = conn.modify_option_group(OptionGroupName='test', OptionsToInclude=[ - ], OptionsToRemove=['MEMCACHED'], ApplyImmediately=True) - result['OptionGroup']['EngineName'].should.equal('mysql') - result['OptionGroup']['Options'].should.equal([]) - result['OptionGroup']['OptionGroupName'].should.equal('test') + result = conn.modify_option_group( + OptionGroupName="test", + OptionsToInclude=[], + OptionsToRemove=["MEMCACHED"], + ApplyImmediately=True, + ) + result["OptionGroup"]["EngineName"].should.equal("mysql") + result["OptionGroup"]["Options"].should.equal([]) + result["OptionGroup"]["OptionGroupName"].should.equal("test") @mock_rds2 def test_modify_option_group_no_options(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', EngineName='mysql', - MajorEngineVersion='5.6', OptionGroupDescription='test option group') - conn.modify_option_group.when.called_with( - OptionGroupName='test').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) + conn.modify_option_group.when.called_with(OptionGroupName="test").should.throw( + ClientError + ) @mock_rds2 def test_modify_non_existant_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.modify_option_group.when.called_with(OptionGroupName='non-existant', OptionsToInclude=[( - 'OptionName', 'Port', 'DBSecurityGroupMemberships', 'VpcSecurityGroupMemberships', 'OptionSettings')]).should.throw(ParamValidationError) + conn = boto3.client("rds", region_name="us-west-2") + conn.modify_option_group.when.called_with( + OptionGroupName="non-existant", + OptionsToInclude=[ + ( + "OptionName", + "Port", + "DBSecurityGroupMemberships", + "VpcSecurityGroupMemberships", + "OptionSettings", + ) + ], + ).should.throw(ParamValidationError) @mock_rds2 def test_delete_non_existant_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_db_instance.when.called_with( - DBInstanceIdentifier="not-a-db").should.throw(ClientError) + DBInstanceIdentifier="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_list_tags_invalid_arn(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.list_tags_for_resource.when.called_with( - ResourceName='arn:aws:rds:bad-arn').should.throw(ClientError) + ResourceName="arn:aws:rds:bad-arn" + ).should.throw(ClientError) @mock_rds2 def test_list_tags_db(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:foo') - result['TagList'].should.equal([]) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:foo" + ) + result["TagList"].should.equal([]) test_instance = conn.create_db_instance( - DBInstanceIdentifier='db-with-tags', + DBInstanceIdentifier="db-with-tags", AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", Port=1234, - DBSecurityGroups=['my_sg'], - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + DBSecurityGroups=["my_sg"], + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName=test_instance['DBInstance']['DBInstanceArn']) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + ResourceName=test_instance["DBInstance"]["DBInstanceArn"] + ) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_add_tags_db(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-without-tags', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg'], - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-without-tags", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-without-tags') - list(result['TagList']).should.have.length_of(2) - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-without-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }, - ]) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-without-tags" + ) + list(result["TagList"]).should.have.length_of(2) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-without-tags", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-without-tags') - list(result['TagList']).should.have.length_of(3) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-without-tags" + ) + list(result["TagList"]).should.have.length_of(3) @mock_rds2 def test_remove_tags_db(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-with-tags', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=['my_sg'], - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-with-tags", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-with-tags') - list(result['TagList']).should.have.length_of(2) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-with-tags" + ) + list(result["TagList"]).should.have.length_of(2) conn.remove_tags_from_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-with-tags', TagKeys=['foo']) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-with-tags", TagKeys=["foo"] + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:db:db-with-tags') - len(result['TagList']).should.equal(1) + ResourceName="arn:aws:rds:us-west-2:1234567890:db:db-with-tags" + ) + len(result["TagList"]).should.equal(1) @mock_rds2 def test_list_tags_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:foo') - result['TagList'].should.equal([]) - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-with-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) - result = conn.list_tags_for_resource(ResourceName=snapshot['DBSnapshot']['DBSnapshotArn']) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:foo" + ) + result["TagList"].should.equal([]) + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", + DBSnapshotIdentifier="snapshot-with-tags", + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) + result = conn.list_tags_for_resource( + ResourceName=snapshot["DBSnapshot"]["DBSnapshotArn"] + ) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_add_tags_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-without-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", + DBSnapshotIdentifier="snapshot-without-tags", + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags') - list(result['TagList']).should.have.length_of(2) - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }, - ]) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags" + ) + list(result["TagList"]).should.have.length_of(2) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags') - list(result['TagList']).should.have.length_of(3) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-without-tags" + ) + list(result["TagList"]).should.have.length_of(3) @mock_rds2 def test_remove_tags_snapshot(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_instance(DBInstanceIdentifier='db-primary-1', - AllocatedStorage=10, - Engine='postgres', - DBName='staging-postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) - snapshot = conn.create_db_snapshot(DBInstanceIdentifier='db-primary-1', - DBSnapshotIdentifier='snapshot-with-tags', - Tags=[ - { - 'Key': 'foo', - 'Value': 'bar', - }, - { - 'Key': 'foo1', - 'Value': 'bar1', - }, - ]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_instance( + DBInstanceIdentifier="db-primary-1", + AllocatedStorage=10, + Engine="postgres", + DBName="staging-postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) + snapshot = conn.create_db_snapshot( + DBInstanceIdentifier="db-primary-1", + DBSnapshotIdentifier="snapshot-with-tags", + Tags=[{"Key": "foo", "Value": "bar"}, {"Key": "foo1", "Value": "bar1"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags') - list(result['TagList']).should.have.length_of(2) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags" + ) + list(result["TagList"]).should.have.length_of(2) conn.remove_tags_from_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags', TagKeys=['foo']) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags", + TagKeys=["foo"], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags') - len(result['TagList']).should.equal(1) + ResourceName="arn:aws:rds:us-west-2:1234567890:snapshot:snapshot-with-tags" + ) + len(result["TagList"]).should.equal(1) @mock_rds2 def test_add_tags_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(0) - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:og:test', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }]) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(0) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(2) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(2) @mock_rds2 def test_remove_tags_option_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_option_group(OptionGroupName='test', - EngineName='mysql', - MajorEngineVersion='5.6', - OptionGroupDescription='test option group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_option_group( + OptionGroupName="test", + EngineName="mysql", + MajorEngineVersion="5.6", + OptionGroupDescription="test option group", + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - conn.add_tags_to_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:og:test', - Tags=[ - { - 'Key': 'foo', - 'Value': 'fish', - }, - { - 'Key': 'foo2', - 'Value': 'bar2', - }]) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + conn.add_tags_to_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test", + Tags=[{"Key": "foo", "Value": "fish"}, {"Key": "foo2", "Value": "bar2"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(2) - conn.remove_tags_from_resource(ResourceName='arn:aws:rds:us-west-2:1234567890:og:test', - TagKeys=['foo']) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(2) + conn.remove_tags_from_resource( + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test", TagKeys=["foo"] + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:og:test') - list(result['TagList']).should.have.length_of(1) + ResourceName="arn:aws:rds:us-west-2:1234567890:og:test" + ) + list(result["TagList"]).should.have.length_of(1) @mock_rds2 def test_create_database_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.create_db_security_group( - DBSecurityGroupName='db_sg', DBSecurityGroupDescription='DB Security Group') - result['DBSecurityGroup']['DBSecurityGroupName'].should.equal("db_sg") - result['DBSecurityGroup'][ - 'DBSecurityGroupDescription'].should.equal("DB Security Group") - result['DBSecurityGroup']['IPRanges'].should.equal([]) + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) + result["DBSecurityGroup"]["DBSecurityGroupName"].should.equal("db_sg") + result["DBSecurityGroup"]["DBSecurityGroupDescription"].should.equal( + "DB Security Group" + ) + result["DBSecurityGroup"]["IPRanges"].should.equal([]) @mock_rds2 def test_get_security_groups(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(0) + result["DBSecurityGroups"].should.have.length_of(0) conn.create_db_security_group( - DBSecurityGroupName='db_sg1', DBSecurityGroupDescription='DB Security Group') + DBSecurityGroupName="db_sg1", DBSecurityGroupDescription="DB Security Group" + ) conn.create_db_security_group( - DBSecurityGroupName='db_sg2', DBSecurityGroupDescription='DB Security Group') + DBSecurityGroupName="db_sg2", DBSecurityGroupDescription="DB Security Group" + ) result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(2) + result["DBSecurityGroups"].should.have.length_of(2) result = conn.describe_db_security_groups(DBSecurityGroupName="db_sg1") - result['DBSecurityGroups'].should.have.length_of(1) - result['DBSecurityGroups'][0]['DBSecurityGroupName'].should.equal("db_sg1") + result["DBSecurityGroups"].should.have.length_of(1) + result["DBSecurityGroups"][0]["DBSecurityGroupName"].should.equal("db_sg1") @mock_rds2 def test_get_non_existant_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.describe_db_security_groups.when.called_with( - DBSecurityGroupName="not-a-sg").should.throw(ClientError) + DBSecurityGroupName="not-a-sg" + ).should.throw(ClientError) @mock_rds2 def test_delete_database_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.create_db_security_group( - DBSecurityGroupName='db_sg', DBSecurityGroupDescription='DB Security Group') + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(1) + result["DBSecurityGroups"].should.have.length_of(1) conn.delete_db_security_group(DBSecurityGroupName="db_sg") result = conn.describe_db_security_groups() - result['DBSecurityGroups'].should.have.length_of(0) + result["DBSecurityGroups"].should.have.length_of(0) @mock_rds2 def test_delete_non_existant_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_db_security_group.when.called_with( - DBSecurityGroupName="not-a-db").should.throw(ClientError) + DBSecurityGroupName="not-a-db" + ).should.throw(ClientError) @mock_rds2 def test_security_group_authorize(): - conn = boto3.client('rds', region_name='us-west-2') - security_group = conn.create_db_security_group(DBSecurityGroupName='db_sg', - DBSecurityGroupDescription='DB Security Group') - security_group['DBSecurityGroup']['IPRanges'].should.equal([]) + conn = boto3.client("rds", region_name="us-west-2") + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) + security_group["DBSecurityGroup"]["IPRanges"].should.equal([]) - conn.authorize_db_security_group_ingress(DBSecurityGroupName='db_sg', - CIDRIP='10.3.2.45/32') + conn.authorize_db_security_group_ingress( + DBSecurityGroupName="db_sg", CIDRIP="10.3.2.45/32" + ) result = conn.describe_db_security_groups(DBSecurityGroupName="db_sg") - result['DBSecurityGroups'][0]['IPRanges'].should.have.length_of(1) - result['DBSecurityGroups'][0]['IPRanges'].should.equal( - [{'Status': 'authorized', 'CIDRIP': '10.3.2.45/32'}]) + result["DBSecurityGroups"][0]["IPRanges"].should.have.length_of(1) + result["DBSecurityGroups"][0]["IPRanges"].should.equal( + [{"Status": "authorized", "CIDRIP": "10.3.2.45/32"}] + ) - conn.authorize_db_security_group_ingress(DBSecurityGroupName='db_sg', - CIDRIP='10.3.2.46/32') + conn.authorize_db_security_group_ingress( + DBSecurityGroupName="db_sg", CIDRIP="10.3.2.46/32" + ) result = conn.describe_db_security_groups(DBSecurityGroupName="db_sg") - result['DBSecurityGroups'][0]['IPRanges'].should.have.length_of(2) - result['DBSecurityGroups'][0]['IPRanges'].should.equal([ - {'Status': 'authorized', 'CIDRIP': '10.3.2.45/32'}, - {'Status': 'authorized', 'CIDRIP': '10.3.2.46/32'}, - ]) + result["DBSecurityGroups"][0]["IPRanges"].should.have.length_of(2) + result["DBSecurityGroups"][0]["IPRanges"].should.equal( + [ + {"Status": "authorized", "CIDRIP": "10.3.2.45/32"}, + {"Status": "authorized", "CIDRIP": "10.3.2.46/32"}, + ] + ) @mock_rds2 def test_add_security_group_to_database(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - DBInstanceClass='postgres', - Engine='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234) + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + DBInstanceClass="postgres", + Engine="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + ) result = conn.describe_db_instances() - result['DBInstances'][0]['DBSecurityGroups'].should.equal([]) - conn.create_db_security_group(DBSecurityGroupName='db_sg', - DBSecurityGroupDescription='DB Security Group') - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - DBSecurityGroups=['db_sg']) + result["DBInstances"][0]["DBSecurityGroups"].should.equal([]) + conn.create_db_security_group( + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + ) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", DBSecurityGroups=["db_sg"] + ) result = conn.describe_db_instances() - result['DBInstances'][0]['DBSecurityGroups'][0][ - 'DBSecurityGroupName'].should.equal('db_sg') + result["DBInstances"][0]["DBSecurityGroups"][0]["DBSecurityGroupName"].should.equal( + "db_sg" + ) @mock_rds2 def test_list_tags_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - security_group = conn.create_db_security_group(DBSecurityGroupName="db_sg", - DBSecurityGroupDescription='DB Security Group', - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSecurityGroup']['DBSecurityGroupName'] - resource = 'arn:aws:rds:us-west-2:1234567890:secgrp:{0}'.format( - security_group) + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", + DBSecurityGroupDescription="DB Security Group", + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSecurityGroup"]["DBSecurityGroupName"] + resource = "arn:aws:rds:us-west-2:1234567890:secgrp:{0}".format(security_group) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_add_tags_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - security_group = conn.create_db_security_group(DBSecurityGroupName="db_sg", - DBSecurityGroupDescription='DB Security Group')['DBSecurityGroup']['DBSecurityGroupName'] + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", DBSecurityGroupDescription="DB Security Group" + )["DBSecurityGroup"]["DBSecurityGroupName"] - resource = 'arn:aws:rds:us-west-2:1234567890:secgrp:{0}'.format( - security_group) - conn.add_tags_to_resource(ResourceName=resource, - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + resource = "arn:aws:rds:us-west-2:1234567890:secgrp:{0}".format(security_group) + conn.add_tags_to_resource( + ResourceName=resource, + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + ) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_rds2 def test_remove_tags_security_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - security_group = conn.create_db_security_group(DBSecurityGroupName="db_sg", - DBSecurityGroupDescription='DB Security Group', - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSecurityGroup']['DBSecurityGroupName'] + security_group = conn.create_db_security_group( + DBSecurityGroupName="db_sg", + DBSecurityGroupDescription="DB Security Group", + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSecurityGroup"]["DBSecurityGroupName"] - resource = 'arn:aws:rds:us-west-2:1234567890:secgrp:{0}'.format( - security_group) - conn.remove_tags_from_resource(ResourceName=resource, TagKeys=['foo']) + resource = "arn:aws:rds:us-west-2:1234567890:secgrp:{0}".format(security_group) + conn.remove_tags_from_resource(ResourceName=resource, TagKeys=["foo"]) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar1', 'Key': 'foo1'}]) + result["TagList"].should.equal([{"Value": "bar1", "Key": "foo1"}]) @mock_ec2 @mock_rds2 def test_create_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet1 = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] - subnet2 = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.2.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet1 = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] + subnet2 = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.2.0/24")[ + "Subnet" + ] - subnet_ids = [subnet1['SubnetId'], subnet2['SubnetId']] - conn = boto3.client('rds', region_name='us-west-2') - result = conn.create_db_subnet_group(DBSubnetGroupName='db_subnet', - DBSubnetGroupDescription='my db subnet', - SubnetIds=subnet_ids) - result['DBSubnetGroup']['DBSubnetGroupName'].should.equal("db_subnet") - result['DBSubnetGroup'][ - 'DBSubnetGroupDescription'].should.equal("my db subnet") - subnets = result['DBSubnetGroup']['Subnets'] - subnet_group_ids = [subnets[0]['SubnetIdentifier'], - subnets[1]['SubnetIdentifier']] + subnet_ids = [subnet1["SubnetId"], subnet2["SubnetId"]] + conn = boto3.client("rds", region_name="us-west-2") + result = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet", + DBSubnetGroupDescription="my db subnet", + SubnetIds=subnet_ids, + ) + result["DBSubnetGroup"]["DBSubnetGroupName"].should.equal("db_subnet") + result["DBSubnetGroup"]["DBSubnetGroupDescription"].should.equal("my db subnet") + subnets = result["DBSubnetGroup"]["Subnets"] + subnet_group_ids = [subnets[0]["SubnetIdentifier"], subnets[1]["SubnetIdentifier"]] list(subnet_group_ids).should.equal(subnet_ids) @mock_ec2 @mock_rds2 def test_create_database_in_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_subnet_group(DBSubnetGroupName='db_subnet1', - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) - conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSubnetGroupName='db_subnet1') - result = conn.describe_db_instances(DBInstanceIdentifier='db-master-1') - result['DBInstances'][0]['DBSubnetGroup'][ - 'DBSubnetGroupName'].should.equal('db_subnet1') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) + conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSubnetGroupName="db_subnet1", + ) + result = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") + result["DBInstances"][0]["DBSubnetGroup"]["DBSubnetGroupName"].should.equal( + "db_subnet1" + ) @mock_ec2 @mock_rds2 def test_describe_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) - conn.create_db_subnet_group(DBSubnetGroupName='db_subnet2', - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet2", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) resp = conn.describe_db_subnet_groups() - resp['DBSubnetGroups'].should.have.length_of(2) + resp["DBSubnetGroups"].should.have.length_of(2) - subnets = resp['DBSubnetGroups'][0]['Subnets'] + subnets = resp["DBSubnetGroups"][0]["Subnets"] subnets.should.have.length_of(1) - list(conn.describe_db_subnet_groups(DBSubnetGroupName="db_subnet1") - ['DBSubnetGroups']).should.have.length_of(1) + list( + conn.describe_db_subnet_groups(DBSubnetGroupName="db_subnet1")["DBSubnetGroups"] + ).should.have.length_of(1) conn.describe_db_subnet_groups.when.called_with( - DBSubnetGroupName="not-a-subnet").should.throw(ClientError) + DBSubnetGroupName="not-a-subnet" + ).should.throw(ClientError) @mock_ec2 @mock_rds2 def test_delete_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']]) + conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + ) result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(1) + result["DBSubnetGroups"].should.have.length_of(1) conn.delete_db_subnet_group(DBSubnetGroupName="db_subnet1") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) conn.delete_db_subnet_group.when.called_with( - DBSubnetGroupName="db_subnet1").should.throw(ClientError) + DBSubnetGroupName="db_subnet1" + ).should.throw(ClientError) @mock_ec2 @mock_rds2 def test_list_tags_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - subnet = conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']], - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSubnetGroup']['DBSubnetGroupName'] + subnet = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSubnetGroup"]["DBSubnetGroupName"] result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:subgrp:{0}'.format(subnet)) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + ResourceName="arn:aws:rds:us-west-2:1234567890:subgrp:{0}".format(subnet) + ) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_ec2 @mock_rds2 def test_add_tags_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - subnet = conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']], - Tags=[])['DBSubnetGroup']['DBSubnetGroupName'] - resource = 'arn:aws:rds:us-west-2:1234567890:subgrp:{0}'.format(subnet) + subnet = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + Tags=[], + )["DBSubnetGroup"]["DBSubnetGroupName"] + resource = "arn:aws:rds:us-west-2:1234567890:subgrp:{0}".format(subnet) - conn.add_tags_to_resource(ResourceName=resource, - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + conn.add_tags_to_resource( + ResourceName=resource, + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + ) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}]) + result["TagList"].should.equal( + [{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}] + ) @mock_ec2 @mock_rds2 def test_remove_tags_database_subnet_group(): - vpc_conn = boto3.client('ec2', 'us-west-2') - vpc = vpc_conn.create_vpc(CidrBlock='10.0.0.0/16')['Vpc'] - subnet = vpc_conn.create_subnet( - VpcId=vpc['VpcId'], CidrBlock='10.0.1.0/24')['Subnet'] + vpc_conn = boto3.client("ec2", "us-west-2") + vpc = vpc_conn.create_vpc(CidrBlock="10.0.0.0/16")["Vpc"] + subnet = vpc_conn.create_subnet(VpcId=vpc["VpcId"], CidrBlock="10.0.1.0/24")[ + "Subnet" + ] - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") result = conn.describe_db_subnet_groups() - result['DBSubnetGroups'].should.have.length_of(0) + result["DBSubnetGroups"].should.have.length_of(0) - subnet = conn.create_db_subnet_group(DBSubnetGroupName="db_subnet1", - DBSubnetGroupDescription='my db subnet', - SubnetIds=[subnet['SubnetId']], - Tags=[{'Value': 'bar', - 'Key': 'foo'}, - {'Value': 'bar1', - 'Key': 'foo1'}])['DBSubnetGroup']['DBSubnetGroupName'] - resource = 'arn:aws:rds:us-west-2:1234567890:subgrp:{0}'.format(subnet) + subnet = conn.create_db_subnet_group( + DBSubnetGroupName="db_subnet1", + DBSubnetGroupDescription="my db subnet", + SubnetIds=[subnet["SubnetId"]], + Tags=[{"Value": "bar", "Key": "foo"}, {"Value": "bar1", "Key": "foo1"}], + )["DBSubnetGroup"]["DBSubnetGroupName"] + resource = "arn:aws:rds:us-west-2:1234567890:subgrp:{0}".format(subnet) - conn.remove_tags_from_resource(ResourceName=resource, TagKeys=['foo']) + conn.remove_tags_from_resource(ResourceName=resource, TagKeys=["foo"]) result = conn.list_tags_for_resource(ResourceName=resource) - result['TagList'].should.equal([{'Value': 'bar1', 'Key': 'foo1'}]) + result["TagList"].should.equal([{"Value": "bar1", "Key": "foo1"}]) @mock_rds2 def test_create_database_replica(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"]) + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + ) - replica = conn.create_db_instance_read_replica(DBInstanceIdentifier="db-replica-1", - SourceDBInstanceIdentifier="db-master-1", - DBInstanceClass="db.m1.small") - replica['DBInstance'][ - 'ReadReplicaSourceDBInstanceIdentifier'].should.equal('db-master-1') - replica['DBInstance']['DBInstanceClass'].should.equal('db.m1.small') - replica['DBInstance']['DBInstanceIdentifier'].should.equal('db-replica-1') + replica = conn.create_db_instance_read_replica( + DBInstanceIdentifier="db-replica-1", + SourceDBInstanceIdentifier="db-master-1", + DBInstanceClass="db.m1.small", + ) + replica["DBInstance"]["ReadReplicaSourceDBInstanceIdentifier"].should.equal( + "db-master-1" + ) + replica["DBInstance"]["DBInstanceClass"].should.equal("db.m1.small") + replica["DBInstance"]["DBInstanceIdentifier"].should.equal("db-replica-1") master = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - master['DBInstances'][0]['ReadReplicaDBInstanceIdentifiers'].should.equal([ - 'db-replica-1']) + master["DBInstances"][0]["ReadReplicaDBInstanceIdentifiers"].should.equal( + ["db-replica-1"] + ) - conn.delete_db_instance( - DBInstanceIdentifier="db-replica-1", SkipFinalSnapshot=True) + conn.delete_db_instance(DBInstanceIdentifier="db-replica-1", SkipFinalSnapshot=True) master = conn.describe_db_instances(DBInstanceIdentifier="db-master-1") - master['DBInstances'][0][ - 'ReadReplicaDBInstanceIdentifiers'].should.equal([]) + master["DBInstances"][0]["ReadReplicaDBInstanceIdentifiers"].should.equal([]) @mock_rds2 @mock_kms def test_create_database_with_encrypted_storage(): - kms_conn = boto3.client('kms', region_name='us-west-2') - key = kms_conn.create_key(Policy='my RDS encryption policy', - Description='RDS encryption key', - KeyUsage='ENCRYPT_DECRYPT') + kms_conn = boto3.client("kms", region_name="us-west-2") + key = kms_conn.create_key( + Policy="my RDS encryption policy", + Description="RDS encryption key", + KeyUsage="ENCRYPT_DECRYPT", + ) - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234, - DBSecurityGroups=["my_sg"], - StorageEncrypted=True, - KmsKeyId=key['KeyMetadata']['KeyId']) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + DBSecurityGroups=["my_sg"], + StorageEncrypted=True, + KmsKeyId=key["KeyMetadata"]["KeyId"], + ) - database['DBInstance']['StorageEncrypted'].should.equal(True) - database['DBInstance']['KmsKeyId'].should.equal( - key['KeyMetadata']['KeyId']) + database["DBInstance"]["StorageEncrypted"].should.equal(True) + database["DBInstance"]["KmsKeyId"].should.equal(key["KeyMetadata"]["KeyId"]) @mock_rds2 def test_create_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - db_parameter_group = conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') + conn = boto3.client("rds", region_name="us-west-2") + db_parameter_group = conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) - db_parameter_group['DBParameterGroup'][ - 'DBParameterGroupName'].should.equal('test') - db_parameter_group['DBParameterGroup'][ - 'DBParameterGroupFamily'].should.equal('mysql5.6') - db_parameter_group['DBParameterGroup'][ - 'Description'].should.equal('test parameter group') + db_parameter_group["DBParameterGroup"]["DBParameterGroupName"].should.equal("test") + db_parameter_group["DBParameterGroup"]["DBParameterGroupFamily"].should.equal( + "mysql5.6" + ) + db_parameter_group["DBParameterGroup"]["Description"].should.equal( + "test parameter group" + ) @mock_rds2 def test_create_db_instance_with_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - db_parameter_group = conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') + conn = boto3.client("rds", region_name="us-west-2") + db_parameter_group = conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='mysql', - DBInstanceClass='db.m1.small', - DBParameterGroupName='test', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234) + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="mysql", + DBInstanceClass="db.m1.small", + DBParameterGroupName="test", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + ) - len(database['DBInstance']['DBParameterGroups']).should.equal(1) - database['DBInstance']['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') - database['DBInstance']['DBParameterGroups'][0][ - 'ParameterApplyStatus'].should.equal('in-sync') + len(database["DBInstance"]["DBParameterGroups"]).should.equal(1) + database["DBInstance"]["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "test" + ) + database["DBInstance"]["DBParameterGroups"][0]["ParameterApplyStatus"].should.equal( + "in-sync" + ) @mock_rds2 def test_create_database_with_default_port(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='postgres', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - DBSecurityGroups=["my_sg"]) - database['DBInstance']['Endpoint']['Port'].should.equal(5432) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="postgres", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + DBSecurityGroups=["my_sg"], + ) + database["DBInstance"]["Endpoint"]["Port"].should.equal(5432) @mock_rds2 def test_modify_db_instance_with_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - database = conn.create_db_instance(DBInstanceIdentifier='db-master-1', - AllocatedStorage=10, - Engine='mysql', - DBInstanceClass='db.m1.small', - MasterUsername='root', - MasterUserPassword='hunter2', - Port=1234) + conn = boto3.client("rds", region_name="us-west-2") + database = conn.create_db_instance( + DBInstanceIdentifier="db-master-1", + AllocatedStorage=10, + Engine="mysql", + DBInstanceClass="db.m1.small", + MasterUsername="root", + MasterUserPassword="hunter2", + Port=1234, + ) - len(database['DBInstance']['DBParameterGroups']).should.equal(1) - database['DBInstance']['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('default.mysql5.6') - database['DBInstance']['DBParameterGroups'][0][ - 'ParameterApplyStatus'].should.equal('in-sync') + len(database["DBInstance"]["DBParameterGroups"]).should.equal(1) + database["DBInstance"]["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "default.mysql5.6" + ) + database["DBInstance"]["DBParameterGroups"][0]["ParameterApplyStatus"].should.equal( + "in-sync" + ) - db_parameter_group = conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - conn.modify_db_instance(DBInstanceIdentifier='db-master-1', - DBParameterGroupName='test', - ApplyImmediately=True) + db_parameter_group = conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + conn.modify_db_instance( + DBInstanceIdentifier="db-master-1", + DBParameterGroupName="test", + ApplyImmediately=True, + ) - database = conn.describe_db_instances( - DBInstanceIdentifier='db-master-1')['DBInstances'][0] - len(database['DBParameterGroups']).should.equal(1) - database['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') - database['DBParameterGroups'][0][ - 'ParameterApplyStatus'].should.equal('in-sync') + database = conn.describe_db_instances(DBInstanceIdentifier="db-master-1")[ + "DBInstances" + ][0] + len(database["DBParameterGroups"]).should.equal(1) + database["DBParameterGroups"][0]["DBParameterGroupName"].should.equal("test") + database["DBParameterGroups"][0]["ParameterApplyStatus"].should.equal("in-sync") @mock_rds2 def test_create_db_parameter_group_empty_description(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group.when.called_with(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group.when.called_with( + DBParameterGroupName="test", DBParameterGroupFamily="mysql5.6", Description="" + ).should.throw(ClientError) @mock_rds2 def test_create_db_parameter_group_duplicate(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - conn.create_db_parameter_group.when.called_with(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group').should.throw(ClientError) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + conn.create_db_parameter_group.when.called_with( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ).should.throw(ClientError) @mock_rds2 def test_describe_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - db_parameter_groups['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + db_parameter_groups["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "test" + ) @mock_rds2 def test_describe_non_existant_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - len(db_parameter_groups['DBParameterGroups']).should.equal(0) + conn = boto3.client("rds", region_name="us-west-2") + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + len(db_parameter_groups["DBParameterGroups"]).should.equal(0) @mock_rds2 def test_delete_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - db_parameter_groups['DBParameterGroups'][0][ - 'DBParameterGroupName'].should.equal('test') - conn.delete_db_parameter_group(DBParameterGroupName='test') - db_parameter_groups = conn.describe_db_parameter_groups( - DBParameterGroupName='test') - len(db_parameter_groups['DBParameterGroups']).should.equal(0) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + db_parameter_groups["DBParameterGroups"][0]["DBParameterGroupName"].should.equal( + "test" + ) + conn.delete_db_parameter_group(DBParameterGroupName="test") + db_parameter_groups = conn.describe_db_parameter_groups(DBParameterGroupName="test") + len(db_parameter_groups["DBParameterGroups"]).should.equal(0) @mock_rds2 def test_modify_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group') + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + ) - modify_result = conn.modify_db_parameter_group(DBParameterGroupName='test', - Parameters=[{ - 'ParameterName': 'foo', - 'ParameterValue': 'foo_val', - 'Description': 'test param', - 'ApplyMethod': 'immediate' - }] - ) + modify_result = conn.modify_db_parameter_group( + DBParameterGroupName="test", + Parameters=[ + { + "ParameterName": "foo", + "ParameterValue": "foo_val", + "Description": "test param", + "ApplyMethod": "immediate", + } + ], + ) - modify_result['DBParameterGroupName'].should.equal('test') + modify_result["DBParameterGroupName"].should.equal("test") - db_parameters = conn.describe_db_parameters(DBParameterGroupName='test') - db_parameters['Parameters'][0]['ParameterName'].should.equal('foo') - db_parameters['Parameters'][0]['ParameterValue'].should.equal('foo_val') - db_parameters['Parameters'][0]['Description'].should.equal('test param') - db_parameters['Parameters'][0]['ApplyMethod'].should.equal('immediate') + db_parameters = conn.describe_db_parameters(DBParameterGroupName="test") + db_parameters["Parameters"][0]["ParameterName"].should.equal("foo") + db_parameters["Parameters"][0]["ParameterValue"].should.equal("foo_val") + db_parameters["Parameters"][0]["Description"].should.equal("test param") + db_parameters["Parameters"][0]["ApplyMethod"].should.equal("immediate") @mock_rds2 def test_delete_non_existant_db_parameter_group(): - conn = boto3.client('rds', region_name='us-west-2') + conn = boto3.client("rds", region_name="us-west-2") conn.delete_db_parameter_group.when.called_with( - DBParameterGroupName='non-existant').should.throw(ClientError) + DBParameterGroupName="non-existant" + ).should.throw(ClientError) @mock_rds2 def test_create_parameter_group_with_tags(): - conn = boto3.client('rds', region_name='us-west-2') - conn.create_db_parameter_group(DBParameterGroupName='test', - DBParameterGroupFamily='mysql5.6', - Description='test parameter group', - Tags=[{ - 'Key': 'foo', - 'Value': 'bar', - }]) + conn = boto3.client("rds", region_name="us-west-2") + conn.create_db_parameter_group( + DBParameterGroupName="test", + DBParameterGroupFamily="mysql5.6", + Description="test parameter group", + Tags=[{"Key": "foo", "Value": "bar"}], + ) result = conn.list_tags_for_resource( - ResourceName='arn:aws:rds:us-west-2:1234567890:pg:test') - result['TagList'].should.equal([{'Value': 'bar', 'Key': 'foo'}]) + ResourceName="arn:aws:rds:us-west-2:1234567890:pg:test" + ) + result["TagList"].should.equal([{"Value": "bar", "Key": "foo"}]) diff --git a/tests/test_rds2/test_server.py b/tests/test_rds2/test_server.py index f9489e054..dade82c9c 100644 --- a/tests/test_rds2/test_server.py +++ b/tests/test_rds2/test_server.py @@ -5,12 +5,12 @@ import sure # noqa import moto.server as server from moto import mock_rds2 -''' +""" Test the different server responses -''' +""" -#@mock_rds2 +# @mock_rds2 # def test_list_databases(): # backend = server.create_backend_app("rds2") # test_client = backend.test_client() diff --git a/tests/test_redshift/test_redshift.py b/tests/test_redshift/test_redshift.py index 79e283e5b..528eaa5e0 100644 --- a/tests/test_redshift/test_redshift.py +++ b/tests/test_redshift/test_redshift.py @@ -9,11 +9,9 @@ from boto.redshift.exceptions import ( ClusterParameterGroupNotFound, ClusterSecurityGroupNotFound, ClusterSubnetGroupNotFound, - InvalidSubnet -) -from botocore.exceptions import ( - ClientError + InvalidSubnet, ) +from botocore.exceptions import ClientError import sure # noqa from moto import mock_ec2 @@ -24,83 +22,84 @@ from moto import mock_redshift_deprecated @mock_redshift def test_create_cluster_boto3(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") response = client.create_cluster( - DBName='test', - ClusterIdentifier='test', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", ) - response['Cluster']['NodeType'].should.equal('ds2.xlarge') - create_time = response['Cluster']['ClusterCreateTime'] + response["Cluster"]["NodeType"].should.equal("ds2.xlarge") + create_time = response["Cluster"]["ClusterCreateTime"] create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) - create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) - response['Cluster']['EnhancedVpcRouting'].should.equal(False) + create_time.should.be.greater_than( + datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1) + ) + response["Cluster"]["EnhancedVpcRouting"].should.equal(False) + @mock_redshift def test_create_cluster_boto3(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") response = client.create_cluster( - DBName='test', - ClusterIdentifier='test', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', - EnhancedVpcRouting=True + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", + EnhancedVpcRouting=True, ) - response['Cluster']['NodeType'].should.equal('ds2.xlarge') - create_time = response['Cluster']['ClusterCreateTime'] + response["Cluster"]["NodeType"].should.equal("ds2.xlarge") + create_time = response["Cluster"]["ClusterCreateTime"] create_time.should.be.lower_than(datetime.datetime.now(create_time.tzinfo)) - create_time.should.be.greater_than(datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1)) - response['Cluster']['EnhancedVpcRouting'].should.equal(True) + create_time.should.be.greater_than( + datetime.datetime.now(create_time.tzinfo) - datetime.timedelta(minutes=1) + ) + response["Cluster"]["EnhancedVpcRouting"].should.equal(True) @mock_redshift def test_create_snapshot_copy_grant(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") grants = client.create_snapshot_copy_grant( - SnapshotCopyGrantName='test-us-east-1', - KmsKeyId='fake', + SnapshotCopyGrantName="test-us-east-1", KmsKeyId="fake" ) - grants['SnapshotCopyGrant']['SnapshotCopyGrantName'].should.equal('test-us-east-1') - grants['SnapshotCopyGrant']['KmsKeyId'].should.equal('fake') + grants["SnapshotCopyGrant"]["SnapshotCopyGrantName"].should.equal("test-us-east-1") + grants["SnapshotCopyGrant"]["KmsKeyId"].should.equal("fake") - client.delete_snapshot_copy_grant( - SnapshotCopyGrantName='test-us-east-1', - ) + client.delete_snapshot_copy_grant(SnapshotCopyGrantName="test-us-east-1") client.describe_snapshot_copy_grants.when.called_with( - SnapshotCopyGrantName='test-us-east-1', + SnapshotCopyGrantName="test-us-east-1" ).should.throw(Exception) @mock_redshift def test_create_many_snapshot_copy_grants(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") for i in range(10): client.create_snapshot_copy_grant( - SnapshotCopyGrantName='test-us-east-1-{0}'.format(i), - KmsKeyId='fake', + SnapshotCopyGrantName="test-us-east-1-{0}".format(i), KmsKeyId="fake" ) response = client.describe_snapshot_copy_grants() - len(response['SnapshotCopyGrants']).should.equal(10) + len(response["SnapshotCopyGrants"]).should.equal(10) @mock_redshift def test_no_snapshot_copy_grants(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") response = client.describe_snapshot_copy_grants() - len(response['SnapshotCopyGrants']).should.equal(0) + len(response["SnapshotCopyGrants"]).should.equal(0) @mock_redshift_deprecated def test_create_cluster(): conn = boto.redshift.connect_to_region("us-east-1") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" cluster_response = conn.create_cluster( cluster_identifier, @@ -117,36 +116,40 @@ def test_create_cluster(): allow_version_upgrade=True, number_of_nodes=3, ) - cluster_response['CreateClusterResponse']['CreateClusterResult'][ - 'Cluster']['ClusterStatus'].should.equal('creating') + cluster_response["CreateClusterResponse"]["CreateClusterResult"]["Cluster"][ + "ClusterStatus" + ].should.equal("creating") cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] - cluster['ClusterIdentifier'].should.equal(cluster_identifier) - cluster['NodeType'].should.equal("dw.hs1.xlarge") - cluster['MasterUsername'].should.equal("username") - cluster['DBName'].should.equal("my_db") - cluster['ClusterSecurityGroups'][0][ - 'ClusterSecurityGroupName'].should.equal("Default") - cluster['VpcSecurityGroups'].should.equal([]) - cluster['ClusterSubnetGroupName'].should.equal(None) - cluster['AvailabilityZone'].should.equal("us-east-1d") - cluster['PreferredMaintenanceWindow'].should.equal("Mon:03:00-Mon:11:00") - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("default.redshift-1.0") - cluster['AutomatedSnapshotRetentionPeriod'].should.equal(10) - cluster['Port'].should.equal(1234) - cluster['ClusterVersion'].should.equal("1.0") - cluster['AllowVersionUpgrade'].should.equal(True) - cluster['NumberOfNodes'].should.equal(3) + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("dw.hs1.xlarge") + cluster["MasterUsername"].should.equal("username") + cluster["DBName"].should.equal("my_db") + cluster["ClusterSecurityGroups"][0]["ClusterSecurityGroupName"].should.equal( + "Default" + ) + cluster["VpcSecurityGroups"].should.equal([]) + cluster["ClusterSubnetGroupName"].should.equal(None) + cluster["AvailabilityZone"].should.equal("us-east-1d") + cluster["PreferredMaintenanceWindow"].should.equal("Mon:03:00-Mon:11:00") + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "default.redshift-1.0" + ) + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(10) + cluster["Port"].should.equal(1234) + cluster["ClusterVersion"].should.equal("1.0") + cluster["AllowVersionUpgrade"].should.equal(True) + cluster["NumberOfNodes"].should.equal(3) @mock_redshift_deprecated def test_create_single_node_cluster(): conn = boto.redshift.connect_to_region("us-east-1") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" conn.create_cluster( cluster_identifier, @@ -158,20 +161,21 @@ def test_create_single_node_cluster(): ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] - cluster['ClusterIdentifier'].should.equal(cluster_identifier) - cluster['NodeType'].should.equal("dw.hs1.xlarge") - cluster['MasterUsername'].should.equal("username") - cluster['DBName'].should.equal("my_db") - cluster['NumberOfNodes'].should.equal(1) + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("dw.hs1.xlarge") + cluster["MasterUsername"].should.equal("username") + cluster["DBName"].should.equal("my_db") + cluster["NumberOfNodes"].should.equal(1) @mock_redshift_deprecated def test_default_cluster_attributes(): conn = boto.redshift.connect_to_region("us-east-1") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" conn.create_cluster( cluster_identifier, @@ -181,29 +185,31 @@ def test_default_cluster_attributes(): ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] - cluster['DBName'].should.equal("dev") - cluster['ClusterSubnetGroupName'].should.equal(None) - assert "us-east-" in cluster['AvailabilityZone'] - cluster['PreferredMaintenanceWindow'].should.equal("Mon:03:00-Mon:03:30") - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("default.redshift-1.0") - cluster['AutomatedSnapshotRetentionPeriod'].should.equal(1) - cluster['Port'].should.equal(5439) - cluster['ClusterVersion'].should.equal("1.0") - cluster['AllowVersionUpgrade'].should.equal(True) - cluster['NumberOfNodes'].should.equal(1) + cluster["DBName"].should.equal("dev") + cluster["ClusterSubnetGroupName"].should.equal(None) + assert "us-east-" in cluster["AvailabilityZone"] + cluster["PreferredMaintenanceWindow"].should.equal("Mon:03:00-Mon:03:30") + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "default.redshift-1.0" + ) + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(1) + cluster["Port"].should.equal(5439) + cluster["ClusterVersion"].should.equal("1.0") + cluster["AllowVersionUpgrade"].should.equal(True) + cluster["NumberOfNodes"].should.equal(1) @mock_redshift @mock_ec2 def test_create_cluster_in_subnet_group(): - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( ClusterSubnetGroupName="my_subnet_group", Description="This is my subnet group", @@ -215,25 +221,25 @@ def test_create_cluster_in_subnet_group(): NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - ClusterSubnetGroupName='my_subnet_group', + ClusterSubnetGroupName="my_subnet_group", ) cluster_response = client.describe_clusters(ClusterIdentifier="my_cluster") - cluster = cluster_response['Clusters'][0] - cluster['ClusterSubnetGroupName'].should.equal('my_subnet_group') + cluster = cluster_response["Clusters"][0] + cluster["ClusterSubnetGroupName"].should.equal("my_subnet_group") @mock_redshift @mock_ec2 def test_create_cluster_in_subnet_group_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') - subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock='10.0.0.0/24') - client = boto3.client('redshift', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', - SubnetIds=[subnet.id] + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", + SubnetIds=[subnet.id], ) client.create_cluster( @@ -241,46 +247,42 @@ def test_create_cluster_in_subnet_group_boto3(): NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - ClusterSubnetGroupName='my_subnet_group', + ClusterSubnetGroupName="my_subnet_group", ) cluster_response = client.describe_clusters(ClusterIdentifier="my_cluster") - cluster = cluster_response['Clusters'][0] - cluster['ClusterSubnetGroupName'].should.equal('my_subnet_group') + cluster = cluster_response["Clusters"][0] + cluster["ClusterSubnetGroupName"].should.equal("my_subnet_group") @mock_redshift_deprecated def test_create_cluster_with_security_group(): conn = boto.redshift.connect_to_region("us-east-1") - conn.create_cluster_security_group( - "security_group1", - "This is my security group", - ) - conn.create_cluster_security_group( - "security_group2", - "This is my security group", - ) + conn.create_cluster_security_group("security_group1", "This is my security group") + conn.create_cluster_security_group("security_group2", "This is my security group") - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" conn.create_cluster( cluster_identifier, node_type="dw.hs1.xlarge", master_username="username", master_user_password="password", - cluster_security_groups=["security_group1", "security_group2"] + cluster_security_groups=["security_group1", "security_group2"], ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - group_names = [group['ClusterSecurityGroupName'] - for group in cluster['ClusterSecurityGroups']] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + group_names = [ + group["ClusterSecurityGroupName"] for group in cluster["ClusterSecurityGroups"] + ] set(group_names).should.equal(set(["security_group1", "security_group2"])) @mock_redshift def test_create_cluster_with_security_group_boto3(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_security_group( ClusterSecurityGroupName="security_group1", Description="This is my security group", @@ -290,18 +292,19 @@ def test_create_cluster_with_security_group_boto3(): Description="This is my security group", ) - cluster_identifier = 'my_cluster' + cluster_identifier = "my_cluster" client.create_cluster( ClusterIdentifier=cluster_identifier, NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - ClusterSecurityGroups=["security_group1", "security_group2"] + ClusterSecurityGroups=["security_group1", "security_group2"], ) response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = response['Clusters'][0] - group_names = [group['ClusterSecurityGroupName'] - for group in cluster['ClusterSecurityGroups']] + cluster = response["Clusters"][0] + group_names = [ + group["ClusterSecurityGroupName"] for group in cluster["ClusterSecurityGroups"] + ] set(group_names).should.equal({"security_group1", "security_group2"}) @@ -313,7 +316,8 @@ def test_create_cluster_with_vpc_security_groups(): redshift_conn = boto.connect_redshift() vpc = vpc_conn.create_vpc("10.0.0.0/16") security_group = ec2_conn.create_security_group( - "vpc_security_group", "a group", vpc_id=vpc.id) + "vpc_security_group", "a group", vpc_id=vpc.id + ) redshift_conn.create_cluster( "my_cluster", @@ -324,24 +328,23 @@ def test_create_cluster_with_vpc_security_groups(): ) cluster_response = redshift_conn.describe_clusters("my_cluster") - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - group_ids = [group['VpcSecurityGroupId'] - for group in cluster['VpcSecurityGroups']] + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + group_ids = [group["VpcSecurityGroupId"] for group in cluster["VpcSecurityGroups"]] list(group_ids).should.equal([security_group.id]) @mock_redshift @mock_ec2 def test_create_cluster_with_vpc_security_groups_boto3(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') - client = boto3.client('redshift', region_name='us-east-1') - cluster_id = 'my_cluster' + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + client = boto3.client("redshift", region_name="us-east-1") + cluster_id = "my_cluster" security_group = ec2.create_security_group( - Description="vpc_security_group", - GroupName="a group", - VpcId=vpc.id) + Description="vpc_security_group", GroupName="a group", VpcId=vpc.id + ) client.create_cluster( ClusterIdentifier=cluster_id, NodeType="dw.hs1.xlarge", @@ -350,27 +353,26 @@ def test_create_cluster_with_vpc_security_groups_boto3(): VpcSecurityGroupIds=[security_group.id], ) response = client.describe_clusters(ClusterIdentifier=cluster_id) - cluster = response['Clusters'][0] - group_ids = [group['VpcSecurityGroupId'] - for group in cluster['VpcSecurityGroups']] + cluster = response["Clusters"][0] + group_ids = [group["VpcSecurityGroupId"] for group in cluster["VpcSecurityGroups"]] list(group_ids).should.equal([security_group.id]) @mock_redshift def test_create_cluster_with_iam_roles(): - iam_roles_arn = ['arn:aws:iam:::role/my-iam-role', ] - client = boto3.client('redshift', region_name='us-east-1') - cluster_id = 'my_cluster' + iam_roles_arn = ["arn:aws:iam:::role/my-iam-role"] + client = boto3.client("redshift", region_name="us-east-1") + cluster_id = "my_cluster" client.create_cluster( ClusterIdentifier=cluster_id, NodeType="dw.hs1.xlarge", MasterUsername="username", MasterUserPassword="password", - IamRoles=iam_roles_arn + IamRoles=iam_roles_arn, ) response = client.describe_clusters(ClusterIdentifier=cluster_id) - cluster = response['Clusters'][0] - iam_roles = [role['IamRoleArn'] for role in cluster['IamRoles']] + cluster = response["Clusters"][0] + iam_roles = [role["IamRoleArn"] for role in cluster["IamRoles"]] iam_roles_arn.should.equal(iam_roles) @@ -378,9 +380,7 @@ def test_create_cluster_with_iam_roles(): def test_create_cluster_with_parameter_group(): conn = boto.connect_redshift() conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) conn.create_cluster( @@ -388,21 +388,25 @@ def test_create_cluster_with_parameter_group(): node_type="dw.hs1.xlarge", master_username="username", master_user_password="password", - cluster_parameter_group_name='my_parameter_group', + cluster_parameter_group_name="my_parameter_group", ) cluster_response = conn.describe_clusters("my_cluster") - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("my_parameter_group") + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "my_parameter_group" + ) @mock_redshift_deprecated def test_describe_non_existent_cluster(): conn = boto.redshift.connect_to_region("us-east-1") - conn.describe_clusters.when.called_with( - "not-a-cluster").should.throw(ClusterNotFound) + conn.describe_clusters.when.called_with("not-a-cluster").should.throw( + ClusterNotFound + ) + @mock_redshift_deprecated def test_delete_cluster(): @@ -417,61 +421,68 @@ def test_delete_cluster(): master_user_password="password", ) - conn.delete_cluster.when.called_with(cluster_identifier, False).should.throw(AttributeError) + conn.delete_cluster.when.called_with(cluster_identifier, False).should.throw( + AttributeError + ) - clusters = conn.describe_clusters()['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'] + clusters = conn.describe_clusters()["DescribeClustersResponse"][ + "DescribeClustersResult" + ]["Clusters"] list(clusters).should.have.length_of(1) conn.delete_cluster( cluster_identifier=cluster_identifier, skip_final_cluster_snapshot=False, - final_cluster_snapshot_identifier=snapshot_identifier - ) + final_cluster_snapshot_identifier=snapshot_identifier, + ) - clusters = conn.describe_clusters()['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'] + clusters = conn.describe_clusters()["DescribeClustersResponse"][ + "DescribeClustersResult" + ]["Clusters"] list(clusters).should.have.length_of(0) snapshots = conn.describe_cluster_snapshots()["DescribeClusterSnapshotsResponse"][ - "DescribeClusterSnapshotsResult"]["Snapshots"] + "DescribeClusterSnapshotsResult" + ]["Snapshots"] list(snapshots).should.have.length_of(1) assert snapshot_identifier in snapshots[0]["SnapshotIdentifier"] # Delete invalid id - conn.delete_cluster.when.called_with( - "not-a-cluster").should.throw(ClusterNotFound) + conn.delete_cluster.when.called_with("not-a-cluster").should.throw(ClusterNotFound) @mock_redshift def test_modify_cluster_vpc_routing(): - iam_roles_arn = ['arn:aws:iam:::role/my-iam-role', ] - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' + iam_roles_arn = ["arn:aws:iam:::role/my-iam-role"] + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" client.create_cluster( ClusterIdentifier=cluster_identifier, NodeType="single-node", MasterUsername="username", MasterUserPassword="password", - IamRoles=iam_roles_arn + IamRoles=iam_roles_arn, ) cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = cluster_response['Clusters'][0] - cluster['EnhancedVpcRouting'].should.equal(False) + cluster = cluster_response["Clusters"][0] + cluster["EnhancedVpcRouting"].should.equal(False) - client.create_cluster_security_group(ClusterSecurityGroupName='security_group', - Description='security_group') + client.create_cluster_security_group( + ClusterSecurityGroupName="security_group", Description="security_group" + ) - client.create_cluster_parameter_group(ParameterGroupName='my_parameter_group', - ParameterGroupFamily='redshift-1.0', - Description='my_parameter_group') + client.create_cluster_parameter_group( + ParameterGroupName="my_parameter_group", + ParameterGroupFamily="redshift-1.0", + Description="my_parameter_group", + ) client.modify_cluster( ClusterIdentifier=cluster_identifier, - ClusterType='multi-node', + ClusterType="multi-node", NodeType="ds2.8xlarge", NumberOfNodes=3, ClusterSecurityGroups=["security_group"], @@ -481,45 +492,42 @@ def test_modify_cluster_vpc_routing(): PreferredMaintenanceWindow="Tue:03:00-Tue:11:00", AllowVersionUpgrade=False, NewClusterIdentifier=cluster_identifier, - EnhancedVpcRouting=True + EnhancedVpcRouting=True, ) cluster_response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = cluster_response['Clusters'][0] - cluster['ClusterIdentifier'].should.equal(cluster_identifier) - cluster['NodeType'].should.equal("ds2.8xlarge") - cluster['PreferredMaintenanceWindow'].should.equal("Tue:03:00-Tue:11:00") - cluster['AutomatedSnapshotRetentionPeriod'].should.equal(7) - cluster['AllowVersionUpgrade'].should.equal(False) + cluster = cluster_response["Clusters"][0] + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("ds2.8xlarge") + cluster["PreferredMaintenanceWindow"].should.equal("Tue:03:00-Tue:11:00") + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(7) + cluster["AllowVersionUpgrade"].should.equal(False) # This one should remain unmodified. - cluster['NumberOfNodes'].should.equal(3) - cluster['EnhancedVpcRouting'].should.equal(True) + cluster["NumberOfNodes"].should.equal(3) + cluster["EnhancedVpcRouting"].should.equal(True) @mock_redshift_deprecated def test_modify_cluster(): conn = boto.connect_redshift() - cluster_identifier = 'my_cluster' - conn.create_cluster_security_group( - "security_group", - "This is my security group", - ) + cluster_identifier = "my_cluster" + conn.create_cluster_security_group("security_group", "This is my security group") conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) conn.create_cluster( cluster_identifier, - node_type='single-node', + node_type="single-node", master_username="username", master_user_password="password", ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse']['DescribeClustersResult']['Clusters'][0] - cluster['EnhancedVpcRouting'].should.equal(False) + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + cluster["EnhancedVpcRouting"].should.equal(False) conn.modify_cluster( cluster_identifier, @@ -535,44 +543,47 @@ def test_modify_cluster(): ) cluster_response = conn.describe_clusters(cluster_identifier) - cluster = cluster_response['DescribeClustersResponse'][ - 'DescribeClustersResult']['Clusters'][0] - cluster['ClusterIdentifier'].should.equal(cluster_identifier) - cluster['NodeType'].should.equal("dw.hs1.xlarge") - cluster['ClusterSecurityGroups'][0][ - 'ClusterSecurityGroupName'].should.equal("security_group") - cluster['PreferredMaintenanceWindow'].should.equal("Tue:03:00-Tue:11:00") - cluster['ClusterParameterGroups'][0][ - 'ParameterGroupName'].should.equal("my_parameter_group") - cluster['AutomatedSnapshotRetentionPeriod'].should.equal(7) - cluster['AllowVersionUpgrade'].should.equal(False) + cluster = cluster_response["DescribeClustersResponse"]["DescribeClustersResult"][ + "Clusters" + ][0] + cluster["ClusterIdentifier"].should.equal(cluster_identifier) + cluster["NodeType"].should.equal("dw.hs1.xlarge") + cluster["ClusterSecurityGroups"][0]["ClusterSecurityGroupName"].should.equal( + "security_group" + ) + cluster["PreferredMaintenanceWindow"].should.equal("Tue:03:00-Tue:11:00") + cluster["ClusterParameterGroups"][0]["ParameterGroupName"].should.equal( + "my_parameter_group" + ) + cluster["AutomatedSnapshotRetentionPeriod"].should.equal(7) + cluster["AllowVersionUpgrade"].should.equal(False) # This one should remain unmodified. - cluster['NumberOfNodes'].should.equal(1) + cluster["NumberOfNodes"].should.equal(1) @mock_redshift @mock_ec2 def test_create_cluster_subnet_group(): - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet1 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") subnet2 = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.1.0/24") - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", SubnetIds=[subnet1.id, subnet2.id], ) subnets_response = client.describe_cluster_subnet_groups( - ClusterSubnetGroupName="my_subnet_group") - my_subnet = subnets_response['ClusterSubnetGroups'][0] + ClusterSubnetGroupName="my_subnet_group" + ) + my_subnet = subnets_response["ClusterSubnetGroups"][0] - my_subnet['ClusterSubnetGroupName'].should.equal("my_subnet_group") - my_subnet['Description'].should.equal("This is my subnet group") - subnet_ids = [subnet['SubnetIdentifier'] - for subnet in my_subnet['Subnets']] + my_subnet["ClusterSubnetGroupName"].should.equal("my_subnet_group") + my_subnet["Description"].should.equal("This is my subnet group") + subnet_ids = [subnet["SubnetIdentifier"] for subnet in my_subnet["Subnets"]] set(subnet_ids).should.equal(set([subnet1.id, subnet2.id])) @@ -581,9 +592,7 @@ def test_create_cluster_subnet_group(): def test_create_invalid_cluster_subnet_group(): redshift_conn = boto.connect_redshift() redshift_conn.create_cluster_subnet_group.when.called_with( - "my_subnet", - "This is my subnet group", - subnet_ids=["subnet-1234"], + "my_subnet", "This is my subnet group", subnet_ids=["subnet-1234"] ).should.throw(InvalidSubnet) @@ -591,754 +600,733 @@ def test_create_invalid_cluster_subnet_group(): def test_describe_non_existent_subnet_group(): conn = boto.redshift.connect_to_region("us-east-1") conn.describe_cluster_subnet_groups.when.called_with( - "not-a-subnet-group").should.throw(ClusterSubnetGroupNotFound) + "not-a-subnet-group" + ).should.throw(ClusterSubnetGroupNotFound) @mock_redshift @mock_ec2 def test_delete_cluster_subnet_group(): - ec2 = boto3.resource('ec2', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", SubnetIds=[subnet.id], ) subnets_response = client.describe_cluster_subnet_groups() - subnets = subnets_response['ClusterSubnetGroups'] + subnets = subnets_response["ClusterSubnetGroups"] subnets.should.have.length_of(1) client.delete_cluster_subnet_group(ClusterSubnetGroupName="my_subnet_group") subnets_response = client.describe_cluster_subnet_groups() - subnets = subnets_response['ClusterSubnetGroups'] + subnets = subnets_response["ClusterSubnetGroups"] subnets.should.have.length_of(0) # Delete invalid id client.delete_cluster_subnet_group.when.called_with( - ClusterSubnetGroupName="not-a-subnet-group").should.throw(ClientError) + ClusterSubnetGroupName="not-a-subnet-group" + ).should.throw(ClientError) @mock_redshift_deprecated def test_create_cluster_security_group(): conn = boto.connect_redshift() - conn.create_cluster_security_group( - "my_security_group", - "This is my security group", - ) + conn.create_cluster_security_group("my_security_group", "This is my security group") - groups_response = conn.describe_cluster_security_groups( - "my_security_group") - my_group = groups_response['DescribeClusterSecurityGroupsResponse'][ - 'DescribeClusterSecurityGroupsResult']['ClusterSecurityGroups'][0] + groups_response = conn.describe_cluster_security_groups("my_security_group") + my_group = groups_response["DescribeClusterSecurityGroupsResponse"][ + "DescribeClusterSecurityGroupsResult" + ]["ClusterSecurityGroups"][0] - my_group['ClusterSecurityGroupName'].should.equal("my_security_group") - my_group['Description'].should.equal("This is my security group") - list(my_group['IPRanges']).should.equal([]) + my_group["ClusterSecurityGroupName"].should.equal("my_security_group") + my_group["Description"].should.equal("This is my security group") + list(my_group["IPRanges"]).should.equal([]) @mock_redshift_deprecated def test_describe_non_existent_security_group(): conn = boto.redshift.connect_to_region("us-east-1") conn.describe_cluster_security_groups.when.called_with( - "not-a-security-group").should.throw(ClusterSecurityGroupNotFound) + "not-a-security-group" + ).should.throw(ClusterSecurityGroupNotFound) @mock_redshift_deprecated def test_delete_cluster_security_group(): conn = boto.connect_redshift() - conn.create_cluster_security_group( - "my_security_group", - "This is my security group", - ) + conn.create_cluster_security_group("my_security_group", "This is my security group") groups_response = conn.describe_cluster_security_groups() - groups = groups_response['DescribeClusterSecurityGroupsResponse'][ - 'DescribeClusterSecurityGroupsResult']['ClusterSecurityGroups'] + groups = groups_response["DescribeClusterSecurityGroupsResponse"][ + "DescribeClusterSecurityGroupsResult" + ]["ClusterSecurityGroups"] groups.should.have.length_of(2) # The default group already exists conn.delete_cluster_security_group("my_security_group") groups_response = conn.describe_cluster_security_groups() - groups = groups_response['DescribeClusterSecurityGroupsResponse'][ - 'DescribeClusterSecurityGroupsResult']['ClusterSecurityGroups'] + groups = groups_response["DescribeClusterSecurityGroupsResponse"][ + "DescribeClusterSecurityGroupsResult" + ]["ClusterSecurityGroups"] groups.should.have.length_of(1) # Delete invalid id conn.delete_cluster_security_group.when.called_with( - "not-a-security-group").should.throw(ClusterSecurityGroupNotFound) + "not-a-security-group" + ).should.throw(ClusterSecurityGroupNotFound) @mock_redshift_deprecated def test_create_cluster_parameter_group(): conn = boto.connect_redshift() conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) - groups_response = conn.describe_cluster_parameter_groups( - "my_parameter_group") - my_group = groups_response['DescribeClusterParameterGroupsResponse'][ - 'DescribeClusterParameterGroupsResult']['ParameterGroups'][0] + groups_response = conn.describe_cluster_parameter_groups("my_parameter_group") + my_group = groups_response["DescribeClusterParameterGroupsResponse"][ + "DescribeClusterParameterGroupsResult" + ]["ParameterGroups"][0] - my_group['ParameterGroupName'].should.equal("my_parameter_group") - my_group['ParameterGroupFamily'].should.equal("redshift-1.0") - my_group['Description'].should.equal("This is my parameter group") + my_group["ParameterGroupName"].should.equal("my_parameter_group") + my_group["ParameterGroupFamily"].should.equal("redshift-1.0") + my_group["Description"].should.equal("This is my parameter group") @mock_redshift_deprecated def test_describe_non_existent_parameter_group(): conn = boto.redshift.connect_to_region("us-east-1") conn.describe_cluster_parameter_groups.when.called_with( - "not-a-parameter-group").should.throw(ClusterParameterGroupNotFound) + "not-a-parameter-group" + ).should.throw(ClusterParameterGroupNotFound) @mock_redshift_deprecated def test_delete_cluster_parameter_group(): conn = boto.connect_redshift() conn.create_cluster_parameter_group( - "my_parameter_group", - "redshift-1.0", - "This is my parameter group", + "my_parameter_group", "redshift-1.0", "This is my parameter group" ) groups_response = conn.describe_cluster_parameter_groups() - groups = groups_response['DescribeClusterParameterGroupsResponse'][ - 'DescribeClusterParameterGroupsResult']['ParameterGroups'] + groups = groups_response["DescribeClusterParameterGroupsResponse"][ + "DescribeClusterParameterGroupsResult" + ]["ParameterGroups"] groups.should.have.length_of(2) # The default group already exists conn.delete_cluster_parameter_group("my_parameter_group") groups_response = conn.describe_cluster_parameter_groups() - groups = groups_response['DescribeClusterParameterGroupsResponse'][ - 'DescribeClusterParameterGroupsResult']['ParameterGroups'] + groups = groups_response["DescribeClusterParameterGroupsResponse"][ + "DescribeClusterParameterGroupsResult" + ]["ParameterGroups"] groups.should.have.length_of(1) # Delete invalid id conn.delete_cluster_parameter_group.when.called_with( - "not-a-parameter-group").should.throw(ClusterParameterGroupNotFound) + "not-a-parameter-group" + ).should.throw(ClusterParameterGroupNotFound) @mock_redshift def test_create_cluster_snapshot_of_non_existent_cluster(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'non-existent-cluster-id' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "non-existent-cluster-id" client.create_cluster_snapshot.when.called_with( - SnapshotIdentifier='snapshot-id', - ClusterIdentifier=cluster_identifier, - ).should.throw(ClientError, 'Cluster {} not found.'.format(cluster_identifier)) + SnapshotIdentifier="snapshot-id", ClusterIdentifier=cluster_identifier + ).should.throw(ClientError, "Cluster {} not found.".format(cluster_identifier)) @mock_redshift def test_create_cluster_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" cluster_response = client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - EnhancedVpcRouting=True + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + EnhancedVpcRouting=True, ) - cluster_response['Cluster']['NodeType'].should.equal('ds2.xlarge') + cluster_response["Cluster"]["NodeType"].should.equal("ds2.xlarge") snapshot_response = client.create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, - Tags=[{'Key': 'test-tag-key', - 'Value': 'test-tag-value'}] + Tags=[{"Key": "test-tag-key", "Value": "test-tag-value"}], ) - snapshot = snapshot_response['Snapshot'] - snapshot['SnapshotIdentifier'].should.equal(snapshot_identifier) - snapshot['ClusterIdentifier'].should.equal(cluster_identifier) - snapshot['NumberOfNodes'].should.equal(1) - snapshot['NodeType'].should.equal('ds2.xlarge') - snapshot['MasterUsername'].should.equal('username') + snapshot = snapshot_response["Snapshot"] + snapshot["SnapshotIdentifier"].should.equal(snapshot_identifier) + snapshot["ClusterIdentifier"].should.equal(cluster_identifier) + snapshot["NumberOfNodes"].should.equal(1) + snapshot["NodeType"].should.equal("ds2.xlarge") + snapshot["MasterUsername"].should.equal("username") @mock_redshift def test_describe_cluster_snapshots(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier_1 = 'my_snapshot_1' - snapshot_identifier_2 = 'my_snapshot_2' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier_1 = "my_snapshot_1" + snapshot_identifier_2 = "my_snapshot_2" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier_1, - ClusterIdentifier=cluster_identifier, + SnapshotIdentifier=snapshot_identifier_1, ClusterIdentifier=cluster_identifier ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier_2, - ClusterIdentifier=cluster_identifier, + SnapshotIdentifier=snapshot_identifier_2, ClusterIdentifier=cluster_identifier ) - resp_snap_1 = client.describe_cluster_snapshots(SnapshotIdentifier=snapshot_identifier_1) - snapshot_1 = resp_snap_1['Snapshots'][0] - snapshot_1['SnapshotIdentifier'].should.equal(snapshot_identifier_1) - snapshot_1['ClusterIdentifier'].should.equal(cluster_identifier) - snapshot_1['NumberOfNodes'].should.equal(1) - snapshot_1['NodeType'].should.equal('ds2.xlarge') - snapshot_1['MasterUsername'].should.equal('username') + resp_snap_1 = client.describe_cluster_snapshots( + SnapshotIdentifier=snapshot_identifier_1 + ) + snapshot_1 = resp_snap_1["Snapshots"][0] + snapshot_1["SnapshotIdentifier"].should.equal(snapshot_identifier_1) + snapshot_1["ClusterIdentifier"].should.equal(cluster_identifier) + snapshot_1["NumberOfNodes"].should.equal(1) + snapshot_1["NodeType"].should.equal("ds2.xlarge") + snapshot_1["MasterUsername"].should.equal("username") - resp_snap_2 = client.describe_cluster_snapshots(SnapshotIdentifier=snapshot_identifier_2) - snapshot_2 = resp_snap_2['Snapshots'][0] - snapshot_2['SnapshotIdentifier'].should.equal(snapshot_identifier_2) - snapshot_2['ClusterIdentifier'].should.equal(cluster_identifier) - snapshot_2['NumberOfNodes'].should.equal(1) - snapshot_2['NodeType'].should.equal('ds2.xlarge') - snapshot_2['MasterUsername'].should.equal('username') + resp_snap_2 = client.describe_cluster_snapshots( + SnapshotIdentifier=snapshot_identifier_2 + ) + snapshot_2 = resp_snap_2["Snapshots"][0] + snapshot_2["SnapshotIdentifier"].should.equal(snapshot_identifier_2) + snapshot_2["ClusterIdentifier"].should.equal(cluster_identifier) + snapshot_2["NumberOfNodes"].should.equal(1) + snapshot_2["NodeType"].should.equal("ds2.xlarge") + snapshot_2["MasterUsername"].should.equal("username") resp_clust = client.describe_cluster_snapshots(ClusterIdentifier=cluster_identifier) - resp_clust['Snapshots'][0].should.equal(resp_snap_1['Snapshots'][0]) - resp_clust['Snapshots'][1].should.equal(resp_snap_2['Snapshots'][0]) + resp_clust["Snapshots"][0].should.equal(resp_snap_1["Snapshots"][0]) + resp_clust["Snapshots"][1].should.equal(resp_snap_2["Snapshots"][0]) @mock_redshift def test_describe_cluster_snapshots_not_found_error(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" client.describe_cluster_snapshots.when.called_with( - ClusterIdentifier=cluster_identifier, - ).should.throw(ClientError, 'Cluster {} not found.'.format(cluster_identifier)) + ClusterIdentifier=cluster_identifier + ).should.throw(ClientError, "Cluster {} not found.".format(cluster_identifier)) client.describe_cluster_snapshots.when.called_with( SnapshotIdentifier=snapshot_identifier - ).should.throw(ClientError, 'Snapshot {} not found.'.format(snapshot_identifier)) + ).should.throw(ClientError, "Snapshot {} not found.".format(snapshot_identifier)) @mock_redshift def test_delete_cluster_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" client.create_cluster( ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier, - ClusterIdentifier=cluster_identifier + SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier ) - snapshots = client.describe_cluster_snapshots()['Snapshots'] + snapshots = client.describe_cluster_snapshots()["Snapshots"] list(snapshots).should.have.length_of(1) - client.delete_cluster_snapshot(SnapshotIdentifier=snapshot_identifier)[ - 'Snapshot']['Status'].should.equal('deleted') + client.delete_cluster_snapshot(SnapshotIdentifier=snapshot_identifier)["Snapshot"][ + "Status" + ].should.equal("deleted") - snapshots = client.describe_cluster_snapshots()['Snapshots'] + snapshots = client.describe_cluster_snapshots()["Snapshots"] list(snapshots).should.have.length_of(0) # Delete invalid id client.delete_cluster_snapshot.when.called_with( - SnapshotIdentifier="not-a-snapshot").should.throw(ClientError) + SnapshotIdentifier="not-a-snapshot" + ).should.throw(ClientError) @mock_redshift def test_cluster_snapshot_already_exists(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - snapshot_identifier = 'my_snapshot' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + snapshot_identifier = "my_snapshot" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) client.create_cluster_snapshot( - SnapshotIdentifier=snapshot_identifier, - ClusterIdentifier=cluster_identifier + SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier ) client.create_cluster_snapshot.when.called_with( - SnapshotIdentifier=snapshot_identifier, - ClusterIdentifier=cluster_identifier + SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier ).should.throw(ClientError) @mock_redshift def test_create_cluster_from_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') - original_cluster_identifier = 'original-cluster' - original_snapshot_identifier = 'original-snapshot' - new_cluster_identifier = 'new-cluster' + client = boto3.client("redshift", region_name="us-east-1") + original_cluster_identifier = "original-cluster" + original_snapshot_identifier = "original-snapshot" + new_cluster_identifier = "new-cluster" client.create_cluster( ClusterIdentifier=original_cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", EnhancedVpcRouting=True, ) client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, - ClusterIdentifier=original_cluster_identifier + ClusterIdentifier=original_cluster_identifier, ) response = client.restore_from_cluster_snapshot( ClusterIdentifier=new_cluster_identifier, SnapshotIdentifier=original_snapshot_identifier, - Port=1234 + Port=1234, ) - response['Cluster']['ClusterStatus'].should.equal('creating') + response["Cluster"]["ClusterStatus"].should.equal("creating") + + response = client.describe_clusters(ClusterIdentifier=new_cluster_identifier) + new_cluster = response["Clusters"][0] + new_cluster["NodeType"].should.equal("ds2.xlarge") + new_cluster["MasterUsername"].should.equal("username") + new_cluster["Endpoint"]["Port"].should.equal(1234) + new_cluster["EnhancedVpcRouting"].should.equal(True) - response = client.describe_clusters( - ClusterIdentifier=new_cluster_identifier - ) - new_cluster = response['Clusters'][0] - new_cluster['NodeType'].should.equal('ds2.xlarge') - new_cluster['MasterUsername'].should.equal('username') - new_cluster['Endpoint']['Port'].should.equal(1234) - new_cluster['EnhancedVpcRouting'].should.equal(True) @mock_redshift def test_create_cluster_from_snapshot_with_waiter(): - client = boto3.client('redshift', region_name='us-east-1') - original_cluster_identifier = 'original-cluster' - original_snapshot_identifier = 'original-snapshot' - new_cluster_identifier = 'new-cluster' + client = boto3.client("redshift", region_name="us-east-1") + original_cluster_identifier = "original-cluster" + original_snapshot_identifier = "original-snapshot" + new_cluster_identifier = "new-cluster" client.create_cluster( ClusterIdentifier=original_cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - EnhancedVpcRouting=True + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + EnhancedVpcRouting=True, ) client.create_cluster_snapshot( SnapshotIdentifier=original_snapshot_identifier, - ClusterIdentifier=original_cluster_identifier + ClusterIdentifier=original_cluster_identifier, ) response = client.restore_from_cluster_snapshot( ClusterIdentifier=new_cluster_identifier, SnapshotIdentifier=original_snapshot_identifier, - Port=1234 + Port=1234, ) - response['Cluster']['ClusterStatus'].should.equal('creating') + response["Cluster"]["ClusterStatus"].should.equal("creating") - client.get_waiter('cluster_restored').wait( + client.get_waiter("cluster_restored").wait( ClusterIdentifier=new_cluster_identifier, - WaiterConfig={ - 'Delay': 1, - 'MaxAttempts': 2, - } + WaiterConfig={"Delay": 1, "MaxAttempts": 2}, ) - response = client.describe_clusters( - ClusterIdentifier=new_cluster_identifier - ) - new_cluster = response['Clusters'][0] - new_cluster['NodeType'].should.equal('ds2.xlarge') - new_cluster['MasterUsername'].should.equal('username') - new_cluster['EnhancedVpcRouting'].should.equal(True) - new_cluster['Endpoint']['Port'].should.equal(1234) + response = client.describe_clusters(ClusterIdentifier=new_cluster_identifier) + new_cluster = response["Clusters"][0] + new_cluster["NodeType"].should.equal("ds2.xlarge") + new_cluster["MasterUsername"].should.equal("username") + new_cluster["EnhancedVpcRouting"].should.equal(True) + new_cluster["Endpoint"]["Port"].should.equal(1234) @mock_redshift def test_create_cluster_from_non_existent_snapshot(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.restore_from_cluster_snapshot.when.called_with( - ClusterIdentifier='cluster-id', - SnapshotIdentifier='non-existent-snapshot', - ).should.throw(ClientError, 'Snapshot non-existent-snapshot not found.') + ClusterIdentifier="cluster-id", SnapshotIdentifier="non-existent-snapshot" + ).should.throw(ClientError, "Snapshot non-existent-snapshot not found.") @mock_redshift def test_create_cluster_status_update(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'test-cluster' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "test-cluster" response = client.create_cluster( ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) - response['Cluster']['ClusterStatus'].should.equal('creating') + response["Cluster"]["ClusterStatus"].should.equal("creating") - response = client.describe_clusters( - ClusterIdentifier=cluster_identifier - ) - response['Clusters'][0]['ClusterStatus'].should.equal('available') + response = client.describe_clusters(ClusterIdentifier=cluster_identifier) + response["Clusters"][0]["ClusterStatus"].should.equal("available") @mock_redshift def test_describe_tags_with_resource_type(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'my_cluster' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - snapshot_identifier = 'my_snapshot' - snapshot_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'snapshot:{}/{}'.format(cluster_identifier, - snapshot_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "my_cluster" + cluster_arn = "arn:aws:redshift:us-east-1:123456789012:" "cluster:{}".format( + cluster_identifier + ) + snapshot_identifier = "my_snapshot" + snapshot_arn = "arn:aws:redshift:us-east-1:123456789012:" "snapshot:{}/{}".format( + cluster_identifier, snapshot_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - Tags=[{'Key': tag_key, - 'Value': tag_value}] + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + Tags=[{"Key": tag_key, "Value": tag_value}], ) - tags_response = client.describe_tags(ResourceType='cluster') - tagged_resources = tags_response['TaggedResources'] + tags_response = client.describe_tags(ResourceType="cluster") + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('cluster') - tagged_resources[0]['ResourceName'].should.equal(cluster_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("cluster") + tagged_resources[0]["ResourceName"].should.equal(cluster_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) client.create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, - Tags=[{'Key': tag_key, - 'Value': tag_value}] + Tags=[{"Key": tag_key, "Value": tag_value}], ) - tags_response = client.describe_tags(ResourceType='snapshot') - tagged_resources = tags_response['TaggedResources'] + tags_response = client.describe_tags(ResourceType="snapshot") + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('snapshot') - tagged_resources[0]['ResourceName'].should.equal(snapshot_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("snapshot") + tagged_resources[0]["ResourceName"].should.equal(snapshot_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) @mock_redshift def test_describe_tags_cannot_specify_resource_type_and_resource_name(): - client = boto3.client('redshift', region_name='us-east-1') - resource_name = 'arn:aws:redshift:us-east-1:123456789012:cluster:cluster-id' - resource_type = 'cluster' + client = boto3.client("redshift", region_name="us-east-1") + resource_name = "arn:aws:redshift:us-east-1:123456789012:cluster:cluster-id" + resource_type = "cluster" client.describe_tags.when.called_with( - ResourceName=resource_name, - ResourceType=resource_type - ).should.throw(ClientError, 'using either an ARN or a resource type') + ResourceName=resource_name, ResourceType=resource_type + ).should.throw(ClientError, "using either an ARN or a resource type") @mock_redshift def test_describe_tags_with_resource_name(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'cluster-id' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - snapshot_identifier = 'snapshot-id' - snapshot_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'snapshot:{}/{}'.format(cluster_identifier, - snapshot_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "cluster-id" + cluster_arn = "arn:aws:redshift:us-east-1:123456789012:" "cluster:{}".format( + cluster_identifier + ) + snapshot_identifier = "snapshot-id" + snapshot_arn = "arn:aws:redshift:us-east-1:123456789012:" "snapshot:{}/{}".format( + cluster_identifier, snapshot_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - Tags=[{'Key': tag_key, - 'Value': tag_value}] + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + Tags=[{"Key": tag_key, "Value": tag_value}], ) tags_response = client.describe_tags(ResourceName=cluster_arn) - tagged_resources = tags_response['TaggedResources'] + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('cluster') - tagged_resources[0]['ResourceName'].should.equal(cluster_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("cluster") + tagged_resources[0]["ResourceName"].should.equal(cluster_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) client.create_cluster_snapshot( SnapshotIdentifier=snapshot_identifier, ClusterIdentifier=cluster_identifier, - Tags=[{'Key': tag_key, - 'Value': tag_value}] + Tags=[{"Key": tag_key, "Value": tag_value}], ) tags_response = client.describe_tags(ResourceName=snapshot_arn) - tagged_resources = tags_response['TaggedResources'] + tagged_resources = tags_response["TaggedResources"] list(tagged_resources).should.have.length_of(1) - tagged_resources[0]['ResourceType'].should.equal('snapshot') - tagged_resources[0]['ResourceName'].should.equal(snapshot_arn) - tag = tagged_resources[0]['Tag'] - tag['Key'].should.equal(tag_key) - tag['Value'].should.equal(tag_value) + tagged_resources[0]["ResourceType"].should.equal("snapshot") + tagged_resources[0]["ResourceName"].should.equal(snapshot_arn) + tag = tagged_resources[0]["Tag"] + tag["Key"].should.equal(tag_key) + tag["Value"].should.equal(tag_value) @mock_redshift def test_create_tags(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'cluster-id' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "cluster-id" + cluster_arn = "arn:aws:redshift:us-east-1:123456789012:" "cluster:{}".format( + cluster_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" num_tags = 5 tags = [] for i in range(0, num_tags): - tag = {'Key': '{}-{}'.format(tag_key, i), - 'Value': '{}-{}'.format(tag_value, i)} + tag = {"Key": "{}-{}".format(tag_key, i), "Value": "{}-{}".format(tag_value, i)} tags.append(tag) client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - ) - client.create_tags( - ResourceName=cluster_arn, - Tags=tags + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", ) + client.create_tags(ResourceName=cluster_arn, Tags=tags) response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = response['Clusters'][0] - list(cluster['Tags']).should.have.length_of(num_tags) + cluster = response["Clusters"][0] + list(cluster["Tags"]).should.have.length_of(num_tags) response = client.describe_tags(ResourceName=cluster_arn) - list(response['TaggedResources']).should.have.length_of(num_tags) + list(response["TaggedResources"]).should.have.length_of(num_tags) @mock_redshift def test_delete_tags(): - client = boto3.client('redshift', region_name='us-east-1') - cluster_identifier = 'cluster-id' - cluster_arn = 'arn:aws:redshift:us-east-1:123456789012:' \ - 'cluster:{}'.format(cluster_identifier) - tag_key = 'test-tag-key' - tag_value = 'test-tag-value' + client = boto3.client("redshift", region_name="us-east-1") + cluster_identifier = "cluster-id" + cluster_arn = "arn:aws:redshift:us-east-1:123456789012:" "cluster:{}".format( + cluster_identifier + ) + tag_key = "test-tag-key" + tag_value = "test-tag-value" tags = [] for i in range(1, 2): - tag = {'Key': '{}-{}'.format(tag_key, i), - 'Value': '{}-{}'.format(tag_value, i)} + tag = {"Key": "{}-{}".format(tag_key, i), "Value": "{}-{}".format(tag_value, i)} tags.append(tag) client.create_cluster( - DBName='test-db', + DBName="test-db", ClusterIdentifier=cluster_identifier, - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='username', - MasterUserPassword='password', - Tags=tags + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="username", + MasterUserPassword="password", + Tags=tags, ) client.delete_tags( ResourceName=cluster_arn, - TagKeys=[tag['Key'] for tag in tags - if tag['Key'] != '{}-1'.format(tag_key)] + TagKeys=[tag["Key"] for tag in tags if tag["Key"] != "{}-1".format(tag_key)], ) response = client.describe_clusters(ClusterIdentifier=cluster_identifier) - cluster = response['Clusters'][0] - list(cluster['Tags']).should.have.length_of(1) + cluster = response["Clusters"][0] + list(cluster["Tags"]).should.have.length_of(1) response = client.describe_tags(ResourceName=cluster_arn) - list(response['TaggedResources']).should.have.length_of(1) + list(response["TaggedResources"]).should.have.length_of(1) @mock_ec2 @mock_redshift def test_describe_tags_all_resource_types(): - ec2 = boto3.resource('ec2', region_name='us-east-1') - vpc = ec2.create_vpc(CidrBlock='10.0.0.0/16') - subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock='10.0.0.0/24') - client = boto3.client('redshift', region_name='us-east-1') + ec2 = boto3.resource("ec2", region_name="us-east-1") + vpc = ec2.create_vpc(CidrBlock="10.0.0.0/16") + subnet = ec2.create_subnet(VpcId=vpc.id, CidrBlock="10.0.0.0/24") + client = boto3.client("redshift", region_name="us-east-1") response = client.describe_tags() - list(response['TaggedResources']).should.have.length_of(0) + list(response["TaggedResources"]).should.have.length_of(0) client.create_cluster_subnet_group( - ClusterSubnetGroupName='my_subnet_group', - Description='This is my subnet group', + ClusterSubnetGroupName="my_subnet_group", + Description="This is my subnet group", SubnetIds=[subnet.id], - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster_security_group( ClusterSecurityGroupName="security_group1", Description="This is my security group", - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster( - DBName='test', - ClusterIdentifier='my_cluster', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + DBName="test", + ClusterIdentifier="my_cluster", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster_snapshot( - SnapshotIdentifier='my_snapshot', - ClusterIdentifier='my_cluster', - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + SnapshotIdentifier="my_snapshot", + ClusterIdentifier="my_cluster", + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) client.create_cluster_parameter_group( ParameterGroupName="my_parameter_group", ParameterGroupFamily="redshift-1.0", Description="This is my parameter group", - Tags=[{'Key': 'tag_key', - 'Value': 'tag_value'}] + Tags=[{"Key": "tag_key", "Value": "tag_value"}], ) response = client.describe_tags() - expected_types = ['cluster', 'parametergroup', 'securitygroup', 'snapshot', 'subnetgroup'] - tagged_resources = response['TaggedResources'] - returned_types = [resource['ResourceType'] for resource in tagged_resources] + expected_types = [ + "cluster", + "parametergroup", + "securitygroup", + "snapshot", + "subnetgroup", + ] + tagged_resources = response["TaggedResources"] + returned_types = [resource["ResourceType"] for resource in tagged_resources] list(tagged_resources).should.have.length_of(len(expected_types)) set(returned_types).should.equal(set(expected_types)) @mock_redshift def test_tagged_resource_not_found_error(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") - cluster_arn = 'arn:aws:redshift:us-east-1::cluster:fake' - client.describe_tags.when.called_with( - ResourceName=cluster_arn - ).should.throw(ClientError, 'cluster (fake) not found.') + cluster_arn = "arn:aws:redshift:us-east-1::cluster:fake" + client.describe_tags.when.called_with(ResourceName=cluster_arn).should.throw( + ClientError, "cluster (fake) not found." + ) - snapshot_arn = 'arn:aws:redshift:us-east-1::snapshot:cluster-id/snap-id' + snapshot_arn = "arn:aws:redshift:us-east-1::snapshot:cluster-id/snap-id" client.delete_tags.when.called_with( - ResourceName=snapshot_arn, - TagKeys=['test'] - ).should.throw(ClientError, 'snapshot (snap-id) not found.') + ResourceName=snapshot_arn, TagKeys=["test"] + ).should.throw(ClientError, "snapshot (snap-id) not found.") - client.describe_tags.when.called_with( - ResourceType='cluster' - ).should.throw(ClientError, "resource of type 'cluster' not found.") + client.describe_tags.when.called_with(ResourceType="cluster").should.throw( + ClientError, "resource of type 'cluster' not found." + ) - client.describe_tags.when.called_with( - ResourceName='bad:arn' - ).should.throw(ClientError, "Tagging is not supported for this type of resource") + client.describe_tags.when.called_with(ResourceName="bad:arn").should.throw( + ClientError, "Tagging is not supported for this type of resource" + ) @mock_redshift def test_enable_snapshot_copy(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - ClusterIdentifier='test', - ClusterType='single-node', - DBName='test', + ClusterIdentifier="test", + ClusterType="single-node", + DBName="test", Encrypted=True, - MasterUsername='user', - MasterUserPassword='password', - NodeType='ds2.xlarge', + MasterUsername="user", + MasterUserPassword="password", + NodeType="ds2.xlarge", ) client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', + ClusterIdentifier="test", + DestinationRegion="us-west-2", RetentionPeriod=3, - SnapshotCopyGrantName='copy-us-east-1-to-us-west-2' + SnapshotCopyGrantName="copy-us-east-1-to-us-west-2", + ) + response = client.describe_clusters(ClusterIdentifier="test") + cluster_snapshot_copy_status = response["Clusters"][0]["ClusterSnapshotCopyStatus"] + cluster_snapshot_copy_status["RetentionPeriod"].should.equal(3) + cluster_snapshot_copy_status["DestinationRegion"].should.equal("us-west-2") + cluster_snapshot_copy_status["SnapshotCopyGrantName"].should.equal( + "copy-us-east-1-to-us-west-2" ) - response = client.describe_clusters(ClusterIdentifier='test') - cluster_snapshot_copy_status = response['Clusters'][0]['ClusterSnapshotCopyStatus'] - cluster_snapshot_copy_status['RetentionPeriod'].should.equal(3) - cluster_snapshot_copy_status['DestinationRegion'].should.equal('us-west-2') - cluster_snapshot_copy_status['SnapshotCopyGrantName'].should.equal('copy-us-east-1-to-us-west-2') @mock_redshift def test_enable_snapshot_copy_unencrypted(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - ClusterIdentifier='test', - ClusterType='single-node', - DBName='test', - MasterUsername='user', - MasterUserPassword='password', - NodeType='ds2.xlarge', + ClusterIdentifier="test", + ClusterType="single-node", + DBName="test", + MasterUsername="user", + MasterUserPassword="password", + NodeType="ds2.xlarge", ) - client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', - ) - response = client.describe_clusters(ClusterIdentifier='test') - cluster_snapshot_copy_status = response['Clusters'][0]['ClusterSnapshotCopyStatus'] - cluster_snapshot_copy_status['RetentionPeriod'].should.equal(7) - cluster_snapshot_copy_status['DestinationRegion'].should.equal('us-west-2') + client.enable_snapshot_copy(ClusterIdentifier="test", DestinationRegion="us-west-2") + response = client.describe_clusters(ClusterIdentifier="test") + cluster_snapshot_copy_status = response["Clusters"][0]["ClusterSnapshotCopyStatus"] + cluster_snapshot_copy_status["RetentionPeriod"].should.equal(7) + cluster_snapshot_copy_status["DestinationRegion"].should.equal("us-west-2") @mock_redshift def test_disable_snapshot_copy(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - DBName='test', - ClusterIdentifier='test', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", ) client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', + ClusterIdentifier="test", + DestinationRegion="us-west-2", RetentionPeriod=3, - SnapshotCopyGrantName='copy-us-east-1-to-us-west-2', + SnapshotCopyGrantName="copy-us-east-1-to-us-west-2", ) - client.disable_snapshot_copy( - ClusterIdentifier='test', - ) - response = client.describe_clusters(ClusterIdentifier='test') - response['Clusters'][0].shouldnt.contain('ClusterSnapshotCopyStatus') + client.disable_snapshot_copy(ClusterIdentifier="test") + response = client.describe_clusters(ClusterIdentifier="test") + response["Clusters"][0].shouldnt.contain("ClusterSnapshotCopyStatus") @mock_redshift def test_modify_snapshot_copy_retention_period(): - client = boto3.client('redshift', region_name='us-east-1') + client = boto3.client("redshift", region_name="us-east-1") client.create_cluster( - DBName='test', - ClusterIdentifier='test', - ClusterType='single-node', - NodeType='ds2.xlarge', - MasterUsername='user', - MasterUserPassword='password', + DBName="test", + ClusterIdentifier="test", + ClusterType="single-node", + NodeType="ds2.xlarge", + MasterUsername="user", + MasterUserPassword="password", ) client.enable_snapshot_copy( - ClusterIdentifier='test', - DestinationRegion='us-west-2', + ClusterIdentifier="test", + DestinationRegion="us-west-2", RetentionPeriod=3, - SnapshotCopyGrantName='copy-us-east-1-to-us-west-2', + SnapshotCopyGrantName="copy-us-east-1-to-us-west-2", ) client.modify_snapshot_copy_retention_period( - ClusterIdentifier='test', - RetentionPeriod=5, + ClusterIdentifier="test", RetentionPeriod=5 ) - response = client.describe_clusters(ClusterIdentifier='test') - cluster_snapshot_copy_status = response['Clusters'][0]['ClusterSnapshotCopyStatus'] - cluster_snapshot_copy_status['RetentionPeriod'].should.equal(5) + response = client.describe_clusters(ClusterIdentifier="test") + cluster_snapshot_copy_status = response["Clusters"][0]["ClusterSnapshotCopyStatus"] + cluster_snapshot_copy_status["RetentionPeriod"].should.equal(5) diff --git a/tests/test_redshift/test_server.py b/tests/test_redshift/test_server.py index c37e9cab7..f4eee85e8 100644 --- a/tests/test_redshift/test_server.py +++ b/tests/test_redshift/test_server.py @@ -6,9 +6,9 @@ import sure # noqa import moto.server as server from moto import mock_redshift -''' +""" Test the different server responses -''' +""" @mock_redshift @@ -16,7 +16,7 @@ def test_describe_clusters(): backend = server.create_backend_app("redshift") test_client = backend.test_client() - res = test_client.get('/?Action=DescribeClusters') + res = test_client.get("/?Action=DescribeClusters") result = res.data.decode("utf-8") result.should.contain("") diff --git a/tests/test_resourcegroups/test_resourcegroups.py b/tests/test_resourcegroups/test_resourcegroups.py index bb3624413..29af9aad7 100644 --- a/tests/test_resourcegroups/test_resourcegroups.py +++ b/tests/test_resourcegroups/test_resourcegroups.py @@ -25,11 +25,13 @@ def test_create_group(): } ), }, - Tags={"resource_group_tag_key": "resource_group_tag_value"} + Tags={"resource_group_tag_key": "resource_group_tag_value"}, ) response["Group"]["Name"].should.contain("test_resource_group") response["ResourceQuery"]["Type"].should.contain("TAG_FILTERS_1_0") - response["Tags"]["resource_group_tag_key"].should.contain("resource_group_tag_value") + response["Tags"]["resource_group_tag_key"].should.contain( + "resource_group_tag_value" + ) @mock_resourcegroups @@ -76,7 +78,9 @@ def test_get_tags(): response = resource_groups.get_tags(Arn=response["Group"]["GroupArn"]) response["Tags"].should.have.length_of(1) - response["Tags"]["resource_group_tag_key"].should.contain("resource_group_tag_value") + response["Tags"]["resource_group_tag_key"].should.contain( + "resource_group_tag_value" + ) return response @@ -100,13 +104,17 @@ def test_tag(): response = resource_groups.tag( Arn=response["Arn"], - Tags={"resource_group_tag_key_2": "resource_group_tag_value_2"} + Tags={"resource_group_tag_key_2": "resource_group_tag_value_2"}, + ) + response["Tags"]["resource_group_tag_key_2"].should.contain( + "resource_group_tag_value_2" ) - response["Tags"]["resource_group_tag_key_2"].should.contain("resource_group_tag_value_2") response = resource_groups.get_tags(Arn=response["Arn"]) response["Tags"].should.have.length_of(2) - response["Tags"]["resource_group_tag_key_2"].should.contain("resource_group_tag_value_2") + response["Tags"]["resource_group_tag_key_2"].should.contain( + "resource_group_tag_value_2" + ) @mock_resourcegroups @@ -115,7 +123,9 @@ def test_untag(): response = test_get_tags() - response = resource_groups.untag(Arn=response["Arn"], Keys=["resource_group_tag_key"]) + response = resource_groups.untag( + Arn=response["Arn"], Keys=["resource_group_tag_key"] + ) response["Keys"].should.contain("resource_group_tag_key") response = resource_groups.get_tags(Arn=response["Arn"]) @@ -129,8 +139,7 @@ def test_update_group(): test_get_group() response = resource_groups.update_group( - GroupName="test_resource_group", - Description="description_2", + GroupName="test_resource_group", Description="description_2" ) response["Group"]["Description"].should.contain("description_2") @@ -154,12 +163,16 @@ def test_update_group_query(): "StackIdentifier": ( "arn:aws:cloudformation:eu-west-1:012345678912:stack/" "test_stack/c223eca0-e744-11e8-8910-500c41f59083" - ) + ), } ), }, ) - response["GroupQuery"]["ResourceQuery"]["Type"].should.contain("CLOUDFORMATION_STACK_1_0") + response["GroupQuery"]["ResourceQuery"]["Type"].should.contain( + "CLOUDFORMATION_STACK_1_0" + ) response = resource_groups.get_group_query(GroupName="test_resource_group") - response["GroupQuery"]["ResourceQuery"]["Type"].should.contain("CLOUDFORMATION_STACK_1_0") + response["GroupQuery"]["ResourceQuery"]["Type"].should.contain( + "CLOUDFORMATION_STACK_1_0" + ) diff --git a/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py b/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py index 1e42dfe55..84f7a8b86 100644 --- a/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py +++ b/tests/test_resourcegroupstaggingapi/test_resourcegroupstaggingapi.py @@ -13,7 +13,7 @@ from moto import mock_s3 @mock_resourcegroupstaggingapi def test_get_resources_s3(): # Tests pagination - s3_client = boto3.client('s3', region_name='eu-central-1') + s3_client = boto3.client("s3", region_name="eu-central-1") # Will end up having key1,key2,key3,key4 response_keys = set() @@ -21,26 +21,25 @@ def test_get_resources_s3(): # Create 4 buckets for i in range(1, 5): i_str = str(i) - s3_client.create_bucket(Bucket='test_bucket' + i_str) + s3_client.create_bucket(Bucket="test_bucket" + i_str) s3_client.put_bucket_tagging( - Bucket='test_bucket' + i_str, - Tagging={'TagSet': [{'Key': 'key' + i_str, 'Value': 'value' + i_str}]} + Bucket="test_bucket" + i_str, + Tagging={"TagSet": [{"Key": "key" + i_str, "Value": "value" + i_str}]}, ) - response_keys.add('key' + i_str) + response_keys.add("key" + i_str) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") resp = rtapi.get_resources(ResourcesPerPage=2) - for resource in resp['ResourceTagMappingList']: - response_keys.remove(resource['Tags'][0]['Key']) + for resource in resp["ResourceTagMappingList"]: + response_keys.remove(resource["Tags"][0]["Key"]) response_keys.should.have.length_of(2) resp = rtapi.get_resources( - ResourcesPerPage=2, - PaginationToken=resp['PaginationToken'] + ResourcesPerPage=2, PaginationToken=resp["PaginationToken"] ) - for resource in resp['ResourceTagMappingList']: - response_keys.remove(resource['Tags'][0]['Key']) + for resource in resp["ResourceTagMappingList"]: + response_keys.remove(resource["Tags"][0]["Key"]) response_keys.should.have.length_of(0) @@ -48,109 +47,86 @@ def test_get_resources_s3(): @mock_ec2 @mock_resourcegroupstaggingapi def test_get_resources_ec2(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") instances = client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) - instance_id = instances['Instances'][0]['InstanceId'] - image_id = client.create_image(Name='testami', InstanceId=instance_id)['ImageId'] + instance_id = instances["Instances"][0]["InstanceId"] + image_id = client.create_image(Name="testami", InstanceId=instance_id)["ImageId"] - client.create_tags( - Resources=[image_id], - Tags=[{'Key': 'ami', 'Value': 'test'}] - ) + client.create_tags(Resources=[image_id], Tags=[{"Key": "ami", "Value": "test"}]) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") resp = rtapi.get_resources() # Check we have 1 entry for Instance, 1 Entry for AMI - resp['ResourceTagMappingList'].should.have.length_of(2) + resp["ResourceTagMappingList"].should.have.length_of(2) # 1 Entry for AMI - resp = rtapi.get_resources(ResourceTypeFilters=['ec2:image']) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('image/') + resp = rtapi.get_resources(ResourceTypeFilters=["ec2:image"]) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("image/") # As were iterating the same data, this rules out that the test above was a fluke - resp = rtapi.get_resources(ResourceTypeFilters=['ec2:instance']) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('instance/') + resp = rtapi.get_resources(ResourceTypeFilters=["ec2:instance"]) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("instance/") # Basic test of tag filters - resp = rtapi.get_resources(TagFilters=[{'Key': 'MY_TAG1', 'Values': ['MY_VALUE1', 'some_other_value']}]) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('instance/') + resp = rtapi.get_resources( + TagFilters=[{"Key": "MY_TAG1", "Values": ["MY_VALUE1", "some_other_value"]}] + ) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("instance/") @mock_ec2 @mock_resourcegroupstaggingapi def test_get_tag_keys_ec2(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") resp = rtapi.get_tag_keys() - resp['TagKeys'].should.contain('MY_TAG1') - resp['TagKeys'].should.contain('MY_TAG2') - resp['TagKeys'].should.contain('MY_TAG3') + resp["TagKeys"].should.contain("MY_TAG1") + resp["TagKeys"].should.contain("MY_TAG2") + resp["TagKeys"].should.contain("MY_TAG3") # TODO test pagenation @@ -158,148 +134,114 @@ def test_get_tag_keys_ec2(): @mock_ec2 @mock_resourcegroupstaggingapi def test_get_tag_values_ec2(): - client = boto3.client('ec2', region_name='eu-central-1') + client = boto3.client("ec2", region_name="eu-central-1") client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE1', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE2', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE1"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE2"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE3', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE3"}], }, ], ) client.run_instances( - ImageId='ami-123', + ImageId="ami-123", MinCount=1, MaxCount=1, - InstanceType='t2.micro', + InstanceType="t2.micro", TagSpecifications=[ { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG1', - 'Value': 'MY_VALUE4', - }, - { - 'Key': 'MY_TAG2', - 'Value': 'MY_VALUE5', - }, + "ResourceType": "instance", + "Tags": [ + {"Key": "MY_TAG1", "Value": "MY_VALUE4"}, + {"Key": "MY_TAG2", "Value": "MY_VALUE5"}, ], }, { - 'ResourceType': 'instance', - 'Tags': [ - { - 'Key': 'MY_TAG3', - 'Value': 'MY_VALUE6', - }, - ] + "ResourceType": "instance", + "Tags": [{"Key": "MY_TAG3", "Value": "MY_VALUE6"}], }, ], ) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='eu-central-1') - resp = rtapi.get_tag_values(Key='MY_TAG1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="eu-central-1") + resp = rtapi.get_tag_values(Key="MY_TAG1") + + resp["TagValues"].should.contain("MY_VALUE1") + resp["TagValues"].should.contain("MY_VALUE4") - resp['TagValues'].should.contain('MY_VALUE1') - resp['TagValues'].should.contain('MY_VALUE4') @mock_ec2 @mock_elbv2 @mock_kms @mock_resourcegroupstaggingapi def test_get_many_resources(): - elbv2 = boto3.client('elbv2', region_name='us-east-1') - ec2 = boto3.resource('ec2', region_name='us-east-1') - kms = boto3.client('kms', region_name='us-east-1') + elbv2 = boto3.client("elbv2", region_name="us-east-1") + ec2 = boto3.resource("ec2", region_name="us-east-1") + kms = boto3.client("kms", region_name="us-east-1") security_group = ec2.create_security_group( - GroupName='a-security-group', Description='First One') - vpc = ec2.create_vpc(CidrBlock='172.28.7.0/24', InstanceTenancy='default') + GroupName="a-security-group", Description="First One" + ) + vpc = ec2.create_vpc(CidrBlock="172.28.7.0/24", InstanceTenancy="default") subnet1 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.192/26', - AvailabilityZone='us-east-1a') + VpcId=vpc.id, CidrBlock="172.28.7.192/26", AvailabilityZone="us-east-1a" + ) subnet2 = ec2.create_subnet( - VpcId=vpc.id, - CidrBlock='172.28.7.0/26', - AvailabilityZone='us-east-1b') + VpcId=vpc.id, CidrBlock="172.28.7.0/26", AvailabilityZone="us-east-1b" + ) elbv2.create_load_balancer( - Name='my-lb', + Name="my-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', + Scheme="internal", Tags=[ - { - 'Key': 'key_name', - 'Value': 'a_value' - }, - { - 'Key': 'key_2', - 'Value': 'val2' - } - ] - ) + {"Key": "key_name", "Value": "a_value"}, + {"Key": "key_2", "Value": "val2"}, + ], + ) elbv2.create_load_balancer( - Name='my-other-lb', + Name="my-other-lb", Subnets=[subnet1.id, subnet2.id], SecurityGroups=[security_group.id], - Scheme='internal', - ) + Scheme="internal", + ) kms.create_key( - KeyUsage='ENCRYPT_DECRYPT', + KeyUsage="ENCRYPT_DECRYPT", Tags=[ - { - 'TagKey': 'key_name', - 'TagValue': 'a_value' - }, - { - 'TagKey': 'key_2', - 'TagValue': 'val2' - } - ] - ) + {"TagKey": "key_name", "TagValue": "a_value"}, + {"TagKey": "key_2", "TagValue": "val2"}, + ], + ) - rtapi = boto3.client('resourcegroupstaggingapi', region_name='us-east-1') + rtapi = boto3.client("resourcegroupstaggingapi", region_name="us-east-1") - resp = rtapi.get_resources(ResourceTypeFilters=['elasticloadbalancer:loadbalancer']) + resp = rtapi.get_resources(ResourceTypeFilters=["elasticloadbalancer:loadbalancer"]) - resp['ResourceTagMappingList'].should.have.length_of(2) - resp['ResourceTagMappingList'][0]['ResourceARN'].should.contain('loadbalancer/') + resp["ResourceTagMappingList"].should.have.length_of(2) + resp["ResourceTagMappingList"][0]["ResourceARN"].should.contain("loadbalancer/") resp = rtapi.get_resources( - ResourceTypeFilters=['elasticloadbalancer:loadbalancer'], - TagFilters=[{ - 'Key': 'key_name' - }] - ) + ResourceTypeFilters=["elasticloadbalancer:loadbalancer"], + TagFilters=[{"Key": "key_name"}], + ) - resp['ResourceTagMappingList'].should.have.length_of(1) - resp['ResourceTagMappingList'][0]['Tags'].should.contain({'Key': 'key_name', 'Value': 'a_value'}) + resp["ResourceTagMappingList"].should.have.length_of(1) + resp["ResourceTagMappingList"][0]["Tags"].should.contain( + {"Key": "key_name", "Value": "a_value"} + ) # TODO test pagenation diff --git a/tests/test_resourcegroupstaggingapi/test_server.py b/tests/test_resourcegroupstaggingapi/test_server.py index 311b1f03e..836fa5828 100644 --- a/tests/test_resourcegroupstaggingapi/test_server.py +++ b/tests/test_resourcegroupstaggingapi/test_server.py @@ -4,9 +4,9 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_resourcegroupstaggingapi_list(): @@ -15,10 +15,10 @@ def test_resourcegroupstaggingapi_list(): # do test headers = { - 'X-Amz-Target': 'ResourceGroupsTaggingAPI_20170126.GetResources', - 'X-Amz-Date': '20171114T234623Z' + "X-Amz-Target": "ResourceGroupsTaggingAPI_20170126.GetResources", + "X-Amz-Date": "20171114T234623Z", } - resp = test_client.post('/', headers=headers, data='{}') + resp = test_client.post("/", headers=headers, data="{}") assert resp.status_code == 200 - assert b'ResourceTagMappingList' in resp.data + assert b"ResourceTagMappingList" in resp.data diff --git a/tests/test_route53/test_route53.py b/tests/test_route53/test_route53.py index babd54d26..0e9a1e2c0 100644 --- a/tests/test_route53/test_route53.py +++ b/tests/test_route53/test_route53.py @@ -17,7 +17,7 @@ from moto import mock_route53, mock_route53_deprecated @mock_route53_deprecated def test_hosted_zone(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") firstzone = conn.create_hosted_zone("testdns.aws.com") zones = conn.get_all_hosted_zones() len(zones["ListHostedZonesResponse"]["HostedZones"]).should.equal(1) @@ -26,30 +26,29 @@ def test_hosted_zone(): zones = conn.get_all_hosted_zones() len(zones["ListHostedZonesResponse"]["HostedZones"]).should.equal(2) - id1 = firstzone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + id1 = firstzone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] zone = conn.get_hosted_zone(id1) - zone["GetHostedZoneResponse"]["HostedZone"][ - "Name"].should.equal("testdns.aws.com.") + zone["GetHostedZoneResponse"]["HostedZone"]["Name"].should.equal("testdns.aws.com.") conn.delete_hosted_zone(id1) zones = conn.get_all_hosted_zones() len(zones["ListHostedZonesResponse"]["HostedZones"]).should.equal(1) conn.get_hosted_zone.when.called_with("abcd").should.throw( - boto.route53.exception.DNSServerError, "404 Not Found") + boto.route53.exception.DNSServerError, "404 Not Found" + ) @mock_route53_deprecated def test_rrset(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") conn.get_all_rrsets.when.called_with("abcd", type="A").should.throw( - boto.route53.exception.DNSServerError, "404 Not Found") + boto.route53.exception.DNSServerError, "404 Not Found" + ) zone = conn.create_hosted_zone("testdns.aws.com") - zoneid = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zoneid = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("CREATE", "foo.bar.testdns.aws.com", "A") @@ -58,7 +57,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('1.2.3.4') + rrsets[0].resource_records[0].should.equal("1.2.3.4") rrsets = conn.get_all_rrsets(zoneid, type="CNAME") rrsets.should.have.length_of(0) @@ -71,7 +70,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('5.6.7.8') + rrsets[0].resource_records[0].should.equal("5.6.7.8") changes = ResourceRecordSets(conn, zoneid) changes.add_change("DELETE", "foo.bar.testdns.aws.com", "A") @@ -87,7 +86,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('1.2.3.4') + rrsets[0].resource_records[0].should.equal("1.2.3.4") changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("UPSERT", "foo.bar.testdns.aws.com", "A") @@ -96,7 +95,7 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('5.6.7.8') + rrsets[0].resource_records[0].should.equal("5.6.7.8") changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("UPSERT", "foo.bar.testdns.aws.com", "TXT") @@ -105,8 +104,8 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid) rrsets.should.have.length_of(2) - rrsets[0].resource_records[0].should.equal('5.6.7.8') - rrsets[1].resource_records[0].should.equal('foo') + rrsets[0].resource_records[0].should.equal("5.6.7.8") + rrsets[1].resource_records[0].should.equal("foo") changes = ResourceRecordSets(conn, zoneid) changes.add_change("DELETE", "foo.bar.testdns.aws.com", "A") @@ -123,29 +122,25 @@ def test_rrset(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(2) - rrsets = conn.get_all_rrsets( - zoneid, name="bar.foo.testdns.aws.com", type="A") + rrsets = conn.get_all_rrsets(zoneid, name="bar.foo.testdns.aws.com", type="A") rrsets.should.have.length_of(1) - rrsets[0].resource_records[0].should.equal('5.6.7.8') + rrsets[0].resource_records[0].should.equal("5.6.7.8") - rrsets = conn.get_all_rrsets( - zoneid, name="foo.bar.testdns.aws.com", type="A") + rrsets = conn.get_all_rrsets(zoneid, name="foo.bar.testdns.aws.com", type="A") rrsets.should.have.length_of(2) resource_records = [rr for rr_set in rrsets for rr in rr_set.resource_records] - resource_records.should.contain('1.2.3.4') - resource_records.should.contain('5.6.7.8') + resource_records.should.contain("1.2.3.4") + resource_records.should.contain("5.6.7.8") - rrsets = conn.get_all_rrsets( - zoneid, name="foo.foo.testdns.aws.com", type="A") + rrsets = conn.get_all_rrsets(zoneid, name="foo.foo.testdns.aws.com", type="A") rrsets.should.have.length_of(0) @mock_route53_deprecated def test_rrset_with_multiple_values(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") zone = conn.create_hosted_zone("testdns.aws.com") - zoneid = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zoneid = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zoneid) change = changes.add_change("CREATE", "foo.bar.testdns.aws.com", "A") @@ -155,39 +150,48 @@ def test_rrset_with_multiple_values(): rrsets = conn.get_all_rrsets(zoneid, type="A") rrsets.should.have.length_of(1) - set(rrsets[0].resource_records).should.equal(set(['1.2.3.4', '5.6.7.8'])) + set(rrsets[0].resource_records).should.equal(set(["1.2.3.4", "5.6.7.8"])) @mock_route53_deprecated def test_alias_rrset(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") zone = conn.create_hosted_zone("testdns.aws.com") - zoneid = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zoneid = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zoneid) - changes.add_change("CREATE", "foo.alias.testdns.aws.com", "A", - alias_hosted_zone_id="Z3DG6IL3SJCGPX", alias_dns_name="foo.testdns.aws.com") - changes.add_change("CREATE", "bar.alias.testdns.aws.com", "CNAME", - alias_hosted_zone_id="Z3DG6IL3SJCGPX", alias_dns_name="bar.testdns.aws.com") + changes.add_change( + "CREATE", + "foo.alias.testdns.aws.com", + "A", + alias_hosted_zone_id="Z3DG6IL3SJCGPX", + alias_dns_name="foo.testdns.aws.com", + ) + changes.add_change( + "CREATE", + "bar.alias.testdns.aws.com", + "CNAME", + alias_hosted_zone_id="Z3DG6IL3SJCGPX", + alias_dns_name="bar.testdns.aws.com", + ) changes.commit() rrsets = conn.get_all_rrsets(zoneid, type="A") alias_targets = [rr_set.alias_dns_name for rr_set in rrsets] alias_targets.should.have.length_of(2) - alias_targets.should.contain('foo.testdns.aws.com') - alias_targets.should.contain('bar.testdns.aws.com') - rrsets[0].alias_dns_name.should.equal('foo.testdns.aws.com') + alias_targets.should.contain("foo.testdns.aws.com") + alias_targets.should.contain("bar.testdns.aws.com") + rrsets[0].alias_dns_name.should.equal("foo.testdns.aws.com") rrsets[0].resource_records.should.have.length_of(0) rrsets = conn.get_all_rrsets(zoneid, type="CNAME") rrsets.should.have.length_of(1) - rrsets[0].alias_dns_name.should.equal('bar.testdns.aws.com') + rrsets[0].alias_dns_name.should.equal("bar.testdns.aws.com") rrsets[0].resource_records.should.have.length_of(0) @mock_route53_deprecated def test_create_health_check(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") check = HealthCheck( ip_addr="10.0.0.25", @@ -201,65 +205,51 @@ def test_create_health_check(): ) conn.create_health_check(check) - checks = conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = conn.get_list_health_checks()["ListHealthChecksResponse"]["HealthChecks"] list(checks).should.have.length_of(1) check = checks[0] - config = check['HealthCheckConfig'] - config['IPAddress'].should.equal("10.0.0.25") - config['Port'].should.equal("80") - config['Type'].should.equal("HTTP") - config['ResourcePath'].should.equal("/") - config['FullyQualifiedDomainName'].should.equal("example.com") - config['SearchString'].should.equal("a good response") - config['RequestInterval'].should.equal("10") - config['FailureThreshold'].should.equal("2") + config = check["HealthCheckConfig"] + config["IPAddress"].should.equal("10.0.0.25") + config["Port"].should.equal("80") + config["Type"].should.equal("HTTP") + config["ResourcePath"].should.equal("/") + config["FullyQualifiedDomainName"].should.equal("example.com") + config["SearchString"].should.equal("a good response") + config["RequestInterval"].should.equal("10") + config["FailureThreshold"].should.equal("2") @mock_route53_deprecated def test_delete_health_check(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") - check = HealthCheck( - ip_addr="10.0.0.25", - port=80, - hc_type="HTTP", - resource_path="/", - ) + check = HealthCheck(ip_addr="10.0.0.25", port=80, hc_type="HTTP", resource_path="/") conn.create_health_check(check) - checks = conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = conn.get_list_health_checks()["ListHealthChecksResponse"]["HealthChecks"] list(checks).should.have.length_of(1) - health_check_id = checks[0]['Id'] + health_check_id = checks[0]["Id"] conn.delete_health_check(health_check_id) - checks = conn.get_list_health_checks()['ListHealthChecksResponse'][ - 'HealthChecks'] + checks = conn.get_list_health_checks()["ListHealthChecksResponse"]["HealthChecks"] list(checks).should.have.length_of(0) @mock_route53_deprecated def test_use_health_check_in_resource_record_set(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") - check = HealthCheck( - ip_addr="10.0.0.25", - port=80, - hc_type="HTTP", - resource_path="/", - ) - check = conn.create_health_check( - check)['CreateHealthCheckResponse']['HealthCheck'] - check_id = check['Id'] + check = HealthCheck(ip_addr="10.0.0.25", port=80, hc_type="HTTP", resource_path="/") + check = conn.create_health_check(check)["CreateHealthCheckResponse"]["HealthCheck"] + check_id = check["Id"] zone = conn.create_hosted_zone("testdns.aws.com") - zone_id = zone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + zone_id = zone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] changes = ResourceRecordSets(conn, zone_id) change = changes.add_change( - "CREATE", "foo.bar.testdns.aws.com", "A", health_check=check_id) + "CREATE", "foo.bar.testdns.aws.com", "A", health_check=check_id + ) change.add_value("1.2.3.4") changes.commit() @@ -269,20 +259,20 @@ def test_use_health_check_in_resource_record_set(): @mock_route53_deprecated def test_hosted_zone_comment_preserved(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") - firstzone = conn.create_hosted_zone( - "testdns.aws.com.", comment="test comment") - zone_id = firstzone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + firstzone = conn.create_hosted_zone("testdns.aws.com.", comment="test comment") + zone_id = firstzone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] hosted_zone = conn.get_hosted_zone(zone_id) - hosted_zone["GetHostedZoneResponse"]["HostedZone"][ - "Config"]["Comment"].should.equal("test comment") + hosted_zone["GetHostedZoneResponse"]["HostedZone"]["Config"][ + "Comment" + ].should.equal("test comment") hosted_zones = conn.get_all_hosted_zones() - hosted_zones["ListHostedZonesResponse"]["HostedZones"][ - 0]["Config"]["Comment"].should.equal("test comment") + hosted_zones["ListHostedZonesResponse"]["HostedZones"][0]["Config"][ + "Comment" + ].should.equal("test comment") zone = conn.get_zone("testdns.aws.com.") zone.config["Comment"].should.equal("test comment") @@ -295,21 +285,22 @@ def test_deleting_weighted_route(): conn.create_hosted_zone("testdns.aws.com.") zone = conn.get_zone("testdns.aws.com.") - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-foo', '50')) - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-bar', '50')) + zone.add_cname( + "cname.testdns.aws.com", "example.com", identifier=("success-test-foo", "50") + ) + zone.add_cname( + "cname.testdns.aws.com", "example.com", identifier=("success-test-bar", "50") + ) - cnames = zone.get_cname('cname.testdns.aws.com.', all=True) + cnames = zone.get_cname("cname.testdns.aws.com.", all=True) cnames.should.have.length_of(2) - foo_cname = [cname for cname in cnames if cname.identifier == - 'success-test-foo'][0] + foo_cname = [cname for cname in cnames if cname.identifier == "success-test-foo"][0] zone.delete_record(foo_cname) - cname = zone.get_cname('cname.testdns.aws.com.', all=True) + cname = zone.get_cname("cname.testdns.aws.com.", all=True) # When get_cname only had one result, it returns just that result instead # of a list. - cname.identifier.should.equal('success-test-bar') + cname.identifier.should.equal("success-test-bar") @mock_route53_deprecated @@ -319,59 +310,63 @@ def test_deleting_latency_route(): conn.create_hosted_zone("testdns.aws.com.") zone = conn.get_zone("testdns.aws.com.") - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-foo', 'us-west-2')) - zone.add_cname("cname.testdns.aws.com", "example.com", - identifier=('success-test-bar', 'us-west-1')) + zone.add_cname( + "cname.testdns.aws.com", + "example.com", + identifier=("success-test-foo", "us-west-2"), + ) + zone.add_cname( + "cname.testdns.aws.com", + "example.com", + identifier=("success-test-bar", "us-west-1"), + ) - cnames = zone.get_cname('cname.testdns.aws.com.', all=True) + cnames = zone.get_cname("cname.testdns.aws.com.", all=True) cnames.should.have.length_of(2) - foo_cname = [cname for cname in cnames if cname.identifier == - 'success-test-foo'][0] - foo_cname.region.should.equal('us-west-2') + foo_cname = [cname for cname in cnames if cname.identifier == "success-test-foo"][0] + foo_cname.region.should.equal("us-west-2") zone.delete_record(foo_cname) - cname = zone.get_cname('cname.testdns.aws.com.', all=True) + cname = zone.get_cname("cname.testdns.aws.com.", all=True) # When get_cname only had one result, it returns just that result instead # of a list. - cname.identifier.should.equal('success-test-bar') - cname.region.should.equal('us-west-1') + cname.identifier.should.equal("success-test-bar") + cname.region.should.equal("us-west-1") @mock_route53_deprecated def test_hosted_zone_private_zone_preserved(): - conn = boto.connect_route53('the_key', 'the_secret') + conn = boto.connect_route53("the_key", "the_secret") firstzone = conn.create_hosted_zone( - "testdns.aws.com.", private_zone=True, vpc_id='vpc-fake', vpc_region='us-east-1') - zone_id = firstzone["CreateHostedZoneResponse"][ - "HostedZone"]["Id"].split("/")[-1] + "testdns.aws.com.", private_zone=True, vpc_id="vpc-fake", vpc_region="us-east-1" + ) + zone_id = firstzone["CreateHostedZoneResponse"]["HostedZone"]["Id"].split("/")[-1] hosted_zone = conn.get_hosted_zone(zone_id) # in (original) boto, these bools returned as strings. - hosted_zone["GetHostedZoneResponse"]["HostedZone"][ - "Config"]["PrivateZone"].should.equal('True') + hosted_zone["GetHostedZoneResponse"]["HostedZone"]["Config"][ + "PrivateZone" + ].should.equal("True") hosted_zones = conn.get_all_hosted_zones() - hosted_zones["ListHostedZonesResponse"]["HostedZones"][ - 0]["Config"]["PrivateZone"].should.equal('True') + hosted_zones["ListHostedZonesResponse"]["HostedZones"][0]["Config"][ + "PrivateZone" + ].should.equal("True") zone = conn.get_zone("testdns.aws.com.") - zone.config["PrivateZone"].should.equal('True') + zone.config["PrivateZone"].should.equal("True") @mock_route53 def test_hosted_zone_private_zone_preserved_boto3(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") # TODO: actually create_hosted_zone statements with PrivateZone=True, but without # a _valid_ vpc-id should fail. firstzone = conn.create_hosted_zone( Name="testdns.aws.com.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="Test", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="Test"), ) zone_id = firstzone["HostedZone"]["Id"].split("/")[-1] @@ -389,24 +384,25 @@ def test_hosted_zone_private_zone_preserved_boto3(): @mock_route53 def test_list_or_change_tags_for_resource_request(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") health_check = conn.create_health_check( - CallerReference='foobar', + CallerReference="foobar", HealthCheckConfig={ - 'IPAddress': '192.0.2.44', - 'Port': 123, - 'Type': 'HTTP', - 'ResourcePath': '/', - 'RequestInterval': 30, - 'FailureThreshold': 123, - 'HealthThreshold': 123, - } + "IPAddress": "192.0.2.44", + "Port": 123, + "Type": "HTTP", + "ResourcePath": "/", + "RequestInterval": 30, + "FailureThreshold": 123, + "HealthThreshold": 123, + }, ) - healthcheck_id = health_check['HealthCheck']['Id'] + healthcheck_id = health_check["HealthCheck"]["Id"] # confirm this works for resources with zero tags response = conn.list_tags_for_resource( - ResourceType="healthcheck", ResourceId=healthcheck_id) + ResourceType="healthcheck", ResourceId=healthcheck_id + ) response["ResourceTagSet"]["Tags"].should.be.empty tag1 = {"Key": "Deploy", "Value": "True"} @@ -414,92 +410,83 @@ def test_list_or_change_tags_for_resource_request(): # Test adding a tag for a resource id conn.change_tags_for_resource( - ResourceType='healthcheck', - ResourceId=healthcheck_id, - AddTags=[tag1, tag2] + ResourceType="healthcheck", ResourceId=healthcheck_id, AddTags=[tag1, tag2] ) # Check to make sure that the response has the 'ResourceTagSet' key response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response.should.contain('ResourceTagSet') + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response.should.contain("ResourceTagSet") # Validate that each key was added - response['ResourceTagSet']['Tags'].should.contain(tag1) - response['ResourceTagSet']['Tags'].should.contain(tag2) + response["ResourceTagSet"]["Tags"].should.contain(tag1) + response["ResourceTagSet"]["Tags"].should.contain(tag2) - len(response['ResourceTagSet']['Tags']).should.equal(2) + len(response["ResourceTagSet"]["Tags"]).should.equal(2) # Try to remove the tags conn.change_tags_for_resource( - ResourceType='healthcheck', + ResourceType="healthcheck", ResourceId=healthcheck_id, - RemoveTagKeys=[tag1['Key']] + RemoveTagKeys=[tag1["Key"]], ) # Check to make sure that the response has the 'ResourceTagSet' key response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response.should.contain('ResourceTagSet') - response['ResourceTagSet']['Tags'].should_not.contain(tag1) - response['ResourceTagSet']['Tags'].should.contain(tag2) + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response.should.contain("ResourceTagSet") + response["ResourceTagSet"]["Tags"].should_not.contain(tag1) + response["ResourceTagSet"]["Tags"].should.contain(tag2) # Remove the second tag conn.change_tags_for_resource( - ResourceType='healthcheck', + ResourceType="healthcheck", ResourceId=healthcheck_id, - RemoveTagKeys=[tag2['Key']] + RemoveTagKeys=[tag2["Key"]], ) response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response['ResourceTagSet']['Tags'].should_not.contain(tag2) + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response["ResourceTagSet"]["Tags"].should_not.contain(tag2) # Re-add the tags conn.change_tags_for_resource( - ResourceType='healthcheck', - ResourceId=healthcheck_id, - AddTags=[tag1, tag2] + ResourceType="healthcheck", ResourceId=healthcheck_id, AddTags=[tag1, tag2] ) # Remove both conn.change_tags_for_resource( - ResourceType='healthcheck', + ResourceType="healthcheck", ResourceId=healthcheck_id, - RemoveTagKeys=[tag1['Key'], tag2['Key']] + RemoveTagKeys=[tag1["Key"], tag2["Key"]], ) response = conn.list_tags_for_resource( - ResourceType='healthcheck', ResourceId=healthcheck_id) - response['ResourceTagSet']['Tags'].should.be.empty + ResourceType="healthcheck", ResourceId=healthcheck_id + ) + response["ResourceTagSet"]["Tags"].should.be.empty @mock_route53 def test_list_hosted_zones_by_name(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") conn.create_hosted_zone( Name="test.b.com.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="test com", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="test com"), ) conn.create_hosted_zone( Name="test.a.org.", - CallerReference=str(hash('bar')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="test org", - ) + CallerReference=str(hash("bar")), + HostedZoneConfig=dict(PrivateZone=True, Comment="test org"), ) conn.create_hosted_zone( Name="test.a.org.", - CallerReference=str(hash('bar')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="test org 2", - ) + CallerReference=str(hash("bar")), + HostedZoneConfig=dict(PrivateZone=True, Comment="test org 2"), ) # test lookup @@ -521,14 +508,11 @@ def test_list_hosted_zones_by_name(): @mock_route53 def test_change_resource_record_sets_crud_valid(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") conn.create_hosted_zone( Name="db.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="db", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="db"), ) zones = conn.list_hosted_zones_by_name(DNSName="db.") @@ -538,244 +522,244 @@ def test_change_resource_record_sets_crud_valid(): # Create A Record. a_record_endpoint_payload = { - 'Comment': 'Create A record prod.redis.db', - 'Changes': [ + "Comment": "Create A record prod.redis.db", + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db.', - 'Type': 'A', - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "prod.redis.db.", + "Type": "A", + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=a_record_endpoint_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=a_record_endpoint_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(1) - a_record_detail = response['ResourceRecordSets'][0] - a_record_detail['Name'].should.equal('prod.redis.db.') - a_record_detail['Type'].should.equal('A') - a_record_detail['TTL'].should.equal(10) - a_record_detail['ResourceRecords'].should.equal([{'Value': '127.0.0.1'}]) + len(response["ResourceRecordSets"]).should.equal(1) + a_record_detail = response["ResourceRecordSets"][0] + a_record_detail["Name"].should.equal("prod.redis.db.") + a_record_detail["Type"].should.equal("A") + a_record_detail["TTL"].should.equal(10) + a_record_detail["ResourceRecords"].should.equal([{"Value": "127.0.0.1"}]) # Update A Record. cname_record_endpoint_payload = { - 'Comment': 'Update A record prod.redis.db', - 'Changes': [ + "Comment": "Update A record prod.redis.db", + "Changes": [ { - 'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db.', - 'Type': 'A', - 'TTL': 60, - 'ResourceRecords': [{ - 'Value': '192.168.1.1' - }] - } + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": "prod.redis.db.", + "Type": "A", + "TTL": 60, + "ResourceRecords": [{"Value": "192.168.1.1"}], + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_record_endpoint_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=cname_record_endpoint_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(1) - cname_record_detail = response['ResourceRecordSets'][0] - cname_record_detail['Name'].should.equal('prod.redis.db.') - cname_record_detail['Type'].should.equal('A') - cname_record_detail['TTL'].should.equal(60) - cname_record_detail['ResourceRecords'].should.equal([{'Value': '192.168.1.1'}]) + len(response["ResourceRecordSets"]).should.equal(1) + cname_record_detail = response["ResourceRecordSets"][0] + cname_record_detail["Name"].should.equal("prod.redis.db.") + cname_record_detail["Type"].should.equal("A") + cname_record_detail["TTL"].should.equal(60) + cname_record_detail["ResourceRecords"].should.equal([{"Value": "192.168.1.1"}]) # Update to add Alias. cname_alias_record_endpoint_payload = { - 'Comment': 'Update to Alias prod.redis.db', - 'Changes': [ + "Comment": "Update to Alias prod.redis.db", + "Changes": [ { - 'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db.', - 'Type': 'A', - 'TTL': 60, - 'AliasTarget': { - 'HostedZoneId': hosted_zone_id, - 'DNSName': 'prod.redis.alias.', - 'EvaluateTargetHealth': False, - } - } + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": "prod.redis.db.", + "Type": "A", + "TTL": 60, + "AliasTarget": { + "HostedZoneId": hosted_zone_id, + "DNSName": "prod.redis.alias.", + "EvaluateTargetHealth": False, + }, + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=cname_alias_record_endpoint_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - cname_alias_record_detail = response['ResourceRecordSets'][0] - cname_alias_record_detail['Name'].should.equal('prod.redis.db.') - cname_alias_record_detail['Type'].should.equal('A') - cname_alias_record_detail['TTL'].should.equal(60) - cname_alias_record_detail['AliasTarget'].should.equal({ - 'HostedZoneId': hosted_zone_id, - 'DNSName': 'prod.redis.alias.', - 'EvaluateTargetHealth': False, - }) - cname_alias_record_detail.should_not.contain('ResourceRecords') + cname_alias_record_detail = response["ResourceRecordSets"][0] + cname_alias_record_detail["Name"].should.equal("prod.redis.db.") + cname_alias_record_detail["Type"].should.equal("A") + cname_alias_record_detail["TTL"].should.equal(60) + cname_alias_record_detail["AliasTarget"].should.equal( + { + "HostedZoneId": hosted_zone_id, + "DNSName": "prod.redis.alias.", + "EvaluateTargetHealth": False, + } + ) + cname_alias_record_detail.should_not.contain("ResourceRecords") # Delete record with wrong type. delete_payload = { - 'Comment': 'delete prod.redis.db', - 'Changes': [ + "Comment": "delete prod.redis.db", + "Changes": [ { - 'Action': 'DELETE', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db', - 'Type': 'CNAME', - } + "Action": "DELETE", + "ResourceRecordSet": {"Name": "prod.redis.db", "Type": "CNAME"}, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(1) + len(response["ResourceRecordSets"]).should.equal(1) # Delete record. delete_payload = { - 'Comment': 'delete prod.redis.db', - 'Changes': [ + "Comment": "delete prod.redis.db", + "Changes": [ { - 'Action': 'DELETE', - 'ResourceRecordSet': { - 'Name': 'prod.redis.db', - 'Type': 'A', - } + "Action": "DELETE", + "ResourceRecordSet": {"Name": "prod.redis.db", "Type": "A"}, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=delete_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(0) + len(response["ResourceRecordSets"]).should.equal(0) + @mock_route53 def test_change_weighted_resource_record_sets(): - conn = boto3.client('route53', region_name='us-east-2') + conn = boto3.client("route53", region_name="us-east-2") conn.create_hosted_zone( - Name='test.vpc.internal.', - CallerReference=str(hash('test')) + Name="test.vpc.internal.", CallerReference=str(hash("test")) ) - zones = conn.list_hosted_zones_by_name( - DNSName='test.vpc.internal.' - ) + zones = conn.list_hosted_zones_by_name(DNSName="test.vpc.internal.") - hosted_zone_id = zones['HostedZones'][0]['Id'] + hosted_zone_id = zones["HostedZones"][0]["Id"] - #Create 2 weighted records + # Create 2 weighted records conn.change_resource_record_sets( HostedZoneId=hosted_zone_id, ChangeBatch={ - 'Changes': [ + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'test.vpc.internal', - 'Type': 'A', - 'SetIdentifier': 'test1', - 'Weight': 50, - 'AliasTarget': { - 'HostedZoneId': 'Z3AADJGX6KTTL2', - 'DNSName': 'internal-test1lb-447688172.us-east-2.elb.amazonaws.com.', - 'EvaluateTargetHealth': True - } - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "test.vpc.internal", + "Type": "A", + "SetIdentifier": "test1", + "Weight": 50, + "AliasTarget": { + "HostedZoneId": "Z3AADJGX6KTTL2", + "DNSName": "internal-test1lb-447688172.us-east-2.elb.amazonaws.com.", + "EvaluateTargetHealth": True, + }, + }, }, - { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'test.vpc.internal', - 'Type': 'A', - 'SetIdentifier': 'test2', - 'Weight': 50, - 'AliasTarget': { - 'HostedZoneId': 'Z3AADJGX6KTTL2', - 'DNSName': 'internal-testlb2-1116641781.us-east-2.elb.amazonaws.com.', - 'EvaluateTargetHealth': True - } - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "test.vpc.internal", + "Type": "A", + "SetIdentifier": "test2", + "Weight": 50, + "AliasTarget": { + "HostedZoneId": "Z3AADJGX6KTTL2", + "DNSName": "internal-testlb2-1116641781.us-east-2.elb.amazonaws.com.", + "EvaluateTargetHealth": True, + }, + }, + }, + ] + }, + ) + + response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) + record = response["ResourceRecordSets"][0] + # Update the first record to have a weight of 90 + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, + ChangeBatch={ + "Changes": [ + { + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": record["Name"], + "Type": record["Type"], + "SetIdentifier": record["SetIdentifier"], + "Weight": 90, + "AliasTarget": { + "HostedZoneId": record["AliasTarget"]["HostedZoneId"], + "DNSName": record["AliasTarget"]["DNSName"], + "EvaluateTargetHealth": record["AliasTarget"][ + "EvaluateTargetHealth" + ], + }, + }, } ] - } + }, ) - response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - record = response['ResourceRecordSets'][0] - #Update the first record to have a weight of 90 + record = response["ResourceRecordSets"][1] + # Update the second record to have a weight of 10 conn.change_resource_record_sets( HostedZoneId=hosted_zone_id, ChangeBatch={ - 'Changes' : [ + "Changes": [ { - 'Action' : 'UPSERT', - 'ResourceRecordSet' : { - 'Name' : record['Name'], - 'Type' : record['Type'], - 'SetIdentifier' : record['SetIdentifier'], - 'Weight' : 90, - 'AliasTarget' : { - 'HostedZoneId' : record['AliasTarget']['HostedZoneId'], - 'DNSName' : record['AliasTarget']['DNSName'], - 'EvaluateTargetHealth' : record['AliasTarget']['EvaluateTargetHealth'] - } - } - }, + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": record["Name"], + "Type": record["Type"], + "SetIdentifier": record["SetIdentifier"], + "Weight": 10, + "AliasTarget": { + "HostedZoneId": record["AliasTarget"]["HostedZoneId"], + "DNSName": record["AliasTarget"]["DNSName"], + "EvaluateTargetHealth": record["AliasTarget"][ + "EvaluateTargetHealth" + ], + }, + }, + } ] - } - ) - - record = response['ResourceRecordSets'][1] - #Update the second record to have a weight of 10 - conn.change_resource_record_sets( - HostedZoneId=hosted_zone_id, - ChangeBatch={ - 'Changes' : [ - { - 'Action' : 'UPSERT', - 'ResourceRecordSet' : { - 'Name' : record['Name'], - 'Type' : record['Type'], - 'SetIdentifier' : record['SetIdentifier'], - 'Weight' : 10, - 'AliasTarget' : { - 'HostedZoneId' : record['AliasTarget']['HostedZoneId'], - 'DNSName' : record['AliasTarget']['DNSName'], - 'EvaluateTargetHealth' : record['AliasTarget']['EvaluateTargetHealth'] - } - } - }, - ] - } + }, ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - for record in response['ResourceRecordSets']: - if record['SetIdentifier'] == 'test1': - record['Weight'].should.equal(90) - if record['SetIdentifier'] == 'test2': - record['Weight'].should.equal(10) + for record in response["ResourceRecordSets"]: + if record["SetIdentifier"] == "test1": + record["Weight"].should.equal(90) + if record["SetIdentifier"] == "test2": + record["Weight"].should.equal(10) @mock_route53 def test_change_resource_record_invalid(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") conn.create_hosted_zone( Name="db.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="db", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="db"), ) zones = conn.list_hosted_zones_by_name(DNSName="db.") @@ -784,92 +768,89 @@ def test_change_resource_record_invalid(): hosted_zone_id = zones["HostedZones"][0]["Id"] invalid_a_record_payload = { - 'Comment': 'this should fail', - 'Changes': [ + "Comment": "this should fail", + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': 'prod.scooby.doo', - 'Type': 'A', - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": "prod.scooby.doo", + "Type": "A", + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } with assert_raises(botocore.exceptions.ClientError): - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=invalid_a_record_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=invalid_a_record_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(0) + len(response["ResourceRecordSets"]).should.equal(0) invalid_cname_record_payload = { - 'Comment': 'this should also fail', - 'Changes': [ + "Comment": "this should also fail", + "Changes": [ { - 'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': 'prod.scooby.doo', - 'Type': 'CNAME', - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "UPSERT", + "ResourceRecordSet": { + "Name": "prod.scooby.doo", + "Type": "CNAME", + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } with assert_raises(botocore.exceptions.ClientError): - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=invalid_cname_record_payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=invalid_cname_record_payload + ) response = conn.list_resource_record_sets(HostedZoneId=hosted_zone_id) - len(response['ResourceRecordSets']).should.equal(0) + len(response["ResourceRecordSets"]).should.equal(0) @mock_route53 def test_list_resource_record_sets_name_type_filters(): - conn = boto3.client('route53', region_name='us-east-1') + conn = boto3.client("route53", region_name="us-east-1") create_hosted_zone_response = conn.create_hosted_zone( Name="db.", - CallerReference=str(hash('foo')), - HostedZoneConfig=dict( - PrivateZone=True, - Comment="db", - ) + CallerReference=str(hash("foo")), + HostedZoneConfig=dict(PrivateZone=True, Comment="db"), ) - hosted_zone_id = create_hosted_zone_response['HostedZone']['Id'] + hosted_zone_id = create_hosted_zone_response["HostedZone"]["Id"] def create_resource_record_set(rec_type, rec_name): payload = { - 'Comment': 'create {} record {}'.format(rec_type, rec_name), - 'Changes': [ + "Comment": "create {} record {}".format(rec_type, rec_name), + "Changes": [ { - 'Action': 'CREATE', - 'ResourceRecordSet': { - 'Name': rec_name, - 'Type': rec_type, - 'TTL': 10, - 'ResourceRecords': [{ - 'Value': '127.0.0.1' - }] - } + "Action": "CREATE", + "ResourceRecordSet": { + "Name": rec_name, + "Type": rec_type, + "TTL": 10, + "ResourceRecords": [{"Value": "127.0.0.1"}], + }, } - ] + ], } - conn.change_resource_record_sets(HostedZoneId=hosted_zone_id, ChangeBatch=payload) + conn.change_resource_record_sets( + HostedZoneId=hosted_zone_id, ChangeBatch=payload + ) # record_type, record_name all_records = [ - ('A', 'a.a.db.'), - ('A', 'a.b.db.'), - ('A', 'b.b.db.'), - ('CNAME', 'b.b.db.'), - ('CNAME', 'b.c.db.'), - ('CNAME', 'c.c.db.') + ("A", "a.a.db."), + ("A", "a.b.db."), + ("A", "b.b.db."), + ("CNAME", "b.b.db."), + ("CNAME", "b.c.db."), + ("CNAME", "c.c.db."), ] for record_type, record_name in all_records: create_resource_record_set(record_type, record_name) @@ -878,10 +859,12 @@ def test_list_resource_record_sets_name_type_filters(): response = conn.list_resource_record_sets( HostedZoneId=hosted_zone_id, StartRecordType=all_records[start_with][0], - StartRecordName=all_records[start_with][1] + StartRecordName=all_records[start_with][1], ) - returned_records = [(record['Type'], record['Name']) for record in response['ResourceRecordSets']] + returned_records = [ + (record["Type"], record["Name"]) for record in response["ResourceRecordSets"] + ] len(returned_records).should.equal(len(all_records) - start_with) for desired_record in all_records[start_with:]: returned_records.should.contain(desired_record) diff --git a/tests/test_s3/test_s3.py b/tests/test_s3/test_s3.py index 6de511ef7..8f3c3538c 100644 --- a/tests/test_s3/test_s3.py +++ b/tests/test_s3/test_s3.py @@ -62,21 +62,20 @@ def reduced_min_part_size(f): class MyModel(object): - def __init__(self, name, value): self.name = name self.value = value def save(self): - s3 = boto3.client('s3', region_name='us-east-1') - s3.put_object(Bucket='mybucket', Key=self.name, Body=self.value) + s3 = boto3.client("s3", region_name="us-east-1") + s3.put_object(Bucket="mybucket", Key=self.name, Body=self.value) @mock_s3 def test_keys_are_pickleable(): """Keys must be pickleable due to boto3 implementation details.""" - key = s3model.FakeKey('name', b'data!') - assert key.value == b'data!' + key = s3model.FakeKey("name", b"data!") + assert key.value == b"data!" pickled = pickle.dumps(key) loaded = pickle.loads(pickled) @@ -85,72 +84,73 @@ def test_keys_are_pickleable(): @mock_s3 def test_append_to_value__basic(): - key = s3model.FakeKey('name', b'data!') - assert key.value == b'data!' + key = s3model.FakeKey("name", b"data!") + assert key.value == b"data!" assert key.size == 5 - key.append_to_value(b' And even more data') - assert key.value == b'data! And even more data' + key.append_to_value(b" And even more data") + assert key.value == b"data! And even more data" assert key.size == 24 @mock_s3 def test_append_to_value__nothing_added(): - key = s3model.FakeKey('name', b'data!') - assert key.value == b'data!' + key = s3model.FakeKey("name", b"data!") + assert key.value == b"data!" assert key.size == 5 - key.append_to_value(b'') - assert key.value == b'data!' + key.append_to_value(b"") + assert key.value == b"data!" assert key.size == 5 @mock_s3 def test_append_to_value__empty_key(): - key = s3model.FakeKey('name', b'') - assert key.value == b'' + key = s3model.FakeKey("name", b"") + assert key.value == b"" assert key.size == 0 - key.append_to_value(b'stuff') - assert key.value == b'stuff' + key.append_to_value(b"stuff") + assert key.value == b"stuff" assert key.size == 5 @mock_s3 def test_my_model_save(): # Create Bucket so that test can run - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket='mybucket') + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket="mybucket") #################################### - model_instance = MyModel('steve', 'is awesome') + model_instance = MyModel("steve", "is awesome") model_instance.save() - body = conn.Object('mybucket', 'steve').get()['Body'].read().decode() + body = conn.Object("mybucket", "steve").get()["Body"].read().decode() - assert body == 'is awesome' + assert body == "is awesome" @mock_s3 def test_key_etag(): - conn = boto3.resource('s3', region_name='us-east-1') - conn.create_bucket(Bucket='mybucket') + conn = boto3.resource("s3", region_name="us-east-1") + conn.create_bucket(Bucket="mybucket") - model_instance = MyModel('steve', 'is awesome') + model_instance = MyModel("steve", "is awesome") model_instance.save() - conn.Bucket('mybucket').Object('steve').e_tag.should.equal( - '"d32bda93738f7e03adb22e66c90fbc04"') + conn.Bucket("mybucket").Object("steve").e_tag.should.equal( + '"d32bda93738f7e03adb22e66c90fbc04"' + ) @mock_s3_deprecated def test_multipart_upload_too_small(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - multipart.upload_part_from_file(BytesIO(b'hello'), 1) - multipart.upload_part_from_file(BytesIO(b'world'), 2) + multipart.upload_part_from_file(BytesIO(b"hello"), 1) + multipart.upload_part_from_file(BytesIO(b"world"), 2) # Multipart with total size under 5MB is refused multipart.complete_upload.should.throw(S3ResponseError) @@ -158,48 +158,45 @@ def test_multipart_upload_too_small(): @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" multipart.upload_part_from_file(BytesIO(part2), 2) multipart.complete_upload() # we should get both parts as the key contents - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + part2) + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_out_of_order(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" multipart.upload_part_from_file(BytesIO(part2), 4) - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 2) multipart.complete_upload() # we should get both parts as the key contents - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + part2) + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_with_headers(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") - multipart = bucket.initiate_multipart_upload( - "the-key", metadata={"foo": "bar"}) - part1 = b'0' * 10 + multipart = bucket.initiate_multipart_upload("the-key", metadata={"foo": "bar"}) + part1 = b"0" * 10 multipart.upload_part_from_file(BytesIO(part1), 1) multipart.complete_upload() @@ -210,29 +207,28 @@ def test_multipart_upload_with_headers(): @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_with_copy_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "original-key" key.set_contents_from_string("key_value") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) multipart.copy_part_from_key("foobar", "original-key", 2, 0, 3) multipart.complete_upload() - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + b"key_") + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + b"key_") @mock_s3_deprecated @reduced_min_part_size def test_multipart_upload_cancel(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) multipart.cancel_upload() # TODO we really need some sort of assertion here, but we don't currently @@ -243,14 +239,14 @@ def test_multipart_upload_cancel(): @reduced_min_part_size def test_multipart_etag(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" multipart.upload_part_from_file(BytesIO(part2), 2) multipart.complete_upload() # we should get both parts as the key contents @@ -261,43 +257,45 @@ def test_multipart_etag(): @reduced_min_part_size def test_multipart_invalid_order(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * 5242880 + part1 = b"0" * 5242880 etag1 = multipart.upload_part_from_file(BytesIO(part1), 1).etag # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" etag2 = multipart.upload_part_from_file(BytesIO(part2), 2).etag xml = "{0}{1}" xml = xml.format(2, etag2) + xml.format(1, etag1) xml = "{0}".format(xml) bucket.complete_multipart_upload.when.called_with( - multipart.key_name, multipart.id, xml).should.throw(S3ResponseError) + multipart.key_name, multipart.id, xml + ).should.throw(S3ResponseError) @mock_s3_deprecated @reduced_min_part_size def test_multipart_etag_quotes_stripped(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE etag1 = multipart.upload_part_from_file(BytesIO(part1), 1).etag # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" etag2 = multipart.upload_part_from_file(BytesIO(part2), 2).etag # Strip quotes from etags - etag1 = etag1.replace('"', '') - etag2 = etag2.replace('"', '') + etag1 = etag1.replace('"', "") + etag2 = etag2.replace('"', "") xml = "{0}{1}" xml = xml.format(1, etag1) + xml.format(2, etag2) xml = "{0}".format(xml) bucket.complete_multipart_upload.when.called_with( - multipart.key_name, multipart.id, xml).should_not.throw(S3ResponseError) + multipart.key_name, multipart.id, xml + ).should_not.throw(S3ResponseError) # we should get both parts as the key contents bucket.get_key("the-key").etag.should.equal(EXPECTED_ETAG) @@ -305,34 +303,34 @@ def test_multipart_etag_quotes_stripped(): @mock_s3_deprecated @reduced_min_part_size def test_multipart_duplicate_upload(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") multipart = bucket.initiate_multipart_upload("the-key") - part1 = b'0' * REDUCED_PART_SIZE + part1 = b"0" * REDUCED_PART_SIZE multipart.upload_part_from_file(BytesIO(part1), 1) # same part again multipart.upload_part_from_file(BytesIO(part1), 1) - part2 = b'1' * 1024 + part2 = b"1" * 1024 multipart.upload_part_from_file(BytesIO(part2), 2) multipart.complete_upload() # We should get only one copy of part 1. - bucket.get_key( - "the-key").get_contents_as_string().should.equal(part1 + part2) + bucket.get_key("the-key").get_contents_as_string().should.equal(part1 + part2) @mock_s3_deprecated def test_list_multiparts(): # Create Bucket so that test can run - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("mybucket") multipart1 = bucket.initiate_multipart_upload("one-key") multipart2 = bucket.initiate_multipart_upload("two-key") uploads = bucket.get_all_multipart_uploads() uploads.should.have.length_of(2) dict([(u.key_name, u.id) for u in uploads]).should.equal( - {'one-key': multipart1.id, 'two-key': multipart2.id}) + {"one-key": multipart1.id, "two-key": multipart2.id} + ) multipart2.cancel_upload() uploads = bucket.get_all_multipart_uploads() uploads.should.have.length_of(1) @@ -344,34 +342,36 @@ def test_list_multiparts(): @mock_s3_deprecated def test_key_save_to_missing_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.get_bucket('mybucket', validate=False) + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.get_bucket("mybucket", validate=False) key = Key(bucket) key.key = "the-key" - key.set_contents_from_string.when.called_with( - "foobar").should.throw(S3ResponseError) + key.set_contents_from_string.when.called_with("foobar").should.throw( + S3ResponseError + ) @mock_s3_deprecated def test_missing_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") bucket.get_key("the-key").should.equal(None) @mock_s3_deprecated def test_missing_key_urllib2(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") conn.create_bucket("foobar") - urlopen.when.called_with( - "http://foobar.s3.amazonaws.com/the-key").should.throw(HTTPError) + urlopen.when.called_with("http://foobar.s3.amazonaws.com/the-key").should.throw( + HTTPError + ) @mock_s3_deprecated def test_empty_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" @@ -379,12 +379,12 @@ def test_empty_key(): key = bucket.get_key("the-key") key.size.should.equal(0) - key.get_contents_as_string().should.equal(b'') + key.get_contents_as_string().should.equal(b"") @mock_s3_deprecated def test_empty_key_set_on_existing_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" @@ -392,63 +392,55 @@ def test_empty_key_set_on_existing_key(): key = bucket.get_key("the-key") key.size.should.equal(6) - key.get_contents_as_string().should.equal(b'foobar') + key.get_contents_as_string().should.equal(b"foobar") key.set_contents_from_string("") - bucket.get_key("the-key").get_contents_as_string().should.equal(b'') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"") @mock_s3_deprecated def test_large_key_save(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("foobar" * 100000) - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b'foobar' * 100000) + bucket.get_key("the-key").get_contents_as_string().should.equal(b"foobar" * 100000) @mock_s3_deprecated def test_copy_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-key') + bucket.copy_key("new-key", "foobar", "the-key") - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b"some value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("the-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") -@parameterized([ - ("the-unicode-💩-key",), - ("key-with?question-mark",), -]) +@parameterized([("the-unicode-💩-key",), ("key-with?question-mark",)]) @mock_s3_deprecated def test_copy_key_with_special_chars(key_name): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = key_name key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', key_name) + bucket.copy_key("new-key", "foobar", key_name) - bucket.get_key( - key_name).get_contents_as_string().should.equal(b"some value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key(key_name).get_contents_as_string().should.equal(b"some value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") @mock_s3_deprecated def test_copy_key_with_version(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") bucket.configure_versioning(versioning=True) key = Key(bucket) @@ -456,46 +448,40 @@ def test_copy_key_with_version(): key.set_contents_from_string("some value") key.set_contents_from_string("another value") - key = [ - key.version_id - for key in bucket.get_all_versions() - if not key.is_latest - ][0] - bucket.copy_key('new-key', 'foobar', 'the-key', src_version_id=key) + key = [key.version_id for key in bucket.get_all_versions() if not key.is_latest][0] + bucket.copy_key("new-key", "foobar", "the-key", src_version_id=key) - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b"another value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("the-key").get_contents_as_string().should.equal(b"another value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") @mock_s3_deprecated def test_set_metadata(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) - key.key = 'the-key' - key.set_metadata('md', 'Metadatastring') + key.key = "the-key" + key.set_metadata("md", "Metadatastring") key.set_contents_from_string("Testval") - bucket.get_key('the-key').get_metadata('md').should.equal('Metadatastring') + bucket.get_key("the-key").get_metadata("md").should.equal("Metadatastring") @mock_s3_deprecated def test_copy_key_replace_metadata(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" - key.set_metadata('md', 'Metadatastring') + key.set_metadata("md", "Metadatastring") key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-key', - metadata={'momd': 'Mometadatastring'}) + bucket.copy_key( + "new-key", "foobar", "the-key", metadata={"momd": "Mometadatastring"} + ) - bucket.get_key("new-key").get_metadata('md').should.be.none - bucket.get_key( - "new-key").get_metadata('momd').should.equal('Mometadatastring') + bucket.get_key("new-key").get_metadata("md").should.be.none + bucket.get_key("new-key").get_metadata("momd").should.equal("Mometadatastring") @freeze_time("2012-01-01 12:00:00") @@ -509,23 +495,23 @@ def test_last_modified(): key.set_contents_from_string("some value") rs = bucket.get_all_keys() - rs[0].last_modified.should.equal('2012-01-01T12:00:00.000Z') + rs[0].last_modified.should.equal("2012-01-01T12:00:00.000Z") - bucket.get_key( - "the-key").last_modified.should.equal('Sun, 01 Jan 2012 12:00:00 GMT') + bucket.get_key("the-key").last_modified.should.equal( + "Sun, 01 Jan 2012 12:00:00 GMT" + ) @mock_s3_deprecated def test_missing_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') - conn.get_bucket.when.called_with('mybucket').should.throw(S3ResponseError) + conn = boto.connect_s3("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket").should.throw(S3ResponseError) @mock_s3_deprecated def test_bucket_with_dash(): - conn = boto.connect_s3('the_key', 'the_secret') - conn.get_bucket.when.called_with( - 'mybucket-test').should.throw(S3ResponseError) + conn = boto.connect_s3("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket-test").should.throw(S3ResponseError) @mock_s3_deprecated @@ -534,7 +520,7 @@ def test_create_existing_bucket(): conn = boto.s3.connect_to_region("us-west-2") conn.create_bucket("foobar") with assert_raises(S3CreateError): - conn.create_bucket('foobar') + conn.create_bucket("foobar") @mock_s3_deprecated @@ -556,15 +542,14 @@ def test_create_existing_bucket_in_us_east_1(): @mock_s3_deprecated def test_other_region(): - conn = S3Connection( - 'key', 'secret', host='s3-website-ap-southeast-2.amazonaws.com') + conn = S3Connection("key", "secret", host="s3-website-ap-southeast-2.amazonaws.com") conn.create_bucket("foobar") list(conn.get_bucket("foobar").get_all_keys()).should.equal([]) @mock_s3_deprecated def test_bucket_deletion(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) @@ -586,7 +571,7 @@ def test_bucket_deletion(): @mock_s3_deprecated def test_get_all_buckets(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") conn.create_bucket("foobar") conn.create_bucket("foobar2") buckets = conn.get_all_buckets() @@ -597,36 +582,34 @@ def test_get_all_buckets(): @mock_s3 @mock_s3_deprecated def test_post_to_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") - requests.post("https://foobar.s3.amazonaws.com/", { - 'key': 'the-key', - 'file': 'nothing' - }) + requests.post( + "https://foobar.s3.amazonaws.com/", {"key": "the-key", "file": "nothing"} + ) - bucket.get_key('the-key').get_contents_as_string().should.equal(b'nothing') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"nothing") @mock_s3 @mock_s3_deprecated def test_post_with_metadata_to_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") - requests.post("https://foobar.s3.amazonaws.com/", { - 'key': 'the-key', - 'file': 'nothing', - 'x-amz-meta-test': 'metadata' - }) + requests.post( + "https://foobar.s3.amazonaws.com/", + {"key": "the-key", "file": "nothing", "x-amz-meta-test": "metadata"}, + ) - bucket.get_key('the-key').get_metadata('test').should.equal('metadata') + bucket.get_key("the-key").get_metadata("test").should.equal("metadata") @mock_s3_deprecated def test_delete_missing_key(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") deleted_key = bucket.delete_key("foobar") deleted_key.key.should.equal("foobar") @@ -634,40 +617,40 @@ def test_delete_missing_key(): @mock_s3_deprecated def test_delete_keys(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") - result = bucket.delete_keys(['file2', 'file3']) + result = bucket.delete_keys(["file2", "file3"]) result.deleted.should.have.length_of(2) result.errors.should.have.length_of(0) keys = bucket.get_all_keys() keys.should.have.length_of(2) - keys[0].name.should.equal('file1') + keys[0].name.should.equal("file1") @mock_s3_deprecated def test_delete_keys_invalid(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") # non-existing key case - result = bucket.delete_keys(['abc', 'file3']) + result = bucket.delete_keys(["abc", "file3"]) result.deleted.should.have.length_of(1) result.errors.should.have.length_of(1) keys = bucket.get_all_keys() keys.should.have.length_of(3) - keys[0].name.should.equal('file1') + keys[0].name.should.equal("file1") # empty keys result = bucket.delete_keys([]) @@ -679,133 +662,137 @@ def test_delete_keys_invalid(): @mock_s3 def test_boto3_delete_empty_keys_list(): with assert_raises(ClientError) as err: - boto3.client('s3').delete_objects(Bucket='foobar', Delete={'Objects': []}) + boto3.client("s3").delete_objects(Bucket="foobar", Delete={"Objects": []}) assert err.exception.response["Error"]["Code"] == "MalformedXML" @mock_s3_deprecated def test_bucket_name_with_dot(): conn = boto.connect_s3() - bucket = conn.create_bucket('firstname.lastname') + bucket = conn.create_bucket("firstname.lastname") - k = Key(bucket, 'somekey') - k.set_contents_from_string('somedata') + k = Key(bucket, "somekey") + k.set_contents_from_string("somedata") @mock_s3_deprecated def test_key_with_special_characters(): conn = boto.connect_s3() - bucket = conn.create_bucket('test_bucket_name') + bucket = conn.create_bucket("test_bucket_name") - key = Key(bucket, 'test_list_keys_2/x?y') - key.set_contents_from_string('value1') + key = Key(bucket, "test_list_keys_2/x?y") + key.set_contents_from_string("value1") - key_list = bucket.list('test_list_keys_2/', '/') + key_list = bucket.list("test_list_keys_2/", "/") keys = [x for x in key_list] keys[0].name.should.equal("test_list_keys_2/x?y") @mock_s3_deprecated def test_unicode_key_with_slash(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "/the-key-unîcode/test" key.set_contents_from_string("value") key = bucket.get_key("/the-key-unîcode/test") - key.get_contents_as_string().should.equal(b'value') + key.get_contents_as_string().should.equal(b"value") @mock_s3_deprecated def test_bucket_key_listing_order(): conn = boto.connect_s3() - bucket = conn.create_bucket('test_bucket') - prefix = 'toplevel/' + bucket = conn.create_bucket("test_bucket") + prefix = "toplevel/" def store(name): k = Key(bucket, prefix + name) - k.set_contents_from_string('somedata') + k.set_contents_from_string("somedata") - names = ['x/key', 'y.key1', 'y.key2', 'y.key3', 'x/y/key', 'x/y/z/key'] + names = ["x/key", "y.key1", "y.key2", "y.key3", "x/y/key", "x/y/z/key"] for name in names: store(name) delimiter = None keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/x/key', 'toplevel/x/y/key', 'toplevel/x/y/z/key', - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3' - ]) + keys.should.equal( + [ + "toplevel/x/key", + "toplevel/x/y/key", + "toplevel/x/y/z/key", + "toplevel/y.key1", + "toplevel/y.key2", + "toplevel/y.key3", + ] + ) - delimiter = '/' + delimiter = "/" keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3', 'toplevel/x/' - ]) + keys.should.equal( + ["toplevel/y.key1", "toplevel/y.key2", "toplevel/y.key3", "toplevel/x/"] + ) # Test delimiter with no prefix - delimiter = '/' + delimiter = "/" keys = [x.name for x in bucket.list(prefix=None, delimiter=delimiter)] - keys.should.equal(['toplevel/']) + keys.should.equal(["toplevel/"]) delimiter = None - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal( - [u'toplevel/x/key', u'toplevel/x/y/key', u'toplevel/x/y/z/key']) + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/key", "toplevel/x/y/key", "toplevel/x/y/z/key"]) - delimiter = '/' - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal([u'toplevel/x/']) + delimiter = "/" + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/"]) @mock_s3_deprecated def test_key_with_reduced_redundancy(): conn = boto.connect_s3() - bucket = conn.create_bucket('test_bucket_name') + bucket = conn.create_bucket("test_bucket_name") - key = Key(bucket, 'test_rr_key') - key.set_contents_from_string('value1', reduced_redundancy=True) + key = Key(bucket, "test_rr_key") + key.set_contents_from_string("value1", reduced_redundancy=True) # we use the bucket iterator because of: # https:/github.com/boto/boto/issues/1173 - list(bucket)[0].storage_class.should.equal('REDUCED_REDUNDANCY') + list(bucket)[0].storage_class.should.equal("REDUCED_REDUNDANCY") @mock_s3_deprecated def test_copy_key_reduced_redundancy(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-key', - storage_class='REDUCED_REDUNDANCY') + bucket.copy_key("new-key", "foobar", "the-key", storage_class="REDUCED_REDUNDANCY") # we use the bucket iterator because of: # https:/github.com/boto/boto/issues/1173 keys = dict([(k.name, k) for k in bucket]) - keys['new-key'].storage_class.should.equal("REDUCED_REDUNDANCY") - keys['the-key'].storage_class.should.equal("STANDARD") + keys["new-key"].storage_class.should.equal("REDUCED_REDUNDANCY") + keys["the-key"].storage_class.should.equal("STANDARD") @freeze_time("2012-01-01 12:00:00") @mock_s3_deprecated def test_restore_key(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") list(bucket)[0].ongoing_restore.should.be.none key.restore(1) - key = bucket.get_key('the-key') + key = bucket.get_key("the-key") key.ongoing_restore.should_not.be.none key.ongoing_restore.should.be.false key.expiry_date.should.equal("Mon, 02 Jan 2012 12:00:00 GMT") key.restore(2) - key = bucket.get_key('the-key') + key = bucket.get_key("the-key") key.ongoing_restore.should_not.be.none key.ongoing_restore.should.be.false key.expiry_date.should.equal("Tue, 03 Jan 2012 12:00:00 GMT") @@ -814,13 +801,13 @@ def test_restore_key(): @freeze_time("2012-01-01 12:00:00") @mock_s3_deprecated def test_restore_key_headers(): - conn = boto.connect_s3('the_key', 'the_secret') + conn = boto.connect_s3("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") - key.restore(1, headers={'foo': 'bar'}) - key = bucket.get_key('the-key') + key.restore(1, headers={"foo": "bar"}) + key = bucket.get_key("the-key") key.ongoing_restore.should_not.be.none key.ongoing_restore.should.be.false key.expiry_date.should.equal("Mon, 02 Jan 2012 12:00:00 GMT") @@ -828,51 +815,51 @@ def test_restore_key_headers(): @mock_s3_deprecated def test_get_versioning_status(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") d = bucket.get_versioning_status() d.should.be.empty bucket.configure_versioning(versioning=True) d = bucket.get_versioning_status() d.shouldnt.be.empty - d.should.have.key('Versioning').being.equal('Enabled') + d.should.have.key("Versioning").being.equal("Enabled") bucket.configure_versioning(versioning=False) d = bucket.get_versioning_status() - d.should.have.key('Versioning').being.equal('Suspended') + d.should.have.key("Versioning").being.equal("Suspended") @mock_s3_deprecated def test_key_version(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") bucket.configure_versioning(versioning=True) versions = [] key = Key(bucket) - key.key = 'the-key' + key.key = "the-key" key.version_id.should.be.none - key.set_contents_from_string('some string') + key.set_contents_from_string("some string") versions.append(key.version_id) - key.set_contents_from_string('some string') + key.set_contents_from_string("some string") versions.append(key.version_id) set(versions).should.have.length_of(2) - key = bucket.get_key('the-key') + key = bucket.get_key("the-key") key.version_id.should.equal(versions[-1]) @mock_s3_deprecated def test_list_versions(): - conn = boto.connect_s3('the_key', 'the_secret') - bucket = conn.create_bucket('foobar') + conn = boto.connect_s3("the_key", "the_secret") + bucket = conn.create_bucket("foobar") bucket.configure_versioning(versioning=True) key_versions = [] - key = Key(bucket, 'the-key') + key = Key(bucket, "the-key") key.version_id.should.be.none key.set_contents_from_string("Version 1") key_versions.append(key.version_id) @@ -883,32 +870,32 @@ def test_list_versions(): versions = list(bucket.list_versions()) versions.should.have.length_of(2) - versions[0].name.should.equal('the-key') + versions[0].name.should.equal("the-key") versions[0].version_id.should.equal(key_versions[0]) versions[0].get_contents_as_string().should.equal(b"Version 1") - versions[1].name.should.equal('the-key') + versions[1].name.should.equal("the-key") versions[1].version_id.should.equal(key_versions[1]) versions[1].get_contents_as_string().should.equal(b"Version 2") - key = Key(bucket, 'the2-key') + key = Key(bucket, "the2-key") key.set_contents_from_string("Version 1") keys = list(bucket.list()) keys.should.have.length_of(2) - versions = list(bucket.list_versions(prefix='the2-')) + versions = list(bucket.list_versions(prefix="the2-")) versions.should.have.length_of(1) @mock_s3_deprecated def test_acl_setting(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') - content = b'imafile' - keyname = 'test.txt' + bucket = conn.create_bucket("foobar") + content = b"imafile" + keyname = "test.txt" key = Key(bucket, name=keyname) - key.content_type = 'text/plain' + key.content_type = "text/plain" key.set_contents_from_string(content) key.make_public() @@ -917,147 +904,175 @@ def test_acl_setting(): assert key.get_contents_as_string() == content grants = key.get_acl().acl.grants - assert any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3_deprecated def test_acl_setting_via_headers(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') - content = b'imafile' - keyname = 'test.txt' + bucket = conn.create_bucket("foobar") + content = b"imafile" + keyname = "test.txt" key = Key(bucket, name=keyname) - key.content_type = 'text/plain' - key.set_contents_from_string(content, headers={ - 'x-amz-grant-full-control': 'uri="http://acs.amazonaws.com/groups/global/AllUsers"' - }) + key.content_type = "text/plain" + key.set_contents_from_string( + content, + headers={ + "x-amz-grant-full-control": 'uri="http://acs.amazonaws.com/groups/global/AllUsers"' + }, + ) key = bucket.get_key(keyname) assert key.get_contents_as_string() == content grants = key.get_acl().acl.grants - assert any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'FULL_CONTROL' for g in grants), grants + assert any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "FULL_CONTROL" + for g in grants + ), grants @mock_s3_deprecated def test_acl_switching(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') - content = b'imafile' - keyname = 'test.txt' + bucket = conn.create_bucket("foobar") + content = b"imafile" + keyname = "test.txt" key = Key(bucket, name=keyname) - key.content_type = 'text/plain' - key.set_contents_from_string(content, policy='public-read') - key.set_acl('private') + key.content_type = "text/plain" + key.set_contents_from_string(content, policy="public-read") + key.set_acl("private") grants = key.get_acl().acl.grants - assert not any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert not any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3_deprecated def test_bucket_acl_setting(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') + bucket = conn.create_bucket("foobar") bucket.make_public() grants = bucket.get_acl().acl.grants - assert any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3_deprecated def test_bucket_acl_switching(): conn = boto.connect_s3() - bucket = conn.create_bucket('foobar') + bucket = conn.create_bucket("foobar") bucket.make_public() - bucket.set_acl('private') + bucket.set_acl("private") grants = bucket.get_acl().acl.grants - assert not any(g.uri == 'http://acs.amazonaws.com/groups/global/AllUsers' and - g.permission == 'READ' for g in grants), grants + assert not any( + g.uri == "http://acs.amazonaws.com/groups/global/AllUsers" + and g.permission == "READ" + for g in grants + ), grants @mock_s3 def test_s3_object_in_public_bucket(): - s3 = boto3.resource('s3') - bucket = s3.Bucket('test-bucket') - bucket.create(ACL='public-read') - bucket.put_object(Body=b'ABCD', Key='file.txt') + s3 = boto3.resource("s3") + bucket = s3.Bucket("test-bucket") + bucket.create(ACL="public-read") + bucket.put_object(Body=b"ABCD", Key="file.txt") - s3_anonymous = boto3.resource('s3') - s3_anonymous.meta.client.meta.events.register('choose-signer.s3.*', disable_signing) + s3_anonymous = boto3.resource("s3") + s3_anonymous.meta.client.meta.events.register("choose-signer.s3.*", disable_signing) - contents = s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get()['Body'].read() - contents.should.equal(b'ABCD') + contents = ( + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket") + .get()["Body"] + .read() + ) + contents.should.equal(b"ABCD") - bucket.put_object(ACL='private', Body=b'ABCD', Key='file.txt') + bucket.put_object(ACL="private", Body=b"ABCD", Key="file.txt") with assert_raises(ClientError) as exc: - s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get() - exc.exception.response['Error']['Code'].should.equal('403') + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket").get() + exc.exception.response["Error"]["Code"].should.equal("403") - params = {'Bucket': 'test-bucket', 'Key': 'file.txt'} - presigned_url = boto3.client('s3').generate_presigned_url('get_object', params, ExpiresIn=900) + params = {"Bucket": "test-bucket", "Key": "file.txt"} + presigned_url = boto3.client("s3").generate_presigned_url( + "get_object", params, ExpiresIn=900 + ) response = requests.get(presigned_url) assert response.status_code == 200 @mock_s3 def test_s3_object_in_private_bucket(): - s3 = boto3.resource('s3') - bucket = s3.Bucket('test-bucket') - bucket.create(ACL='private') - bucket.put_object(ACL='private', Body=b'ABCD', Key='file.txt') + s3 = boto3.resource("s3") + bucket = s3.Bucket("test-bucket") + bucket.create(ACL="private") + bucket.put_object(ACL="private", Body=b"ABCD", Key="file.txt") - s3_anonymous = boto3.resource('s3') - s3_anonymous.meta.client.meta.events.register('choose-signer.s3.*', disable_signing) + s3_anonymous = boto3.resource("s3") + s3_anonymous.meta.client.meta.events.register("choose-signer.s3.*", disable_signing) with assert_raises(ClientError) as exc: - s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get() - exc.exception.response['Error']['Code'].should.equal('403') + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket").get() + exc.exception.response["Error"]["Code"].should.equal("403") - bucket.put_object(ACL='public-read', Body=b'ABCD', Key='file.txt') - contents = s3_anonymous.Object(key='file.txt', bucket_name='test-bucket').get()['Body'].read() - contents.should.equal(b'ABCD') + bucket.put_object(ACL="public-read", Body=b"ABCD", Key="file.txt") + contents = ( + s3_anonymous.Object(key="file.txt", bucket_name="test-bucket") + .get()["Body"] + .read() + ) + contents.should.equal(b"ABCD") @mock_s3_deprecated def test_unicode_key(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = Key(bucket) - key.key = u'こんにちは.jpg' - key.set_contents_from_string('Hello world!') + key.key = "こんにちは.jpg" + key.set_contents_from_string("Hello world!") assert [listed_key.key for listed_key in bucket.list()] == [key.key] fetched_key = bucket.get_key(key.key) assert fetched_key.key == key.key - assert fetched_key.get_contents_as_string().decode("utf-8") == 'Hello world!' + assert fetched_key.get_contents_as_string().decode("utf-8") == "Hello world!" @mock_s3_deprecated def test_unicode_value(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = Key(bucket) - key.key = 'some_key' - key.set_contents_from_string(u'こんにちは.jpg') + key.key = "some_key" + key.set_contents_from_string("こんにちは.jpg") list(bucket.list()) key = bucket.get_key(key.key) - assert key.get_contents_as_string().decode("utf-8") == u'こんにちは.jpg' + assert key.get_contents_as_string().decode("utf-8") == "こんにちは.jpg" @mock_s3_deprecated def test_setting_content_encoding(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = bucket.new_key("keyname") key.set_metadata("Content-Encoding", "gzip") compressed_data = "abcdef" @@ -1070,77 +1085,57 @@ def test_setting_content_encoding(): @mock_s3_deprecated def test_bucket_location(): conn = boto.s3.connect_to_region("us-west-2") - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") bucket.get_location().should.equal("us-west-2") @mock_s3 def test_bucket_location_us_east_1(): - cli = boto3.client('s3') - bucket_name = 'mybucket' + cli = boto3.client("s3") + bucket_name = "mybucket" # No LocationConstraint ==> us-east-1 cli.create_bucket(Bucket=bucket_name) - cli.get_bucket_location(Bucket=bucket_name)['LocationConstraint'].should.equal(None) + cli.get_bucket_location(Bucket=bucket_name)["LocationConstraint"].should.equal(None) @mock_s3_deprecated def test_ranged_get(): conn = boto.connect_s3() - bucket = conn.create_bucket('mybucket') + bucket = conn.create_bucket("mybucket") key = Key(bucket) - key.key = 'bigkey' + key.key = "bigkey" rep = b"0123456789" key.set_contents_from_string(rep * 10) # Implicitly bounded range requests. - key.get_contents_as_string( - headers={'Range': 'bytes=0-'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=50-'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=99-'}).should.equal(b'9') + key.get_contents_as_string(headers={"Range": "bytes=0-"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=50-"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=99-"}).should.equal(b"9") # Explicitly bounded range requests starting from the first byte. - key.get_contents_as_string( - headers={'Range': 'bytes=0-0'}).should.equal(b'0') - key.get_contents_as_string( - headers={'Range': 'bytes=0-49'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=0-99'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=0-100'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=0-700'}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=0-0"}).should.equal(b"0") + key.get_contents_as_string(headers={"Range": "bytes=0-49"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=0-99"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=0-100"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=0-700"}).should.equal(rep * 10) # Explicitly bounded range requests starting from the / a middle byte. - key.get_contents_as_string( - headers={'Range': 'bytes=50-54'}).should.equal(rep[:5]) - key.get_contents_as_string( - headers={'Range': 'bytes=50-99'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=50-100'}).should.equal(rep * 5) - key.get_contents_as_string( - headers={'Range': 'bytes=50-700'}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=50-54"}).should.equal(rep[:5]) + key.get_contents_as_string(headers={"Range": "bytes=50-99"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=50-100"}).should.equal(rep * 5) + key.get_contents_as_string(headers={"Range": "bytes=50-700"}).should.equal(rep * 5) # Explicitly bounded range requests starting from the last byte. - key.get_contents_as_string( - headers={'Range': 'bytes=99-99'}).should.equal(b'9') - key.get_contents_as_string( - headers={'Range': 'bytes=99-100'}).should.equal(b'9') - key.get_contents_as_string( - headers={'Range': 'bytes=99-700'}).should.equal(b'9') + key.get_contents_as_string(headers={"Range": "bytes=99-99"}).should.equal(b"9") + key.get_contents_as_string(headers={"Range": "bytes=99-100"}).should.equal(b"9") + key.get_contents_as_string(headers={"Range": "bytes=99-700"}).should.equal(b"9") # Suffix range requests. - key.get_contents_as_string( - headers={'Range': 'bytes=-1'}).should.equal(b'9') - key.get_contents_as_string( - headers={'Range': 'bytes=-60'}).should.equal(rep * 6) - key.get_contents_as_string( - headers={'Range': 'bytes=-100'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=-101'}).should.equal(rep * 10) - key.get_contents_as_string( - headers={'Range': 'bytes=-700'}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=-1"}).should.equal(b"9") + key.get_contents_as_string(headers={"Range": "bytes=-60"}).should.equal(rep * 6) + key.get_contents_as_string(headers={"Range": "bytes=-100"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=-101"}).should.equal(rep * 10) + key.get_contents_as_string(headers={"Range": "bytes=-700"}).should.equal(rep * 10) key.size.should.equal(100) @@ -1148,36 +1143,40 @@ def test_ranged_get(): @mock_s3_deprecated def test_policy(): conn = boto.connect_s3() - bucket_name = 'mybucket' + bucket_name = "mybucket" bucket = conn.create_bucket(bucket_name) - policy = json.dumps({ - "Version": "2012-10-17", - "Id": "PutObjPolicy", - "Statement": [ - { - "Sid": "DenyUnEncryptedObjectUploads", - "Effect": "Deny", - "Principal": "*", - "Action": "s3:PutObject", - "Resource": "arn:aws:s3:::{bucket_name}/*".format(bucket_name=bucket_name), - "Condition": { - "StringNotEquals": { - "s3:x-amz-server-side-encryption": "aws:kms" - } + policy = json.dumps( + { + "Version": "2012-10-17", + "Id": "PutObjPolicy", + "Statement": [ + { + "Sid": "DenyUnEncryptedObjectUploads", + "Effect": "Deny", + "Principal": "*", + "Action": "s3:PutObject", + "Resource": "arn:aws:s3:::{bucket_name}/*".format( + bucket_name=bucket_name + ), + "Condition": { + "StringNotEquals": { + "s3:x-amz-server-side-encryption": "aws:kms" + } + }, } - } - ] - }) + ], + } + ) with assert_raises(S3ResponseError) as err: bucket.get_policy() ex = err.exception ex.box_usage.should.be.none - ex.error_code.should.equal('NoSuchBucketPolicy') - ex.message.should.equal('The bucket policy does not exist') - ex.reason.should.equal('Not Found') + ex.error_code.should.equal("NoSuchBucketPolicy") + ex.message.should.equal("The bucket policy does not exist") + ex.reason.should.equal("Not Found") ex.resource.should.be.none ex.status.should.equal(404) ex.body.should.contain(bucket_name) @@ -1187,7 +1186,7 @@ def test_policy(): bucket = conn.get_bucket(bucket_name) - bucket.get_policy().decode('utf-8').should.equal(policy) + bucket.get_policy().decode("utf-8").should.equal(policy) bucket.delete_policy() @@ -1198,7 +1197,7 @@ def test_policy(): @mock_s3_deprecated def test_website_configuration_xml(): conn = boto.connect_s3() - bucket = conn.create_bucket('test-bucket') + bucket = conn.create_bucket("test-bucket") bucket.set_website_configuration_xml(TEST_XML) bucket.get_website_configuration_xml().should.equal(TEST_XML) @@ -1206,124 +1205,129 @@ def test_website_configuration_xml(): @mock_s3_deprecated def test_key_with_trailing_slash_in_ordinary_calling_format(): conn = boto.connect_s3( - 'access_key', - 'secret_key', - calling_format=boto.s3.connection.OrdinaryCallingFormat() + "access_key", + "secret_key", + calling_format=boto.s3.connection.OrdinaryCallingFormat(), ) - bucket = conn.create_bucket('test_bucket_name') + bucket = conn.create_bucket("test_bucket_name") - key_name = 'key_with_slash/' + key_name = "key_with_slash/" key = Key(bucket, key_name) - key.set_contents_from_string('some value') + key.set_contents_from_string("some value") [k.name for k in bucket.get_all_keys()].should.contain(key_name) @mock_s3 def test_boto3_key_etag(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='steve', Body=b'is awesome') - resp = s3.get_object(Bucket='mybucket', Key='steve') - resp['ETag'].should.equal('"d32bda93738f7e03adb22e66c90fbc04"') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="steve", Body=b"is awesome") + resp = s3.get_object(Bucket="mybucket", Key="steve") + resp["ETag"].should.equal('"d32bda93738f7e03adb22e66c90fbc04"') @mock_s3 def test_website_redirect_location(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") - s3.put_object(Bucket='mybucket', Key='steve', Body=b'is awesome') - resp = s3.get_object(Bucket='mybucket', Key='steve') - resp.get('WebsiteRedirectLocation').should.be.none + s3.put_object(Bucket="mybucket", Key="steve", Body=b"is awesome") + resp = s3.get_object(Bucket="mybucket", Key="steve") + resp.get("WebsiteRedirectLocation").should.be.none - url = 'https://github.com/spulec/moto' - s3.put_object(Bucket='mybucket', Key='steve', Body=b'is awesome', WebsiteRedirectLocation=url) - resp = s3.get_object(Bucket='mybucket', Key='steve') - resp['WebsiteRedirectLocation'].should.equal(url) + url = "https://github.com/spulec/moto" + s3.put_object( + Bucket="mybucket", Key="steve", Body=b"is awesome", WebsiteRedirectLocation=url + ) + resp = s3.get_object(Bucket="mybucket", Key="steve") + resp["WebsiteRedirectLocation"].should.equal(url) @mock_s3 def test_boto3_list_objects_truncated_response(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='one', Body=b'1') - s3.put_object(Bucket='mybucket', Key='two', Body=b'22') - s3.put_object(Bucket='mybucket', Key='three', Body=b'333') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"1") + s3.put_object(Bucket="mybucket", Key="two", Body=b"22") + s3.put_object(Bucket="mybucket", Key="three", Body=b"333") # First list - resp = s3.list_objects(Bucket='mybucket', MaxKeys=1) - listed_object = resp['Contents'][0] + resp = s3.list_objects(Bucket="mybucket", MaxKeys=1) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'one' - assert resp['MaxKeys'] == 1 - assert resp['IsTruncated'] == True - assert resp['Prefix'] == 'None' - assert resp['Delimiter'] == 'None' - assert 'NextMarker' in resp + assert listed_object["Key"] == "one" + assert resp["MaxKeys"] == 1 + assert resp["IsTruncated"] == True + assert resp["Prefix"] == "None" + assert resp["Delimiter"] == "None" + assert "NextMarker" in resp next_marker = resp["NextMarker"] # Second list - resp = s3.list_objects( - Bucket='mybucket', MaxKeys=1, Marker=next_marker) - listed_object = resp['Contents'][0] + resp = s3.list_objects(Bucket="mybucket", MaxKeys=1, Marker=next_marker) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'three' - assert resp['MaxKeys'] == 1 - assert resp['IsTruncated'] == True - assert resp['Prefix'] == 'None' - assert resp['Delimiter'] == 'None' - assert 'NextMarker' in resp + assert listed_object["Key"] == "three" + assert resp["MaxKeys"] == 1 + assert resp["IsTruncated"] == True + assert resp["Prefix"] == "None" + assert resp["Delimiter"] == "None" + assert "NextMarker" in resp next_marker = resp["NextMarker"] # Third list - resp = s3.list_objects( - Bucket='mybucket', MaxKeys=1, Marker=next_marker) - listed_object = resp['Contents'][0] + resp = s3.list_objects(Bucket="mybucket", MaxKeys=1, Marker=next_marker) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'two' - assert resp['MaxKeys'] == 1 - assert resp['IsTruncated'] == False - assert resp['Prefix'] == 'None' - assert resp['Delimiter'] == 'None' - assert 'NextMarker' not in resp + assert listed_object["Key"] == "two" + assert resp["MaxKeys"] == 1 + assert resp["IsTruncated"] == False + assert resp["Prefix"] == "None" + assert resp["Delimiter"] == "None" + assert "NextMarker" not in resp @mock_s3 def test_boto3_list_keys_xml_escaped(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - key_name = 'Q&A.txt' - s3.put_object(Bucket='mybucket', Key=key_name, Body=b'is awesome') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + key_name = "Q&A.txt" + s3.put_object(Bucket="mybucket", Key=key_name, Body=b"is awesome") - resp = s3.list_objects_v2(Bucket='mybucket', Prefix=key_name) + resp = s3.list_objects_v2(Bucket="mybucket", Prefix=key_name) - assert resp['Contents'][0]['Key'] == key_name - assert resp['KeyCount'] == 1 - assert resp['MaxKeys'] == 1000 - assert resp['Prefix'] == key_name - assert resp['IsTruncated'] == False - assert 'Delimiter' not in resp - assert 'StartAfter' not in resp - assert 'NextContinuationToken' not in resp - assert 'Owner' not in resp['Contents'][0] + assert resp["Contents"][0]["Key"] == key_name + assert resp["KeyCount"] == 1 + assert resp["MaxKeys"] == 1000 + assert resp["Prefix"] == key_name + assert resp["IsTruncated"] == False + assert "Delimiter" not in resp + assert "StartAfter" not in resp + assert "NextContinuationToken" not in resp + assert "Owner" not in resp["Contents"][0] @mock_s3 def test_boto3_list_objects_v2_common_prefix_pagination(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") max_keys = 1 - keys = ['test/{i}/{i}'.format(i=i) for i in range(3)] + keys = ["test/{i}/{i}".format(i=i) for i in range(3)] for key in keys: - s3.put_object(Bucket='mybucket', Key=key, Body=b'v') + s3.put_object(Bucket="mybucket", Key=key, Body=b"v") prefixes = [] - args = {"Bucket": 'mybucket', "Delimiter": "/", "Prefix": "test/", "MaxKeys": max_keys} + args = { + "Bucket": "mybucket", + "Delimiter": "/", + "Prefix": "test/", + "MaxKeys": max_keys, + } resp = {"IsTruncated": True} while resp.get("IsTruncated", False): if "NextContinuationToken" in resp: @@ -1333,214 +1337,220 @@ def test_boto3_list_objects_v2_common_prefix_pagination(): assert len(resp["CommonPrefixes"]) == max_keys prefixes.extend(i["Prefix"] for i in resp["CommonPrefixes"]) - assert prefixes == [k[:k.rindex('/') + 1] for k in keys] + assert prefixes == [k[: k.rindex("/") + 1] for k in keys] @mock_s3 def test_boto3_list_objects_v2_truncated_response(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='one', Body=b'1') - s3.put_object(Bucket='mybucket', Key='two', Body=b'22') - s3.put_object(Bucket='mybucket', Key='three', Body=b'333') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"1") + s3.put_object(Bucket="mybucket", Key="two", Body=b"22") + s3.put_object(Bucket="mybucket", Key="three", Body=b"333") # First list - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=1) - listed_object = resp['Contents'][0] + resp = s3.list_objects_v2(Bucket="mybucket", MaxKeys=1) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'one' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == True - assert 'Delimiter' not in resp - assert 'StartAfter' not in resp - assert 'Owner' not in listed_object # owner info was not requested + assert listed_object["Key"] == "one" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == True + assert "Delimiter" not in resp + assert "StartAfter" not in resp + assert "Owner" not in listed_object # owner info was not requested - next_token = resp['NextContinuationToken'] + next_token = resp["NextContinuationToken"] # Second list resp = s3.list_objects_v2( - Bucket='mybucket', MaxKeys=1, ContinuationToken=next_token) - listed_object = resp['Contents'][0] + Bucket="mybucket", MaxKeys=1, ContinuationToken=next_token + ) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'three' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == True - assert 'Delimiter' not in resp - assert 'StartAfter' not in resp - assert 'Owner' not in listed_object + assert listed_object["Key"] == "three" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == True + assert "Delimiter" not in resp + assert "StartAfter" not in resp + assert "Owner" not in listed_object - next_token = resp['NextContinuationToken'] + next_token = resp["NextContinuationToken"] # Third list resp = s3.list_objects_v2( - Bucket='mybucket', MaxKeys=1, ContinuationToken=next_token) - listed_object = resp['Contents'][0] + Bucket="mybucket", MaxKeys=1, ContinuationToken=next_token + ) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'two' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == False - assert 'Delimiter' not in resp - assert 'Owner' not in listed_object - assert 'StartAfter' not in resp - assert 'NextContinuationToken' not in resp + assert listed_object["Key"] == "two" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == False + assert "Delimiter" not in resp + assert "Owner" not in listed_object + assert "StartAfter" not in resp + assert "NextContinuationToken" not in resp @mock_s3 def test_boto3_list_objects_v2_truncated_response_start_after(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='one', Body=b'1') - s3.put_object(Bucket='mybucket', Key='two', Body=b'22') - s3.put_object(Bucket='mybucket', Key='three', Body=b'333') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"1") + s3.put_object(Bucket="mybucket", Key="two", Body=b"22") + s3.put_object(Bucket="mybucket", Key="three", Body=b"333") # First list - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=1, StartAfter='one') - listed_object = resp['Contents'][0] + resp = s3.list_objects_v2(Bucket="mybucket", MaxKeys=1, StartAfter="one") + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'three' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == True - assert resp['StartAfter'] == 'one' - assert 'Delimiter' not in resp - assert 'Owner' not in listed_object + assert listed_object["Key"] == "three" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == True + assert resp["StartAfter"] == "one" + assert "Delimiter" not in resp + assert "Owner" not in listed_object - next_token = resp['NextContinuationToken'] + next_token = resp["NextContinuationToken"] # Second list # The ContinuationToken must take precedence over StartAfter. - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=1, StartAfter='one', - ContinuationToken=next_token) - listed_object = resp['Contents'][0] + resp = s3.list_objects_v2( + Bucket="mybucket", MaxKeys=1, StartAfter="one", ContinuationToken=next_token + ) + listed_object = resp["Contents"][0] - assert listed_object['Key'] == 'two' - assert resp['MaxKeys'] == 1 - assert resp['Prefix'] == '' - assert resp['KeyCount'] == 1 - assert resp['IsTruncated'] == False + assert listed_object["Key"] == "two" + assert resp["MaxKeys"] == 1 + assert resp["Prefix"] == "" + assert resp["KeyCount"] == 1 + assert resp["IsTruncated"] == False # When ContinuationToken is given, StartAfter is ignored. This also means # AWS does not return it in the response. - assert 'StartAfter' not in resp - assert 'Delimiter' not in resp - assert 'Owner' not in listed_object + assert "StartAfter" not in resp + assert "Delimiter" not in resp + assert "Owner" not in listed_object @mock_s3 def test_boto3_list_objects_v2_fetch_owner(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='one', Body=b'11') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="one", Body=b"11") - resp = s3.list_objects_v2(Bucket='mybucket', FetchOwner=True) - owner = resp['Contents'][0]['Owner'] + resp = s3.list_objects_v2(Bucket="mybucket", FetchOwner=True) + owner = resp["Contents"][0]["Owner"] - assert 'ID' in owner - assert 'DisplayName' in owner + assert "ID" in owner + assert "DisplayName" in owner assert len(owner.keys()) == 2 @mock_s3 def test_boto3_list_objects_v2_truncate_combined_keys_and_folders(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') - s3.put_object(Bucket='mybucket', Key='1/2', Body='') - s3.put_object(Bucket='mybucket', Key='2', Body='') - s3.put_object(Bucket='mybucket', Key='3/4', Body='') - s3.put_object(Bucket='mybucket', Key='4', Body='') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") + s3.put_object(Bucket="mybucket", Key="1/2", Body="") + s3.put_object(Bucket="mybucket", Key="2", Body="") + s3.put_object(Bucket="mybucket", Key="3/4", Body="") + s3.put_object(Bucket="mybucket", Key="4", Body="") - resp = s3.list_objects_v2(Bucket='mybucket', Prefix='', MaxKeys=2, Delimiter='/') - assert 'Delimiter' in resp - assert resp['IsTruncated'] is True - assert resp['KeyCount'] == 2 - assert len(resp['Contents']) == 1 - assert resp['Contents'][0]['Key'] == '2' - assert len(resp['CommonPrefixes']) == 1 - assert resp['CommonPrefixes'][0]['Prefix'] == '1/' + resp = s3.list_objects_v2(Bucket="mybucket", Prefix="", MaxKeys=2, Delimiter="/") + assert "Delimiter" in resp + assert resp["IsTruncated"] is True + assert resp["KeyCount"] == 2 + assert len(resp["Contents"]) == 1 + assert resp["Contents"][0]["Key"] == "2" + assert len(resp["CommonPrefixes"]) == 1 + assert resp["CommonPrefixes"][0]["Prefix"] == "1/" - last_tail = resp['NextContinuationToken'] - resp = s3.list_objects_v2(Bucket='mybucket', MaxKeys=2, Prefix='', Delimiter='/', StartAfter=last_tail) - assert resp['KeyCount'] == 2 - assert resp['IsTruncated'] is False - assert len(resp['Contents']) == 1 - assert resp['Contents'][0]['Key'] == '4' - assert len(resp['CommonPrefixes']) == 1 - assert resp['CommonPrefixes'][0]['Prefix'] == '3/' + last_tail = resp["NextContinuationToken"] + resp = s3.list_objects_v2( + Bucket="mybucket", MaxKeys=2, Prefix="", Delimiter="/", StartAfter=last_tail + ) + assert resp["KeyCount"] == 2 + assert resp["IsTruncated"] is False + assert len(resp["Contents"]) == 1 + assert resp["Contents"][0]["Key"] == "4" + assert len(resp["CommonPrefixes"]) == 1 + assert resp["CommonPrefixes"][0]["Prefix"] == "3/" @mock_s3 def test_boto3_bucket_create(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').get()['Body'].read().decode( - "utf-8").should.equal("some text") + s3.Object("blah", "hello.txt").get()["Body"].read().decode("utf-8").should.equal( + "some text" + ) @mock_s3 def test_bucket_create_duplicate(): - s3 = boto3.resource('s3', region_name='us-west-2') - s3.create_bucket(Bucket="blah", CreateBucketConfiguration={ - 'LocationConstraint': 'us-west-2', - }) + s3 = boto3.resource("s3", region_name="us-west-2") + s3.create_bucket( + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "us-west-2"} + ) with assert_raises(ClientError) as exc: s3.create_bucket( - Bucket="blah", - CreateBucketConfiguration={ - 'LocationConstraint': 'us-west-2', - } + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "us-west-2"} ) - exc.exception.response['Error']['Code'].should.equal('BucketAlreadyExists') + exc.exception.response["Error"]["Code"].should.equal("BucketAlreadyExists") @mock_s3 def test_bucket_create_force_us_east_1(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") with assert_raises(ClientError) as exc: - s3.create_bucket(Bucket="blah", CreateBucketConfiguration={ - 'LocationConstraint': 'us-east-1', - }) - exc.exception.response['Error']['Code'].should.equal('InvalidLocationConstraint') + s3.create_bucket( + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "us-east-1"} + ) + exc.exception.response["Error"]["Code"].should.equal("InvalidLocationConstraint") @mock_s3 def test_boto3_bucket_create_eu_central(): - s3 = boto3.resource('s3', region_name='eu-central-1') + s3 = boto3.resource("s3", region_name="eu-central-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').get()['Body'].read().decode( - "utf-8").should.equal("some text") + s3.Object("blah", "hello.txt").get()["Body"].read().decode("utf-8").should.equal( + "some text" + ) @mock_s3 def test_boto3_head_object(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt') + s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt" + ) with assert_raises(ClientError) as e: - s3.Object('blah', 'hello2.txt').meta.client.head_object( - Bucket='blah', Key='hello_bad.txt') - e.exception.response['Error']['Code'].should.equal('404') + s3.Object("blah", "hello2.txt").meta.client.head_object( + Bucket="blah", Key="hello_bad.txt" + ) + e.exception.response["Error"]["Code"].should.equal("404") @mock_s3 def test_boto3_bucket_deletion(): - cli = boto3.client('s3', region_name='us-east-1') + cli = boto3.client("s3", region_name="us-east-1") cli.create_bucket(Bucket="foobar") cli.put_object(Bucket="foobar", Key="the-key", Body="some value") @@ -1548,8 +1558,11 @@ def test_boto3_bucket_deletion(): # Try to delete a bucket that still has keys cli.delete_bucket.when.called_with(Bucket="foobar").should.throw( cli.exceptions.ClientError, - ('An error occurred (BucketNotEmpty) when calling the DeleteBucket operation: ' - 'The bucket you tried to delete is not empty')) + ( + "An error occurred (BucketNotEmpty) when calling the DeleteBucket operation: " + "The bucket you tried to delete is not empty" + ), + ) cli.delete_object(Bucket="foobar", Key="the-key") cli.delete_bucket(Bucket="foobar") @@ -1557,123 +1570,158 @@ def test_boto3_bucket_deletion(): # Get non-existing bucket cli.head_bucket.when.called_with(Bucket="foobar").should.throw( cli.exceptions.ClientError, - "An error occurred (404) when calling the HeadBucket operation: Not Found") + "An error occurred (404) when calling the HeadBucket operation: Not Found", + ) # Delete non-existing bucket - cli.delete_bucket.when.called_with(Bucket="foobar").should.throw(cli.exceptions.NoSuchBucket) + cli.delete_bucket.when.called_with(Bucket="foobar").should.throw( + cli.exceptions.NoSuchBucket + ) @mock_s3 def test_boto3_get_object(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") s3.create_bucket(Bucket="blah") - s3.Object('blah', 'hello.txt').put(Body="some text") + s3.Object("blah", "hello.txt").put(Body="some text") - s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt') + s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt" + ) with assert_raises(ClientError) as e: - s3.Object('blah', 'hello2.txt').get() + s3.Object("blah", "hello2.txt").get() - e.exception.response['Error']['Code'].should.equal('NoSuchKey') + e.exception.response["Error"]["Code"].should.equal("NoSuchKey") @mock_s3 def test_boto3_get_missing_object_with_part_number(): - s3 = boto3.resource('s3', region_name='us-east-1') + s3 = boto3.resource("s3", region_name="us-east-1") s3.create_bucket(Bucket="blah") with assert_raises(ClientError) as e: - s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt', PartNumber=123) + s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt", PartNumber=123 + ) - e.exception.response['Error']['Code'].should.equal('404') + e.exception.response["Error"]["Code"].should.equal("404") @mock_s3 def test_boto3_head_object_with_versioning(): - s3 = boto3.resource('s3', region_name='us-east-1') - bucket = s3.create_bucket(Bucket='blah') + s3 = boto3.resource("s3", region_name="us-east-1") + bucket = s3.create_bucket(Bucket="blah") bucket.Versioning().enable() - old_content = 'some text' - new_content = 'some new text' - s3.Object('blah', 'hello.txt').put(Body=old_content) - s3.Object('blah', 'hello.txt').put(Body=new_content) + old_content = "some text" + new_content = "some new text" + s3.Object("blah", "hello.txt").put(Body=old_content) + s3.Object("blah", "hello.txt").put(Body=new_content) - versions = list(s3.Bucket('blah').object_versions.all()) + versions = list(s3.Bucket("blah").object_versions.all()) latest = list(filter(lambda item: item.is_latest, versions))[0] oldest = list(filter(lambda item: not item.is_latest, versions))[0] - head_object = s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt') - head_object['VersionId'].should.equal(latest.id) - head_object['ContentLength'].should.equal(len(new_content)) + head_object = s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt" + ) + head_object["VersionId"].should.equal(latest.id) + head_object["ContentLength"].should.equal(len(new_content)) - old_head_object = s3.Object('blah', 'hello.txt').meta.client.head_object( - Bucket='blah', Key='hello.txt', VersionId=oldest.id) - old_head_object['VersionId'].should.equal(oldest.id) - old_head_object['ContentLength'].should.equal(len(old_content)) + old_head_object = s3.Object("blah", "hello.txt").meta.client.head_object( + Bucket="blah", Key="hello.txt", VersionId=oldest.id + ) + old_head_object["VersionId"].should.equal(oldest.id) + old_head_object["ContentLength"].should.equal(len(old_content)) - old_head_object['VersionId'].should_not.equal(head_object['VersionId']) + old_head_object["VersionId"].should_not.equal(head_object["VersionId"]) @mock_s3 def test_boto3_copy_object_with_versioning(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='blah', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - client.put_bucket_versioning(Bucket='blah', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket( + Bucket="blah", CreateBucketConfiguration={"LocationConstraint": "eu-west-1"} + ) + client.put_bucket_versioning( + Bucket="blah", VersioningConfiguration={"Status": "Enabled"} + ) - client.put_object(Bucket='blah', Key='test1', Body=b'test1') - client.put_object(Bucket='blah', Key='test2', Body=b'test2') + client.put_object(Bucket="blah", Key="test1", Body=b"test1") + client.put_object(Bucket="blah", Key="test2", Body=b"test2") - obj1_version = client.get_object(Bucket='blah', Key='test1')['VersionId'] - obj2_version = client.get_object(Bucket='blah', Key='test2')['VersionId'] + obj1_version = client.get_object(Bucket="blah", Key="test1")["VersionId"] + obj2_version = client.get_object(Bucket="blah", Key="test2")["VersionId"] - client.copy_object(CopySource={'Bucket': 'blah', 'Key': 'test1'}, Bucket='blah', Key='test2') - obj2_version_new = client.get_object(Bucket='blah', Key='test2')['VersionId'] + client.copy_object( + CopySource={"Bucket": "blah", "Key": "test1"}, Bucket="blah", Key="test2" + ) + obj2_version_new = client.get_object(Bucket="blah", Key="test2")["VersionId"] # Version should be different to previous version obj2_version_new.should_not.equal(obj2_version) - client.copy_object(CopySource={'Bucket': 'blah', 'Key': 'test2', 'VersionId': obj2_version}, Bucket='blah', Key='test3') - obj3_version_new = client.get_object(Bucket='blah', Key='test3')['VersionId'] + client.copy_object( + CopySource={"Bucket": "blah", "Key": "test2", "VersionId": obj2_version}, + Bucket="blah", + Key="test3", + ) + obj3_version_new = client.get_object(Bucket="blah", Key="test3")["VersionId"] obj3_version_new.should_not.equal(obj2_version_new) # Copy file that doesn't exist with assert_raises(ClientError) as e: - client.copy_object(CopySource={'Bucket': 'blah', 'Key': 'test4', 'VersionId': obj2_version}, Bucket='blah', Key='test5') - e.exception.response['Error']['Code'].should.equal('404') + client.copy_object( + CopySource={"Bucket": "blah", "Key": "test4", "VersionId": obj2_version}, + Bucket="blah", + Key="test5", + ) + e.exception.response["Error"]["Code"].should.equal("404") - response = client.create_multipart_upload(Bucket='blah', Key='test4') - upload_id = response['UploadId'] - response = client.upload_part_copy(Bucket='blah', Key='test4', - CopySource={'Bucket': 'blah', 'Key': 'test3', 'VersionId': obj3_version_new}, - UploadId=upload_id, PartNumber=1) + response = client.create_multipart_upload(Bucket="blah", Key="test4") + upload_id = response["UploadId"] + response = client.upload_part_copy( + Bucket="blah", + Key="test4", + CopySource={"Bucket": "blah", "Key": "test3", "VersionId": obj3_version_new}, + UploadId=upload_id, + PartNumber=1, + ) etag = response["CopyPartResult"]["ETag"] client.complete_multipart_upload( - Bucket='blah', Key='test4', UploadId=upload_id, - MultipartUpload={'Parts': [{'ETag': etag, 'PartNumber': 1}]}) + Bucket="blah", + Key="test4", + UploadId=upload_id, + MultipartUpload={"Parts": [{"ETag": etag, "PartNumber": 1}]}, + ) - response = client.get_object(Bucket='blah', Key='test4') + response = client.get_object(Bucket="blah", Key="test4") data = response["Body"].read() - data.should.equal(b'test2') + data.should.equal(b"test2") @mock_s3 def test_boto3_copy_object_from_unversioned_to_versioned_bucket(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='src', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - client.create_bucket(Bucket='dest', CreateBucketConfiguration={'LocationConstraint': 'eu-west-1'}) - client.put_bucket_versioning(Bucket='dest', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket( + Bucket="src", CreateBucketConfiguration={"LocationConstraint": "eu-west-1"} + ) + client.create_bucket( + Bucket="dest", CreateBucketConfiguration={"LocationConstraint": "eu-west-1"} + ) + client.put_bucket_versioning( + Bucket="dest", VersioningConfiguration={"Status": "Enabled"} + ) - client.put_object(Bucket='src', Key='test', Body=b'content') + client.put_object(Bucket="src", Key="test", Body=b"content") - obj2_version_new = client.copy_object(CopySource={'Bucket': 'src', 'Key': 'test'}, Bucket='dest', Key='test') \ - .get('VersionId') + obj2_version_new = client.copy_object( + CopySource={"Bucket": "src", "Key": "test"}, Bucket="dest", Key="test" + ).get("VersionId") # VersionId should be present in the response obj2_version_new.should_not.equal(None) @@ -1681,125 +1729,138 @@ def test_boto3_copy_object_from_unversioned_to_versioned_bucket(): @mock_s3 def test_boto3_deleted_versionings_list(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='blah') - client.put_bucket_versioning(Bucket='blah', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket(Bucket="blah") + client.put_bucket_versioning( + Bucket="blah", VersioningConfiguration={"Status": "Enabled"} + ) - client.put_object(Bucket='blah', Key='test1', Body=b'test1') - client.put_object(Bucket='blah', Key='test2', Body=b'test2') - client.delete_objects(Bucket='blah', Delete={'Objects': [{'Key': 'test1'}]}) + client.put_object(Bucket="blah", Key="test1", Body=b"test1") + client.put_object(Bucket="blah", Key="test2", Body=b"test2") + client.delete_objects(Bucket="blah", Delete={"Objects": [{"Key": "test1"}]}) - listed = client.list_objects_v2(Bucket='blah') - assert len(listed['Contents']) == 1 + listed = client.list_objects_v2(Bucket="blah") + assert len(listed["Contents"]) == 1 @mock_s3 def test_boto3_delete_versioned_bucket(): - client = boto3.client('s3', region_name='us-east-1') + client = boto3.client("s3", region_name="us-east-1") - client.create_bucket(Bucket='blah') - client.put_bucket_versioning(Bucket='blah', VersioningConfiguration={'Status': 'Enabled'}) + client.create_bucket(Bucket="blah") + client.put_bucket_versioning( + Bucket="blah", VersioningConfiguration={"Status": "Enabled"} + ) - resp = client.put_object(Bucket='blah', Key='test1', Body=b'test1') - client.delete_object(Bucket='blah', Key='test1', VersionId=resp["VersionId"]) + resp = client.put_object(Bucket="blah", Key="test1", Body=b"test1") + client.delete_object(Bucket="blah", Key="test1", VersionId=resp["VersionId"]) - client.delete_bucket(Bucket='blah') + client.delete_bucket(Bucket="blah") @mock_s3 def test_boto3_get_object_if_modified_since(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "blah" s3.create_bucket(Bucket=bucket_name) - key = 'hello.txt' + key = "hello.txt" - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") with assert_raises(botocore.exceptions.ClientError) as err: s3.get_object( Bucket=bucket_name, Key=key, - IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1) + IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1), ) e = err.exception - e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'}) + e.response["Error"].should.equal({"Code": "304", "Message": "Not Modified"}) @mock_s3 def test_boto3_head_object_if_modified_since(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "blah" s3.create_bucket(Bucket=bucket_name) - key = 'hello.txt' + key = "hello.txt" - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") with assert_raises(botocore.exceptions.ClientError) as err: s3.head_object( Bucket=bucket_name, Key=key, - IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1) + IfModifiedSince=datetime.datetime.utcnow() + datetime.timedelta(hours=1), ) e = err.exception - e.response['Error'].should.equal({'Code': '304', 'Message': 'Not Modified'}) + e.response["Error"].should.equal({"Code": "304", "Message": "Not Modified"}) @mock_s3 @reduced_min_part_size def test_boto3_multipart_etag(): # Create Bucket so that test can run - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") - upload_id = s3.create_multipart_upload( - Bucket='mybucket', Key='the-key')['UploadId'] - part1 = b'0' * REDUCED_PART_SIZE + upload_id = s3.create_multipart_upload(Bucket="mybucket", Key="the-key")["UploadId"] + part1 = b"0" * REDUCED_PART_SIZE etags = [] etags.append( - s3.upload_part(Bucket='mybucket', Key='the-key', PartNumber=1, - UploadId=upload_id, Body=part1)['ETag']) + s3.upload_part( + Bucket="mybucket", + Key="the-key", + PartNumber=1, + UploadId=upload_id, + Body=part1, + )["ETag"] + ) # last part, can be less than 5 MB - part2 = b'1' + part2 = b"1" etags.append( - s3.upload_part(Bucket='mybucket', Key='the-key', PartNumber=2, - UploadId=upload_id, Body=part2)['ETag']) + s3.upload_part( + Bucket="mybucket", + Key="the-key", + PartNumber=2, + UploadId=upload_id, + Body=part2, + )["ETag"] + ) s3.complete_multipart_upload( - Bucket='mybucket', Key='the-key', UploadId=upload_id, - MultipartUpload={'Parts': [{'ETag': etag, 'PartNumber': i} - for i, etag in enumerate(etags, 1)]}) + Bucket="mybucket", + Key="the-key", + UploadId=upload_id, + MultipartUpload={ + "Parts": [ + {"ETag": etag, "PartNumber": i} for i, etag in enumerate(etags, 1) + ] + }, + ) # we should get both parts as the key contents - resp = s3.get_object(Bucket='mybucket', Key='the-key') - resp['ETag'].should.equal(EXPECTED_ETAG) + resp = s3.get_object(Bucket="mybucket", Key="the-key") + resp["ETag"].should.equal(EXPECTED_ETAG) @mock_s3 @reduced_min_part_size def test_boto3_multipart_part_size(): - s3 = boto3.client('s3', region_name='us-east-1') - s3.create_bucket(Bucket='mybucket') + s3 = boto3.client("s3", region_name="us-east-1") + s3.create_bucket(Bucket="mybucket") - mpu = s3.create_multipart_upload(Bucket='mybucket', Key='the-key') + mpu = s3.create_multipart_upload(Bucket="mybucket", Key="the-key") mpu_id = mpu["UploadId"] parts = [] n_parts = 10 for i in range(1, n_parts + 1): part_size = REDUCED_PART_SIZE + i - body = b'1' * part_size + body = b"1" * part_size part = s3.upload_part( - Bucket='mybucket', - Key='the-key', + Bucket="mybucket", + Key="the-key", PartNumber=i, UploadId=mpu_id, Body=body, @@ -1808,34 +1869,29 @@ def test_boto3_multipart_part_size(): parts.append({"PartNumber": i, "ETag": part["ETag"]}) s3.complete_multipart_upload( - Bucket='mybucket', - Key='the-key', + Bucket="mybucket", + Key="the-key", UploadId=mpu_id, MultipartUpload={"Parts": parts}, ) for i in range(1, n_parts + 1): - obj = s3.head_object(Bucket='mybucket', Key='the-key', PartNumber=i) + obj = s3.head_object(Bucket="mybucket", Key="the-key", PartNumber=i) assert obj["ContentLength"] == REDUCED_PART_SIZE + i @mock_s3 def test_boto3_put_object_with_tagging(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test', - Tagging='foo=bar', - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test", Tagging="foo=bar") resp = s3.get_object_tagging(Bucket=bucket_name, Key=key) - resp['TagSet'].should.contain({'Key': 'foo', 'Value': 'bar'}) + resp["TagSet"].should.contain({"Key": "foo", "Value": "bar"}) @mock_s3 @@ -1845,58 +1901,44 @@ def test_boto3_put_bucket_tagging(): s3.create_bucket(Bucket=bucket_name) # With 1 tag: - resp = s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - } - ] - }) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp = s3.put_bucket_tagging( + Bucket=bucket_name, Tagging={"TagSet": [{"Key": "TagOne", "Value": "ValueOne"}]} + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # With multiple tags: - resp = s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagTwo", - "Value": "ValueTwo" - } - ] - }) + resp = s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagTwo", "Value": "ValueTwo"}, + ] + }, + ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # No tags is also OK: - resp = s3.put_bucket_tagging(Bucket=bucket_name, Tagging={ - "TagSet": [] - }) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp = s3.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": []}) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # With duplicate tag keys: with assert_raises(ClientError) as err: - resp = s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagOne", - "Value": "ValueOneAgain" - } - ] - }) + resp = s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagOne", "Value": "ValueOneAgain"}, + ] + }, + ) e = err.exception e.response["Error"]["Code"].should.equal("InvalidTag") - e.response["Error"]["Message"].should.equal("Cannot provide multiple Tags with the same key") + e.response["Error"]["Message"].should.equal( + "Cannot provide multiple Tags with the same key" + ) @mock_s3 @@ -1904,29 +1946,23 @@ def test_boto3_get_bucket_tagging(): s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagTwo", - "Value": "ValueTwo" - } - ] - }) + s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagTwo", "Value": "ValueTwo"}, + ] + }, + ) # Get the tags for the bucket: resp = s3.get_bucket_tagging(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) len(resp["TagSet"]).should.equal(2) # With no tags: - s3.put_bucket_tagging(Bucket=bucket_name, Tagging={ - "TagSet": [] - }) + s3.put_bucket_tagging(Bucket=bucket_name, Tagging={"TagSet": []}) with assert_raises(ClientError) as err: s3.get_bucket_tagging(Bucket=bucket_name) @@ -1942,22 +1978,18 @@ def test_boto3_delete_bucket_tagging(): bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - s3.put_bucket_tagging(Bucket=bucket_name, - Tagging={ - "TagSet": [ - { - "Key": "TagOne", - "Value": "ValueOne" - }, - { - "Key": "TagTwo", - "Value": "ValueTwo" - } - ] - }) + s3.put_bucket_tagging( + Bucket=bucket_name, + Tagging={ + "TagSet": [ + {"Key": "TagOne", "Value": "ValueOne"}, + {"Key": "TagTwo", "Value": "ValueTwo"}, + ] + }, + ) resp = s3.delete_bucket_tagging(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(204) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(204) with assert_raises(ClientError) as err: s3.get_bucket_tagging(Bucket=bucket_name) @@ -1973,76 +2005,56 @@ def test_boto3_put_bucket_cors(): bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - resp = s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [ - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "GET", - "POST" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - }, - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "PUT" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - } - ] - }) - - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - - with assert_raises(ClientError) as err: - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ + resp = s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ "CORSRules": [ { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "NOTREAL", - "POST" - ] - } + "AllowedOrigins": ["*"], + "AllowedMethods": ["GET", "POST"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, + { + "AllowedOrigins": ["*"], + "AllowedMethods": ["PUT"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, ] - }) - e = err.exception - e.response["Error"]["Code"].should.equal("InvalidRequest") - e.response["Error"]["Message"].should.equal("Found unsupported HTTP method in CORS config. " - "Unsupported method is NOTREAL") + }, + ) + + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) with assert_raises(ClientError) as err: - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [] - }) + s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ + "CORSRules": [ + {"AllowedOrigins": ["*"], "AllowedMethods": ["NOTREAL", "POST"]} + ] + }, + ) + e = err.exception + e.response["Error"]["Code"].should.equal("InvalidRequest") + e.response["Error"]["Message"].should.equal( + "Found unsupported HTTP method in CORS config. " "Unsupported method is NOTREAL" + ) + + with assert_raises(ClientError) as err: + s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={"CORSRules": []}) e = err.exception e.response["Error"]["Code"].should.equal("MalformedXML") # And 101: many_rules = [{"AllowedOrigins": ["*"], "AllowedMethods": ["GET"]}] * 101 with assert_raises(ClientError) as err: - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": many_rules - }) + s3.put_bucket_cors( + Bucket=bucket_name, CORSConfiguration={"CORSRules": many_rules} + ) e = err.exception e.response["Error"]["Code"].should.equal("MalformedXML") @@ -2061,44 +2073,30 @@ def test_boto3_get_bucket_cors(): e.response["Error"]["Code"].should.equal("NoSuchCORSConfiguration") e.response["Error"]["Message"].should.equal("The CORS configuration does not exist") - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [ - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "GET", - "POST" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - }, - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "PUT" - ], - "AllowedHeaders": [ - "Authorization" - ], - "ExposeHeaders": [ - "x-amz-request-id" - ], - "MaxAgeSeconds": 123 - } - ] - }) + s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ + "CORSRules": [ + { + "AllowedOrigins": ["*"], + "AllowedMethods": ["GET", "POST"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, + { + "AllowedOrigins": ["*"], + "AllowedMethods": ["PUT"], + "AllowedHeaders": ["Authorization"], + "ExposeHeaders": ["x-amz-request-id"], + "MaxAgeSeconds": 123, + }, + ] + }, + ) resp = s3.get_bucket_cors(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) len(resp["CORSRules"]).should.equal(2) @@ -2107,21 +2105,15 @@ def test_boto3_delete_bucket_cors(): s3 = boto3.client("s3", region_name="us-east-1") bucket_name = "mybucket" s3.create_bucket(Bucket=bucket_name) - s3.put_bucket_cors(Bucket=bucket_name, CORSConfiguration={ - "CORSRules": [ - { - "AllowedOrigins": [ - "*" - ], - "AllowedMethods": [ - "GET" - ] - } - ] - }) + s3.put_bucket_cors( + Bucket=bucket_name, + CORSConfiguration={ + "CORSRules": [{"AllowedOrigins": ["*"], "AllowedMethods": ["GET"]}] + }, + ) resp = s3.delete_bucket_cors(Bucket=bucket_name) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(204) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(204) # Verify deletion: with assert_raises(ClientError) as err: @@ -2137,25 +2129,28 @@ def test_put_bucket_acl_body(): s3 = boto3.client("s3", region_name="us-east-1") s3.create_bucket(Bucket="bucket") bucket_owner = s3.get_bucket_acl(Bucket="bucket")["Owner"] - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", }, - "Permission": "WRITE" - }, - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "READ_ACP", }, - "Permission": "READ_ACP" - } - ], - "Owner": bucket_owner - }) + ], + "Owner": bucket_owner, + }, + ) result = s3.get_bucket_acl(Bucket="bucket") assert len(result["Grants"]) == 2 @@ -2165,54 +2160,65 @@ def test_put_bucket_acl_body(): assert g["Permission"] in ["WRITE", "READ_ACP"] # With one: - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" - }, - "Permission": "WRITE" - } - ], - "Owner": bucket_owner - }) + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", + } + ], + "Owner": bucket_owner, + }, + ) result = s3.get_bucket_acl(Bucket="bucket") assert len(result["Grants"]) == 1 # With no owner: with assert_raises(ClientError) as err: - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" - }, - "Permission": "WRITE" - } - ] - }) + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "MalformedACLError" # With incorrect permission: with assert_raises(ClientError) as err: - s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" - }, - "Permission": "lskjflkasdjflkdsjfalisdjflkdsjf" - } - ], - "Owner": bucket_owner - }) + s3.put_bucket_acl( + Bucket="bucket", + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "lskjflkasdjflkdsjfalisdjflkdsjf", + } + ], + "Owner": bucket_owner, + }, + ) assert err.exception.response["Error"]["Code"] == "MalformedACLError" # Clear the ACLs: - result = s3.put_bucket_acl(Bucket="bucket", AccessControlPolicy={"Grants": [], "Owner": bucket_owner}) + result = s3.put_bucket_acl( + Bucket="bucket", AccessControlPolicy={"Grants": [], "Owner": bucket_owner} + ) assert not result.get("Grants") @@ -2228,46 +2234,43 @@ def test_put_bucket_notification(): assert not result.get("LambdaFunctionConfigurations") # Place proper topic configuration: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "TopicConfigurations": [ - { - "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", - "Events": [ - "s3:ObjectCreated:*", - "s3:ObjectRemoved:*" - ] - }, - { - "TopicArn": "arn:aws:sns:us-east-1:012345678910:myothertopic", - "Events": [ - "s3:ObjectCreated:*" - ], - "Filter": { - "Key": { - "FilterRules": [ - { - "Name": "prefix", - "Value": "images/" - }, - { - "Name": "suffix", - "Value": "png" - } - ] - } - } - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "TopicConfigurations": [ + { + "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", + "Events": ["s3:ObjectCreated:*", "s3:ObjectRemoved:*"], + }, + { + "TopicArn": "arn:aws:sns:us-east-1:012345678910:myothertopic", + "Events": ["s3:ObjectCreated:*"], + "Filter": { + "Key": { + "FilterRules": [ + {"Name": "prefix", "Value": "images/"}, + {"Name": "suffix", "Value": "png"}, + ] + } + }, + }, + ] + }, + ) # Verify to completion: result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["TopicConfigurations"]) == 2 assert not result.get("QueueConfigurations") assert not result.get("LambdaFunctionConfigurations") - assert result["TopicConfigurations"][0]["TopicArn"] == "arn:aws:sns:us-east-1:012345678910:mytopic" - assert result["TopicConfigurations"][1]["TopicArn"] == "arn:aws:sns:us-east-1:012345678910:myothertopic" + assert ( + result["TopicConfigurations"][0]["TopicArn"] + == "arn:aws:sns:us-east-1:012345678910:mytopic" + ) + assert ( + result["TopicConfigurations"][1]["TopicArn"] + == "arn:aws:sns:us-east-1:012345678910:myothertopic" + ) assert len(result["TopicConfigurations"][0]["Events"]) == 2 assert len(result["TopicConfigurations"][1]["Events"]) == 1 assert result["TopicConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" @@ -2277,111 +2280,138 @@ def test_put_bucket_notification(): assert result["TopicConfigurations"][1]["Id"] assert not result["TopicConfigurations"][0].get("Filter") assert len(result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"]) == 2 - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Name"] == "prefix" - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Value"] == "images/" - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Name"] == "suffix" - assert result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Value"] == "png" + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Name"] + == "prefix" + ) + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][0]["Value"] + == "images/" + ) + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Name"] + == "suffix" + ) + assert ( + result["TopicConfigurations"][1]["Filter"]["Key"]["FilterRules"][1]["Value"] + == "png" + ) # Place proper queue configuration: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "QueueConfigurations": [ - { - "Id": "SomeID", - "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", - "Events": ["s3:ObjectCreated:*"], - "Filter": { - "Key": { - "FilterRules": [ - { - "Name": "prefix", - "Value": "images/" - } - ] - } - } - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "QueueConfigurations": [ + { + "Id": "SomeID", + "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", + "Events": ["s3:ObjectCreated:*"], + "Filter": { + "Key": {"FilterRules": [{"Name": "prefix", "Value": "images/"}]} + }, + } + ] + }, + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["QueueConfigurations"]) == 1 assert not result.get("TopicConfigurations") assert not result.get("LambdaFunctionConfigurations") assert result["QueueConfigurations"][0]["Id"] == "SomeID" - assert result["QueueConfigurations"][0]["QueueArn"] == "arn:aws:sqs:us-east-1:012345678910:myQueue" + assert ( + result["QueueConfigurations"][0]["QueueArn"] + == "arn:aws:sqs:us-east-1:012345678910:myQueue" + ) assert result["QueueConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" assert len(result["QueueConfigurations"][0]["Events"]) == 1 assert len(result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 - assert result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Name"] == "prefix" - assert result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Value"] == "images/" + assert ( + result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Name"] + == "prefix" + ) + assert ( + result["QueueConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Value"] + == "images/" + ) # Place proper Lambda configuration: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "LambdaFunctionConfigurations": [ - { - "LambdaFunctionArn": - "arn:aws:lambda:us-east-1:012345678910:function:lambda", - "Events": ["s3:ObjectCreated:*"], - "Filter": { - "Key": { - "FilterRules": [ - { - "Name": "prefix", - "Value": "images/" - } - ] - } - } - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "LambdaFunctionConfigurations": [ + { + "LambdaFunctionArn": "arn:aws:lambda:us-east-1:012345678910:function:lambda", + "Events": ["s3:ObjectCreated:*"], + "Filter": { + "Key": {"FilterRules": [{"Name": "prefix", "Value": "images/"}]} + }, + } + ] + }, + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["LambdaFunctionConfigurations"]) == 1 assert not result.get("TopicConfigurations") assert not result.get("QueueConfigurations") assert result["LambdaFunctionConfigurations"][0]["Id"] - assert result["LambdaFunctionConfigurations"][0]["LambdaFunctionArn"] == \ - "arn:aws:lambda:us-east-1:012345678910:function:lambda" - assert result["LambdaFunctionConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" + assert ( + result["LambdaFunctionConfigurations"][0]["LambdaFunctionArn"] + == "arn:aws:lambda:us-east-1:012345678910:function:lambda" + ) + assert ( + result["LambdaFunctionConfigurations"][0]["Events"][0] == "s3:ObjectCreated:*" + ) assert len(result["LambdaFunctionConfigurations"][0]["Events"]) == 1 - assert len(result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"]) == 1 - assert result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Name"] == "prefix" - assert result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0]["Value"] == "images/" + assert ( + len(result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"]) + == 1 + ) + assert ( + result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0][ + "Name" + ] + == "prefix" + ) + assert ( + result["LambdaFunctionConfigurations"][0]["Filter"]["Key"]["FilterRules"][0][ + "Value" + ] + == "images/" + ) # And with all 3 set: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "TopicConfigurations": [ - { - "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", - "Events": [ - "s3:ObjectCreated:*", - "s3:ObjectRemoved:*" - ] - } - ], - "LambdaFunctionConfigurations": [ - { - "LambdaFunctionArn": - "arn:aws:lambda:us-east-1:012345678910:function:lambda", - "Events": ["s3:ObjectCreated:*"] - } - ], - "QueueConfigurations": [ - { - "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", - "Events": ["s3:ObjectCreated:*"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "TopicConfigurations": [ + { + "TopicArn": "arn:aws:sns:us-east-1:012345678910:mytopic", + "Events": ["s3:ObjectCreated:*", "s3:ObjectRemoved:*"], + } + ], + "LambdaFunctionConfigurations": [ + { + "LambdaFunctionArn": "arn:aws:lambda:us-east-1:012345678910:function:lambda", + "Events": ["s3:ObjectCreated:*"], + } + ], + "QueueConfigurations": [ + { + "QueueArn": "arn:aws:sqs:us-east-1:012345678910:myQueue", + "Events": ["s3:ObjectCreated:*"], + } + ], + }, + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert len(result["LambdaFunctionConfigurations"]) == 1 assert len(result["TopicConfigurations"]) == 1 assert len(result["QueueConfigurations"]) == 1 # And clear it out: - s3.put_bucket_notification_configuration(Bucket="bucket", NotificationConfiguration={}) + s3.put_bucket_notification_configuration( + Bucket="bucket", NotificationConfiguration={} + ) result = s3.get_bucket_notification_configuration(Bucket="bucket") assert not result.get("TopicConfigurations") assert not result.get("QueueConfigurations") @@ -2396,51 +2426,63 @@ def test_put_bucket_notification_errors(): # With incorrect ARNs: for tech, arn in [("Queue", "sqs"), ("Topic", "sns"), ("LambdaFunction", "lambda")]: with assert_raises(ClientError) as err: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "{}Configurations".format(tech): [ - { - "{}Arn".format(tech): - "arn:aws:{}:us-east-1:012345678910:lksajdfkldskfj", - "Events": ["s3:ObjectCreated:*"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "{}Configurations".format(tech): [ + { + "{}Arn".format( + tech + ): "arn:aws:{}:us-east-1:012345678910:lksajdfkldskfj", + "Events": ["s3:ObjectCreated:*"], + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidArgument" - assert err.exception.response["Error"]["Message"] == "The ARN is not well formed" + assert ( + err.exception.response["Error"]["Message"] == "The ARN is not well formed" + ) # Region not the same as the bucket: with assert_raises(ClientError) as err: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "QueueConfigurations": [ - { - "QueueArn": - "arn:aws:sqs:us-west-2:012345678910:lksajdfkldskfj", - "Events": ["s3:ObjectCreated:*"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "QueueConfigurations": [ + { + "QueueArn": "arn:aws:sqs:us-west-2:012345678910:lksajdfkldskfj", + "Events": ["s3:ObjectCreated:*"], + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidArgument" - assert err.exception.response["Error"]["Message"] == \ - "The notification destination service region is not valid for the bucket location constraint" + assert ( + err.exception.response["Error"]["Message"] + == "The notification destination service region is not valid for the bucket location constraint" + ) # Invalid event name: with assert_raises(ClientError) as err: - s3.put_bucket_notification_configuration(Bucket="bucket", - NotificationConfiguration={ - "QueueConfigurations": [ - { - "QueueArn": - "arn:aws:sqs:us-east-1:012345678910:lksajdfkldskfj", - "Events": ["notarealeventname"] - } - ] - }) + s3.put_bucket_notification_configuration( + Bucket="bucket", + NotificationConfiguration={ + "QueueConfigurations": [ + { + "QueueArn": "arn:aws:sqs:us-east-1:012345678910:lksajdfkldskfj", + "Events": ["notarealeventname"], + } + ] + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidArgument" - assert err.exception.response["Error"]["Message"] == "The event is not supported for notifications" + assert ( + err.exception.response["Error"]["Message"] + == "The event is not supported for notifications" + ) @mock_s3 @@ -2451,7 +2493,10 @@ def test_boto3_put_bucket_logging(): wrong_region_bucket = "wrongregionlogbucket" s3.create_bucket(Bucket=bucket_name) s3.create_bucket(Bucket=log_bucket) # Adding the ACL for log-delivery later... - s3.create_bucket(Bucket=wrong_region_bucket, CreateBucketConfiguration={"LocationConstraint": "us-west-2"}) + s3.create_bucket( + Bucket=wrong_region_bucket, + CreateBucketConfiguration={"LocationConstraint": "us-west-2"}, + ) # No logging config: result = s3.get_bucket_logging(Bucket=bucket_name) @@ -2459,72 +2504,78 @@ def test_boto3_put_bucket_logging(): # A log-bucket that doesn't exist: with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": "IAMNOTREAL", - "TargetPrefix": "" - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": {"TargetBucket": "IAMNOTREAL", "TargetPrefix": ""} + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidTargetBucketForLogging" # A log-bucket that's missing the proper ACLs for LogDelivery: with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "" - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": {"TargetBucket": log_bucket, "TargetPrefix": ""} + }, + ) assert err.exception.response["Error"]["Code"] == "InvalidTargetBucketForLogging" assert "log-delivery" in err.exception.response["Error"]["Message"] # Add the proper "log-delivery" ACL to the log buckets: bucket_owner = s3.get_bucket_acl(Bucket=log_bucket)["Owner"] for bucket in [log_bucket, wrong_region_bucket]: - s3.put_bucket_acl(Bucket=bucket, AccessControlPolicy={ - "Grants": [ - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + s3.put_bucket_acl( + Bucket=bucket, + AccessControlPolicy={ + "Grants": [ + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "WRITE", }, - "Permission": "WRITE" - }, - { - "Grantee": { - "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", - "Type": "Group" + { + "Grantee": { + "URI": "http://acs.amazonaws.com/groups/s3/LogDelivery", + "Type": "Group", + }, + "Permission": "READ_ACP", }, - "Permission": "READ_ACP" - }, - { - "Grantee": { - "Type": "CanonicalUser", - "ID": bucket_owner["ID"] + { + "Grantee": {"Type": "CanonicalUser", "ID": bucket_owner["ID"]}, + "Permission": "FULL_CONTROL", }, - "Permission": "FULL_CONTROL" - } - ], - "Owner": bucket_owner - }) + ], + "Owner": bucket_owner, + }, + ) # A log-bucket that's in the wrong region: with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": wrong_region_bucket, - "TargetPrefix": "" - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": wrong_region_bucket, + "TargetPrefix": "", + } + }, + ) assert err.exception.response["Error"]["Code"] == "CrossLocationLoggingProhibitted" # Correct logging: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "{}/".format(bucket_name) - } - }) + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": log_bucket, + "TargetPrefix": "{}/".format(bucket_name), + } + }, + ) result = s3.get_bucket_logging(Bucket=bucket_name) assert result["LoggingEnabled"]["TargetBucket"] == log_bucket assert result["LoggingEnabled"]["TargetPrefix"] == "{}/".format(bucket_name) @@ -2535,56 +2586,9 @@ def test_boto3_put_bucket_logging(): assert not s3.get_bucket_logging(Bucket=bucket_name).get("LoggingEnabled") # And enabling with multiple target grants: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "{}/".format(bucket_name), - "TargetGrants": [ - { - "Grantee": { - "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" - }, - "Permission": "READ" - }, - { - "Grantee": { - "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" - }, - "Permission": "WRITE" - } - ] - } - }) - - result = s3.get_bucket_logging(Bucket=bucket_name) - assert len(result["LoggingEnabled"]["TargetGrants"]) == 2 - assert result["LoggingEnabled"]["TargetGrants"][0]["Grantee"]["ID"] == \ - "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274" - - # Test with just 1 grant: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ - "LoggingEnabled": { - "TargetBucket": log_bucket, - "TargetPrefix": "{}/".format(bucket_name), - "TargetGrants": [ - { - "Grantee": { - "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" - }, - "Permission": "READ" - } - ] - } - }) - result = s3.get_bucket_logging(Bucket=bucket_name) - assert len(result["LoggingEnabled"]["TargetGrants"]) == 1 - - # With an invalid grant: - with assert_raises(ClientError) as err: - s3.put_bucket_logging(Bucket=bucket_name, BucketLoggingStatus={ + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ "LoggingEnabled": { "TargetBucket": log_bucket, "TargetPrefix": "{}/".format(bucket_name), @@ -2592,96 +2596,152 @@ def test_boto3_put_bucket_logging(): { "Grantee": { "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", - "Type": "CanonicalUser" + "Type": "CanonicalUser", }, - "Permission": "NOTAREALPERM" - } - ] + "Permission": "READ", + }, + { + "Grantee": { + "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", + "Type": "CanonicalUser", + }, + "Permission": "WRITE", + }, + ], } - }) + }, + ) + + result = s3.get_bucket_logging(Bucket=bucket_name) + assert len(result["LoggingEnabled"]["TargetGrants"]) == 2 + assert ( + result["LoggingEnabled"]["TargetGrants"][0]["Grantee"]["ID"] + == "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274" + ) + + # Test with just 1 grant: + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": log_bucket, + "TargetPrefix": "{}/".format(bucket_name), + "TargetGrants": [ + { + "Grantee": { + "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", + "Type": "CanonicalUser", + }, + "Permission": "READ", + } + ], + } + }, + ) + result = s3.get_bucket_logging(Bucket=bucket_name) + assert len(result["LoggingEnabled"]["TargetGrants"]) == 1 + + # With an invalid grant: + with assert_raises(ClientError) as err: + s3.put_bucket_logging( + Bucket=bucket_name, + BucketLoggingStatus={ + "LoggingEnabled": { + "TargetBucket": log_bucket, + "TargetPrefix": "{}/".format(bucket_name), + "TargetGrants": [ + { + "Grantee": { + "ID": "SOMEIDSTRINGHERE9238748923734823917498237489237409123840983274", + "Type": "CanonicalUser", + }, + "Permission": "NOTAREALPERM", + } + ], + } + }, + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" @mock_s3 def test_boto3_put_object_tagging(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as err: s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) e = err.exception - e.response['Error'].should.equal({ - 'Code': 'NoSuchKey', - 'Message': 'The specified key does not exist.', - 'RequestID': '7a62c49f-347e-4fc4-9331-6e8eEXAMPLE', - }) - - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' + e.response["Error"].should.equal( + { + "Code": "NoSuchKey", + "Message": "The specified key does not exist.", + "RequestID": "7a62c49f-347e-4fc4-9331-6e8eEXAMPLE", + } ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") + resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @mock_s3 def test_boto3_put_object_tagging_on_earliest_version(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3_resource = boto3.resource('s3') + s3_resource = boto3.resource("s3") bucket_versioning = s3_resource.BucketVersioning(bucket_name) bucket_versioning.enable() - bucket_versioning.status.should.equal('Enabled') + bucket_versioning.status.should.equal("Enabled") with assert_raises(ClientError) as err: s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) e = err.exception - e.response['Error'].should.equal({ - 'Code': 'NoSuchKey', - 'Message': 'The specified key does not exist.', - 'RequestID': '7a62c49f-347e-4fc4-9331-6e8eEXAMPLE', - }) + e.response["Error"].should.equal( + { + "Code": "NoSuchKey", + "Message": "The specified key does not exist.", + "RequestID": "7a62c49f-347e-4fc4-9331-6e8eEXAMPLE", + } + ) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test_updated' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") + s3.put_object(Bucket=bucket_name, Key=key, Body="test_updated") object_versions = list(s3_resource.Bucket(bucket_name).object_versions.all()) first_object = object_versions[0] @@ -2690,68 +2750,65 @@ def test_boto3_put_object_tagging_on_earliest_version(): resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]}, - VersionId=first_object.id + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, + VersionId=first_object.id, ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # Older version has tags while the most recent does not resp = s3.get_object_tagging(Bucket=bucket_name, Key=key, VersionId=first_object.id) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - resp['TagSet'].should.equal( - [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'} - ] + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal( + [{"Key": "item1", "Value": "foo"}, {"Key": "item2", "Value": "bar"}] ) - resp = s3.get_object_tagging(Bucket=bucket_name, Key=key, VersionId=second_object.id) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - resp['TagSet'].should.equal([]) + resp = s3.get_object_tagging( + Bucket=bucket_name, Key=key, VersionId=second_object.id + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal([]) @mock_s3 def test_boto3_put_object_tagging_on_both_version(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3_resource = boto3.resource('s3') + s3_resource = boto3.resource("s3") bucket_versioning = s3_resource.BucketVersioning(bucket_name) bucket_versioning.enable() - bucket_versioning.status.should.equal('Enabled') + bucket_versioning.status.should.equal("Enabled") with assert_raises(ClientError) as err: s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) e = err.exception - e.response['Error'].should.equal({ - 'Code': 'NoSuchKey', - 'Message': 'The specified key does not exist.', - 'RequestID': '7a62c49f-347e-4fc4-9331-6e8eEXAMPLE', - }) + e.response["Error"].should.equal( + { + "Code": "NoSuchKey", + "Message": "The specified key does not exist.", + "RequestID": "7a62c49f-347e-4fc4-9331-6e8eEXAMPLE", + } + ) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test_updated' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") + s3.put_object(Bucket=bucket_name, Key=key, Body="test_updated") object_versions = list(s3_resource.Bucket(bucket_name).object_versions.all()) first_object = object_versions[0] @@ -2760,380 +2817,292 @@ def test_boto3_put_object_tagging_on_both_version(): resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]}, - VersionId=first_object.id + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, + VersionId=first_object.id, ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'baz'}, - {'Key': 'item2', 'Value': 'bin'}, - ]}, - VersionId=second_object.id + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "baz"}, + {"Key": "item2", "Value": "bin"}, + ] + }, + VersionId=second_object.id, ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) resp = s3.get_object_tagging(Bucket=bucket_name, Key=key, VersionId=first_object.id) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - resp['TagSet'].should.equal( - [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'} - ] + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal( + [{"Key": "item1", "Value": "foo"}, {"Key": "item2", "Value": "bar"}] ) - resp = s3.get_object_tagging(Bucket=bucket_name, Key=key, VersionId=second_object.id) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - resp['TagSet'].should.equal( - [ - {'Key': 'item1', 'Value': 'baz'}, - {'Key': 'item2', 'Value': 'bin'} - ] + resp = s3.get_object_tagging( + Bucket=bucket_name, Key=key, VersionId=second_object.id + ) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + resp["TagSet"].should.equal( + [{"Key": "item1", "Value": "baz"}, {"Key": "item2", "Value": "bin"}] ) @mock_s3 def test_boto3_put_object_tagging_with_single_tag(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'} - ]} + Tagging={"TagSet": [{"Key": "item1", "Value": "foo"}]}, ) - resp['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + resp["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) @mock_s3 def test_boto3_get_object_tagging(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-tags' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-tags" s3.create_bucket(Bucket=bucket_name) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body='test' - ) + s3.put_object(Bucket=bucket_name, Key=key, Body="test") resp = s3.get_object_tagging(Bucket=bucket_name, Key=key) - resp['TagSet'].should.have.length_of(0) + resp["TagSet"].should.have.length_of(0) resp = s3.put_object_tagging( Bucket=bucket_name, Key=key, - Tagging={'TagSet': [ - {'Key': 'item1', 'Value': 'foo'}, - {'Key': 'item2', 'Value': 'bar'}, - ]} + Tagging={ + "TagSet": [ + {"Key": "item1", "Value": "foo"}, + {"Key": "item2", "Value": "bar"}, + ] + }, ) resp = s3.get_object_tagging(Bucket=bucket_name, Key=key) - resp['TagSet'].should.have.length_of(2) - resp['TagSet'].should.contain({'Key': 'item1', 'Value': 'foo'}) - resp['TagSet'].should.contain({'Key': 'item2', 'Value': 'bar'}) + resp["TagSet"].should.have.length_of(2) + resp["TagSet"].should.contain({"Key": "item1", "Value": "foo"}) + resp["TagSet"].should.contain({"Key": "item2", "Value": "bar"}) @mock_s3 def test_boto3_list_object_versions(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) - response = s3.list_object_versions( - Bucket=bucket_name - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + response = s3.list_object_versions(Bucket=bucket_name) # Two object versions should be returned - len(response['Versions']).should.equal(2) - keys = set([item['Key'] for item in response['Versions']]) + len(response["Versions"]).should.equal(2) + keys = set([item["Key"] for item in response["Versions"]]) keys.should.equal({key}) # Test latest object version is returned response = s3.get_object(Bucket=bucket_name, Key=key) - response['Body'].read().should.equal(items[-1]) + response["Body"].read().should.equal(items[-1]) @mock_s3 def test_boto3_list_object_versions_with_versioning_disabled(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" s3.create_bucket(Bucket=bucket_name) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) - response = s3.list_object_versions( - Bucket=bucket_name - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + response = s3.list_object_versions(Bucket=bucket_name) # One object version should be returned - len(response['Versions']).should.equal(1) - response['Versions'][0]['Key'].should.equal(key) + len(response["Versions"]).should.equal(1) + response["Versions"][0]["Key"].should.equal(key) # The version id should be the string null - response['Versions'][0]['VersionId'].should.equal('null') + response["Versions"][0]["VersionId"].should.equal("null") # Test latest object version is returned response = s3.get_object(Bucket=bucket_name, Key=key) - response['Body'].read().should.equal(items[-1]) + response["Body"].read().should.equal(items[-1]) @mock_s3 def test_boto3_list_object_versions_with_versioning_enabled_late(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" s3.create_bucket(Bucket=bucket_name) - items = (six.b('v1'), six.b('v2')) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=six.b('v1') - ) + items = (six.b("v1"), six.b("v2")) + s3.put_object(Bucket=bucket_name, Key=key, Body=six.b("v1")) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } - ) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=six.b('v2') - ) - response = s3.list_object_versions( - Bucket=bucket_name + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) + s3.put_object(Bucket=bucket_name, Key=key, Body=six.b("v2")) + response = s3.list_object_versions(Bucket=bucket_name) # Two object versions should be returned - len(response['Versions']).should.equal(2) - keys = set([item['Key'] for item in response['Versions']]) + len(response["Versions"]).should.equal(2) + keys = set([item["Key"] for item in response["Versions"]]) keys.should.equal({key}) # There should still be a null version id. - versionsId = set([item['VersionId'] for item in response['Versions']]) - versionsId.should.contain('null') + versionsId = set([item["VersionId"] for item in response["Versions"]]) + versionsId.should.contain("null") # Test latest object version is returned response = s3.get_object(Bucket=bucket_name, Key=key) - response['Body'].read().should.equal(items[-1]) + response["Body"].read().should.equal(items[-1]) @mock_s3 def test_boto3_bad_prefix_list_object_versions(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = 'key-with-versions' - bad_prefix = 'key-that-does-not-exist' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions" + bad_prefix = "key-that-does-not-exist" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) - response = s3.list_object_versions( - Bucket=bucket_name, - Prefix=bad_prefix, - ) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response.should_not.contain('Versions') - response.should_not.contain('DeleteMarkers') + s3.put_object(Bucket=bucket_name, Key=key, Body=body) + response = s3.list_object_versions(Bucket=bucket_name, Prefix=bad_prefix) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response.should_not.contain("Versions") + response.should_not.contain("DeleteMarkers") @mock_s3 def test_boto3_delete_markers(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = u'key-with-versions-and-unicode-ó' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions-and-unicode-ó" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) - s3.delete_objects(Bucket=bucket_name, Delete={'Objects': [{'Key': key}]}) + s3.delete_objects(Bucket=bucket_name, Delete={"Objects": [{"Key": key}]}) with assert_raises(ClientError) as e: - s3.get_object( - Bucket=bucket_name, - Key=key - ) - e.exception.response['Error']['Code'].should.equal('NoSuchKey') + s3.get_object(Bucket=bucket_name, Key=key) + e.exception.response["Error"]["Code"].should.equal("NoSuchKey") - response = s3.list_object_versions( - Bucket=bucket_name - ) - response['Versions'].should.have.length_of(2) - response['DeleteMarkers'].should.have.length_of(1) + response = s3.list_object_versions(Bucket=bucket_name) + response["Versions"].should.have.length_of(2) + response["DeleteMarkers"].should.have.length_of(1) s3.delete_object( - Bucket=bucket_name, - Key=key, - VersionId=response['DeleteMarkers'][0]['VersionId'] + Bucket=bucket_name, Key=key, VersionId=response["DeleteMarkers"][0]["VersionId"] ) - response = s3.get_object( - Bucket=bucket_name, - Key=key - ) - response['Body'].read().should.equal(items[-1]) + response = s3.get_object(Bucket=bucket_name, Key=key) + response["Body"].read().should.equal(items[-1]) - response = s3.list_object_versions( - Bucket=bucket_name - ) - response['Versions'].should.have.length_of(2) + response = s3.list_object_versions(Bucket=bucket_name) + response["Versions"].should.have.length_of(2) # We've asserted there is only 2 records so one is newest, one is oldest - latest = list(filter(lambda item: item['IsLatest'], response['Versions']))[0] - oldest = list(filter(lambda item: not item['IsLatest'], response['Versions']))[0] + latest = list(filter(lambda item: item["IsLatest"], response["Versions"]))[0] + oldest = list(filter(lambda item: not item["IsLatest"], response["Versions"]))[0] # Double check ordering of version ID's - latest['VersionId'].should_not.equal(oldest['VersionId']) + latest["VersionId"].should_not.equal(oldest["VersionId"]) # Double check the name is still unicode - latest['Key'].should.equal('key-with-versions-and-unicode-ó') - oldest['Key'].should.equal('key-with-versions-and-unicode-ó') + latest["Key"].should.equal("key-with-versions-and-unicode-ó") + oldest["Key"].should.equal("key-with-versions-and-unicode-ó") @mock_s3 def test_boto3_multiple_delete_markers(): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - key = u'key-with-versions-and-unicode-ó' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + key = "key-with-versions-and-unicode-ó" s3.create_bucket(Bucket=bucket_name) s3.put_bucket_versioning( - Bucket=bucket_name, - VersioningConfiguration={ - 'Status': 'Enabled' - } + Bucket=bucket_name, VersioningConfiguration={"Status": "Enabled"} ) - items = (six.b('v1'), six.b('v2')) + items = (six.b("v1"), six.b("v2")) for body in items: - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) # Delete the object twice to add multiple delete markers s3.delete_object(Bucket=bucket_name, Key=key) s3.delete_object(Bucket=bucket_name, Key=key) response = s3.list_object_versions(Bucket=bucket_name) - response['DeleteMarkers'].should.have.length_of(2) + response["DeleteMarkers"].should.have.length_of(2) with assert_raises(ClientError) as e: - s3.get_object( - Bucket=bucket_name, - Key=key - ) - e.response['Error']['Code'].should.equal('404') + s3.get_object(Bucket=bucket_name, Key=key) + e.response["Error"]["Code"].should.equal("404") # Remove both delete markers to restore the object s3.delete_object( - Bucket=bucket_name, - Key=key, - VersionId=response['DeleteMarkers'][0]['VersionId'] + Bucket=bucket_name, Key=key, VersionId=response["DeleteMarkers"][0]["VersionId"] ) s3.delete_object( - Bucket=bucket_name, - Key=key, - VersionId=response['DeleteMarkers'][1]['VersionId'] + Bucket=bucket_name, Key=key, VersionId=response["DeleteMarkers"][1]["VersionId"] ) - response = s3.get_object( - Bucket=bucket_name, - Key=key - ) - response['Body'].read().should.equal(items[-1]) + response = s3.get_object(Bucket=bucket_name, Key=key) + response["Body"].read().should.equal(items[-1]) response = s3.list_object_versions(Bucket=bucket_name) - response['Versions'].should.have.length_of(2) + response["Versions"].should.have.length_of(2) # We've asserted there is only 2 records so one is newest, one is oldest - latest = list(filter(lambda item: item['IsLatest'], response['Versions']))[0] - oldest = list(filter(lambda item: not item['IsLatest'], response['Versions']))[0] + latest = list(filter(lambda item: item["IsLatest"], response["Versions"]))[0] + oldest = list(filter(lambda item: not item["IsLatest"], response["Versions"]))[0] # Double check ordering of version ID's - latest['VersionId'].should_not.equal(oldest['VersionId']) + latest["VersionId"].should_not.equal(oldest["VersionId"]) # Double check the name is still unicode - latest['Key'].should.equal('key-with-versions-and-unicode-ó') - oldest['Key'].should.equal('key-with-versions-and-unicode-ó') + latest["Key"].should.equal("key-with-versions-and-unicode-ó") + oldest["Key"].should.equal("key-with-versions-and-unicode-ó") @mock_s3 def test_get_stream_gzipped(): payload = b"this is some stuff here" - s3_client = boto3.client("s3", region_name='us-east-1') - s3_client.create_bucket(Bucket='moto-tests') + s3_client = boto3.client("s3", region_name="us-east-1") + s3_client.create_bucket(Bucket="moto-tests") buffer_ = BytesIO() - with GzipFile(fileobj=buffer_, mode='w') as f: + with GzipFile(fileobj=buffer_, mode="w") as f: f.write(payload) payload_gz = buffer_.getvalue() s3_client.put_object( - Bucket='moto-tests', - Key='keyname', - Body=payload_gz, - ContentEncoding='gzip', + Bucket="moto-tests", Key="keyname", Body=payload_gz, ContentEncoding="gzip" ) - obj = s3_client.get_object( - Bucket='moto-tests', - Key='keyname', - ) - res = zlib.decompress(obj['Body'].read(), 16 + zlib.MAX_WBITS) + obj = s3_client.get_object(Bucket="moto-tests", Key="keyname") + res = zlib.decompress(obj["Body"].read(), 16 + zlib.MAX_WBITS) assert res == payload @@ -3159,163 +3128,153 @@ TEST_XML = """\ @mock_s3 def test_boto3_bucket_name_too_long(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") with assert_raises(ClientError) as exc: - s3.create_bucket(Bucket='x' * 64) - exc.exception.response['Error']['Code'].should.equal('InvalidBucketName') + s3.create_bucket(Bucket="x" * 64) + exc.exception.response["Error"]["Code"].should.equal("InvalidBucketName") @mock_s3 def test_boto3_bucket_name_too_short(): - s3 = boto3.client('s3', region_name='us-east-1') + s3 = boto3.client("s3", region_name="us-east-1") with assert_raises(ClientError) as exc: - s3.create_bucket(Bucket='x' * 2) - exc.exception.response['Error']['Code'].should.equal('InvalidBucketName') + s3.create_bucket(Bucket="x" * 2) + exc.exception.response["Error"]["Code"].should.equal("InvalidBucketName") @mock_s3 def test_accelerated_none_when_unspecified(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.shouldnt.have.key('Status') + resp.shouldnt.have.key("Status") @mock_s3 def test_can_enable_bucket_acceleration(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Enabled'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Enabled"} ) - resp.keys().should.have.length_of(1) # Response contains nothing (only HTTP headers) + resp.keys().should.have.length_of( + 1 + ) # Response contains nothing (only HTTP headers) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.should.have.key('Status') - resp['Status'].should.equal('Enabled') + resp.should.have.key("Status") + resp["Status"].should.equal("Enabled") @mock_s3 def test_can_suspend_bucket_acceleration(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Enabled'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Enabled"} ) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Suspended'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Suspended"} ) - resp.keys().should.have.length_of(1) # Response contains nothing (only HTTP headers) + resp.keys().should.have.length_of( + 1 + ) # Response contains nothing (only HTTP headers) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.should.have.key('Status') - resp['Status'].should.equal('Suspended') + resp.should.have.key("Status") + resp["Status"].should.equal("Suspended") @mock_s3 def test_suspending_acceleration_on_not_configured_bucket_does_nothing(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) resp = s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Suspended'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Suspended"} ) - resp.keys().should.have.length_of(1) # Response contains nothing (only HTTP headers) + resp.keys().should.have.length_of( + 1 + ) # Response contains nothing (only HTTP headers) resp = s3.get_bucket_accelerate_configuration(Bucket=bucket_name) - resp.shouldnt.have.key('Status') + resp.shouldnt.have.key("Status") @mock_s3 def test_accelerate_configuration_status_validation(): - bucket_name = 'some_bucket' - s3 = boto3.client('s3') + bucket_name = "some_bucket" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as exc: s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'bad_status'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "bad_status"} ) - exc.exception.response['Error']['Code'].should.equal('MalformedXML') + exc.exception.response["Error"]["Code"].should.equal("MalformedXML") @mock_s3 def test_accelerate_configuration_is_not_supported_when_bucket_name_has_dots(): - bucket_name = 'some.bucket.with.dots' - s3 = boto3.client('s3') + bucket_name = "some.bucket.with.dots" + s3 = boto3.client("s3") s3.create_bucket(Bucket=bucket_name) with assert_raises(ClientError) as exc: s3.put_bucket_accelerate_configuration( - Bucket=bucket_name, - AccelerateConfiguration={'Status': 'Enabled'}, + Bucket=bucket_name, AccelerateConfiguration={"Status": "Enabled"} ) - exc.exception.response['Error']['Code'].should.equal('InvalidRequest') + exc.exception.response["Error"]["Code"].should.equal("InvalidRequest") def store_and_read_back_a_key(key): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - body = b'Some body' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + body = b"Some body" s3.create_bucket(Bucket=bucket_name) - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) response = s3.get_object(Bucket=bucket_name, Key=key) - response['Body'].read().should.equal(body) + response["Body"].read().should.equal(body) @mock_s3 def test_paths_with_leading_slashes_work(): - store_and_read_back_a_key('/a-key') + store_and_read_back_a_key("/a-key") @mock_s3 def test_root_dir_with_empty_name_works(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Does not work in server mode due to error in Workzeug') - store_and_read_back_a_key('/') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Does not work in server mode due to error in Workzeug") + store_and_read_back_a_key("/") -@parameterized([ - ('foo/bar/baz',), - ('foo',), - ('foo/run_dt%3D2019-01-01%252012%253A30%253A00',), -]) +@parameterized( + [("foo/bar/baz",), ("foo",), ("foo/run_dt%3D2019-01-01%252012%253A30%253A00",)] +) @mock_s3 def test_delete_objects_with_url_encoded_key(key): - s3 = boto3.client('s3', region_name='us-east-1') - bucket_name = 'mybucket' - body = b'Some body' + s3 = boto3.client("s3", region_name="us-east-1") + bucket_name = "mybucket" + body = b"Some body" s3.create_bucket(Bucket=bucket_name) def put_object(): - s3.put_object( - Bucket=bucket_name, - Key=key, - Body=body - ) + s3.put_object(Bucket=bucket_name, Key=key, Body=body) def assert_deleted(): with assert_raises(ClientError) as e: s3.get_object(Bucket=bucket_name, Key=key) - e.exception.response['Error']['Code'].should.equal('NoSuchKey') + e.exception.response["Error"]["Code"].should.equal("NoSuchKey") put_object() s3.delete_object(Bucket=bucket_name, Key=key) assert_deleted() put_object() - s3.delete_objects(Bucket=bucket_name, Delete={'Objects': [{'Key': key}]}) + s3.delete_objects(Bucket=bucket_name, Delete={"Objects": [{"Key": key}]}) assert_deleted() @@ -3324,67 +3283,98 @@ def test_list_config_discovered_resources(): from moto.s3.config import s3_config_query # Without any buckets: - assert s3_config_query.list_config_service_resources("global", "global", None, None, 100, None) == ([], None) + assert s3_config_query.list_config_service_resources( + "global", "global", None, None, 100, None + ) == ([], None) # With 10 buckets in us-west-2: for x in range(0, 10): - s3_config_query.backends['global'].create_bucket('bucket{}'.format(x), 'us-west-2') + s3_config_query.backends["global"].create_bucket( + "bucket{}".format(x), "us-west-2" + ) # With 2 buckets in eu-west-1: for x in range(10, 12): - s3_config_query.backends['global'].create_bucket('eu-bucket{}'.format(x), 'eu-west-1') + s3_config_query.backends["global"].create_bucket( + "eu-bucket{}".format(x), "eu-west-1" + ) - result, next_token = s3_config_query.list_config_service_resources(None, None, 100, None) + result, next_token = s3_config_query.list_config_service_resources( + None, None, 100, None + ) assert not next_token assert len(result) == 12 for x in range(0, 10): assert result[x] == { - 'type': 'AWS::S3::Bucket', - 'id': 'bucket{}'.format(x), - 'name': 'bucket{}'.format(x), - 'region': 'us-west-2' + "type": "AWS::S3::Bucket", + "id": "bucket{}".format(x), + "name": "bucket{}".format(x), + "region": "us-west-2", } for x in range(10, 12): assert result[x] == { - 'type': 'AWS::S3::Bucket', - 'id': 'eu-bucket{}'.format(x), - 'name': 'eu-bucket{}'.format(x), - 'region': 'eu-west-1' + "type": "AWS::S3::Bucket", + "id": "eu-bucket{}".format(x), + "name": "eu-bucket{}".format(x), + "region": "eu-west-1", } # With a name: - result, next_token = s3_config_query.list_config_service_resources(None, 'bucket0', 100, None) - assert len(result) == 1 and result[0]['name'] == 'bucket0' and not next_token + result, next_token = s3_config_query.list_config_service_resources( + None, "bucket0", 100, None + ) + assert len(result) == 1 and result[0]["name"] == "bucket0" and not next_token # With a region: - result, next_token = s3_config_query.list_config_service_resources(None, None, 100, None, resource_region='eu-west-1') - assert len(result) == 2 and not next_token and result[1]['name'] == 'eu-bucket11' + result, next_token = s3_config_query.list_config_service_resources( + None, None, 100, None, resource_region="eu-west-1" + ) + assert len(result) == 2 and not next_token and result[1]["name"] == "eu-bucket11" # With resource ids: - result, next_token = s3_config_query.list_config_service_resources(['bucket0', 'bucket1'], None, 100, None) - assert len(result) == 2 and result[0]['name'] == 'bucket0' and result[1]['name'] == 'bucket1' and not next_token + result, next_token = s3_config_query.list_config_service_resources( + ["bucket0", "bucket1"], None, 100, None + ) + assert ( + len(result) == 2 + and result[0]["name"] == "bucket0" + and result[1]["name"] == "bucket1" + and not next_token + ) # With duplicated resource ids: - result, next_token = s3_config_query.list_config_service_resources(['bucket0', 'bucket0'], None, 100, None) - assert len(result) == 1 and result[0]['name'] == 'bucket0' and not next_token + result, next_token = s3_config_query.list_config_service_resources( + ["bucket0", "bucket0"], None, 100, None + ) + assert len(result) == 1 and result[0]["name"] == "bucket0" and not next_token # Pagination: - result, next_token = s3_config_query.list_config_service_resources(None, None, 1, None) - assert len(result) == 1 and result[0]['name'] == 'bucket0' and next_token == 'bucket1' + result, next_token = s3_config_query.list_config_service_resources( + None, None, 1, None + ) + assert ( + len(result) == 1 and result[0]["name"] == "bucket0" and next_token == "bucket1" + ) # Last Page: - result, next_token = s3_config_query.list_config_service_resources(None, None, 1, 'eu-bucket11', resource_region='eu-west-1') - assert len(result) == 1 and result[0]['name'] == 'eu-bucket11' and not next_token + result, next_token = s3_config_query.list_config_service_resources( + None, None, 1, "eu-bucket11", resource_region="eu-west-1" + ) + assert len(result) == 1 and result[0]["name"] == "eu-bucket11" and not next_token # With a list of buckets: - result, next_token = s3_config_query.list_config_service_resources(['bucket0', 'bucket1'], None, 1, None) - assert len(result) == 1 and result[0]['name'] == 'bucket0' and next_token == 'bucket1' + result, next_token = s3_config_query.list_config_service_resources( + ["bucket0", "bucket1"], None, 1, None + ) + assert ( + len(result) == 1 and result[0]["name"] == "bucket0" and next_token == "bucket1" + ) # With an invalid page: with assert_raises(InvalidNextTokenException) as inte: - s3_config_query.list_config_service_resources(None, None, 1, 'notabucket') + s3_config_query.list_config_service_resources(None, None, 1, "notabucket") - assert 'The nextToken provided is invalid' in inte.exception.message + assert "The nextToken provided is invalid" in inte.exception.message @mock_s3 @@ -3392,132 +3382,112 @@ def test_s3_lifecycle_config_dict(): from moto.s3.config import s3_config_query # With 1 bucket in us-west-2: - s3_config_query.backends['global'].create_bucket('bucket1', 'us-west-2') + s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") # And a lifecycle policy lifecycle = [ { - 'ID': 'rule1', - 'Status': 'Enabled', - 'Filter': {'Prefix': ''}, - 'Expiration': {'Days': 1} + "ID": "rule1", + "Status": "Enabled", + "Filter": {"Prefix": ""}, + "Expiration": {"Days": 1}, }, { - 'ID': 'rule2', - 'Status': 'Enabled', - 'Filter': { - 'And': { - 'Prefix': 'some/path', - 'Tag': [ - {'Key': 'TheKey', 'Value': 'TheValue'} - ] + "ID": "rule2", + "Status": "Enabled", + "Filter": { + "And": { + "Prefix": "some/path", + "Tag": [{"Key": "TheKey", "Value": "TheValue"}], } }, - 'Expiration': {'Days': 1} + "Expiration": {"Days": 1}, }, + {"ID": "rule3", "Status": "Enabled", "Filter": {}, "Expiration": {"Days": 1}}, { - 'ID': 'rule3', - 'Status': 'Enabled', - 'Filter': {}, - 'Expiration': {'Days': 1} + "ID": "rule4", + "Status": "Enabled", + "Filter": {"Prefix": ""}, + "AbortIncompleteMultipartUpload": {"DaysAfterInitiation": 1}, }, - { - 'ID': 'rule4', - 'Status': 'Enabled', - 'Filter': {'Prefix': ''}, - 'AbortIncompleteMultipartUpload': {'DaysAfterInitiation': 1} - } ] - s3_config_query.backends['global'].set_bucket_lifecycle('bucket1', lifecycle) + s3_config_query.backends["global"].set_bucket_lifecycle("bucket1", lifecycle) # Get the rules for this: - lifecycles = [rule.to_config_dict() for rule in s3_config_query.backends['global'].buckets['bucket1'].rules] + lifecycles = [ + rule.to_config_dict() + for rule in s3_config_query.backends["global"].buckets["bucket1"].rules + ] # Verify the first: assert lifecycles[0] == { - 'id': 'rule1', - 'prefix': None, - 'status': 'Enabled', - 'expirationInDays': 1, - 'expiredObjectDeleteMarker': None, - 'noncurrentVersionExpirationInDays': -1, - 'expirationDate': None, - 'transitions': None, - 'noncurrentVersionTransitions': None, - 'abortIncompleteMultipartUpload': None, - 'filter': { - 'predicate': { - 'type': 'LifecyclePrefixPredicate', - 'prefix': '' - } - } + "id": "rule1", + "prefix": None, + "status": "Enabled", + "expirationInDays": 1, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": None, + "filter": {"predicate": {"type": "LifecyclePrefixPredicate", "prefix": ""}}, } # Verify the second: assert lifecycles[1] == { - 'id': 'rule2', - 'prefix': None, - 'status': 'Enabled', - 'expirationInDays': 1, - 'expiredObjectDeleteMarker': None, - 'noncurrentVersionExpirationInDays': -1, - 'expirationDate': None, - 'transitions': None, - 'noncurrentVersionTransitions': None, - 'abortIncompleteMultipartUpload': None, - 'filter': { - 'predicate': { - 'type': 'LifecycleAndOperator', - 'operands': [ + "id": "rule2", + "prefix": None, + "status": "Enabled", + "expirationInDays": 1, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": None, + "filter": { + "predicate": { + "type": "LifecycleAndOperator", + "operands": [ + {"type": "LifecyclePrefixPredicate", "prefix": "some/path"}, { - 'type': 'LifecyclePrefixPredicate', - 'prefix': 'some/path' + "type": "LifecycleTagPredicate", + "tag": {"key": "TheKey", "value": "TheValue"}, }, - { - 'type': 'LifecycleTagPredicate', - 'tag': { - 'key': 'TheKey', - 'value': 'TheValue' - } - }, - ] + ], } - } + }, } # And the third: assert lifecycles[2] == { - 'id': 'rule3', - 'prefix': None, - 'status': 'Enabled', - 'expirationInDays': 1, - 'expiredObjectDeleteMarker': None, - 'noncurrentVersionExpirationInDays': -1, - 'expirationDate': None, - 'transitions': None, - 'noncurrentVersionTransitions': None, - 'abortIncompleteMultipartUpload': None, - 'filter': {'predicate': None} + "id": "rule3", + "prefix": None, + "status": "Enabled", + "expirationInDays": 1, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": None, + "filter": {"predicate": None}, } # And the last: assert lifecycles[3] == { - 'id': 'rule4', - 'prefix': None, - 'status': 'Enabled', - 'expirationInDays': None, - 'expiredObjectDeleteMarker': None, - 'noncurrentVersionExpirationInDays': -1, - 'expirationDate': None, - 'transitions': None, - 'noncurrentVersionTransitions': None, - 'abortIncompleteMultipartUpload': {'daysAfterInitiation': 1}, - 'filter': { - 'predicate': { - 'type': 'LifecyclePrefixPredicate', - 'prefix': '' - } - } + "id": "rule4", + "prefix": None, + "status": "Enabled", + "expirationInDays": None, + "expiredObjectDeleteMarker": None, + "noncurrentVersionExpirationInDays": -1, + "expirationDate": None, + "transitions": None, + "noncurrentVersionTransitions": None, + "abortIncompleteMultipartUpload": {"daysAfterInitiation": 1}, + "filter": {"predicate": {"type": "LifecyclePrefixPredicate", "prefix": ""}}, } @@ -3526,99 +3496,98 @@ def test_s3_notification_config_dict(): from moto.s3.config import s3_config_query # With 1 bucket in us-west-2: - s3_config_query.backends['global'].create_bucket('bucket1', 'us-west-2') + s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") # And some notifications: notifications = { - 'TopicConfiguration': [{ - 'Id': 'Topic', - "Topic": 'arn:aws:sns:us-west-2:012345678910:mytopic', - "Event": [ - "s3:ReducedRedundancyLostObject", - "s3:ObjectRestore:Completed" - ] - }], - 'QueueConfiguration': [{ - 'Id': 'Queue', - 'Queue': 'arn:aws:sqs:us-west-2:012345678910:myqueue', - 'Event': [ - "s3:ObjectRemoved:Delete" - ], - 'Filter': { - 'S3Key': { - 'FilterRule': [ - { - 'Name': 'prefix', - 'Value': 'stuff/here/' - } - ] - } + "TopicConfiguration": [ + { + "Id": "Topic", + "Topic": "arn:aws:sns:us-west-2:012345678910:mytopic", + "Event": [ + "s3:ReducedRedundancyLostObject", + "s3:ObjectRestore:Completed", + ], } - }], - 'CloudFunctionConfiguration': [{ - 'Id': 'Lambda', - 'CloudFunction': 'arn:aws:lambda:us-west-2:012345678910:function:mylambda', - 'Event': [ - "s3:ObjectCreated:Post", - "s3:ObjectCreated:Copy", - "s3:ObjectCreated:Put" - ], - 'Filter': { - 'S3Key': { - 'FilterRule': [ - { - 'Name': 'suffix', - 'Value': '.png' - } - ] - } + ], + "QueueConfiguration": [ + { + "Id": "Queue", + "Queue": "arn:aws:sqs:us-west-2:012345678910:myqueue", + "Event": ["s3:ObjectRemoved:Delete"], + "Filter": { + "S3Key": { + "FilterRule": [{"Name": "prefix", "Value": "stuff/here/"}] + } + }, } - }] + ], + "CloudFunctionConfiguration": [ + { + "Id": "Lambda", + "CloudFunction": "arn:aws:lambda:us-west-2:012345678910:function:mylambda", + "Event": [ + "s3:ObjectCreated:Post", + "s3:ObjectCreated:Copy", + "s3:ObjectCreated:Put", + ], + "Filter": { + "S3Key": {"FilterRule": [{"Name": "suffix", "Value": ".png"}]} + }, + } + ], } - s3_config_query.backends['global'].put_bucket_notification_configuration('bucket1', notifications) + s3_config_query.backends["global"].put_bucket_notification_configuration( + "bucket1", notifications + ) # Get the notifications for this: - notifications = s3_config_query.backends['global'].buckets['bucket1'].notification_configuration.to_config_dict() + notifications = ( + s3_config_query.backends["global"] + .buckets["bucket1"] + .notification_configuration.to_config_dict() + ) # Verify it all: assert notifications == { - 'configurations': { - 'Topic': { - 'events': ['s3:ReducedRedundancyLostObject', 's3:ObjectRestore:Completed'], - 'filter': None, - 'objectPrefixes': [], - 'topicARN': 'arn:aws:sns:us-west-2:012345678910:mytopic', - 'type': 'TopicConfiguration' + "configurations": { + "Topic": { + "events": [ + "s3:ReducedRedundancyLostObject", + "s3:ObjectRestore:Completed", + ], + "filter": None, + "objectPrefixes": [], + "topicARN": "arn:aws:sns:us-west-2:012345678910:mytopic", + "type": "TopicConfiguration", }, - 'Queue': { - 'events': ['s3:ObjectRemoved:Delete'], - 'filter': { - 's3KeyFilter': { - 'filterRules': [{ - 'name': 'prefix', - 'value': 'stuff/here/' - }] + "Queue": { + "events": ["s3:ObjectRemoved:Delete"], + "filter": { + "s3KeyFilter": { + "filterRules": [{"name": "prefix", "value": "stuff/here/"}] } }, - 'objectPrefixes': [], - 'queueARN': 'arn:aws:sqs:us-west-2:012345678910:myqueue', - 'type': 'QueueConfiguration' + "objectPrefixes": [], + "queueARN": "arn:aws:sqs:us-west-2:012345678910:myqueue", + "type": "QueueConfiguration", }, - 'Lambda': { - 'events': ['s3:ObjectCreated:Post', 's3:ObjectCreated:Copy', 's3:ObjectCreated:Put'], - 'filter': { - 's3KeyFilter': { - 'filterRules': [{ - 'name': 'suffix', - 'value': '.png' - }] + "Lambda": { + "events": [ + "s3:ObjectCreated:Post", + "s3:ObjectCreated:Copy", + "s3:ObjectCreated:Put", + ], + "filter": { + "s3KeyFilter": { + "filterRules": [{"name": "suffix", "value": ".png"}] } }, - 'objectPrefixes': [], - 'queueARN': 'arn:aws:lambda:us-west-2:012345678910:function:mylambda', - 'type': 'LambdaConfiguration' - } + "objectPrefixes": [], + "queueARN": "arn:aws:lambda:us-west-2:012345678910:function:mylambda", + "type": "LambdaConfiguration", + }, } } @@ -3629,129 +3598,197 @@ def test_s3_acl_to_config_dict(): from moto.s3.models import FakeAcl, FakeGrant, FakeGrantee, OWNER # With 1 bucket in us-west-2: - s3_config_query.backends['global'].create_bucket('logbucket', 'us-west-2') + s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2") # Get the config dict with nothing other than the owner details: - acls = s3_config_query.backends['global'].buckets['logbucket'].acl.to_config_dict() - assert acls == { - 'grantSet': None, - 'owner': {'displayName': None, 'id': OWNER} - } + acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() + assert acls == {"grantSet": None, "owner": {"displayName": None, "id": OWNER}} # Add some Log Bucket ACLs: - log_acls = FakeAcl([ - FakeGrant([FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], "WRITE"), - FakeGrant([FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], "READ_ACP"), - FakeGrant([FakeGrantee(id=OWNER)], "FULL_CONTROL") - ]) - s3_config_query.backends['global'].set_bucket_acl('logbucket', log_acls) + log_acls = FakeAcl( + [ + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "WRITE", + ), + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "READ_ACP", + ), + FakeGrant([FakeGrantee(id=OWNER)], "FULL_CONTROL"), + ] + ) + s3_config_query.backends["global"].set_bucket_acl("logbucket", log_acls) - acls = s3_config_query.backends['global'].buckets['logbucket'].acl.to_config_dict() + acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() assert acls == { - 'grantSet': None, - 'grantList': [{'grantee': 'LogDelivery', 'permission': 'Write'}, {'grantee': 'LogDelivery', 'permission': 'ReadAcp'}], - 'owner': {'displayName': None, 'id': OWNER} + "grantSet": None, + "grantList": [ + {"grantee": "LogDelivery", "permission": "Write"}, + {"grantee": "LogDelivery", "permission": "ReadAcp"}, + ], + "owner": {"displayName": None, "id": OWNER}, } # Give the owner less than full_control permissions: - log_acls = FakeAcl([FakeGrant([FakeGrantee(id=OWNER)], "READ_ACP"), FakeGrant([FakeGrantee(id=OWNER)], "WRITE_ACP")]) - s3_config_query.backends['global'].set_bucket_acl('logbucket', log_acls) - acls = s3_config_query.backends['global'].buckets['logbucket'].acl.to_config_dict() + log_acls = FakeAcl( + [ + FakeGrant([FakeGrantee(id=OWNER)], "READ_ACP"), + FakeGrant([FakeGrantee(id=OWNER)], "WRITE_ACP"), + ] + ) + s3_config_query.backends["global"].set_bucket_acl("logbucket", log_acls) + acls = s3_config_query.backends["global"].buckets["logbucket"].acl.to_config_dict() assert acls == { - 'grantSet': None, - 'grantList': [ - {'grantee': {'id': OWNER, 'displayName': None}, 'permission': 'ReadAcp'}, - {'grantee': {'id': OWNER, 'displayName': None}, 'permission': 'WriteAcp'} + "grantSet": None, + "grantList": [ + {"grantee": {"id": OWNER, "displayName": None}, "permission": "ReadAcp"}, + {"grantee": {"id": OWNER, "displayName": None}, "permission": "WriteAcp"}, ], - 'owner': {'displayName': None, 'id': OWNER} + "owner": {"displayName": None, "id": OWNER}, } @mock_s3 def test_s3_config_dict(): from moto.s3.config import s3_config_query - from moto.s3.models import FakeAcl, FakeGrant, FakeGrantee, FakeTag, FakeTagging, FakeTagSet, OWNER + from moto.s3.models import ( + FakeAcl, + FakeGrant, + FakeGrantee, + FakeTag, + FakeTagging, + FakeTagSet, + OWNER, + ) # Without any buckets: - assert not s3_config_query.get_config_resource('some_bucket') + assert not s3_config_query.get_config_resource("some_bucket") - tags = FakeTagging(FakeTagSet([FakeTag('someTag', 'someValue'), FakeTag('someOtherTag', 'someOtherValue')])) + tags = FakeTagging( + FakeTagSet( + [FakeTag("someTag", "someValue"), FakeTag("someOtherTag", "someOtherValue")] + ) + ) # With 1 bucket in us-west-2: - s3_config_query.backends['global'].create_bucket('bucket1', 'us-west-2') - s3_config_query.backends['global'].put_bucket_tagging('bucket1', tags) + s3_config_query.backends["global"].create_bucket("bucket1", "us-west-2") + s3_config_query.backends["global"].put_bucket_tagging("bucket1", tags) # With a log bucket: - s3_config_query.backends['global'].create_bucket('logbucket', 'us-west-2') - log_acls = FakeAcl([ - FakeGrant([FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], "WRITE"), - FakeGrant([FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], "READ_ACP"), - FakeGrant([FakeGrantee(id=OWNER)], "FULL_CONTROL") - ]) - - s3_config_query.backends['global'].set_bucket_acl('logbucket', log_acls) - s3_config_query.backends['global'].put_bucket_logging('bucket1', {'TargetBucket': 'logbucket', 'TargetPrefix': ''}) - - policy = json.dumps({ - 'Statement': [ - { - "Effect": "Deny", - "Action": "s3:DeleteObject", - "Principal": "*", - "Resource": "arn:aws:s3:::bucket1/*" - } + s3_config_query.backends["global"].create_bucket("logbucket", "us-west-2") + log_acls = FakeAcl( + [ + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "WRITE", + ), + FakeGrant( + [FakeGrantee(uri="http://acs.amazonaws.com/groups/s3/LogDelivery")], + "READ_ACP", + ), + FakeGrant([FakeGrantee(id=OWNER)], "FULL_CONTROL"), ] - }) + ) + + s3_config_query.backends["global"].set_bucket_acl("logbucket", log_acls) + s3_config_query.backends["global"].put_bucket_logging( + "bucket1", {"TargetBucket": "logbucket", "TargetPrefix": ""} + ) + + policy = json.dumps( + { + "Statement": [ + { + "Effect": "Deny", + "Action": "s3:DeleteObject", + "Principal": "*", + "Resource": "arn:aws:s3:::bucket1/*", + } + ] + } + ) # The policy is a byte array -- need to encode in Python 3 -- for Python 2 just pass the raw string in: if sys.version_info[0] > 2: - pass_policy = bytes(policy, 'utf-8') + pass_policy = bytes(policy, "utf-8") else: pass_policy = policy - s3_config_query.backends['global'].set_bucket_policy('bucket1', pass_policy) + s3_config_query.backends["global"].set_bucket_policy("bucket1", pass_policy) # Get the us-west-2 bucket and verify that it works properly: - bucket1_result = s3_config_query.get_config_resource('bucket1') + bucket1_result = s3_config_query.get_config_resource("bucket1") # Just verify a few things: - assert bucket1_result['arn'] == 'arn:aws:s3:::bucket1' - assert bucket1_result['awsRegion'] == 'us-west-2' - assert bucket1_result['resourceName'] == bucket1_result['resourceId'] == 'bucket1' - assert bucket1_result['tags'] == {'someTag': 'someValue', 'someOtherTag': 'someOtherValue'} - assert json.loads(bucket1_result['supplementaryConfiguration']['BucketTaggingConfiguration']) == \ - {'tagSets': [{'tags': bucket1_result['tags']}]} - assert isinstance(bucket1_result['configuration'], str) - exist_list = ['AccessControlList', 'BucketAccelerateConfiguration', 'BucketLoggingConfiguration', 'BucketPolicy', - 'IsRequesterPaysEnabled', 'BucketNotificationConfiguration'] + assert bucket1_result["arn"] == "arn:aws:s3:::bucket1" + assert bucket1_result["awsRegion"] == "us-west-2" + assert bucket1_result["resourceName"] == bucket1_result["resourceId"] == "bucket1" + assert bucket1_result["tags"] == { + "someTag": "someValue", + "someOtherTag": "someOtherValue", + } + assert json.loads( + bucket1_result["supplementaryConfiguration"]["BucketTaggingConfiguration"] + ) == {"tagSets": [{"tags": bucket1_result["tags"]}]} + assert isinstance(bucket1_result["configuration"], str) + exist_list = [ + "AccessControlList", + "BucketAccelerateConfiguration", + "BucketLoggingConfiguration", + "BucketPolicy", + "IsRequesterPaysEnabled", + "BucketNotificationConfiguration", + ] for exist in exist_list: - assert isinstance(bucket1_result['supplementaryConfiguration'][exist], str) + assert isinstance(bucket1_result["supplementaryConfiguration"][exist], str) # Verify the logging config: - assert json.loads(bucket1_result['supplementaryConfiguration']['BucketLoggingConfiguration']) == \ - {'destinationBucketName': 'logbucket', 'logFilePrefix': ''} + assert json.loads( + bucket1_result["supplementaryConfiguration"]["BucketLoggingConfiguration"] + ) == {"destinationBucketName": "logbucket", "logFilePrefix": ""} # Verify that the AccessControlList is a double-wrapped JSON string: - assert json.loads(json.loads(bucket1_result['supplementaryConfiguration']['AccessControlList'])) == \ - {'grantSet': None, 'owner': {'displayName': None, 'id': '75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a'}} + assert json.loads( + json.loads(bucket1_result["supplementaryConfiguration"]["AccessControlList"]) + ) == { + "grantSet": None, + "owner": { + "displayName": None, + "id": "75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a", + }, + } # Verify the policy: - assert json.loads(bucket1_result['supplementaryConfiguration']['BucketPolicy']) == {'policyText': policy} + assert json.loads(bucket1_result["supplementaryConfiguration"]["BucketPolicy"]) == { + "policyText": policy + } # Filter by correct region: - assert bucket1_result == s3_config_query.get_config_resource('bucket1', resource_region='us-west-2') + assert bucket1_result == s3_config_query.get_config_resource( + "bucket1", resource_region="us-west-2" + ) # By incorrect region: - assert not s3_config_query.get_config_resource('bucket1', resource_region='eu-west-1') + assert not s3_config_query.get_config_resource( + "bucket1", resource_region="eu-west-1" + ) # With correct resource ID and name: - assert bucket1_result == s3_config_query.get_config_resource('bucket1', resource_name='bucket1') + assert bucket1_result == s3_config_query.get_config_resource( + "bucket1", resource_name="bucket1" + ) # With an incorrect resource name: - assert not s3_config_query.get_config_resource('bucket1', resource_name='eu-bucket-1') + assert not s3_config_query.get_config_resource( + "bucket1", resource_name="eu-bucket-1" + ) # Verify that no bucket policy returns the proper value: - logging_bucket = s3_config_query.get_config_resource('logbucket') - assert json.loads(logging_bucket['supplementaryConfiguration']['BucketPolicy']) == \ - {'policyText': None} - assert not logging_bucket['tags'] - assert not logging_bucket['supplementaryConfiguration'].get('BucketTaggingConfiguration') + logging_bucket = s3_config_query.get_config_resource("logbucket") + assert json.loads(logging_bucket["supplementaryConfiguration"]["BucketPolicy"]) == { + "policyText": None + } + assert not logging_bucket["tags"] + assert not logging_bucket["supplementaryConfiguration"].get( + "BucketTaggingConfiguration" + ) diff --git a/tests/test_s3/test_s3_lifecycle.py b/tests/test_s3/test_s3_lifecycle.py index 5b05fe518..260b248f1 100644 --- a/tests/test_s3/test_s3_lifecycle.py +++ b/tests/test_s3/test_s3_lifecycle.py @@ -19,14 +19,14 @@ def test_lifecycle_create(): bucket = conn.create_bucket("foobar") lifecycle = Lifecycle() - lifecycle.add_rule('myid', '', 'Enabled', 30) + lifecycle.add_rule("myid", "", "Enabled", 30) bucket.configure_lifecycle(lifecycle) response = bucket.get_lifecycle_config() len(response).should.equal(1) lifecycle = response[0] - lifecycle.id.should.equal('myid') - lifecycle.prefix.should.equal('') - lifecycle.status.should.equal('Enabled') + lifecycle.id.should.equal("myid") + lifecycle.prefix.should.equal("") + lifecycle.status.should.equal("Enabled") list(lifecycle.transition).should.equal([]) @@ -39,21 +39,19 @@ def test_lifecycle_with_filters(): lfc = { "Rules": [ { - "Expiration": { - "Days": 7 - }, + "Expiration": {"Days": 7}, "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" + "Filter": {"Prefix": ""}, + "Status": "Enabled", } ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 - assert result["Rules"][0]["Filter"]["Prefix"] == '' + assert result["Rules"][0]["Filter"]["Prefix"] == "" assert not result["Rules"][0]["Filter"].get("And") assert not result["Rules"][0]["Filter"].get("Tag") with assert_raises(KeyError): @@ -63,39 +61,38 @@ def test_lifecycle_with_filters(): lfc = { "Rules": [ { - "Expiration": { - "Days": 7 - }, + "Expiration": {"Days": 7}, "ID": "wholebucket", "Filter": {}, - "Status": "Enabled" + "Status": "Enabled", } ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 with assert_raises(KeyError): assert result["Rules"][0]["Prefix"] # If we remove the filter -- and don't specify a Prefix, then this is bad: - lfc['Rules'][0].pop('Filter') + lfc["Rules"][0].pop("Filter") with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" # With a tag: - lfc["Rules"][0]["Filter"] = { - 'Tag': { - "Key": "mytag", - "Value": "mytagvalue" - } - } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + lfc["Rules"][0]["Filter"] = {"Tag": {"Key": "mytag", "Value": "mytagvalue"}} + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 with assert_raises(KeyError): - assert result["Rules"][0]["Filter"]['Prefix'] + assert result["Rules"][0]["Filter"]["Prefix"] assert not result["Rules"][0]["Filter"].get("And") assert result["Rules"][0]["Filter"]["Tag"]["Key"] == "mytag" assert result["Rules"][0]["Filter"]["Tag"]["Value"] == "mytagvalue" @@ -106,15 +103,12 @@ def test_lifecycle_with_filters(): lfc["Rules"][0]["Filter"] = { "And": { "Prefix": "some/prefix", - "Tags": [ - { - "Key": "mytag", - "Value": "mytagvalue" - } - ] + "Tags": [{"Key": "mytag", "Value": "mytagvalue"}], } } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert not result["Rules"][0]["Filter"].get("Prefix") @@ -129,17 +123,13 @@ def test_lifecycle_with_filters(): lfc["Rules"][0]["Filter"]["And"] = { "Prefix": "some/prefix", "Tags": [ - { - "Key": "mytag", - "Value": "mytagvalue" - }, - { - "Key": "mytag2", - "Value": "mytagvalue2" - } - ] + {"Key": "mytag", "Value": "mytagvalue"}, + {"Key": "mytag2", "Value": "mytagvalue2"}, + ], } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert not result["Rules"][0]["Filter"].get("Prefix") @@ -155,17 +145,13 @@ def test_lifecycle_with_filters(): # And filter without Prefix but multiple Tags: lfc["Rules"][0]["Filter"]["And"] = { "Tags": [ - { - "Key": "mytag", - "Value": "mytagvalue" - }, - { - "Key": "mytag2", - "Value": "mytagvalue2" - } + {"Key": "mytag", "Value": "mytagvalue"}, + {"Key": "mytag2", "Value": "mytagvalue2"}, ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 with assert_raises(KeyError): @@ -179,89 +165,81 @@ def test_lifecycle_with_filters(): assert result["Rules"][0]["Prefix"] # Can't have both filter and prefix: - lfc["Rules"][0]["Prefix"] = '' + lfc["Rules"][0]["Prefix"] = "" with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" - lfc["Rules"][0]["Prefix"] = 'some/path' + lfc["Rules"][0]["Prefix"] = "some/path" with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" # No filters -- just a prefix: del lfc["Rules"][0]["Filter"] - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert not result["Rules"][0].get("Filter") assert result["Rules"][0]["Prefix"] == "some/path" # Can't have Tag, Prefix, and And in a filter: - del lfc['Rules'][0]['Prefix'] + del lfc["Rules"][0]["Prefix"] lfc["Rules"][0]["Filter"] = { "Prefix": "some/prefix", - "Tag": { - "Key": "mytag", - "Value": "mytagvalue" - } + "Tag": {"Key": "mytag", "Value": "mytagvalue"}, } with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" lfc["Rules"][0]["Filter"] = { - "Tag": { - "Key": "mytag", - "Value": "mytagvalue" - }, + "Tag": {"Key": "mytag", "Value": "mytagvalue"}, "And": { "Prefix": "some/prefix", "Tags": [ - { - "Key": "mytag", - "Value": "mytagvalue" - }, - { - "Key": "mytag2", - "Value": "mytagvalue2" - } - ] - } + {"Key": "mytag", "Value": "mytagvalue"}, + {"Key": "mytag2", "Value": "mytagvalue2"}, + ], + }, } with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" # Make sure multiple rules work: lfc = { "Rules": [ { - "Expiration": { - "Days": 7 - }, + "Expiration": {"Days": 7}, "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" + "Filter": {"Prefix": ""}, + "Status": "Enabled", }, { - "Expiration": { - "Days": 10 - }, + "Expiration": {"Days": 10}, "ID": "Tags", - "Filter": { - "Tag": {'Key': 'somekey', 'Value': 'somevalue'} - }, - "Status": "Enabled" - } + "Filter": {"Tag": {"Key": "somekey", "Value": "somevalue"}}, + "Status": "Enabled", + }, ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) - result = client.get_bucket_lifecycle_configuration(Bucket="bucket")['Rules'] + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) + result = client.get_bucket_lifecycle_configuration(Bucket="bucket")["Rules"] assert len(result) == 2 - assert result[0]['ID'] == 'wholebucket' - assert result[1]['ID'] == 'Tags' + assert result[0]["ID"] == "wholebucket" + assert result[1]["ID"] == "Tags" @mock_s3 @@ -272,25 +250,25 @@ def test_lifecycle_with_eodm(): lfc = { "Rules": [ { - "Expiration": { - "ExpiredObjectDeleteMarker": True - }, + "Expiration": {"ExpiredObjectDeleteMarker": True}, "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" + "Filter": {"Prefix": ""}, + "Status": "Enabled", } ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert result["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] # Set to False: lfc["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] = False - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert not result["Rules"][0]["Expiration"]["ExpiredObjectDeleteMarker"] @@ -298,13 +276,17 @@ def test_lifecycle_with_eodm(): # With failure: lfc["Rules"][0]["Expiration"]["Days"] = 7 with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" del lfc["Rules"][0]["Expiration"]["Days"] lfc["Rules"][0]["Expiration"]["Date"] = datetime(2015, 1, 1) with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" @@ -316,25 +298,25 @@ def test_lifecycle_with_nve(): lfc = { "Rules": [ { - "NoncurrentVersionExpiration": { - "NoncurrentDays": 30 - }, + "NoncurrentVersionExpiration": {"NoncurrentDays": 30}, "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" + "Filter": {"Prefix": ""}, + "Status": "Enabled", } ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 30 # Change NoncurrentDays: lfc["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] = 10 - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert result["Rules"][0]["NoncurrentVersionExpiration"]["NoncurrentDays"] == 10 @@ -350,48 +332,61 @@ def test_lifecycle_with_nvt(): lfc = { "Rules": [ { - "NoncurrentVersionTransitions": [{ - "NoncurrentDays": 30, - "StorageClass": "ONEZONE_IA" - }], + "NoncurrentVersionTransitions": [ + {"NoncurrentDays": 30, "StorageClass": "ONEZONE_IA"} + ], "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" + "Filter": {"Prefix": ""}, + "Status": "Enabled", } ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 30 - assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] == "ONEZONE_IA" + assert ( + result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] + == "ONEZONE_IA" + ) # Change NoncurrentDays: lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 10 - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] == 10 # Change StorageClass: lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] = "GLACIER" - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 - assert result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] == "GLACIER" + assert ( + result["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] + == "GLACIER" + ) # With failures for missing children: del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["NoncurrentDays"] = 30 del lfc["Rules"][0]["NoncurrentVersionTransitions"][0]["StorageClass"] with assert_raises(ClientError) as err: - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) assert err.exception.response["Error"]["Code"] == "MalformedXML" @@ -403,28 +398,33 @@ def test_lifecycle_with_aimu(): lfc = { "Rules": [ { - "AbortIncompleteMultipartUpload": { - "DaysAfterInitiation": 7 - }, + "AbortIncompleteMultipartUpload": {"DaysAfterInitiation": 7}, "ID": "wholebucket", - "Filter": { - "Prefix": "" - }, - "Status": "Enabled" + "Filter": {"Prefix": ""}, + "Status": "Enabled", } ] } - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 - assert result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 7 + assert ( + result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 7 + ) # Change DaysAfterInitiation: lfc["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] = 30 - client.put_bucket_lifecycle_configuration(Bucket="bucket", LifecycleConfiguration=lfc) + client.put_bucket_lifecycle_configuration( + Bucket="bucket", LifecycleConfiguration=lfc + ) result = client.get_bucket_lifecycle_configuration(Bucket="bucket") assert len(result["Rules"]) == 1 - assert result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] == 30 + assert ( + result["Rules"][0]["AbortIncompleteMultipartUpload"]["DaysAfterInitiation"] + == 30 + ) # TODO: Add test for failures due to missing children @@ -435,15 +435,16 @@ def test_lifecycle_with_glacier_transition(): bucket = conn.create_bucket("foobar") lifecycle = Lifecycle() - transition = Transition(days=30, storage_class='GLACIER') - rule = Rule('myid', prefix='', status='Enabled', expiration=None, - transition=transition) + transition = Transition(days=30, storage_class="GLACIER") + rule = Rule( + "myid", prefix="", status="Enabled", expiration=None, transition=transition + ) lifecycle.append(rule) bucket.configure_lifecycle(lifecycle) response = bucket.get_lifecycle_config() transition = response[0].transition transition.days.should.equal(30) - transition.storage_class.should.equal('GLACIER') + transition.storage_class.should.equal("GLACIER") transition.date.should.equal(None) @@ -452,16 +453,16 @@ def test_lifecycle_multi(): conn = boto.s3.connect_to_region("us-west-1") bucket = conn.create_bucket("foobar") - date = '2022-10-12T00:00:00.000Z' - sc = 'GLACIER' + date = "2022-10-12T00:00:00.000Z" + sc = "GLACIER" lifecycle = Lifecycle() lifecycle.add_rule("1", "1/", "Enabled", 1) lifecycle.add_rule("2", "2/", "Enabled", Expiration(days=2)) lifecycle.add_rule("3", "3/", "Enabled", Expiration(date=date)) - lifecycle.add_rule("4", "4/", "Enabled", None, - Transition(days=4, storage_class=sc)) - lifecycle.add_rule("5", "5/", "Enabled", None, - Transition(date=date, storage_class=sc)) + lifecycle.add_rule("4", "4/", "Enabled", None, Transition(days=4, storage_class=sc)) + lifecycle.add_rule( + "5", "5/", "Enabled", None, Transition(date=date, storage_class=sc) + ) bucket.configure_lifecycle(lifecycle) # read the lifecycle back diff --git a/tests/test_s3/test_s3_storageclass.py b/tests/test_s3/test_s3_storageclass.py index c72b773a9..dbdc85c42 100644 --- a/tests/test_s3/test_s3_storageclass.py +++ b/tests/test_s3/test_s3_storageclass.py @@ -11,30 +11,35 @@ from moto import mock_s3 @mock_s3 def test_s3_storage_class_standard(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - # add an object to the bucket with standard storage + # add an object to the bucket with standard storage - s3.put_object(Bucket="Bucket", Key="my_key", Body="my_value") + s3.put_object(Bucket="Bucket", Key="my_key", Body="my_value") - list_of_objects = s3.list_objects(Bucket="Bucket") + list_of_objects = s3.list_objects(Bucket="Bucket") - list_of_objects['Contents'][0]["StorageClass"].should.equal("STANDARD") + list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD") @mock_s3 def test_s3_storage_class_infrequent_access(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - # add an object to the bucket with standard storage + # add an object to the bucket with standard storage - s3.put_object(Bucket="Bucket", Key="my_key_infrequent", Body="my_value_infrequent", StorageClass="STANDARD_IA") + s3.put_object( + Bucket="Bucket", + Key="my_key_infrequent", + Body="my_value_infrequent", + StorageClass="STANDARD_IA", + ) - D = s3.list_objects(Bucket="Bucket") + D = s3.list_objects(Bucket="Bucket") - D['Contents'][0]["StorageClass"].should.equal("STANDARD_IA") + D["Contents"][0]["StorageClass"].should.equal("STANDARD_IA") @mock_s3 @@ -42,74 +47,104 @@ def test_s3_storage_class_intelligent_tiering(): s3 = boto3.client("s3") s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="my_key_infrequent", Body="my_value_infrequent", StorageClass="INTELLIGENT_TIERING") + s3.put_object( + Bucket="Bucket", + Key="my_key_infrequent", + Body="my_value_infrequent", + StorageClass="INTELLIGENT_TIERING", + ) objects = s3.list_objects(Bucket="Bucket") - objects['Contents'][0]["StorageClass"].should.equal("INTELLIGENT_TIERING") + objects["Contents"][0]["StorageClass"].should.equal("INTELLIGENT_TIERING") @mock_s3 def test_s3_storage_class_copy(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD" + ) - s3.create_bucket(Bucket="Bucket2") - # second object is originally of storage class REDUCED_REDUNDANCY - s3.put_object(Bucket="Bucket2", Key="Second_Object", Body="Body2") + s3.create_bucket(Bucket="Bucket2") + # second object is originally of storage class REDUCED_REDUNDANCY + s3.put_object(Bucket="Bucket2", Key="Second_Object", Body="Body2") - s3.copy_object(CopySource = {"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket2", Key="Second_Object", StorageClass="ONEZONE_IA") + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket2", + Key="Second_Object", + StorageClass="ONEZONE_IA", + ) - list_of_copied_objects = s3.list_objects(Bucket="Bucket2") + list_of_copied_objects = s3.list_objects(Bucket="Bucket2") - # checks that a copied object can be properly copied - list_of_copied_objects["Contents"][0]["StorageClass"].should.equal("ONEZONE_IA") + # checks that a copied object can be properly copied + list_of_copied_objects["Contents"][0]["StorageClass"].should.equal("ONEZONE_IA") @mock_s3 def test_s3_invalid_copied_storage_class(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARD" + ) - s3.create_bucket(Bucket="Bucket2") - s3.put_object(Bucket="Bucket2", Key="Second_Object", Body="Body2", StorageClass="REDUCED_REDUNDANCY") + s3.create_bucket(Bucket="Bucket2") + s3.put_object( + Bucket="Bucket2", + Key="Second_Object", + Body="Body2", + StorageClass="REDUCED_REDUNDANCY", + ) - # Try to copy an object with an invalid storage class - with assert_raises(ClientError) as err: - s3.copy_object(CopySource = {"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket2", Key="Second_Object", StorageClass="STANDARD2") + # Try to copy an object with an invalid storage class + with assert_raises(ClientError) as err: + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket2", + Key="Second_Object", + StorageClass="STANDARD2", + ) - e = err.exception - e.response["Error"]["Code"].should.equal("InvalidStorageClass") - e.response["Error"]["Message"].should.equal("The storage class you specified is not valid") + e = err.exception + e.response["Error"]["Code"].should.equal("InvalidStorageClass") + e.response["Error"]["Message"].should.equal( + "The storage class you specified is not valid" + ) @mock_s3 def test_s3_invalid_storage_class(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - # Try to add an object with an invalid storage class - with assert_raises(ClientError) as err: - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARDD") + # Try to add an object with an invalid storage class + with assert_raises(ClientError) as err: + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="STANDARDD" + ) - e = err.exception - e.response["Error"]["Code"].should.equal("InvalidStorageClass") - e.response["Error"]["Message"].should.equal("The storage class you specified is not valid") + e = err.exception + e.response["Error"]["Code"].should.equal("InvalidStorageClass") + e.response["Error"]["Message"].should.equal( + "The storage class you specified is not valid" + ) @mock_s3 def test_s3_default_storage_class(): - s3 = boto3.client("s3") - s3.create_bucket(Bucket="Bucket") + s3 = boto3.client("s3") + s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body") + s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body") - list_of_objects = s3.list_objects(Bucket="Bucket") + list_of_objects = s3.list_objects(Bucket="Bucket") - # tests that the default storage class is still STANDARD - list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD") + # tests that the default storage class is still STANDARD + list_of_objects["Contents"][0]["StorageClass"].should.equal("STANDARD") @mock_s3 @@ -117,10 +152,16 @@ def test_s3_copy_object_error_for_glacier_storage_class(): s3 = boto3.client("s3") s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="GLACIER") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="GLACIER" + ) with assert_raises(ClientError) as exc: - s3.copy_object(CopySource={"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket", Key="Second_Object") + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket", + Key="Second_Object", + ) exc.exception.response["Error"]["Code"].should.equal("ObjectNotInActiveTierError") @@ -130,9 +171,15 @@ def test_s3_copy_object_error_for_deep_archive_storage_class(): s3 = boto3.client("s3") s3.create_bucket(Bucket="Bucket") - s3.put_object(Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="DEEP_ARCHIVE") + s3.put_object( + Bucket="Bucket", Key="First_Object", Body="Body", StorageClass="DEEP_ARCHIVE" + ) with assert_raises(ClientError) as exc: - s3.copy_object(CopySource={"Bucket": "Bucket", "Key": "First_Object"}, Bucket="Bucket", Key="Second_Object") + s3.copy_object( + CopySource={"Bucket": "Bucket", "Key": "First_Object"}, + Bucket="Bucket", + Key="Second_Object", + ) exc.exception.response["Error"]["Code"].should.equal("ObjectNotInActiveTierError") diff --git a/tests/test_s3/test_s3_utils.py b/tests/test_s3/test_s3_utils.py index 93a4683e6..b90225597 100644 --- a/tests/test_s3/test_s3_utils.py +++ b/tests/test_s3/test_s3_utils.py @@ -1,28 +1,36 @@ from __future__ import unicode_literals import os from sure import expect -from moto.s3.utils import bucket_name_from_url, _VersionedKeyStore, parse_region_from_url, clean_key_name, undo_clean_key_name +from moto.s3.utils import ( + bucket_name_from_url, + _VersionedKeyStore, + parse_region_from_url, + clean_key_name, + undo_clean_key_name, +) from parameterized import parameterized def test_base_url(): - expect(bucket_name_from_url('https://s3.amazonaws.com/')).should.equal(None) + expect(bucket_name_from_url("https://s3.amazonaws.com/")).should.equal(None) def test_localhost_bucket(): - expect(bucket_name_from_url('https://wfoobar.localhost:5000/abc') - ).should.equal("wfoobar") + expect(bucket_name_from_url("https://wfoobar.localhost:5000/abc")).should.equal( + "wfoobar" + ) def test_localhost_without_bucket(): - expect(bucket_name_from_url( - 'https://www.localhost:5000/def')).should.equal(None) + expect(bucket_name_from_url("https://www.localhost:5000/def")).should.equal(None) + def test_force_ignore_subdomain_for_bucketnames(): - os.environ['S3_IGNORE_SUBDOMAIN_BUCKETNAME'] = '1' - expect(bucket_name_from_url('https://subdomain.localhost:5000/abc/resource')).should.equal(None) - del(os.environ['S3_IGNORE_SUBDOMAIN_BUCKETNAME']) - + os.environ["S3_IGNORE_SUBDOMAIN_BUCKETNAME"] = "1" + expect( + bucket_name_from_url("https://subdomain.localhost:5000/abc/resource") + ).should.equal(None) + del os.environ["S3_IGNORE_SUBDOMAIN_BUCKETNAME"] def test_versioned_key_store(): @@ -30,78 +38,84 @@ def test_versioned_key_store(): d.should.have.length_of(0) - d['key'] = [1] + d["key"] = [1] d.should.have.length_of(1) - d['key'] = 2 + d["key"] = 2 d.should.have.length_of(1) - d.should.have.key('key').being.equal(2) + d.should.have.key("key").being.equal(2) - d.get.when.called_with('key').should.return_value(2) - d.get.when.called_with('badkey').should.return_value(None) - d.get.when.called_with('badkey', 'HELLO').should.return_value('HELLO') + d.get.when.called_with("key").should.return_value(2) + d.get.when.called_with("badkey").should.return_value(None) + d.get.when.called_with("badkey", "HELLO").should.return_value("HELLO") # Tests key[ - d.shouldnt.have.key('badkey') - d.__getitem__.when.called_with('badkey').should.throw(KeyError) + d.shouldnt.have.key("badkey") + d.__getitem__.when.called_with("badkey").should.throw(KeyError) - d.getlist('key').should.have.length_of(2) - d.getlist('key').should.be.equal([[1], 2]) - d.getlist('badkey').should.be.none + d.getlist("key").should.have.length_of(2) + d.getlist("key").should.be.equal([[1], 2]) + d.getlist("badkey").should.be.none - d.setlist('key', 1) - d.getlist('key').should.be.equal([1]) + d.setlist("key", 1) + d.getlist("key").should.be.equal([1]) - d.setlist('key', (1, 2)) - d.getlist('key').shouldnt.be.equal((1, 2)) - d.getlist('key').should.be.equal([1, 2]) + d.setlist("key", (1, 2)) + d.getlist("key").shouldnt.be.equal((1, 2)) + d.getlist("key").should.be.equal([1, 2]) - d.setlist('key', [[1], [2]]) - d['key'].should.have.length_of(1) - d.getlist('key').should.be.equal([[1], [2]]) + d.setlist("key", [[1], [2]]) + d["key"].should.have.length_of(1) + d.getlist("key").should.be.equal([[1], [2]]) def test_parse_region_from_url(): - expected = 'us-west-2' - for url in ['http://s3-us-west-2.amazonaws.com/bucket', - 'http://s3.us-west-2.amazonaws.com/bucket', - 'http://bucket.s3-us-west-2.amazonaws.com', - 'https://s3-us-west-2.amazonaws.com/bucket', - 'https://s3.us-west-2.amazonaws.com/bucket', - 'https://bucket.s3-us-west-2.amazonaws.com']: + expected = "us-west-2" + for url in [ + "http://s3-us-west-2.amazonaws.com/bucket", + "http://s3.us-west-2.amazonaws.com/bucket", + "http://bucket.s3-us-west-2.amazonaws.com", + "https://s3-us-west-2.amazonaws.com/bucket", + "https://s3.us-west-2.amazonaws.com/bucket", + "https://bucket.s3-us-west-2.amazonaws.com", + ]: parse_region_from_url(url).should.equal(expected) - expected = 'us-east-1' - for url in ['http://s3.amazonaws.com/bucket', - 'http://bucket.s3.amazonaws.com', - 'https://s3.amazonaws.com/bucket', - 'https://bucket.s3.amazonaws.com']: + expected = "us-east-1" + for url in [ + "http://s3.amazonaws.com/bucket", + "http://bucket.s3.amazonaws.com", + "https://s3.amazonaws.com/bucket", + "https://bucket.s3.amazonaws.com", + ]: parse_region_from_url(url).should.equal(expected) -@parameterized([ - ('foo/bar/baz', - 'foo/bar/baz'), - ('foo', - 'foo'), - ('foo/run_dt%3D2019-01-01%252012%253A30%253A00', - 'foo/run_dt=2019-01-01%2012%3A30%3A00'), -]) +@parameterized( + [ + ("foo/bar/baz", "foo/bar/baz"), + ("foo", "foo"), + ( + "foo/run_dt%3D2019-01-01%252012%253A30%253A00", + "foo/run_dt=2019-01-01%2012%3A30%3A00", + ), + ] +) def test_clean_key_name(key, expected): clean_key_name(key).should.equal(expected) -@parameterized([ - ('foo/bar/baz', - 'foo/bar/baz'), - ('foo', - 'foo'), - ('foo/run_dt%3D2019-01-01%252012%253A30%253A00', - 'foo/run_dt%253D2019-01-01%25252012%25253A30%25253A00'), -]) +@parameterized( + [ + ("foo/bar/baz", "foo/bar/baz"), + ("foo", "foo"), + ( + "foo/run_dt%3D2019-01-01%252012%253A30%253A00", + "foo/run_dt%253D2019-01-01%25252012%25253A30%25253A00", + ), + ] +) def test_undo_clean_key_name(key, expected): undo_clean_key_name(key).should.equal(expected) - - diff --git a/tests/test_s3/test_server.py b/tests/test_s3/test_server.py index b179a2329..56d46de09 100644 --- a/tests/test_s3/test_server.py +++ b/tests/test_s3/test_server.py @@ -6,16 +6,16 @@ import sure # noqa from flask.testing import FlaskClient import moto.server as server -''' +""" Test the different server responses -''' +""" class AuthenticatedClient(FlaskClient): def open(self, *args, **kwargs): - kwargs['headers'] = kwargs.get('headers', {}) - kwargs['headers']['Authorization'] = "Any authorization header" - kwargs['content_length'] = 0 # Fixes content-length complaints. + kwargs["headers"] = kwargs.get("headers", {}) + kwargs["headers"]["Authorization"] = "Any authorization header" + kwargs["content_length"] = 0 # Fixes content-length complaints. return super(AuthenticatedClient, self).open(*args, **kwargs) @@ -27,30 +27,29 @@ def authenticated_client(): def test_s3_server_get(): test_client = authenticated_client() - res = test_client.get('/') + res = test_client.get("/") - res.data.should.contain(b'ListAllMyBucketsResult') + res.data.should.contain(b"ListAllMyBucketsResult") def test_s3_server_bucket_create(): test_client = authenticated_client() - res = test_client.put('/', 'http://foobaz.localhost:5000/') + res = test_client.put("/", "http://foobaz.localhost:5000/") res.status_code.should.equal(200) - res = test_client.get('/') - res.data.should.contain(b'foobaz') + res = test_client.get("/") + res.data.should.contain(b"foobaz") - res = test_client.get('/', 'http://foobaz.localhost:5000/') + res = test_client.get("/", "http://foobaz.localhost:5000/") res.status_code.should.equal(200) res.data.should.contain(b"ListBucketResult") - res = test_client.put( - '/bar', 'http://foobaz.localhost:5000/', data='test value') + res = test_client.put("/bar", "http://foobaz.localhost:5000/", data="test value") res.status_code.should.equal(200) - assert 'ETag' in dict(res.headers) + assert "ETag" in dict(res.headers) - res = test_client.get('/bar', 'http://foobaz.localhost:5000/') + res = test_client.get("/bar", "http://foobaz.localhost:5000/") res.status_code.should.equal(200) res.data.should.equal(b"test value") @@ -59,24 +58,24 @@ def test_s3_server_bucket_versioning(): test_client = authenticated_client() # Just enough XML to enable versioning - body = 'Enabled' - res = test_client.put( - '/?versioning', 'http://foobaz.localhost:5000', data=body) + body = "Enabled" + res = test_client.put("/?versioning", "http://foobaz.localhost:5000", data=body) res.status_code.should.equal(200) def test_s3_server_post_to_bucket(): test_client = authenticated_client() - res = test_client.put('/', 'http://tester.localhost:5000/') + res = test_client.put("/", "http://tester.localhost:5000/") res.status_code.should.equal(200) - test_client.post('/', "https://tester.localhost:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/", + "https://tester.localhost:5000/", + data={"key": "the-key", "file": "nothing"}, + ) - res = test_client.get('/the-key', 'http://tester.localhost:5000/') + res = test_client.get("/the-key", "http://tester.localhost:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") @@ -84,23 +83,28 @@ def test_s3_server_post_to_bucket(): def test_s3_server_post_without_content_length(): test_client = authenticated_client() - res = test_client.put('/', 'http://tester.localhost:5000/', environ_overrides={'CONTENT_LENGTH': ''}) + res = test_client.put( + "/", "http://tester.localhost:5000/", environ_overrides={"CONTENT_LENGTH": ""} + ) res.status_code.should.equal(411) - res = test_client.post('/', "https://tester.localhost:5000/", environ_overrides={'CONTENT_LENGTH': ''}) + res = test_client.post( + "/", "https://tester.localhost:5000/", environ_overrides={"CONTENT_LENGTH": ""} + ) res.status_code.should.equal(411) def test_s3_server_post_unicode_bucket_key(): # Make sure that we can deal with non-ascii characters in request URLs (e.g., S3 object names) dispatcher = server.DomainDispatcherApplication(server.create_backend_app) - backend_app = dispatcher.get_application({ - 'HTTP_HOST': 's3.amazonaws.com', - 'PATH_INFO': '/test-bucket/test-object-てすと' - }) + backend_app = dispatcher.get_application( + {"HTTP_HOST": "s3.amazonaws.com", "PATH_INFO": "/test-bucket/test-object-てすと"} + ) assert backend_app - backend_app = dispatcher.get_application({ - 'HTTP_HOST': 's3.amazonaws.com', - 'PATH_INFO': '/test-bucket/test-object-てすと'.encode('utf-8') - }) + backend_app = dispatcher.get_application( + { + "HTTP_HOST": "s3.amazonaws.com", + "PATH_INFO": "/test-bucket/test-object-てすと".encode("utf-8"), + } + ) assert backend_app diff --git a/tests/test_s3bucket_path/test_bucket_path_server.py b/tests/test_s3bucket_path/test_bucket_path_server.py index f6238dd28..2fe606799 100644 --- a/tests/test_s3bucket_path/test_bucket_path_server.py +++ b/tests/test_s3bucket_path/test_bucket_path_server.py @@ -4,16 +4,16 @@ import sure # noqa from flask.testing import FlaskClient import moto.server as server -''' +""" Test the different server responses -''' +""" class AuthenticatedClient(FlaskClient): def open(self, *args, **kwargs): - kwargs['headers'] = kwargs.get('headers', {}) - kwargs['headers']['Authorization'] = "Any authorization header" - kwargs['content_length'] = 0 # Fixes content-length complaints. + kwargs["headers"] = kwargs.get("headers", {}) + kwargs["headers"]["Authorization"] = "Any authorization header" + kwargs["content_length"] = 0 # Fixes content-length complaints. return super(AuthenticatedClient, self).open(*args, **kwargs) @@ -26,42 +26,41 @@ def authenticated_client(): def test_s3_server_get(): test_client = authenticated_client() - res = test_client.get('/') + res = test_client.get("/") - res.data.should.contain(b'ListAllMyBucketsResult') + res.data.should.contain(b"ListAllMyBucketsResult") def test_s3_server_bucket_create(): test_client = authenticated_client() - res = test_client.put('/foobar', 'http://localhost:5000') + res = test_client.put("/foobar", "http://localhost:5000") res.status_code.should.equal(200) - res = test_client.get('/') - res.data.should.contain(b'foobar') + res = test_client.get("/") + res.data.should.contain(b"foobar") - res = test_client.get('/foobar', 'http://localhost:5000') + res = test_client.get("/foobar", "http://localhost:5000") res.status_code.should.equal(200) res.data.should.contain(b"ListBucketResult") - res = test_client.put('/foobar2/', 'http://localhost:5000') + res = test_client.put("/foobar2/", "http://localhost:5000") res.status_code.should.equal(200) - res = test_client.get('/') - res.data.should.contain(b'foobar2') + res = test_client.get("/") + res.data.should.contain(b"foobar2") - res = test_client.get('/foobar2/', 'http://localhost:5000') + res = test_client.get("/foobar2/", "http://localhost:5000") res.status_code.should.equal(200) res.data.should.contain(b"ListBucketResult") - res = test_client.get('/missing-bucket', 'http://localhost:5000') + res = test_client.get("/missing-bucket", "http://localhost:5000") res.status_code.should.equal(404) - res = test_client.put( - '/foobar/bar', 'http://localhost:5000', data='test value') + res = test_client.put("/foobar/bar", "http://localhost:5000", data="test value") res.status_code.should.equal(200) - res = test_client.get('/foobar/bar', 'http://localhost:5000') + res = test_client.get("/foobar/bar", "http://localhost:5000") res.status_code.should.equal(200) res.data.should.equal(b"test value") @@ -69,15 +68,16 @@ def test_s3_server_bucket_create(): def test_s3_server_post_to_bucket(): test_client = authenticated_client() - res = test_client.put('/foobar2', 'http://localhost:5000/') + res = test_client.put("/foobar2", "http://localhost:5000/") res.status_code.should.equal(200) - test_client.post('/foobar2', "https://localhost:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/foobar2", + "https://localhost:5000/", + data={"key": "the-key", "file": "nothing"}, + ) - res = test_client.get('/foobar2/the-key', 'http://localhost:5000/') + res = test_client.get("/foobar2/the-key", "http://localhost:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") @@ -85,15 +85,14 @@ def test_s3_server_post_to_bucket(): def test_s3_server_put_ipv6(): test_client = authenticated_client() - res = test_client.put('/foobar2', 'http://[::]:5000/') + res = test_client.put("/foobar2", "http://[::]:5000/") res.status_code.should.equal(200) - test_client.post('/foobar2', "https://[::]:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/foobar2", "https://[::]:5000/", data={"key": "the-key", "file": "nothing"} + ) - res = test_client.get('/foobar2/the-key', 'http://[::]:5000/') + res = test_client.get("/foobar2/the-key", "http://[::]:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") @@ -101,14 +100,15 @@ def test_s3_server_put_ipv6(): def test_s3_server_put_ipv4(): test_client = authenticated_client() - res = test_client.put('/foobar2', 'http://127.0.0.1:5000/') + res = test_client.put("/foobar2", "http://127.0.0.1:5000/") res.status_code.should.equal(200) - test_client.post('/foobar2', "https://127.0.0.1:5000/", data={ - 'key': 'the-key', - 'file': 'nothing' - }) + test_client.post( + "/foobar2", + "https://127.0.0.1:5000/", + data={"key": "the-key", "file": "nothing"}, + ) - res = test_client.get('/foobar2/the-key', 'http://127.0.0.1:5000/') + res = test_client.get("/foobar2/the-key", "http://127.0.0.1:5000/") res.status_code.should.equal(200) res.data.should.equal(b"nothing") diff --git a/tests/test_s3bucket_path/test_s3bucket_path.py b/tests/test_s3bucket_path/test_s3bucket_path.py index 21d786c61..e204d0527 100644 --- a/tests/test_s3bucket_path/test_s3bucket_path.py +++ b/tests/test_s3bucket_path/test_s3bucket_path.py @@ -20,14 +20,13 @@ def create_connection(key=None, secret=None): class MyModel(object): - def __init__(self, name, value): self.name = name self.value = value def save(self): - conn = create_connection('the_key', 'the_secret') - bucket = conn.get_bucket('mybucket') + conn = create_connection("the_key", "the_secret") + bucket = conn.get_bucket("mybucket") k = Key(bucket) k.key = self.name k.set_contents_from_string(self.value) @@ -36,96 +35,95 @@ class MyModel(object): @mock_s3_deprecated def test_my_model_save(): # Create Bucket so that test can run - conn = create_connection('the_key', 'the_secret') - conn.create_bucket('mybucket') + conn = create_connection("the_key", "the_secret") + conn.create_bucket("mybucket") #################################### - model_instance = MyModel('steve', 'is awesome') + model_instance = MyModel("steve", "is awesome") model_instance.save() - conn.get_bucket('mybucket').get_key( - 'steve').get_contents_as_string().should.equal(b'is awesome') + conn.get_bucket("mybucket").get_key("steve").get_contents_as_string().should.equal( + b"is awesome" + ) @mock_s3_deprecated def test_missing_key(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") bucket.get_key("the-key").should.equal(None) @mock_s3_deprecated def test_missing_key_urllib2(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") conn.create_bucket("foobar") - urlopen.when.called_with( - "http://s3.amazonaws.com/foobar/the-key").should.throw(HTTPError) + urlopen.when.called_with("http://s3.amazonaws.com/foobar/the-key").should.throw( + HTTPError + ) @mock_s3_deprecated def test_empty_key(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("") - bucket.get_key("the-key").get_contents_as_string().should.equal(b'') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"") @mock_s3_deprecated def test_empty_key_set_on_existing_key(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("foobar") - bucket.get_key("the-key").get_contents_as_string().should.equal(b'foobar') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"foobar") key.set_contents_from_string("") - bucket.get_key("the-key").get_contents_as_string().should.equal(b'') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"") @mock_s3_deprecated def test_large_key_save(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("foobar" * 100000) - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b'foobar' * 100000) + bucket.get_key("the-key").get_contents_as_string().should.equal(b"foobar" * 100000) @mock_s3_deprecated def test_copy_key(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) key.key = "the-key" key.set_contents_from_string("some value") - bucket.copy_key('new-key', 'foobar', 'the-key') + bucket.copy_key("new-key", "foobar", "the-key") - bucket.get_key( - "the-key").get_contents_as_string().should.equal(b"some value") - bucket.get_key( - "new-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("the-key").get_contents_as_string().should.equal(b"some value") + bucket.get_key("new-key").get_contents_as_string().should.equal(b"some value") @mock_s3_deprecated def test_set_metadata(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) - key.key = 'the-key' - key.set_metadata('md', 'Metadatastring') + key.key = "the-key" + key.set_metadata("md", "Metadatastring") key.set_contents_from_string("Testval") - bucket.get_key('the-key').get_metadata('md').should.equal('Metadatastring') + bucket.get_key("the-key").get_metadata("md").should.equal("Metadatastring") @freeze_time("2012-01-01 12:00:00") @@ -139,28 +137,28 @@ def test_last_modified(): key.set_contents_from_string("some value") rs = bucket.get_all_keys() - rs[0].last_modified.should.equal('2012-01-01T12:00:00.000Z') + rs[0].last_modified.should.equal("2012-01-01T12:00:00.000Z") - bucket.get_key( - "the-key").last_modified.should.equal('Sun, 01 Jan 2012 12:00:00 GMT') + bucket.get_key("the-key").last_modified.should.equal( + "Sun, 01 Jan 2012 12:00:00 GMT" + ) @mock_s3_deprecated def test_missing_bucket(): - conn = create_connection('the_key', 'the_secret') - conn.get_bucket.when.called_with('mybucket').should.throw(S3ResponseError) + conn = create_connection("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket").should.throw(S3ResponseError) @mock_s3_deprecated def test_bucket_with_dash(): - conn = create_connection('the_key', 'the_secret') - conn.get_bucket.when.called_with( - 'mybucket-test').should.throw(S3ResponseError) + conn = create_connection("the_key", "the_secret") + conn.get_bucket.when.called_with("mybucket-test").should.throw(S3ResponseError) @mock_s3_deprecated def test_bucket_deletion(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") key = Key(bucket) @@ -182,7 +180,7 @@ def test_bucket_deletion(): @mock_s3_deprecated def test_get_all_buckets(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") conn.create_bucket("foobar") conn.create_bucket("foobar2") buckets = conn.get_all_buckets() @@ -193,50 +191,48 @@ def test_get_all_buckets(): @mock_s3 @mock_s3_deprecated def test_post_to_bucket(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") - requests.post("https://s3.amazonaws.com/foobar", { - 'key': 'the-key', - 'file': 'nothing' - }) + requests.post( + "https://s3.amazonaws.com/foobar", {"key": "the-key", "file": "nothing"} + ) - bucket.get_key('the-key').get_contents_as_string().should.equal(b'nothing') + bucket.get_key("the-key").get_contents_as_string().should.equal(b"nothing") @mock_s3 @mock_s3_deprecated def test_post_with_metadata_to_bucket(): - conn = create_connection('the_key', 'the_secret') + conn = create_connection("the_key", "the_secret") bucket = conn.create_bucket("foobar") - requests.post("https://s3.amazonaws.com/foobar", { - 'key': 'the-key', - 'file': 'nothing', - 'x-amz-meta-test': 'metadata' - }) + requests.post( + "https://s3.amazonaws.com/foobar", + {"key": "the-key", "file": "nothing", "x-amz-meta-test": "metadata"}, + ) - bucket.get_key('the-key').get_metadata('test').should.equal('metadata') + bucket.get_key("the-key").get_metadata("test").should.equal("metadata") @mock_s3_deprecated def test_bucket_name_with_dot(): conn = create_connection() - bucket = conn.create_bucket('firstname.lastname') + bucket = conn.create_bucket("firstname.lastname") - k = Key(bucket, 'somekey') - k.set_contents_from_string('somedata') + k = Key(bucket, "somekey") + k.set_contents_from_string("somedata") @mock_s3_deprecated def test_key_with_special_characters(): conn = create_connection() - bucket = conn.create_bucket('test_bucket_name') + bucket = conn.create_bucket("test_bucket_name") - key = Key(bucket, 'test_list_keys_2/*x+?^@~!y') - key.set_contents_from_string('value1') + key = Key(bucket, "test_list_keys_2/*x+?^@~!y") + key.set_contents_from_string("value1") - key_list = bucket.list('test_list_keys_2/', '/') + key_list = bucket.list("test_list_keys_2/", "/") keys = [x for x in key_list] keys[0].name.should.equal("test_list_keys_2/*x+?^@~!y") @@ -244,78 +240,83 @@ def test_key_with_special_characters(): @mock_s3_deprecated def test_bucket_key_listing_order(): conn = create_connection() - bucket = conn.create_bucket('test_bucket') - prefix = 'toplevel/' + bucket = conn.create_bucket("test_bucket") + prefix = "toplevel/" def store(name): k = Key(bucket, prefix + name) - k.set_contents_from_string('somedata') + k.set_contents_from_string("somedata") - names = ['x/key', 'y.key1', 'y.key2', 'y.key3', 'x/y/key', 'x/y/z/key'] + names = ["x/key", "y.key1", "y.key2", "y.key3", "x/y/key", "x/y/z/key"] for name in names: store(name) delimiter = None keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/x/key', 'toplevel/x/y/key', 'toplevel/x/y/z/key', - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3' - ]) + keys.should.equal( + [ + "toplevel/x/key", + "toplevel/x/y/key", + "toplevel/x/y/z/key", + "toplevel/y.key1", + "toplevel/y.key2", + "toplevel/y.key3", + ] + ) - delimiter = '/' + delimiter = "/" keys = [x.name for x in bucket.list(prefix, delimiter)] - keys.should.equal([ - 'toplevel/y.key1', 'toplevel/y.key2', 'toplevel/y.key3', 'toplevel/x/' - ]) + keys.should.equal( + ["toplevel/y.key1", "toplevel/y.key2", "toplevel/y.key3", "toplevel/x/"] + ) # Test delimiter with no prefix - delimiter = '/' + delimiter = "/" keys = [x.name for x in bucket.list(prefix=None, delimiter=delimiter)] - keys.should.equal(['toplevel/']) + keys.should.equal(["toplevel/"]) delimiter = None - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal( - ['toplevel/x/key', 'toplevel/x/y/key', 'toplevel/x/y/z/key']) + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/key", "toplevel/x/y/key", "toplevel/x/y/z/key"]) - delimiter = '/' - keys = [x.name for x in bucket.list(prefix + 'x', delimiter)] - keys.should.equal(['toplevel/x/']) + delimiter = "/" + keys = [x.name for x in bucket.list(prefix + "x", delimiter)] + keys.should.equal(["toplevel/x/"]) @mock_s3_deprecated def test_delete_keys(): conn = create_connection() - bucket = conn.create_bucket('foobar') + bucket = conn.create_bucket("foobar") - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") - result = bucket.delete_keys(['file2', 'file3']) + result = bucket.delete_keys(["file2", "file3"]) result.deleted.should.have.length_of(2) result.errors.should.have.length_of(0) keys = bucket.get_all_keys() keys.should.have.length_of(2) - keys[0].name.should.equal('file1') + keys[0].name.should.equal("file1") @mock_s3_deprecated def test_delete_keys_with_invalid(): conn = create_connection() - bucket = conn.create_bucket('foobar') + bucket = conn.create_bucket("foobar") - Key(bucket=bucket, name='file1').set_contents_from_string('abc') - Key(bucket=bucket, name='file2').set_contents_from_string('abc') - Key(bucket=bucket, name='file3').set_contents_from_string('abc') - Key(bucket=bucket, name='file4').set_contents_from_string('abc') + Key(bucket=bucket, name="file1").set_contents_from_string("abc") + Key(bucket=bucket, name="file2").set_contents_from_string("abc") + Key(bucket=bucket, name="file3").set_contents_from_string("abc") + Key(bucket=bucket, name="file4").set_contents_from_string("abc") - result = bucket.delete_keys(['abc', 'file3']) + result = bucket.delete_keys(["abc", "file3"]) result.deleted.should.have.length_of(1) result.errors.should.have.length_of(1) keys = bucket.get_all_keys() keys.should.have.length_of(3) - keys[0].name.should.equal('file1') + keys[0].name.should.equal("file1") diff --git a/tests/test_s3bucket_path/test_s3bucket_path_combo.py b/tests/test_s3bucket_path/test_s3bucket_path_combo.py index e1b1075ee..2ca7107d9 100644 --- a/tests/test_s3bucket_path/test_s3bucket_path_combo.py +++ b/tests/test_s3bucket_path/test_s3bucket_path_combo.py @@ -14,12 +14,12 @@ def test_bucketpath_combo_serial(): @mock_s3_deprecated def make_bucket_path(): conn = create_connection() - conn.create_bucket('mybucketpath') + conn.create_bucket("mybucketpath") @mock_s3_deprecated def make_bucket(): - conn = boto.connect_s3('the_key', 'the_secret') - conn.create_bucket('mybucket') + conn = boto.connect_s3("the_key", "the_secret") + conn.create_bucket("mybucket") make_bucket() make_bucket_path() diff --git a/tests/test_s3bucket_path/test_s3bucket_path_utils.py b/tests/test_s3bucket_path/test_s3bucket_path_utils.py index c607ea2ec..072968929 100644 --- a/tests/test_s3bucket_path/test_s3bucket_path_utils.py +++ b/tests/test_s3bucket_path/test_s3bucket_path_utils.py @@ -4,13 +4,14 @@ from moto.s3bucket_path.utils import bucket_name_from_url def test_base_url(): - expect(bucket_name_from_url('https://s3.amazonaws.com/')).should.equal(None) + expect(bucket_name_from_url("https://s3.amazonaws.com/")).should.equal(None) def test_localhost_bucket(): - expect(bucket_name_from_url('https://localhost:5000/wfoobar/abc') - ).should.equal("wfoobar") + expect(bucket_name_from_url("https://localhost:5000/wfoobar/abc")).should.equal( + "wfoobar" + ) def test_localhost_without_bucket(): - expect(bucket_name_from_url('https://www.localhost:5000')).should.equal(None) + expect(bucket_name_from_url("https://www.localhost:5000")).should.equal(None) diff --git a/tests/test_secretsmanager/test_secretsmanager.py b/tests/test_secretsmanager/test_secretsmanager.py index e2fc266ea..bf688ec12 100644 --- a/tests/test_secretsmanager/test_secretsmanager.py +++ b/tests/test_secretsmanager/test_secretsmanager.py @@ -12,310 +12,325 @@ import sure # noqa from nose.tools import assert_raises, assert_equal from six import b -DEFAULT_SECRET_NAME = 'test-secret' +DEFAULT_SECRET_NAME = "test-secret" @mock_secretsmanager def test_get_secret_value(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - create_secret = conn.create_secret(Name='java-util-test-password', - SecretString="foosecret") - result = conn.get_secret_value(SecretId='java-util-test-password') - assert result['SecretString'] == 'foosecret' + create_secret = conn.create_secret( + Name="java-util-test-password", SecretString="foosecret" + ) + result = conn.get_secret_value(SecretId="java-util-test-password") + assert result["SecretString"] == "foosecret" @mock_secretsmanager def test_get_secret_value_binary(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - create_secret = conn.create_secret(Name='java-util-test-password', - SecretBinary=b("foosecret")) - result = conn.get_secret_value(SecretId='java-util-test-password') - assert result['SecretBinary'] == b('foosecret') + create_secret = conn.create_secret( + Name="java-util-test-password", SecretBinary=b("foosecret") + ) + result = conn.get_secret_value(SecretId="java-util-test-password") + assert result["SecretBinary"] == b("foosecret") @mock_secretsmanager def test_get_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError) as cm: - result = conn.get_secret_value(SecretId='i-dont-exist') + result = conn.get_secret_value(SecretId="i-dont-exist") assert_equal( - u"Secrets Manager can\u2019t find the specified secret.", - cm.exception.response['Error']['Message'] + "Secrets Manager can\u2019t find the specified secret.", + cm.exception.response["Error"]["Message"], ) @mock_secretsmanager def test_get_secret_that_does_not_match(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - create_secret = conn.create_secret(Name='java-util-test-password', - SecretString="foosecret") + conn = boto3.client("secretsmanager", region_name="us-west-2") + create_secret = conn.create_secret( + Name="java-util-test-password", SecretString="foosecret" + ) with assert_raises(ClientError) as cm: - result = conn.get_secret_value(SecretId='i-dont-match') + result = conn.get_secret_value(SecretId="i-dont-match") assert_equal( - u"Secrets Manager can\u2019t find the specified secret.", - cm.exception.response['Error']['Message'] + "Secrets Manager can\u2019t find the specified secret.", + cm.exception.response["Error"]["Message"], ) @mock_secretsmanager def test_get_secret_value_that_is_marked_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.delete_secret(SecretId='test-secret') + conn.delete_secret(SecretId="test-secret") with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='test-secret') + result = conn.get_secret_value(SecretId="test-secret") @mock_secretsmanager def test_get_secret_that_has_no_value(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") create_secret = conn.create_secret(Name="java-util-test-password") with assert_raises(ClientError) as cm: - result = conn.get_secret_value(SecretId='java-util-test-password') + result = conn.get_secret_value(SecretId="java-util-test-password") assert_equal( - u"Secrets Manager can\u2019t find the specified secret value for staging label: AWSCURRENT", - cm.exception.response['Error']['Message'] + "Secrets Manager can\u2019t find the specified secret value for staging label: AWSCURRENT", + cm.exception.response["Error"]["Message"], ) @mock_secretsmanager def test_create_secret(): - conn = boto3.client('secretsmanager', region_name='us-east-1') + conn = boto3.client("secretsmanager", region_name="us-east-1") - result = conn.create_secret(Name='test-secret', SecretString="foosecret") - assert result['ARN'] - assert result['Name'] == 'test-secret' - secret = conn.get_secret_value(SecretId='test-secret') - assert secret['SecretString'] == 'foosecret' + result = conn.create_secret(Name="test-secret", SecretString="foosecret") + assert result["ARN"] + assert result["Name"] == "test-secret" + secret = conn.get_secret_value(SecretId="test-secret") + assert secret["SecretString"] == "foosecret" @mock_secretsmanager def test_create_secret_with_tags(): - conn = boto3.client('secretsmanager', region_name='us-east-1') - secret_name = 'test-secret-with-tags' + conn = boto3.client("secretsmanager", region_name="us-east-1") + secret_name = "test-secret-with-tags" result = conn.create_secret( Name=secret_name, SecretString="foosecret", - Tags=[{"Key": "Foo", "Value": "Bar"}, {"Key": "Mykey", "Value": "Myvalue"}] + Tags=[{"Key": "Foo", "Value": "Bar"}, {"Key": "Mykey", "Value": "Myvalue"}], ) - assert result['ARN'] - assert result['Name'] == secret_name + assert result["ARN"] + assert result["Name"] == secret_name secret_value = conn.get_secret_value(SecretId=secret_name) - assert secret_value['SecretString'] == 'foosecret' + assert secret_value["SecretString"] == "foosecret" secret_details = conn.describe_secret(SecretId=secret_name) - assert secret_details['Tags'] == [{"Key": "Foo", "Value": "Bar"}, {"Key": "Mykey", "Value": "Myvalue"}] + assert secret_details["Tags"] == [ + {"Key": "Foo", "Value": "Bar"}, + {"Key": "Mykey", "Value": "Myvalue"}, + ] @mock_secretsmanager def test_delete_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - deleted_secret = conn.delete_secret(SecretId='test-secret') + deleted_secret = conn.delete_secret(SecretId="test-secret") - assert deleted_secret['ARN'] - assert deleted_secret['Name'] == 'test-secret' - assert deleted_secret['DeletionDate'] > datetime.fromtimestamp(1, pytz.utc) + assert deleted_secret["ARN"] + assert deleted_secret["Name"] == "test-secret" + assert deleted_secret["DeletionDate"] > datetime.fromtimestamp(1, pytz.utc) - secret_details = conn.describe_secret(SecretId='test-secret') + secret_details = conn.describe_secret(SecretId="test-secret") - assert secret_details['ARN'] - assert secret_details['Name'] == 'test-secret' - assert secret_details['DeletedDate'] > datetime.fromtimestamp(1, pytz.utc) + assert secret_details["ARN"] + assert secret_details["Name"] == "test-secret" + assert secret_details["DeletedDate"] > datetime.fromtimestamp(1, pytz.utc) @mock_secretsmanager def test_delete_secret_force(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - result = conn.delete_secret(SecretId='test-secret', ForceDeleteWithoutRecovery=True) + result = conn.delete_secret(SecretId="test-secret", ForceDeleteWithoutRecovery=True) - assert result['ARN'] - assert result['DeletionDate'] > datetime.fromtimestamp(1, pytz.utc) - assert result['Name'] == 'test-secret' + assert result["ARN"] + assert result["DeletionDate"] > datetime.fromtimestamp(1, pytz.utc) + assert result["Name"] == "test-secret" with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='test-secret') + result = conn.get_secret_value(SecretId="test-secret") @mock_secretsmanager def test_delete_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='i-dont-exist', ForceDeleteWithoutRecovery=True) + result = conn.delete_secret( + SecretId="i-dont-exist", ForceDeleteWithoutRecovery=True + ) @mock_secretsmanager def test_delete_secret_fails_with_both_force_delete_flag_and_recovery_window_flag(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret', RecoveryWindowInDays=1, ForceDeleteWithoutRecovery=True) + result = conn.delete_secret( + SecretId="test-secret", + RecoveryWindowInDays=1, + ForceDeleteWithoutRecovery=True, + ) @mock_secretsmanager def test_delete_secret_recovery_window_too_short(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret', RecoveryWindowInDays=6) + result = conn.delete_secret(SecretId="test-secret", RecoveryWindowInDays=6) @mock_secretsmanager def test_delete_secret_recovery_window_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret', RecoveryWindowInDays=31) + result = conn.delete_secret(SecretId="test-secret", RecoveryWindowInDays=31) @mock_secretsmanager def test_delete_secret_that_is_marked_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - deleted_secret = conn.delete_secret(SecretId='test-secret') + deleted_secret = conn.delete_secret(SecretId="test-secret") with assert_raises(ClientError): - result = conn.delete_secret(SecretId='test-secret') + result = conn.delete_secret(SecretId="test-secret") @mock_secretsmanager def test_get_random_password_default_length(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password() - assert len(random_password['RandomPassword']) == 32 + assert len(random_password["RandomPassword"]) == 32 + @mock_secretsmanager def test_get_random_password_default_requirements(): # When require_each_included_type, default true - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password() # Should contain lowercase, upppercase, digit, special character - assert any(c.islower() for c in random_password['RandomPassword']) - assert any(c.isupper() for c in random_password['RandomPassword']) - assert any(c.isdigit() for c in random_password['RandomPassword']) - assert any(c in string.punctuation - for c in random_password['RandomPassword']) + assert any(c.islower() for c in random_password["RandomPassword"]) + assert any(c.isupper() for c in random_password["RandomPassword"]) + assert any(c.isdigit() for c in random_password["RandomPassword"]) + assert any(c in string.punctuation for c in random_password["RandomPassword"]) + @mock_secretsmanager def test_get_random_password_custom_length(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password(PasswordLength=50) - assert len(random_password['RandomPassword']) == 50 + assert len(random_password["RandomPassword"]) == 50 + @mock_secretsmanager def test_get_random_exclude_lowercase(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - random_password = conn.get_random_password(PasswordLength=55, - ExcludeLowercase=True) - assert any(c.islower() for c in random_password['RandomPassword']) == False + random_password = conn.get_random_password(PasswordLength=55, ExcludeLowercase=True) + assert any(c.islower() for c in random_password["RandomPassword"]) == False @mock_secretsmanager def test_get_random_exclude_uppercase(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - random_password = conn.get_random_password(PasswordLength=55, - ExcludeUppercase=True) - assert any(c.isupper() for c in random_password['RandomPassword']) == False + random_password = conn.get_random_password(PasswordLength=55, ExcludeUppercase=True) + assert any(c.isupper() for c in random_password["RandomPassword"]) == False @mock_secretsmanager def test_get_random_exclude_characters_and_symbols(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - random_password = conn.get_random_password(PasswordLength=20, - ExcludeCharacters='xyzDje@?!.') - assert any(c in 'xyzDje@?!.' for c in random_password['RandomPassword']) == False + random_password = conn.get_random_password( + PasswordLength=20, ExcludeCharacters="xyzDje@?!." + ) + assert any(c in "xyzDje@?!." for c in random_password["RandomPassword"]) == False @mock_secretsmanager def test_get_random_exclude_numbers(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - random_password = conn.get_random_password(PasswordLength=100, - ExcludeNumbers=True) - assert any(c.isdigit() for c in random_password['RandomPassword']) == False + random_password = conn.get_random_password(PasswordLength=100, ExcludeNumbers=True) + assert any(c.isdigit() for c in random_password["RandomPassword"]) == False @mock_secretsmanager def test_get_random_exclude_punctuation(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - random_password = conn.get_random_password(PasswordLength=100, - ExcludePunctuation=True) - assert any(c in string.punctuation - for c in random_password['RandomPassword']) == False + random_password = conn.get_random_password( + PasswordLength=100, ExcludePunctuation=True + ) + assert ( + any(c in string.punctuation for c in random_password["RandomPassword"]) == False + ) @mock_secretsmanager def test_get_random_include_space_false(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") random_password = conn.get_random_password(PasswordLength=300) - assert any(c.isspace() for c in random_password['RandomPassword']) == False + assert any(c.isspace() for c in random_password["RandomPassword"]) == False @mock_secretsmanager def test_get_random_include_space_true(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - random_password = conn.get_random_password(PasswordLength=4, - IncludeSpace=True) - assert any(c.isspace() for c in random_password['RandomPassword']) == True + random_password = conn.get_random_password(PasswordLength=4, IncludeSpace=True) + assert any(c.isspace() for c in random_password["RandomPassword"]) == True @mock_secretsmanager def test_get_random_require_each_included_type(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - random_password = conn.get_random_password(PasswordLength=4, - RequireEachIncludedType=True) - assert any(c in string.punctuation for c in random_password['RandomPassword']) == True - assert any(c in string.ascii_lowercase for c in random_password['RandomPassword']) == True - assert any(c in string.ascii_uppercase for c in random_password['RandomPassword']) == True - assert any(c in string.digits for c in random_password['RandomPassword']) == True + random_password = conn.get_random_password( + PasswordLength=4, RequireEachIncludedType=True + ) + assert ( + any(c in string.punctuation for c in random_password["RandomPassword"]) == True + ) + assert ( + any(c in string.ascii_lowercase for c in random_password["RandomPassword"]) + == True + ) + assert ( + any(c in string.ascii_uppercase for c in random_password["RandomPassword"]) + == True + ) + assert any(c in string.digits for c in random_password["RandomPassword"]) == True @mock_secretsmanager def test_get_random_too_short_password(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): random_password = conn.get_random_password(PasswordLength=3) @@ -323,7 +338,7 @@ def test_get_random_too_short_password(): @mock_secretsmanager def test_get_random_too_long_password(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(Exception): random_password = conn.get_random_password(PasswordLength=5555) @@ -331,177 +346,167 @@ def test_get_random_too_long_password(): @mock_secretsmanager def test_describe_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name='test-secret', - SecretString='foosecret') - - conn.create_secret(Name='test-secret-2', - SecretString='barsecret') - - secret_description = conn.describe_secret(SecretId='test-secret') - secret_description_2 = conn.describe_secret(SecretId='test-secret-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name="test-secret", SecretString="foosecret") - assert secret_description # Returned dict is not empty - assert secret_description['Name'] == ('test-secret') - assert secret_description['ARN'] != '' # Test arn not empty - assert secret_description_2['Name'] == ('test-secret-2') - assert secret_description_2['ARN'] != '' # Test arn not empty + conn.create_secret(Name="test-secret-2", SecretString="barsecret") + + secret_description = conn.describe_secret(SecretId="test-secret") + secret_description_2 = conn.describe_secret(SecretId="test-secret-2") + + assert secret_description # Returned dict is not empty + assert secret_description["Name"] == ("test-secret") + assert secret_description["ARN"] != "" # Test arn not empty + assert secret_description_2["Name"] == ("test-secret-2") + assert secret_description_2["ARN"] != "" # Test arn not empty @mock_secretsmanager def test_describe_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='i-dont-exist') + result = conn.get_secret_value(SecretId="i-dont-exist") @mock_secretsmanager def test_describe_secret_that_does_not_match(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name='test-secret', - SecretString='foosecret') - + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name="test-secret", SecretString="foosecret") + with assert_raises(ClientError): - result = conn.get_secret_value(SecretId='i-dont-match') + result = conn.get_secret_value(SecretId="i-dont-match") @mock_secretsmanager def test_list_secrets_empty(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") secrets = conn.list_secrets() - assert secrets['SecretList'] == [] + assert secrets["SecretList"] == [] @mock_secretsmanager def test_list_secrets(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.create_secret(Name='test-secret-2', - SecretString='barsecret', - Tags=[{ - 'Key': 'a', - 'Value': '1' - }]) + conn.create_secret( + Name="test-secret-2", + SecretString="barsecret", + Tags=[{"Key": "a", "Value": "1"}], + ) secrets = conn.list_secrets() - assert secrets['SecretList'][0]['ARN'] is not None - assert secrets['SecretList'][0]['Name'] == 'test-secret' - assert secrets['SecretList'][1]['ARN'] is not None - assert secrets['SecretList'][1]['Name'] == 'test-secret-2' - assert secrets['SecretList'][1]['Tags'] == [{ - 'Key': 'a', - 'Value': '1' - }] + assert secrets["SecretList"][0]["ARN"] is not None + assert secrets["SecretList"][0]["Name"] == "test-secret" + assert secrets["SecretList"][1]["ARN"] is not None + assert secrets["SecretList"][1]["Name"] == "test-secret-2" + assert secrets["SecretList"][1]["Tags"] == [{"Key": "a", "Value": "1"}] @mock_secretsmanager def test_restore_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.delete_secret(SecretId='test-secret') + conn.delete_secret(SecretId="test-secret") - described_secret_before = conn.describe_secret(SecretId='test-secret') - assert described_secret_before['DeletedDate'] > datetime.fromtimestamp(1, pytz.utc) + described_secret_before = conn.describe_secret(SecretId="test-secret") + assert described_secret_before["DeletedDate"] > datetime.fromtimestamp(1, pytz.utc) - restored_secret = conn.restore_secret(SecretId='test-secret') - assert restored_secret['ARN'] - assert restored_secret['Name'] == 'test-secret' + restored_secret = conn.restore_secret(SecretId="test-secret") + assert restored_secret["ARN"] + assert restored_secret["Name"] == "test-secret" - described_secret_after = conn.describe_secret(SecretId='test-secret') - assert 'DeletedDate' not in described_secret_after + described_secret_after = conn.describe_secret(SecretId="test-secret") + assert "DeletedDate" not in described_secret_after @mock_secretsmanager def test_restore_secret_that_is_not_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - restored_secret = conn.restore_secret(SecretId='test-secret') - assert restored_secret['ARN'] - assert restored_secret['Name'] == 'test-secret' + restored_secret = conn.restore_secret(SecretId="test-secret") + assert restored_secret["ARN"] + assert restored_secret["Name"] == "test-secret" @mock_secretsmanager def test_restore_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError): - result = conn.restore_secret(SecretId='i-dont-exist') + result = conn.restore_secret(SecretId="i-dont-exist") @mock_secretsmanager def test_rotate_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") rotated_secret = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME) assert rotated_secret - assert rotated_secret['ARN'] != '' # Test arn not empty - assert rotated_secret['Name'] == DEFAULT_SECRET_NAME - assert rotated_secret['VersionId'] != '' + assert rotated_secret["ARN"] != "" # Test arn not empty + assert rotated_secret["Name"] == DEFAULT_SECRET_NAME + assert rotated_secret["VersionId"] != "" + @mock_secretsmanager def test_rotate_secret_enable_rotation(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") initial_description = conn.describe_secret(SecretId=DEFAULT_SECRET_NAME) assert initial_description - assert initial_description['RotationEnabled'] is False - assert initial_description['RotationRules']['AutomaticallyAfterDays'] == 0 + assert initial_description["RotationEnabled"] is False + assert initial_description["RotationRules"]["AutomaticallyAfterDays"] == 0 - conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - RotationRules={'AutomaticallyAfterDays': 42}) + conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, RotationRules={"AutomaticallyAfterDays": 42} + ) rotated_description = conn.describe_secret(SecretId=DEFAULT_SECRET_NAME) assert rotated_description - assert rotated_description['RotationEnabled'] is True - assert rotated_description['RotationRules']['AutomaticallyAfterDays'] == 42 + assert rotated_description["RotationEnabled"] is True + assert rotated_description["RotationRules"]["AutomaticallyAfterDays"] == 42 @mock_secretsmanager def test_rotate_secret_that_is_marked_deleted(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn.create_secret(Name="test-secret", SecretString="foosecret") - conn.delete_secret(SecretId='test-secret') + conn.delete_secret(SecretId="test-secret") with assert_raises(ClientError): - result = conn.rotate_secret(SecretId='test-secret') + result = conn.rotate_secret(SecretId="test-secret") @mock_secretsmanager def test_rotate_secret_that_does_not_exist(): - conn = boto3.client('secretsmanager', 'us-west-2') + conn = boto3.client("secretsmanager", "us-west-2") with assert_raises(ClientError): - result = conn.rotate_secret(SecretId='i-dont-exist') + result = conn.rotate_secret(SecretId="i-dont-exist") + @mock_secretsmanager def test_rotate_secret_that_does_not_match(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name='test-secret', - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name="test-secret", SecretString="foosecret") with assert_raises(ClientError): - result = conn.rotate_secret(SecretId='i-dont-match') + result = conn.rotate_secret(SecretId="i-dont-match") + @mock_secretsmanager def test_rotate_secret_client_request_token_too_short(): @@ -510,30 +515,32 @@ def test_rotate_secret_client_request_token_too_short(): # test_server actually handles this error. assert True + @mock_secretsmanager def test_rotate_secret_client_request_token_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") client_request_token = ( - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-' - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C' + "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-" "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C" ) with assert_raises(ClientError): - result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - ClientRequestToken=client_request_token) + result = conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, ClientRequestToken=client_request_token + ) + @mock_secretsmanager def test_rotate_secret_rotation_lambda_arn_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") - rotation_lambda_arn = '85B7-446A-B7E4' * 147 # == 2058 characters + rotation_lambda_arn = "85B7-446A-B7E4" * 147 # == 2058 characters with assert_raises(ClientError): - result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - RotationLambdaARN=rotation_lambda_arn) + result = conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, RotationLambdaARN=rotation_lambda_arn + ) + @mock_secretsmanager def test_rotate_secret_rotation_period_zero(): @@ -545,118 +552,138 @@ def test_rotate_secret_rotation_period_zero(): @mock_secretsmanager def test_rotate_secret_rotation_period_too_long(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - conn.create_secret(Name=DEFAULT_SECRET_NAME, - SecretString='foosecret') + conn = boto3.client("secretsmanager", region_name="us-west-2") + conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretString="foosecret") - rotation_rules = {'AutomaticallyAfterDays': 1001} + rotation_rules = {"AutomaticallyAfterDays": 1001} with assert_raises(ClientError): - result = conn.rotate_secret(SecretId=DEFAULT_SECRET_NAME, - RotationRules=rotation_rules) + result = conn.rotate_secret( + SecretId=DEFAULT_SECRET_NAME, RotationRules=rotation_rules + ) @mock_secretsmanager def test_put_secret_value_puts_new_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='foosecret', - VersionStages=['AWSCURRENT']) - version_id = put_secret_value_dict['VersionId'] + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="foosecret", + VersionStages=["AWSCURRENT"], + ) + version_id = put_secret_value_dict["VersionId"] - get_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME, - VersionId=version_id, - VersionStage='AWSCURRENT') + get_secret_value_dict = conn.get_secret_value( + SecretId=DEFAULT_SECRET_NAME, VersionId=version_id, VersionStage="AWSCURRENT" + ) assert get_secret_value_dict - assert get_secret_value_dict['SecretString'] == 'foosecret' + assert get_secret_value_dict["SecretString"] == "foosecret" @mock_secretsmanager def test_put_secret_binary_value_puts_new_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretBinary=b('foosecret'), - VersionStages=['AWSCURRENT']) - version_id = put_secret_value_dict['VersionId'] + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretBinary=b("foosecret"), + VersionStages=["AWSCURRENT"], + ) + version_id = put_secret_value_dict["VersionId"] - get_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME, - VersionId=version_id, - VersionStage='AWSCURRENT') + get_secret_value_dict = conn.get_secret_value( + SecretId=DEFAULT_SECRET_NAME, VersionId=version_id, VersionStage="AWSCURRENT" + ) assert get_secret_value_dict - assert get_secret_value_dict['SecretBinary'] == b('foosecret') + assert get_secret_value_dict["SecretBinary"] == b("foosecret") @mock_secretsmanager def test_create_and_put_secret_binary_value_puts_new_secret(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") conn.create_secret(Name=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret")) - conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, SecretBinary=b('foosecret_update')) + conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, SecretBinary=b("foosecret_update") + ) latest_secret = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME) assert latest_secret - assert latest_secret['SecretBinary'] == b('foosecret_update') + assert latest_secret["SecretBinary"] == b("foosecret_update") @mock_secretsmanager def test_put_secret_binary_requires_either_string_or_binary(): - conn = boto3.client('secretsmanager', region_name='us-west-2') + conn = boto3.client("secretsmanager", region_name="us-west-2") with assert_raises(ClientError) as ire: conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME) - ire.exception.response['Error']['Code'].should.equal('InvalidRequestException') - ire.exception.response['Error']['Message'].should.equal('You must provide either SecretString or SecretBinary.') + ire.exception.response["Error"]["Code"].should.equal("InvalidRequestException") + ire.exception.response["Error"]["Message"].should.equal( + "You must provide either SecretString or SecretBinary." + ) @mock_secretsmanager def test_put_secret_value_can_get_first_version_if_put_twice(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='first_secret', - VersionStages=['AWSCURRENT']) - first_version_id = put_secret_value_dict['VersionId'] - conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='second_secret', - VersionStages=['AWSCURRENT']) + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="first_secret", + VersionStages=["AWSCURRENT"], + ) + first_version_id = put_secret_value_dict["VersionId"] + conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="second_secret", + VersionStages=["AWSCURRENT"], + ) - first_secret_value_dict = conn.get_secret_value(SecretId=DEFAULT_SECRET_NAME, - VersionId=first_version_id) - first_secret_value = first_secret_value_dict['SecretString'] + first_secret_value_dict = conn.get_secret_value( + SecretId=DEFAULT_SECRET_NAME, VersionId=first_version_id + ) + first_secret_value = first_secret_value_dict["SecretString"] - assert first_secret_value == 'first_secret' + assert first_secret_value == "first_secret" @mock_secretsmanager def test_put_secret_value_versions_differ_if_same_secret_put_twice(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - first_version_id = put_secret_value_dict['VersionId'] - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - second_version_id = put_secret_value_dict['VersionId'] + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + first_version_id = put_secret_value_dict["VersionId"] + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + second_version_id = put_secret_value_dict["VersionId"] assert first_version_id != second_version_id @mock_secretsmanager def test_can_list_secret_version_ids(): - conn = boto3.client('secretsmanager', region_name='us-west-2') - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - first_version_id = put_secret_value_dict['VersionId'] - put_secret_value_dict = conn.put_secret_value(SecretId=DEFAULT_SECRET_NAME, - SecretString='dupe_secret', - VersionStages=['AWSCURRENT']) - second_version_id = put_secret_value_dict['VersionId'] + conn = boto3.client("secretsmanager", region_name="us-west-2") + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + first_version_id = put_secret_value_dict["VersionId"] + put_secret_value_dict = conn.put_secret_value( + SecretId=DEFAULT_SECRET_NAME, + SecretString="dupe_secret", + VersionStages=["AWSCURRENT"], + ) + second_version_id = put_secret_value_dict["VersionId"] versions_list = conn.list_secret_version_ids(SecretId=DEFAULT_SECRET_NAME) - returned_version_ids = [v['VersionId'] for v in versions_list['Versions']] + returned_version_ids = [v["VersionId"] for v in versions_list["Versions"]] assert [first_version_id, second_version_id].sort() == returned_version_ids.sort() - diff --git a/tests/test_secretsmanager/test_server.py b/tests/test_secretsmanager/test_server.py index 6955d8232..89cb90185 100644 --- a/tests/test_secretsmanager/test_server.py +++ b/tests/test_secretsmanager/test_server.py @@ -7,11 +7,11 @@ import sure # noqa import moto.server as server from moto import mock_secretsmanager -''' +""" Test the different server responses for secretsmanager -''' +""" -DEFAULT_SECRET_NAME = 'test-secret' +DEFAULT_SECRET_NAME = "test-secret" @mock_secretsmanager @@ -20,22 +20,21 @@ def test_get_secret_value(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foo-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) - get_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "VersionStage": "AWSCURRENT"}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foo-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + get_secret = test_client.post( + "/", + data={"SecretId": DEFAULT_SECRET_NAME, "VersionStage": "AWSCURRENT"}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) json_data = json.loads(get_secret.data.decode("utf-8")) - assert json_data['SecretString'] == 'foo-secret' + assert json_data["SecretString"] == "foo-secret" + @mock_secretsmanager def test_get_secret_that_does_not_exist(): @@ -43,56 +42,63 @@ def test_get_secret_that_does_not_exist(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - get_secret = test_client.post('/', - data={"SecretId": "i-dont-exist", - "VersionStage": "AWSCURRENT"}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + get_secret = test_client.post( + "/", + data={"SecretId": "i-dont-exist", "VersionStage": "AWSCURRENT"}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) json_data = json.loads(get_secret.data.decode("utf-8")) - assert json_data['message'] == u"Secrets Manager can\u2019t find the specified secret." - assert json_data['__type'] == 'ResourceNotFoundException' + assert ( + json_data["message"] == "Secrets Manager can\u2019t find the specified secret." + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_get_secret_that_does_not_match(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foo-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) - get_secret = test_client.post('/', - data={"SecretId": "i-dont-match", - "VersionStage": "AWSCURRENT"}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foo-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + get_secret = test_client.post( + "/", + data={"SecretId": "i-dont-match", "VersionStage": "AWSCURRENT"}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) json_data = json.loads(get_secret.data.decode("utf-8")) - assert json_data['message'] == u"Secrets Manager can\u2019t find the specified secret." - assert json_data['__type'] == 'ResourceNotFoundException' + assert ( + json_data["message"] == "Secrets Manager can\u2019t find the specified secret." + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_get_secret_that_has_no_value(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) - get_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + get_secret = test_client.post( + "/", + data={"SecretId": DEFAULT_SECRET_NAME}, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) json_data = json.loads(get_secret.data.decode("utf-8")) - assert json_data['message'] == u"Secrets Manager can\u2019t find the specified secret value for staging label: AWSCURRENT" - assert json_data['__type'] == 'ResourceNotFoundException' + assert ( + json_data["message"] + == "Secrets Manager can\u2019t find the specified secret value for staging label: AWSCURRENT" + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_create_secret(): @@ -100,139 +106,135 @@ def test_create_secret(): backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - res = test_client.post('/', - data={"Name": "test-secret", - "SecretString": "foo-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) - res_2 = test_client.post('/', - data={"Name": "test-secret-2", - "SecretString": "bar-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret"}, - ) + res = test_client.post( + "/", + data={"Name": "test-secret", "SecretString": "foo-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + res_2 = test_client.post( + "/", + data={"Name": "test-secret-2", "SecretString": "bar-secret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) json_data = json.loads(res.data.decode("utf-8")) - assert json_data['ARN'] != '' - assert json_data['Name'] == 'test-secret' - + assert json_data["ARN"] != "" + assert json_data["Name"] == "test-secret" + json_data_2 = json.loads(res_2.data.decode("utf-8")) - assert json_data_2['ARN'] != '' - assert json_data_2['Name'] == 'test-secret-2' + assert json_data_2["ARN"] != "" + assert json_data_2["Name"] == "test-secret-2" + @mock_secretsmanager def test_describe_secret(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": "test-secret", - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) - describe_secret = test_client.post('/', - data={"SecretId": "test-secret"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) - - create_secret_2 = test_client.post('/', - data={"Name": "test-secret-2", - "SecretString": "barsecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) - describe_secret_2 = test_client.post('/', - data={"SecretId": "test-secret-2"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": "test-secret", "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + describe_secret = test_client.post( + "/", + data={"SecretId": "test-secret"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) + + create_secret_2 = test_client.post( + "/", + data={"Name": "test-secret-2", "SecretString": "barsecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + describe_secret_2 = test_client.post( + "/", + data={"SecretId": "test-secret-2"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) json_data = json.loads(describe_secret.data.decode("utf-8")) - assert json_data # Returned dict is not empty - assert json_data['ARN'] != '' - assert json_data['Name'] == 'test-secret' - + assert json_data # Returned dict is not empty + assert json_data["ARN"] != "" + assert json_data["Name"] == "test-secret" + json_data_2 = json.loads(describe_secret_2.data.decode("utf-8")) - assert json_data_2 # Returned dict is not empty - assert json_data_2['ARN'] != '' - assert json_data_2['Name'] == 'test-secret-2' + assert json_data_2 # Returned dict is not empty + assert json_data_2["ARN"] != "" + assert json_data_2["Name"] == "test-secret-2" + @mock_secretsmanager def test_describe_secret_that_does_not_exist(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - describe_secret = test_client.post('/', - data={"SecretId": "i-dont-exist"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) + describe_secret = test_client.post( + "/", + data={"SecretId": "i-dont-exist"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) json_data = json.loads(describe_secret.data.decode("utf-8")) - assert json_data['message'] == u"Secrets Manager can\u2019t find the specified secret." - assert json_data['__type'] == 'ResourceNotFoundException' + assert ( + json_data["message"] == "Secrets Manager can\u2019t find the specified secret." + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_describe_secret_that_does_not_match(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) - describe_secret = test_client.post('/', - data={"SecretId": "i-dont-match"}, - headers={ - "X-Amz-Target": "secretsmanager.DescribeSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) + describe_secret = test_client.post( + "/", + data={"SecretId": "i-dont-match"}, + headers={"X-Amz-Target": "secretsmanager.DescribeSecret"}, + ) json_data = json.loads(describe_secret.data.decode("utf-8")) - assert json_data['message'] == u"Secrets Manager can\u2019t find the specified secret." - assert json_data['__type'] == 'ResourceNotFoundException' + assert ( + json_data["message"] == "Secrets Manager can\u2019t find the specified secret." + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_rotate_secret(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) client_request_token = "EXAMPLE2-90ab-cdef-fedc-ba987SECRET2" - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "ClientRequestToken": client_request_token}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "ClientRequestToken": client_request_token, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data # Returned dict is not empty - assert json_data['ARN'] != '' - assert json_data['Name'] == DEFAULT_SECRET_NAME - assert json_data['VersionId'] == client_request_token + assert json_data # Returned dict is not empty + assert json_data["ARN"] != "" + assert json_data["Name"] == DEFAULT_SECRET_NAME + assert json_data["VersionId"] == client_request_token + # @mock_secretsmanager # def test_rotate_secret_enable_rotation(): @@ -291,291 +293,316 @@ def test_rotate_secret(): # assert json_data['RotationEnabled'] is True # assert json_data['RotationRules']['AutomaticallyAfterDays'] == 42 + @mock_secretsmanager def test_rotate_secret_that_does_not_exist(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - rotate_secret = test_client.post('/', - data={"SecretId": "i-dont-exist"}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={"SecretId": "i-dont-exist"}, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == u"Secrets Manager can\u2019t find the specified secret." - assert json_data['__type'] == 'ResourceNotFoundException' + assert ( + json_data["message"] == "Secrets Manager can\u2019t find the specified secret." + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_rotate_secret_that_does_not_match(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) - rotate_secret = test_client.post('/', - data={"SecretId": "i-dont-match"}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={"SecretId": "i-dont-match"}, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == u"Secrets Manager can\u2019t find the specified secret." - assert json_data['__type'] == 'ResourceNotFoundException' + assert ( + json_data["message"] == "Secrets Manager can\u2019t find the specified secret." + ) + assert json_data["__type"] == "ResourceNotFoundException" + @mock_secretsmanager def test_rotate_secret_client_request_token_too_short(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) client_request_token = "ED9F8B6C-85B7-B7E4-38F2A3BEB13C" - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "ClientRequestToken": client_request_token}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "ClientRequestToken": client_request_token, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "ClientRequestToken must be 32-64 characters long." - assert json_data['__type'] == 'InvalidParameterException' + assert json_data["message"] == "ClientRequestToken must be 32-64 characters long." + assert json_data["__type"] == "InvalidParameterException" + @mock_secretsmanager def test_rotate_secret_client_request_token_too_long(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) client_request_token = ( - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-' - 'ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C' + "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C-" "ED9F8B6C-85B7-446A-B7E4-38F2A3BEB13C" + ) + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "ClientRequestToken": client_request_token, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, ) - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "ClientRequestToken": client_request_token}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "ClientRequestToken must be 32-64 characters long." - assert json_data['__type'] == 'InvalidParameterException' + assert json_data["message"] == "ClientRequestToken must be 32-64 characters long." + assert json_data["__type"] == "InvalidParameterException" + @mock_secretsmanager def test_rotate_secret_rotation_lambda_arn_too_long(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - create_secret = test_client.post('/', - data={"Name": DEFAULT_SECRET_NAME, - "SecretString": "foosecret"}, - headers={ - "X-Amz-Target": "secretsmanager.CreateSecret" - }, - ) + create_secret = test_client.post( + "/", + data={"Name": DEFAULT_SECRET_NAME, "SecretString": "foosecret"}, + headers={"X-Amz-Target": "secretsmanager.CreateSecret"}, + ) - rotation_lambda_arn = '85B7-446A-B7E4' * 147 # == 2058 characters - rotate_secret = test_client.post('/', - data={"SecretId": DEFAULT_SECRET_NAME, - "RotationLambdaARN": rotation_lambda_arn}, - headers={ - "X-Amz-Target": "secretsmanager.RotateSecret" - }, - ) + rotation_lambda_arn = "85B7-446A-B7E4" * 147 # == 2058 characters + rotate_secret = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "RotationLambdaARN": rotation_lambda_arn, + }, + headers={"X-Amz-Target": "secretsmanager.RotateSecret"}, + ) json_data = json.loads(rotate_secret.data.decode("utf-8")) - assert json_data['message'] == "RotationLambdaARN must <= 2048 characters long." - assert json_data['__type'] == 'InvalidParameterException' - - - + assert json_data["message"] == "RotationLambdaARN must <= 2048 characters long." + assert json_data["__type"] == "InvalidParameterException" @mock_secretsmanager def test_put_secret_value_puts_new_secret(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "foosecret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) + test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "foosecret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) - put_second_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "foosecret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - second_secret_json_data = json.loads(put_second_secret_value_json.data.decode("utf-8")) + put_second_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "foosecret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + second_secret_json_data = json.loads( + put_second_secret_value_json.data.decode("utf-8") + ) - version_id = second_secret_json_data['VersionId'] + version_id = second_secret_json_data["VersionId"] - secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "VersionId": version_id, - "VersionStage": 'AWSCURRENT'}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "VersionId": version_id, + "VersionStage": "AWSCURRENT", + }, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) second_secret_json_data = json.loads(secret_value_json.data.decode("utf-8")) assert second_secret_json_data - assert second_secret_json_data['SecretString'] == 'foosecret' + assert second_secret_json_data["SecretString"] == "foosecret" @mock_secretsmanager def test_put_secret_value_can_get_first_version_if_put_twice(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - first_secret_string = 'first_secret' - second_secret_string = 'second_secret' + first_secret_string = "first_secret" + second_secret_string = "second_secret" - put_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": first_secret_string, - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) + put_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": first_secret_string, + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) - first_secret_json_data = json.loads(put_first_secret_value_json.data.decode("utf-8")) + first_secret_json_data = json.loads( + put_first_secret_value_json.data.decode("utf-8") + ) - first_secret_version_id = first_secret_json_data['VersionId'] + first_secret_version_id = first_secret_json_data["VersionId"] - test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": second_secret_string, - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) + test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": second_secret_string, + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) - get_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "VersionId": first_secret_version_id, - "VersionStage": 'AWSCURRENT'}, - headers={ - "X-Amz-Target": "secretsmanager.GetSecretValue"}, - ) + get_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "VersionId": first_secret_version_id, + "VersionStage": "AWSCURRENT", + }, + headers={"X-Amz-Target": "secretsmanager.GetSecretValue"}, + ) - get_first_secret_json_data = json.loads(get_first_secret_value_json.data.decode("utf-8")) + get_first_secret_json_data = json.loads( + get_first_secret_value_json.data.decode("utf-8") + ) assert get_first_secret_json_data - assert get_first_secret_json_data['SecretString'] == first_secret_string + assert get_first_secret_json_data["SecretString"] == first_secret_string @mock_secretsmanager def test_put_secret_value_versions_differ_if_same_secret_put_twice(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - put_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - first_secret_json_data = json.loads(put_first_secret_value_json.data.decode("utf-8")) - first_secret_version_id = first_secret_json_data['VersionId'] + put_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + first_secret_json_data = json.loads( + put_first_secret_value_json.data.decode("utf-8") + ) + first_secret_version_id = first_secret_json_data["VersionId"] - put_second_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - second_secret_json_data = json.loads(put_second_secret_value_json.data.decode("utf-8")) - second_secret_version_id = second_secret_json_data['VersionId'] + put_second_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + second_secret_json_data = json.loads( + put_second_secret_value_json.data.decode("utf-8") + ) + second_secret_version_id = second_secret_json_data["VersionId"] assert first_secret_version_id != second_secret_version_id @mock_secretsmanager def test_can_list_secret_version_ids(): - backend = server.create_backend_app('secretsmanager') + backend = server.create_backend_app("secretsmanager") test_client = backend.test_client() - put_first_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - first_secret_json_data = json.loads(put_first_secret_value_json.data.decode("utf-8")) - first_secret_version_id = first_secret_json_data['VersionId'] - put_second_secret_value_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, - "SecretString": "secret", - "VersionStages": ["AWSCURRENT"]}, - headers={ - "X-Amz-Target": "secretsmanager.PutSecretValue"}, - ) - second_secret_json_data = json.loads(put_second_secret_value_json.data.decode("utf-8")) - second_secret_version_id = second_secret_json_data['VersionId'] + put_first_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + first_secret_json_data = json.loads( + put_first_secret_value_json.data.decode("utf-8") + ) + first_secret_version_id = first_secret_json_data["VersionId"] + put_second_secret_value_json = test_client.post( + "/", + data={ + "SecretId": DEFAULT_SECRET_NAME, + "SecretString": "secret", + "VersionStages": ["AWSCURRENT"], + }, + headers={"X-Amz-Target": "secretsmanager.PutSecretValue"}, + ) + second_secret_json_data = json.loads( + put_second_secret_value_json.data.decode("utf-8") + ) + second_secret_version_id = second_secret_json_data["VersionId"] - list_secret_versions_json = test_client.post('/', - data={ - "SecretId": DEFAULT_SECRET_NAME, }, - headers={ - "X-Amz-Target": "secretsmanager.ListSecretVersionIds"}, - ) + list_secret_versions_json = test_client.post( + "/", + data={"SecretId": DEFAULT_SECRET_NAME}, + headers={"X-Amz-Target": "secretsmanager.ListSecretVersionIds"}, + ) versions_list = json.loads(list_secret_versions_json.data.decode("utf-8")) - returned_version_ids = [v['VersionId'] for v in versions_list['Versions']] + returned_version_ids = [v["VersionId"] for v in versions_list["Versions"]] + + assert [ + first_secret_version_id, + second_secret_version_id, + ].sort() == returned_version_ids.sort() - assert [first_secret_version_id, second_secret_version_id].sort() == returned_version_ids.sort() # # The following tests should work, but fail on the embedded dict in # RotationRules. The error message suggests a problem deeper in the code, which # needs further investigation. -# +# # @mock_secretsmanager # def test_rotate_secret_rotation_period_zero(): diff --git a/tests/test_ses/test_server.py b/tests/test_ses/test_server.py index 6af656000..b9d2252ce 100644 --- a/tests/test_ses/test_server.py +++ b/tests/test_ses/test_server.py @@ -3,14 +3,14 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_ses_list_identities(): backend = server.create_backend_app("ses") test_client = backend.test_client() - res = test_client.get('/?Action=ListIdentities') + res = test_client.get("/?Action=ListIdentities") res.data.should.contain(b"ListIdentitiesResponse") diff --git a/tests/test_ses/test_ses.py b/tests/test_ses/test_ses.py index 431d42e1d..851327b9d 100644 --- a/tests/test_ses/test_ses.py +++ b/tests/test_ses/test_ses.py @@ -11,106 +11,119 @@ from moto import mock_ses_deprecated @mock_ses_deprecated def test_verify_email_identity(): - conn = boto.connect_ses('the_key', 'the_secret') + conn = boto.connect_ses("the_key", "the_secret") conn.verify_email_identity("test@example.com") identities = conn.list_identities() - address = identities['ListIdentitiesResponse'][ - 'ListIdentitiesResult']['Identities'][0] - address.should.equal('test@example.com') + address = identities["ListIdentitiesResponse"]["ListIdentitiesResult"][ + "Identities" + ][0] + address.should.equal("test@example.com") @mock_ses_deprecated def test_domain_verify(): - conn = boto.connect_ses('the_key', 'the_secret') + conn = boto.connect_ses("the_key", "the_secret") conn.verify_domain_dkim("domain1.com") conn.verify_domain_identity("domain2.com") identities = conn.list_identities() - domains = list(identities['ListIdentitiesResponse'][ - 'ListIdentitiesResult']['Identities']) - domains.should.equal(['domain1.com', 'domain2.com']) + domains = list( + identities["ListIdentitiesResponse"]["ListIdentitiesResult"]["Identities"] + ) + domains.should.equal(["domain1.com", "domain2.com"]) @mock_ses_deprecated def test_delete_identity(): - conn = boto.connect_ses('the_key', 'the_secret') + conn = boto.connect_ses("the_key", "the_secret") conn.verify_email_identity("test@example.com") - conn.list_identities()['ListIdentitiesResponse']['ListIdentitiesResult'][ - 'Identities'].should.have.length_of(1) + conn.list_identities()["ListIdentitiesResponse"]["ListIdentitiesResult"][ + "Identities" + ].should.have.length_of(1) conn.delete_identity("test@example.com") - conn.list_identities()['ListIdentitiesResponse']['ListIdentitiesResult'][ - 'Identities'].should.have.length_of(0) + conn.list_identities()["ListIdentitiesResponse"]["ListIdentitiesResult"][ + "Identities" + ].should.have.length_of(0) @mock_ses_deprecated def test_send_email(): - conn = boto.connect_ses('the_key', 'the_secret') + conn = boto.connect_ses("the_key", "the_secret") conn.send_email.when.called_with( - "test@example.com", "test subject", - "test body", "test_to@example.com").should.throw(BotoServerError) + "test@example.com", "test subject", "test body", "test_to@example.com" + ).should.throw(BotoServerError) conn.verify_email_identity("test@example.com") - conn.send_email("test@example.com", "test subject", - "test body", "test_to@example.com") + conn.send_email( + "test@example.com", "test subject", "test body", "test_to@example.com" + ) send_quota = conn.get_send_quota() - sent_count = int(send_quota['GetSendQuotaResponse'][ - 'GetSendQuotaResult']['SentLast24Hours']) + sent_count = int( + send_quota["GetSendQuotaResponse"]["GetSendQuotaResult"]["SentLast24Hours"] + ) sent_count.should.equal(1) @mock_ses_deprecated def test_send_html_email(): - conn = boto.connect_ses('the_key', 'the_secret') + conn = boto.connect_ses("the_key", "the_secret") conn.send_email.when.called_with( - "test@example.com", "test subject", - "test body", "test_to@example.com", format="html").should.throw(BotoServerError) + "test@example.com", + "test subject", + "test body", + "test_to@example.com", + format="html", + ).should.throw(BotoServerError) conn.verify_email_identity("test@example.com") - conn.send_email("test@example.com", "test subject", - "test body", "test_to@example.com", format="html") + conn.send_email( + "test@example.com", + "test subject", + "test body", + "test_to@example.com", + format="html", + ) send_quota = conn.get_send_quota() - sent_count = int(send_quota['GetSendQuotaResponse'][ - 'GetSendQuotaResult']['SentLast24Hours']) + sent_count = int( + send_quota["GetSendQuotaResponse"]["GetSendQuotaResult"]["SentLast24Hours"] + ) sent_count.should.equal(1) @mock_ses_deprecated def test_send_raw_email(): - conn = boto.connect_ses('the_key', 'the_secret') + conn = boto.connect_ses("the_key", "the_secret") message = email.mime.multipart.MIMEMultipart() - message['Subject'] = 'Test' - message['From'] = 'test@example.com' - message['To'] = 'to@example.com' + message["Subject"] = "Test" + message["From"] = "test@example.com" + message["To"] = "to@example.com" # Message body - part = email.mime.text.MIMEText('test file attached') + part = email.mime.text.MIMEText("test file attached") message.attach(part) # Attachment - part = email.mime.text.MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') + part = email.mime.text.MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") message.attach(part) conn.send_raw_email.when.called_with( - source=message['From'], - raw_message=message.as_string(), + source=message["From"], raw_message=message.as_string() ).should.throw(BotoServerError) conn.verify_email_identity("test@example.com") - conn.send_raw_email( - source=message['From'], - raw_message=message.as_string(), - ) + conn.send_raw_email(source=message["From"], raw_message=message.as_string()) send_quota = conn.get_send_quota() - sent_count = int(send_quota['GetSendQuotaResponse'][ - 'GetSendQuotaResult']['SentLast24Hours']) + sent_count = int( + send_quota["GetSendQuotaResponse"]["GetSendQuotaResult"]["SentLast24Hours"] + ) sent_count.should.equal(1) diff --git a/tests/test_ses/test_ses_boto3.py b/tests/test_ses/test_ses_boto3.py index fa042164d..ee7c92aa1 100644 --- a/tests/test_ses/test_ses_boto3.py +++ b/tests/test_ses/test_ses_boto3.py @@ -12,46 +12,48 @@ from moto import mock_ses @mock_ses def test_verify_email_identity(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") conn.verify_email_identity(EmailAddress="test@example.com") identities = conn.list_identities() - address = identities['Identities'][0] - address.should.equal('test@example.com') + address = identities["Identities"][0] + address.should.equal("test@example.com") + @mock_ses def test_verify_email_address(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") conn.verify_email_address(EmailAddress="test@example.com") email_addresses = conn.list_verified_email_addresses() - email = email_addresses['VerifiedEmailAddresses'][0] - email.should.equal('test@example.com') + email = email_addresses["VerifiedEmailAddresses"][0] + email.should.equal("test@example.com") + @mock_ses def test_domain_verify(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") conn.verify_domain_dkim(Domain="domain1.com") conn.verify_domain_identity(Domain="domain2.com") identities = conn.list_identities() - domains = list(identities['Identities']) - domains.should.equal(['domain1.com', 'domain2.com']) + domains = list(identities["Identities"]) + domains.should.equal(["domain1.com", "domain2.com"]) @mock_ses def test_delete_identity(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") conn.verify_email_identity(EmailAddress="test@example.com") - conn.list_identities()['Identities'].should.have.length_of(1) + conn.list_identities()["Identities"].should.have.length_of(1) conn.delete_identity(Identity="test@example.com") - conn.list_identities()['Identities'].should.have.length_of(0) + conn.list_identities()["Identities"].should.have.length_of(0) @mock_ses def test_send_email(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") kwargs = dict( Source="test@example.com", @@ -62,27 +64,27 @@ def test_send_email(): }, Message={ "Subject": {"Data": "test subject"}, - "Body": {"Text": {"Data": "test body"}} - } + "Body": {"Text": {"Data": "test body"}}, + }, ) conn.send_email.when.called_with(**kwargs).should.throw(ClientError) - conn.verify_domain_identity(Domain='example.com') + conn.verify_domain_identity(Domain="example.com") conn.send_email(**kwargs) - too_many_addresses = list('to%s@example.com' % i for i in range(51)) + too_many_addresses = list("to%s@example.com" % i for i in range(51)) conn.send_email.when.called_with( - **dict(kwargs, Destination={'ToAddresses': too_many_addresses}) + **dict(kwargs, Destination={"ToAddresses": too_many_addresses}) ).should.throw(ClientError) send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) + sent_count = int(send_quota["SentLast24Hours"]) sent_count.should.equal(3) @mock_ses def test_send_templated_email(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") kwargs = dict( Source="test@example.com", @@ -92,38 +94,35 @@ def test_send_templated_email(): "BccAddresses": ["test_bcc@example.com"], }, Template="test_template", - TemplateData='{\"name\": \"test\"}' + TemplateData='{"name": "test"}', ) - conn.send_templated_email.when.called_with( - **kwargs).should.throw(ClientError) + conn.send_templated_email.when.called_with(**kwargs).should.throw(ClientError) - conn.verify_domain_identity(Domain='example.com') + conn.verify_domain_identity(Domain="example.com") conn.send_templated_email(**kwargs) - too_many_addresses = list('to%s@example.com' % i for i in range(51)) + too_many_addresses = list("to%s@example.com" % i for i in range(51)) conn.send_templated_email.when.called_with( - **dict(kwargs, Destination={'ToAddresses': too_many_addresses}) + **dict(kwargs, Destination={"ToAddresses": too_many_addresses}) ).should.throw(ClientError) send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) + sent_count = int(send_quota["SentLast24Hours"]) sent_count.should.equal(3) @mock_ses def test_send_html_email(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") kwargs = dict( Source="test@example.com", - Destination={ - "ToAddresses": ["test_to@example.com"] - }, + Destination={"ToAddresses": ["test_to@example.com"]}, Message={ "Subject": {"Data": "test subject"}, - "Body": {"Html": {"Data": "test body"}} - } + "Body": {"Html": {"Data": "test body"}}, + }, ) conn.send_email.when.called_with(**kwargs).should.throw(ClientError) @@ -132,32 +131,29 @@ def test_send_html_email(): conn.send_email(**kwargs) send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) + sent_count = int(send_quota["SentLast24Hours"]) sent_count.should.equal(1) @mock_ses def test_send_raw_email(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") message = MIMEMultipart() - message['Subject'] = 'Test' - message['From'] = 'test@example.com' - message['To'] = 'to@example.com, foo@example.com' + message["Subject"] = "Test" + message["From"] = "test@example.com" + message["To"] = "to@example.com, foo@example.com" # Message body - part = MIMEText('test file attached') + part = MIMEText("test file attached") message.attach(part) # Attachment - part = MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') + part = MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") message.attach(part) - kwargs = dict( - Source=message['From'], - RawMessage={'Data': message.as_string()}, - ) + kwargs = dict(Source=message["From"], RawMessage={"Data": message.as_string()}) conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) @@ -165,31 +161,29 @@ def test_send_raw_email(): conn.send_raw_email(**kwargs) send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) + sent_count = int(send_quota["SentLast24Hours"]) sent_count.should.equal(2) @mock_ses def test_send_raw_email_without_source(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") message = MIMEMultipart() - message['Subject'] = 'Test' - message['From'] = 'test@example.com' - message['To'] = 'to@example.com, foo@example.com' + message["Subject"] = "Test" + message["From"] = "test@example.com" + message["To"] = "to@example.com, foo@example.com" # Message body - part = MIMEText('test file attached') + part = MIMEText("test file attached") message.attach(part) # Attachment - part = MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') + part = MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") message.attach(part) - kwargs = dict( - RawMessage={'Data': message.as_string()}, - ) + kwargs = dict(RawMessage={"Data": message.as_string()}) conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) @@ -197,29 +191,26 @@ def test_send_raw_email_without_source(): conn.send_raw_email(**kwargs) send_quota = conn.get_send_quota() - sent_count = int(send_quota['SentLast24Hours']) + sent_count = int(send_quota["SentLast24Hours"]) sent_count.should.equal(2) @mock_ses def test_send_raw_email_without_source_or_from(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") message = MIMEMultipart() - message['Subject'] = 'Test' - message['To'] = 'to@example.com, foo@example.com' + message["Subject"] = "Test" + message["To"] = "to@example.com, foo@example.com" # Message body - part = MIMEText('test file attached') + part = MIMEText("test file attached") message.attach(part) # Attachment - part = MIMEText('contents of test file here') - part.add_header('Content-Disposition', 'attachment; filename=test.txt') + part = MIMEText("contents of test file here") + part.add_header("Content-Disposition", "attachment; filename=test.txt") message.attach(part) - kwargs = dict( - RawMessage={'Data': message.as_string()}, - ) + kwargs = dict(RawMessage={"Data": message.as_string()}) conn.send_raw_email.when.called_with(**kwargs).should.throw(ClientError) - diff --git a/tests/test_ses/test_ses_sns_boto3.py b/tests/test_ses/test_ses_sns_boto3.py index 37f79a8b0..a55c150ff 100644 --- a/tests/test_ses/test_ses_sns_boto3.py +++ b/tests/test_ses/test_ses_sns_boto3.py @@ -14,19 +14,16 @@ from moto.ses.models import SESFeedback @mock_ses def test_enable_disable_ses_sns_communication(): - conn = boto3.client('ses', region_name='us-east-1') + conn = boto3.client("ses", region_name="us-east-1") conn.set_identity_notification_topic( - Identity='test.com', - NotificationType='Bounce', - SnsTopic='the-arn' - ) - conn.set_identity_notification_topic( - Identity='test.com', - NotificationType='Bounce' + Identity="test.com", NotificationType="Bounce", SnsTopic="the-arn" ) + conn.set_identity_notification_topic(Identity="test.com", NotificationType="Bounce") -def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg): +def __setup_feedback_env__( + ses_conn, sns_conn, sqs_conn, domain, topic, queue, region, expected_msg +): """Setup the AWS environment to test the SES SNS Feedback""" # Environment setup # Create SQS queue @@ -35,30 +32,32 @@ def __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, r create_topic_response = sns_conn.create_topic(Name=topic) topic_arn = create_topic_response["TopicArn"] # Subscribe the SNS topic to the SQS queue - sns_conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:%s:123456789012:%s" % (region, queue)) + sns_conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:%s:123456789012:%s" % (region, queue), + ) # Verify SES domain ses_conn.verify_domain_identity(Domain=domain) # Setup SES notification topic if expected_msg is not None: ses_conn.set_identity_notification_topic( - Identity=domain, - NotificationType=expected_msg, - SnsTopic=topic_arn + Identity=domain, NotificationType=expected_msg, SnsTopic=topic_arn ) def __test_sns_feedback__(addr, expected_msg): region_name = "us-east-1" - ses_conn = boto3.client('ses', region_name=region_name) - sns_conn = boto3.client('sns', region_name=region_name) - sqs_conn = boto3.resource('sqs', region_name=region_name) + ses_conn = boto3.client("ses", region_name=region_name) + sns_conn = boto3.client("sns", region_name=region_name) + sqs_conn = boto3.resource("sqs", region_name=region_name) domain = "example.com" topic = "bounce-arn-feedback" queue = "feedback-test-queue" - __setup_feedback_env__(ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg) + __setup_feedback_env__( + ses_conn, sns_conn, sqs_conn, domain, topic, queue, region_name, expected_msg + ) # Send the message kwargs = dict( @@ -70,8 +69,8 @@ def __test_sns_feedback__(addr, expected_msg): }, Message={ "Subject": {"Data": "test subject"}, - "Body": {"Text": {"Data": "test body"}} - } + "Body": {"Text": {"Data": "test body"}}, + }, ) ses_conn.send_email(**kwargs) diff --git a/tests/test_sns/test_application.py b/tests/test_sns/test_application.py index 319e4a6f8..efa5e0f3e 100644 --- a/tests/test_sns/test_application.py +++ b/tests/test_sns/test_application.py @@ -17,10 +17,12 @@ def test_create_platform_application(): "PlatformPrincipal": "platform_principal", }, ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] application_arn.should.equal( - 'arn:aws:sns:us-east-1:123456789012:app/APNS/my-application') + "arn:aws:sns:us-east-1:123456789012:app/APNS/my-application" + ) @mock_sns_deprecated @@ -34,21 +36,26 @@ def test_get_platform_application_attributes(): "PlatformPrincipal": "platform_principal", }, ) - arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - attributes = conn.get_platform_application_attributes(arn)['GetPlatformApplicationAttributesResponse'][ - 'GetPlatformApplicationAttributesResult']['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "platform_principal", - }) + arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + attributes = conn.get_platform_application_attributes(arn)[ + "GetPlatformApplicationAttributesResponse" + ]["GetPlatformApplicationAttributesResult"]["Attributes"] + attributes.should.equal( + { + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + } + ) @mock_sns_deprecated def test_get_missing_platform_application_attributes(): conn = boto.connect_sns() conn.get_platform_application_attributes.when.called_with( - "a-fake-arn").should.throw(BotoServerError) + "a-fake-arn" + ).should.throw(BotoServerError) @mock_sns_deprecated @@ -62,60 +69,50 @@ def test_set_platform_application_attributes(): "PlatformPrincipal": "platform_principal", }, ) - arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] - conn.set_platform_application_attributes(arn, - {"PlatformPrincipal": "other"} - ) - attributes = conn.get_platform_application_attributes(arn)['GetPlatformApplicationAttributesResponse'][ - 'GetPlatformApplicationAttributesResult']['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "other", - }) + arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] + conn.set_platform_application_attributes(arn, {"PlatformPrincipal": "other"}) + attributes = conn.get_platform_application_attributes(arn)[ + "GetPlatformApplicationAttributesResponse" + ]["GetPlatformApplicationAttributesResult"]["Attributes"] + attributes.should.equal( + {"PlatformCredential": "platform_credential", "PlatformPrincipal": "other"} + ) @mock_sns_deprecated def test_list_platform_applications(): conn = boto.connect_sns() - conn.create_platform_application( - name="application1", - platform="APNS", - ) - conn.create_platform_application( - name="application2", - platform="APNS", - ) + conn.create_platform_application(name="application1", platform="APNS") + conn.create_platform_application(name="application2", platform="APNS") applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['ListPlatformApplicationsResponse'][ - 'ListPlatformApplicationsResult']['PlatformApplications'] + applications = applications_repsonse["ListPlatformApplicationsResponse"][ + "ListPlatformApplicationsResult" + ]["PlatformApplications"] applications.should.have.length_of(2) @mock_sns_deprecated def test_delete_platform_application(): conn = boto.connect_sns() - conn.create_platform_application( - name="application1", - platform="APNS", - ) - conn.create_platform_application( - name="application2", - platform="APNS", - ) + conn.create_platform_application(name="application1", platform="APNS") + conn.create_platform_application(name="application2", platform="APNS") applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['ListPlatformApplicationsResponse'][ - 'ListPlatformApplicationsResult']['PlatformApplications'] + applications = applications_repsonse["ListPlatformApplicationsResponse"][ + "ListPlatformApplicationsResult" + ]["PlatformApplications"] applications.should.have.length_of(2) - application_arn = applications[0]['PlatformApplicationArn'] + application_arn = applications[0]["PlatformApplicationArn"] conn.delete_platform_application(application_arn) applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['ListPlatformApplicationsResponse'][ - 'ListPlatformApplicationsResult']['PlatformApplications'] + applications = applications_repsonse["ListPlatformApplicationsResponse"][ + "ListPlatformApplicationsResult" + ]["PlatformApplications"] applications.should.have.length_of(1) @@ -123,154 +120,152 @@ def test_delete_platform_application(): def test_create_platform_endpoint(): conn = boto.connect_sns() platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", + name="my-application", platform="APNS" ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( platform_application_arn=application_arn, token="some_unique_id", custom_user_data="some user data", - attributes={ - "Enabled": False, - }, + attributes={"Enabled": False}, ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] endpoint_arn.should.contain( - "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") + "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/" + ) @mock_sns_deprecated def test_get_list_endpoints_by_platform_application(): conn = boto.connect_sns() platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", + name="my-application", platform="APNS" ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( platform_application_arn=application_arn, token="some_unique_id", custom_user_data="some user data", - attributes={ - "CustomUserData": "some data", - }, + attributes={"CustomUserData": "some data"}, ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] endpoint_list = conn.list_endpoints_by_platform_application( platform_application_arn=application_arn - )['ListEndpointsByPlatformApplicationResponse']['ListEndpointsByPlatformApplicationResult']['Endpoints'] + )["ListEndpointsByPlatformApplicationResponse"][ + "ListEndpointsByPlatformApplicationResult" + ][ + "Endpoints" + ] endpoint_list.should.have.length_of(1) - endpoint_list[0]['Attributes']['CustomUserData'].should.equal('some data') - endpoint_list[0]['EndpointArn'].should.equal(endpoint_arn) + endpoint_list[0]["Attributes"]["CustomUserData"].should.equal("some data") + endpoint_list[0]["EndpointArn"].should.equal(endpoint_arn) @mock_sns_deprecated def test_get_endpoint_attributes(): conn = boto.connect_sns() platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", + name="my-application", platform="APNS" ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( platform_application_arn=application_arn, token="some_unique_id", custom_user_data="some user data", - attributes={ - "Enabled": False, - "CustomUserData": "some data", - }, + attributes={"Enabled": False, "CustomUserData": "some data"}, ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] - attributes = conn.get_endpoint_attributes(endpoint_arn)['GetEndpointAttributesResponse'][ - 'GetEndpointAttributesResult']['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'False', - "CustomUserData": "some data", - }) + attributes = conn.get_endpoint_attributes(endpoint_arn)[ + "GetEndpointAttributesResponse" + ]["GetEndpointAttributesResult"]["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "False", "CustomUserData": "some data"} + ) @mock_sns_deprecated def test_get_missing_endpoint_attributes(): conn = boto.connect_sns() - conn.get_endpoint_attributes.when.called_with( - "a-fake-arn").should.throw(BotoServerError) + conn.get_endpoint_attributes.when.called_with("a-fake-arn").should.throw( + BotoServerError + ) @mock_sns_deprecated def test_set_endpoint_attributes(): conn = boto.connect_sns() platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", + name="my-application", platform="APNS" ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( platform_application_arn=application_arn, token="some_unique_id", custom_user_data="some user data", - attributes={ - "Enabled": False, - "CustomUserData": "some data", - }, + attributes={"Enabled": False, "CustomUserData": "some data"}, ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] - conn.set_endpoint_attributes(endpoint_arn, - {"CustomUserData": "other data"} - ) - attributes = conn.get_endpoint_attributes(endpoint_arn)['GetEndpointAttributesResponse'][ - 'GetEndpointAttributesResult']['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'False', - "CustomUserData": "other data", - }) + conn.set_endpoint_attributes(endpoint_arn, {"CustomUserData": "other data"}) + attributes = conn.get_endpoint_attributes(endpoint_arn)[ + "GetEndpointAttributesResponse" + ]["GetEndpointAttributesResult"]["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "False", "CustomUserData": "other data"} + ) @mock_sns_deprecated def test_delete_endpoint(): conn = boto.connect_sns() platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", + name="my-application", platform="APNS" ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( platform_application_arn=application_arn, token="some_unique_id", custom_user_data="some user data", - attributes={ - "Enabled": False, - "CustomUserData": "some data", - }, + attributes={"Enabled": False, "CustomUserData": "some data"}, ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] endpoint_list = conn.list_endpoints_by_platform_application( platform_application_arn=application_arn - )['ListEndpointsByPlatformApplicationResponse']['ListEndpointsByPlatformApplicationResult']['Endpoints'] + )["ListEndpointsByPlatformApplicationResponse"][ + "ListEndpointsByPlatformApplicationResult" + ][ + "Endpoints" + ] endpoint_list.should.have.length_of(1) @@ -278,7 +273,11 @@ def test_delete_endpoint(): endpoint_list = conn.list_endpoints_by_platform_application( platform_application_arn=application_arn - )['ListEndpointsByPlatformApplicationResponse']['ListEndpointsByPlatformApplicationResult']['Endpoints'] + )["ListEndpointsByPlatformApplicationResponse"][ + "ListEndpointsByPlatformApplicationResult" + ][ + "Endpoints" + ] endpoint_list.should.have.length_of(0) @@ -286,23 +285,23 @@ def test_delete_endpoint(): def test_publish_to_platform_endpoint(): conn = boto.connect_sns() platform_application = conn.create_platform_application( - name="my-application", - platform="APNS", + name="my-application", platform="APNS" ) - application_arn = platform_application['CreatePlatformApplicationResponse'][ - 'CreatePlatformApplicationResult']['PlatformApplicationArn'] + application_arn = platform_application["CreatePlatformApplicationResponse"][ + "CreatePlatformApplicationResult" + ]["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( platform_application_arn=application_arn, token="some_unique_id", custom_user_data="some user data", - attributes={ - "Enabled": True, - }, + attributes={"Enabled": True}, ) - endpoint_arn = endpoint['CreatePlatformEndpointResponse'][ - 'CreatePlatformEndpointResult']['EndpointArn'] + endpoint_arn = endpoint["CreatePlatformEndpointResponse"][ + "CreatePlatformEndpointResult" + ]["EndpointArn"] - conn.publish(message="some message", message_structure="json", - target_arn=endpoint_arn) + conn.publish( + message="some message", message_structure="json", target_arn=endpoint_arn + ) diff --git a/tests/test_sns/test_application_boto3.py b/tests/test_sns/test_application_boto3.py index 1c9695fea..6f683b051 100644 --- a/tests/test_sns/test_application_boto3.py +++ b/tests/test_sns/test_application_boto3.py @@ -8,7 +8,7 @@ import sure # noqa @mock_sns def test_create_platform_application(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") response = conn.create_platform_application( Name="my-application", Platform="APNS", @@ -17,14 +17,15 @@ def test_create_platform_application(): "PlatformPrincipal": "platform_principal", }, ) - application_arn = response['PlatformApplicationArn'] + application_arn = response["PlatformApplicationArn"] application_arn.should.equal( - 'arn:aws:sns:us-east-1:123456789012:app/APNS/my-application') + "arn:aws:sns:us-east-1:123456789012:app/APNS/my-application" + ) @mock_sns def test_get_platform_application_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( Name="my-application", Platform="APNS", @@ -33,25 +34,29 @@ def test_get_platform_application_attributes(): "PlatformPrincipal": "platform_principal", }, ) - arn = platform_application['PlatformApplicationArn'] - attributes = conn.get_platform_application_attributes( - PlatformApplicationArn=arn)['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "platform_principal", - }) + arn = platform_application["PlatformApplicationArn"] + attributes = conn.get_platform_application_attributes(PlatformApplicationArn=arn)[ + "Attributes" + ] + attributes.should.equal( + { + "PlatformCredential": "platform_credential", + "PlatformPrincipal": "platform_principal", + } + ) @mock_sns def test_get_missing_platform_application_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.get_platform_application_attributes.when.called_with( - PlatformApplicationArn="a-fake-arn").should.throw(ClientError) + PlatformApplicationArn="a-fake-arn" + ).should.throw(ClientError) @mock_sns def test_set_platform_application_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( Name="my-application", Platform="APNS", @@ -60,291 +65,249 @@ def test_set_platform_application_attributes(): "PlatformPrincipal": "platform_principal", }, ) - arn = platform_application['PlatformApplicationArn'] - conn.set_platform_application_attributes(PlatformApplicationArn=arn, - Attributes={ - "PlatformPrincipal": "other"} - ) - attributes = conn.get_platform_application_attributes( - PlatformApplicationArn=arn)['Attributes'] - attributes.should.equal({ - "PlatformCredential": "platform_credential", - "PlatformPrincipal": "other", - }) + arn = platform_application["PlatformApplicationArn"] + conn.set_platform_application_attributes( + PlatformApplicationArn=arn, Attributes={"PlatformPrincipal": "other"} + ) + attributes = conn.get_platform_application_attributes(PlatformApplicationArn=arn)[ + "Attributes" + ] + attributes.should.equal( + {"PlatformCredential": "platform_credential", "PlatformPrincipal": "other"} + ) @mock_sns def test_list_platform_applications(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_platform_application( - Name="application1", - Platform="APNS", - Attributes={}, + Name="application1", Platform="APNS", Attributes={} ) conn.create_platform_application( - Name="application2", - Platform="APNS", - Attributes={}, + Name="application2", Platform="APNS", Attributes={} ) applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['PlatformApplications'] + applications = applications_repsonse["PlatformApplications"] applications.should.have.length_of(2) @mock_sns def test_delete_platform_application(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_platform_application( - Name="application1", - Platform="APNS", - Attributes={}, + Name="application1", Platform="APNS", Attributes={} ) conn.create_platform_application( - Name="application2", - Platform="APNS", - Attributes={}, + Name="application2", Platform="APNS", Attributes={} ) applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['PlatformApplications'] + applications = applications_repsonse["PlatformApplications"] applications.should.have.length_of(2) - application_arn = applications[0]['PlatformApplicationArn'] + application_arn = applications[0]["PlatformApplicationArn"] conn.delete_platform_application(PlatformApplicationArn=application_arn) applications_repsonse = conn.list_platform_applications() - applications = applications_repsonse['PlatformApplications'] + applications = applications_repsonse["PlatformApplications"] applications.should.have.length_of(1) @mock_sns def test_create_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] endpoint_arn.should.contain( - "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/") + "arn:aws:sns:us-east-1:123456789012:endpoint/APNS/my-application/" + ) @mock_sns def test_create_duplicate_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ) endpoint = conn.create_platform_endpoint.when.called_with( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ).should.throw(ClientError) @mock_sns def test_get_list_endpoints_by_platform_application(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "CustomUserData": "some data", - }, + Attributes={"CustomUserData": "some data"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] endpoint_list = conn.list_endpoints_by_platform_application( PlatformApplicationArn=application_arn - )['Endpoints'] + )["Endpoints"] endpoint_list.should.have.length_of(1) - endpoint_list[0]['Attributes']['CustomUserData'].should.equal('some data') - endpoint_list[0]['EndpointArn'].should.equal(endpoint_arn) + endpoint_list[0]["Attributes"]["CustomUserData"].should.equal("some data") + endpoint_list[0]["EndpointArn"].should.equal(endpoint_arn) @mock_sns def test_get_endpoint_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - "CustomUserData": "some data", - }, + Attributes={"Enabled": "false", "CustomUserData": "some data"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] - attributes = conn.get_endpoint_attributes( - EndpointArn=endpoint_arn)['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'false', - "CustomUserData": "some data", - }) + attributes = conn.get_endpoint_attributes(EndpointArn=endpoint_arn)["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "false", "CustomUserData": "some data"} + ) @mock_sns def test_get_missing_endpoint_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.get_endpoint_attributes.when.called_with( - EndpointArn="a-fake-arn").should.throw(ClientError) + EndpointArn="a-fake-arn" + ).should.throw(ClientError) @mock_sns def test_set_endpoint_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - "CustomUserData": "some data", - }, + Attributes={"Enabled": "false", "CustomUserData": "some data"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] - conn.set_endpoint_attributes(EndpointArn=endpoint_arn, - Attributes={"CustomUserData": "other data"} - ) - attributes = conn.get_endpoint_attributes( - EndpointArn=endpoint_arn)['Attributes'] - attributes.should.equal({ - "Token": "some_unique_id", - "Enabled": 'false', - "CustomUserData": "other data", - }) + conn.set_endpoint_attributes( + EndpointArn=endpoint_arn, Attributes={"CustomUserData": "other data"} + ) + attributes = conn.get_endpoint_attributes(EndpointArn=endpoint_arn)["Attributes"] + attributes.should.equal( + {"Token": "some_unique_id", "Enabled": "false", "CustomUserData": "other data"} + ) @mock_sns def test_publish_to_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'true', - }, + Attributes={"Enabled": "true"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] - conn.publish(Message="some message", - MessageStructure="json", TargetArn=endpoint_arn) + conn.publish( + Message="some message", MessageStructure="json", TargetArn=endpoint_arn + ) @mock_sns def test_publish_to_disabled_platform_endpoint(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") platform_application = conn.create_platform_application( - Name="my-application", - Platform="APNS", - Attributes={}, + Name="my-application", Platform="APNS", Attributes={} ) - application_arn = platform_application['PlatformApplicationArn'] + application_arn = platform_application["PlatformApplicationArn"] endpoint = conn.create_platform_endpoint( PlatformApplicationArn=application_arn, Token="some_unique_id", CustomUserData="some user data", - Attributes={ - "Enabled": 'false', - }, + Attributes={"Enabled": "false"}, ) - endpoint_arn = endpoint['EndpointArn'] + endpoint_arn = endpoint["EndpointArn"] conn.publish.when.called_with( - Message="some message", - MessageStructure="json", - TargetArn=endpoint_arn, + Message="some message", MessageStructure="json", TargetArn=endpoint_arn ).should.throw(ClientError) @mock_sns def test_set_sms_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") - conn.set_sms_attributes(attributes={'DefaultSMSType': 'Transactional', 'test': 'test'}) + conn.set_sms_attributes( + attributes={"DefaultSMSType": "Transactional", "test": "test"} + ) response = conn.get_sms_attributes() - response.should.contain('attributes') - response['attributes'].should.contain('DefaultSMSType') - response['attributes'].should.contain('test') - response['attributes']['DefaultSMSType'].should.equal('Transactional') - response['attributes']['test'].should.equal('test') + response.should.contain("attributes") + response["attributes"].should.contain("DefaultSMSType") + response["attributes"].should.contain("test") + response["attributes"]["DefaultSMSType"].should.equal("Transactional") + response["attributes"]["test"].should.equal("test") @mock_sns def test_get_sms_attributes_filtered(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") - conn.set_sms_attributes(attributes={'DefaultSMSType': 'Transactional', 'test': 'test'}) + conn.set_sms_attributes( + attributes={"DefaultSMSType": "Transactional", "test": "test"} + ) - response = conn.get_sms_attributes(attributes=['DefaultSMSType']) - response.should.contain('attributes') - response['attributes'].should.contain('DefaultSMSType') - response['attributes'].should_not.contain('test') - response['attributes']['DefaultSMSType'].should.equal('Transactional') + response = conn.get_sms_attributes(attributes=["DefaultSMSType"]) + response.should.contain("attributes") + response["attributes"].should.contain("DefaultSMSType") + response["attributes"].should_not.contain("test") + response["attributes"]["DefaultSMSType"].should.equal("Transactional") diff --git a/tests/test_sns/test_publishing.py b/tests/test_sns/test_publishing.py index 964296837..b45277bde 100644 --- a/tests/test_sns/test_publishing.py +++ b/tests/test_sns/test_publishing.py @@ -18,25 +18,38 @@ def test_publish_to_sqs(): conn = boto.connect_sns() conn.create_topic("some-topic") topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] sqs_conn = boto.connect_sqs() sqs_conn.create_queue("test-queue") - conn.subscribe(topic_arn, "sqs", - "arn:aws:sqs:us-east-1:123456789012:test-queue") + conn.subscribe(topic_arn, "sqs", "arn:aws:sqs:us-east-1:123456789012:test-queue") - message_to_publish = 'my message' + message_to_publish = "my message" subject_to_publish = "test subject" with freeze_time("2015-01-01 12:00:00"): - published_message = conn.publish(topic=topic_arn, message=message_to_publish, subject=subject_to_publish) - published_message_id = published_message['PublishResponse']['PublishResult']['MessageId'] + published_message = conn.publish( + topic=topic_arn, message=message_to_publish, subject=subject_to_publish + ) + published_message_id = published_message["PublishResponse"]["PublishResult"][ + "MessageId" + ] queue = sqs_conn.get_queue("test-queue") message = queue.read(1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message_to_publish, published_message_id, subject_to_publish, 'us-east-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", '2015-01-01T12:00:00.000Z', message.get_body()) + expected = MESSAGE_FROM_SQS_TEMPLATE % ( + message_to_publish, + published_message_id, + subject_to_publish, + "us-east-1", + ) + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + message.get_body(), + ) acquired_message.should.equal(expected) @@ -46,24 +59,37 @@ def test_publish_to_sqs_in_different_region(): conn = boto.sns.connect_to_region("us-west-1") conn.create_topic("some-topic") topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] sqs_conn = boto.sqs.connect_to_region("us-west-2") sqs_conn.create_queue("test-queue") - conn.subscribe(topic_arn, "sqs", - "arn:aws:sqs:us-west-2:123456789012:test-queue") + conn.subscribe(topic_arn, "sqs", "arn:aws:sqs:us-west-2:123456789012:test-queue") - message_to_publish = 'my message' + message_to_publish = "my message" subject_to_publish = "test subject" with freeze_time("2015-01-01 12:00:00"): - published_message = conn.publish(topic=topic_arn, message=message_to_publish, subject=subject_to_publish) - published_message_id = published_message['PublishResponse']['PublishResult']['MessageId'] + published_message = conn.publish( + topic=topic_arn, message=message_to_publish, subject=subject_to_publish + ) + published_message_id = published_message["PublishResponse"]["PublishResult"][ + "MessageId" + ] queue = sqs_conn.get_queue("test-queue") message = queue.read(1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message_to_publish, published_message_id, subject_to_publish, 'us-west-1') + expected = MESSAGE_FROM_SQS_TEMPLATE % ( + message_to_publish, + published_message_id, + subject_to_publish, + "us-west-1", + ) - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", '2015-01-01T12:00:00.000Z', message.get_body()) + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + message.get_body(), + ) acquired_message.should.equal(expected) diff --git a/tests/test_sns/test_publishing_boto3.py b/tests/test_sns/test_publishing_boto3.py index d7bf32e51..64669d5e0 100644 --- a/tests/test_sns/test_publishing_boto3.py +++ b/tests/test_sns/test_publishing_boto3.py @@ -20,45 +20,53 @@ MESSAGE_FROM_SQS_TEMPLATE = '{\n "Message": "%s",\n "MessageId": "%s",\n "Sig @mock_sqs @mock_sns def test_publish_to_sqs(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - sqs_conn = boto3.resource('sqs', region_name='us-east-1') + sqs_conn = boto3.resource("sqs", region_name="us-east-1") sqs_conn.create_queue(QueueName="test-queue") - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue", + ) + message = "my message" with freeze_time("2015-01-01 12:00:00"): published_message = conn.publish(TopicArn=topic_arn, Message=message) - published_message_id = published_message['MessageId'] + published_message_id = published_message["MessageId"] queue = sqs_conn.get_queue_by_name(QueueName="test-queue") messages = queue.receive_messages(MaxNumberOfMessages=1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, 'us-east-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", u'2015-01-01T12:00:00.000Z', messages[0].body) + expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, "us-east-1") + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + messages[0].body, + ) acquired_message.should.equal(expected) @mock_sqs @mock_sns def test_publish_to_sqs_raw(): - sns = boto3.resource('sns', region_name='us-east-1') - topic = sns.create_topic(Name='some-topic') + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="some-topic") - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") subscription = topic.subscribe( - Protocol='sqs', Endpoint=queue.attributes['QueueArn']) + Protocol="sqs", Endpoint=queue.attributes["QueueArn"] + ) subscription.set_attributes( - AttributeName='RawMessageDelivery', AttributeValue='true') + AttributeName="RawMessageDelivery", AttributeValue="true" + ) - message = 'my message' + message = "my message" with freeze_time("2015-01-01 12:00:00"): topic.publish(Message=message) @@ -69,195 +77,213 @@ def test_publish_to_sqs_raw(): @mock_sqs @mock_sns def test_publish_to_sqs_bad(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - sqs_conn = boto3.resource('sqs', region_name='us-east-1') + sqs_conn = boto3.resource("sqs", region_name="us-east-1") sqs_conn.create_queue(QueueName="test-queue") - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue", + ) + message = "my message" try: # Test missing Value conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': {'DataType': 'String'}}) + TopicArn=topic_arn, + Message=message, + MessageAttributes={"store": {"DataType": "String"}}, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") try: # Test empty DataType (if the DataType field is missing entirely # botocore throws an exception during validation) conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': { - 'DataType': '', - 'StringValue': 'example_corp' - }}) + TopicArn=topic_arn, + Message=message, + MessageAttributes={ + "store": {"DataType": "", "StringValue": "example_corp"} + }, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") try: # Test empty Value conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': { - 'DataType': 'String', - 'StringValue': '' - }}) + TopicArn=topic_arn, + Message=message, + MessageAttributes={"store": {"DataType": "String", "StringValue": ""}}, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") try: # Test Number DataType, with a non numeric value conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'price': { - 'DataType': 'Number', - 'StringValue': 'error' - }}) + TopicArn=topic_arn, + Message=message, + MessageAttributes={"price": {"DataType": "Number", "StringValue": "error"}}, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') - err.response['Error']['Message'].should.equal("An error occurred (ParameterValueInvalid) when calling the Publish operation: Could not cast message attribute 'price' value to number.") + err.response["Error"]["Code"].should.equal("InvalidParameterValue") + err.response["Error"]["Message"].should.equal( + "An error occurred (ParameterValueInvalid) when calling the Publish operation: Could not cast message attribute 'price' value to number." + ) @mock_sqs @mock_sns def test_publish_to_sqs_msg_attr_byte_value(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - sqs_conn = boto3.resource('sqs', region_name='us-east-1') + sqs_conn = boto3.resource("sqs", region_name="us-east-1") queue = sqs_conn.create_queue(QueueName="test-queue") - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue", + ) + message = "my message" conn.publish( - TopicArn=topic_arn, Message=message, - MessageAttributes={'store': { - 'DataType': 'Binary', - 'BinaryValue': b'\x02\x03\x04' - }}) + TopicArn=topic_arn, + Message=message, + MessageAttributes={ + "store": {"DataType": "Binary", "BinaryValue": b"\x02\x03\x04"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([{ - 'store': { - 'Type': 'Binary', - 'Value': base64.b64encode(b'\x02\x03\x04').decode() - } - }]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": { + "Type": "Binary", + "Value": base64.b64encode(b"\x02\x03\x04").decode(), + } + } + ] + ) @mock_sns def test_publish_sms(): - client = boto3.client('sns', region_name='us-east-1') + client = boto3.client("sns", region_name="us-east-1") client.create_topic(Name="some-topic") resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] + arn = resp["TopicArn"] - client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='+15551234567' - ) + client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="+15551234567") result = client.publish(PhoneNumber="+15551234567", Message="my message") - result.should.contain('MessageId') + result.should.contain("MessageId") @mock_sns def test_publish_bad_sms(): - client = boto3.client('sns', region_name='us-east-1') + client = boto3.client("sns", region_name="us-east-1") client.create_topic(Name="some-topic") resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] + arn = resp["TopicArn"] - client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='+15551234567' - ) + client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="+15551234567") try: # Test invalid number client.publish(PhoneNumber="NAA+15551234567", Message="my message") except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') + err.response["Error"]["Code"].should.equal("InvalidParameter") try: # Test not found number client.publish(PhoneNumber="+44001234567", Message="my message") except ClientError as err: - err.response['Error']['Code'].should.equal('ParameterValueInvalid') + err.response["Error"]["Code"].should.equal("ParameterValueInvalid") @mock_sqs @mock_sns def test_publish_to_sqs_dump_json(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - sqs_conn = boto3.resource('sqs', region_name='us-east-1') + sqs_conn = boto3.resource("sqs", region_name="us-east-1") sqs_conn.create_queue(QueueName="test-queue") - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue", + ) - message = json.dumps({ - "Records": [{ - "eventVersion": "2.0", - "eventSource": "aws:s3", - "s3": { - "s3SchemaVersion": "1.0" - } - }] - }, sort_keys=True) + message = json.dumps( + { + "Records": [ + { + "eventVersion": "2.0", + "eventSource": "aws:s3", + "s3": {"s3SchemaVersion": "1.0"}, + } + ] + }, + sort_keys=True, + ) with freeze_time("2015-01-01 12:00:00"): published_message = conn.publish(TopicArn=topic_arn, Message=message) - published_message_id = published_message['MessageId'] + published_message_id = published_message["MessageId"] queue = sqs_conn.get_queue_by_name(QueueName="test-queue") messages = queue.receive_messages(MaxNumberOfMessages=1) escaped = message.replace('"', '\\"') - expected = MESSAGE_FROM_SQS_TEMPLATE % (escaped, published_message_id, 'us-east-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", u'2015-01-01T12:00:00.000Z', messages[0].body) + expected = MESSAGE_FROM_SQS_TEMPLATE % (escaped, published_message_id, "us-east-1") + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + messages[0].body, + ) acquired_message.should.equal(expected) @mock_sqs @mock_sns def test_publish_to_sqs_in_different_region(): - conn = boto3.client('sns', region_name='us-west-1') + conn = boto3.client("sns", region_name="us-west-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - sqs_conn = boto3.resource('sqs', region_name='us-west-2') + sqs_conn = boto3.resource("sqs", region_name="us-west-2") sqs_conn.create_queue(QueueName="test-queue") - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-west-2:123456789012:test-queue") + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-west-2:123456789012:test-queue", + ) - message = 'my message' + message = "my message" with freeze_time("2015-01-01 12:00:00"): published_message = conn.publish(TopicArn=topic_arn, Message=message) - published_message_id = published_message['MessageId'] + published_message_id = published_message["MessageId"] queue = sqs_conn.get_queue_by_name(QueueName="test-queue") messages = queue.receive_messages(MaxNumberOfMessages=1) - expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, 'us-west-1') - acquired_message = re.sub("\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", u'2015-01-01T12:00:00.000Z', messages[0].body) + expected = MESSAGE_FROM_SQS_TEMPLATE % (message, published_message_id, "us-west-1") + acquired_message = re.sub( + "\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{3}Z", + "2015-01-01T12:00:00.000Z", + messages[0].body, + ) acquired_message.should.equal(expected) @@ -266,47 +292,46 @@ def test_publish_to_sqs_in_different_region(): def test_publish_to_http(): def callback(request): request.headers["Content-Type"].should.equal("text/plain; charset=UTF-8") - json.loads.when.called_with( - request.body.decode() - ).should_not.throw(Exception) + json.loads.when.called_with(request.body.decode()).should_not.throw(Exception) return 200, {}, "" responses.add_callback( - method="POST", - url="http://example.com/foobar", - callback=callback, + method="POST", url="http://example.com/foobar", callback=callback ) - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/foobar") + conn.subscribe( + TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/foobar" + ) response = conn.publish( - TopicArn=topic_arn, Message="my message", Subject="my subject") + TopicArn=topic_arn, Message="my message", Subject="my subject" + ) @mock_sqs @mock_sns def test_publish_subject(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - sqs_conn = boto3.resource('sqs', region_name='us-east-1') + sqs_conn = boto3.resource("sqs", region_name="us-east-1") sqs_conn.create_queue(QueueName="test-queue") - conn.subscribe(TopicArn=topic_arn, - Protocol="sqs", - Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue") - message = 'my message' - subject1 = 'test subject' - subject2 = 'test subject' * 20 + conn.subscribe( + TopicArn=topic_arn, + Protocol="sqs", + Endpoint="arn:aws:sqs:us-east-1:123456789012:test-queue", + ) + message = "my message" + subject1 = "test subject" + subject2 = "test subject" * 20 with freeze_time("2015-01-01 12:00:00"): conn.publish(TopicArn=topic_arn, Message=message, Subject=subject1) @@ -315,37 +340,37 @@ def test_publish_subject(): with freeze_time("2015-01-01 12:00:00"): conn.publish(TopicArn=topic_arn, Message=message, Subject=subject2) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') + err.response["Error"]["Code"].should.equal("InvalidParameter") else: - raise RuntimeError('Should have raised an InvalidParameter exception') + raise RuntimeError("Should have raised an InvalidParameter exception") @mock_sns def test_publish_message_too_long(): - sns = boto3.resource('sns', region_name='us-east-1') - topic = sns.create_topic(Name='some-topic') + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="some-topic") with assert_raises(ClientError): - topic.publish( - Message="".join(["." for i in range(0, 262145)])) + topic.publish(Message="".join(["." for i in range(0, 262145)])) # message short enough - does not raise an error - topic.publish( - Message="".join(["." for i in range(0, 262144)])) + topic.publish(Message="".join(["." for i in range(0, 262144)])) def _setup_filter_policy_test(filter_policy): - sns = boto3.resource('sns', region_name='us-east-1') - topic = sns.create_topic(Name='some-topic') + sns = boto3.resource("sns", region_name="us-east-1") + topic = sns.create_topic(Name="some-topic") - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") subscription = topic.subscribe( - Protocol='sqs', Endpoint=queue.attributes['QueueArn']) + Protocol="sqs", Endpoint=queue.attributes["QueueArn"] + ) subscription.set_attributes( - AttributeName='FilterPolicy', AttributeValue=json.dumps(filter_policy)) + AttributeName="FilterPolicy", AttributeValue=json.dumps(filter_policy) + ) return topic, subscription, queue @@ -353,248 +378,252 @@ def _setup_filter_policy_test(filter_policy): @mock_sqs @mock_sns def test_filtering_exact_string(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) topic.publish( - Message='match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}}) + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'store': {'Type': 'String', 'Value': 'example_corp'}}]) + [{"store": {"Type": "String", "Value": "example_corp"}}] + ) @mock_sqs @mock_sns def test_filtering_exact_string_multiple_message_attributes(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) topic.publish( - Message='match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}}) + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([{ - 'store': {'Type': 'String', 'Value': 'example_corp'}, - 'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": {"Type": "String", "Value": "example_corp"}, + "event": {"Type": "String", "Value": "order_cancelled"}, + } + ] + ) @mock_sqs @mock_sns def test_filtering_exact_string_OR_matching(): topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp', 'different_corp']}) + {"store": ["example_corp", "different_corp"]} + ) topic.publish( - Message='match example_corp', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}}) + Message="match example_corp", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) topic.publish( - Message='match different_corp', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'different_corp'}}) + Message="match different_corp", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "different_corp"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal( - ['match example_corp', 'match different_corp']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([ - {'store': {'Type': 'String', 'Value': 'example_corp'}}, - {'store': {'Type': 'String', 'Value': 'different_corp'}}]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match example_corp", "match different_corp"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + {"store": {"Type": "String", "Value": "example_corp"}}, + {"store": {"Type": "String", "Value": "different_corp"}}, + ] + ) @mock_sqs @mock_sns def test_filtering_exact_string_AND_matching_positive(): topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp'], - 'event': ['order_cancelled']}) + {"store": ["example_corp"], "event": ["order_cancelled"]} + ) topic.publish( - Message='match example_corp order_cancelled', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}}) + Message="match example_corp order_cancelled", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal( - ['match example_corp order_cancelled']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([{ - 'store': {'Type': 'String', 'Value': 'example_corp'}, - 'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match example_corp order_cancelled"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": {"Type": "String", "Value": "example_corp"}, + "event": {"Type": "String", "Value": "order_cancelled"}, + } + ] + ) @mock_sqs @mock_sns def test_filtering_exact_string_AND_matching_no_match(): topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp'], - 'event': ['order_cancelled']}) + {"store": ["example_corp"], "event": ["order_cancelled"]} + ) topic.publish( - Message='match example_corp order_accepted', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_accepted'}}) + Message="match example_corp order_accepted", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_accepted"}, + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @mock_sqs @mock_sns def test_filtering_exact_string_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) topic.publish( - Message='no match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'different_corp'}}) + Message="no match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "different_corp"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @mock_sqs @mock_sns def test_filtering_exact_string_no_attributes_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'store': ['example_corp']}) + topic, subscription, queue = _setup_filter_policy_test({"store": ["example_corp"]}) - topic.publish(Message='no match') + topic.publish(Message="no match") messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @mock_sqs @mock_sns def test_filtering_exact_number_int(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) topic.publish( - Message='match', - MessageAttributes={'price': {'DataType': 'Number', - 'StringValue': '100'}}) + Message="match", + MessageAttributes={"price": {"DataType": "Number", "StringValue": "100"}}, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal( - [{'price': {'Type': 'Number', 'Value': 100}}]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"price": {"Type": "Number", "Value": 100}}]) @mock_sqs @mock_sns def test_filtering_exact_number_float(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100.1]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100.1]}) topic.publish( - Message='match', - MessageAttributes={'price': {'DataType': 'Number', - 'StringValue': '100.1'}}) + Message="match", + MessageAttributes={"price": {"DataType": "Number", "StringValue": "100.1"}}, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal( - [{'price': {'Type': 'Number', 'Value': 100.1}}]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal([{"price": {"Type": "Number", "Value": 100.1}}]) @mock_sqs @mock_sns def test_filtering_exact_number_float_accuracy(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100.123456789]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100.123456789]}) topic.publish( - Message='match', - MessageAttributes={'price': {'DataType': 'Number', - 'StringValue': '100.1234561'}}) + Message="match", + MessageAttributes={ + "price": {"DataType": "Number", "StringValue": "100.1234561"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'price': {'Type': 'Number', 'Value': 100.1234561}}]) + [{"price": {"Type": "Number", "Value": 100.1234561}}] + ) @mock_sqs @mock_sns def test_filtering_exact_number_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) topic.publish( - Message='no match', - MessageAttributes={'price': {'DataType': 'Number', - 'StringValue': '101'}}) + Message="no match", + MessageAttributes={"price": {"DataType": "Number", "StringValue": "101"}}, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @mock_sqs @mock_sns def test_filtering_exact_number_with_string_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) topic.publish( - Message='no match', - MessageAttributes={'price': {'DataType': 'String', - 'StringValue': '100'}}) + Message="no match", + MessageAttributes={"price": {"DataType": "String", "StringValue": "100"}}, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @@ -602,118 +631,142 @@ def test_filtering_exact_number_with_string_no_match(): @mock_sns def test_filtering_string_array_match(): topic, subscription, queue = _setup_filter_policy_test( - {'customer_interests': ['basketball', 'baseball']}) + {"customer_interests": ["basketball", "baseball"]} + ) topic.publish( - Message='match', - MessageAttributes={'customer_interests': {'DataType': 'String.Array', - 'StringValue': json.dumps(['basketball', 'rugby'])}}) + Message="match", + MessageAttributes={ + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + } + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}}]) + [ + { + "customer_interests": { + "Type": "String.Array", + "Value": json.dumps(["basketball", "rugby"]), + } + } + ] + ) @mock_sqs @mock_sns def test_filtering_string_array_no_match(): topic, subscription, queue = _setup_filter_policy_test( - {'customer_interests': ['baseball']}) + {"customer_interests": ["baseball"]} + ) topic.publish( - Message='no_match', - MessageAttributes={'customer_interests': {'DataType': 'String.Array', - 'StringValue': json.dumps(['basketball', 'rugby'])}}) + Message="no_match", + MessageAttributes={ + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + } + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @mock_sqs @mock_sns def test_filtering_string_array_with_number_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100, 500]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100, 500]}) topic.publish( - Message='match', - MessageAttributes={'price': {'DataType': 'String.Array', - 'StringValue': json.dumps([100, 50])}}) + Message="match", + MessageAttributes={ + "price": {"DataType": "String.Array", "StringValue": json.dumps([100, 50])} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'price': {'Type': 'String.Array', 'Value': json.dumps([100, 50])}}]) + [{"price": {"Type": "String.Array", "Value": json.dumps([100, 50])}}] + ) @mock_sqs @mock_sns def test_filtering_string_array_with_number_float_accuracy_match(): topic, subscription, queue = _setup_filter_policy_test( - {'price': [100.123456789, 500]}) + {"price": [100.123456789, 500]} + ) topic.publish( - Message='match', - MessageAttributes={'price': {'DataType': 'String.Array', - 'StringValue': json.dumps([100.1234561, 50])}}) + Message="match", + MessageAttributes={ + "price": { + "DataType": "String.Array", + "StringValue": json.dumps([100.1234561, 50]), + } + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'price': {'Type': 'String.Array', 'Value': json.dumps([100.1234561, 50])}}]) + [{"price": {"Type": "String.Array", "Value": json.dumps([100.1234561, 50])}}] + ) @mock_sqs @mock_sns # this is the correct behavior from SNS def test_filtering_string_array_with_number_no_array_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100, 500]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100, 500]}) topic.publish( - Message='match', - MessageAttributes={'price': {'DataType': 'String.Array', - 'StringValue': '100'}}) + Message="match", + MessageAttributes={"price": {"DataType": "String.Array", "StringValue": "100"}}, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'price': {'Type': 'String.Array', 'Value': '100'}}]) + [{"price": {"Type": "String.Array", "Value": "100"}}] + ) @mock_sqs @mock_sns def test_filtering_string_array_with_number_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [500]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [500]}) topic.publish( - Message='no_match', - MessageAttributes={'price': {'DataType': 'String.Array', - 'StringValue': json.dumps([100, 50])}}) + Message="no_match", + MessageAttributes={ + "price": {"DataType": "String.Array", "StringValue": json.dumps([100, 50])} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @@ -721,19 +774,19 @@ def test_filtering_string_array_with_number_no_match(): @mock_sns # this is the correct behavior from SNS def test_filtering_string_array_with_string_no_array_no_match(): - topic, subscription, queue = _setup_filter_policy_test( - {'price': [100]}) + topic, subscription, queue = _setup_filter_policy_test({"price": [100]}) topic.publish( - Message='no_match', - MessageAttributes={'price': {'DataType': 'String.Array', - 'StringValue': 'one hundread'}}) + Message="no_match", + MessageAttributes={ + "price": {"DataType": "String.Array", "StringValue": "one hundread"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @@ -741,38 +794,43 @@ def test_filtering_string_array_with_string_no_array_no_match(): @mock_sns def test_filtering_attribute_key_exists_match(): topic, subscription, queue = _setup_filter_policy_test( - {'store': [{'exists': True}]}) + {"store": [{"exists": True}]} + ) topic.publish( - Message='match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}}) + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'store': {'Type': 'String', 'Value': 'example_corp'}}]) + [{"store": {"Type": "String", "Value": "example_corp"}}] + ) @mock_sqs @mock_sns def test_filtering_attribute_key_exists_no_match(): topic, subscription, queue = _setup_filter_policy_test( - {'store': [{'exists': True}]}) + {"store": [{"exists": True}]} + ) topic.publish( - Message='no match', - MessageAttributes={'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}}) + Message="no match", + MessageAttributes={ + "event": {"DataType": "String", "StringValue": "order_cancelled"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @@ -780,38 +838,43 @@ def test_filtering_attribute_key_exists_no_match(): @mock_sns def test_filtering_attribute_key_not_exists_match(): topic, subscription, queue = _setup_filter_policy_test( - {'store': [{'exists': False}]}) + {"store": [{"exists": False}]} + ) topic.publish( - Message='match', - MessageAttributes={'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}}) + Message="match", + MessageAttributes={ + "event": {"DataType": "String", "StringValue": "order_cancelled"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal(['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal( - [{'event': {'Type': 'String', 'Value': 'order_cancelled'}}]) + [{"event": {"Type": "String", "Value": "order_cancelled"}}] + ) @mock_sqs @mock_sns def test_filtering_attribute_key_not_exists_no_match(): topic, subscription, queue = _setup_filter_policy_test( - {'store': [{'exists': False}]}) + {"store": [{"exists": False}]} + ) topic.publish( - Message='no match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}}) + Message="no match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"} + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) @@ -819,59 +882,74 @@ def test_filtering_attribute_key_not_exists_no_match(): @mock_sns def test_filtering_all_AND_matching_match(): topic, subscription, queue = _setup_filter_policy_test( - {'store': [{'exists': True}], - 'event': ['order_cancelled'], - 'customer_interests': ['basketball', 'baseball'], - 'price': [100]}) + { + "store": [{"exists": True}], + "event": ["order_cancelled"], + "customer_interests": ["basketball", "baseball"], + "price": [100], + } + ) topic.publish( - Message='match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}, - 'customer_interests': {'DataType': 'String.Array', - 'StringValue': json.dumps(['basketball', 'rugby'])}, - 'price': {'DataType': 'Number', - 'StringValue': '100'}}) + Message="match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + }, + "price": {"DataType": "Number", "StringValue": "100"}, + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] - message_bodies.should.equal( - ['match']) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] - message_attributes.should.equal([{ - 'store': {'Type': 'String', 'Value': 'example_corp'}, - 'event': {'Type': 'String', 'Value': 'order_cancelled'}, - 'customer_interests': {'Type': 'String.Array', 'Value': json.dumps(['basketball', 'rugby'])}, - 'price': {'Type': 'Number', 'Value': 100}}]) + message_bodies = [json.loads(m.body)["Message"] for m in messages] + message_bodies.should.equal(["match"]) + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] + message_attributes.should.equal( + [ + { + "store": {"Type": "String", "Value": "example_corp"}, + "event": {"Type": "String", "Value": "order_cancelled"}, + "customer_interests": { + "Type": "String.Array", + "Value": json.dumps(["basketball", "rugby"]), + }, + "price": {"Type": "Number", "Value": 100}, + } + ] + ) @mock_sqs @mock_sns def test_filtering_all_AND_matching_no_match(): topic, subscription, queue = _setup_filter_policy_test( - {'store': [{'exists': True}], - 'event': ['order_cancelled'], - 'customer_interests': ['basketball', 'baseball'], - 'price': [100], - "encrypted": [False]}) + { + "store": [{"exists": True}], + "event": ["order_cancelled"], + "customer_interests": ["basketball", "baseball"], + "price": [100], + "encrypted": [False], + } + ) topic.publish( - Message='no match', - MessageAttributes={'store': {'DataType': 'String', - 'StringValue': 'example_corp'}, - 'event': {'DataType': 'String', - 'StringValue': 'order_cancelled'}, - 'customer_interests': {'DataType': 'String.Array', - 'StringValue': json.dumps(['basketball', 'rugby'])}, - 'price': {'DataType': 'Number', - 'StringValue': '100'}}) + Message="no match", + MessageAttributes={ + "store": {"DataType": "String", "StringValue": "example_corp"}, + "event": {"DataType": "String", "StringValue": "order_cancelled"}, + "customer_interests": { + "DataType": "String.Array", + "StringValue": json.dumps(["basketball", "rugby"]), + }, + "price": {"DataType": "Number", "StringValue": "100"}, + }, + ) messages = queue.receive_messages(MaxNumberOfMessages=5) - message_bodies = [json.loads(m.body)['Message'] for m in messages] + message_bodies = [json.loads(m.body)["Message"] for m in messages] message_bodies.should.equal([]) - message_attributes = [ - json.loads(m.body)['MessageAttributes'] for m in messages] + message_attributes = [json.loads(m.body)["MessageAttributes"] for m in messages] message_attributes.should.equal([]) diff --git a/tests/test_sns/test_server.py b/tests/test_sns/test_server.py index 465dfa2c2..ec8bbe201 100644 --- a/tests/test_sns/test_server.py +++ b/tests/test_sns/test_server.py @@ -4,9 +4,9 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_sns_server_get(): @@ -16,9 +16,11 @@ def test_sns_server_get(): topic_data = test_client.action_data("CreateTopic", Name="testtopic") topic_data.should.contain("CreateTopicResult") topic_data.should.contain( - "arn:aws:sns:us-east-1:123456789012:testtopic") + "arn:aws:sns:us-east-1:123456789012:testtopic" + ) topics_data = test_client.action_data("ListTopics") topics_data.should.contain("ListTopicsResult") topic_data.should.contain( - "arn:aws:sns:us-east-1:123456789012:testtopic") + "arn:aws:sns:us-east-1:123456789012:testtopic" + ) diff --git a/tests/test_sns/test_subscriptions.py b/tests/test_sns/test_subscriptions.py index ba241ba44..fbd4274f4 100644 --- a/tests/test_sns/test_subscriptions.py +++ b/tests/test_sns/test_subscriptions.py @@ -12,13 +12,15 @@ def test_creating_subscription(): conn = boto.connect_sns() conn.create_topic("some-topic") topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] conn.subscribe(topic_arn, "http", "http://example.com/") subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] + "ListSubscriptionsResult" + ]["Subscriptions"] subscriptions.should.have.length_of(1) subscription = subscriptions[0] subscription["TopicArn"].should.equal(topic_arn) @@ -31,7 +33,8 @@ def test_creating_subscription(): # And there should be zero subscriptions left subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] + "ListSubscriptionsResult" + ]["Subscriptions"] subscriptions.should.have.length_of(0) @@ -40,13 +43,15 @@ def test_deleting_subscriptions_by_deleting_topic(): conn = boto.connect_sns() conn.create_topic("some-topic") topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] conn.subscribe(topic_arn, "http", "http://example.com/") subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] + "ListSubscriptionsResult" + ]["Subscriptions"] subscriptions.should.have.length_of(1) subscription = subscriptions[0] subscription["TopicArn"].should.equal(topic_arn) @@ -64,7 +69,8 @@ def test_deleting_subscriptions_by_deleting_topic(): # And there should be zero subscriptions left subscriptions = conn.get_all_subscriptions()["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["Subscriptions"] + "ListSubscriptionsResult" + ]["Subscriptions"] subscriptions.should.have.length_of(0) @@ -76,16 +82,17 @@ def test_getting_subscriptions_by_topic(): topics_json = conn.get_all_topics() topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] - topic1_arn = topics[0]['TopicArn'] - topic2_arn = topics[1]['TopicArn'] + topic1_arn = topics[0]["TopicArn"] + topic2_arn = topics[1]["TopicArn"] conn.subscribe(topic1_arn, "http", "http://example1.com/") conn.subscribe(topic2_arn, "http", "http://example2.com/") topic1_subscriptions = conn.get_all_subscriptions_by_topic(topic1_arn)[ - "ListSubscriptionsByTopicResponse"]["ListSubscriptionsByTopicResult"]["Subscriptions"] + "ListSubscriptionsByTopicResponse" + ]["ListSubscriptionsByTopicResult"]["Subscriptions"] topic1_subscriptions.should.have.length_of(1) - topic1_subscriptions[0]['Endpoint'].should.equal("http://example1.com/") + topic1_subscriptions[0]["Endpoint"].should.equal("http://example1.com/") @mock_sns_deprecated @@ -96,40 +103,47 @@ def test_subscription_paging(): topics_json = conn.get_all_topics() topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] - topic1_arn = topics[0]['TopicArn'] - topic2_arn = topics[1]['TopicArn'] + topic1_arn = topics[0]["TopicArn"] + topic2_arn = topics[1]["TopicArn"] for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 3)): - conn.subscribe(topic1_arn, 'email', 'email_' + - str(index) + '@test.com') - conn.subscribe(topic2_arn, 'email', 'email_' + - str(index) + '@test.com') + conn.subscribe(topic1_arn, "email", "email_" + str(index) + "@test.com") + conn.subscribe(topic2_arn, "email", "email_" + str(index) + "@test.com") all_subscriptions = conn.get_all_subscriptions() all_subscriptions["ListSubscriptionsResponse"]["ListSubscriptionsResult"][ - "Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) + "Subscriptions" + ].should.have.length_of(DEFAULT_PAGE_SIZE) next_token = all_subscriptions["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["NextToken"] + "ListSubscriptionsResult" + ]["NextToken"] next_token.should.equal(DEFAULT_PAGE_SIZE) all_subscriptions = conn.get_all_subscriptions(next_token=next_token * 2) all_subscriptions["ListSubscriptionsResponse"]["ListSubscriptionsResult"][ - "Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE * 2 / 3)) + "Subscriptions" + ].should.have.length_of(int(DEFAULT_PAGE_SIZE * 2 / 3)) next_token = all_subscriptions["ListSubscriptionsResponse"][ - "ListSubscriptionsResult"]["NextToken"] + "ListSubscriptionsResult" + ]["NextToken"] next_token.should.equal(None) topic1_subscriptions = conn.get_all_subscriptions_by_topic(topic1_arn) - topic1_subscriptions["ListSubscriptionsByTopicResponse"]["ListSubscriptionsByTopicResult"][ - "Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) + topic1_subscriptions["ListSubscriptionsByTopicResponse"][ + "ListSubscriptionsByTopicResult" + ]["Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) next_token = topic1_subscriptions["ListSubscriptionsByTopicResponse"][ - "ListSubscriptionsByTopicResult"]["NextToken"] + "ListSubscriptionsByTopicResult" + ]["NextToken"] next_token.should.equal(DEFAULT_PAGE_SIZE) topic1_subscriptions = conn.get_all_subscriptions_by_topic( - topic1_arn, next_token=next_token) - topic1_subscriptions["ListSubscriptionsByTopicResponse"]["ListSubscriptionsByTopicResult"][ - "Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE / 3)) + topic1_arn, next_token=next_token + ) + topic1_subscriptions["ListSubscriptionsByTopicResponse"][ + "ListSubscriptionsByTopicResult" + ]["Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE / 3)) next_token = topic1_subscriptions["ListSubscriptionsByTopicResponse"][ - "ListSubscriptionsByTopicResult"]["NextToken"] + "ListSubscriptionsByTopicResult" + ]["NextToken"] next_token.should.equal(None) diff --git a/tests/test_sns/test_subscriptions_boto3.py b/tests/test_sns/test_subscriptions_boto3.py index 282ec4652..04d4eec6e 100644 --- a/tests/test_sns/test_subscriptions_boto3.py +++ b/tests/test_sns/test_subscriptions_boto3.py @@ -13,64 +13,53 @@ from moto.sns.models import DEFAULT_PAGE_SIZE @mock_sns def test_subscribe_sms(): - client = boto3.client('sns', region_name='us-east-1') + client = boto3.client("sns", region_name="us-east-1") client.create_topic(Name="some-topic") resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] + arn = resp["TopicArn"] + + resp = client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="+15551234567") + resp.should.contain("SubscriptionArn") - resp = client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='+15551234567' - ) - resp.should.contain('SubscriptionArn') @mock_sns def test_double_subscription(): - client = boto3.client('sns', region_name='us-east-1') + client = boto3.client("sns", region_name="us-east-1") client.create_topic(Name="some-topic") resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] + arn = resp["TopicArn"] do_subscribe_sqs = lambda sqs_arn: client.subscribe( - TopicArn=arn, - Protocol='sqs', - Endpoint=sqs_arn + TopicArn=arn, Protocol="sqs", Endpoint=sqs_arn ) - resp1 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') - resp2 = do_subscribe_sqs('arn:aws:sqs:elasticmq:000000000000:foo') + resp1 = do_subscribe_sqs("arn:aws:sqs:elasticmq:000000000000:foo") + resp2 = do_subscribe_sqs("arn:aws:sqs:elasticmq:000000000000:foo") - resp1['SubscriptionArn'].should.equal(resp2['SubscriptionArn']) + resp1["SubscriptionArn"].should.equal(resp2["SubscriptionArn"]) @mock_sns def test_subscribe_bad_sms(): - client = boto3.client('sns', region_name='us-east-1') + client = boto3.client("sns", region_name="us-east-1") client.create_topic(Name="some-topic") resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] + arn = resp["TopicArn"] try: # Test invalid number - client.subscribe( - TopicArn=arn, - Protocol='sms', - Endpoint='NAA+15551234567' - ) + client.subscribe(TopicArn=arn, Protocol="sms", Endpoint="NAA+15551234567") except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') + err.response["Error"]["Code"].should.equal("InvalidParameter") @mock_sns def test_creating_subscription(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/") + conn.subscribe(TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/") subscriptions = conn.list_subscriptions()["Subscriptions"] subscriptions.should.have.length_of(1) @@ -90,14 +79,12 @@ def test_creating_subscription(): @mock_sns def test_deleting_subscriptions_by_deleting_topic(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/") + conn.subscribe(TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/") subscriptions = conn.list_subscriptions()["Subscriptions"] subscriptions.should.have.length_of(1) @@ -122,41 +109,44 @@ def test_deleting_subscriptions_by_deleting_topic(): @mock_sns def test_getting_subscriptions_by_topic(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="topic1") conn.create_topic(Name="topic2") response = conn.list_topics() topics = response["Topics"] - topic1_arn = topics[0]['TopicArn'] - topic2_arn = topics[1]['TopicArn'] + topic1_arn = topics[0]["TopicArn"] + topic2_arn = topics[1]["TopicArn"] - conn.subscribe(TopicArn=topic1_arn, - Protocol="http", - Endpoint="http://example1.com/") - conn.subscribe(TopicArn=topic2_arn, - Protocol="http", - Endpoint="http://example2.com/") + conn.subscribe( + TopicArn=topic1_arn, Protocol="http", Endpoint="http://example1.com/" + ) + conn.subscribe( + TopicArn=topic2_arn, Protocol="http", Endpoint="http://example2.com/" + ) topic1_subscriptions = conn.list_subscriptions_by_topic(TopicArn=topic1_arn)[ - "Subscriptions"] + "Subscriptions" + ] topic1_subscriptions.should.have.length_of(1) - topic1_subscriptions[0]['Endpoint'].should.equal("http://example1.com/") + topic1_subscriptions[0]["Endpoint"].should.equal("http://example1.com/") @mock_sns def test_subscription_paging(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="topic1") response = conn.list_topics() topics = response["Topics"] - topic1_arn = topics[0]['TopicArn'] + topic1_arn = topics[0]["TopicArn"] for index in range(DEFAULT_PAGE_SIZE + int(DEFAULT_PAGE_SIZE / 3)): - conn.subscribe(TopicArn=topic1_arn, - Protocol='email', - Endpoint='email_' + str(index) + '@test.com') + conn.subscribe( + TopicArn=topic1_arn, + Protocol="email", + Endpoint="email_" + str(index) + "@test.com", + ) all_subscriptions = conn.list_subscriptions() all_subscriptions["Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) @@ -164,85 +154,87 @@ def test_subscription_paging(): next_token.should.equal(str(DEFAULT_PAGE_SIZE)) all_subscriptions = conn.list_subscriptions(NextToken=next_token) - all_subscriptions["Subscriptions"].should.have.length_of( - int(DEFAULT_PAGE_SIZE / 3)) + all_subscriptions["Subscriptions"].should.have.length_of(int(DEFAULT_PAGE_SIZE / 3)) all_subscriptions.shouldnt.have("NextToken") - topic1_subscriptions = conn.list_subscriptions_by_topic( - TopicArn=topic1_arn) - topic1_subscriptions["Subscriptions"].should.have.length_of( - DEFAULT_PAGE_SIZE) + topic1_subscriptions = conn.list_subscriptions_by_topic(TopicArn=topic1_arn) + topic1_subscriptions["Subscriptions"].should.have.length_of(DEFAULT_PAGE_SIZE) next_token = topic1_subscriptions["NextToken"] next_token.should.equal(str(DEFAULT_PAGE_SIZE)) topic1_subscriptions = conn.list_subscriptions_by_topic( - TopicArn=topic1_arn, NextToken=next_token) + TopicArn=topic1_arn, NextToken=next_token + ) topic1_subscriptions["Subscriptions"].should.have.length_of( - int(DEFAULT_PAGE_SIZE / 3)) + int(DEFAULT_PAGE_SIZE / 3) + ) topic1_subscriptions.shouldnt.have("NextToken") + @mock_sns def test_subscribe_attributes(): - client = boto3.client('sns', region_name='us-east-1') + client = boto3.client("sns", region_name="us-east-1") client.create_topic(Name="some-topic") resp = client.create_topic(Name="some-topic") - arn = resp['TopicArn'] + arn = resp["TopicArn"] - resp = client.subscribe( - TopicArn=arn, - Protocol='http', - Endpoint='http://test.com' - ) + resp = client.subscribe(TopicArn=arn, Protocol="http", Endpoint="http://test.com") attributes = client.get_subscription_attributes( - SubscriptionArn=resp['SubscriptionArn'] + SubscriptionArn=resp["SubscriptionArn"] ) - attributes.should.contain('Attributes') - attributes['Attributes'].should.contain('PendingConfirmation') - attributes['Attributes']['PendingConfirmation'].should.equal('false') - attributes['Attributes'].should.contain('Endpoint') - attributes['Attributes']['Endpoint'].should.equal('http://test.com') - attributes['Attributes'].should.contain('TopicArn') - attributes['Attributes']['TopicArn'].should.equal(arn) - attributes['Attributes'].should.contain('Protocol') - attributes['Attributes']['Protocol'].should.equal('http') - attributes['Attributes'].should.contain('SubscriptionArn') - attributes['Attributes']['SubscriptionArn'].should.equal(resp['SubscriptionArn']) + attributes.should.contain("Attributes") + attributes["Attributes"].should.contain("PendingConfirmation") + attributes["Attributes"]["PendingConfirmation"].should.equal("false") + attributes["Attributes"].should.contain("Endpoint") + attributes["Attributes"]["Endpoint"].should.equal("http://test.com") + attributes["Attributes"].should.contain("TopicArn") + attributes["Attributes"]["TopicArn"].should.equal(arn) + attributes["Attributes"].should.contain("Protocol") + attributes["Attributes"]["Protocol"].should.equal("http") + attributes["Attributes"].should.contain("SubscriptionArn") + attributes["Attributes"]["SubscriptionArn"].should.equal(resp["SubscriptionArn"]) @mock_sns def test_creating_subscription_with_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - delivery_policy = json.dumps({ - 'healthyRetryPolicy': { - "numRetries": 10, - "minDelayTarget": 1, - "maxDelayTarget":2 + delivery_policy = json.dumps( + { + "healthyRetryPolicy": { + "numRetries": 10, + "minDelayTarget": 1, + "maxDelayTarget": 2, + } } - }) + ) - filter_policy = json.dumps({ - "store": ["example_corp"], - "event": ["order_cancelled"], - "encrypted": [False], - "customer_interests": ["basketball", "baseball"], - "price": [100, 100.12], - "error": [None] - }) + filter_policy = json.dumps( + { + "store": ["example_corp"], + "event": ["order_cancelled"], + "encrypted": [False], + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None], + } + ) - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/", - Attributes={ - 'RawMessageDelivery': 'true', - 'DeliveryPolicy': delivery_policy, - 'FilterPolicy': filter_policy - }) + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + "RawMessageDelivery": "true", + "DeliveryPolicy": delivery_policy, + "FilterPolicy": filter_policy, + }, + ) subscriptions = conn.list_subscriptions()["Subscriptions"] subscriptions.should.have.length_of(1) @@ -254,13 +246,11 @@ def test_creating_subscription_with_attributes(): # Test the subscription attributes have been set subscription_arn = subscription["SubscriptionArn"] - attrs = conn.get_subscription_attributes( - SubscriptionArn=subscription_arn - ) + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) - attrs['Attributes']['RawMessageDelivery'].should.equal('true') - attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy) - attrs['Attributes']['FilterPolicy'].should.equal(filter_policy) + attrs["Attributes"]["RawMessageDelivery"].should.equal("true") + attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy) + attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy) # Now unsubscribe the subscription conn.unsubscribe(SubscriptionArn=subscription["SubscriptionArn"]) @@ -271,24 +261,22 @@ def test_creating_subscription_with_attributes(): # invalid attr name with assert_raises(ClientError): - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/", - Attributes={ - 'InvalidName': 'true' - }) + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"InvalidName": "true"}, + ) @mock_sns def test_set_subscription_attributes(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response["Topics"][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] - conn.subscribe(TopicArn=topic_arn, - Protocol="http", - Endpoint="http://example.com/") + conn.subscribe(TopicArn=topic_arn, Protocol="http", Endpoint="http://example.com/") subscriptions = conn.list_subscriptions()["Subscriptions"] subscriptions.should.have.length_of(1) @@ -299,202 +287,200 @@ def test_set_subscription_attributes(): subscription["Endpoint"].should.equal("http://example.com/") subscription_arn = subscription["SubscriptionArn"] - attrs = conn.get_subscription_attributes( - SubscriptionArn=subscription_arn - ) - attrs.should.have.key('Attributes') + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) + attrs.should.have.key("Attributes") conn.set_subscription_attributes( SubscriptionArn=subscription_arn, - AttributeName='RawMessageDelivery', - AttributeValue='true' + AttributeName="RawMessageDelivery", + AttributeValue="true", ) - delivery_policy = json.dumps({ - 'healthyRetryPolicy': { - "numRetries": 10, - "minDelayTarget": 1, - "maxDelayTarget":2 + delivery_policy = json.dumps( + { + "healthyRetryPolicy": { + "numRetries": 10, + "minDelayTarget": 1, + "maxDelayTarget": 2, + } } - }) + ) conn.set_subscription_attributes( SubscriptionArn=subscription_arn, - AttributeName='DeliveryPolicy', - AttributeValue=delivery_policy + AttributeName="DeliveryPolicy", + AttributeValue=delivery_policy, ) - filter_policy = json.dumps({ - "store": ["example_corp"], - "event": ["order_cancelled"], - "encrypted": [False], - "customer_interests": ["basketball", "baseball"], - "price": [100, 100.12], - "error": [None] - }) + filter_policy = json.dumps( + { + "store": ["example_corp"], + "event": ["order_cancelled"], + "encrypted": [False], + "customer_interests": ["basketball", "baseball"], + "price": [100, 100.12], + "error": [None], + } + ) conn.set_subscription_attributes( SubscriptionArn=subscription_arn, - AttributeName='FilterPolicy', - AttributeValue=filter_policy + AttributeName="FilterPolicy", + AttributeValue=filter_policy, ) - attrs = conn.get_subscription_attributes( - SubscriptionArn=subscription_arn - ) + attrs = conn.get_subscription_attributes(SubscriptionArn=subscription_arn) - attrs['Attributes']['RawMessageDelivery'].should.equal('true') - attrs['Attributes']['DeliveryPolicy'].should.equal(delivery_policy) - attrs['Attributes']['FilterPolicy'].should.equal(filter_policy) + attrs["Attributes"]["RawMessageDelivery"].should.equal("true") + attrs["Attributes"]["DeliveryPolicy"].should.equal(delivery_policy) + attrs["Attributes"]["FilterPolicy"].should.equal(filter_policy) # not existing subscription with assert_raises(ClientError): conn.set_subscription_attributes( - SubscriptionArn='invalid', - AttributeName='RawMessageDelivery', - AttributeValue='true' + SubscriptionArn="invalid", + AttributeName="RawMessageDelivery", + AttributeValue="true", ) with assert_raises(ClientError): - attrs = conn.get_subscription_attributes( - SubscriptionArn='invalid' - ) - + attrs = conn.get_subscription_attributes(SubscriptionArn="invalid") # invalid attr name with assert_raises(ClientError): conn.set_subscription_attributes( SubscriptionArn=subscription_arn, - AttributeName='InvalidName', - AttributeValue='true' + AttributeName="InvalidName", + AttributeValue="true", ) @mock_sns def test_subscribe_invalid_filter_policy(): - conn = boto3.client('sns', region_name = 'us-east-1') - conn.create_topic(Name = 'some-topic') + conn = boto3.client("sns", region_name="us-east-1") + conn.create_topic(Name="some-topic") response = conn.list_topics() - topic_arn = response['Topics'][0]['TopicArn'] + topic_arn = response["Topics"][0]["TopicArn"] try: - conn.subscribe(TopicArn = topic_arn, - Protocol = 'http', - Endpoint = 'http://example.com/', - Attributes = { - 'FilterPolicy': json.dumps({ - 'store': [str(i) for i in range(151)] - }) - }) + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={ + "FilterPolicy": json.dumps({"store": [str(i) for i in range(151)]}) + }, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') - err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Filter policy is too complex') + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Filter policy is too complex" + ) try: - conn.subscribe(TopicArn = topic_arn, - Protocol = 'http', - Endpoint = 'http://example.com/', - Attributes = { - 'FilterPolicy': json.dumps({ - 'store': [['example_corp']] - }) - }) + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [["example_corp"]]})}, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') - err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null') + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Match value must be String, number, true, false, or null" + ) try: - conn.subscribe(TopicArn = topic_arn, - Protocol = 'http', - Endpoint = 'http://example.com/', - Attributes = { - 'FilterPolicy': json.dumps({ - 'store': [{'exists': None}] - }) - }) + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [{"exists": None}]})}, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') - err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: exists match pattern must be either true or false.') + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: exists match pattern must be either true or false." + ) try: - conn.subscribe(TopicArn = topic_arn, - Protocol = 'http', - Endpoint = 'http://example.com/', - Attributes = { - 'FilterPolicy': json.dumps({ - 'store': [{'error': True}] - }) - }) + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [{"error": True}]})}, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameter') - err.response['Error']['Message'].should.equal('Invalid parameter: FilterPolicy: Unrecognized match type error') + err.response["Error"]["Code"].should.equal("InvalidParameter") + err.response["Error"]["Message"].should.equal( + "Invalid parameter: FilterPolicy: Unrecognized match type error" + ) try: - conn.subscribe(TopicArn = topic_arn, - Protocol = 'http', - Endpoint = 'http://example.com/', - Attributes = { - 'FilterPolicy': json.dumps({ - 'store': [1000000001] - }) - }) + conn.subscribe( + TopicArn=topic_arn, + Protocol="http", + Endpoint="http://example.com/", + Attributes={"FilterPolicy": json.dumps({"store": [1000000001]})}, + ) except ClientError as err: - err.response['Error']['Code'].should.equal('InternalFailure') + err.response["Error"]["Code"].should.equal("InternalFailure") + @mock_sns def test_check_not_opted_out(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.check_if_phone_number_is_opted_out(phoneNumber='+447428545375') + conn = boto3.client("sns", region_name="us-east-1") + response = conn.check_if_phone_number_is_opted_out(phoneNumber="+447428545375") - response.should.contain('isOptedOut') - response['isOptedOut'].should.be(False) + response.should.contain("isOptedOut") + response["isOptedOut"].should.be(False) @mock_sns def test_check_opted_out(): # Phone number ends in 99 so is hardcoded in the endpoint to return opted # out status - conn = boto3.client('sns', region_name='us-east-1') - response = conn.check_if_phone_number_is_opted_out(phoneNumber='+447428545399') + conn = boto3.client("sns", region_name="us-east-1") + response = conn.check_if_phone_number_is_opted_out(phoneNumber="+447428545399") - response.should.contain('isOptedOut') - response['isOptedOut'].should.be(True) + response.should.contain("isOptedOut") + response["isOptedOut"].should.be(True) @mock_sns def test_check_opted_out_invalid(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") # Invalid phone number with assert_raises(ClientError): - conn.check_if_phone_number_is_opted_out(phoneNumber='+44742LALALA') + conn.check_if_phone_number_is_opted_out(phoneNumber="+44742LALALA") @mock_sns def test_list_opted_out(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") response = conn.list_phone_numbers_opted_out() - response.should.contain('phoneNumbers') - len(response['phoneNumbers']).should.be.greater_than(0) + response.should.contain("phoneNumbers") + len(response["phoneNumbers"]).should.be.greater_than(0) @mock_sns def test_opt_in(): - conn = boto3.client('sns', region_name='us-east-1') + conn = boto3.client("sns", region_name="us-east-1") response = conn.list_phone_numbers_opted_out() - current_len = len(response['phoneNumbers']) + current_len = len(response["phoneNumbers"]) assert current_len > 0 - conn.opt_in_phone_number(phoneNumber=response['phoneNumbers'][0]) + conn.opt_in_phone_number(phoneNumber=response["phoneNumbers"][0]) response = conn.list_phone_numbers_opted_out() - len(response['phoneNumbers']).should.be.greater_than(0) - len(response['phoneNumbers']).should.be.lower_than(current_len) + len(response["phoneNumbers"]).should.be.greater_than(0) + len(response["phoneNumbers"]).should.be.lower_than(current_len) @mock_sns def test_confirm_subscription(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.create_topic(Name='testconfirm') + conn = boto3.client("sns", region_name="us-east-1") + response = conn.create_topic(Name="testconfirm") conn.confirm_subscription( - TopicArn=response['TopicArn'], - Token='2336412f37fb687f5d51e6e241d59b68c4e583a5cee0be6f95bbf97ab8d2441cf47b99e848408adaadf4c197e65f03473d53c4ba398f6abbf38ce2e8ebf7b4ceceb2cd817959bcde1357e58a2861b05288c535822eb88cac3db04f592285249971efc6484194fc4a4586147f16916692', - AuthenticateOnUnsubscribe='true' + TopicArn=response["TopicArn"], + Token="2336412f37fb687f5d51e6e241d59b68c4e583a5cee0be6f95bbf97ab8d2441cf47b99e848408adaadf4c197e65f03473d53c4ba398f6abbf38ce2e8ebf7b4ceceb2cd817959bcde1357e58a2861b05288c535822eb88cac3db04f592285249971efc6484194fc4a4586147f16916692", + AuthenticateOnUnsubscribe="true", ) diff --git a/tests/test_sns/test_topics.py b/tests/test_sns/test_topics.py index a7d9723cc..4a5100c94 100644 --- a/tests/test_sns/test_topics.py +++ b/tests/test_sns/test_topics.py @@ -18,13 +18,12 @@ def test_create_and_delete_topic(): topics_json = conn.get_all_topics() topics = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] topics.should.have.length_of(1) - topics[0]['TopicArn'].should.equal( - "arn:aws:sns:{0}:123456789012:some-topic" - .format(conn.region.name) + topics[0]["TopicArn"].should.equal( + "arn:aws:sns:{0}:123456789012:some-topic".format(conn.region.name) ) # Delete the topic - conn.delete_topic(topics[0]['TopicArn']) + conn.delete_topic(topics[0]["TopicArn"]) # And there should now be 0 topics topics_json = conn.get_all_topics() @@ -35,29 +34,31 @@ def test_create_and_delete_topic(): @mock_sns_deprecated def test_get_missing_topic(): conn = boto.connect_sns() - conn.get_topic_attributes.when.called_with( - "a-fake-arn").should.throw(BotoServerError) + conn.get_topic_attributes.when.called_with("a-fake-arn").should.throw( + BotoServerError + ) @mock_sns_deprecated def test_create_topic_in_multiple_regions(): - for region in ['us-west-1', 'us-west-2']: + for region in ["us-west-1", "us-west-2"]: conn = boto.sns.connect_to_region(region) conn.create_topic("some-topic") - list(conn.get_all_topics()["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"]).should.have.length_of(1) + list( + conn.get_all_topics()["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + ).should.have.length_of(1) @mock_sns_deprecated def test_topic_corresponds_to_region(): - for region in ['us-east-1', 'us-west-2']: + for region in ["us-east-1", "us-west-2"]: conn = boto.sns.connect_to_region(region) conn.create_topic("some-topic") topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] - topic_arn.should.equal( - "arn:aws:sns:{0}:123456789012:some-topic".format(region)) + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] + topic_arn.should.equal("arn:aws:sns:{0}:123456789012:some-topic".format(region)) @mock_sns_deprecated @@ -66,51 +67,51 @@ def test_topic_attributes(): conn.create_topic("some-topic") topics_json = conn.get_all_topics() - topic_arn = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"][0]['TopicArn'] + topic_arn = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"][0][ + "TopicArn" + ] - attributes = conn.get_topic_attributes(topic_arn)['GetTopicAttributesResponse'][ - 'GetTopicAttributesResult']['Attributes'] + attributes = conn.get_topic_attributes(topic_arn)["GetTopicAttributesResponse"][ + "GetTopicAttributesResult" + ]["Attributes"] attributes["TopicArn"].should.equal( - "arn:aws:sns:{0}:123456789012:some-topic" - .format(conn.region.name) + "arn:aws:sns:{0}:123456789012:some-topic".format(conn.region.name) ) attributes["Owner"].should.equal(123456789012) - json.loads(attributes["Policy"]).should.equal({ - "Version": "2008-10-17", - "Id": "__default_policy_ID", - "Statement": [{ - "Effect": "Allow", - "Sid": "__default_statement_ID", - "Principal": { - "AWS": "*" - }, - "Action": [ - "SNS:GetTopicAttributes", - "SNS:SetTopicAttributes", - "SNS:AddPermission", - "SNS:RemovePermission", - "SNS:DeleteTopic", - "SNS:Subscribe", - "SNS:ListSubscriptionsByTopic", - "SNS:Publish", - "SNS:Receive", - ], - "Resource": "arn:aws:sns:us-east-1:123456789012:some-topic", - "Condition": { - "StringEquals": { - "AWS:SourceOwner": "123456789012" + json.loads(attributes["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:123456789012:some-topic", + "Condition": {"StringEquals": {"AWS:SourceOwner": "123456789012"}}, } - } - }] - }) + ], + } + ) attributes["DisplayName"].should.equal("") attributes["SubscriptionsPending"].should.equal(0) attributes["SubscriptionsConfirmed"].should.equal(0) attributes["SubscriptionsDeleted"].should.equal(0) attributes["DeliveryPolicy"].should.equal("") json.loads(attributes["EffectiveDeliveryPolicy"]).should.equal( - DEFAULT_EFFECTIVE_DELIVERY_POLICY) + DEFAULT_EFFECTIVE_DELIVERY_POLICY + ) # boto can't handle prefix-mandatory strings: # i.e. unicode on Python 2 -- u"foobar" @@ -120,19 +121,21 @@ def test_topic_attributes(): displayname = b"My display name" delivery = {b"http": {b"defaultHealthyRetryPolicy": {b"numRetries": 5}}} else: - policy = json.dumps({u"foo": u"bar"}) - displayname = u"My display name" - delivery = {u"http": {u"defaultHealthyRetryPolicy": {u"numRetries": 5}}} + policy = json.dumps({"foo": "bar"}) + displayname = "My display name" + delivery = {"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}} conn.set_topic_attributes(topic_arn, "Policy", policy) conn.set_topic_attributes(topic_arn, "DisplayName", displayname) conn.set_topic_attributes(topic_arn, "DeliveryPolicy", delivery) - attributes = conn.get_topic_attributes(topic_arn)['GetTopicAttributesResponse'][ - 'GetTopicAttributesResult']['Attributes'] + attributes = conn.get_topic_attributes(topic_arn)["GetTopicAttributesResponse"][ + "GetTopicAttributesResult" + ]["Attributes"] attributes["Policy"].should.equal('{"foo": "bar"}') attributes["DisplayName"].should.equal("My display name") attributes["DeliveryPolicy"].should.equal( - "{'http': {'defaultHealthyRetryPolicy': {'numRetries': 5}}}") + "{'http': {'defaultHealthyRetryPolicy': {'numRetries': 5}}}" + ) @mock_sns_deprecated @@ -142,19 +145,15 @@ def test_topic_paging(): conn.create_topic("some-topic_" + str(index)) topics_json = conn.get_all_topics() - topics_list = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"] - next_token = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["NextToken"] + topics_list = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + next_token = topics_json["ListTopicsResponse"]["ListTopicsResult"]["NextToken"] len(topics_list).should.equal(DEFAULT_PAGE_SIZE) next_token.should.equal(DEFAULT_PAGE_SIZE) topics_json = conn.get_all_topics(next_token=next_token) - topics_list = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["Topics"] - next_token = topics_json["ListTopicsResponse"][ - "ListTopicsResult"]["NextToken"] + topics_list = topics_json["ListTopicsResponse"]["ListTopicsResult"]["Topics"] + next_token = topics_json["ListTopicsResponse"]["ListTopicsResult"]["NextToken"] topics_list.should.have.length_of(int(DEFAULT_PAGE_SIZE / 2)) next_token.should.equal(None) diff --git a/tests/test_sns/test_topics_boto3.py b/tests/test_sns/test_topics_boto3.py index de7bb10cc..e4c9d303f 100644 --- a/tests/test_sns/test_topics_boto3.py +++ b/tests/test_sns/test_topics_boto3.py @@ -13,19 +13,20 @@ from moto.sns.models import DEFAULT_EFFECTIVE_DELIVERY_POLICY, DEFAULT_PAGE_SIZE @mock_sns def test_create_and_delete_topic(): conn = boto3.client("sns", region_name="us-east-1") - for topic_name in ('some-topic', '-some-topic-', '_some-topic_', 'a' * 256): + for topic_name in ("some-topic", "-some-topic-", "_some-topic_", "a" * 256): conn.create_topic(Name=topic_name) topics_json = conn.list_topics() topics = topics_json["Topics"] topics.should.have.length_of(1) - topics[0]['TopicArn'].should.equal( - "arn:aws:sns:{0}:123456789012:{1}" - .format(conn._client_config.region_name, topic_name) + topics[0]["TopicArn"].should.equal( + "arn:aws:sns:{0}:123456789012:{1}".format( + conn._client_config.region_name, topic_name + ) ) # Delete the topic - conn.delete_topic(TopicArn=topics[0]['TopicArn']) + conn.delete_topic(TopicArn=topics[0]["TopicArn"]) # And there should now be 0 topics topics_json = conn.list_topics() @@ -36,96 +37,89 @@ def test_create_and_delete_topic(): @mock_sns def test_create_topic_with_attributes(): conn = boto3.client("sns", region_name="us-east-1") - conn.create_topic(Name='some-topic-with-attribute', Attributes={'DisplayName': 'test-topic'}) + conn.create_topic( + Name="some-topic-with-attribute", Attributes={"DisplayName": "test-topic"} + ) topics_json = conn.list_topics() - topic_arn = topics_json["Topics"][0]['TopicArn'] + topic_arn = topics_json["Topics"][0]["TopicArn"] - attributes = conn.get_topic_attributes(TopicArn=topic_arn)['Attributes'] - attributes['DisplayName'].should.equal('test-topic') + attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"] + attributes["DisplayName"].should.equal("test-topic") @mock_sns def test_create_topic_with_tags(): conn = boto3.client("sns", region_name="us-east-1") response = conn.create_topic( - Name='some-topic-with-tags', + Name="some-topic-with-tags", Tags=[ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_1' - }, - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, + ], + ) + topic_arn = response["TopicArn"] + + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, ] ) - topic_arn = response['TopicArn'] - - conn.list_tags_for_resource(ResourceArn=topic_arn)['Tags'].should.equal([ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_1' - }, - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } - ]) @mock_sns def test_create_topic_should_be_indempodent(): conn = boto3.client("sns", region_name="us-east-1") - topic_arn = conn.create_topic(Name="some-topic")['TopicArn'] + topic_arn = conn.create_topic(Name="some-topic")["TopicArn"] conn.set_topic_attributes( - TopicArn=topic_arn, - AttributeName="DisplayName", - AttributeValue="should_be_set" + TopicArn=topic_arn, AttributeName="DisplayName", AttributeValue="should_be_set" ) - topic_display_name = conn.get_topic_attributes( - TopicArn=topic_arn - )['Attributes']['DisplayName'] + topic_display_name = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"][ + "DisplayName" + ] topic_display_name.should.be.equal("should_be_set") - #recreate topic to prove indempodentcy - topic_arn = conn.create_topic(Name="some-topic")['TopicArn'] - topic_display_name = conn.get_topic_attributes( - TopicArn=topic_arn - )['Attributes']['DisplayName'] + # recreate topic to prove indempodentcy + topic_arn = conn.create_topic(Name="some-topic")["TopicArn"] + topic_display_name = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"][ + "DisplayName" + ] topic_display_name.should.be.equal("should_be_set") + @mock_sns def test_get_missing_topic(): conn = boto3.client("sns", region_name="us-east-1") - conn.get_topic_attributes.when.called_with( - TopicArn="a-fake-arn").should.throw(ClientError) + conn.get_topic_attributes.when.called_with(TopicArn="a-fake-arn").should.throw( + ClientError + ) + @mock_sns def test_create_topic_must_meet_constraints(): conn = boto3.client("sns", region_name="us-east-1") - common_random_chars = [':', ";", "!", "@", "|", "^", "%"] + common_random_chars = [":", ";", "!", "@", "|", "^", "%"] for char in common_random_chars: - conn.create_topic.when.called_with( - Name="no%s_invalidchar" % char).should.throw(ClientError) - conn.create_topic.when.called_with( - Name="no spaces allowed").should.throw(ClientError) + conn.create_topic.when.called_with(Name="no%s_invalidchar" % char).should.throw( + ClientError + ) + conn.create_topic.when.called_with(Name="no spaces allowed").should.throw( + ClientError + ) @mock_sns def test_create_topic_should_be_of_certain_length(): conn = boto3.client("sns", region_name="us-east-1") too_short = "" - conn.create_topic.when.called_with( - Name=too_short).should.throw(ClientError) + conn.create_topic.when.called_with(Name=too_short).should.throw(ClientError) too_long = "x" * 257 - conn.create_topic.when.called_with( - Name=too_long).should.throw(ClientError) + conn.create_topic.when.called_with(Name=too_long).should.throw(ClientError) @mock_sns def test_create_topic_in_multiple_regions(): - for region in ['us-west-1', 'us-west-2']: + for region in ["us-west-1", "us-west-2"]: conn = boto3.client("sns", region_name=region) conn.create_topic(Name="some-topic") list(conn.list_topics()["Topics"]).should.have.length_of(1) @@ -133,13 +127,12 @@ def test_create_topic_in_multiple_regions(): @mock_sns def test_topic_corresponds_to_region(): - for region in ['us-east-1', 'us-west-2']: + for region in ["us-east-1", "us-west-2"]: conn = boto3.client("sns", region_name=region) conn.create_topic(Name="some-topic") topics_json = conn.list_topics() - topic_arn = topics_json["Topics"][0]['TopicArn'] - topic_arn.should.equal( - "arn:aws:sns:{0}:123456789012:some-topic".format(region)) + topic_arn = topics_json["Topics"][0]["TopicArn"] + topic_arn.should.equal("arn:aws:sns:{0}:123456789012:some-topic".format(region)) @mock_sns @@ -148,49 +141,49 @@ def test_topic_attributes(): conn.create_topic(Name="some-topic") topics_json = conn.list_topics() - topic_arn = topics_json["Topics"][0]['TopicArn'] + topic_arn = topics_json["Topics"][0]["TopicArn"] - attributes = conn.get_topic_attributes(TopicArn=topic_arn)['Attributes'] + attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"] attributes["TopicArn"].should.equal( - "arn:aws:sns:{0}:123456789012:some-topic" - .format(conn._client_config.region_name) + "arn:aws:sns:{0}:123456789012:some-topic".format( + conn._client_config.region_name + ) ) - attributes["Owner"].should.equal('123456789012') - json.loads(attributes["Policy"]).should.equal({ - "Version": "2008-10-17", - "Id": "__default_policy_ID", - "Statement": [{ - "Effect": "Allow", - "Sid": "__default_statement_ID", - "Principal": { - "AWS": "*" - }, - "Action": [ - "SNS:GetTopicAttributes", - "SNS:SetTopicAttributes", - "SNS:AddPermission", - "SNS:RemovePermission", - "SNS:DeleteTopic", - "SNS:Subscribe", - "SNS:ListSubscriptionsByTopic", - "SNS:Publish", - "SNS:Receive", - ], - "Resource": "arn:aws:sns:us-east-1:123456789012:some-topic", - "Condition": { - "StringEquals": { - "AWS:SourceOwner": "123456789012" + attributes["Owner"].should.equal("123456789012") + json.loads(attributes["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:123456789012:some-topic", + "Condition": {"StringEquals": {"AWS:SourceOwner": "123456789012"}}, } - } - }] - }) + ], + } + ) attributes["DisplayName"].should.equal("") - attributes["SubscriptionsPending"].should.equal('0') - attributes["SubscriptionsConfirmed"].should.equal('0') - attributes["SubscriptionsDeleted"].should.equal('0') + attributes["SubscriptionsPending"].should.equal("0") + attributes["SubscriptionsConfirmed"].should.equal("0") + attributes["SubscriptionsDeleted"].should.equal("0") attributes["DeliveryPolicy"].should.equal("") json.loads(attributes["EffectiveDeliveryPolicy"]).should.equal( - DEFAULT_EFFECTIVE_DELIVERY_POLICY) + DEFAULT_EFFECTIVE_DELIVERY_POLICY + ) # boto can't handle prefix-mandatory strings: # i.e. unicode on Python 2 -- u"foobar" @@ -199,27 +192,30 @@ def test_topic_attributes(): policy = json.dumps({b"foo": b"bar"}) displayname = b"My display name" delivery = json.dumps( - {b"http": {b"defaultHealthyRetryPolicy": {b"numRetries": 5}}}) + {b"http": {b"defaultHealthyRetryPolicy": {b"numRetries": 5}}} + ) else: - policy = json.dumps({u"foo": u"bar"}) - displayname = u"My display name" + policy = json.dumps({"foo": "bar"}) + displayname = "My display name" delivery = json.dumps( - {u"http": {u"defaultHealthyRetryPolicy": {u"numRetries": 5}}}) - conn.set_topic_attributes(TopicArn=topic_arn, - AttributeName="Policy", - AttributeValue=policy) - conn.set_topic_attributes(TopicArn=topic_arn, - AttributeName="DisplayName", - AttributeValue=displayname) - conn.set_topic_attributes(TopicArn=topic_arn, - AttributeName="DeliveryPolicy", - AttributeValue=delivery) + {"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}} + ) + conn.set_topic_attributes( + TopicArn=topic_arn, AttributeName="Policy", AttributeValue=policy + ) + conn.set_topic_attributes( + TopicArn=topic_arn, AttributeName="DisplayName", AttributeValue=displayname + ) + conn.set_topic_attributes( + TopicArn=topic_arn, AttributeName="DeliveryPolicy", AttributeValue=delivery + ) - attributes = conn.get_topic_attributes(TopicArn=topic_arn)['Attributes'] + attributes = conn.get_topic_attributes(TopicArn=topic_arn)["Attributes"] attributes["Policy"].should.equal('{"foo": "bar"}') attributes["DisplayName"].should.equal("My display name") attributes["DeliveryPolicy"].should.equal( - '{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}') + '{"http": {"defaultHealthyRetryPolicy": {"numRetries": 5}}}' + ) @mock_sns @@ -244,389 +240,263 @@ def test_topic_paging(): @mock_sns def test_add_remove_permissions(): - client = boto3.client('sns', region_name='us-east-1') - topic_arn = client.create_topic(Name='test-permissions')['TopicArn'] + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="test-permissions")["TopicArn"] client.add_permission( TopicArn=topic_arn, - Label='test', - AWSAccountId=['999999999999'], - ActionName=['Publish'] + Label="test", + AWSAccountId=["999999999999"], + ActionName=["Publish"], ) response = client.get_topic_attributes(TopicArn=topic_arn) - json.loads(response['Attributes']['Policy']).should.equal({ - 'Version': '2008-10-17', - 'Id': '__default_policy_ID', - 'Statement': [ - { - 'Effect': 'Allow', - 'Sid': '__default_statement_ID', - 'Principal': { - 'AWS': '*' + json.loads(response["Attributes"]["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:123456789012:test-permissions", + "Condition": {"StringEquals": {"AWS:SourceOwner": "123456789012"}}, }, - 'Action': [ - 'SNS:GetTopicAttributes', - 'SNS:SetTopicAttributes', - 'SNS:AddPermission', - 'SNS:RemovePermission', - 'SNS:DeleteTopic', - 'SNS:Subscribe', - 'SNS:ListSubscriptionsByTopic', - 'SNS:Publish', - 'SNS:Receive', - ], - 'Resource': 'arn:aws:sns:us-east-1:123456789012:test-permissions', - 'Condition': { - 'StringEquals': { - 'AWS:SourceOwner': '123456789012' - } + { + "Sid": "test", + "Effect": "Allow", + "Principal": {"AWS": "arn:aws:iam::999999999999:root"}, + "Action": "SNS:Publish", + "Resource": "arn:aws:sns:us-east-1:123456789012:test-permissions", + }, + ], + } + ) + + client.remove_permission(TopicArn=topic_arn, Label="test") + + response = client.get_topic_attributes(TopicArn=topic_arn) + json.loads(response["Attributes"]["Policy"]).should.equal( + { + "Version": "2008-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Effect": "Allow", + "Sid": "__default_statement_ID", + "Principal": {"AWS": "*"}, + "Action": [ + "SNS:GetTopicAttributes", + "SNS:SetTopicAttributes", + "SNS:AddPermission", + "SNS:RemovePermission", + "SNS:DeleteTopic", + "SNS:Subscribe", + "SNS:ListSubscriptionsByTopic", + "SNS:Publish", + "SNS:Receive", + ], + "Resource": "arn:aws:sns:us-east-1:123456789012:test-permissions", + "Condition": {"StringEquals": {"AWS:SourceOwner": "123456789012"}}, } + ], + } + ) + + client.add_permission( + TopicArn=topic_arn, + Label="test", + AWSAccountId=["888888888888", "999999999999"], + ActionName=["Publish", "Subscribe"], + ) + + response = client.get_topic_attributes(TopicArn=topic_arn) + json.loads(response["Attributes"]["Policy"])["Statement"][1].should.equal( + { + "Sid": "test", + "Effect": "Allow", + "Principal": { + "AWS": [ + "arn:aws:iam::888888888888:root", + "arn:aws:iam::999999999999:root", + ] }, - { - 'Sid': 'test', - 'Effect': 'Allow', - 'Principal': { - 'AWS': 'arn:aws:iam::999999999999:root' - }, - 'Action': 'SNS:Publish', - 'Resource': 'arn:aws:sns:us-east-1:123456789012:test-permissions' - } - ] - }) - - client.remove_permission( - TopicArn=topic_arn, - Label='test' + "Action": ["SNS:Publish", "SNS:Subscribe"], + "Resource": "arn:aws:sns:us-east-1:123456789012:test-permissions", + } ) - response = client.get_topic_attributes(TopicArn=topic_arn) - json.loads(response['Attributes']['Policy']).should.equal({ - 'Version': '2008-10-17', - 'Id': '__default_policy_ID', - 'Statement': [ - { - 'Effect': 'Allow', - 'Sid': '__default_statement_ID', - 'Principal': { - 'AWS': '*' - }, - 'Action': [ - 'SNS:GetTopicAttributes', - 'SNS:SetTopicAttributes', - 'SNS:AddPermission', - 'SNS:RemovePermission', - 'SNS:DeleteTopic', - 'SNS:Subscribe', - 'SNS:ListSubscriptionsByTopic', - 'SNS:Publish', - 'SNS:Receive', - ], - 'Resource': 'arn:aws:sns:us-east-1:123456789012:test-permissions', - 'Condition': { - 'StringEquals': { - 'AWS:SourceOwner': '123456789012' - } - } - } - ] - }) - - client.add_permission( - TopicArn=topic_arn, - Label='test', - AWSAccountId=[ - '888888888888', - '999999999999' - ], - ActionName=[ - 'Publish', - 'Subscribe' - ] - ) - - response = client.get_topic_attributes(TopicArn=topic_arn) - json.loads(response['Attributes']['Policy'])['Statement'][1].should.equal({ - 'Sid': 'test', - 'Effect': 'Allow', - 'Principal': { - 'AWS': [ - 'arn:aws:iam::888888888888:root', - 'arn:aws:iam::999999999999:root' - ] - }, - 'Action': [ - 'SNS:Publish', - 'SNS:Subscribe' - ], - 'Resource': 'arn:aws:sns:us-east-1:123456789012:test-permissions' - }) - # deleting non existing permission should be successful - client.remove_permission( - TopicArn=topic_arn, - Label='non-existing' - ) + client.remove_permission(TopicArn=topic_arn, Label="non-existing") @mock_sns def test_add_permission_errors(): - client = boto3.client('sns', region_name='us-east-1') - topic_arn = client.create_topic(Name='test-permissions')['TopicArn'] + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="test-permissions")["TopicArn"] client.add_permission( TopicArn=topic_arn, - Label='test', - AWSAccountId=['999999999999'], - ActionName=['Publish'] + Label="test", + AWSAccountId=["999999999999"], + ActionName=["Publish"], ) client.add_permission.when.called_with( TopicArn=topic_arn, - Label='test', - AWSAccountId=['999999999999'], - ActionName=['AddPermission'] - ).should.throw( - ClientError, - 'Statement already exists' - ) + Label="test", + AWSAccountId=["999999999999"], + ActionName=["AddPermission"], + ).should.throw(ClientError, "Statement already exists") client.add_permission.when.called_with( - TopicArn=topic_arn + '-not-existing', - Label='test-2', - AWSAccountId=['999999999999'], - ActionName=['AddPermission'] - ).should.throw( - ClientError, - 'Topic does not exist' - ) + TopicArn=topic_arn + "-not-existing", + Label="test-2", + AWSAccountId=["999999999999"], + ActionName=["AddPermission"], + ).should.throw(ClientError, "Topic does not exist") client.add_permission.when.called_with( TopicArn=topic_arn, - Label='test-2', - AWSAccountId=['999999999999'], - ActionName=['NotExistingAction'] - ).should.throw( - ClientError, - 'Policy statement action out of service scope!' - ) + Label="test-2", + AWSAccountId=["999999999999"], + ActionName=["NotExistingAction"], + ).should.throw(ClientError, "Policy statement action out of service scope!") @mock_sns def test_remove_permission_errors(): - client = boto3.client('sns', region_name='us-east-1') - topic_arn = client.create_topic(Name='test-permissions')['TopicArn'] + client = boto3.client("sns", region_name="us-east-1") + topic_arn = client.create_topic(Name="test-permissions")["TopicArn"] client.add_permission( TopicArn=topic_arn, - Label='test', - AWSAccountId=['999999999999'], - ActionName=['Publish'] + Label="test", + AWSAccountId=["999999999999"], + ActionName=["Publish"], ) client.remove_permission.when.called_with( - TopicArn=topic_arn + '-not-existing', - Label='test', - ).should.throw( - ClientError, - 'Topic does not exist' - ) + TopicArn=topic_arn + "-not-existing", Label="test" + ).should.throw(ClientError, "Topic does not exist") @mock_sns def test_tag_topic(): - conn = boto3.client('sns', region_name='us-east-1') - response = conn.create_topic( - Name = 'some-topic-with-tags' - ) - topic_arn = response['TopicArn'] + conn = boto3.client("sns", region_name="us-east-1") + response = conn.create_topic(Name="some-topic-with-tags") + topic_arn = response["TopicArn"] conn.tag_resource( - ResourceArn=topic_arn, - Tags=[ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_1' - } - ] + ResourceArn=topic_arn, Tags=[{"Key": "tag_key_1", "Value": "tag_value_1"}] + ) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_1", "Value": "tag_value_1"}] ) - conn.list_tags_for_resource(ResourceArn = topic_arn)['Tags'].should.equal([ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_1' - } - ]) conn.tag_resource( - ResourceArn=topic_arn, - Tags=[ - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } + ResourceArn=topic_arn, Tags=[{"Key": "tag_key_2", "Value": "tag_value_2"}] + ) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, ] ) - conn.list_tags_for_resource(ResourceArn = topic_arn)['Tags'].should.equal([ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_1' - }, - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } - ]) conn.tag_resource( - ResourceArn = topic_arn, - Tags = [ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_X' - } + ResourceArn=topic_arn, Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] + ) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [ + {"Key": "tag_key_1", "Value": "tag_value_X"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, ] ) - conn.list_tags_for_resource(ResourceArn = topic_arn)['Tags'].should.equal([ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_X' - }, - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } - ]) @mock_sns def test_untag_topic(): - conn = boto3.client('sns', region_name = 'us-east-1') + conn = boto3.client("sns", region_name="us-east-1") response = conn.create_topic( - Name = 'some-topic-with-tags', - Tags = [ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_1' - }, - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } - ] + Name="some-topic-with-tags", + Tags=[ + {"Key": "tag_key_1", "Value": "tag_value_1"}, + {"Key": "tag_key_2", "Value": "tag_value_2"}, + ], ) - topic_arn = response['TopicArn'] + topic_arn = response["TopicArn"] - conn.untag_resource( - ResourceArn = topic_arn, - TagKeys = [ - 'tag_key_1' - ] + conn.untag_resource(ResourceArn=topic_arn, TagKeys=["tag_key_1"]) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_2", "Value": "tag_value_2"}] ) - conn.list_tags_for_resource(ResourceArn = topic_arn)['Tags'].should.equal([ - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } - ]) # removing a non existing tag should not raise any error - conn.untag_resource( - ResourceArn = topic_arn, - TagKeys = [ - 'not-existing-tag' - ] + conn.untag_resource(ResourceArn=topic_arn, TagKeys=["not-existing-tag"]) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_2", "Value": "tag_value_2"}] ) - conn.list_tags_for_resource(ResourceArn = topic_arn)['Tags'].should.equal([ - { - 'Key': 'tag_key_2', - 'Value': 'tag_value_2' - } - ]) @mock_sns def test_list_tags_for_resource_error(): - conn = boto3.client('sns', region_name = 'us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic( - Name = 'some-topic-with-tags', - Tags = [ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_X' - } - ] + Name="some-topic-with-tags", Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] ) conn.list_tags_for_resource.when.called_with( - ResourceArn = 'not-existing-topic' - ).should.throw( - ClientError, - 'Resource does not exist' - ) + ResourceArn="not-existing-topic" + ).should.throw(ClientError, "Resource does not exist") @mock_sns def test_tag_resource_errors(): - conn = boto3.client('sns', region_name = 'us-east-1') + conn = boto3.client("sns", region_name="us-east-1") response = conn.create_topic( - Name = 'some-topic-with-tags', - Tags = [ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_X' - } - ] + Name="some-topic-with-tags", Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] ) - topic_arn = response['TopicArn'] + topic_arn = response["TopicArn"] conn.tag_resource.when.called_with( - ResourceArn = 'not-existing-topic', - Tags = [ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_1' - } - ] - ).should.throw( - ClientError, - 'Resource does not exist' - ) + ResourceArn="not-existing-topic", + Tags=[{"Key": "tag_key_1", "Value": "tag_value_1"}], + ).should.throw(ClientError, "Resource does not exist") - too_many_tags = [{'Key': 'tag_key_{}'.format(i), 'Value': 'tag_value_{}'.format(i)} for i in range(51)] + too_many_tags = [ + {"Key": "tag_key_{}".format(i), "Value": "tag_value_{}".format(i)} + for i in range(51) + ] conn.tag_resource.when.called_with( - ResourceArn = topic_arn, - Tags = too_many_tags + ResourceArn=topic_arn, Tags=too_many_tags ).should.throw( - ClientError, - 'Could not complete request: tag quota of per resource exceeded' + ClientError, "Could not complete request: tag quota of per resource exceeded" ) # when the request fails, the tags should not be updated - conn.list_tags_for_resource(ResourceArn = topic_arn)['Tags'].should.equal([ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_X' - } - ]) + conn.list_tags_for_resource(ResourceArn=topic_arn)["Tags"].should.equal( + [{"Key": "tag_key_1", "Value": "tag_value_X"}] + ) @mock_sns def test_untag_resource_error(): - conn = boto3.client('sns', region_name = 'us-east-1') + conn = boto3.client("sns", region_name="us-east-1") conn.create_topic( - Name = 'some-topic-with-tags', - Tags = [ - { - 'Key': 'tag_key_1', - 'Value': 'tag_value_X' - } - ] + Name="some-topic-with-tags", Tags=[{"Key": "tag_key_1", "Value": "tag_value_X"}] ) conn.untag_resource.when.called_with( - ResourceArn = 'not-existing-topic', - TagKeys = [ - 'tag_key_1' - ] - ).should.throw( - ClientError, - 'Resource does not exist' - ) + ResourceArn="not-existing-topic", TagKeys=["tag_key_1"] + ).should.throw(ClientError, "Resource does not exist") diff --git a/tests/test_sqs/test_server.py b/tests/test_sqs/test_server.py index e7f745fd2..0116a93ef 100644 --- a/tests/test_sqs/test_server.py +++ b/tests/test_sqs/test_server.py @@ -7,40 +7,40 @@ import time import moto.server as server -''' +""" Test the different server responses -''' +""" def test_sqs_list_identities(): backend = server.create_backend_app("sqs") test_client = backend.test_client() - res = test_client.get('/?Action=ListQueues') + res = test_client.get("/?Action=ListQueues") res.data.should.contain(b"ListQueuesResponse") # Make sure that we can receive messages from queues whose name contains dots (".") # The AWS API mandates that the names of FIFO queues use the suffix ".fifo" # See: https://github.com/spulec/moto/issues/866 - for queue_name in ('testqueue', 'otherqueue.fifo'): - - res = test_client.put('/?Action=CreateQueue&QueueName=%s' % queue_name) + for queue_name in ("testqueue", "otherqueue.fifo"): + res = test_client.put("/?Action=CreateQueue&QueueName=%s" % queue_name) res = test_client.put( - '/123/%s?MessageBody=test-message&Action=SendMessage' % queue_name) + "/123/%s?MessageBody=test-message&Action=SendMessage" % queue_name + ) res = test_client.get( - '/123/%s?Action=ReceiveMessage&MaxNumberOfMessages=1' % queue_name) + "/123/%s?Action=ReceiveMessage&MaxNumberOfMessages=1" % queue_name + ) - message = re.search("(.*?)", - res.data.decode('utf-8')).groups()[0] - message.should.equal('test-message') + message = re.search("(.*?)", res.data.decode("utf-8")).groups()[0] + message.should.equal("test-message") - res = test_client.get('/?Action=ListQueues&QueueNamePrefix=other') - res.data.should.contain(b'otherqueue.fifo') - res.data.should_not.contain(b'testqueue') + res = test_client.get("/?Action=ListQueues&QueueNamePrefix=other") + res.data.should.contain(b"otherqueue.fifo") + res.data.should_not.contain(b"testqueue") def test_messages_polling(): @@ -48,26 +48,25 @@ def test_messages_polling(): test_client = backend.test_client() messages = [] - test_client.put('/?Action=CreateQueue&QueueName=testqueue') + test_client.put("/?Action=CreateQueue&QueueName=testqueue") def insert_messages(): messages_count = 5 while messages_count > 0: test_client.put( - '/123/testqueue?MessageBody=test-message&Action=SendMessage' - '&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10' + "/123/testqueue?MessageBody=test-message&Action=SendMessage" + "&Attribute.1.Name=WaitTimeSeconds&Attribute.1.Value=10" ) messages_count -= 1 - time.sleep(.5) + time.sleep(0.5) def get_messages(): count = 0 while count < 5: msg_res = test_client.get( - '/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5' + "/123/testqueue?Action=ReceiveMessage&MaxNumberOfMessages=1&WaitTimeSeconds=5" ) - new_msgs = re.findall("(.*?)", - msg_res.data.decode('utf-8')) + new_msgs = re.findall("(.*?)", msg_res.data.decode("utf-8")) count += len(new_msgs) messages.append(new_msgs) diff --git a/tests/test_sqs/test_sqs.py b/tests/test_sqs/test_sqs.py index a2111c9d7..2c1cdd524 100644 --- a/tests/test_sqs/test_sqs.py +++ b/tests/test_sqs/test_sqs.py @@ -26,188 +26,155 @@ from nose import SkipTest @mock_sqs def test_create_fifo_queue_fail(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") try: - sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'FifoQueue': 'true', - } - ) + sqs.create_queue(QueueName="test-queue", Attributes={"FifoQueue": "true"}) except botocore.exceptions.ClientError as err: - err.response['Error']['Code'].should.equal('InvalidParameterValue') + err.response["Error"]["Code"].should.equal("InvalidParameterValue") else: - raise RuntimeError('Should of raised InvalidParameterValue Exception') + raise RuntimeError("Should of raised InvalidParameterValue Exception") @mock_sqs def test_create_queue_with_same_attributes(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") - dlq_url = sqs.create_queue(QueueName='test-queue-dlq')['QueueUrl'] - dlq_arn = sqs.get_queue_attributes(QueueUrl=dlq_url)['Attributes']['QueueArn'] + dlq_url = sqs.create_queue(QueueName="test-queue-dlq")["QueueUrl"] + dlq_arn = sqs.get_queue_attributes(QueueUrl=dlq_url)["Attributes"]["QueueArn"] attributes = { - 'DelaySeconds': '900', - 'MaximumMessageSize': '262144', - 'MessageRetentionPeriod': '1209600', - 'ReceiveMessageWaitTimeSeconds': '20', - 'RedrivePolicy': '{"deadLetterTargetArn": "%s", "maxReceiveCount": 100}' % (dlq_arn), - 'VisibilityTimeout': '43200' + "DelaySeconds": "900", + "MaximumMessageSize": "262144", + "MessageRetentionPeriod": "1209600", + "ReceiveMessageWaitTimeSeconds": "20", + "RedrivePolicy": '{"deadLetterTargetArn": "%s", "maxReceiveCount": 100}' + % (dlq_arn), + "VisibilityTimeout": "43200", } - sqs.create_queue( - QueueName='test-queue', - Attributes=attributes - ) + sqs.create_queue(QueueName="test-queue", Attributes=attributes) - sqs.create_queue( - QueueName='test-queue', - Attributes=attributes - ) + sqs.create_queue(QueueName="test-queue", Attributes=attributes) @mock_sqs def test_create_queue_with_different_attributes_fail(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") - sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'VisibilityTimeout': '10', - } - ) + sqs.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "10"}) try: - sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'VisibilityTimeout': '60', - } - ) + sqs.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "60"}) except botocore.exceptions.ClientError as err: - err.response['Error']['Code'].should.equal('QueueAlreadyExists') + err.response["Error"]["Code"].should.equal("QueueAlreadyExists") else: - raise RuntimeError('Should of raised QueueAlreadyExists Exception') + raise RuntimeError("Should of raised QueueAlreadyExists Exception") @mock_sqs def test_create_fifo_queue(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") resp = sqs.create_queue( - QueueName='test-queue.fifo', - Attributes={ - 'FifoQueue': 'true', - } + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] response = sqs.get_queue_attributes(QueueUrl=queue_url) - response['Attributes'].should.contain('FifoQueue') - response['Attributes']['FifoQueue'].should.equal('true') + response["Attributes"].should.contain("FifoQueue") + response["Attributes"]["FifoQueue"].should.equal("true") @mock_sqs def test_create_queue(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") - new_queue = sqs.create_queue(QueueName='test-queue') + new_queue = sqs.create_queue(QueueName="test-queue") new_queue.should_not.be.none - new_queue.should.have.property('url').should.contain('test-queue') + new_queue.should.have.property("url").should.contain("test-queue") - queue = sqs.get_queue_by_name(QueueName='test-queue') - queue.attributes.get('QueueArn').should_not.be.none - queue.attributes.get('QueueArn').split(':')[-1].should.equal('test-queue') - queue.attributes.get('QueueArn').split(':')[3].should.equal('us-east-1') - queue.attributes.get('VisibilityTimeout').should_not.be.none - queue.attributes.get('VisibilityTimeout').should.equal('30') + queue = sqs.get_queue_by_name(QueueName="test-queue") + queue.attributes.get("QueueArn").should_not.be.none + queue.attributes.get("QueueArn").split(":")[-1].should.equal("test-queue") + queue.attributes.get("QueueArn").split(":")[3].should.equal("us-east-1") + queue.attributes.get("VisibilityTimeout").should_not.be.none + queue.attributes.get("VisibilityTimeout").should.equal("30") @mock_sqs def test_create_queue_kms(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") new_queue = sqs.create_queue( - QueueName='test-queue', + QueueName="test-queue", Attributes={ - 'KmsMasterKeyId': 'master-key-id', - 'KmsDataKeyReusePeriodSeconds': '600' - }) + "KmsMasterKeyId": "master-key-id", + "KmsDataKeyReusePeriodSeconds": "600", + }, + ) new_queue.should_not.be.none - queue = sqs.get_queue_by_name(QueueName='test-queue') + queue = sqs.get_queue_by_name(QueueName="test-queue") - queue.attributes.get('KmsMasterKeyId').should.equal('master-key-id') - queue.attributes.get('KmsDataKeyReusePeriodSeconds').should.equal('600') + queue.attributes.get("KmsMasterKeyId").should.equal("master-key-id") + queue.attributes.get("KmsDataKeyReusePeriodSeconds").should.equal("600") @mock_sqs def test_create_queue_with_tags(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") response = client.create_queue( - QueueName='test-queue-with-tags', - tags={ - 'tag_key_1': 'tag_value_1' - } + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_1"} ) - queue_url = response['QueueUrl'] + queue_url = response["QueueUrl"] - client.list_queue_tags(QueueUrl=queue_url)['Tags'].should.equal({ - 'tag_key_1': 'tag_value_1' - }) + client.list_queue_tags(QueueUrl=queue_url)["Tags"].should.equal( + {"tag_key_1": "tag_value_1"} + ) @mock_sqs def test_get_queue_url(): - client = boto3.client('sqs', region_name='us-east-1') - client.create_queue(QueueName='test-queue') + client = boto3.client("sqs", region_name="us-east-1") + client.create_queue(QueueName="test-queue") - response = client.get_queue_url(QueueName='test-queue') + response = client.get_queue_url(QueueName="test-queue") - response.should.have.key('QueueUrl').which.should.contain('test-queue') + response.should.have.key("QueueUrl").which.should.contain("test-queue") @mock_sqs def test_get_queue_url_errors(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") - client.get_queue_url.when.called_with( - QueueName='non-existing-queue' - ).should.throw( - ClientError, - 'The specified queue does not exist for this wsdl version.' + client.get_queue_url.when.called_with(QueueName="non-existing-queue").should.throw( + ClientError, "The specified queue does not exist for this wsdl version." ) @mock_sqs def test_get_nonexistent_queue(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") with assert_raises(ClientError) as err: - sqs.get_queue_by_name(QueueName='nonexisting-queue') + sqs.get_queue_by_name(QueueName="nonexisting-queue") ex = err.exception - ex.operation_name.should.equal('GetQueueUrl') - ex.response['Error']['Code'].should.equal( - 'AWS.SimpleQueueService.NonExistentQueue') + ex.operation_name.should.equal("GetQueueUrl") + ex.response["Error"]["Code"].should.equal("AWS.SimpleQueueService.NonExistentQueue") with assert_raises(ClientError) as err: - sqs.Queue('http://whatever-incorrect-queue-address').load() + sqs.Queue("http://whatever-incorrect-queue-address").load() ex = err.exception - ex.operation_name.should.equal('GetQueueAttributes') - ex.response['Error']['Code'].should.equal( - 'AWS.SimpleQueueService.NonExistentQueue') + ex.operation_name.should.equal("GetQueueAttributes") + ex.response["Error"]["Code"].should.equal("AWS.SimpleQueueService.NonExistentQueue") @mock_sqs def test_message_send_without_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") - msg = queue.send_message( - MessageBody="derp" - ) - msg.get('MD5OfMessageBody').should.equal( - '58fd9edd83341c29f1aebba81c31e257') - msg.shouldnt.have.key('MD5OfMessageAttributes') - msg.get('MessageId').should_not.contain(' \n') + msg = queue.send_message(MessageBody="derp") + msg.get("MD5OfMessageBody").should.equal("58fd9edd83341c29f1aebba81c31e257") + msg.shouldnt.have.key("MD5OfMessageAttributes") + msg.get("MessageId").should_not.contain(" \n") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -215,22 +182,17 @@ def test_message_send_without_attributes(): @mock_sqs def test_message_send_with_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") msg = queue.send_message( MessageBody="derp", MessageAttributes={ - 'timestamp': { - 'StringValue': '1493147359900', - 'DataType': 'Number', - } - } + "timestamp": {"StringValue": "1493147359900", "DataType": "Number"} + }, ) - msg.get('MD5OfMessageBody').should.equal( - '58fd9edd83341c29f1aebba81c31e257') - msg.get('MD5OfMessageAttributes').should.equal( - '235c5c510d26fb653d073faed50ae77c') - msg.get('MessageId').should_not.contain(' \n') + msg.get("MD5OfMessageBody").should.equal("58fd9edd83341c29f1aebba81c31e257") + msg.get("MD5OfMessageAttributes").should.equal("235c5c510d26fb653d073faed50ae77c") + msg.get("MessageId").should_not.contain(" \n") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -238,22 +200,20 @@ def test_message_send_with_attributes(): @mock_sqs def test_message_with_complex_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") msg = queue.send_message( MessageBody="derp", MessageAttributes={ - 'ccc': {'StringValue': 'testjunk', 'DataType': 'String'}, - 'aaa': {'BinaryValue': b'\x02\x03\x04', 'DataType': 'Binary'}, - 'zzz': {'DataType': 'Number', 'StringValue': '0230.01'}, - 'öther_encodings': {'DataType': 'String', 'StringValue': 'T\xFCst'} - } + "ccc": {"StringValue": "testjunk", "DataType": "String"}, + "aaa": {"BinaryValue": b"\x02\x03\x04", "DataType": "Binary"}, + "zzz": {"DataType": "Number", "StringValue": "0230.01"}, + "öther_encodings": {"DataType": "String", "StringValue": "T\xFCst"}, + }, ) - msg.get('MD5OfMessageBody').should.equal( - '58fd9edd83341c29f1aebba81c31e257') - msg.get('MD5OfMessageAttributes').should.equal( - '8ae21a7957029ef04146b42aeaa18a22') - msg.get('MessageId').should_not.contain(' \n') + msg.get("MD5OfMessageBody").should.equal("58fd9edd83341c29f1aebba81c31e257") + msg.get("MD5OfMessageAttributes").should.equal("8ae21a7957029ef04146b42aeaa18a22") + msg.get("MessageId").should_not.contain(" \n") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -261,9 +221,10 @@ def test_message_with_complex_attributes(): @mock_sqs def test_send_message_with_message_group_id(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-group-id.fifo", - Attributes={'FifoQueue': 'true'}) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-group-id.fifo", Attributes={"FifoQueue": "true"} + ) sent = queue.send_message( MessageBody="mydata", @@ -275,17 +236,17 @@ def test_send_message_with_message_group_id(): messages.should.have.length_of(1) message_attributes = messages[0].attributes - message_attributes.should.contain('MessageGroupId') - message_attributes['MessageGroupId'].should.equal('group_id_1') - message_attributes.should.contain('MessageDeduplicationId') - message_attributes['MessageDeduplicationId'].should.equal('dedupe_id_1') + message_attributes.should.contain("MessageGroupId") + message_attributes["MessageGroupId"].should.equal("group_id_1") + message_attributes.should.contain("MessageDeduplicationId") + message_attributes["MessageDeduplicationId"].should.equal("dedupe_id_1") @mock_sqs def test_send_message_with_unicode_characters(): - body_one = 'Héllo!😀' + body_one = "Héllo!😀" - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") msg = queue.send_message(MessageBody=body_one) @@ -297,68 +258,69 @@ def test_send_message_with_unicode_characters(): @mock_sqs def test_set_queue_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") - queue.attributes['VisibilityTimeout'].should.equal("30") + queue.attributes["VisibilityTimeout"].should.equal("30") queue.set_attributes(Attributes={"VisibilityTimeout": "45"}) - queue.attributes['VisibilityTimeout'].should.equal("45") + queue.attributes["VisibilityTimeout"].should.equal("45") @mock_sqs def test_create_queues_in_multiple_region(): - west1_conn = boto3.client('sqs', region_name='us-west-1') + west1_conn = boto3.client("sqs", region_name="us-west-1") west1_conn.create_queue(QueueName="blah") - west2_conn = boto3.client('sqs', region_name='us-west-2') + west2_conn = boto3.client("sqs", region_name="us-west-2") west2_conn.create_queue(QueueName="test-queue") - list(west1_conn.list_queues()['QueueUrls']).should.have.length_of(1) - list(west2_conn.list_queues()['QueueUrls']).should.have.length_of(1) + list(west1_conn.list_queues()["QueueUrls"]).should.have.length_of(1) + list(west2_conn.list_queues()["QueueUrls"]).should.have.length_of(1) if settings.TEST_SERVER_MODE: - base_url = 'http://localhost:5000' + base_url = "http://localhost:5000" else: - base_url = 'https://us-west-1.queue.amazonaws.com' + base_url = "https://us-west-1.queue.amazonaws.com" - west1_conn.list_queues()['QueueUrls'][0].should.equal( - '{base_url}/123456789012/blah'.format(base_url=base_url)) + west1_conn.list_queues()["QueueUrls"][0].should.equal( + "{base_url}/123456789012/blah".format(base_url=base_url) + ) @mock_sqs def test_get_queue_with_prefix(): - conn = boto3.client("sqs", region_name='us-west-1') + conn = boto3.client("sqs", region_name="us-west-1") conn.create_queue(QueueName="prefixa-queue") conn.create_queue(QueueName="prefixb-queue") conn.create_queue(QueueName="test-queue") - conn.list_queues()['QueueUrls'].should.have.length_of(3) + conn.list_queues()["QueueUrls"].should.have.length_of(3) - queue = conn.list_queues(QueueNamePrefix="test-")['QueueUrls'] + queue = conn.list_queues(QueueNamePrefix="test-")["QueueUrls"] queue.should.have.length_of(1) if settings.TEST_SERVER_MODE: - base_url = 'http://localhost:5000' + base_url = "http://localhost:5000" else: - base_url = 'https://us-west-1.queue.amazonaws.com' + base_url = "https://us-west-1.queue.amazonaws.com" queue[0].should.equal( - "{base_url}/123456789012/test-queue".format(base_url=base_url)) + "{base_url}/123456789012/test-queue".format(base_url=base_url) + ) @mock_sqs def test_delete_queue(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') - conn.create_queue(QueueName="test-queue", - Attributes={"VisibilityTimeout": "3"}) - queue = sqs.Queue('test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") + conn.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "3"}) + queue = sqs.Queue("test-queue") - conn.list_queues()['QueueUrls'].should.have.length_of(1) + conn.list_queues()["QueueUrls"].should.have.length_of(1) queue.delete() - conn.list_queues().get('QueueUrls').should.equal(None) + conn.list_queues().get("QueueUrls").should.equal(None) with assert_raises(botocore.exceptions.ClientError): queue.delete() @@ -366,196 +328,181 @@ def test_delete_queue(): @mock_sqs def test_get_queue_attributes(): - client = boto3.client('sqs', region_name='us-east-1') - response = client.create_queue(QueueName='test-queue') - queue_url = response['QueueUrl'] + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] response = client.get_queue_attributes(QueueUrl=queue_url) - response['Attributes']['ApproximateNumberOfMessages'].should.equal('0') - response['Attributes']['ApproximateNumberOfMessagesDelayed'].should.equal('0') - response['Attributes']['ApproximateNumberOfMessagesNotVisible'].should.equal('0') - response['Attributes']['CreatedTimestamp'].should.be.a(six.string_types) - response['Attributes']['DelaySeconds'].should.equal('0') - response['Attributes']['LastModifiedTimestamp'].should.be.a(six.string_types) - response['Attributes']['MaximumMessageSize'].should.equal('65536') - response['Attributes']['MessageRetentionPeriod'].should.equal('345600') - response['Attributes']['QueueArn'].should.equal('arn:aws:sqs:us-east-1:123456789012:test-queue') - response['Attributes']['ReceiveMessageWaitTimeSeconds'].should.equal('0') - response['Attributes']['VisibilityTimeout'].should.equal('30') + response["Attributes"]["ApproximateNumberOfMessages"].should.equal("0") + response["Attributes"]["ApproximateNumberOfMessagesDelayed"].should.equal("0") + response["Attributes"]["ApproximateNumberOfMessagesNotVisible"].should.equal("0") + response["Attributes"]["CreatedTimestamp"].should.be.a(six.string_types) + response["Attributes"]["DelaySeconds"].should.equal("0") + response["Attributes"]["LastModifiedTimestamp"].should.be.a(six.string_types) + response["Attributes"]["MaximumMessageSize"].should.equal("65536") + response["Attributes"]["MessageRetentionPeriod"].should.equal("345600") + response["Attributes"]["QueueArn"].should.equal( + "arn:aws:sqs:us-east-1:123456789012:test-queue" + ) + response["Attributes"]["ReceiveMessageWaitTimeSeconds"].should.equal("0") + response["Attributes"]["VisibilityTimeout"].should.equal("30") response = client.get_queue_attributes( QueueUrl=queue_url, AttributeNames=[ - 'ApproximateNumberOfMessages', - 'MaximumMessageSize', - 'QueueArn', - 'VisibilityTimeout' - ] + "ApproximateNumberOfMessages", + "MaximumMessageSize", + "QueueArn", + "VisibilityTimeout", + ], ) - response['Attributes'].should.equal({ - 'ApproximateNumberOfMessages': '0', - 'MaximumMessageSize': '65536', - 'QueueArn': 'arn:aws:sqs:us-east-1:123456789012:test-queue', - 'VisibilityTimeout': '30' - }) + response["Attributes"].should.equal( + { + "ApproximateNumberOfMessages": "0", + "MaximumMessageSize": "65536", + "QueueArn": "arn:aws:sqs:us-east-1:123456789012:test-queue", + "VisibilityTimeout": "30", + } + ) # should not return any attributes, if it was not set before response = client.get_queue_attributes( - QueueUrl=queue_url, - AttributeNames=[ - 'KmsMasterKeyId' - ] + QueueUrl=queue_url, AttributeNames=["KmsMasterKeyId"] ) - response.should_not.have.key('Attributes') + response.should_not.have.key("Attributes") @mock_sqs def test_get_queue_attributes_errors(): - client = boto3.client('sqs', region_name='us-east-1') - response = client.create_queue(QueueName='test-queue') - queue_url = response['QueueUrl'] + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] client.get_queue_attributes.when.called_with( - QueueUrl=queue_url + '-non-existing' + QueueUrl=queue_url + "-non-existing" ).should.throw( - ClientError, - 'The specified queue does not exist for this wsdl version.' + ClientError, "The specified queue does not exist for this wsdl version." ) client.get_queue_attributes.when.called_with( QueueUrl=queue_url, - AttributeNames=[ - 'QueueArn', - 'not-existing', - 'VisibilityTimeout' - ] - ).should.throw( - ClientError, - 'Unknown Attribute not-existing.' - ) + AttributeNames=["QueueArn", "not-existing", "VisibilityTimeout"], + ).should.throw(ClientError, "Unknown Attribute not-existing.") client.get_queue_attributes.when.called_with( - QueueUrl=queue_url, - AttributeNames=[ - '' - ] - ).should.throw( - ClientError, - 'Unknown Attribute .' - ) + QueueUrl=queue_url, AttributeNames=[""] + ).should.throw(ClientError, "Unknown Attribute .") client.get_queue_attributes.when.called_with( - QueueUrl = queue_url, - AttributeNames = [] - ).should.throw( - ClientError, - 'Unknown Attribute .' - ) + QueueUrl=queue_url, AttributeNames=[] + ).should.throw(ClientError, "Unknown Attribute .") @mock_sqs def test_set_queue_attribute(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') - conn.create_queue(QueueName="test-queue", - Attributes={"VisibilityTimeout": '3'}) + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") + conn.create_queue(QueueName="test-queue", Attributes={"VisibilityTimeout": "3"}) queue = sqs.Queue("test-queue") - queue.attributes['VisibilityTimeout'].should.equal('3') + queue.attributes["VisibilityTimeout"].should.equal("3") - queue.set_attributes(Attributes={"VisibilityTimeout": '45'}) + queue.set_attributes(Attributes={"VisibilityTimeout": "45"}) queue = sqs.Queue("test-queue") - queue.attributes['VisibilityTimeout'].should.equal('45') + queue.attributes["VisibilityTimeout"].should.equal("45") @mock_sqs def test_send_receive_message_without_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") conn.create_queue(QueueName="test-queue") queue = sqs.Queue("test-queue") - body_one = 'this is a test message' - body_two = 'this is another test message' + body_one = "this is a test message" + body_two = "this is another test message" queue.send_message(MessageBody=body_one) queue.send_message(MessageBody=body_two) - messages = conn.receive_message( - QueueUrl=queue.url, MaxNumberOfMessages=2)['Messages'] + messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=2)[ + "Messages" + ] message1 = messages[0] message2 = messages[1] - message1['Body'].should.equal(body_one) - message2['Body'].should.equal(body_two) + message1["Body"].should.equal(body_one) + message2["Body"].should.equal(body_two) - message1.shouldnt.have.key('MD5OfMessageAttributes') - message2.shouldnt.have.key('MD5OfMessageAttributes') + message1.shouldnt.have.key("MD5OfMessageAttributes") + message2.shouldnt.have.key("MD5OfMessageAttributes") @mock_sqs def test_send_receive_message_with_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") conn.create_queue(QueueName="test-queue") queue = sqs.Queue("test-queue") - body_one = 'this is a test message' - body_two = 'this is another test message' + body_one = "this is a test message" + body_two = "this is another test message" queue.send_message( MessageBody=body_one, MessageAttributes={ - 'timestamp': { - 'StringValue': '1493147359900', - 'DataType': 'Number', - } - } + "timestamp": {"StringValue": "1493147359900", "DataType": "Number"} + }, ) queue.send_message( MessageBody=body_two, MessageAttributes={ - 'timestamp': { - 'StringValue': '1493147359901', - 'DataType': 'Number', - } - } + "timestamp": {"StringValue": "1493147359901", "DataType": "Number"} + }, ) - messages = conn.receive_message( - QueueUrl=queue.url, MaxNumberOfMessages=2)['Messages'] + messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=2)[ + "Messages" + ] message1 = messages[0] message2 = messages[1] - message1.get('Body').should.equal(body_one) - message2.get('Body').should.equal(body_two) + message1.get("Body").should.equal(body_one) + message2.get("Body").should.equal(body_two) - message1.get('MD5OfMessageAttributes').should.equal('235c5c510d26fb653d073faed50ae77c') - message2.get('MD5OfMessageAttributes').should.equal('994258b45346a2cc3f9cbb611aa7af30') + message1.get("MD5OfMessageAttributes").should.equal( + "235c5c510d26fb653d073faed50ae77c" + ) + message2.get("MD5OfMessageAttributes").should.equal( + "994258b45346a2cc3f9cbb611aa7af30" + ) @mock_sqs def test_send_receive_message_timestamps(): - sqs = boto3.resource('sqs', region_name='us-east-1') - conn = boto3.client("sqs", region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") + conn = boto3.client("sqs", region_name="us-east-1") conn.create_queue(QueueName="test-queue") queue = sqs.Queue("test-queue") response = queue.send_message(MessageBody="derp") - assert response['ResponseMetadata']['RequestId'] + assert response["ResponseMetadata"]["RequestId"] - messages = conn.receive_message( - QueueUrl=queue.url, MaxNumberOfMessages=1)['Messages'] + messages = conn.receive_message(QueueUrl=queue.url, MaxNumberOfMessages=1)[ + "Messages" + ] message = messages[0] - sent_timestamp = message.get('Attributes').get('SentTimestamp') - approximate_first_receive_timestamp = message.get('Attributes').get('ApproximateFirstReceiveTimestamp') + sent_timestamp = message.get("Attributes").get("SentTimestamp") + approximate_first_receive_timestamp = message.get("Attributes").get( + "ApproximateFirstReceiveTimestamp" + ) int.when.called_with(sent_timestamp).shouldnt.throw(ValueError) int.when.called_with(approximate_first_receive_timestamp).shouldnt.throw(ValueError) @@ -563,8 +510,8 @@ def test_send_receive_message_timestamps(): @mock_sqs def test_max_number_of_messages_invalid_param(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") with assert_raises(ClientError): queue.receive_messages(MaxNumberOfMessages=11) @@ -578,8 +525,8 @@ def test_max_number_of_messages_invalid_param(): @mock_sqs def test_wait_time_seconds_invalid_param(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName='test-queue') + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue(QueueName="test-queue") with assert_raises(ClientError): queue.receive_messages(WaitTimeSeconds=-1) @@ -599,7 +546,7 @@ def test_receive_messages_with_wait_seconds_timeout_of_zero(): :return: """ - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue(QueueName="blah") messages = queue.receive_messages(WaitTimeSeconds=0) @@ -608,11 +555,11 @@ def test_receive_messages_with_wait_seconds_timeout_of_zero(): @mock_sqs_deprecated def test_send_message_with_xml_characters(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = '< & >' + body_one = "< & >" queue.write(queue.new_message(body_one)) @@ -624,17 +571,23 @@ def test_send_message_with_xml_characters(): @requires_boto_gte("2.28") @mock_sqs_deprecated def test_send_message_with_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body = 'this is a test message' + body = "this is a test message" message = queue.new_message(body) - BASE64_BINARY = base64.b64encode(b'binary value').decode('utf-8') + BASE64_BINARY = base64.b64encode(b"binary value").decode("utf-8") message_attributes = { - 'test.attribute_name': {'data_type': 'String', 'string_value': 'attribute value'}, - 'test.binary_attribute': {'data_type': 'Binary', 'binary_value': BASE64_BINARY}, - 'test.number_attribute': {'data_type': 'Number', 'string_value': 'string value'} + "test.attribute_name": { + "data_type": "String", + "string_value": "attribute value", + }, + "test.binary_attribute": {"data_type": "Binary", "binary_value": BASE64_BINARY}, + "test.number_attribute": { + "data_type": "Number", + "string_value": "string value", + }, } message.message_attributes = message_attributes @@ -650,12 +603,12 @@ def test_send_message_with_attributes(): @mock_sqs_deprecated def test_send_message_with_delay(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = 'this is a test message' - body_two = 'this is another test message' + body_one = "this is a test message" + body_two = "this is another test message" queue.write(queue.new_message(body_one), delay_seconds=3) queue.write(queue.new_message(body_two)) @@ -671,11 +624,11 @@ def test_send_message_with_delay(): @mock_sqs_deprecated def test_send_large_message_fails(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = 'test message' * 200000 + body_one = "test message" * 200000 huge_message = queue.new_message(body_one) queue.write.when.called_with(huge_message).should.throw(SQSError) @@ -683,11 +636,11 @@ def test_send_large_message_fails(): @mock_sqs_deprecated def test_message_becomes_inflight_when_received(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=2) queue.set_message_class(RawMessage) - body_one = 'this is a test message' + body_one = "this is a test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) @@ -704,16 +657,15 @@ def test_message_becomes_inflight_when_received(): @mock_sqs_deprecated def test_receive_message_with_explicit_visibility_timeout(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - body_one = 'this is another test message' + body_one = "this is another test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) - messages = conn.receive_message( - queue, number_messages=1, visibility_timeout=0) + messages = conn.receive_message(queue, number_messages=1, visibility_timeout=0) assert len(messages) == 1 @@ -723,11 +675,11 @@ def test_receive_message_with_explicit_visibility_timeout(): @mock_sqs_deprecated def test_change_message_visibility(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=2) queue.set_message_class(RawMessage) - body_one = 'this is another test message' + body_one = "this is another test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) @@ -757,11 +709,11 @@ def test_change_message_visibility(): @mock_sqs_deprecated def test_message_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=2) queue.set_message_class(RawMessage) - body_one = 'this is another test message' + body_one = "this is another test message" queue.write(queue.new_message(body_one)) queue.count().should.equal(1) @@ -773,19 +725,19 @@ def test_message_attributes(): message_attributes = messages[0].attributes - assert message_attributes.get('ApproximateFirstReceiveTimestamp') - assert int(message_attributes.get('ApproximateReceiveCount')) == 1 - assert message_attributes.get('SentTimestamp') - assert message_attributes.get('SenderId') + assert message_attributes.get("ApproximateFirstReceiveTimestamp") + assert int(message_attributes.get("ApproximateReceiveCount")) == 1 + assert message_attributes.get("SentTimestamp") + assert message_attributes.get("SenderId") @mock_sqs_deprecated def test_read_message_from_queue(): conn = boto.connect_sqs() - queue = conn.create_queue('testqueue') + queue = conn.create_queue("testqueue") queue.set_message_class(RawMessage) - body = 'foo bar baz' + body = "foo bar baz" queue.write(queue.new_message(body)) message = queue.read(1) message.get_body().should.equal(body) @@ -793,23 +745,23 @@ def test_read_message_from_queue(): @mock_sqs_deprecated def test_queue_length(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is a test message')) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is a test message")) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(2) @mock_sqs_deprecated def test_delete_message(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is a test message')) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is a test message")) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(2) messages = conn.receive_message(queue, number_messages=1) @@ -825,17 +777,19 @@ def test_delete_message(): @mock_sqs_deprecated def test_send_batch_operation(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) # See https://github.com/boto/boto/issues/831 queue.set_message_class(RawMessage) - queue.write_batch([ - ("my_first_message", 'test message 1', 0), - ("my_second_message", 'test message 2', 0), - ("my_third_message", 'test message 3', 0), - ]) + queue.write_batch( + [ + ("my_first_message", "test message 1", 0), + ("my_second_message", "test message 2", 0), + ("my_third_message", "test message 3", 0), + ] + ) messages = queue.get_messages(3) messages[0].get_body().should.equal("test message 1") @@ -847,12 +801,16 @@ def test_send_batch_operation(): @requires_boto_gte("2.28") @mock_sqs_deprecated def test_send_batch_operation_with_message_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) queue.set_message_class(RawMessage) - message_tuple = ("my_first_message", 'test message 1', 0, { - 'name1': {'data_type': 'String', 'string_value': 'foo'}}) + message_tuple = ( + "my_first_message", + "test message 1", + 0, + {"name1": {"data_type": "String", "string_value": "foo"}}, + ) queue.write_batch([message_tuple]) messages = queue.get_messages() @@ -864,14 +822,17 @@ def test_send_batch_operation_with_message_attributes(): @mock_sqs_deprecated def test_delete_batch_operation(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=3) - conn.send_message_batch(queue, [ - ("my_first_message", 'test message 1', 0), - ("my_second_message", 'test message 2', 0), - ("my_third_message", 'test message 3', 0), - ]) + conn.send_message_batch( + queue, + [ + ("my_first_message", "test message 1", 0), + ("my_second_message", "test message 2", 0), + ("my_third_message", "test message 3", 0), + ], + ) messages = queue.get_messages(2) queue.delete_message_batch(messages) @@ -881,42 +842,42 @@ def test_delete_batch_operation(): @mock_sqs_deprecated def test_queue_attributes(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") - queue_name = 'test-queue' + queue_name = "test-queue" visibility_timeout = 3 - queue = conn.create_queue( - queue_name, visibility_timeout=visibility_timeout) + queue = conn.create_queue(queue_name, visibility_timeout=visibility_timeout) attributes = queue.get_attributes() - attributes['QueueArn'].should.look_like( - 'arn:aws:sqs:us-east-1:123456789012:%s' % queue_name) + attributes["QueueArn"].should.look_like( + "arn:aws:sqs:us-east-1:123456789012:%s" % queue_name + ) - attributes['VisibilityTimeout'].should.look_like(str(visibility_timeout)) + attributes["VisibilityTimeout"].should.look_like(str(visibility_timeout)) attribute_names = queue.get_attributes().keys() - attribute_names.should.contain('ApproximateNumberOfMessagesNotVisible') - attribute_names.should.contain('MessageRetentionPeriod') - attribute_names.should.contain('ApproximateNumberOfMessagesDelayed') - attribute_names.should.contain('MaximumMessageSize') - attribute_names.should.contain('CreatedTimestamp') - attribute_names.should.contain('ApproximateNumberOfMessages') - attribute_names.should.contain('ReceiveMessageWaitTimeSeconds') - attribute_names.should.contain('DelaySeconds') - attribute_names.should.contain('VisibilityTimeout') - attribute_names.should.contain('LastModifiedTimestamp') - attribute_names.should.contain('QueueArn') + attribute_names.should.contain("ApproximateNumberOfMessagesNotVisible") + attribute_names.should.contain("MessageRetentionPeriod") + attribute_names.should.contain("ApproximateNumberOfMessagesDelayed") + attribute_names.should.contain("MaximumMessageSize") + attribute_names.should.contain("CreatedTimestamp") + attribute_names.should.contain("ApproximateNumberOfMessages") + attribute_names.should.contain("ReceiveMessageWaitTimeSeconds") + attribute_names.should.contain("DelaySeconds") + attribute_names.should.contain("VisibilityTimeout") + attribute_names.should.contain("LastModifiedTimestamp") + attribute_names.should.contain("QueueArn") @mock_sqs_deprecated def test_change_message_visibility_on_invalid_receipt(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=1) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(1) messages = conn.receive_message(queue, number_messages=1) @@ -934,17 +895,16 @@ def test_change_message_visibility_on_invalid_receipt(): assert len(messages) == 1 - original_message.change_visibility.when.called_with( - 100).should.throw(SQSError) + original_message.change_visibility.when.called_with(100).should.throw(SQSError) @mock_sqs_deprecated def test_change_message_visibility_on_visible_message(): - conn = boto.connect_sqs('the_key', 'the_secret') + conn = boto.connect_sqs("the_key", "the_secret") queue = conn.create_queue("test-queue", visibility_timeout=1) queue.set_message_class(RawMessage) - queue.write(queue.new_message('this is another test message')) + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(1) messages = conn.receive_message(queue, number_messages=1) @@ -958,16 +918,15 @@ def test_change_message_visibility_on_visible_message(): queue.count().should.equal(1) - original_message.change_visibility.when.called_with( - 100).should.throw(SQSError) + original_message.change_visibility.when.called_with(100).should.throw(SQSError) @mock_sqs_deprecated def test_purge_action(): conn = boto.sqs.connect_to_region("us-east-1") - queue = conn.create_queue('new-queue') - queue.write(queue.new_message('this is another test message')) + queue = conn.create_queue("new-queue") + queue.write(queue.new_message("this is another test message")) queue.count().should.equal(1) queue.purge() @@ -979,11 +938,10 @@ def test_purge_action(): def test_delete_message_after_visibility_timeout(): VISIBILITY_TIMEOUT = 1 conn = boto.sqs.connect_to_region("us-east-1") - new_queue = conn.create_queue( - 'new-queue', visibility_timeout=VISIBILITY_TIMEOUT) + new_queue = conn.create_queue("new-queue", visibility_timeout=VISIBILITY_TIMEOUT) m1 = Message() - m1.set_body('Message 1!') + m1.set_body("Message 1!") new_queue.write(m1) assert new_queue.count() == 1 @@ -999,613 +957,519 @@ def test_delete_message_after_visibility_timeout(): @mock_sqs def test_delete_message_errors(): - client = boto3.client('sqs', region_name='us-east-1') - response = client.create_queue(QueueName='test-queue') - queue_url = response['QueueUrl'] - client.send_message( - QueueUrl=queue_url, - MessageBody='body' - ) - response = client.receive_message( - QueueUrl=queue_url - ) - receipt_handle = response['Messages'][0]['ReceiptHandle'] + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] + client.send_message(QueueUrl=queue_url, MessageBody="body") + response = client.receive_message(QueueUrl=queue_url) + receipt_handle = response["Messages"][0]["ReceiptHandle"] client.delete_message.when.called_with( - QueueUrl=queue_url + '-not-existing', - ReceiptHandle=receipt_handle + QueueUrl=queue_url + "-not-existing", ReceiptHandle=receipt_handle ).should.throw( - ClientError, - 'The specified queue does not exist for this wsdl version.' + ClientError, "The specified queue does not exist for this wsdl version." ) client.delete_message.when.called_with( - QueueUrl=queue_url, - ReceiptHandle='not-existing' - ).should.throw( - ClientError, - 'The input receipt handle is invalid.' - ) + QueueUrl=queue_url, ReceiptHandle="not-existing" + ).should.throw(ClientError, "The input receipt handle is invalid.") + @mock_sqs def test_send_message_batch(): - client = boto3.client('sqs', region_name='us-east-1') - response = client.create_queue(QueueName='test-queue') - queue_url = response['QueueUrl'] + client = boto3.client("sqs", region_name="us-east-1") + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] response = client.send_message_batch( QueueUrl=queue_url, Entries=[ { - 'Id': 'id_1', - 'MessageBody': 'body_1', - 'DelaySeconds': 0, - 'MessageAttributes': { - 'attribute_name_1': { - 'StringValue': 'attribute_value_1', - 'DataType': 'String' + "Id": "id_1", + "MessageBody": "body_1", + "DelaySeconds": 0, + "MessageAttributes": { + "attribute_name_1": { + "StringValue": "attribute_value_1", + "DataType": "String", } - } + }, }, { - 'Id': 'id_2', - 'MessageBody': 'body_2', - 'DelaySeconds': 0, - 'MessageAttributes': { - 'attribute_name_2': { - 'StringValue': '123', - 'DataType': 'Number' - } - } - } - ] + "Id": "id_2", + "MessageBody": "body_2", + "DelaySeconds": 0, + "MessageAttributes": { + "attribute_name_2": {"StringValue": "123", "DataType": "Number"} + }, + }, + ], ) - sorted([entry['Id'] for entry in response['Successful']]).should.equal([ - 'id_1', - 'id_2' - ]) - - response = client.receive_message( - QueueUrl=queue_url, - MaxNumberOfMessages=10 + sorted([entry["Id"] for entry in response["Successful"]]).should.equal( + ["id_1", "id_2"] ) - response['Messages'][0]['Body'].should.equal('body_1') - response['Messages'][0]['MessageAttributes'].should.equal({ - 'attribute_name_1': { - 'StringValue': 'attribute_value_1', - 'DataType': 'String' - } - }) - response['Messages'][1]['Body'].should.equal('body_2') - response['Messages'][1]['MessageAttributes'].should.equal({ - 'attribute_name_2': { - 'StringValue': '123', - 'DataType': 'Number' - } - }) + response = client.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=10) + + response["Messages"][0]["Body"].should.equal("body_1") + response["Messages"][0]["MessageAttributes"].should.equal( + {"attribute_name_1": {"StringValue": "attribute_value_1", "DataType": "String"}} + ) + response["Messages"][1]["Body"].should.equal("body_2") + response["Messages"][1]["MessageAttributes"].should.equal( + {"attribute_name_2": {"StringValue": "123", "DataType": "Number"}} + ) @mock_sqs def test_send_message_batch_errors(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") - response = client.create_queue(QueueName='test-queue') - queue_url = response['QueueUrl'] + response = client.create_queue(QueueName="test-queue") + queue_url = response["QueueUrl"] client.send_message_batch.when.called_with( - QueueUrl=queue_url + '-not-existing', - Entries=[ - { - 'Id': 'id_1', - 'MessageBody': 'body_1' - } - ] + QueueUrl=queue_url + "-not-existing", + Entries=[{"Id": "id_1", "MessageBody": "body_1"}], ).should.throw( - ClientError, - 'The specified queue does not exist for this wsdl version.' + ClientError, "The specified queue does not exist for this wsdl version." ) client.send_message_batch.when.called_with( - QueueUrl=queue_url, - Entries=[] + QueueUrl=queue_url, Entries=[] ).should.throw( ClientError, - 'There should be at least one SendMessageBatchRequestEntry in the request.' + "There should be at least one SendMessageBatchRequestEntry in the request.", ) client.send_message_batch.when.called_with( - QueueUrl=queue_url, - Entries=[ - { - 'Id': '', - 'MessageBody': 'body_1' - } - ] + QueueUrl=queue_url, Entries=[{"Id": "", "MessageBody": "body_1"}] ).should.throw( ClientError, - 'A batch entry id can only contain alphanumeric characters, ' - 'hyphens and underscores. It can be at most 80 letters long.' + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", ) client.send_message_batch.when.called_with( - QueueUrl=queue_url, - Entries=[ - { - 'Id': '.!@#$%^&*()+=', - 'MessageBody': 'body_1' - } - ] + QueueUrl=queue_url, Entries=[{"Id": ".!@#$%^&*()+=", "MessageBody": "body_1"}] ).should.throw( ClientError, - 'A batch entry id can only contain alphanumeric characters, ' - 'hyphens and underscores. It can be at most 80 letters long.' + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", ) client.send_message_batch.when.called_with( - QueueUrl=queue_url, - Entries=[ - { - 'Id': 'i' * 81, - 'MessageBody': 'body_1' - } - ] + QueueUrl=queue_url, Entries=[{"Id": "i" * 81, "MessageBody": "body_1"}] ).should.throw( ClientError, - 'A batch entry id can only contain alphanumeric characters, ' - 'hyphens and underscores. It can be at most 80 letters long.' + "A batch entry id can only contain alphanumeric characters, " + "hyphens and underscores. It can be at most 80 letters long.", ) client.send_message_batch.when.called_with( - QueueUrl=queue_url, - Entries=[ - { - 'Id': 'id_1', - 'MessageBody': 'b' * 262145 - } - ] + QueueUrl=queue_url, Entries=[{"Id": "id_1", "MessageBody": "b" * 262145}] ).should.throw( ClientError, - 'Batch requests cannot be longer than 262144 bytes. ' - 'You have sent 262145 bytes.' + "Batch requests cannot be longer than 262144 bytes. " + "You have sent 262145 bytes.", ) # only the first duplicated Id is reported client.send_message_batch.when.called_with( QueueUrl=queue_url, Entries=[ - { - 'Id': 'id_1', - 'MessageBody': 'body_1' - }, - { - 'Id': 'id_2', - 'MessageBody': 'body_2' - }, - { - 'Id': 'id_2', - 'MessageBody': 'body_2' - }, - { - 'Id': 'id_1', - 'MessageBody': 'body_1' - } - ] - ).should.throw( - ClientError, - 'Id id_2 repeated.' - ) + {"Id": "id_1", "MessageBody": "body_1"}, + {"Id": "id_2", "MessageBody": "body_2"}, + {"Id": "id_2", "MessageBody": "body_2"}, + {"Id": "id_1", "MessageBody": "body_1"}, + ], + ).should.throw(ClientError, "Id id_2 repeated.") - entries = [{'Id': 'id_{}'.format(i), 'MessageBody': 'body_{}'.format(i)} for i in range(11)] + entries = [ + {"Id": "id_{}".format(i), "MessageBody": "body_{}".format(i)} for i in range(11) + ] client.send_message_batch.when.called_with( - QueueUrl=queue_url, - Entries=entries + QueueUrl=queue_url, Entries=entries ).should.throw( ClientError, - 'Maximum number of entries per request are 10. ' - 'You have sent 11.' + "Maximum number of entries per request are 10. " "You have sent 11.", ) @mock_sqs def test_batch_change_message_visibility(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") with freeze_time("2015-01-01 12:00:00"): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") resp = sqs.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] - sqs.send_message(QueueUrl=queue_url, MessageBody='msg1') - sqs.send_message(QueueUrl=queue_url, MessageBody='msg2') - sqs.send_message(QueueUrl=queue_url, MessageBody='msg3') + sqs.send_message(QueueUrl=queue_url, MessageBody="msg1") + sqs.send_message(QueueUrl=queue_url, MessageBody="msg2") + sqs.send_message(QueueUrl=queue_url, MessageBody="msg3") with freeze_time("2015-01-01 12:01:00"): receive_resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=2) - len(receive_resp['Messages']).should.equal(2) + len(receive_resp["Messages"]).should.equal(2) - handles = [item['ReceiptHandle'] for item in receive_resp['Messages']] - entries = [{'Id': str(uuid.uuid4()), 'ReceiptHandle': handle, 'VisibilityTimeout': 43200} for handle in handles] + handles = [item["ReceiptHandle"] for item in receive_resp["Messages"]] + entries = [ + { + "Id": str(uuid.uuid4()), + "ReceiptHandle": handle, + "VisibilityTimeout": 43200, + } + for handle in handles + ] resp = sqs.change_message_visibility_batch(QueueUrl=queue_url, Entries=entries) - len(resp['Successful']).should.equal(2) + len(resp["Successful"]).should.equal(2) with freeze_time("2015-01-01 14:00:00"): resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=3) - len(resp['Messages']).should.equal(1) + len(resp["Messages"]).should.equal(1) with freeze_time("2015-01-01 16:00:00"): resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=3) - len(resp['Messages']).should.equal(1) + len(resp["Messages"]).should.equal(1) with freeze_time("2015-01-02 12:00:00"): resp = sqs.receive_message(QueueUrl=queue_url, MaxNumberOfMessages=3) - len(resp['Messages']).should.equal(3) + len(resp["Messages"]).should.equal(3) @mock_sqs def test_permissions(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") resp = client.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] - client.add_permission(QueueUrl=queue_url, Label='account1', AWSAccountIds=['111111111111'], Actions=['*']) - client.add_permission(QueueUrl=queue_url, Label='account2', AWSAccountIds=['222211111111'], Actions=['SendMessage']) + client.add_permission( + QueueUrl=queue_url, + Label="account1", + AWSAccountIds=["111111111111"], + Actions=["*"], + ) + client.add_permission( + QueueUrl=queue_url, + Label="account2", + AWSAccountIds=["222211111111"], + Actions=["SendMessage"], + ) with assert_raises(ClientError): - client.add_permission(QueueUrl=queue_url, Label='account2', AWSAccountIds=['222211111111'], Actions=['SomeRubbish']) + client.add_permission( + QueueUrl=queue_url, + Label="account2", + AWSAccountIds=["222211111111"], + Actions=["SomeRubbish"], + ) - client.remove_permission(QueueUrl=queue_url, Label='account2') + client.remove_permission(QueueUrl=queue_url, Label="account2") with assert_raises(ClientError): - client.remove_permission(QueueUrl=queue_url, Label='non_existant') + client.remove_permission(QueueUrl=queue_url, Label="non_existant") @mock_sqs def test_tags(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") resp = client.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url = resp['QueueUrl'] + queue_url = resp["QueueUrl"] - client.tag_queue( - QueueUrl=queue_url, - Tags={ - 'test1': 'value1', - 'test2': 'value2', - } - ) + client.tag_queue(QueueUrl=queue_url, Tags={"test1": "value1", "test2": "value2"}) resp = client.list_queue_tags(QueueUrl=queue_url) - resp['Tags'].should.contain('test1') - resp['Tags'].should.contain('test2') + resp["Tags"].should.contain("test1") + resp["Tags"].should.contain("test2") - client.untag_queue( - QueueUrl=queue_url, - TagKeys=['test2'] - ) + client.untag_queue(QueueUrl=queue_url, TagKeys=["test2"]) resp = client.list_queue_tags(QueueUrl=queue_url) - resp['Tags'].should.contain('test1') - resp['Tags'].should_not.contain('test2') + resp["Tags"].should.contain("test1") + resp["Tags"].should_not.contain("test2") # removing a non existing tag should not raise any error - client.untag_queue( - QueueUrl=queue_url, - TagKeys=[ - 'not-existing-tag' - ] - ) - client.list_queue_tags(QueueUrl=queue_url)['Tags'].should.equal({ - 'test1': 'value1' - }) + client.untag_queue(QueueUrl=queue_url, TagKeys=["not-existing-tag"]) + client.list_queue_tags(QueueUrl=queue_url)["Tags"].should.equal({"test1": "value1"}) @mock_sqs def test_list_queue_tags_errors(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") response = client.create_queue( - QueueName='test-queue-with-tags', - tags={ - 'tag_key_1': 'tag_value_X' - } + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_X"} ) - queue_url = response['QueueUrl'] + queue_url = response["QueueUrl"] client.list_queue_tags.when.called_with( - QueueUrl=queue_url + '-not-existing', + QueueUrl=queue_url + "-not-existing" ).should.throw( - ClientError, - 'The specified queue does not exist for this wsdl version.' + ClientError, "The specified queue does not exist for this wsdl version." ) @mock_sqs def test_tag_queue_errors(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") response = client.create_queue( - QueueName='test-queue-with-tags', - tags={ - 'tag_key_1': 'tag_value_X' - } + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_X"} ) - queue_url = response['QueueUrl'] + queue_url = response["QueueUrl"] client.tag_queue.when.called_with( - QueueUrl=queue_url + '-not-existing', - Tags={ - 'tag_key_1': 'tag_value_1' - } + QueueUrl=queue_url + "-not-existing", Tags={"tag_key_1": "tag_value_1"} ).should.throw( - ClientError, - 'The specified queue does not exist for this wsdl version.' + ClientError, "The specified queue does not exist for this wsdl version." ) - client.tag_queue.when.called_with( - QueueUrl=queue_url, - Tags={} - ).should.throw( - ClientError, - 'The request must contain the parameter Tags.' + client.tag_queue.when.called_with(QueueUrl=queue_url, Tags={}).should.throw( + ClientError, "The request must contain the parameter Tags." ) - too_many_tags = {'tag_key_{}'.format(i): 'tag_value_{}'.format(i) for i in range(51)} + too_many_tags = { + "tag_key_{}".format(i): "tag_value_{}".format(i) for i in range(51) + } client.tag_queue.when.called_with( - QueueUrl=queue_url, - Tags=too_many_tags - ).should.throw( - ClientError, - 'Too many tags added for queue test-queue-with-tags.' - ) + QueueUrl=queue_url, Tags=too_many_tags + ).should.throw(ClientError, "Too many tags added for queue test-queue-with-tags.") # when the request fails, the tags should not be updated - client.list_queue_tags(QueueUrl=queue_url)['Tags'].should.equal( - { - 'tag_key_1': 'tag_value_X' - } + client.list_queue_tags(QueueUrl=queue_url)["Tags"].should.equal( + {"tag_key_1": "tag_value_X"} ) @mock_sqs def test_untag_queue_errors(): - client = boto3.client('sqs', region_name='us-east-1') + client = boto3.client("sqs", region_name="us-east-1") response = client.create_queue( - QueueName='test-queue-with-tags', - tags={ - 'tag_key_1': 'tag_value_1' - } + QueueName="test-queue-with-tags", tags={"tag_key_1": "tag_value_1"} ) - queue_url = response['QueueUrl'] + queue_url = response["QueueUrl"] client.untag_queue.when.called_with( - QueueUrl=queue_url + '-not-existing', - TagKeys=[ - 'tag_key_1' - ] + QueueUrl=queue_url + "-not-existing", TagKeys=["tag_key_1"] ).should.throw( - ClientError, - 'The specified queue does not exist for this wsdl version.' + ClientError, "The specified queue does not exist for this wsdl version." ) - client.untag_queue.when.called_with( - QueueUrl=queue_url, - TagKeys=[] - ).should.throw( - ClientError, - 'Tag keys must be between 1 and 128 characters in length.' + client.untag_queue.when.called_with(QueueUrl=queue_url, TagKeys=[]).should.throw( + ClientError, "Tag keys must be between 1 and 128 characters in length." ) @mock_sqs def test_create_fifo_queue_with_dlq(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") resp = sqs.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url1 = resp['QueueUrl'] - queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)['Attributes']['QueueArn'] + queue_url1 = resp["QueueUrl"] + queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)["Attributes"]["QueueArn"] resp = sqs.create_queue( - QueueName='test-dlr-queue', - Attributes={'FifoQueue': 'false'} + QueueName="test-dlr-queue", Attributes={"FifoQueue": "false"} ) - queue_url2 = resp['QueueUrl'] - queue_arn2 = sqs.get_queue_attributes(QueueUrl=queue_url2)['Attributes']['QueueArn'] + queue_url2 = resp["QueueUrl"] + queue_arn2 = sqs.get_queue_attributes(QueueUrl=queue_url2)["Attributes"]["QueueArn"] sqs.create_queue( - QueueName='test-queue.fifo', + QueueName="test-queue.fifo", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn1, 'maxReceiveCount': 2}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps( + {"deadLetterTargetArn": queue_arn1, "maxReceiveCount": 2} + ), + }, ) # Cant have fifo queue with non fifo DLQ with assert_raises(ClientError): sqs.create_queue( - QueueName='test-queue2.fifo', + QueueName="test-queue2.fifo", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn2, 'maxReceiveCount': 2}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps( + {"deadLetterTargetArn": queue_arn2, "maxReceiveCount": 2} + ), + }, ) @mock_sqs def test_queue_with_dlq(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") with freeze_time("2015-01-01 12:00:00"): resp = sqs.create_queue( - QueueName='test-dlr-queue.fifo', - Attributes={'FifoQueue': 'true'} + QueueName="test-dlr-queue.fifo", Attributes={"FifoQueue": "true"} ) - queue_url1 = resp['QueueUrl'] - queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)['Attributes']['QueueArn'] + queue_url1 = resp["QueueUrl"] + queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)["Attributes"][ + "QueueArn" + ] resp = sqs.create_queue( - QueueName='test-queue.fifo', + QueueName="test-queue.fifo", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn1, 'maxReceiveCount': 2}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps( + {"deadLetterTargetArn": queue_arn1, "maxReceiveCount": 2} + ), + }, ) - queue_url2 = resp['QueueUrl'] + queue_url2 = resp["QueueUrl"] - sqs.send_message(QueueUrl=queue_url2, MessageBody='msg1') - sqs.send_message(QueueUrl=queue_url2, MessageBody='msg2') + sqs.send_message(QueueUrl=queue_url2, MessageBody="msg1") + sqs.send_message(QueueUrl=queue_url2, MessageBody="msg2") with freeze_time("2015-01-01 13:00:00"): - resp = sqs.receive_message(QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0) - resp['Messages'][0]['Body'].should.equal('msg1') + resp = sqs.receive_message( + QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + resp["Messages"][0]["Body"].should.equal("msg1") with freeze_time("2015-01-01 13:01:00"): - resp = sqs.receive_message(QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0) - resp['Messages'][0]['Body'].should.equal('msg1') + resp = sqs.receive_message( + QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + resp["Messages"][0]["Body"].should.equal("msg1") with freeze_time("2015-01-01 13:02:00"): - resp = sqs.receive_message(QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0) - len(resp['Messages']).should.equal(1) + resp = sqs.receive_message( + QueueUrl=queue_url2, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + len(resp["Messages"]).should.equal(1) - resp = sqs.receive_message(QueueUrl=queue_url1, VisibilityTimeout=30, WaitTimeSeconds=0) - resp['Messages'][0]['Body'].should.equal('msg1') + resp = sqs.receive_message( + QueueUrl=queue_url1, VisibilityTimeout=30, WaitTimeSeconds=0 + ) + resp["Messages"][0]["Body"].should.equal("msg1") # Might as well test list source queues resp = sqs.list_dead_letter_source_queues(QueueUrl=queue_url1) - resp['queueUrls'][0].should.equal(queue_url2) + resp["queueUrls"][0].should.equal(queue_url2) @mock_sqs def test_redrive_policy_available(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") - resp = sqs.create_queue(QueueName='test-deadletter') - queue_url1 = resp['QueueUrl'] - queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)['Attributes']['QueueArn'] - redrive_policy = { - 'deadLetterTargetArn': queue_arn1, - 'maxReceiveCount': 1, - } + resp = sqs.create_queue(QueueName="test-deadletter") + queue_url1 = resp["QueueUrl"] + queue_arn1 = sqs.get_queue_attributes(QueueUrl=queue_url1)["Attributes"]["QueueArn"] + redrive_policy = {"deadLetterTargetArn": queue_arn1, "maxReceiveCount": 1} resp = sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'RedrivePolicy': json.dumps(redrive_policy) - } + QueueName="test-queue", Attributes={"RedrivePolicy": json.dumps(redrive_policy)} ) - queue_url2 = resp['QueueUrl'] - attributes = sqs.get_queue_attributes(QueueUrl=queue_url2)['Attributes'] - assert 'RedrivePolicy' in attributes - assert json.loads(attributes['RedrivePolicy']) == redrive_policy + queue_url2 = resp["QueueUrl"] + attributes = sqs.get_queue_attributes(QueueUrl=queue_url2)["Attributes"] + assert "RedrivePolicy" in attributes + assert json.loads(attributes["RedrivePolicy"]) == redrive_policy # Cant have redrive policy without maxReceiveCount with assert_raises(ClientError): sqs.create_queue( - QueueName='test-queue2', + QueueName="test-queue2", Attributes={ - 'FifoQueue': 'true', - 'RedrivePolicy': json.dumps({'deadLetterTargetArn': queue_arn1}) - } + "FifoQueue": "true", + "RedrivePolicy": json.dumps({"deadLetterTargetArn": queue_arn1}), + }, ) @mock_sqs def test_redrive_policy_non_existent_queue(): - sqs = boto3.client('sqs', region_name='us-east-1') + sqs = boto3.client("sqs", region_name="us-east-1") redrive_policy = { - 'deadLetterTargetArn': 'arn:aws:sqs:us-east-1:123456789012:no-queue', - 'maxReceiveCount': 1, + "deadLetterTargetArn": "arn:aws:sqs:us-east-1:123456789012:no-queue", + "maxReceiveCount": 1, } with assert_raises(ClientError): sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'RedrivePolicy': json.dumps(redrive_policy) - } + QueueName="test-queue", + Attributes={"RedrivePolicy": json.dumps(redrive_policy)}, ) @mock_sqs def test_redrive_policy_set_attributes(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") - queue = sqs.create_queue(QueueName='test-queue') - deadletter_queue = sqs.create_queue(QueueName='test-deadletter') + queue = sqs.create_queue(QueueName="test-queue") + deadletter_queue = sqs.create_queue(QueueName="test-deadletter") redrive_policy = { - 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], - 'maxReceiveCount': 1, + "deadLetterTargetArn": deadletter_queue.attributes["QueueArn"], + "maxReceiveCount": 1, } - queue.set_attributes(Attributes={ - 'RedrivePolicy': json.dumps(redrive_policy)}) + queue.set_attributes(Attributes={"RedrivePolicy": json.dumps(redrive_policy)}) - copy = sqs.get_queue_by_name(QueueName='test-queue') - assert 'RedrivePolicy' in copy.attributes - copy_policy = json.loads(copy.attributes['RedrivePolicy']) + copy = sqs.get_queue_by_name(QueueName="test-queue") + assert "RedrivePolicy" in copy.attributes + copy_policy = json.loads(copy.attributes["RedrivePolicy"]) assert copy_policy == redrive_policy @mock_sqs def test_redrive_policy_set_attributes_with_string_value(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") - queue = sqs.create_queue(QueueName='test-queue') - deadletter_queue = sqs.create_queue(QueueName='test-deadletter') + queue = sqs.create_queue(QueueName="test-queue") + deadletter_queue = sqs.create_queue(QueueName="test-deadletter") - queue.set_attributes(Attributes={ - 'RedrivePolicy': json.dumps({ - 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], - 'maxReceiveCount': '1', - })}) + queue.set_attributes( + Attributes={ + "RedrivePolicy": json.dumps( + { + "deadLetterTargetArn": deadletter_queue.attributes["QueueArn"], + "maxReceiveCount": "1", + } + ) + } + ) - copy = sqs.get_queue_by_name(QueueName='test-queue') - assert 'RedrivePolicy' in copy.attributes - copy_policy = json.loads(copy.attributes['RedrivePolicy']) + copy = sqs.get_queue_by_name(QueueName="test-queue") + assert "RedrivePolicy" in copy.attributes + copy_policy = json.loads(copy.attributes["RedrivePolicy"]) assert copy_policy == { - 'deadLetterTargetArn': deadletter_queue.attributes['QueueArn'], - 'maxReceiveCount': 1, + "deadLetterTargetArn": deadletter_queue.attributes["QueueArn"], + "maxReceiveCount": 1, } @mock_sqs def test_receive_messages_with_message_group_id(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-queue.fifo", - Attributes={ - 'FifoQueue': 'true', - }) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} + ) queue.set_attributes(Attributes={"VisibilityTimeout": "3600"}) - queue.send_message( - MessageBody="message-1", - MessageGroupId="group" - ) - queue.send_message( - MessageBody="message-2", - MessageGroupId="group" - ) + queue.send_message(MessageBody="message-1", MessageGroupId="group") + queue.send_message(MessageBody="message-2", MessageGroupId="group") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -1624,20 +1488,13 @@ def test_receive_messages_with_message_group_id(): @mock_sqs def test_receive_messages_with_message_group_id_on_requeue(): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-queue.fifo", - Attributes={ - 'FifoQueue': 'true', - }) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} + ) queue.set_attributes(Attributes={"VisibilityTimeout": "3600"}) - queue.send_message( - MessageBody="message-1", - MessageGroupId="group" - ) - queue.send_message( - MessageBody="message-2", - MessageGroupId="group" - ) + queue.send_message(MessageBody="message-1", MessageGroupId="group") + queue.send_message(MessageBody="message-2", MessageGroupId="group") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -1657,24 +1514,17 @@ def test_receive_messages_with_message_group_id_on_requeue(): @mock_sqs def test_receive_messages_with_message_group_id_on_visibility_timeout(): - if os.environ.get('TEST_SERVER_MODE', 'false').lower() == 'true': - raise SkipTest('Cant manipulate time in server mode') + if os.environ.get("TEST_SERVER_MODE", "false").lower() == "true": + raise SkipTest("Cant manipulate time in server mode") with freeze_time("2015-01-01 12:00:00"): - sqs = boto3.resource('sqs', region_name='us-east-1') - queue = sqs.create_queue(QueueName="test-queue.fifo", - Attributes={ - 'FifoQueue': 'true', - }) + sqs = boto3.resource("sqs", region_name="us-east-1") + queue = sqs.create_queue( + QueueName="test-queue.fifo", Attributes={"FifoQueue": "true"} + ) queue.set_attributes(Attributes={"VisibilityTimeout": "3600"}) - queue.send_message( - MessageBody="message-1", - MessageGroupId="group" - ) - queue.send_message( - MessageBody="message-2", - MessageGroupId="group" - ) + queue.send_message(MessageBody="message-1", MessageGroupId="group") + queue.send_message(MessageBody="message-2", MessageGroupId="group") messages = queue.receive_messages() messages.should.have.length_of(1) @@ -1698,15 +1548,13 @@ def test_receive_messages_with_message_group_id_on_visibility_timeout(): messages.should.have.length_of(1) messages[0].message_id.should.equal(message.message_id) + @mock_sqs def test_receive_message_for_queue_with_receive_message_wait_time_seconds_set(): - sqs = boto3.resource('sqs', region_name='us-east-1') + sqs = boto3.resource("sqs", region_name="us-east-1") queue = sqs.create_queue( - QueueName='test-queue', - Attributes={ - 'ReceiveMessageWaitTimeSeconds': '2', - } + QueueName="test-queue", Attributes={"ReceiveMessageWaitTimeSeconds": "2"} ) queue.receive_messages() diff --git a/tests/test_ssm/test_ssm_boto3.py b/tests/test_ssm/test_ssm_boto3.py index 33870e383..d50ceb528 100644 --- a/tests/test_ssm/test_ssm_boto3.py +++ b/tests/test_ssm/test_ssm_boto3.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals import boto3 import botocore.exceptions -import sure # noqa +import sure # noqa import datetime import uuid import json @@ -15,1219 +15,952 @@ from moto import mock_ssm, mock_cloudformation @mock_ssm def test_delete_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(1) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(1) - client.delete_parameter(Name='test') + client.delete_parameter(Name="test") - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(0) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(0) @mock_ssm def test_delete_parameters(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(1) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(1) - result = client.delete_parameters(Names=['test', 'invalid']) - len(result['DeletedParameters']).should.equal(1) - len(result['InvalidParameters']).should.equal(1) + result = client.delete_parameters(Names=["test", "invalid"]) + len(result["DeletedParameters"]).should.equal(1) + len(result["InvalidParameters"]).should.equal(1) - response = client.get_parameters(Names=['test']) - len(response['Parameters']).should.equal(0) + response = client.get_parameters(Names=["test"]) + len(response["Parameters"]).should.equal(0) @mock_ssm def test_get_parameters_by_path(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='/foo/name1', - Description='A test parameter', - Value='value1', - Type='String') - - client.put_parameter( - Name='/foo/name2', - Description='A test parameter', - Value='value2', - Type='String') - - client.put_parameter( - Name='/bar/name3', - Description='A test parameter', - Value='value3', - Type='String') - - client.put_parameter( - Name='/bar/name3/name4', - Description='A test parameter', - Value='value4', - Type='String') - - client.put_parameter( - Name='/baz/name1', - Description='A test parameter (list)', - Value='value1,value2,value3', - Type='StringList') - - client.put_parameter( - Name='/baz/name2', - Description='A test parameter', - Value='value1', - Type='String') - - client.put_parameter( - Name='/baz/pwd', - Description='A secure test parameter', - Value='my_secret', - Type='SecureString', - KeyId='alias/aws/ssm') - - client.put_parameter( - Name='foo', - Description='A test parameter', - Value='bar', - Type='String') - - client.put_parameter( - Name='baz', - Description='A test parameter', - Value='qux', - Type='String') - - response = client.get_parameters_by_path(Path='/', Recursive=False) - len(response['Parameters']).should.equal(2) - {p['Value'] for p in response['Parameters']}.should.equal( - set(['bar', 'qux']) + Name="/foo/name1", Description="A test parameter", Value="value1", Type="String" ) - response = client.get_parameters_by_path(Path='/', Recursive=True) - len(response['Parameters']).should.equal(9) - - response = client.get_parameters_by_path(Path='/foo') - len(response['Parameters']).should.equal(2) - {p['Value'] for p in response['Parameters']}.should.equal( - set(['value1', 'value2']) + client.put_parameter( + Name="/foo/name2", Description="A test parameter", Value="value2", Type="String" ) - response = client.get_parameters_by_path(Path='/bar', Recursive=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Value'].should.equal('value3') - - response = client.get_parameters_by_path(Path='/bar', Recursive=True) - len(response['Parameters']).should.equal(2) - {p['Value'] for p in response['Parameters']}.should.equal( - set(['value3', 'value4']) + client.put_parameter( + Name="/bar/name3", Description="A test parameter", Value="value3", Type="String" ) - response = client.get_parameters_by_path(Path='/baz') - len(response['Parameters']).should.equal(3) - - filters = [{ - 'Key': 'Type', - 'Option': 'Equals', - 'Values': ['StringList'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name1']) + client.put_parameter( + Name="/bar/name3/name4", + Description="A test parameter", + Value="value4", + Type="String", ) + client.put_parameter( + Name="/baz/name1", + Description="A test parameter (list)", + Value="value1,value2,value3", + Type="StringList", + ) + + client.put_parameter( + Name="/baz/name2", Description="A test parameter", Value="value1", Type="String" + ) + + client.put_parameter( + Name="/baz/pwd", + Description="A secure test parameter", + Value="my_secret", + Type="SecureString", + KeyId="alias/aws/ssm", + ) + + client.put_parameter( + Name="foo", Description="A test parameter", Value="bar", Type="String" + ) + + client.put_parameter( + Name="baz", Description="A test parameter", Value="qux", Type="String" + ) + + response = client.get_parameters_by_path(Path="/", Recursive=False) + len(response["Parameters"]).should.equal(2) + {p["Value"] for p in response["Parameters"]}.should.equal(set(["bar", "qux"])) + + response = client.get_parameters_by_path(Path="/", Recursive=True) + len(response["Parameters"]).should.equal(9) + + response = client.get_parameters_by_path(Path="/foo") + len(response["Parameters"]).should.equal(2) + {p["Value"] for p in response["Parameters"]}.should.equal(set(["value1", "value2"])) + + response = client.get_parameters_by_path(Path="/bar", Recursive=False) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Value"].should.equal("value3") + + response = client.get_parameters_by_path(Path="/bar", Recursive=True) + len(response["Parameters"]).should.equal(2) + {p["Value"] for p in response["Parameters"]}.should.equal(set(["value3", "value4"])) + + response = client.get_parameters_by_path(Path="/baz") + len(response["Parameters"]).should.equal(3) + + filters = [{"Key": "Type", "Option": "Equals", "Values": ["StringList"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/name1"])) + # note: 'Option' is optional (default: 'Equals') - filters = [{ - 'Key': 'Type', - 'Values': ['StringList'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name1']) + filters = [{"Key": "Type", "Values": ["StringList"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/name1"])) + + filters = [{"Key": "Type", "Option": "Equals", "Values": ["String"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/name2"])) + + filters = [ + {"Key": "Type", "Option": "Equals", "Values": ["String", "SecureString"]} + ] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(2) + {p["Name"] for p in response["Parameters"]}.should.equal( + set(["/baz/name2", "/baz/pwd"]) ) - filters = [{ - 'Key': 'Type', - 'Option': 'Equals', - 'Values': ['String'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name2']) + filters = [{"Key": "Type", "Option": "BeginsWith", "Values": ["String"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(2) + {p["Name"] for p in response["Parameters"]}.should.equal( + set(["/baz/name1", "/baz/name2"]) ) - filters = [{ - 'Key': 'Type', - 'Option': 'Equals', - 'Values': ['String', 'SecureString'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(2) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name2', '/baz/pwd']) - ) - - filters = [{ - 'Key': 'Type', - 'Option': 'BeginsWith', - 'Values': ['String'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(2) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/name1', '/baz/name2']) - ) - - filters = [{ - 'Key': 'KeyId', - 'Option': 'Equals', - 'Values': ['alias/aws/ssm'], - }] - response = client.get_parameters_by_path(Path='/baz', ParameterFilters=filters) - len(response['Parameters']).should.equal(1) - {p['Name'] for p in response['Parameters']}.should.equal( - set(['/baz/pwd']) - ) + filters = [{"Key": "KeyId", "Option": "Equals", "Values": ["alias/aws/ssm"]}] + response = client.get_parameters_by_path(Path="/baz", ParameterFilters=filters) + len(response["Parameters"]).should.equal(1) + {p["Name"] for p in response["Parameters"]}.should.equal(set(["/baz/pwd"])) @mock_ssm def test_put_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") response = client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response['Version'].should.equal(1) + response["Version"].should.equal(1) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('String') - response['Parameters'][0]['Version'].should.equal(1) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("String") + response["Parameters"][0]["Version"].should.equal(1) try: client.put_parameter( - Name='test', - Description='desc 2', - Value='value 2', - Type='String') - raise RuntimeError('Should fail') + Name="test", Description="desc 2", Value="value 2", Type="String" + ) + raise RuntimeError("Should fail") except botocore.exceptions.ClientError as err: - err.operation_name.should.equal('PutParameter') - err.response['Error']['Message'].should.equal('Parameter test already exists.') + err.operation_name.should.equal("PutParameter") + err.response["Error"]["Message"].should.equal("Parameter test already exists.") - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) # without overwrite nothing change - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('String') - response['Parameters'][0]['Version'].should.equal(1) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("String") + response["Parameters"][0]["Version"].should.equal(1) response = client.put_parameter( - Name='test', - Description='desc 3', - Value='value 3', - Type='String', - Overwrite=True) + Name="test", + Description="desc 3", + Value="value 3", + Type="String", + Overwrite=True, + ) - response['Version'].should.equal(2) + response["Version"].should.equal(2) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) # without overwrite nothing change - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value 3') - response['Parameters'][0]['Type'].should.equal('String') - response['Parameters'][0]['Version'].should.equal(2) + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value 3") + response["Parameters"][0]["Type"].should.equal("String") + response["Parameters"][0]["Version"].should.equal(2) + @mock_ssm def test_put_parameter_china(): - client = boto3.client('ssm', region_name='cn-north-1') + client = boto3.client("ssm", region_name="cn-north-1") response = client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response['Version'].should.equal(1) + response["Version"].should.equal(1) @mock_ssm def test_get_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String') + Name="test", Description="A test parameter", Value="value", Type="String" + ) - response = client.get_parameter( - Name='test', - WithDecryption=False) + response = client.get_parameter(Name="test", WithDecryption=False) - response['Parameter']['Name'].should.equal('test') - response['Parameter']['Value'].should.equal('value') - response['Parameter']['Type'].should.equal('String') + response["Parameter"]["Name"].should.equal("test") + response["Parameter"]["Value"].should.equal("value") + response["Parameter"]["Type"].should.equal("String") @mock_ssm def test_get_nonexistant_parameter(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") try: - client.get_parameter( - Name='test_noexist', - WithDecryption=False) - raise RuntimeError('Should of failed') + client.get_parameter(Name="test_noexist", WithDecryption=False) + raise RuntimeError("Should of failed") except botocore.exceptions.ClientError as err: - err.operation_name.should.equal('GetParameter') - err.response['Error']['Message'].should.equal('Parameter test_noexist not found.') + err.operation_name.should.equal("GetParameter") + err.response["Error"]["Message"].should.equal( + "Parameter test_noexist not found." + ) @mock_ssm def test_describe_parameters(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='String', - AllowedPattern=r'.*') + Name="test", + Description="A test parameter", + Value="value", + Type="String", + AllowedPattern=r".*", + ) response = client.describe_parameters() - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('test') - parameters[0]['Type'].should.equal('String') - parameters[0]['AllowedPattern'].should.equal(r'.*') + parameters[0]["Name"].should.equal("test") + parameters[0]["Type"].should.equal("String") + parameters[0]["AllowedPattern"].should.equal(r".*") @mock_ssm def test_describe_parameters_paging(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - client.put_parameter( - Name="param-%d" % i, - Value="value-%d" % i, - Type="String" - ) + client.put_parameter(Name="param-%d" % i, Value="value-%d" % i, Type="String") response = client.describe_parameters() - response['Parameters'].should.have.length_of(10) - response['NextToken'].should.equal('10') + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("10") - response = client.describe_parameters(NextToken=response['NextToken']) - response['Parameters'].should.have.length_of(10) - response['NextToken'].should.equal('20') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("20") - response = client.describe_parameters(NextToken=response['NextToken']) - response['Parameters'].should.have.length_of(10) - response['NextToken'].should.equal('30') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("30") - response = client.describe_parameters(NextToken=response['NextToken']) - response['Parameters'].should.have.length_of(10) - response['NextToken'].should.equal('40') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("40") - response = client.describe_parameters(NextToken=response['NextToken']) - response['Parameters'].should.have.length_of(10) - response['NextToken'].should.equal('50') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(10) + response["NextToken"].should.equal("50") - response = client.describe_parameters(NextToken=response['NextToken']) - response['Parameters'].should.have.length_of(0) - response.should_not.have.key('NextToken') + response = client.describe_parameters(NextToken=response["NextToken"]) + response["Parameters"].should.have.length_of(0) + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_filter_names(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - p = { - 'Name': "param-%d" % i, - 'Value': "value-%d" % i, - 'Type': "String" - } + p = {"Name": "param-%d" % i, "Value": "value-%d" % i, "Type": "String"} if i % 5 == 0: - p['Type'] = 'SecureString' - p['KeyId'] = 'a key' + p["Type"] = "SecureString" + p["KeyId"] = "a key" client.put_parameter(**p) - response = client.describe_parameters(Filters=[ - { - 'Key': 'Name', - 'Values': ['param-22'] - }, - ]) + response = client.describe_parameters( + Filters=[{"Key": "Name", "Values": ["param-22"]}] + ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('param-22') - parameters[0]['Type'].should.equal('String') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("param-22") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_filter_type(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - p = { - 'Name': "param-%d" % i, - 'Value': "value-%d" % i, - 'Type': "String" - } + p = {"Name": "param-%d" % i, "Value": "value-%d" % i, "Type": "String"} if i % 5 == 0: - p['Type'] = 'SecureString' - p['KeyId'] = 'a key' + p["Type"] = "SecureString" + p["KeyId"] = "a key" client.put_parameter(**p) - response = client.describe_parameters(Filters=[ - { - 'Key': 'Type', - 'Values': ['SecureString'] - }, - ]) + response = client.describe_parameters( + Filters=[{"Key": "Type", "Values": ["SecureString"]}] + ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(10) - parameters[0]['Type'].should.equal('SecureString') - response.should.have.key('NextToken').which.should.equal('10') + parameters[0]["Type"].should.equal("SecureString") + response.should.have.key("NextToken").which.should.equal("10") @mock_ssm def test_describe_parameters_filter_keyid(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") for i in range(50): - p = { - 'Name': "param-%d" % i, - 'Value': "value-%d" % i, - 'Type': "String" - } + p = {"Name": "param-%d" % i, "Value": "value-%d" % i, "Type": "String"} if i % 5 == 0: - p['Type'] = 'SecureString' - p['KeyId'] = "key:%d" % i + p["Type"] = "SecureString" + p["KeyId"] = "key:%d" % i client.put_parameter(**p) - response = client.describe_parameters(Filters=[ - { - 'Key': 'KeyId', - 'Values': ['key:10'] - }, - ]) + response = client.describe_parameters( + Filters=[{"Key": "KeyId", "Values": ["key:10"]}] + ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('param-10') - parameters[0]['Type'].should.equal('SecureString') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("param-10") + parameters[0]["Type"].should.equal("SecureString") + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_with_parameter_filters_keyid(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") + client.put_parameter(Name="secure-param", Value="secure-value", Type="SecureString") client.put_parameter( - Name='secure-param', - Value='secure-value', - Type='SecureString' - ) - client.put_parameter( - Name='custom-secure-param', - Value='custom-secure-value', - Type='SecureString', - KeyId='alias/custom' - ) - client.put_parameter( - Name = 'param', - Value = 'value', - Type = 'String' + Name="custom-secure-param", + Value="custom-secure-value", + Type="SecureString", + KeyId="alias/custom", ) + client.put_parameter(Name="param", Value="value", Type="String") response = client.describe_parameters( - ParameterFilters=[{ - 'Key': 'KeyId', - 'Values': ['alias/aws/ssm'] - }] + ParameterFilters=[{"Key": "KeyId", "Values": ["alias/aws/ssm"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('secure-param') - parameters[0]['Type'].should.equal('SecureString') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("secure-param") + parameters[0]["Type"].should.equal("SecureString") + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'KeyId', - 'Values': ['alias/custom'] - }] + ParameterFilters=[{"Key": "KeyId", "Values": ["alias/custom"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('custom-secure-param') - parameters[0]['Type'].should.equal('SecureString') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("custom-secure-param") + parameters[0]["Type"].should.equal("SecureString") + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'KeyId', - 'Option': 'BeginsWith', - 'Values': ['alias'] - }] + ParameterFilters=[{"Key": "KeyId", "Option": "BeginsWith", "Values": ["alias"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(2) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_with_parameter_filters_name(): - client = boto3.client('ssm', region_name='us-east-1') - client.put_parameter( - Name='param', - Value='value', - Type='String' - ) - client.put_parameter( - Name = '/param-2', - Value = 'value-2', - Type = 'String' - ) + client = boto3.client("ssm", region_name="us-east-1") + client.put_parameter(Name="param", Value="value", Type="String") + client.put_parameter(Name="/param-2", Value="value-2", Type="String") response = client.describe_parameters( - ParameterFilters=[{ - 'Key': 'Name', - 'Values': ['param'] - }] + ParameterFilters=[{"Key": "Name", "Values": ["param"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('param') - parameters[0]['Type'].should.equal('String') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("param") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters=[{ - 'Key': 'Name', - 'Values': ['/param'] - }] + ParameterFilters=[{"Key": "Name", "Values": ["/param"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('param') - parameters[0]['Type'].should.equal('String') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("param") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters=[{ - 'Key': 'Name', - 'Values': ['param-2'] - }] + ParameterFilters=[{"Key": "Name", "Values": ["param-2"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('/param-2') - parameters[0]['Type'].should.equal('String') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("/param-2") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Name', - 'Option': 'BeginsWith', - 'Values': ['param'] - }] + ParameterFilters=[{"Key": "Name", "Option": "BeginsWith", "Values": ["param"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(2) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_with_parameter_filters_path(): - client = boto3.client('ssm', region_name='us-east-1') - client.put_parameter( - Name='/foo/name1', - Value='value1', - Type='String') + client = boto3.client("ssm", region_name="us-east-1") + client.put_parameter(Name="/foo/name1", Value="value1", Type="String") - client.put_parameter( - Name='/foo/name2', - Value='value2', - Type='String') + client.put_parameter(Name="/foo/name2", Value="value2", Type="String") - client.put_parameter( - Name='/bar/name3', - Value='value3', - Type='String') + client.put_parameter(Name="/bar/name3", Value="value3", Type="String") - client.put_parameter( - Name='/bar/name3/name4', - Value='value4', - Type='String') + client.put_parameter(Name="/bar/name3/name4", Value="value4", Type="String") - client.put_parameter( - Name='foo', - Value='bar', - Type='String') + client.put_parameter(Name="foo", Value="bar", Type="String") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Values': ['/fo'] - }] + ParameterFilters=[{"Key": "Path", "Values": ["/fo"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(0) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Values': ['/'] - }] + ParameterFilters=[{"Key": "Path", "Values": ["/"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('foo') - parameters[0]['Type'].should.equal('String') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("foo") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Values': ['/', '/foo'] - }] + ParameterFilters=[{"Key": "Path", "Values": ["/", "/foo"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(3) - {parameter['Name'] for parameter in response['Parameters']}.should.equal( - {'/foo/name1', '/foo/name2', 'foo'} + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2", "foo"} ) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Values': ['/foo/'] - }] + ParameterFilters=[{"Key": "Path", "Values": ["/foo/"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(2) - {parameter['Name'] for parameter in response['Parameters']}.should.equal( - {'/foo/name1', '/foo/name2'} + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2"} ) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Option': 'OneLevel', - 'Values': ['/bar/name3'] - }] + ParameterFilters=[ + {"Key": "Path", "Option": "OneLevel", "Values": ["/bar/name3"]} + ] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('/bar/name3/name4') - parameters[0]['Type'].should.equal('String') - response.should_not.have.key('NextToken') - + parameters[0]["Name"].should.equal("/bar/name3/name4") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Option': 'Recursive', - 'Values': ['/fo'] - }] + ParameterFilters=[{"Key": "Path", "Option": "Recursive", "Values": ["/fo"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(0) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Option': 'Recursive', - 'Values': ['/'] - }] + ParameterFilters=[{"Key": "Path", "Option": "Recursive", "Values": ["/"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(5) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Option': 'Recursive', - 'Values': ['/foo', '/bar'] - }] + ParameterFilters=[ + {"Key": "Path", "Option": "Recursive", "Values": ["/foo", "/bar"]} + ] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(4) - {parameter['Name'] for parameter in response['Parameters']}.should.equal( - {'/foo/name1', '/foo/name2', '/bar/name3', '/bar/name3/name4'} + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2", "/bar/name3", "/bar/name3/name4"} ) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Option': 'Recursive', - 'Values': ['/foo/'] - }] + ParameterFilters=[{"Key": "Path", "Option": "Recursive", "Values": ["/foo/"]}] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(2) - {parameter['Name'] for parameter in response['Parameters']}.should.equal( - {'/foo/name1', '/foo/name2'} + {parameter["Name"] for parameter in response["Parameters"]}.should.equal( + {"/foo/name1", "/foo/name2"} ) - response.should_not.have.key('NextToken') + response.should_not.have.key("NextToken") response = client.describe_parameters( - ParameterFilters = [{ - 'Key': 'Path', - 'Option': 'Recursive', - 'Values': ['/bar/name3'] - }] + ParameterFilters=[ + {"Key": "Path", "Option": "Recursive", "Values": ["/bar/name3"]} + ] ) - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(1) - parameters[0]['Name'].should.equal('/bar/name3/name4') - parameters[0]['Type'].should.equal('String') - response.should_not.have.key('NextToken') + parameters[0]["Name"].should.equal("/bar/name3/name4") + parameters[0]["Type"].should.equal("String") + response.should_not.have.key("NextToken") @mock_ssm def test_describe_parameters_invalid_parameter_filters(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.describe_parameters.when.called_with( - Filters=[{ - 'Key': 'Name', - 'Values': ['test'] - }], - ParameterFilters=[{ - 'Key': 'Name', - 'Values': ['test'] - }] + Filters=[{"Key": "Name", "Values": ["test"]}], + ParameterFilters=[{"Key": "Name", "Values": ["test"]}], ).should.throw( ClientError, - 'You can use either Filters or ParameterFilters in a single request.' + "You can use either Filters or ParameterFilters in a single request.", ) - client.describe_parameters.when.called_with( - ParameterFilters=[{}] - ).should.throw( + client.describe_parameters.when.called_with(ParameterFilters=[{}]).should.throw( ParamValidationError, - 'Parameter validation failed:\nMissing required parameter in ParameterFilters[0]: "Key"' + 'Parameter validation failed:\nMissing required parameter in ParameterFilters[0]: "Key"', ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'key', - }] + ParameterFilters=[{"Key": "key"}] ).should.throw( ClientError, '1 validation error detected: Value "key" at "parameterFilters.1.member.key" failed to satisfy constraint: ' - 'Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier' + "Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier", ) - long_key = 'tag:' + 't' * 129 + long_key = "tag:" + "t" * 129 client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': long_key, - }] + ParameterFilters=[{"Key": long_key}] ).should.throw( ClientError, '1 validation error detected: Value "{value}" at "parameterFilters.1.member.key" failed to satisfy constraint: ' - 'Member must have length less than or equal to 132'.format(value=long_key) + "Member must have length less than or equal to 132".format(value=long_key), ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - 'Option': 'over 10 chars' - }] + ParameterFilters=[{"Key": "Name", "Option": "over 10 chars"}] ).should.throw( ClientError, '1 validation error detected: Value "over 10 chars" at "parameterFilters.1.member.option" failed to satisfy constraint: ' - 'Member must have length less than or equal to 10' + "Member must have length less than or equal to 10", ) - many_values = ['test'] * 51 + many_values = ["test"] * 51 client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - 'Values': many_values - }] + ParameterFilters=[{"Key": "Name", "Values": many_values}] ).should.throw( ClientError, '1 validation error detected: Value "{value}" at "parameterFilters.1.member.values" failed to satisfy constraint: ' - 'Member must have length less than or equal to 50'.format(value=many_values) + "Member must have length less than or equal to 50".format(value=many_values), ) - long_value = ['t' * 1025] + long_value = ["t" * 1025] client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - 'Values': long_value - }] + ParameterFilters=[{"Key": "Name", "Values": long_value}] ).should.throw( ClientError, '1 validation error detected: Value "{value}" at "parameterFilters.1.member.values" failed to satisfy constraint: ' - '[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]'.format(value=long_value) + "[Member must have length less than or equal to 1024, Member must have length greater than or equal to 1]".format( + value=long_value + ), ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - 'Option': 'over 10 chars' - },{ - 'Key': 'key', - }] + ParameterFilters=[{"Key": "Name", "Option": "over 10 chars"}, {"Key": "key"}] ).should.throw( ClientError, - '2 validation errors detected: ' + "2 validation errors detected: " 'Value "over 10 chars" at "parameterFilters.1.member.option" failed to satisfy constraint: ' - 'Member must have length less than or equal to 10; ' + "Member must have length less than or equal to 10; " 'Value "key" at "parameterFilters.2.member.key" failed to satisfy constraint: ' - 'Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier' + "Member must satisfy regular expression pattern: tag:.+|Name|Type|KeyId|Path|Label|Tier", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Label', - }] + ParameterFilters=[{"Key": "Label"}] ).should.throw( ClientError, - 'The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier].' + "The following filter key is not valid: Label. Valid filter keys include: [Path, Name, Type, KeyId, Tier].", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - }] + ParameterFilters=[{"Key": "Name"}] ).should.throw( ClientError, - 'The following filter values are missing : null for filter key Name.' + "The following filter values are missing : null for filter key Name.", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - 'Values': [] - }] + ParameterFilters=[{"Key": "Name", "Values": []}] ).should.throw( ParamValidationError, - 'Invalid length for parameter ParameterFilters[0].Values, value: 0, valid range: 1-inf' + "Invalid length for parameter ParameterFilters[0].Values, value: 0, valid range: 1-inf", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - 'Values': ['test'] - },{ - 'Key': 'Name', - 'Values': ['test test'] - }] + ParameterFilters=[ + {"Key": "Name", "Values": ["test"]}, + {"Key": "Name", "Values": ["test test"]}, + ] ).should.throw( ClientError, - 'The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter.' + "The following filter is duplicated in the request: Name. A request can contain only one occurrence of a specific filter.", ) - for value in ['/###', '//', 'test']: + for value in ["/###", "//", "test"]: client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Path', - 'Values': [value] - }] + ParameterFilters=[{"Key": "Path", "Values": [value]}] ).should.throw( ClientError, 'The parameter doesn\'t meet the parameter name requirements. The parameter name must begin with a forward slash "/". ' - 'It can\'t be prefixed with \"aws\" or \"ssm\" (case-insensitive). ' - 'It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). ' + 'It can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' + "It must use only letters, numbers, or the following symbols: . (period), - (hyphen), _ (underscore). " 'Special characters are not allowed. All sub-paths, if specified, must use the forward slash symbol "/". ' - 'Valid example: /get/parameters2-/by1./path0_.' + "Valid example: /get/parameters2-/by1./path0_.", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Path', - 'Values': ['/aws', '/ssm'] - }] + ParameterFilters=[{"Key": "Path", "Values": ["/aws", "/ssm"]}] ).should.throw( ClientError, 'Filters for common parameters can\'t be prefixed with "aws" or "ssm" (case-insensitive). ' - 'When using global parameters, please specify within a global namespace.' + "When using global parameters, please specify within a global namespace.", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Path', - 'Option': 'Equals', - 'Values': ['test'] - }] + ParameterFilters=[{"Key": "Path", "Option": "Equals", "Values": ["test"]}] ).should.throw( ClientError, - 'The following filter option is not valid: Equals. Valid options include: [Recursive, OneLevel].' + "The following filter option is not valid: Equals. Valid options include: [Recursive, OneLevel].", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Tier', - 'Values': ['test'] - }] + ParameterFilters=[{"Key": "Tier", "Values": ["test"]}] ).should.throw( ClientError, - 'The following filter value is not valid: test. Valid values include: [Standard, Advanced, Intelligent-Tiering]' + "The following filter value is not valid: test. Valid values include: [Standard, Advanced, Intelligent-Tiering]", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Type', - 'Values': ['test'] - }] + ParameterFilters=[{"Key": "Type", "Values": ["test"]}] ).should.throw( ClientError, - 'The following filter value is not valid: test. Valid values include: [String, StringList, SecureString]' + "The following filter value is not valid: test. Valid values include: [String, StringList, SecureString]", ) client.describe_parameters.when.called_with( - ParameterFilters=[{ - 'Key': 'Name', - 'Option': 'option', - 'Values': ['test'] - }] + ParameterFilters=[{"Key": "Name", "Option": "option", "Values": ["test"]}] ).should.throw( ClientError, - 'The following filter option is not valid: option. Valid options include: [BeginsWith, Equals].' + "The following filter option is not valid: option. Valid options include: [BeginsWith, Equals].", ) @mock_ssm def test_describe_parameters_attributes(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='aa', - Value='11', - Type='String', - Description='my description' + Name="aa", Value="11", Type="String", Description="my description" ) - client.put_parameter( - Name='bb', - Value='22', - Type='String' - ) + client.put_parameter(Name="bb", Value="22", Type="String") response = client.describe_parameters() - parameters = response['Parameters'] + parameters = response["Parameters"] parameters.should.have.length_of(2) - parameters[0]['Description'].should.equal('my description') - parameters[0]['Version'].should.equal(1) - parameters[0]['LastModifiedDate'].should.be.a(datetime.date) - parameters[0]['LastModifiedUser'].should.equal('N/A') + parameters[0]["Description"].should.equal("my description") + parameters[0]["Version"].should.equal(1) + parameters[0]["LastModifiedDate"].should.be.a(datetime.date) + parameters[0]["LastModifiedUser"].should.equal("N/A") - parameters[1].should_not.have.key('Description') - parameters[1]['Version'].should.equal(1) + parameters[1].should_not.have.key("Description") + parameters[1]["Version"].should.equal(1) @mock_ssm def test_get_parameter_invalid(): - client = client = boto3.client('ssm', region_name='us-east-1') - response = client.get_parameters( - Names=[ - 'invalid' - ], - WithDecryption=False) + client = client = boto3.client("ssm", region_name="us-east-1") + response = client.get_parameters(Names=["invalid"], WithDecryption=False) - len(response['Parameters']).should.equal(0) - len(response['InvalidParameters']).should.equal(1) - response['InvalidParameters'][0].should.equal('invalid') + len(response["Parameters"]).should.equal(0) + len(response["InvalidParameters"]).should.equal(1) + response["InvalidParameters"][0].should.equal("invalid") @mock_ssm def test_put_parameter_secure_default_kms(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='SecureString') + Name="test", Description="A test parameter", Value="value", Type="SecureString" + ) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('kms:alias/aws/ssm:value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("kms:alias/aws/ssm:value") + response["Parameters"][0]["Type"].should.equal("SecureString") - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=True) + response = client.get_parameters(Names=["test"], WithDecryption=True) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("SecureString") @mock_ssm def test_put_parameter_secure_custom_kms(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.put_parameter( - Name='test', - Description='A test parameter', - Value='value', - Type='SecureString', - KeyId='foo') + Name="test", + Description="A test parameter", + Value="value", + Type="SecureString", + KeyId="foo", + ) - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=False) + response = client.get_parameters(Names=["test"], WithDecryption=False) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('kms:foo:value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("kms:foo:value") + response["Parameters"][0]["Type"].should.equal("SecureString") - response = client.get_parameters( - Names=[ - 'test' - ], - WithDecryption=True) + response = client.get_parameters(Names=["test"], WithDecryption=True) - len(response['Parameters']).should.equal(1) - response['Parameters'][0]['Name'].should.equal('test') - response['Parameters'][0]['Value'].should.equal('value') - response['Parameters'][0]['Type'].should.equal('SecureString') + len(response["Parameters"]).should.equal(1) + response["Parameters"][0]["Name"].should.equal("test") + response["Parameters"][0]["Value"].should.equal("value") + response["Parameters"][0]["Type"].should.equal("SecureString") @mock_ssm def test_add_remove_list_tags_for_resource(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") client.add_tags_to_resource( - ResourceId='test', - ResourceType='Parameter', - Tags=[{'Key': 'test-key', 'Value': 'test-value'}] + ResourceId="test", + ResourceType="Parameter", + Tags=[{"Key": "test-key", "Value": "test-value"}], ) response = client.list_tags_for_resource( - ResourceId='test', - ResourceType='Parameter' + ResourceId="test", ResourceType="Parameter" ) - len(response['TagList']).should.equal(1) - response['TagList'][0]['Key'].should.equal('test-key') - response['TagList'][0]['Value'].should.equal('test-value') + len(response["TagList"]).should.equal(1) + response["TagList"][0]["Key"].should.equal("test-key") + response["TagList"][0]["Value"].should.equal("test-value") client.remove_tags_from_resource( - ResourceId='test', - ResourceType='Parameter', - TagKeys=['test-key'] + ResourceId="test", ResourceType="Parameter", TagKeys=["test-key"] ) response = client.list_tags_for_resource( - ResourceId='test', - ResourceType='Parameter' + ResourceId="test", ResourceType="Parameter" ) - len(response['TagList']).should.equal(0) + len(response["TagList"]).should.equal(0) @mock_ssm def test_send_command(): - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") # note the timeout is determined server side, so this is a simpler check. before = datetime.datetime.now() response = client.send_command( - InstanceIds=['i-123456'], + InstanceIds=["i-123456"], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref' + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", ) - cmd = response['Command'] + cmd = response["Command"] - cmd['CommandId'].should_not.be(None) - cmd['DocumentName'].should.equal(ssm_document) - cmd['Parameters'].should.equal(params) + cmd["CommandId"].should_not.be(None) + cmd["DocumentName"].should.equal(ssm_document) + cmd["Parameters"].should.equal(params) - cmd['OutputS3Region'].should.equal('us-east-2') - cmd['OutputS3BucketName'].should.equal('the-bucket') - cmd['OutputS3KeyPrefix'].should.equal('pref') + cmd["OutputS3Region"].should.equal("us-east-2") + cmd["OutputS3BucketName"].should.equal("the-bucket") + cmd["OutputS3KeyPrefix"].should.equal("pref") - cmd['ExpiresAfter'].should.be.greater_than(before) + cmd["ExpiresAfter"].should.be.greater_than(before) # test sending a command without any optional parameters - response = client.send_command( - DocumentName=ssm_document) + response = client.send_command(DocumentName=ssm_document) - cmd = response['Command'] + cmd = response["Command"] - cmd['CommandId'].should_not.be(None) - cmd['DocumentName'].should.equal(ssm_document) + cmd["CommandId"].should_not.be(None) + cmd["DocumentName"].should.equal(ssm_document) @mock_ssm def test_list_commands(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} response = client.send_command( - InstanceIds=['i-123456'], + InstanceIds=["i-123456"], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref') + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", + ) - cmd = response['Command'] - cmd_id = cmd['CommandId'] + cmd = response["Command"] + cmd_id = cmd["CommandId"] # get the command by id - response = client.list_commands( - CommandId=cmd_id) + response = client.list_commands(CommandId=cmd_id) - cmds = response['Commands'] + cmds = response["Commands"] len(cmds).should.equal(1) - cmds[0]['CommandId'].should.equal(cmd_id) + cmds[0]["CommandId"].should.equal(cmd_id) # add another command with the same instance id to test listing by # instance id - client.send_command( - InstanceIds=['i-123456'], - DocumentName=ssm_document) + client.send_command(InstanceIds=["i-123456"], DocumentName=ssm_document) - response = client.list_commands( - InstanceId='i-123456') + response = client.list_commands(InstanceId="i-123456") - cmds = response['Commands'] + cmds = response["Commands"] len(cmds).should.equal(2) for cmd in cmds: - cmd['InstanceIds'].should.contain('i-123456') + cmd["InstanceIds"].should.contain("i-123456") # test the error case for an invalid command id with assert_raises(ClientError): - response = client.list_commands( - CommandId=str(uuid.uuid4())) + response = client.list_commands(CommandId=str(uuid.uuid4())) + @mock_ssm def test_get_command_invocation(): - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} response = client.send_command( - InstanceIds=['i-123456', 'i-234567', 'i-345678'], + InstanceIds=["i-123456", "i-234567", "i-345678"], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref') + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", + ) - cmd = response['Command'] - cmd_id = cmd['CommandId'] + cmd = response["Command"] + cmd_id = cmd["CommandId"] - instance_id = 'i-345678' + instance_id = "i-345678" invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId=instance_id, - PluginName='aws:runShellScript') + CommandId=cmd_id, InstanceId=instance_id, PluginName="aws:runShellScript" + ) - invocation_response['CommandId'].should.equal(cmd_id) - invocation_response['InstanceId'].should.equal(instance_id) + invocation_response["CommandId"].should.equal(cmd_id) + invocation_response["InstanceId"].should.equal(instance_id) # test the error case for an invalid instance id with assert_raises(ClientError): invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId='i-FAKE') + CommandId=cmd_id, InstanceId="i-FAKE" + ) # test the error case for an invalid plugin name with assert_raises(ClientError): invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId=instance_id, - PluginName='FAKE') + CommandId=cmd_id, InstanceId=instance_id, PluginName="FAKE" + ) + @mock_ssm @mock_cloudformation @@ -1243,63 +976,52 @@ def test_get_command_invocations_from_stack(): "KeyName": "test", "InstanceType": "t2.micro", "Tags": [ - { - "Key": "Test Description", - "Value": "Test tag" - }, - { - "Key": "Test Name", - "Value": "Name tag for tests" - } - ] - } + {"Key": "Test Description", "Value": "Test tag"}, + {"Key": "Test Name", "Value": "Name tag for tests"}, + ], + }, } }, "Outputs": { "test": { "Description": "Test Output", "Value": "Test output value", - "Export": { - "Name": "Test value to export" - } + "Export": {"Name": "Test value to export"}, }, - "PublicIP": { - "Value": "Test public ip" - } - } + "PublicIP": {"Value": "Test public ip"}, + }, } - cloudformation_client = boto3.client( - 'cloudformation', - region_name='us-east-1') + cloudformation_client = boto3.client("cloudformation", region_name="us-east-1") stack_template_str = json.dumps(stack_template) response = cloudformation_client.create_stack( - StackName='test_stack', + StackName="test_stack", TemplateBody=stack_template_str, - Capabilities=('CAPABILITY_IAM', )) + Capabilities=("CAPABILITY_IAM",), + ) - client = boto3.client('ssm', region_name='us-east-1') + client = boto3.client("ssm", region_name="us-east-1") - ssm_document = 'AWS-RunShellScript' - params = {'commands': ['#!/bin/bash\necho \'hello world\'']} + ssm_document = "AWS-RunShellScript" + params = {"commands": ["#!/bin/bash\necho 'hello world'"]} response = client.send_command( - Targets=[{ - 'Key': 'tag:aws:cloudformation:stack-name', - 'Values': ('test_stack', )}], + Targets=[ + {"Key": "tag:aws:cloudformation:stack-name", "Values": ("test_stack",)} + ], DocumentName=ssm_document, Parameters=params, - OutputS3Region='us-east-2', - OutputS3BucketName='the-bucket', - OutputS3KeyPrefix='pref') + OutputS3Region="us-east-2", + OutputS3BucketName="the-bucket", + OutputS3KeyPrefix="pref", + ) - cmd = response['Command'] - cmd_id = cmd['CommandId'] - instance_ids = cmd['InstanceIds'] + cmd = response["Command"] + cmd_id = cmd["CommandId"] + instance_ids = cmd["InstanceIds"] invocation_response = client.get_command_invocation( - CommandId=cmd_id, - InstanceId=instance_ids[0], - PluginName='aws:runShellScript') + CommandId=cmd_id, InstanceId=instance_ids[0], PluginName="aws:runShellScript" + ) diff --git a/tests/test_stepfunctions/test_stepfunctions.py b/tests/test_stepfunctions/test_stepfunctions.py index 6c1e7e4c8..77b9fbfb3 100644 --- a/tests/test_stepfunctions/test_stepfunctions.py +++ b/tests/test_stepfunctions/test_stepfunctions.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals import boto3 -import sure # noqa +import sure # noqa import datetime from datetime import datetime @@ -11,372 +11,513 @@ from nose.tools import assert_raises from moto import mock_sts, mock_stepfunctions -region = 'us-east-1' -simple_definition = '{"Comment": "An example of the Amazon States Language using a choice state.",' \ - '"StartAt": "DefaultState",' \ - '"States": ' \ - '{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}' +region = "us-east-1" +simple_definition = ( + '{"Comment": "An example of the Amazon States Language using a choice state.",' + '"StartAt": "DefaultState",' + '"States": ' + '{"DefaultState": {"Type": "Fail","Error": "DefaultStateError","Cause": "No Matches!"}}}' +) account_id = None @mock_stepfunctions @mock_sts def test_state_machine_creation_succeeds(): - client = boto3.client('stepfunctions', region_name=region) - name = 'example_step_function' + client = boto3.client("stepfunctions", region_name=region) + name = "example_step_function" # - response = client.create_state_machine(name=name, - definition=str(simple_definition), - roleArn=_get_default_role()) + response = client.create_state_machine( + name=name, definition=str(simple_definition), roleArn=_get_default_role() + ) # - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - response['creationDate'].should.be.a(datetime) - response['stateMachineArn'].should.equal('arn:aws:states:' + region + ':123456789012:stateMachine:' + name) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + response["creationDate"].should.be.a(datetime) + response["stateMachineArn"].should.equal( + "arn:aws:states:" + region + ":123456789012:stateMachine:" + name + ) @mock_stepfunctions def test_state_machine_creation_fails_with_invalid_names(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) invalid_names = [ - 'with space', - 'withbracket', 'with{bracket', 'with}bracket', 'with[bracket', 'with]bracket', - 'with?wildcard', 'with*wildcard', - 'special"char', 'special#char', 'special%char', 'special\\char', 'special^char', 'special|char', - 'special~char', 'special`char', 'special$char', 'special&char', 'special,char', 'special;char', - 'special:char', 'special/char', - u'uni\u0000code', u'uni\u0001code', u'uni\u0002code', u'uni\u0003code', u'uni\u0004code', - u'uni\u0005code', u'uni\u0006code', u'uni\u0007code', u'uni\u0008code', u'uni\u0009code', - u'uni\u000Acode', u'uni\u000Bcode', u'uni\u000Ccode', - u'uni\u000Dcode', u'uni\u000Ecode', u'uni\u000Fcode', - u'uni\u0010code', u'uni\u0011code', u'uni\u0012code', u'uni\u0013code', u'uni\u0014code', - u'uni\u0015code', u'uni\u0016code', u'uni\u0017code', u'uni\u0018code', u'uni\u0019code', - u'uni\u001Acode', u'uni\u001Bcode', u'uni\u001Ccode', - u'uni\u001Dcode', u'uni\u001Ecode', u'uni\u001Fcode', - u'uni\u007Fcode', - u'uni\u0080code', u'uni\u0081code', u'uni\u0082code', u'uni\u0083code', u'uni\u0084code', - u'uni\u0085code', u'uni\u0086code', u'uni\u0087code', u'uni\u0088code', u'uni\u0089code', - u'uni\u008Acode', u'uni\u008Bcode', u'uni\u008Ccode', - u'uni\u008Dcode', u'uni\u008Ecode', u'uni\u008Fcode', - u'uni\u0090code', u'uni\u0091code', u'uni\u0092code', u'uni\u0093code', u'uni\u0094code', - u'uni\u0095code', u'uni\u0096code', u'uni\u0097code', u'uni\u0098code', u'uni\u0099code', - u'uni\u009Acode', u'uni\u009Bcode', u'uni\u009Ccode', - u'uni\u009Dcode', u'uni\u009Ecode', u'uni\u009Fcode'] + "with space", + "withbracket", + "with{bracket", + "with}bracket", + "with[bracket", + "with]bracket", + "with?wildcard", + "with*wildcard", + 'special"char', + "special#char", + "special%char", + "special\\char", + "special^char", + "special|char", + "special~char", + "special`char", + "special$char", + "special&char", + "special,char", + "special;char", + "special:char", + "special/char", + "uni\u0000code", + "uni\u0001code", + "uni\u0002code", + "uni\u0003code", + "uni\u0004code", + "uni\u0005code", + "uni\u0006code", + "uni\u0007code", + "uni\u0008code", + "uni\u0009code", + "uni\u000Acode", + "uni\u000Bcode", + "uni\u000Ccode", + "uni\u000Dcode", + "uni\u000Ecode", + "uni\u000Fcode", + "uni\u0010code", + "uni\u0011code", + "uni\u0012code", + "uni\u0013code", + "uni\u0014code", + "uni\u0015code", + "uni\u0016code", + "uni\u0017code", + "uni\u0018code", + "uni\u0019code", + "uni\u001Acode", + "uni\u001Bcode", + "uni\u001Ccode", + "uni\u001Dcode", + "uni\u001Ecode", + "uni\u001Fcode", + "uni\u007Fcode", + "uni\u0080code", + "uni\u0081code", + "uni\u0082code", + "uni\u0083code", + "uni\u0084code", + "uni\u0085code", + "uni\u0086code", + "uni\u0087code", + "uni\u0088code", + "uni\u0089code", + "uni\u008Acode", + "uni\u008Bcode", + "uni\u008Ccode", + "uni\u008Dcode", + "uni\u008Ecode", + "uni\u008Fcode", + "uni\u0090code", + "uni\u0091code", + "uni\u0092code", + "uni\u0093code", + "uni\u0094code", + "uni\u0095code", + "uni\u0096code", + "uni\u0097code", + "uni\u0098code", + "uni\u0099code", + "uni\u009Acode", + "uni\u009Bcode", + "uni\u009Ccode", + "uni\u009Dcode", + "uni\u009Ecode", + "uni\u009Fcode", + ] # for invalid_name in invalid_names: with assert_raises(ClientError) as exc: - client.create_state_machine(name=invalid_name, - definition=str(simple_definition), - roleArn=_get_default_role()) + client.create_state_machine( + name=invalid_name, + definition=str(simple_definition), + roleArn=_get_default_role(), + ) @mock_stepfunctions def test_state_machine_creation_requires_valid_role_arn(): - client = boto3.client('stepfunctions', region_name=region) - name = 'example_step_function' + client = boto3.client("stepfunctions", region_name=region) + name = "example_step_function" # with assert_raises(ClientError) as exc: - client.create_state_machine(name=name, - definition=str(simple_definition), - roleArn='arn:aws:iam::1234:role/unknown_role') + client.create_state_machine( + name=name, + definition=str(simple_definition), + roleArn="arn:aws:iam::1234:role/unknown_role", + ) @mock_stepfunctions def test_state_machine_list_returns_empty_list_by_default(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # list = client.list_state_machines() - list['stateMachines'].should.be.empty + list["stateMachines"].should.be.empty @mock_stepfunctions @mock_sts def test_state_machine_list_returns_created_state_machines(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - machine2 = client.create_state_machine(name='name2', - definition=str(simple_definition), - roleArn=_get_default_role()) - machine1 = client.create_state_machine(name='name1', - definition=str(simple_definition), - roleArn=_get_default_role(), - tags=[{'key': 'tag_key', 'value': 'tag_value'}]) + machine2 = client.create_state_machine( + name="name2", definition=str(simple_definition), roleArn=_get_default_role() + ) + machine1 = client.create_state_machine( + name="name1", + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=[{"key": "tag_key", "value": "tag_value"}], + ) list = client.list_state_machines() # - list['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - list['stateMachines'].should.have.length_of(2) - list['stateMachines'][0]['creationDate'].should.be.a(datetime) - list['stateMachines'][0]['creationDate'].should.equal(machine1['creationDate']) - list['stateMachines'][0]['name'].should.equal('name1') - list['stateMachines'][0]['stateMachineArn'].should.equal(machine1['stateMachineArn']) - list['stateMachines'][1]['creationDate'].should.be.a(datetime) - list['stateMachines'][1]['creationDate'].should.equal(machine2['creationDate']) - list['stateMachines'][1]['name'].should.equal('name2') - list['stateMachines'][1]['stateMachineArn'].should.equal(machine2['stateMachineArn']) + list["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + list["stateMachines"].should.have.length_of(2) + list["stateMachines"][0]["creationDate"].should.be.a(datetime) + list["stateMachines"][0]["creationDate"].should.equal(machine1["creationDate"]) + list["stateMachines"][0]["name"].should.equal("name1") + list["stateMachines"][0]["stateMachineArn"].should.equal( + machine1["stateMachineArn"] + ) + list["stateMachines"][1]["creationDate"].should.be.a(datetime) + list["stateMachines"][1]["creationDate"].should.equal(machine2["creationDate"]) + list["stateMachines"][1]["name"].should.equal("name2") + list["stateMachines"][1]["stateMachineArn"].should.equal( + machine2["stateMachineArn"] + ) @mock_stepfunctions @mock_sts def test_state_machine_creation_is_idempotent_by_name(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) sm_list = client.list_state_machines() - sm_list['stateMachines'].should.have.length_of(1) + sm_list["stateMachines"].should.have.length_of(1) # - client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) sm_list = client.list_state_machines() - sm_list['stateMachines'].should.have.length_of(1) + sm_list["stateMachines"].should.have.length_of(1) # - client.create_state_machine(name='diff_name', definition=str(simple_definition), roleArn=_get_default_role()) + client.create_state_machine( + name="diff_name", definition=str(simple_definition), roleArn=_get_default_role() + ) sm_list = client.list_state_machines() - sm_list['stateMachines'].should.have.length_of(2) + sm_list["stateMachines"].should.have.length_of(2) @mock_stepfunctions @mock_sts def test_state_machine_creation_can_be_described(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - desc = client.describe_state_machine(stateMachineArn=sm['stateMachineArn']) - desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - desc['creationDate'].should.equal(sm['creationDate']) - desc['definition'].should.equal(str(simple_definition)) - desc['name'].should.equal('name') - desc['roleArn'].should.equal(_get_default_role()) - desc['stateMachineArn'].should.equal(sm['stateMachineArn']) - desc['status'].should.equal('ACTIVE') + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + desc = client.describe_state_machine(stateMachineArn=sm["stateMachineArn"]) + desc["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + desc["creationDate"].should.equal(sm["creationDate"]) + desc["definition"].should.equal(str(simple_definition)) + desc["name"].should.equal("name") + desc["roleArn"].should.equal(_get_default_role()) + desc["stateMachineArn"].should.equal(sm["stateMachineArn"]) + desc["status"].should.equal("ACTIVE") @mock_stepfunctions @mock_sts def test_state_machine_throws_error_when_describing_unknown_machine(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # with assert_raises(ClientError) as exc: - unknown_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' + unknown_state_machine = ( + "arn:aws:states:" + + region + + ":" + + _get_account_id() + + ":stateMachine:unknown" + ) client.describe_state_machine(stateMachineArn=unknown_state_machine) @mock_stepfunctions @mock_sts def test_state_machine_throws_error_when_describing_machine_in_different_account(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # with assert_raises(ClientError) as exc: - unknown_state_machine = 'arn:aws:states:' + region + ':000000000000:stateMachine:unknown' + unknown_state_machine = ( + "arn:aws:states:" + region + ":000000000000:stateMachine:unknown" + ) client.describe_state_machine(stateMachineArn=unknown_state_machine) @mock_stepfunctions @mock_sts def test_state_machine_can_be_deleted(): - client = boto3.client('stepfunctions', region_name=region) - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) + client = boto3.client("stepfunctions", region_name=region) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) # - response = client.delete_state_machine(stateMachineArn=sm['stateMachineArn']) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response = client.delete_state_machine(stateMachineArn=sm["stateMachineArn"]) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # sm_list = client.list_state_machines() - sm_list['stateMachines'].should.have.length_of(0) + sm_list["stateMachines"].should.have.length_of(0) @mock_stepfunctions @mock_sts def test_state_machine_can_deleted_nonexisting_machine(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - unknown_state_machine = 'arn:aws:states:' + region + ':123456789012:stateMachine:unknown' + unknown_state_machine = ( + "arn:aws:states:" + region + ":123456789012:stateMachine:unknown" + ) response = client.delete_state_machine(stateMachineArn=unknown_state_machine) - response['ResponseMetadata']['HTTPStatusCode'].should.equal(200) + response["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) # sm_list = client.list_state_machines() - sm_list['stateMachines'].should.have.length_of(0) + sm_list["stateMachines"].should.have.length_of(0) @mock_stepfunctions @mock_sts def test_state_machine_list_tags_for_created_machine(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - machine = client.create_state_machine(name='name1', - definition=str(simple_definition), - roleArn=_get_default_role(), - tags=[{'key': 'tag_key', 'value': 'tag_value'}]) - response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) - tags = response['tags'] + machine = client.create_state_machine( + name="name1", + definition=str(simple_definition), + roleArn=_get_default_role(), + tags=[{"key": "tag_key", "value": "tag_value"}], + ) + response = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + tags = response["tags"] tags.should.have.length_of(1) - tags[0].should.equal({'key': 'tag_key', 'value': 'tag_value'}) + tags[0].should.equal({"key": "tag_key", "value": "tag_value"}) @mock_stepfunctions @mock_sts def test_state_machine_list_tags_for_machine_without_tags(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - machine = client.create_state_machine(name='name1', - definition=str(simple_definition), - roleArn=_get_default_role()) - response = client.list_tags_for_resource(resourceArn=machine['stateMachineArn']) - tags = response['tags'] + machine = client.create_state_machine( + name="name1", definition=str(simple_definition), roleArn=_get_default_role() + ) + response = client.list_tags_for_resource(resourceArn=machine["stateMachineArn"]) + tags = response["tags"] tags.should.have.length_of(0) @mock_stepfunctions @mock_sts def test_state_machine_list_tags_for_nonexisting_machine(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - non_existing_state_machine = 'arn:aws:states:' + region + ':' + _get_account_id() + ':stateMachine:unknown' + non_existing_state_machine = ( + "arn:aws:states:" + region + ":" + _get_account_id() + ":stateMachine:unknown" + ) response = client.list_tags_for_resource(resourceArn=non_existing_state_machine) - tags = response['tags'] + tags = response["tags"] tags.should.have.length_of(0) @mock_stepfunctions @mock_sts def test_state_machine_start_execution(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) # - execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - uuid_regex = '[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}' - expected_exec_name = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:name:' + uuid_regex - execution['executionArn'].should.match(expected_exec_name) - execution['startDate'].should.be.a(datetime) + execution["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + uuid_regex = "[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" + expected_exec_name = ( + "arn:aws:states:" + + region + + ":" + + _get_account_id() + + ":execution:name:" + + uuid_regex + ) + execution["executionArn"].should.match(expected_exec_name) + execution["startDate"].should.be.a(datetime) @mock_stepfunctions @mock_sts def test_state_machine_start_execution_with_custom_name(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - execution = client.start_execution(stateMachineArn=sm['stateMachineArn'], name='execution_name') + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution( + stateMachineArn=sm["stateMachineArn"], name="execution_name" + ) # - execution['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - expected_exec_name = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:name:execution_name' - execution['executionArn'].should.equal(expected_exec_name) - execution['startDate'].should.be.a(datetime) + execution["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + expected_exec_name = ( + "arn:aws:states:" + + region + + ":" + + _get_account_id() + + ":execution:name:execution_name" + ) + execution["executionArn"].should.equal(expected_exec_name) + execution["startDate"].should.be.a(datetime) @mock_stepfunctions @mock_sts def test_state_machine_list_executions(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) - execution_arn = execution['executionArn'] - execution_name = execution_arn[execution_arn.rindex(':')+1:] - executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + execution_arn = execution["executionArn"] + execution_name = execution_arn[execution_arn.rindex(":") + 1 :] + executions = client.list_executions(stateMachineArn=sm["stateMachineArn"]) # - executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - executions['executions'].should.have.length_of(1) - executions['executions'][0]['executionArn'].should.equal(execution_arn) - executions['executions'][0]['name'].should.equal(execution_name) - executions['executions'][0]['startDate'].should.equal(execution['startDate']) - executions['executions'][0]['stateMachineArn'].should.equal(sm['stateMachineArn']) - executions['executions'][0]['status'].should.equal('RUNNING') - executions['executions'][0].shouldnt.have('stopDate') + executions["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + executions["executions"].should.have.length_of(1) + executions["executions"][0]["executionArn"].should.equal(execution_arn) + executions["executions"][0]["name"].should.equal(execution_name) + executions["executions"][0]["startDate"].should.equal(execution["startDate"]) + executions["executions"][0]["stateMachineArn"].should.equal(sm["stateMachineArn"]) + executions["executions"][0]["status"].should.equal("RUNNING") + executions["executions"][0].shouldnt.have("stopDate") @mock_stepfunctions @mock_sts def test_state_machine_list_executions_when_none_exist(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - executions = client.list_executions(stateMachineArn=sm['stateMachineArn']) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + executions = client.list_executions(stateMachineArn=sm["stateMachineArn"]) # - executions['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - executions['executions'].should.have.length_of(0) + executions["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + executions["executions"].should.have.length_of(0) @mock_stepfunctions @mock_sts def test_state_machine_describe_execution(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) - description = client.describe_execution(executionArn=execution['executionArn']) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + description = client.describe_execution(executionArn=execution["executionArn"]) # - description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - description['executionArn'].should.equal(execution['executionArn']) - description['input'].should.equal("{}") - description['name'].shouldnt.be.empty - description['startDate'].should.equal(execution['startDate']) - description['stateMachineArn'].should.equal(sm['stateMachineArn']) - description['status'].should.equal('RUNNING') - description.shouldnt.have('stopDate') + description["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + description["executionArn"].should.equal(execution["executionArn"]) + description["input"].should.equal("{}") + description["name"].shouldnt.be.empty + description["startDate"].should.equal(execution["startDate"]) + description["stateMachineArn"].should.equal(sm["stateMachineArn"]) + description["status"].should.equal("RUNNING") + description.shouldnt.have("stopDate") @mock_stepfunctions @mock_sts def test_state_machine_throws_error_when_describing_unknown_machine(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # with assert_raises(ClientError) as exc: - unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' + unknown_execution = ( + "arn:aws:states:" + region + ":" + _get_account_id() + ":execution:unknown" + ) client.describe_execution(executionArn=unknown_execution) @mock_stepfunctions @mock_sts def test_state_machine_can_be_described_by_execution(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) - desc = client.describe_state_machine_for_execution(executionArn=execution['executionArn']) - desc['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - desc['definition'].should.equal(str(simple_definition)) - desc['name'].should.equal('name') - desc['roleArn'].should.equal(_get_default_role()) - desc['stateMachineArn'].should.equal(sm['stateMachineArn']) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + desc = client.describe_state_machine_for_execution( + executionArn=execution["executionArn"] + ) + desc["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + desc["definition"].should.equal(str(simple_definition)) + desc["name"].should.equal("name") + desc["roleArn"].should.equal(_get_default_role()) + desc["stateMachineArn"].should.equal(sm["stateMachineArn"]) @mock_stepfunctions @mock_sts def test_state_machine_throws_error_when_describing_unknown_execution(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # with assert_raises(ClientError) as exc: - unknown_execution = 'arn:aws:states:' + region + ':' + _get_account_id() + ':execution:unknown' + unknown_execution = ( + "arn:aws:states:" + region + ":" + _get_account_id() + ":execution:unknown" + ) client.describe_state_machine_for_execution(executionArn=unknown_execution) @mock_stepfunctions @mock_sts def test_state_machine_stop_execution(): - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - start = client.start_execution(stateMachineArn=sm['stateMachineArn']) - stop = client.stop_execution(executionArn=start['executionArn']) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + start = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + stop = client.stop_execution(executionArn=start["executionArn"]) # - stop['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - stop['stopDate'].should.be.a(datetime) + stop["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + stop["stopDate"].should.be.a(datetime) @mock_stepfunctions @mock_sts def test_state_machine_describe_execution_after_stoppage(): account_id - client = boto3.client('stepfunctions', region_name=region) + client = boto3.client("stepfunctions", region_name=region) # - sm = client.create_state_machine(name='name', definition=str(simple_definition), roleArn=_get_default_role()) - execution = client.start_execution(stateMachineArn=sm['stateMachineArn']) - client.stop_execution(executionArn=execution['executionArn']) - description = client.describe_execution(executionArn=execution['executionArn']) + sm = client.create_state_machine( + name="name", definition=str(simple_definition), roleArn=_get_default_role() + ) + execution = client.start_execution(stateMachineArn=sm["stateMachineArn"]) + client.stop_execution(executionArn=execution["executionArn"]) + description = client.describe_execution(executionArn=execution["executionArn"]) # - description['ResponseMetadata']['HTTPStatusCode'].should.equal(200) - description['status'].should.equal('SUCCEEDED') - description['stopDate'].should.be.a(datetime) + description["ResponseMetadata"]["HTTPStatusCode"].should.equal(200) + description["status"].should.equal("SUCCEEDED") + description["stopDate"].should.be.a(datetime) def _get_account_id(): @@ -385,9 +526,9 @@ def _get_account_id(): return account_id sts = boto3.client("sts") identity = sts.get_caller_identity() - account_id = identity['Account'] + account_id = identity["Account"] return account_id def _get_default_role(): - return 'arn:aws:iam::' + _get_account_id() + ':role/unknown_sf_role' + return "arn:aws:iam::" + _get_account_id() + ":role/unknown_sf_role" diff --git a/tests/test_sts/test_server.py b/tests/test_sts/test_server.py index 40260a49f..8903477d7 100644 --- a/tests/test_sts/test_server.py +++ b/tests/test_sts/test_server.py @@ -3,16 +3,16 @@ import sure # noqa import moto.server as server -''' +""" Test the different server responses -''' +""" def test_sts_get_session_token(): backend = server.create_backend_app("sts") test_client = backend.test_client() - res = test_client.get('/?Action=GetSessionToken') + res = test_client.get("/?Action=GetSessionToken") res.status_code.should.equal(200) res.data.should.contain(b"SessionToken") res.data.should.contain(b"AccessKeyId") @@ -22,7 +22,7 @@ def test_sts_get_federation_token(): backend = server.create_backend_app("sts") test_client = backend.test_client() - res = test_client.get('/?Action=GetFederationToken&Name=Bob') + res = test_client.get("/?Action=GetFederationToken&Name=Bob") res.status_code.should.equal(200) res.data.should.contain(b"SessionToken") res.data.should.contain(b"AccessKeyId") @@ -32,7 +32,7 @@ def test_sts_get_caller_identity(): backend = server.create_backend_app("sts") test_client = backend.test_client() - res = test_client.get('/?Action=GetCallerIdentity') + res = test_client.get("/?Action=GetCallerIdentity") res.status_code.should.equal(200) res.data.should.contain(b"Arn") res.data.should.contain(b"UserId") diff --git a/tests/test_sts/test_sts.py b/tests/test_sts/test_sts.py index b047a8d13..2cb1c49e7 100644 --- a/tests/test_sts/test_sts.py +++ b/tests/test_sts/test_sts.py @@ -20,9 +20,10 @@ def test_get_session_token(): conn = boto.connect_sts() token = conn.get_session_token(duration=123) - token.expiration.should.equal('2012-01-01T12:02:03.000Z') + token.expiration.should.equal("2012-01-01T12:02:03.000Z") token.session_token.should.equal( - "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE") + "AQoEXAMPLEH4aoAH0gNCAPyJxz4BlCFFxWNE1OPTgk5TthT+FvwqnKwRcOIfrRh3c/LTo6UDdyJwOOvEVPvLXCrrrUtdnniCEXAMPLE/IvU1dYUg2RVAJBanLiHb4IgRmpRV3zrkuWJOgQs8IZZaIv2BXIa2R4OlgkBN9bkUDNCJiBeb/AXlzBBko7b15fjrBs2+cTQtpZ3CYWFXG8C5zqx37wnOE49mRl/+OtkIKGO7fAE" + ) token.access_key.should.equal("AKIAIOSFODNN7EXAMPLE") token.secret_key.should.equal("wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") @@ -34,57 +35,72 @@ def test_get_federation_token(): token_name = "Bob" token = conn.get_federation_token(duration=123, name=token_name) - token.credentials.expiration.should.equal('2012-01-01T12:02:03.000Z') + token.credentials.expiration.should.equal("2012-01-01T12:02:03.000Z") token.credentials.session_token.should.equal( - "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQWLWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGdQrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz+scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==") + "AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQWLWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGdQrmGdeehM4IC1NtBmUpp2wUE8phUZampKsburEDy0KPkyQDYwT7WZ0wq5VSXDvp75YU9HFvlRd8Tx6q6fE8YQcHNVXAkiY9q6d+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz+scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==" + ) token.credentials.access_key.should.equal("AKIAIOSFODNN7EXAMPLE") token.credentials.secret_key.should.equal( - "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY") + "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" + ) token.federated_user_arn.should.equal( - "arn:aws:sts::{account_id}:federated-user/{token_name}".format(account_id=ACCOUNT_ID, token_name=token_name)) + "arn:aws:sts::{account_id}:federated-user/{token_name}".format( + account_id=ACCOUNT_ID, token_name=token_name + ) + ) token.federated_user_id.should.equal(str(ACCOUNT_ID) + ":" + token_name) @freeze_time("2012-01-01 12:00:00") @mock_sts def test_assume_role(): - client = boto3.client( - "sts", region_name='us-east-1') + client = boto3.client("sts", region_name="us-east-1") session_name = "session-name" - policy = json.dumps({ - "Statement": [ - { - "Sid": "Stmt13690092345534", - "Action": [ - "S3:ListBucket" - ], - "Effect": "Allow", - "Resource": [ - "arn:aws:s3:::foobar-tester" - ] - }, - ] - }) + policy = json.dumps( + { + "Statement": [ + { + "Sid": "Stmt13690092345534", + "Action": ["S3:ListBucket"], + "Effect": "Allow", + "Resource": ["arn:aws:s3:::foobar-tester"], + } + ] + } + ) role_name = "test-role" - s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format(account_id=ACCOUNT_ID, role_name=role_name) - assume_role_response = client.assume_role(RoleArn=s3_role, RoleSessionName=session_name, - Policy=policy, DurationSeconds=900) + s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format( + account_id=ACCOUNT_ID, role_name=role_name + ) + assume_role_response = client.assume_role( + RoleArn=s3_role, + RoleSessionName=session_name, + Policy=policy, + DurationSeconds=900, + ) - credentials = assume_role_response['Credentials'] + credentials = assume_role_response["Credentials"] if not settings.TEST_SERVER_MODE: - credentials['Expiration'].isoformat().should.equal('2012-01-01T12:15:00+00:00') - credentials['SessionToken'].should.have.length_of(356) - assert credentials['SessionToken'].startswith("FQoGZXIvYXdzE") - credentials['AccessKeyId'].should.have.length_of(20) - assert credentials['AccessKeyId'].startswith("ASIA") - credentials['SecretAccessKey'].should.have.length_of(40) + credentials["Expiration"].isoformat().should.equal("2012-01-01T12:15:00+00:00") + credentials["SessionToken"].should.have.length_of(356) + assert credentials["SessionToken"].startswith("FQoGZXIvYXdzE") + credentials["AccessKeyId"].should.have.length_of(20) + assert credentials["AccessKeyId"].startswith("ASIA") + credentials["SecretAccessKey"].should.have.length_of(40) - assume_role_response['AssumedRoleUser']['Arn'].should.equal("arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( - account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name)) - assert assume_role_response['AssumedRoleUser']['AssumedRoleId'].startswith("AROA") - assert assume_role_response['AssumedRoleUser']['AssumedRoleId'].endswith(":" + session_name) - assume_role_response['AssumedRoleUser']['AssumedRoleId'].should.have.length_of(21 + 1 + len(session_name)) + assume_role_response["AssumedRoleUser"]["Arn"].should.equal( + "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( + account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name + ) + ) + assert assume_role_response["AssumedRoleUser"]["AssumedRoleId"].startswith("AROA") + assert assume_role_response["AssumedRoleUser"]["AssumedRoleId"].endswith( + ":" + session_name + ) + assume_role_response["AssumedRoleUser"]["AssumedRoleId"].should.have.length_of( + 21 + 1 + len(session_name) + ) @freeze_time("2012-01-01 12:00:00") @@ -92,122 +108,135 @@ def test_assume_role(): def test_assume_role_with_web_identity(): conn = boto.connect_sts() - policy = json.dumps({ - "Statement": [ - { - "Sid": "Stmt13690092345534", - "Action": [ - "S3:ListBucket" - ], - "Effect": "Allow", - "Resource": [ - "arn:aws:s3:::foobar-tester" - ] - }, - ] - }) + policy = json.dumps( + { + "Statement": [ + { + "Sid": "Stmt13690092345534", + "Action": ["S3:ListBucket"], + "Effect": "Allow", + "Resource": ["arn:aws:s3:::foobar-tester"], + } + ] + } + ) role_name = "test-role" - s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format(account_id=ACCOUNT_ID, role_name=role_name) + s3_role = "arn:aws:iam::{account_id}:role/{role_name}".format( + account_id=ACCOUNT_ID, role_name=role_name + ) session_name = "session-name" role = conn.assume_role_with_web_identity( - s3_role, session_name, policy, duration_seconds=123) + s3_role, session_name, policy, duration_seconds=123 + ) credentials = role.credentials - credentials.expiration.should.equal('2012-01-01T12:02:03.000Z') + credentials.expiration.should.equal("2012-01-01T12:02:03.000Z") credentials.session_token.should.have.length_of(356) assert credentials.session_token.startswith("FQoGZXIvYXdzE") credentials.access_key.should.have.length_of(20) assert credentials.access_key.startswith("ASIA") credentials.secret_key.should.have.length_of(40) - role.user.arn.should.equal("arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( - account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name)) + role.user.arn.should.equal( + "arn:aws:sts::{account_id}:assumed-role/{role_name}/{session_name}".format( + account_id=ACCOUNT_ID, role_name=role_name, session_name=session_name + ) + ) role.user.assume_role_id.should.contain("session-name") @mock_sts def test_get_caller_identity_with_default_credentials(): - identity = boto3.client( - "sts", region_name='us-east-1').get_caller_identity() + identity = boto3.client("sts", region_name="us-east-1").get_caller_identity() - identity['Arn'].should.equal('arn:aws:sts::{account_id}:user/moto'.format(account_id=ACCOUNT_ID)) - identity['UserId'].should.equal('AKIAIOSFODNN7EXAMPLE') - identity['Account'].should.equal(str(ACCOUNT_ID)) + identity["Arn"].should.equal( + "arn:aws:sts::{account_id}:user/moto".format(account_id=ACCOUNT_ID) + ) + identity["UserId"].should.equal("AKIAIOSFODNN7EXAMPLE") + identity["Account"].should.equal(str(ACCOUNT_ID)) @mock_sts @mock_iam def test_get_caller_identity_with_iam_user_credentials(): - iam_client = boto3.client("iam", region_name='us-east-1') + iam_client = boto3.client("iam", region_name="us-east-1") iam_user_name = "new-user" - iam_user = iam_client.create_user(UserName=iam_user_name)['User'] - access_key = iam_client.create_access_key(UserName=iam_user_name)['AccessKey'] + iam_user = iam_client.create_user(UserName=iam_user_name)["User"] + access_key = iam_client.create_access_key(UserName=iam_user_name)["AccessKey"] identity = boto3.client( - "sts", region_name='us-east-1', aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']).get_caller_identity() + "sts", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ).get_caller_identity() - identity['Arn'].should.equal(iam_user['Arn']) - identity['UserId'].should.equal(iam_user['UserId']) - identity['Account'].should.equal(str(ACCOUNT_ID)) + identity["Arn"].should.equal(iam_user["Arn"]) + identity["UserId"].should.equal(iam_user["UserId"]) + identity["Account"].should.equal(str(ACCOUNT_ID)) @mock_sts @mock_iam def test_get_caller_identity_with_assumed_role_credentials(): - iam_client = boto3.client("iam", region_name='us-east-1') - sts_client = boto3.client("sts", region_name='us-east-1') + iam_client = boto3.client("iam", region_name="us-east-1") + sts_client = boto3.client("sts", region_name="us-east-1") iam_role_name = "new-user" trust_policy_document = { "Version": "2012-10-17", "Statement": { "Effect": "Allow", - "Principal": {"AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID)}, - "Action": "sts:AssumeRole" - } + "Principal": { + "AWS": "arn:aws:iam::{account_id}:root".format(account_id=ACCOUNT_ID) + }, + "Action": "sts:AssumeRole", + }, } iam_role_arn = iam_client.role_arn = iam_client.create_role( RoleName=iam_role_name, - AssumeRolePolicyDocument=json.dumps(trust_policy_document) - )['Role']['Arn'] + AssumeRolePolicyDocument=json.dumps(trust_policy_document), + )["Role"]["Arn"] session_name = "new-session" - assumed_role = sts_client.assume_role(RoleArn=iam_role_arn, - RoleSessionName=session_name) - access_key = assumed_role['Credentials'] + assumed_role = sts_client.assume_role( + RoleArn=iam_role_arn, RoleSessionName=session_name + ) + access_key = assumed_role["Credentials"] identity = boto3.client( - "sts", region_name='us-east-1', aws_access_key_id=access_key['AccessKeyId'], - aws_secret_access_key=access_key['SecretAccessKey']).get_caller_identity() + "sts", + region_name="us-east-1", + aws_access_key_id=access_key["AccessKeyId"], + aws_secret_access_key=access_key["SecretAccessKey"], + ).get_caller_identity() - identity['Arn'].should.equal(assumed_role['AssumedRoleUser']['Arn']) - identity['UserId'].should.equal(assumed_role['AssumedRoleUser']['AssumedRoleId']) - identity['Account'].should.equal(str(ACCOUNT_ID)) + identity["Arn"].should.equal(assumed_role["AssumedRoleUser"]["Arn"]) + identity["UserId"].should.equal(assumed_role["AssumedRoleUser"]["AssumedRoleId"]) + identity["Account"].should.equal(str(ACCOUNT_ID)) @mock_sts def test_federation_token_with_too_long_policy(): "Trying to get a federation token with a policy longer than 2048 character should fail" - cli = boto3.client("sts", region_name='us-east-1') - resource_tmpl = 'arn:aws:s3:::yyyy-xxxxx-cloud-default/my_default_folder/folder-name-%s/*' + cli = boto3.client("sts", region_name="us-east-1") + resource_tmpl = ( + "arn:aws:s3:::yyyy-xxxxx-cloud-default/my_default_folder/folder-name-%s/*" + ) statements = [] for num in range(30): statements.append( { - 'Effect': 'Allow', - 'Action': ['s3:*'], - 'Resource': resource_tmpl % str(num) + "Effect": "Allow", + "Action": ["s3:*"], + "Resource": resource_tmpl % str(num), } ) - policy = { - 'Version': '2012-10-17', - 'Statement': statements - } + policy = {"Version": "2012-10-17", "Statement": statements} json_policy = json.dumps(policy) assert len(json_policy) > MAX_FEDERATION_TOKEN_POLICY_LENGTH with assert_raises(ClientError) as exc: - cli.get_federation_token(Name='foo', DurationSeconds=3600, Policy=json_policy) - exc.exception.response['Error']['Code'].should.equal('ValidationError') - exc.exception.response['Error']['Message'].should.contain( + cli.get_federation_token(Name="foo", DurationSeconds=3600, Policy=json_policy) + exc.exception.response["Error"]["Code"].should.equal("ValidationError") + exc.exception.response["Error"]["Message"].should.contain( str(MAX_FEDERATION_TOKEN_POLICY_LENGTH) ) diff --git a/tests/test_swf/models/test_activity_task.py b/tests/test_swf/models/test_activity_task.py index 41c88cafe..96f7c345f 100644 --- a/tests/test_swf/models/test_activity_task.py +++ b/tests/test_swf/models/test_activity_task.py @@ -2,11 +2,7 @@ from freezegun import freeze_time import sure # noqa from moto.swf.exceptions import SWFWorkflowExecutionClosedError -from moto.swf.models import ( - ActivityTask, - ActivityType, - Timeout, -) +from moto.swf.models import ActivityTask, ActivityType, Timeout from ..utils import ( ACTIVITY_TASK_TIMEOUTS, @@ -149,6 +145,7 @@ def test_activity_task_cannot_change_state_on_closed_workflow_execution(): wfe.complete(123) task.timeout.when.called_with(Timeout(task, 0, "foo")).should.throw( - SWFWorkflowExecutionClosedError) + SWFWorkflowExecutionClosedError + ) task.complete.when.called_with().should.throw(SWFWorkflowExecutionClosedError) task.fail.when.called_with().should.throw(SWFWorkflowExecutionClosedError) diff --git a/tests/test_swf/models/test_decision_task.py b/tests/test_swf/models/test_decision_task.py index b5e23eaca..0661adffb 100644 --- a/tests/test_swf/models/test_decision_task.py +++ b/tests/test_swf/models/test_decision_task.py @@ -76,5 +76,6 @@ def test_decision_task_cannot_change_state_on_closed_workflow_execution(): wfe.complete(123) task.timeout.when.called_with(Timeout(task, 0, "foo")).should.throw( - SWFWorkflowExecutionClosedError) + SWFWorkflowExecutionClosedError + ) task.complete.when.called_with().should.throw(SWFWorkflowExecutionClosedError) diff --git a/tests/test_swf/models/test_domain.py b/tests/test_swf/models/test_domain.py index 1a8a1268d..389e516df 100644 --- a/tests/test_swf/models/test_domain.py +++ b/tests/test_swf/models/test_domain.py @@ -9,15 +9,13 @@ import tests.backport_assert_raises # noqa # Fake WorkflowExecution for tests purposes WorkflowExecution = namedtuple( - "WorkflowExecution", - ["workflow_id", "run_id", "execution_status", "open"] + "WorkflowExecution", ["workflow_id", "run_id", "execution_status", "open"] ) def test_domain_short_dict_representation(): domain = Domain("foo", "52") - domain.to_short_dict().should.equal( - {"name": "foo", "status": "REGISTERED"}) + domain.to_short_dict().should.equal({"name": "foo", "status": "REGISTERED"}) domain.description = "foo bar" domain.to_short_dict()["description"].should.equal("foo bar") @@ -39,9 +37,7 @@ def test_domain_string_representation(): def test_domain_add_to_activity_task_list(): domain = Domain("my-domain", "60") domain.add_to_activity_task_list("foo", "bar") - domain.activity_task_lists.should.equal({ - "foo": ["bar"] - }) + domain.activity_task_lists.should.equal({"foo": ["bar"]}) def test_domain_activity_tasks(): @@ -54,9 +50,7 @@ def test_domain_activity_tasks(): def test_domain_add_to_decision_task_list(): domain = Domain("my-domain", "60") domain.add_to_decision_task_list("foo", "bar") - domain.decision_task_lists.should.equal({ - "foo": ["bar"] - }) + domain.decision_task_lists.should.equal({"foo": ["bar"]}) def test_domain_decision_tasks(): @@ -70,50 +64,44 @@ def test_domain_get_workflow_execution(): domain = Domain("my-domain", "60") wfe1 = WorkflowExecution( - workflow_id="wf-id-1", run_id="run-id-1", execution_status="OPEN", open=True) + workflow_id="wf-id-1", run_id="run-id-1", execution_status="OPEN", open=True + ) wfe2 = WorkflowExecution( - workflow_id="wf-id-1", run_id="run-id-2", execution_status="CLOSED", open=False) + workflow_id="wf-id-1", run_id="run-id-2", execution_status="CLOSED", open=False + ) wfe3 = WorkflowExecution( - workflow_id="wf-id-2", run_id="run-id-3", execution_status="OPEN", open=True) + workflow_id="wf-id-2", run_id="run-id-3", execution_status="OPEN", open=True + ) wfe4 = WorkflowExecution( - workflow_id="wf-id-3", run_id="run-id-4", execution_status="CLOSED", open=False) + workflow_id="wf-id-3", run_id="run-id-4", execution_status="CLOSED", open=False + ) domain.workflow_executions = [wfe1, wfe2, wfe3, wfe4] # get workflow execution through workflow_id and run_id - domain.get_workflow_execution( - "wf-id-1", run_id="run-id-1").should.equal(wfe1) - domain.get_workflow_execution( - "wf-id-1", run_id="run-id-2").should.equal(wfe2) - domain.get_workflow_execution( - "wf-id-3", run_id="run-id-4").should.equal(wfe4) + domain.get_workflow_execution("wf-id-1", run_id="run-id-1").should.equal(wfe1) + domain.get_workflow_execution("wf-id-1", run_id="run-id-2").should.equal(wfe2) + domain.get_workflow_execution("wf-id-3", run_id="run-id-4").should.equal(wfe4) domain.get_workflow_execution.when.called_with( "wf-id-1", run_id="non-existent" - ).should.throw( - SWFUnknownResourceFault, - ) + ).should.throw(SWFUnknownResourceFault) # get OPEN workflow execution by default if no run_id domain.get_workflow_execution("wf-id-1").should.equal(wfe1) - domain.get_workflow_execution.when.called_with( - "wf-id-3" - ).should.throw( + domain.get_workflow_execution.when.called_with("wf-id-3").should.throw( SWFUnknownResourceFault ) - domain.get_workflow_execution.when.called_with( - "wf-id-non-existent" - ).should.throw( + domain.get_workflow_execution.when.called_with("wf-id-non-existent").should.throw( SWFUnknownResourceFault ) # raise_if_closed attribute domain.get_workflow_execution( - "wf-id-1", run_id="run-id-1", raise_if_closed=True).should.equal(wfe1) + "wf-id-1", run_id="run-id-1", raise_if_closed=True + ).should.equal(wfe1) domain.get_workflow_execution.when.called_with( "wf-id-3", run_id="run-id-4", raise_if_closed=True - ).should.throw( - SWFUnknownResourceFault - ) + ).should.throw(SWFUnknownResourceFault) # raise_if_none attribute domain.get_workflow_execution("foo", raise_if_none=False).should.be.none diff --git a/tests/test_swf/models/test_generic_type.py b/tests/test_swf/models/test_generic_type.py index 294df9f84..ef7378d06 100644 --- a/tests/test_swf/models/test_generic_type.py +++ b/tests/test_swf/models/test_generic_type.py @@ -4,7 +4,6 @@ import sure # noqa # Tests for GenericType (ActivityType, WorkflowType) class FooType(GenericType): - @property def kind(self): return "foo" @@ -40,12 +39,12 @@ def test_type_full_dict_representation(): _type.to_full_dict()["configuration"].should.equal({}) _type.task_list = "foo" - _type.to_full_dict()["configuration"][ - "defaultTaskList"].should.equal({"name": "foo"}) + _type.to_full_dict()["configuration"]["defaultTaskList"].should.equal( + {"name": "foo"} + ) _type.just_an_example_timeout = "60" - _type.to_full_dict()["configuration"][ - "justAnExampleTimeout"].should.equal("60") + _type.to_full_dict()["configuration"]["justAnExampleTimeout"].should.equal("60") _type.non_whitelisted_property = "34" keys = _type.to_full_dict()["configuration"].keys() @@ -55,4 +54,5 @@ def test_type_full_dict_representation(): def test_type_string_representation(): _type = FooType("test-foo", "v1.0") str(_type).should.equal( - "FooType(name: test-foo, version: v1.0, status: REGISTERED)") + "FooType(name: test-foo, version: v1.0, status: REGISTERED)" + ) diff --git a/tests/test_swf/models/test_history_event.py b/tests/test_swf/models/test_history_event.py index b869408ce..8b8234187 100644 --- a/tests/test_swf/models/test_history_event.py +++ b/tests/test_swf/models/test_history_event.py @@ -15,17 +15,17 @@ def test_history_event_creation(): @freeze_time("2015-01-01 12:00:00") def test_history_event_to_dict_representation(): he = HistoryEvent(123, "DecisionTaskStarted", scheduled_event_id=2) - he.to_dict().should.equal({ - "eventId": 123, - "eventType": "DecisionTaskStarted", - "eventTimestamp": 1420113600.0, - "decisionTaskStartedEventAttributes": { - "scheduledEventId": 2 + he.to_dict().should.equal( + { + "eventId": 123, + "eventType": "DecisionTaskStarted", + "eventTimestamp": 1420113600.0, + "decisionTaskStartedEventAttributes": {"scheduledEventId": 2}, } - }) + ) def test_history_event_breaks_on_initialization_if_not_implemented(): - HistoryEvent.when.called_with( - 123, "UnknownHistoryEvent" - ).should.throw(NotImplementedError) + HistoryEvent.when.called_with(123, "UnknownHistoryEvent").should.throw( + NotImplementedError + ) diff --git a/tests/test_swf/models/test_workflow_execution.py b/tests/test_swf/models/test_workflow_execution.py index 45b91c86a..6c73a9686 100644 --- a/tests/test_swf/models/test_workflow_execution.py +++ b/tests/test_swf/models/test_workflow_execution.py @@ -1,12 +1,7 @@ from freezegun import freeze_time import sure # noqa -from moto.swf.models import ( - ActivityType, - Timeout, - WorkflowType, - WorkflowExecution, -) +from moto.swf.models import ActivityType, Timeout, WorkflowType, WorkflowExecution from moto.swf.exceptions import SWFDefaultUndefinedFault from ..utils import ( auto_start_decision_tasks, @@ -43,28 +38,31 @@ def test_workflow_execution_creation_child_policy_logic(): WorkflowExecution( domain, WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", default_execution_start_to_close_timeout="300", default_task_start_to_close_timeout="300", ), - "ab1234" + "ab1234", ).child_policy.should.equal("ABANDON") WorkflowExecution( domain, WorkflowType( - "test-workflow", "v1.0", task_list="queue", + "test-workflow", + "v1.0", + task_list="queue", default_execution_start_to_close_timeout="300", default_task_start_to_close_timeout="300", ), "ab1234", - child_policy="REQUEST_CANCEL" + child_policy="REQUEST_CANCEL", ).child_policy.should.equal("REQUEST_CANCEL") WorkflowExecution.when.called_with( - domain, - WorkflowType("test-workflow", "v1.0"), "ab1234" + domain, WorkflowType("test-workflow", "v1.0"), "ab1234" ).should.throw(SWFDefaultUndefinedFault) @@ -84,8 +82,10 @@ def test_workflow_execution_generates_a_random_run_id(): def test_workflow_execution_short_dict_representation(): domain = get_basic_domain() wf_type = WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", default_execution_start_to_close_timeout="300", default_task_start_to_close_timeout="300", ) @@ -99,8 +99,10 @@ def test_workflow_execution_short_dict_representation(): def test_workflow_execution_medium_dict_representation(): domain = get_basic_domain() wf_type = WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", default_execution_start_to_close_timeout="300", default_task_start_to_close_timeout="300", ) @@ -109,7 +111,7 @@ def test_workflow_execution_medium_dict_representation(): md = wfe.to_medium_dict() md["execution"].should.equal(wfe.to_short_dict()) md["workflowType"].should.equal(wf_type.to_short_dict()) - md["startTimestamp"].should.be.a('float') + md["startTimestamp"].should.be.a("float") md["executionStatus"].should.equal("OPEN") md["cancelRequested"].should.be.falsy md.should_not.contain("tagList") @@ -122,8 +124,10 @@ def test_workflow_execution_medium_dict_representation(): def test_workflow_execution_full_dict_representation(): domain = get_basic_domain() wf_type = WorkflowType( - "test-workflow", "v1.0", - task_list="queue", default_child_policy="ABANDON", + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", default_execution_start_to_close_timeout="300", default_task_start_to_close_timeout="300", ) @@ -134,32 +138,36 @@ def test_workflow_execution_full_dict_representation(): fd["openCounts"]["openTimers"].should.equal(0) fd["openCounts"]["openDecisionTasks"].should.equal(0) fd["openCounts"]["openActivityTasks"].should.equal(0) - fd["executionConfiguration"].should.equal({ - "childPolicy": "ABANDON", - "executionStartToCloseTimeout": "300", - "taskList": {"name": "queue"}, - "taskStartToCloseTimeout": "300", - }) + fd["executionConfiguration"].should.equal( + { + "childPolicy": "ABANDON", + "executionStartToCloseTimeout": "300", + "taskList": {"name": "queue"}, + "taskStartToCloseTimeout": "300", + } + ) def test_workflow_execution_list_dict_representation(): domain = get_basic_domain() wf_type = WorkflowType( - 'test-workflow', 'v1.0', - task_list='queue', default_child_policy='ABANDON', - default_execution_start_to_close_timeout='300', - default_task_start_to_close_timeout='300', + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="ABANDON", + default_execution_start_to_close_timeout="300", + default_task_start_to_close_timeout="300", ) - wfe = WorkflowExecution(domain, wf_type, 'ab1234') + wfe = WorkflowExecution(domain, wf_type, "ab1234") ld = wfe.to_list_dict() - ld['workflowType']['version'].should.equal('v1.0') - ld['workflowType']['name'].should.equal('test-workflow') - ld['executionStatus'].should.equal('OPEN') - ld['execution']['workflowId'].should.equal('ab1234') - ld['execution'].should.contain('runId') - ld['cancelRequested'].should.be.false - ld.should.contain('startTimestamp') + ld["workflowType"]["version"].should.equal("v1.0") + ld["workflowType"]["name"].should.equal("test-workflow") + ld["executionStatus"].should.equal("OPEN") + ld["execution"]["workflowId"].should.equal("ab1234") + ld["execution"].should.contain("runId") + ld["cancelRequested"].should.be.false + ld.should.contain("startTimestamp") def test_workflow_execution_schedule_decision_task(): @@ -240,10 +248,8 @@ def test_workflow_execution_schedule_activity_task(): wfe.open_counts["openActivityTasks"].should.equal(1) last_event = wfe.events()[-1] last_event.event_type.should.equal("ActivityTaskScheduled") - last_event.event_attributes[ - "decisionTaskCompletedEventId"].should.equal(123) - last_event.event_attributes["taskList"][ - "name"].should.equal("task-list-name") + last_event.event_attributes["decisionTaskCompletedEventId"].should.equal(123) + last_event.event_attributes["taskList"]["name"].should.equal("task-list-name") wfe.activity_tasks.should.have.length_of(1) task = wfe.activity_tasks[0] @@ -254,17 +260,18 @@ def test_workflow_execution_schedule_activity_task(): def test_workflow_execution_schedule_activity_task_without_task_list_should_take_default(): wfe = make_workflow_execution() - wfe.domain.add_type( - ActivityType("test-activity", "v1.2", task_list="foobar") + wfe.domain.add_type(ActivityType("test-activity", "v1.2", task_list="foobar")) + wfe.schedule_activity_task( + 123, + { + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.2"}, + "scheduleToStartTimeout": "600", + "scheduleToCloseTimeout": "600", + "startToCloseTimeout": "600", + "heartbeatTimeout": "300", + }, ) - wfe.schedule_activity_task(123, { - "activityId": "my-activity-001", - "activityType": {"name": "test-activity", "version": "v1.2"}, - "scheduleToStartTimeout": "600", - "scheduleToCloseTimeout": "600", - "startToCloseTimeout": "600", - "heartbeatTimeout": "300", - }) wfe.open_counts["openActivityTasks"].should.equal(1) last_event = wfe.events()[-1] @@ -290,50 +297,51 @@ def test_workflow_execution_schedule_activity_task_should_fail_if_wrong_attribut wfe.schedule_activity_task(123, hsh) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "ACTIVITY_TYPE_DOES_NOT_EXIST") + last_event.event_attributes["cause"].should.equal("ACTIVITY_TYPE_DOES_NOT_EXIST") hsh["activityType"]["name"] = "test-activity" wfe.schedule_activity_task(123, hsh) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "ACTIVITY_TYPE_DEPRECATED") + last_event.event_attributes["cause"].should.equal("ACTIVITY_TYPE_DEPRECATED") hsh["activityType"]["version"] = "v1.2" wfe.schedule_activity_task(123, hsh) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "DEFAULT_TASK_LIST_UNDEFINED") + last_event.event_attributes["cause"].should.equal("DEFAULT_TASK_LIST_UNDEFINED") hsh["taskList"] = {"name": "foobar"} wfe.schedule_activity_task(123, hsh) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") last_event.event_attributes["cause"].should.equal( - "DEFAULT_SCHEDULE_TO_START_TIMEOUT_UNDEFINED") + "DEFAULT_SCHEDULE_TO_START_TIMEOUT_UNDEFINED" + ) hsh["scheduleToStartTimeout"] = "600" wfe.schedule_activity_task(123, hsh) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") last_event.event_attributes["cause"].should.equal( - "DEFAULT_SCHEDULE_TO_CLOSE_TIMEOUT_UNDEFINED") + "DEFAULT_SCHEDULE_TO_CLOSE_TIMEOUT_UNDEFINED" + ) hsh["scheduleToCloseTimeout"] = "600" wfe.schedule_activity_task(123, hsh) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") last_event.event_attributes["cause"].should.equal( - "DEFAULT_START_TO_CLOSE_TIMEOUT_UNDEFINED") + "DEFAULT_START_TO_CLOSE_TIMEOUT_UNDEFINED" + ) hsh["startToCloseTimeout"] = "600" wfe.schedule_activity_task(123, hsh) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") last_event.event_attributes["cause"].should.equal( - "DEFAULT_HEARTBEAT_TIMEOUT_UNDEFINED") + "DEFAULT_HEARTBEAT_TIMEOUT_UNDEFINED" + ) wfe.open_counts["openActivityTasks"].should.equal(0) wfe.activity_tasks.should.have.length_of(0) @@ -365,9 +373,9 @@ def test_workflow_execution_schedule_activity_task_failure_triggers_new_decision "activityId": "my-activity-001", "activityType": { "name": "test-activity-does-not-exist", - "version": "v1.2" + "version": "v1.2", }, - } + }, }, { "decisionType": "ScheduleActivityTask", @@ -375,11 +383,12 @@ def test_workflow_execution_schedule_activity_task_failure_triggers_new_decision "activityId": "my-activity-001", "activityType": { "name": "test-activity-does-not-exist", - "version": "v1.2" + "version": "v1.2", }, - } + }, }, - ]) + ], + ) wfe.latest_execution_context.should.equal("free-form execution context") wfe.open_counts["openActivityTasks"].should.equal(0) @@ -402,8 +411,7 @@ def test_workflow_execution_schedule_activity_task_with_same_activity_id(): wfe.open_counts["openActivityTasks"].should.equal(1) last_event = wfe.events()[-1] last_event.event_type.should.equal("ScheduleActivityTaskFailed") - last_event.event_attributes["cause"].should.equal( - "ACTIVITY_ID_ALREADY_IN_USE") + last_event.event_attributes["cause"].should.equal("ACTIVITY_ID_ALREADY_IN_USE") def test_workflow_execution_start_activity_task(): @@ -481,8 +489,7 @@ def test_timeouts_are_processed_in_order_and_reevaluated(): # - but the last scheduled decision task should *not* timeout (workflow closed) with freeze_time("2015-01-01 12:00:00"): wfe = make_workflow_execution( - execution_start_to_close_timeout=8 * 60, - task_start_to_close_timeout=5 * 60, + execution_start_to_close_timeout=8 * 60, task_start_to_close_timeout=5 * 60 ) # decision will automatically start wfe = auto_start_decision_tasks(wfe) @@ -493,9 +500,11 @@ def test_timeouts_are_processed_in_order_and_reevaluated(): wfe._process_timeouts() event_types = [e.event_type for e in wfe.events()[event_idx:]] - event_types.should.equal([ - "DecisionTaskTimedOut", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "WorkflowExecutionTimedOut", - ]) + event_types.should.equal( + [ + "DecisionTaskTimedOut", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "WorkflowExecutionTimedOut", + ] + ) diff --git a/tests/test_swf/responses/test_activity_tasks.py b/tests/test_swf/responses/test_activity_tasks.py index c0b8897b9..0b72b7ca7 100644 --- a/tests/test_swf/responses/test_activity_tasks.py +++ b/tests/test_swf/responses/test_activity_tasks.py @@ -12,18 +12,19 @@ from ..utils import setup_workflow, SCHEDULE_ACTIVITY_TASK_DECISION @mock_swf_deprecated def test_poll_for_activity_task_when_one(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) resp = conn.poll_for_activity_task( - "test-domain", "activity-task-list", identity="surprise") + "test-domain", "activity-task-list", identity="surprise" + ) resp["activityId"].should.equal("my-activity-001") resp["taskToken"].should_not.be.none resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) resp["events"][-1]["eventType"].should.equal("ActivityTaskStarted") resp["events"][-1]["activityTaskStartedEventAttributes"].should.equal( {"identity": "surprise", "scheduledEventId": 5} @@ -48,14 +49,12 @@ def test_poll_for_activity_task_on_non_existent_queue(): @mock_swf_deprecated def test_count_pending_activity_tasks(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) - resp = conn.count_pending_activity_tasks( - "test-domain", "activity-task-list") + resp = conn.count_pending_activity_tasks("test-domain", "activity-task-list") resp.should.equal({"count": 1, "truncated": False}) @@ -70,20 +69,22 @@ def test_count_pending_decision_tasks_on_non_existent_task_list(): @mock_swf_deprecated def test_respond_activity_task_completed(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] resp = conn.respond_activity_task_completed( - activity_token, result="result of the task") + activity_token, result="result of the task" + ) resp.should.be.none resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) resp["events"][-2]["eventType"].should.equal("ActivityTaskCompleted") resp["events"][-2]["activityTaskCompletedEventAttributes"].should.equal( {"result": "result of the task", "scheduledEventId": 5, "startedEventId": 6} @@ -93,13 +94,13 @@ def test_respond_activity_task_completed(): @mock_swf_deprecated def test_respond_activity_task_completed_on_closed_workflow_execution(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] # bad: we're closing workflow execution manually, but endpoints are not # coded for now.. @@ -107,52 +108,57 @@ def test_respond_activity_task_completed_on_closed_workflow_execution(): wfe.execution_status = "CLOSED" # /bad - conn.respond_activity_task_completed.when.called_with( - activity_token - ).should.throw(SWFResponseError, "WorkflowExecution=") + conn.respond_activity_task_completed.when.called_with(activity_token).should.throw( + SWFResponseError, "WorkflowExecution=" + ) @mock_swf_deprecated def test_respond_activity_task_completed_with_task_already_completed(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] conn.respond_activity_task_completed(activity_token) - conn.respond_activity_task_completed.when.called_with( - activity_token - ).should.throw(SWFResponseError, "Unknown activity, scheduledEventId = 5") + conn.respond_activity_task_completed.when.called_with(activity_token).should.throw( + SWFResponseError, "Unknown activity, scheduledEventId = 5" + ) # RespondActivityTaskFailed endpoint @mock_swf_deprecated def test_respond_activity_task_failed(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] - resp = conn.respond_activity_task_failed(activity_token, - reason="short reason", - details="long details") + resp = conn.respond_activity_task_failed( + activity_token, reason="short reason", details="long details" + ) resp.should.be.none resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) resp["events"][-2]["eventType"].should.equal("ActivityTaskFailed") resp["events"][-2]["activityTaskFailedEventAttributes"].should.equal( - {"reason": "short reason", "details": "long details", - "scheduledEventId": 5, "startedEventId": 6} + { + "reason": "short reason", + "details": "long details", + "scheduledEventId": 5, + "startedEventId": 6, + } ) @@ -162,11 +168,10 @@ def test_respond_activity_task_completed_with_wrong_token(): # because the safeguards are shared with RespondActivityTaskCompleted, so # no need to retest everything end-to-end. conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) conn.poll_for_activity_task("test-domain", "activity-task-list") conn.respond_activity_task_failed.when.called_with( "not-a-correct-token" @@ -177,13 +182,13 @@ def test_respond_activity_task_completed_with_wrong_token(): @mock_swf_deprecated def test_record_activity_task_heartbeat(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + activity_token = conn.poll_for_activity_task("test-domain", "activity-task-list")[ + "taskToken" + ] resp = conn.record_activity_task_heartbeat(activity_token) resp.should.equal({"cancelRequested": False}) @@ -192,13 +197,11 @@ def test_record_activity_task_heartbeat(): @mock_swf_deprecated def test_record_activity_task_heartbeat_with_wrong_token(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) - conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) + conn.poll_for_activity_task("test-domain", "activity-task-list")["taskToken"] conn.record_activity_task_heartbeat.when.called_with( "bad-token", details="some progress details" @@ -208,21 +211,23 @@ def test_record_activity_task_heartbeat_with_wrong_token(): @mock_swf_deprecated def test_record_activity_task_heartbeat_sets_details_in_case_of_timeout(): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) + decision_token = conn.poll_for_decision_task("test-domain", "queue")["taskToken"] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) with freeze_time("2015-01-01 12:00:00"): activity_token = conn.poll_for_activity_task( - "test-domain", "activity-task-list")["taskToken"] + "test-domain", "activity-task-list" + )["taskToken"] conn.record_activity_task_heartbeat( - activity_token, details="some progress details") + activity_token, details="some progress details" + ) with freeze_time("2015-01-01 12:05:30"): # => Activity Task Heartbeat timeout reached!! resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) resp["events"][-2]["eventType"].should.equal("ActivityTaskTimedOut") attrs = resp["events"][-2]["activityTaskTimedOutEventAttributes"] attrs["details"].should.equal("some progress details") diff --git a/tests/test_swf/responses/test_activity_types.py b/tests/test_swf/responses/test_activity_types.py index 95d8a3733..3fa9ad6b1 100644 --- a/tests/test_swf/responses/test_activity_types.py +++ b/tests/test_swf/responses/test_activity_types.py @@ -49,10 +49,11 @@ def test_list_activity_types(): conn.register_activity_type("test-domain", "c-test-activity", "v1.0") all_activity_types = conn.list_activity_types("test-domain", "REGISTERED") - names = [activity_type["activityType"]["name"] - for activity_type in all_activity_types["typeInfos"]] - names.should.equal( - ["a-test-activity", "b-test-activity", "c-test-activity"]) + names = [ + activity_type["activityType"]["name"] + for activity_type in all_activity_types["typeInfos"] + ] + names.should.equal(["a-test-activity", "b-test-activity", "c-test-activity"]) @mock_swf_deprecated @@ -63,12 +64,14 @@ def test_list_activity_types_reverse_order(): conn.register_activity_type("test-domain", "a-test-activity", "v1.0") conn.register_activity_type("test-domain", "c-test-activity", "v1.0") - all_activity_types = conn.list_activity_types("test-domain", "REGISTERED", - reverse_order=True) - names = [activity_type["activityType"]["name"] - for activity_type in all_activity_types["typeInfos"]] - names.should.equal( - ["c-test-activity", "b-test-activity", "a-test-activity"]) + all_activity_types = conn.list_activity_types( + "test-domain", "REGISTERED", reverse_order=True + ) + names = [ + activity_type["activityType"]["name"] + for activity_type in all_activity_types["typeInfos"] + ] + names.should.equal(["c-test-activity", "b-test-activity", "a-test-activity"]) # DeprecateActivityType endpoint @@ -112,11 +115,15 @@ def test_deprecate_non_existent_activity_type(): def test_describe_activity_type(): conn = boto.connect_swf("the_key", "the_secret") conn.register_domain("test-domain", "60") - conn.register_activity_type("test-domain", "test-activity", "v1.0", - task_list="foo", default_task_heartbeat_timeout="32") + conn.register_activity_type( + "test-domain", + "test-activity", + "v1.0", + task_list="foo", + default_task_heartbeat_timeout="32", + ) - actype = conn.describe_activity_type( - "test-domain", "test-activity", "v1.0") + actype = conn.describe_activity_type("test-domain", "test-activity", "v1.0") actype["configuration"]["defaultTaskList"]["name"].should.equal("foo") infos = actype["typeInfo"] infos["activityType"]["name"].should.equal("test-activity") diff --git a/tests/test_swf/responses/test_decision_tasks.py b/tests/test_swf/responses/test_decision_tasks.py index 972b1053b..6389536e6 100644 --- a/tests/test_swf/responses/test_decision_tasks.py +++ b/tests/test_swf/responses/test_decision_tasks.py @@ -14,18 +14,20 @@ def test_poll_for_decision_task_when_one(): conn = setup_workflow() resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) types = [evt["eventType"] for evt in resp["events"]] types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) - resp = conn.poll_for_decision_task( - "test-domain", "queue", identity="srv01") + resp = conn.poll_for_decision_task("test-domain", "queue", identity="srv01") types = [evt["eventType"] for evt in resp["events"]] - types.should.equal(["WorkflowExecutionStarted", - "DecisionTaskScheduled", "DecisionTaskStarted"]) + types.should.equal( + ["WorkflowExecutionStarted", "DecisionTaskScheduled", "DecisionTaskStarted"] + ) - resp[ - "events"][-1]["decisionTaskStartedEventAttributes"]["identity"].should.equal("srv01") + resp["events"][-1]["decisionTaskStartedEventAttributes"]["identity"].should.equal( + "srv01" + ) @mock_swf_deprecated @@ -49,11 +51,11 @@ def test_poll_for_decision_task_on_non_existent_queue(): @mock_swf_deprecated def test_poll_for_decision_task_with_reverse_order(): conn = setup_workflow() - resp = conn.poll_for_decision_task( - "test-domain", "queue", reverse_order=True) + resp = conn.poll_for_decision_task("test-domain", "queue", reverse_order=True) types = [evt["eventType"] for evt in resp["events"]] types.should.equal( - ["DecisionTaskStarted", "DecisionTaskScheduled", "WorkflowExecutionStarted"]) + ["DecisionTaskStarted", "DecisionTaskScheduled", "WorkflowExecutionStarted"] + ) # CountPendingDecisionTasks endpoint @@ -91,29 +93,32 @@ def test_respond_decision_task_completed_with_no_decision(): task_token = resp["taskToken"] resp = conn.respond_decision_task_completed( - task_token, - execution_context="free-form context", + task_token, execution_context="free-form context" ) resp.should.be.none resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - ]) + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + ] + ) evt = resp["events"][-1] - evt["decisionTaskCompletedEventAttributes"].should.equal({ - "executionContext": "free-form context", - "scheduledEventId": 2, - "startedEventId": 3, - }) + evt["decisionTaskCompletedEventAttributes"].should.equal( + { + "executionContext": "free-form context", + "scheduledEventId": 2, + "startedEventId": 3, + } + ) - resp = conn.describe_workflow_execution( - "test-domain", conn.run_id, "uid-abcd1234") + resp = conn.describe_workflow_execution("test-domain", conn.run_id, "uid-abcd1234") resp["latestExecutionContext"].should.equal("free-form context") @@ -138,9 +143,9 @@ def test_respond_decision_task_completed_on_close_workflow_execution(): wfe.execution_status = "CLOSED" # /bad - conn.respond_decision_task_completed.when.called_with( - task_token - ).should.throw(SWFResponseError) + conn.respond_decision_task_completed.when.called_with(task_token).should.throw( + SWFResponseError + ) @mock_swf_deprecated @@ -150,9 +155,9 @@ def test_respond_decision_task_completed_with_task_already_completed(): task_token = resp["taskToken"] conn.respond_decision_task_completed(task_token) - conn.respond_decision_task_completed.when.called_with( - task_token - ).should.throw(SWFResponseError) + conn.respond_decision_task_completed.when.called_with(task_token).should.throw( + SWFResponseError + ) @mock_swf_deprecated @@ -161,26 +166,31 @@ def test_respond_decision_task_completed_with_complete_workflow_execution(): resp = conn.poll_for_decision_task("test-domain", "queue") task_token = resp["taskToken"] - decisions = [{ - "decisionType": "CompleteWorkflowExecution", - "completeWorkflowExecutionDecisionAttributes": {"result": "foo bar"} - }] - resp = conn.respond_decision_task_completed( - task_token, decisions=decisions) + decisions = [ + { + "decisionType": "CompleteWorkflowExecution", + "completeWorkflowExecutionDecisionAttributes": {"result": "foo bar"}, + } + ] + resp = conn.respond_decision_task_completed(task_token, decisions=decisions) resp.should.be.none resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - "WorkflowExecutionCompleted", - ]) + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + "WorkflowExecutionCompleted", + ] + ) resp["events"][-1]["workflowExecutionCompletedEventAttributes"][ - "result"].should.equal("foo bar") + "result" + ].should.equal("foo bar") @mock_swf_deprecated @@ -211,9 +221,10 @@ def test_respond_decision_task_completed_with_invalid_decision_type(): ] conn.respond_decision_task_completed.when.called_with( - task_token, decisions=decisions).should.throw( - SWFResponseError, - r"Value 'BadDecisionType' at 'decisions.1.member.decisionType'" + task_token, decisions=decisions + ).should.throw( + SWFResponseError, + r"Value 'BadDecisionType' at 'decisions.1.member.decisionType'", ) @@ -226,8 +237,8 @@ def test_respond_decision_task_completed_with_missing_attributes(): decisions = [ { "decisionType": "should trigger even with incorrect decision type", - "startTimerDecisionAttributes": {} - }, + "startTimerDecisionAttributes": {}, + } ] conn.respond_decision_task_completed.when.called_with( @@ -235,7 +246,7 @@ def test_respond_decision_task_completed_with_missing_attributes(): ).should.throw( SWFResponseError, r"Value null at 'decisions.1.member.startTimerDecisionAttributes.timerId' " - r"failed to satisfy constraint: Member must not be null" + r"failed to satisfy constraint: Member must not be null", ) @@ -245,16 +256,14 @@ def test_respond_decision_task_completed_with_missing_attributes_totally(): resp = conn.poll_for_decision_task("test-domain", "queue") task_token = resp["taskToken"] - decisions = [ - {"decisionType": "StartTimer"}, - ] + decisions = [{"decisionType": "StartTimer"}] conn.respond_decision_task_completed.when.called_with( task_token, decisions=decisions ).should.throw( SWFResponseError, r"Value null at 'decisions.1.member.startTimerDecisionAttributes.timerId' " - r"failed to satisfy constraint: Member must not be null" + r"failed to satisfy constraint: Member must not be null", ) @@ -264,24 +273,31 @@ def test_respond_decision_task_completed_with_fail_workflow_execution(): resp = conn.poll_for_decision_task("test-domain", "queue") task_token = resp["taskToken"] - decisions = [{ - "decisionType": "FailWorkflowExecution", - "failWorkflowExecutionDecisionAttributes": {"reason": "my rules", "details": "foo"} - }] - resp = conn.respond_decision_task_completed( - task_token, decisions=decisions) + decisions = [ + { + "decisionType": "FailWorkflowExecution", + "failWorkflowExecutionDecisionAttributes": { + "reason": "my rules", + "details": "foo", + }, + } + ] + resp = conn.respond_decision_task_completed(task_token, decisions=decisions) resp.should.be.none resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - "WorkflowExecutionFailed", - ]) + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + "WorkflowExecutionFailed", + ] + ) attrs = resp["events"][-1]["workflowExecutionFailedEventAttributes"] attrs["reason"].should.equal("my rules") attrs["details"].should.equal("foo") @@ -294,49 +310,44 @@ def test_respond_decision_task_completed_with_schedule_activity_task(): resp = conn.poll_for_decision_task("test-domain", "queue") task_token = resp["taskToken"] - decisions = [{ - "decisionType": "ScheduleActivityTask", - "scheduleActivityTaskDecisionAttributes": { - "activityId": "my-activity-001", - "activityType": { - "name": "test-activity", - "version": "v1.1" - }, - "heartbeatTimeout": "60", - "input": "123", - "taskList": { - "name": "my-task-list" + decisions = [ + { + "decisionType": "ScheduleActivityTask", + "scheduleActivityTaskDecisionAttributes": { + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.1"}, + "heartbeatTimeout": "60", + "input": "123", + "taskList": {"name": "my-task-list"}, }, } - }] - resp = conn.respond_decision_task_completed( - task_token, decisions=decisions) + ] + resp = conn.respond_decision_task_completed(task_token, decisions=decisions) resp.should.be.none resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) types = [evt["eventType"] for evt in resp["events"]] - types.should.equal([ - "WorkflowExecutionStarted", - "DecisionTaskScheduled", - "DecisionTaskStarted", - "DecisionTaskCompleted", - "ActivityTaskScheduled", - ]) - resp["events"][-1]["activityTaskScheduledEventAttributes"].should.equal({ - "decisionTaskCompletedEventId": 4, - "activityId": "my-activity-001", - "activityType": { - "name": "test-activity", - "version": "v1.1", - }, - "heartbeatTimeout": "60", - "input": "123", - "taskList": { - "name": "my-task-list" - }, - }) + types.should.equal( + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskCompleted", + "ActivityTaskScheduled", + ] + ) + resp["events"][-1]["activityTaskScheduledEventAttributes"].should.equal( + { + "decisionTaskCompletedEventId": 4, + "activityId": "my-activity-001", + "activityType": {"name": "test-activity", "version": "v1.1"}, + "heartbeatTimeout": "60", + "input": "123", + "taskList": {"name": "my-task-list"}, + } + ) - resp = conn.describe_workflow_execution( - "test-domain", conn.run_id, "uid-abcd1234") + resp = conn.describe_workflow_execution("test-domain", conn.run_id, "uid-abcd1234") resp["latestActivityTaskTimestamp"].should.equal(1420113600.0) diff --git a/tests/test_swf/responses/test_domains.py b/tests/test_swf/responses/test_domains.py index 8edc76432..638bd410e 100644 --- a/tests/test_swf/responses/test_domains.py +++ b/tests/test_swf/responses/test_domains.py @@ -82,18 +82,16 @@ def test_deprecate_already_deprecated_domain(): conn.register_domain("test-domain", "60", description="A test domain") conn.deprecate_domain("test-domain") - conn.deprecate_domain.when.called_with( - "test-domain" - ).should.throw(SWFResponseError) + conn.deprecate_domain.when.called_with("test-domain").should.throw(SWFResponseError) @mock_swf_deprecated def test_deprecate_non_existent_domain(): conn = boto.connect_swf("the_key", "the_secret") - conn.deprecate_domain.when.called_with( - "non-existent" - ).should.throw(SWFResponseError) + conn.deprecate_domain.when.called_with("non-existent").should.throw( + SWFResponseError + ) # DescribeDomain endpoint @@ -103,8 +101,7 @@ def test_describe_domain(): conn.register_domain("test-domain", "60", description="A test domain") domain = conn.describe_domain("test-domain") - domain["configuration"][ - "workflowExecutionRetentionPeriodInDays"].should.equal("60") + domain["configuration"]["workflowExecutionRetentionPeriodInDays"].should.equal("60") domain["domainInfo"]["description"].should.equal("A test domain") domain["domainInfo"]["name"].should.equal("test-domain") domain["domainInfo"]["status"].should.equal("REGISTERED") @@ -114,6 +111,4 @@ def test_describe_domain(): def test_describe_non_existent_domain(): conn = boto.connect_swf("the_key", "the_secret") - conn.describe_domain.when.called_with( - "non-existent" - ).should.throw(SWFResponseError) + conn.describe_domain.when.called_with("non-existent").should.throw(SWFResponseError) diff --git a/tests/test_swf/responses/test_timeouts.py b/tests/test_swf/responses/test_timeouts.py index f49c597a4..25ca8ae7d 100644 --- a/tests/test_swf/responses/test_timeouts.py +++ b/tests/test_swf/responses/test_timeouts.py @@ -12,23 +12,27 @@ from ..utils import setup_workflow, SCHEDULE_ACTIVITY_TASK_DECISION def test_activity_task_heartbeat_timeout(): with freeze_time("2015-01-01 12:00:00"): conn = setup_workflow() - decision_token = conn.poll_for_decision_task( - "test-domain", "queue")["taskToken"] - conn.respond_decision_task_completed(decision_token, decisions=[ - SCHEDULE_ACTIVITY_TASK_DECISION - ]) + decision_token = conn.poll_for_decision_task("test-domain", "queue")[ + "taskToken" + ] + conn.respond_decision_task_completed( + decision_token, decisions=[SCHEDULE_ACTIVITY_TASK_DECISION] + ) conn.poll_for_activity_task( - "test-domain", "activity-task-list", identity="surprise") + "test-domain", "activity-task-list", identity="surprise" + ) with freeze_time("2015-01-01 12:04:30"): resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) resp["events"][-1]["eventType"].should.equal("ActivityTaskStarted") with freeze_time("2015-01-01 12:05:30"): # => Activity Task Heartbeat timeout reached!! resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) resp["events"][-2]["eventType"].should.equal("ActivityTaskTimedOut") attrs = resp["events"][-2]["activityTaskTimedOutEventAttributes"] @@ -50,7 +54,8 @@ def test_decision_task_start_to_close_timeout(): with freeze_time("2015-01-01 12:04:30"): resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) event_types = [evt["eventType"] for evt in resp["events"]] event_types.should.equal( @@ -60,17 +65,27 @@ def test_decision_task_start_to_close_timeout(): with freeze_time("2015-01-01 12:05:30"): # => Decision Task Start to Close timeout reached!! resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) event_types = [evt["eventType"] for evt in resp["events"]] event_types.should.equal( - ["WorkflowExecutionStarted", "DecisionTaskScheduled", "DecisionTaskStarted", - "DecisionTaskTimedOut", "DecisionTaskScheduled"] + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "DecisionTaskStarted", + "DecisionTaskTimedOut", + "DecisionTaskScheduled", + ] ) attrs = resp["events"][-2]["decisionTaskTimedOutEventAttributes"] - attrs.should.equal({ - "scheduledEventId": 2, "startedEventId": 3, "timeoutType": "START_TO_CLOSE" - }) + attrs.should.equal( + { + "scheduledEventId": 2, + "startedEventId": 3, + "timeoutType": "START_TO_CLOSE", + } + ) # checks that event has been emitted at 12:05:00, not 12:05:30 resp["events"][-2]["eventTimestamp"].should.equal(1420113900.0) @@ -85,26 +100,27 @@ def test_workflow_execution_start_to_close_timeout(): with freeze_time("2015-01-01 13:59:30"): resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) event_types = [evt["eventType"] for evt in resp["events"]] - event_types.should.equal( - ["WorkflowExecutionStarted", "DecisionTaskScheduled"] - ) + event_types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) with freeze_time("2015-01-01 14:00:30"): # => Workflow Execution Start to Close timeout reached!! resp = conn.get_workflow_execution_history( - "test-domain", conn.run_id, "uid-abcd1234") + "test-domain", conn.run_id, "uid-abcd1234" + ) event_types = [evt["eventType"] for evt in resp["events"]] event_types.should.equal( - ["WorkflowExecutionStarted", "DecisionTaskScheduled", - "WorkflowExecutionTimedOut"] + [ + "WorkflowExecutionStarted", + "DecisionTaskScheduled", + "WorkflowExecutionTimedOut", + ] ) attrs = resp["events"][-1]["workflowExecutionTimedOutEventAttributes"] - attrs.should.equal({ - "childPolicy": "ABANDON", "timeoutType": "START_TO_CLOSE" - }) + attrs.should.equal({"childPolicy": "ABANDON", "timeoutType": "START_TO_CLOSE"}) # checks that event has been emitted at 14:00:00, not 14:00:30 resp["events"][-1]["eventTimestamp"].should.equal(1420120800.0) diff --git a/tests/test_swf/responses/test_workflow_executions.py b/tests/test_swf/responses/test_workflow_executions.py index 88e3caa75..bec352ce8 100644 --- a/tests/test_swf/responses/test_workflow_executions.py +++ b/tests/test_swf/responses/test_workflow_executions.py @@ -3,6 +3,7 @@ from boto.swf.exceptions import SWFResponseError from datetime import datetime, timedelta import sure # noqa + # Ensure 'assert_raises' context manager support for Python 2.6 import tests.backport_assert_raises # noqa @@ -16,8 +17,11 @@ def setup_swf_environment(): conn = boto.connect_swf("the_key", "the_secret") conn.register_domain("test-domain", "60", description="A test domain") conn.register_workflow_type( - "test-domain", "test-workflow", "v1.0", - task_list="queue", default_child_policy="TERMINATE", + "test-domain", + "test-workflow", + "v1.0", + task_list="queue", + default_child_policy="TERMINATE", default_execution_start_to_close_timeout="300", default_task_start_to_close_timeout="300", ) @@ -31,29 +35,34 @@ def test_start_workflow_execution(): conn = setup_swf_environment() wf = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) wf.should.contain("runId") + @mock_swf_deprecated def test_signal_workflow_execution(): conn = setup_swf_environment() hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) run_id = hsh["runId"] wfe = conn.signal_workflow_execution( - "test-domain", "my_signal", "uid-abcd1234", "my_input", run_id) + "test-domain", "my_signal", "uid-abcd1234", "my_input", run_id + ) - wfe = conn.describe_workflow_execution( - "test-domain", run_id, "uid-abcd1234") + wfe = conn.describe_workflow_execution("test-domain", run_id, "uid-abcd1234") wfe["openCounts"]["openDecisionTasks"].should.equal(2) + @mock_swf_deprecated def test_start_already_started_workflow_execution(): conn = setup_swf_environment() conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) conn.start_workflow_execution.when.called_with( "test-domain", "uid-abcd1234", "test-workflow", "v1.0" @@ -75,13 +84,12 @@ def test_start_workflow_execution_on_deprecated_type(): def test_describe_workflow_execution(): conn = setup_swf_environment() hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) run_id = hsh["runId"] - wfe = conn.describe_workflow_execution( - "test-domain", run_id, "uid-abcd1234") - wfe["executionInfo"]["execution"][ - "workflowId"].should.equal("uid-abcd1234") + wfe = conn.describe_workflow_execution("test-domain", run_id, "uid-abcd1234") + wfe["executionInfo"]["execution"]["workflowId"].should.equal("uid-abcd1234") wfe["executionInfo"]["executionStatus"].should.equal("OPEN") @@ -99,11 +107,11 @@ def test_describe_non_existent_workflow_execution(): def test_get_workflow_execution_history(): conn = setup_swf_environment() hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) run_id = hsh["runId"] - resp = conn.get_workflow_execution_history( - "test-domain", run_id, "uid-abcd1234") + resp = conn.get_workflow_execution_history("test-domain", run_id, "uid-abcd1234") types = [evt["eventType"] for evt in resp["events"]] types.should.equal(["WorkflowExecutionStarted", "DecisionTaskScheduled"]) @@ -112,11 +120,13 @@ def test_get_workflow_execution_history(): def test_get_workflow_execution_history_with_reverse_order(): conn = setup_swf_environment() hsh = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) run_id = hsh["runId"] - resp = conn.get_workflow_execution_history("test-domain", run_id, "uid-abcd1234", - reverse_order=True) + resp = conn.get_workflow_execution_history( + "test-domain", run_id, "uid-abcd1234", reverse_order=True + ) types = [evt["eventType"] for evt in resp["events"]] types.should.equal(["DecisionTaskScheduled", "WorkflowExecutionStarted"]) @@ -136,32 +146,36 @@ def test_list_open_workflow_executions(): conn = setup_swf_environment() # One open workflow execution conn.start_workflow_execution( - 'test-domain', 'uid-abcd1234', 'test-workflow', 'v1.0' + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" ) # One closed workflow execution to make sure it isn't displayed run_id = conn.start_workflow_execution( - 'test-domain', 'uid-abcd12345', 'test-workflow', 'v1.0' - )['runId'] - conn.terminate_workflow_execution('test-domain', 'uid-abcd12345', - details='some details', - reason='a more complete reason', - run_id=run_id) + "test-domain", "uid-abcd12345", "test-workflow", "v1.0" + )["runId"] + conn.terminate_workflow_execution( + "test-domain", + "uid-abcd12345", + details="some details", + reason="a more complete reason", + run_id=run_id, + ) yesterday = datetime.utcnow() - timedelta(days=1) oldest_date = unix_time(yesterday) - response = conn.list_open_workflow_executions('test-domain', - oldest_date, - workflow_id='test-workflow') - execution_infos = response['executionInfos'] + response = conn.list_open_workflow_executions( + "test-domain", oldest_date, workflow_id="test-workflow" + ) + execution_infos = response["executionInfos"] len(execution_infos).should.equal(1) open_workflow = execution_infos[0] - open_workflow['workflowType'].should.equal({'version': 'v1.0', - 'name': 'test-workflow'}) - open_workflow.should.contain('startTimestamp') - open_workflow['execution']['workflowId'].should.equal('uid-abcd1234') - open_workflow['execution'].should.contain('runId') - open_workflow['cancelRequested'].should.be(False) - open_workflow['executionStatus'].should.equal('OPEN') + open_workflow["workflowType"].should.equal( + {"version": "v1.0", "name": "test-workflow"} + ) + open_workflow.should.contain("startTimestamp") + open_workflow["execution"]["workflowId"].should.equal("uid-abcd1234") + open_workflow["execution"].should.contain("runId") + open_workflow["cancelRequested"].should.be(False) + open_workflow["executionStatus"].should.equal("OPEN") # ListClosedWorkflowExecutions endpoint @@ -170,33 +184,36 @@ def test_list_closed_workflow_executions(): conn = setup_swf_environment() # Leave one workflow execution open to make sure it isn't displayed conn.start_workflow_execution( - 'test-domain', 'uid-abcd1234', 'test-workflow', 'v1.0' + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" ) # One closed workflow execution run_id = conn.start_workflow_execution( - 'test-domain', 'uid-abcd12345', 'test-workflow', 'v1.0' - )['runId'] - conn.terminate_workflow_execution('test-domain', 'uid-abcd12345', - details='some details', - reason='a more complete reason', - run_id=run_id) + "test-domain", "uid-abcd12345", "test-workflow", "v1.0" + )["runId"] + conn.terminate_workflow_execution( + "test-domain", + "uid-abcd12345", + details="some details", + reason="a more complete reason", + run_id=run_id, + ) yesterday = datetime.utcnow() - timedelta(days=1) oldest_date = unix_time(yesterday) response = conn.list_closed_workflow_executions( - 'test-domain', - start_oldest_date=oldest_date, - workflow_id='test-workflow') - execution_infos = response['executionInfos'] + "test-domain", start_oldest_date=oldest_date, workflow_id="test-workflow" + ) + execution_infos = response["executionInfos"] len(execution_infos).should.equal(1) open_workflow = execution_infos[0] - open_workflow['workflowType'].should.equal({'version': 'v1.0', - 'name': 'test-workflow'}) - open_workflow.should.contain('startTimestamp') - open_workflow['execution']['workflowId'].should.equal('uid-abcd12345') - open_workflow['execution'].should.contain('runId') - open_workflow['cancelRequested'].should.be(False) - open_workflow['executionStatus'].should.equal('CLOSED') + open_workflow["workflowType"].should.equal( + {"version": "v1.0", "name": "test-workflow"} + ) + open_workflow.should.contain("startTimestamp") + open_workflow["execution"]["workflowId"].should.equal("uid-abcd12345") + open_workflow["execution"].should.contain("runId") + open_workflow["cancelRequested"].should.be(False) + open_workflow["executionStatus"].should.equal("CLOSED") # TerminateWorkflowExecution endpoint @@ -207,14 +224,16 @@ def test_terminate_workflow_execution(): "test-domain", "uid-abcd1234", "test-workflow", "v1.0" )["runId"] - resp = conn.terminate_workflow_execution("test-domain", "uid-abcd1234", - details="some details", - reason="a more complete reason", - run_id=run_id) + resp = conn.terminate_workflow_execution( + "test-domain", + "uid-abcd1234", + details="some details", + reason="a more complete reason", + run_id=run_id, + ) resp.should.be.none - resp = conn.get_workflow_execution_history( - "test-domain", run_id, "uid-abcd1234") + resp = conn.get_workflow_execution_history("test-domain", run_id, "uid-abcd1234") evt = resp["events"][-1] evt["eventType"].should.equal("WorkflowExecutionTerminated") attrs = evt["workflowExecutionTerminatedEventAttributes"] @@ -243,16 +262,12 @@ def test_terminate_workflow_execution_with_wrong_workflow_or_run_id(): # already closed, without run_id conn.terminate_workflow_execution.when.called_with( "test-domain", "uid-abcd1234" - ).should.throw( - SWFResponseError, "Unknown execution, workflowId = uid-abcd1234" - ) + ).should.throw(SWFResponseError, "Unknown execution, workflowId = uid-abcd1234") # wrong workflow id conn.terminate_workflow_execution.when.called_with( "test-domain", "uid-non-existent" - ).should.throw( - SWFResponseError, "Unknown execution, workflowId = uid-non-existent" - ) + ).should.throw(SWFResponseError, "Unknown execution, workflowId = uid-non-existent") # wrong run_id conn.terminate_workflow_execution.when.called_with( diff --git a/tests/test_swf/responses/test_workflow_types.py b/tests/test_swf/responses/test_workflow_types.py index 9e097a873..4c92d7762 100644 --- a/tests/test_swf/responses/test_workflow_types.py +++ b/tests/test_swf/responses/test_workflow_types.py @@ -49,10 +49,11 @@ def test_list_workflow_types(): conn.register_workflow_type("test-domain", "c-test-workflow", "v1.0") all_workflow_types = conn.list_workflow_types("test-domain", "REGISTERED") - names = [activity_type["workflowType"]["name"] - for activity_type in all_workflow_types["typeInfos"]] - names.should.equal( - ["a-test-workflow", "b-test-workflow", "c-test-workflow"]) + names = [ + activity_type["workflowType"]["name"] + for activity_type in all_workflow_types["typeInfos"] + ] + names.should.equal(["a-test-workflow", "b-test-workflow", "c-test-workflow"]) @mock_swf_deprecated @@ -63,12 +64,14 @@ def test_list_workflow_types_reverse_order(): conn.register_workflow_type("test-domain", "a-test-workflow", "v1.0") conn.register_workflow_type("test-domain", "c-test-workflow", "v1.0") - all_workflow_types = conn.list_workflow_types("test-domain", "REGISTERED", - reverse_order=True) - names = [activity_type["workflowType"]["name"] - for activity_type in all_workflow_types["typeInfos"]] - names.should.equal( - ["c-test-workflow", "b-test-workflow", "a-test-workflow"]) + all_workflow_types = conn.list_workflow_types( + "test-domain", "REGISTERED", reverse_order=True + ) + names = [ + activity_type["workflowType"]["name"] + for activity_type in all_workflow_types["typeInfos"] + ] + names.should.equal(["c-test-workflow", "b-test-workflow", "a-test-workflow"]) # DeprecateWorkflowType endpoint @@ -112,15 +115,18 @@ def test_deprecate_non_existent_workflow_type(): def test_describe_workflow_type(): conn = boto.connect_swf("the_key", "the_secret") conn.register_domain("test-domain", "60") - conn.register_workflow_type("test-domain", "test-workflow", "v1.0", - task_list="foo", default_child_policy="TERMINATE") + conn.register_workflow_type( + "test-domain", + "test-workflow", + "v1.0", + task_list="foo", + default_child_policy="TERMINATE", + ) - actype = conn.describe_workflow_type( - "test-domain", "test-workflow", "v1.0") + actype = conn.describe_workflow_type("test-domain", "test-workflow", "v1.0") actype["configuration"]["defaultTaskList"]["name"].should.equal("foo") actype["configuration"]["defaultChildPolicy"].should.equal("TERMINATE") - actype["configuration"].keys().should_not.contain( - "defaultTaskStartToCloseTimeout") + actype["configuration"].keys().should_not.contain("defaultTaskStartToCloseTimeout") infos = actype["typeInfo"] infos["workflowType"]["name"].should.equal("test-workflow") infos["workflowType"]["version"].should.equal("v1.0") diff --git a/tests/test_swf/test_exceptions.py b/tests/test_swf/test_exceptions.py index 8617242b9..2e42cdb9b 100644 --- a/tests/test_swf/test_exceptions.py +++ b/tests/test_swf/test_exceptions.py @@ -16,69 +16,76 @@ from moto.swf.exceptions import ( SWFValidationException, SWFDecisionValidationException, ) -from moto.swf.models import ( - WorkflowType, -) +from moto.swf.models import WorkflowType def test_swf_client_error(): ex = SWFClientError("ASpecificType", "error message") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "ASpecificType", - "message": "error message" - }) + json.loads(ex.get_body()).should.equal( + {"__type": "ASpecificType", "message": "error message"} + ) def test_swf_unknown_resource_fault(): ex = SWFUnknownResourceFault("type", "detail") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", - "message": "Unknown type: detail" - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", + "message": "Unknown type: detail", + } + ) def test_swf_unknown_resource_fault_with_only_one_parameter(): ex = SWFUnknownResourceFault("foo bar baz") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", - "message": "Unknown foo bar baz" - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#UnknownResourceFault", + "message": "Unknown foo bar baz", + } + ) def test_swf_domain_already_exists_fault(): ex = SWFDomainAlreadyExistsFault("domain-name") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", - "message": "domain-name" - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#DomainAlreadyExistsFault", + "message": "domain-name", + } + ) def test_swf_domain_deprecated_fault(): ex = SWFDomainDeprecatedFault("domain-name") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#DomainDeprecatedFault", - "message": "domain-name" - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#DomainDeprecatedFault", + "message": "domain-name", + } + ) def test_swf_serialization_exception(): ex = SWFSerializationException("value") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#SerializationException", - "message": "class java.lang.Foo can not be converted to an String (not a real SWF exception ; happened on: value)" - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#SerializationException", + "message": "class java.lang.Foo can not be converted to an String (not a real SWF exception ; happened on: value)", + } + ) def test_swf_type_already_exists_fault(): @@ -86,10 +93,12 @@ def test_swf_type_already_exists_fault(): ex = SWFTypeAlreadyExistsFault(wft) ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#TypeAlreadyExistsFault", - "message": "WorkflowType=[name=wf-name, version=wf-version]" - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#TypeAlreadyExistsFault", + "message": "WorkflowType=[name=wf-name, version=wf-version]", + } + ) def test_swf_type_deprecated_fault(): @@ -97,51 +106,65 @@ def test_swf_type_deprecated_fault(): ex = SWFTypeDeprecatedFault(wft) ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#TypeDeprecatedFault", - "message": "WorkflowType=[name=wf-name, version=wf-version]" - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#TypeDeprecatedFault", + "message": "WorkflowType=[name=wf-name, version=wf-version]", + } + ) def test_swf_workflow_execution_already_started_fault(): ex = SWFWorkflowExecutionAlreadyStartedFault() ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", - 'message': 'Already Started', - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#WorkflowExecutionAlreadyStartedFault", + "message": "Already Started", + } + ) def test_swf_default_undefined_fault(): ex = SWFDefaultUndefinedFault("execution_start_to_close_timeout") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazonaws.swf.base.model#DefaultUndefinedFault", - "message": "executionStartToCloseTimeout", - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazonaws.swf.base.model#DefaultUndefinedFault", + "message": "executionStartToCloseTimeout", + } + ) def test_swf_validation_exception(): ex = SWFValidationException("Invalid token") ex.code.should.equal(400) - json.loads(ex.get_body()).should.equal({ - "__type": "com.amazon.coral.validate#ValidationException", - "message": "Invalid token", - }) + json.loads(ex.get_body()).should.equal( + { + "__type": "com.amazon.coral.validate#ValidationException", + "message": "Invalid token", + } + ) def test_swf_decision_validation_error(): - ex = SWFDecisionValidationException([ - {"type": "null_value", - "where": "decisions.1.member.startTimerDecisionAttributes.startToFireTimeout"}, - {"type": "bad_decision_type", - "value": "FooBar", - "where": "decisions.1.member.decisionType", - "possible_values": "Foo, Bar, Baz"}, - ]) + ex = SWFDecisionValidationException( + [ + { + "type": "null_value", + "where": "decisions.1.member.startTimerDecisionAttributes.startToFireTimeout", + }, + { + "type": "bad_decision_type", + "value": "FooBar", + "where": "decisions.1.member.decisionType", + "possible_values": "Foo, Bar, Baz", + }, + ] + ) ex.code.should.equal(400) ex.error_type.should.equal("com.amazon.coral.validate#ValidationException") diff --git a/tests/test_swf/test_utils.py b/tests/test_swf/test_utils.py index ffa147037..328342bbe 100644 --- a/tests/test_swf/test_utils.py +++ b/tests/test_swf/test_utils.py @@ -4,10 +4,6 @@ from moto.swf.utils import decapitalize def test_decapitalize(): - cases = { - "fooBar": "fooBar", - "FooBar": "fooBar", - "FOO BAR": "fOO BAR", - } + cases = {"fooBar": "fooBar", "FooBar": "fooBar", "FOO BAR": "fOO BAR"} for before, after in cases.items(): decapitalize(before).should.equal(after) diff --git a/tests/test_swf/utils.py b/tests/test_swf/utils.py index 2197b71df..48c2cbd94 100644 --- a/tests/test_swf/utils.py +++ b/tests/test_swf/utils.py @@ -1,11 +1,6 @@ import boto -from moto.swf.models import ( - ActivityType, - Domain, - WorkflowType, - WorkflowExecution, -) +from moto.swf.models import ActivityType, Domain, WorkflowType, WorkflowExecution # Some useful constants @@ -13,9 +8,9 @@ from moto.swf.models import ( # from semi-real world example, the goal is mostly to have predictible and # intuitive behaviour in moto/swf own tests... ACTIVITY_TASK_TIMEOUTS = { - "heartbeatTimeout": "300", # 5 mins + "heartbeatTimeout": "300", # 5 mins "scheduleToStartTimeout": "1800", # 30 mins - "startToCloseTimeout": "1800", # 30 mins + "startToCloseTimeout": "1800", # 30 mins "scheduleToCloseTimeout": "2700", # 45 mins } @@ -26,11 +21,12 @@ SCHEDULE_ACTIVITY_TASK_DECISION = { "activityId": "my-activity-001", "activityType": {"name": "test-activity", "version": "v1.1"}, "taskList": {"name": "activity-task-list"}, - } + }, } for key, value in ACTIVITY_TASK_TIMEOUTS.items(): - SCHEDULE_ACTIVITY_TASK_DECISION[ - "scheduleActivityTaskDecisionAttributes"][key] = value + SCHEDULE_ACTIVITY_TASK_DECISION["scheduleActivityTaskDecisionAttributes"][ + key + ] = value # A test Domain @@ -40,14 +36,15 @@ def get_basic_domain(): # A test WorkflowType def _generic_workflow_type_attributes(): - return [ - "test-workflow", "v1.0" - ], { - "task_list": "queue", - "default_child_policy": "ABANDON", - "default_execution_start_to_close_timeout": "7200", - "default_task_start_to_close_timeout": "300", - } + return ( + ["test-workflow", "v1.0"], + { + "task_list": "queue", + "default_child_policy": "ABANDON", + "default_execution_start_to_close_timeout": "7200", + "default_task_start_to_close_timeout": "300", + }, + ) def get_basic_workflow_type(): @@ -81,14 +78,17 @@ def setup_workflow(): conn.register_domain("test-domain", "60", description="A test domain") conn = mock_basic_workflow_type("test-domain", conn) conn.register_activity_type( - "test-domain", "test-activity", "v1.1", + "test-domain", + "test-activity", + "v1.1", default_task_heartbeat_timeout="600", default_task_schedule_to_close_timeout="600", default_task_schedule_to_start_timeout="600", default_task_start_to_close_timeout="600", ) wfe = conn.start_workflow_execution( - "test-domain", "uid-abcd1234", "test-workflow", "v1.0") + "test-domain", "uid-abcd1234", "test-workflow", "v1.0" + ) conn.run_id = wfe["runId"] return conn diff --git a/tests/test_xray/test_xray_boto3.py b/tests/test_xray/test_xray_boto3.py index 5ad8f8bc7..4089abd2e 100644 --- a/tests/test_xray/test_xray_boto3.py +++ b/tests/test_xray/test_xray_boto3.py @@ -3,7 +3,7 @@ from __future__ import unicode_literals import boto3 import json import botocore.exceptions -import sure # noqa +import sure # noqa from moto import mock_xray @@ -12,128 +12,137 @@ import datetime @mock_xray def test_put_telemetry(): - client = boto3.client('xray', region_name='us-east-1') + client = boto3.client("xray", region_name="us-east-1") client.put_telemetry_records( TelemetryRecords=[ { - 'Timestamp': datetime.datetime(2015, 1, 1), - 'SegmentsReceivedCount': 123, - 'SegmentsSentCount': 123, - 'SegmentsSpilloverCount': 123, - 'SegmentsRejectedCount': 123, - 'BackendConnectionErrors': { - 'TimeoutCount': 123, - 'ConnectionRefusedCount': 123, - 'HTTPCode4XXCount': 123, - 'HTTPCode5XXCount': 123, - 'UnknownHostCount': 123, - 'OtherCount': 123 - } - }, + "Timestamp": datetime.datetime(2015, 1, 1), + "SegmentsReceivedCount": 123, + "SegmentsSentCount": 123, + "SegmentsSpilloverCount": 123, + "SegmentsRejectedCount": 123, + "BackendConnectionErrors": { + "TimeoutCount": 123, + "ConnectionRefusedCount": 123, + "HTTPCode4XXCount": 123, + "HTTPCode5XXCount": 123, + "UnknownHostCount": 123, + "OtherCount": 123, + }, + } ], - EC2InstanceId='string', - Hostname='string', - ResourceARN='string' + EC2InstanceId="string", + Hostname="string", + ResourceARN="string", ) @mock_xray def test_put_trace_segments(): - client = boto3.client('xray', region_name='us-east-1') + client = boto3.client("xray", region_name="us-east-1") client.put_trace_segments( TraceSegmentDocuments=[ - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0a', - 'start_time': 1.478293361271E9, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'end_time': 1.478293361449E9 - }) + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0a", + "start_time": 1.478293361271e9, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "end_time": 1.478293361449e9, + } + ) ] ) @mock_xray def test_trace_summary(): - client = boto3.client('xray', region_name='us-east-1') + client = boto3.client("xray", region_name="us-east-1") client.put_trace_segments( TraceSegmentDocuments=[ - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0a', - 'start_time': 1.478293361271E9, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'in_progress': True - }), - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0b', - 'start_time': 1478293365, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'end_time': 1478293385 - }) + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0a", + "start_time": 1.478293361271e9, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "in_progress": True, + } + ), + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0b", + "start_time": 1478293365, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "end_time": 1478293385, + } + ), ] ) client.get_trace_summaries( - StartTime=datetime.datetime(2014, 1, 1), - EndTime=datetime.datetime(2017, 1, 1) + StartTime=datetime.datetime(2014, 1, 1), EndTime=datetime.datetime(2017, 1, 1) ) @mock_xray def test_batch_get_trace(): - client = boto3.client('xray', region_name='us-east-1') + client = boto3.client("xray", region_name="us-east-1") client.put_trace_segments( TraceSegmentDocuments=[ - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0a', - 'start_time': 1.478293361271E9, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'in_progress': True - }), - json.dumps({ - 'name': 'example.com', - 'id': '70de5b6f19ff9a0b', - 'start_time': 1478293365, - 'trace_id': '1-581cf771-a006649127e371903a2de979', - 'end_time': 1478293385 - }) + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0a", + "start_time": 1.478293361271e9, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "in_progress": True, + } + ), + json.dumps( + { + "name": "example.com", + "id": "70de5b6f19ff9a0b", + "start_time": 1478293365, + "trace_id": "1-581cf771-a006649127e371903a2de979", + "end_time": 1478293385, + } + ), ] ) resp = client.batch_get_traces( - TraceIds=['1-581cf771-a006649127e371903a2de979', '1-581cf772-b006649127e371903a2de979'] + TraceIds=[ + "1-581cf771-a006649127e371903a2de979", + "1-581cf772-b006649127e371903a2de979", + ] ) - len(resp['UnprocessedTraceIds']).should.equal(1) - len(resp['Traces']).should.equal(1) + len(resp["UnprocessedTraceIds"]).should.equal(1) + len(resp["Traces"]).should.equal(1) # Following are not implemented, just testing it returns what boto expects @mock_xray def test_batch_get_service_graph(): - client = boto3.client('xray', region_name='us-east-1') + client = boto3.client("xray", region_name="us-east-1") client.get_service_graph( - StartTime=datetime.datetime(2014, 1, 1), - EndTime=datetime.datetime(2017, 1, 1) + StartTime=datetime.datetime(2014, 1, 1), EndTime=datetime.datetime(2017, 1, 1) ) @mock_xray def test_batch_get_trace_graph(): - client = boto3.client('xray', region_name='us-east-1') + client = boto3.client("xray", region_name="us-east-1") client.batch_get_traces( - TraceIds=['1-581cf771-a006649127e371903a2de979', '1-581cf772-b006649127e371903a2de979'] + TraceIds=[ + "1-581cf771-a006649127e371903a2de979", + "1-581cf772-b006649127e371903a2de979", + ] ) - - - - - diff --git a/tests/test_xray/test_xray_client.py b/tests/test_xray/test_xray_client.py index 0cd948950..6b74136c9 100644 --- a/tests/test_xray/test_xray_client.py +++ b/tests/test_xray/test_xray_client.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals from moto import mock_xray_client, XRaySegment, mock_dynamodb2 -import sure # noqa +import sure # noqa import boto3 from moto.xray.mock_client import MockEmitter @@ -9,10 +9,12 @@ import aws_xray_sdk.core.patcher as xray_core_patcher import botocore.client import botocore.endpoint + original_make_api_call = botocore.client.BaseClient._make_api_call original_encode_headers = botocore.endpoint.Endpoint._encode_headers import requests + original_session_request = requests.Session.request original_session_prep_request = requests.Session.prepare_request @@ -24,24 +26,24 @@ def test_xray_dynamo_request_id(): xray_core_patcher._PATCHED_MODULES = set() xray_core.patch_all() - client = boto3.client('dynamodb', region_name='us-east-1') + client = boto3.client("dynamodb", region_name="us-east-1") with XRaySegment(): resp = client.list_tables() - resp['ResponseMetadata'].should.contain('RequestId') - id1 = resp['ResponseMetadata']['RequestId'] + resp["ResponseMetadata"].should.contain("RequestId") + id1 = resp["ResponseMetadata"]["RequestId"] with XRaySegment(): client.list_tables() resp = client.list_tables() - id2 = resp['ResponseMetadata']['RequestId'] + id2 = resp["ResponseMetadata"]["RequestId"] id1.should_not.equal(id2) - setattr(botocore.client.BaseClient, '_make_api_call', original_make_api_call) - setattr(botocore.endpoint.Endpoint, '_encode_headers', original_encode_headers) - setattr(requests.Session, 'request', original_session_request) - setattr(requests.Session, 'prepare_request', original_session_prep_request) + setattr(botocore.client.BaseClient, "_make_api_call", original_make_api_call) + setattr(botocore.endpoint.Endpoint, "_encode_headers", original_encode_headers) + setattr(requests.Session, "request", original_session_request) + setattr(requests.Session, "prepare_request", original_session_prep_request) @mock_xray_client @@ -52,10 +54,10 @@ def test_xray_udp_emitter_patched(): assert isinstance(xray_core.xray_recorder._emitter, MockEmitter) - setattr(botocore.client.BaseClient, '_make_api_call', original_make_api_call) - setattr(botocore.endpoint.Endpoint, '_encode_headers', original_encode_headers) - setattr(requests.Session, 'request', original_session_request) - setattr(requests.Session, 'prepare_request', original_session_prep_request) + setattr(botocore.client.BaseClient, "_make_api_call", original_make_api_call) + setattr(botocore.endpoint.Endpoint, "_encode_headers", original_encode_headers) + setattr(requests.Session, "request", original_session_request) + setattr(requests.Session, "prepare_request", original_session_prep_request) @mock_xray_client @@ -64,9 +66,9 @@ def test_xray_context_patched(): xray_core_patcher._PATCHED_MODULES = set() xray_core.patch_all() - xray_core.xray_recorder._context.context_missing.should.equal('LOG_ERROR') + xray_core.xray_recorder._context.context_missing.should.equal("LOG_ERROR") - setattr(botocore.client.BaseClient, '_make_api_call', original_make_api_call) - setattr(botocore.endpoint.Endpoint, '_encode_headers', original_encode_headers) - setattr(requests.Session, 'request', original_session_request) - setattr(requests.Session, 'prepare_request', original_session_prep_request) + setattr(botocore.client.BaseClient, "_make_api_call", original_make_api_call) + setattr(botocore.endpoint.Endpoint, "_encode_headers", original_encode_headers) + setattr(requests.Session, "request", original_session_request) + setattr(requests.Session, "prepare_request", original_session_prep_request) diff --git a/tox.ini b/tox.ini index 570b5790f..9dacca18c 100644 --- a/tox.ini +++ b/tox.ini @@ -15,5 +15,5 @@ commands = nosetests {posargs} [flake8] -ignore = E128,E501 +ignore = W503,W605,E128,E501,E203,E266,E501,E231 exclude = moto/packages,dist